diff --git a/engine/util/util.go b/engine/util/util.go index dcd3eef4..62850430 100644 --- a/engine/util/util.go +++ b/engine/util/util.go @@ -94,10 +94,22 @@ func B64ToRes(str string) (engine.Res, error) { // StructTagToFieldName returns a mapping from recommended alias to actual field // name. It returns an error if it finds a collision. It uses the `lang` tags. -func StructTagToFieldName(res engine.Res) (map[string]string, error) { +// It must be passed a ptr to a struct or it will error. +func StructTagToFieldName(stptr interface{}) (map[string]string, error) { // TODO: fallback to looking up yaml tags, although harder to parse result := make(map[string]string) // `lang` field tag -> field name - st := reflect.TypeOf(res).Elem() // elem for ptr to res + if stptr == nil { + return nil, fmt.Errorf("got nil input instead of ptr to struct") + } + typ := reflect.TypeOf(stptr) + if k := typ.Kind(); k != reflect.Ptr { // we only look at *Struct's + return nil, fmt.Errorf("input is not a ptr, got: %+v", k) + } + st := typ.Elem() // elem for ptr to struct (dereference the pointer) + if k := st.Kind(); k != reflect.Struct { // this should be a struct now + return nil, fmt.Errorf("input doesn't point to a struct, got: %+v", k) + } + for i := 0; i < st.NumField(); i++ { field := st.Field(i) name := field.Name diff --git a/engine/util/util_test.go b/engine/util/util_test.go index 661c6e9e..c6ebdc00 100644 --- a/engine/util/util_test.go +++ b/engine/util/util_test.go @@ -21,6 +21,7 @@ package util import ( "os/user" + "reflect" "strconv" "testing" ) @@ -105,3 +106,61 @@ func TestCurrentUserGroupById(t *testing.T) { t.Errorf("gid didn't match current user's: %s vs %s", strconv.Itoa(gid), currentGID) } } + +func TestStructTagToFieldName0(t *testing.T) { + type foo struct { + A string `lang:"aaa"` + B bool `lang:"bbb"` + C int64 `lang:"ccc"` + } + f := &foo{ // a ptr! + A: "hello", + B: true, + C: 13, + } + m, err := StructTagToFieldName(f) // (map[string]string, error) + if err != nil { + t.Errorf("got error: %+v", err) + return + } + t.Logf("got output: %+v", m) + expected := map[string]string{ + "aaa": "A", + "bbb": "B", + "ccc": "C", + } + if !reflect.DeepEqual(m, expected) { + t.Errorf("unexpected result") + return + } +} + +func TestStructTagToFieldName1(t *testing.T) { + type foo struct { + A string `lang:"aaa"` + B bool `lang:"bbb"` + C int64 `lang:"ccc"` + } + f := foo{ // not a ptr! + A: "hello", + B: true, + C: 13, + } + m, err := StructTagToFieldName(f) // (map[string]string, error) + if err == nil { + t.Errorf("expected error, got nil") + //return + } + t.Logf("got output: %+v", m) + t.Logf("got error: %+v", err) +} + +func TestStructTagToFieldName2(t *testing.T) { + m, err := StructTagToFieldName(nil) // (map[string]string, error) + if err == nil { + t.Errorf("expected error, got nil") + //return + } + t.Logf("got output: %+v", m) + t.Logf("got error: %+v", err) +}