diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver.go index 511b2fe3..e6c20b7b 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver.go @@ -195,6 +195,16 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfa return unsolved, result } + // build a static list that won't get consumed + fnInvariants := []*interfaces.EqualityWrapFuncInvariant{} + for _, x := range equalities { + eq, ok := x.(*interfaces.EqualityWrapFuncInvariant) + if !ok { + continue + } + fnInvariants = append(fnInvariants, eq) + } + logf("%s: starting loop with %d equalities", Name, len(equalities)) // run until we're solved, stop consuming equalities, or type clash Loop: @@ -481,6 +491,78 @@ Loop: } } + // is there another EqualityWrapFuncInvariant with the same Expr1 pointer? + for _, fn := range fnInvariants { + if eq.Expr1 != fn.Expr1 { + continue + } + // wow they match + + if len(eq.Expr2Ord) != len(fn.Expr2Ord) { + return nil, fmt.Errorf("func arg count differs") + } + for i := range eq.Expr2Ord { + lhsName := eq.Expr2Ord[i] + lhsExpr := eq.Expr2Map[lhsName] // assume key exists + rhsName := fn.Expr2Ord[i] + rhsExpr := fn.Expr2Map[rhsName] // assume key exists + + lhsTyp, lhsExists := solved[lhsExpr] + rhsTyp, rhsExists := solved[rhsExpr] + + // both solved or both unsolved we skip + if lhsExists && !rhsExists { // teach rhs + typ, exists := funcPartials[eq.Expr1][rhsExpr] + if !exists { + funcPartials[eq.Expr1][rhsExpr] = lhsTyp // learn! + continue + } + if err := typ.Cmp(lhsTyp); err != nil { + return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial func arg") + } + } + if rhsExists && !lhsExists { // teach lhs + typ, exists := funcPartials[eq.Expr1][lhsExpr] + if !exists { + funcPartials[eq.Expr1][lhsExpr] = rhsTyp // learn! + continue + } + if err := typ.Cmp(rhsTyp); err != nil { + return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial func arg") + } + } + } + + lhsExpr := eq.Expr2Out + rhsExpr := fn.Expr2Out + + lhsTyp, lhsExists := solved[lhsExpr] + rhsTyp, rhsExists := solved[rhsExpr] + + // both solved or both unsolved we skip + if lhsExists && !rhsExists { // teach rhs + typ, exists := funcPartials[eq.Expr1][rhsExpr] + if !exists { + funcPartials[eq.Expr1][rhsExpr] = lhsTyp // learn! + continue + } + if err := typ.Cmp(lhsTyp); err != nil { + return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial func arg") + } + } + if rhsExists && !lhsExists { // teach lhs + typ, exists := funcPartials[eq.Expr1][lhsExpr] + if !exists { + funcPartials[eq.Expr1][lhsExpr] = rhsTyp // learn! + continue + } + if err := typ.Cmp(rhsTyp); err != nil { + return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial func arg") + } + } + + } + // can we solve anything? var ready = true // assume ready typ := &types.Type{ diff --git a/lang/unification/simplesolver_test.go b/lang/unification/simplesolver_test.go index cf46026f..8a2716f2 100644 --- a/lang/unification/simplesolver_test.go +++ b/lang/unification/simplesolver_test.go @@ -96,6 +96,70 @@ func TestSimpleSolver1(t *testing.T) { //experr: ErrAmbiguous, }) } + { + // ?1 = func(x ?2) ?3 + // ?1 = func(arg0 str) ?4 + // ?3 = str # needed since we don't know what the func body is + expr1 := &interfaces.ExprAny{} // ?1 + expr2 := &interfaces.ExprAny{} // ?2 + expr3 := &interfaces.ExprAny{} // ?3 + expr4 := &interfaces.ExprAny{} // ?4 + + arg0 := &interfaces.ExprAny{} // arg0 + + invarA := &interfaces.EqualityWrapFuncInvariant{ + Expr1: expr1, // Expr + Expr2Map: map[string]interfaces.Expr{ // map[string]Expr + "x": expr2, + }, + Expr2Ord: []string{"x"}, // []string + Expr2Out: expr3, // Expr + } + + invarB := &interfaces.EqualityWrapFuncInvariant{ + Expr1: expr1, // Expr + Expr2Map: map[string]interfaces.Expr{ // map[string]Expr + "arg0": arg0, + }, + Expr2Ord: []string{"arg0"}, // []string + Expr2Out: expr4, // Expr + } + + invarC := &interfaces.EqualsInvariant{ + Expr: expr3, + Type: types.NewType("str"), + } + + invarD := &interfaces.EqualsInvariant{ + Expr: arg0, + Type: types.NewType("str"), + } + + testCases = append(testCases, test{ + name: "dual functions", + invariants: []interfaces.Invariant{ + invarA, + invarB, + invarC, + invarD, + }, + expected: []interfaces.Expr{ + expr1, + expr2, + expr3, + expr4, + arg0, + }, + fail: false, + expect: map[interfaces.Expr]*types.Type{ + expr1: types.NewType("func(str) str"), + expr2: types.NewType("str"), + expr3: types.NewType("str"), + expr4: types.NewType("str"), + arg0: types.NewType("str"), + }, + }) + } names := []string{} for index, tc := range testCases { // run all the tests