diff --git a/lang/lexer.nex b/lang/lexer.nex index 265a9b6e..c1241659 100644 --- a/lang/lexer.nex +++ b/lang/lexer.nex @@ -184,6 +184,11 @@ lval.str = yylex.Text() return STRUCT_IDENTIFIER } +/func/ { + yylex.pos(lval) // our pos + lval.str = yylex.Text() + return FUNC_IDENTIFIER + } /class/ { yylex.pos(lval) // our pos lval.str = yylex.Text() diff --git a/lang/lexparse_test.go b/lang/lexparse_test.go index 85aa8aa6..da8a90c5 100644 --- a/lang/lexparse_test.go +++ b/lang/lexparse_test.go @@ -31,6 +31,7 @@ import ( "github.com/purpleidea/mgmt/util" "github.com/davecgh/go-spew/spew" + "github.com/kylelemons/godebug/pretty" ) func TestLexParse0(t *testing.T) { @@ -1606,6 +1607,168 @@ func TestLexParse0(t *testing.T) { exp: exp, }) } + { + exp := &StmtProg{ + Prog: []interfaces.Stmt{ + &StmtFunc{ + Name: "f1", + Func: &ExprFunc{ + Body: &ExprInt{ + V: 42, + }, + }, + }, + }, + } + values = append(values, test{ + name: "simple function stmt 1", + code: ` + func f1() { + 42 + } + `, + fail: false, + exp: exp, + }) + } + { + fn := &ExprFunc{ + Return: types.TypeInt, + Body: &ExprCall{ + Name: operatorFuncName, + Args: []interfaces.Expr{ + &ExprStr{ + V: "+", + }, + &ExprInt{ + V: 13, + }, + &ExprInt{ + V: 42, + }, + }, + }, + } + // sometimes, the type can get set by the parser when it's known + if err := fn.SetType(types.NewType("func() int")); err != nil { + t.Fatal("could not build type") + } + exp := &StmtProg{ + Prog: []interfaces.Stmt{ + &StmtFunc{ + Name: "f2", + Func: fn, + }, + }, + } + values = append(values, test{ + name: "simple function stmt 2", + code: ` + func f2() int { + 13 + 42 + } + `, + fail: false, + exp: exp, + }) + } + { + fn := &ExprFunc{ + Args: []*Arg{ + { + Name: "a", + Type: types.TypeInt, + }, + { + Name: "b", + //Type: &types.Type{}, + }, + }, + Return: types.TypeInt, + Body: &ExprCall{ + Name: operatorFuncName, + Args: []interfaces.Expr{ + &ExprStr{ + V: "+", + }, + &ExprVar{ + Name: "a", + }, + &ExprVar{ + Name: "b", + }, + }, + }, + } + // we can't set the type here, because it's only partially known + //if err := fn.SetType(types.NewType("func() int")); err != nil { + // t.Fatal("could not build type") + //} + exp := &StmtProg{ + Prog: []interfaces.Stmt{ + &StmtFunc{ + Name: "f3", + Func: fn, + }, + }, + } + values = append(values, test{ + name: "simple function stmt 3", + code: ` + func f3($a int, $b) int { + $a + $b + } + `, + fail: false, + exp: exp, + }) + } + { + fn := &ExprFunc{ + Args: []*Arg{ + { + Name: "x", + Type: types.TypeStr, + }, + }, + Return: types.TypeStr, + Body: &ExprCall{ + Name: operatorFuncName, + Args: []interfaces.Expr{ + &ExprStr{ + V: "+", + }, + &ExprStr{ + V: "hello", + }, + &ExprVar{ + Name: "x", + }, + }, + }, + } + if err := fn.SetType(types.NewType("func(x str) str")); err != nil { + t.Fatal("could not build type") + } + exp := &StmtProg{ + Prog: []interfaces.Stmt{ + &StmtFunc{ + Name: "f4", + Func: fn, + }, + }, + } + values = append(values, test{ + name: "simple function stmt 4", + code: ` + func f4($x str) str { + "hello" + $x + } + `, + fail: false, + exp: exp, + }) + } names := []string{} for index, test := range values { // run all the tests @@ -1647,11 +1810,27 @@ func TestLexParse0(t *testing.T) { if exp != nil { if !reflect.DeepEqual(ast, exp) { - t.Errorf("test #%d: AST did not match expected", index) - // TODO: consider making our own recursive print function - t.Logf("test #%d: actual: \n\n%s\n", index, spew.Sdump(ast)) - t.Logf("test #%d: expected: \n\n%s", index, spew.Sdump(exp)) - continue + // double check because DeepEqual is different since the func exists + diff := pretty.Compare(ast, exp) + if diff != "" { // bonus + t.Errorf("test #%d: AST did not match expected", index) + // TODO: consider making our own recursive print function + t.Logf("test #%d: actual: \n\n%s\n", index, spew.Sdump(ast)) + t.Logf("test #%d: expected: \n\n%s", index, spew.Sdump(exp)) + + // more details, for tricky cases: + diffable := &pretty.Config{ + Diffable: true, + IncludeUnexported: true, + //PrintStringers: false, + //PrintTextMarshalers: false, + //SkipZeroFields: false, + } + t.Logf("test #%d: actual: \n\n%s\n", index, diffable.Sprint(ast)) + t.Logf("test #%d: expected: \n\n%s", index, diffable.Sprint(exp)) + t.Logf("test #%d: diff:\n%s", index, diff) + continue + } } } } diff --git a/lang/parser.y b/lang/parser.y index 45f0b243..f5830531 100644 --- a/lang/parser.y +++ b/lang/parser.y @@ -87,6 +87,7 @@ func init() { %token VAR_IDENTIFIER_HX %token RES_IDENTIFIER CAPITALIZED_RES_IDENTIFIER %token IDENTIFIER CAPITALIZED_IDENTIFIER +%token FUNC_IDENTIFIER %token CLASS_IDENTIFIER INCLUDE_IDENTIFIER %token IMPORT_IDENTIFIER AS_IDENTIFIER %token COMMENT ERROR @@ -192,6 +193,60 @@ stmt: ElseBranch: $8.stmt, } } + // this is the named version, iow, a user-defined function (statement) + // `func name() { }` + // `func name() { }` + // `func name(, ) { }` +| FUNC_IDENTIFIER IDENTIFIER OPEN_PAREN args CLOSE_PAREN OPEN_CURLY expr CLOSE_CURLY + { + posLast(yylex, yyDollar) // our pos + $$.stmt = &StmtFunc{ + Name: $2.str, + Func: &ExprFunc{ + Args: $4.args, + //Return: nil, + Body: $7.expr, + }, + } + } + // `func name(...) { }` +| FUNC_IDENTIFIER IDENTIFIER OPEN_PAREN args CLOSE_PAREN type OPEN_CURLY expr CLOSE_CURLY + { + posLast(yylex, yyDollar) // our pos + fn := &ExprFunc{ + Args: $4.args, + Return: $6.typ, // return type is known + Body: $8.expr, + } + isFullyTyped := $6.typ != nil // true if set + m := make(map[string]*types.Type) + ord := []string{} + for _, a := range $4.args { + if a.Type == nil { + // at least one is unknown, can't run SetType... + isFullyTyped = false + break + } + m[a.Name] = a.Type + ord = append(ord, a.Name) + } + if isFullyTyped { + typ := &types.Type{ + Kind: types.KindFunc, + Map: m, + Ord: ord, + Out: $6.typ, + } + if err := fn.SetType(typ); err != nil { + // this will ultimately cause a parser error to occur... + yylex.Error(fmt.Sprintf("%s: %+v", ErrParseSetType, err)) + } + } + $$.stmt = &StmtFunc{ + Name: $2.str, + Func: fn, + } + } // `class name { }` | CLASS_IDENTIFIER IDENTIFIER OPEN_CURLY prog CLOSE_CURLY { @@ -325,6 +380,12 @@ expr: // TODO: var could be squashed in here directly... $$.expr = $1.expr } +| func + { + posLast(yylex, yyDollar) // our pos + // TODO: var could be squashed in here directly... + $$.expr = $1.expr + } | IF expr OPEN_CURLY expr CLOSE_CURLY ELSE OPEN_CURLY expr CLOSE_CURLY { posLast(yylex, yyDollar) // our pos @@ -699,6 +760,55 @@ var: } } ; +func: + // this is the lambda version, iow, a function as a value (expression) + // `func() { }` + // `func() { }` + // `func(, ) { }` + FUNC_IDENTIFIER OPEN_PAREN args CLOSE_PAREN OPEN_CURLY expr CLOSE_CURLY + { + posLast(yylex, yyDollar) // our pos + $$.expr = &ExprFunc{ + Args: $3.args, + //Return: nil, + Body: $6.expr, + } + } + // `func(...) { }` +| FUNC_IDENTIFIER OPEN_PAREN args CLOSE_PAREN type OPEN_CURLY expr CLOSE_CURLY + { + posLast(yylex, yyDollar) // our pos + $$.expr = &ExprFunc{ + Args: $3.args, + Return: $5.typ, // return type is known + Body: $7.expr, + } + isFullyTyped := $5.typ != nil // true if set + m := make(map[string]*types.Type) + ord := []string{} + for _, a := range $3.args { + if a.Type == nil { + // at least one is unknown, can't run SetType... + isFullyTyped = false + break + } + m[a.Name] = a.Type + ord = append(ord, a.Name) + } + if isFullyTyped { + typ := &types.Type{ + Kind: types.KindFunc, + Map: m, + Ord: ord, + Out: $5.typ, + } + if err := $$.expr.SetType(typ); err != nil { + // this will ultimately cause a parser error to occur... + yylex.Error(fmt.Sprintf("%s: %+v", ErrParseSetType, err)) + } + } + } +; args: /* end of list */ { diff --git a/lang/structs.go b/lang/structs.go index 65882412..9063a0c8 100644 --- a/lang/structs.go +++ b/lang/structs.go @@ -1542,6 +1542,86 @@ func (obj *StmtProg) Output() (*interfaces.Output, error) { }, nil } +// StmtFunc represents a user defined function. It binds the specified name to +// the supplied function in the current scope and irrespective of the order of +// definition. +type StmtFunc struct { + Name string + //Func *ExprFunc // TODO: should it be this instead? + Func interfaces.Expr // TODO: is this correct? +} + +// Apply is a general purpose iterator method that operates on any AST node. It +// is not used as the primary AST traversal function because it is less readable +// and easy to reason about than manually implementing traversal for each node. +// Nevertheless, it is a useful facility for operations that might only apply to +// a select number of node types, since they won't need extra noop iterators... +func (obj *StmtFunc) Apply(fn func(interfaces.Node) error) error { + if err := obj.Func.Apply(fn); err != nil { + return err + } + return fn(obj) +} + +// Init initializes this branch of the AST, and returns an error if it fails to +// validate. +func (obj *StmtFunc) Init(data *interfaces.Data) error { + //obj.data = data // TODO: ??? + if err := obj.Func.Init(data); err != nil { + return err + } + return nil +} + +// Interpolate returns a new node (or itself) once it has been expanded. This +// generally increases the size of the AST when it is used. It calls Interpolate +// on any child elements and builds the new node with those new node contents. +func (obj *StmtFunc) Interpolate() (interfaces.Stmt, error) { + interpolated, err := obj.Func.Interpolate() + if err != nil { + return nil, err + } + + return &StmtFunc{ + Name: obj.Name, + Func: interpolated, + }, nil +} + +// SetScope sets the scope of the child expression bound to it. It seems this is +// necessary in order to reach this, in particular in situations when a bound +// expression points to a previously bound expression. +func (obj *StmtFunc) SetScope(scope *interfaces.Scope) error { + return obj.Func.SetScope(scope) +} + +// Unify returns the list of invariants that this node produces. It recursively +// calls Unify on any children elements that exist in the AST, and returns the +// collection to the caller. +func (obj *StmtFunc) Unify() ([]interfaces.Invariant, error) { + if obj.Name == "" { + return nil, fmt.Errorf("missing function name") + } + return obj.Func.Unify() +} + +// Graph returns the reactive function graph which is expressed by this node. It +// includes any vertices produced by this node, and the appropriate edges to any +// vertices that are produced by its children. Nodes which fulfill the Expr +// interface directly produce vertices (and possible children) where as nodes +// that fulfill the Stmt interface do not produces vertices, where as their +// children might. This particular func statement adds its linked expression to +// the graph. +func (obj *StmtFunc) Graph() (*pgraph.Graph, error) { + return obj.Func.Graph() +} + +// Output for the func statement produces no output. Any values of interest come +// from the use of the func which this binds the function to. +func (obj *StmtFunc) Output() (*interfaces.Output, error) { + return (&interfaces.Output{}).Empty(), nil +} + // StmtClass represents a user defined class. It's effectively a program body // that can optionally take some parameterized inputs. // TODO: We don't currently support defining polymorphic classes (eg: different @@ -3324,6 +3404,10 @@ type ExprStructField struct { // call, that is represented by ExprCall. // XXX: this is currently not fully implemented, and parts may be incorrect. type ExprFunc struct { + Args []*Arg + Return *types.Type // return type if specified + Body interfaces.Expr + typ *types.Type V func([]types.Value) (types.Value, error) @@ -3342,7 +3426,20 @@ func (obj *ExprFunc) Apply(fn func(interfaces.Node) error) error { // String returns a short representation of this expression. // FIXME: fmt.Sprintf("func(%+v)", obj.V) fails `go vet` (bug?), so wait until // we have a better printable function value and put that here instead. -func (obj *ExprFunc) String() string { return fmt.Sprintf("func(???)") } // TODO: print nicely +//func (obj *ExprFunc) String() string { return fmt.Sprintf("func(???)") } // TODO: print nicely +func (obj *ExprFunc) String() string { + var a []string + for _, x := range obj.Args { + a = append(a, fmt.Sprintf("%s", x.String())) + } + args := strings.Join(a, ", ") + s := fmt.Sprintf("func(%s)", args) + if obj.Return != nil { + s += fmt.Sprintf(" %s", obj.Return.String()) + } + s += fmt.Sprintf(" { %s }", obj.Body.String()) + return s +} // Init initializes this branch of the AST, and returns an error if it fails to // validate. @@ -4127,6 +4224,15 @@ type Arg struct { Type *types.Type // nil if unspecified (needs to be solved for) } +// String returns a short representation of this arg. +func (obj *Arg) String() string { + s := obj.Name + if obj.Type != nil { + s += fmt.Sprintf(" %s", obj.Type.String()) + } + return s +} + // ExprIf represents an if expression which *must* have both branches, and which // returns a value. As a result, it has a type. This is different from a StmtIf, // which does not need to have both branches, and which does not return a value.