golang的反射机制

前端之家收集整理的这篇文章主要介绍了golang的反射机制前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

go反射

什么是反射?使用反射可以实现什么功能

反射提供了一种可以操作任意类型数据结构的能力。通过反射你可以实现对任意类型的深度复制,可以对任意类型执行自定义的操作。另外由于golang可以给结构体定义tag熟悉,结合反射,于是就可以实现结构体与数据库表、结构体与json对象之间的相互转换。

使用反射需要注意什么事情?

使用反射时需要先确定要操作的值是否是期望的类型,是否是可以进行“赋值”操作的,否则reflect包将会毫不留情的产生一个panic。

struct initializer示例

go-struct-initializer 是我在学习golang反射过程中实现的一个可以根据struct tag初始化结构体的go library,这里对其中用到的反射技术做一下说明

package structinitializer

import (
    "reflect"
    "fmt"
    "strconv"
    "strings"
)

const (
    ERROR_NOT_A_POINTER = "NotAPointer"
    ERROR_NOT_IMPLEMENT = "NotImplement"
    ERROR_POINTE_TO_NON_STRUCT = "PointToNonStruct"
    ERROR_TYPE_CONVERSION_Failed = "TypeConversionFailed"
    ERROR_MULTI_ERRORS_FOUND = "MultiErrorsFound"
)

type Error struct {
    Reason  string
    Message string
    Stack   string
}

func NewError(reason,message string,stack string) error {
    return &Error{reason,message,stack}
}

func mergedError(errors []error,stackName string) error {
    if len(errors) == 0 {
        return nil
    }
    if len(errors) == 1 {
        return errors[0]
    }
    points := make([]string,len(errors))
    for i,err := range errors {
        points[i] = fmt.Sprintf("* %s",err)
    }
    return &Error{
        ERROR_MULTI_ERRORS_FOUND,strings.Join(points,"\n"),stackName,}
}

func (e *Error) Error() string {
    if e.Stack == "" {
        return fmt.Sprintf("%s: %s",e.Reason,e.Message)
    }
    return fmt.Sprintf("%s:(%s) %s",e.Stack,e.Message)
}

type InitialiserConfig struct {
    TagName string
}

type Initialiser struct {
    config *InitialiserConfig
}

func NewStructInitialiser() Initialiser {
    config := InitialiserConfig{
        TagName: "default",}
    initialiser := Initialiser{
        config: &config,}
    return initialiser
}

func (self *Initialiser) initialiseInt(stackName,tag string,val reflect.Value) error {
    if tag != "" {
        i,err := strconv.ParseInt(tag, 10, 64)
        if err != nil {
            return NewError(
                ERROR_TYPE_CONVERSION_Failed,fmt.Sprintf(`"%s" can't convert to int64`,tag),)
        }
        if val.Int() == 0 {
            val.SetInt(int64(i))
        }
    }
    return nil
}

func (self *Initialiser) initialiseUint(stackName,err := strconv.ParseUint(tag,fmt.Sprintf(`"%s" can't convert to uint64`,)
        }
        if val.Uint() == 0 {
            val.SetUint(uint64(i))  // 设置value
        }
    }
    return nil
}

func (self *Initialiser) initialiseString(stackName,val reflect.Value) error {
    if tag != "" && val.String() == ""{
        val.SetString(tag)
    }
    return nil
}

// 检查某个结构体是否存在某个域设置了default标签
func (self *Initialiser) checkStructHasDefaultValue(typ reflect.Type) bool {
    for i:=0; i<typ.NumField(); i++ {
        field := typ.Field(i)
        if field.Tag.Get(self.config.TagName) != "" {  // 获取struct field的指定tag
            return true
        }
        if field.Type.Kind() == reflect.Struct && self.checkStructHasDefaultValue(field.Type) {  // 
            return true
        }
    }
    return false
}

func (self *Initialiser) initialisePtr(stackName,val reflect.Value) error {
    if tag == "-" {
        return nil
    }
    typ := val.Type().Elem()
    if tag == "" {
        if typ.Kind() != reflect.Struct {
            return nil
        } else if !self.checkStructHasDefaultValue(typ) {
            return nil
        }
    }
        // 结构体存在至少一个域设置了default标签才有初始化的必要
    realVal := reflect.New(typ)  // 根据结构体类型创建一个新的结构体
        // Indirect可以获取到value对应的真实value(指针指向的value)
    if err := self.initialise(stackName,tag,reflect.Indirect(realVal)); err != nil {
        return err
    }
    val.Set(realVal)
    return nil
}

func (self *Initialiser) initialise(stackName string,val reflect.Value) error {
    typ := val.Type()
    kind := typ.Kind()
    if kind == reflect.Struct {
        return self.initialiseStruct(stackName,val)
    }
    switch kind {
    case reflect.Int:
        fallthrough
    case reflect.Int8:
        fallthrough
    case reflect.Int16:
        fallthrough
    case reflect.Int32:
        fallthrough
    case reflect.Int64:
        return self.initialiseInt(stackName,val)
    case reflect.Uint:
        fallthrough
    case reflect.Uint8:
        fallthrough
    case reflect.Uint16:
        fallthrough
    case reflect.Uint32:
        fallthrough
    case reflect.Uint64:
        fallthrough
    case reflect.Uintptr:
        return self.initialiseUint(stackName,val)
    case reflect.String:
        return self.initialiseString(stackName,val)
    case reflect.Ptr:
        return self.initialisePtr(stackName,val)
    default:
        return NewError(ERROR_NOT_IMPLEMENT,"not implement",stackName)
    }
    return NewError(ERROR_NOT_IMPLEMENT,stackName)
}

func (self *Initialiser) initialiseStruct(stackName string,structVal reflect.Value) error {
    structs := make([]reflect.Value, 1, 5)
    structs[0] = structVal
    errors := make([]error, 0)
    for len(structs) > 0 {
        structVal := structs[0]
        structs = structs[1:]
        structTyp := structVal.Type() // 结构体类型信息
        for i:=0; i<structTyp.NumField(); i++ { // 循环结构体的所有域 
            structField := structTyp.Field(i) // 获取结构体域信息
            tag := structField.Tag // 获取结构体域的标签
            defaultTag := tag.Get(self.config.TagName);
            if structField.Anonymous {  // 这个域是一个结构体组合
                structs = append(structs,structVal.Field(i))
                continue
            }
            fieldName := structField.Name // 获取结构体域的名字
            if stackName != "" {
                fieldName = fmt.Sprintf("%s.%s",fieldName)
            }
            err := self.initialise(
                fieldName,defaultTag,structVal.Field(i),// 获取结构体域对应的值信息
            )
            if err != nil {
                errors = append(errors,err)
            }
        }
    }
    return mergedError(errors,stackName)
}

func (self *Initialiser) Initialise(inf interface{}) error {
    typ := reflect.TypeOf(inf)
    kind := typ.Kind() // kind与type的区别在于:kind无法区分基本类型的别名
    if kind != reflect.Ptr { // 只有指针才能执行初始化操作
        return NewError(
            ERROR_NOT_A_POINTER,fmt.Sprintf("%s not a pointer",typ),"",)
    }
    val := reflect.ValueOf(inf).Elem() // 获取指针对应的值信息
    kind = val.Kind()
    if kind != reflect.Struct {
        return NewError(
            ERROR_POINTE_TO_NON_STRUCT,fmt.Sprintf("%s point to non struct",)
    }
    return self.initialiseStruct(val.Type().Name(),val)
}

func InitializeStruct(struct_ interface{}) error {
    initialiser := NewStructInitialiser()
    return initialiser.Initialise(struct_)
}
package structinitializer

import (
    "testing"
    "fmt"
    "reflect"
)

type Validator interface {
    Valid() bool
}

type TNormStruct struct  {
    AInt int
    AStr string
}

type TStruct struct {
    AInt int `default:"10"`
    AStr string `default:"hello"`
}

func (self *TStruct) Valid() bool {
    return self.AStr == "hello" && self.AInt == 10
}

type TWrongStruct struct {
    AInt int `default:"hhhhh"`
    AStr string `default:"hello"`
}

func (self *TWrongStruct) Valid() bool {
    return false
}

type TEmbeddStruct struct  {
    AStruct TStruct
    AInt int `default:"101"`
}

func (self *TEmbeddStruct) Valid() bool {
    return self.AStruct.Valid() && self.AInt == 101
}

type TEmbeddStruct2 struct  {
    AStruct TStruct
    AInt int
}

type TAnonymousStruct struct  {
    TStruct
    AUint uint16 `default:"20"`
}

func (self *TAnonymousStruct) Valid() bool  {
    return self.TStruct.Valid() && self.AUint == 20
}

type TPointerStruct struct {
    AStructPtr *TStruct
    AIntPtr *int32 `default:"20"`
    AStrPtr *string `default:"-"`
    AUintPtr *uint
    ANormStruct *TNormStruct
}

func (self *TPointerStruct) Valid() bool {
    if self.AStructPtr != nil && self.AStructPtr.Valid() && self.AIntPtr != nil && *self.AIntPtr == 20 && self.AStrPtr == nil && self.AUintPtr == nil && self.ANormStruct == nil {
        return true
    }
    return false
}

func AssertTrue(t *testing.T,a bool,msg string) {
    if !a {
        t.Errorf(msg)
    }
}

func TestCheckStructHasDefaultValue(t *testing.T) {
    self := NewStructInitialiser()
    AssertTrue(t,self.checkStructHasDefaultValue(reflect.TypeOf(TStruct{})),"TStruct check Failed!")
    AssertTrue(t,!self.checkStructHasDefaultValue(reflect.TypeOf(TNormStruct{})),"TNormStrcut check succeed!")
    AssertTrue(t,self.checkStructHasDefaultValue(reflect.TypeOf(TEmbeddStruct2{})),"TEmbeddStruct2 check Failed!")
}

func checkError(t *testing.T,err error,wantReason string) {
    if err == nil {
        t.Errorf("except non pointer err!")
        return
    } else {
        if err1 := err.(*Error); err1.Reason != wantReason {
            t.Errorf("want '%s' but '%s': %s",wantReason,err1.Reason,err1.Message)
            return
        }
    }
    fmt.Printf("%#v\n",err)
}

func TestNonPointer(t *testing.T) {
    aStruct := TStruct{}
    err := InitializeStruct(aStruct)
    checkError(t,err,ERROR_NOT_A_POINTER)
}

func TestPointToNonStruct(t *testing.T) {
    aInt := 100
    err := InitializeStruct(&aInt)
    checkError(t,ERROR_POINTE_TO_NON_STRUCT)
}

func TestWrongTag(t *testing.T) {
    aStruct := TWrongStruct{}
    err := InitializeStruct(&aStruct)
    checkError(t,ERROR_TYPE_CONVERSION_Failed)
}

func TestDefaultSetted(t *testing.T) {
    aStruct := TStruct{
        AInt: 103,}
    err := InitializeStruct(&aStruct)
    if err != nil || aStruct.AInt != 103 {
        t.Errorf("BOOM!!! rewrite a user setted value? err: %s",err)
    }
    fmt.Printf("%#v\n",aStruct)
}

func checkStruct(t *testing.T,v Validator) {
    err := InitializeStruct(v)
    if err != nil {
        t.Errorf("unexcept err: %#v,aStruct: %#v\n",v)
        return
    }

    if !v.Valid() {
        t.Errorf("init incorret: %#v\n",v)
        return
    }
    fmt.Printf("%#v\n",v)
}

func TestEmbeddStruct(t *testing.T) {
    aStruct := TEmbeddStruct{}
    checkStruct(t,&aStruct)
}

func TestAnonymousStrct(t *testing.T) {
    aStruct := TAnonymousStruct{}
    checkStruct(t,&aStruct)
}

func TestNormalStruct(t *testing.T) {
    aStruct := TStruct{}
    checkStruct(t,&aStruct)
}

func TestPointerStruct(t *testing.T) {
    aStruct := TPointerStruct{}
    checkStruct(t,&aStruct)
}

参考

原文链接:https://www.f2er.com/go/189445.html

猜你在找的Go相关文章