diff --git a/lang/funcs/simple/simple.go b/lang/funcs/simple/simple.go index 42348dbd..5d84ccdb 100644 --- a/lang/funcs/simple/simple.go +++ b/lang/funcs/simple/simple.go @@ -32,6 +32,8 @@ package simple import ( "context" "fmt" + "reflect" + "strings" "github.com/purpleidea/mgmt/lang/funcs" "github.com/purpleidea/mgmt/lang/interfaces" @@ -196,3 +198,67 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error { } } } + +// StructRegister takes an CLI args struct with optional struct tags, and +// generates simple functions from the contained fields in the specified +// namespace. If no struct field named `func` is included, then a default +// function name which is the lower case representation of the field name will +// be used, otherwise the struct tag contents are used. If the struct tag +// contains the `-` character, then the field will be skipped. +// TODO: An alternative version of this might choose to return all of the values +// as a single giant struct. +func StructRegister(moduleName string, args interface{}) error { + if args == nil { + // programming error + return fmt.Errorf("could not convert/access our struct") + } + //fmt.Printf("A: %+v\n", args) + + val := reflect.ValueOf(args) + if val.Kind() == reflect.Ptr { // max one de-referencing + val = val.Elem() + } + typ := val.Type() + + for i := 0; i < typ.NumField(); i++ { + v := val.Field(i) // value of the field + t := typ.Field(i) // struct type, get real type with .Type + + name := strings.ToLower(t.Name) // default + if alias, ok := t.Tag.Lookup("func"); ok { + if alias == "-" { // skip + continue + } + name = alias + } + //fmt.Printf("N: %+v\n", name) // debug + if len(strings.Trim(name, "abcdefghijklmnopqrstuvwxyz_")) > 0 { + return fmt.Errorf("struct field index(%d) has invalid char(s) in function name", i) + } + + typed, err := types.TypeOf(t.Type) // reflect.Type -> (*types.Type, error) + if err != nil { + return err + } + //fmt.Printf("T: %+v\n", typed.String()) // debug + + ModuleRegister(moduleName, name, &types.FuncValue{ + T: types.NewType(fmt.Sprintf("func() %s", typed.String())), + V: func(input []types.Value) (types.Value, error) { + //if args == nil { + // // programming error + // return nil, fmt.Errorf("could not convert/access our struct") + //} + + value, err := types.ValueOf(v) // reflect.Value -> (types.Value, error) + if err != nil { + return nil, errwrap.Wrapf(err, "func `%s.%s()` has nil value", moduleName, name) + } + //fmt.Printf("V: %+v\n", value) // debug + return value, nil + }, + }) + } + + return nil +}