diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver.go index 08c09e68..3e9b976a 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver.go @@ -154,6 +154,7 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s if typ, exists := solved[eq.Expr1]; exists { // wow, now known, so tell the partials! + // TODO: this assumes typ is a list, is that guaranteed? listPartials[eq.Expr1][eq.Expr2Val] = typ.Val } @@ -211,6 +212,7 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s if typ, exists := solved[eq.Expr1]; exists { // wow, now known, so tell the partials! + // TODO: this assumes typ is a map, is that guaranteed? mapPartials[eq.Expr1][eq.Expr2Key] = typ.Key mapPartials[eq.Expr1][eq.Expr2Val] = typ.Val } @@ -281,6 +283,10 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s if typ, exists := solved[eq.Expr1]; exists { // wow, now known, so tell the partials! + // TODO: this assumes typ is a struct, is that guaranteed? + if len(typ.Ord) != len(eq.Expr2Ord) { + return nil, fmt.Errorf("struct field count differs") + } for i, name := range eq.Expr2Ord { expr := eq.Expr2Map[name] // assume key exists structPartials[eq.Expr1][expr] = typ.Map[typ.Ord[i]] // assume key exists @@ -351,6 +357,10 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s if typ, exists := solved[eq.Expr1]; exists { // wow, now known, so tell the partials! + // TODO: this assumes typ is a func, is that guaranteed? + if len(typ.Ord) != len(eq.Expr2Ord) { + return nil, fmt.Errorf("func arg count differs") + } for i, name := range eq.Expr2Ord { expr := eq.Expr2Map[name] // assume key exists funcPartials[eq.Expr1][expr] = typ.Map[typ.Ord[i]] // assume key exists diff --git a/lang/unification_test.go b/lang/unification_test.go index 6dcc351e..e8001b88 100644 --- a/lang/unification_test.go +++ b/lang/unification_test.go @@ -473,6 +473,41 @@ func TestUnification1(t *testing.T) { fail: true, }) } + { + //test "t1" { + // stringptr => getenv("GOPATH", "bug"), # bad (two args vs. one) + //} + expr := &ExprCall{ + Name: "getenv", + Args: []interfaces.Expr{ + &ExprStr{ + V: "GOPATH", + }, + &ExprStr{ + V: "bug", + }, + }, + } + stmt := &StmtProg{ + Prog: []interfaces.Stmt{ + &StmtRes{ + Kind: "test", + Name: &ExprStr{V: "t1"}, + Fields: []*StmtResField{ + { + Field: "stringptr", + Value: expr, + }, + }, + }, + }, + } + values = append(values, test{ + name: "function, wrong arg count", + ast: stmt, + fail: true, + }) + } for index, test := range values { // run all the tests t.Run(fmt.Sprintf("test #%d (%s)", index, test.name), func(t *testing.T) {