diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver.go index 4149b05f..f14d17b0 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver.go @@ -65,9 +65,14 @@ const ( // SimpleInvariantSolver with the log parameter of your choice specified. The // result satisfies the correct signature for the solver parameter of the // Unification function. +// TODO: Get rid of this function and consider just using the struct directly. func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) { return func(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) { - return SimpleInvariantSolver(ctx, invariants, expected, logf) + sis := &SimpleInvariantSolver{ + Debug: false, // TODO: consider plumbing this through + Logf: logf, + } + return sis.Solve(ctx, invariants, expected) } } @@ -196,8 +201,14 @@ func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []inter // SimpleInvariantSolver is an iterative invariant solver for AST expressions. // It is intended to be very simple, even if it's computationally inefficient. -func SimpleInvariantSolver(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) { - debug := false // XXX: add to interface +// TODO: Move some of the global solver constants into this struct as params. +type SimpleInvariantSolver struct { + Debug bool + Logf func(format string, v ...interface{}) +} + +// Solve is the actual solve implementation of the solver. +func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) { process := func(invariants []interfaces.Invariant) ([]interfaces.Invariant, []*interfaces.ExclusiveInvariant, error) { equalities := []interfaces.Invariant{} exclusives := []*interfaces.ExclusiveInvariant{} @@ -289,7 +300,7 @@ func SimpleInvariantSolver(ctx context.Context, invariants []interfaces.Invarian } used = append(used, i) // mark equality as used up } - logf("%s: got %d equalities left after %d used up", Name, len(equalities)-len(used), len(used)) + obj.Logf("%s: got %d equalities left after %d used up", Name, len(equalities)-len(used), len(used)) // delete used equalities, in reverse order to preserve indexing! for i := len(used) - 1; i >= 0; i-- { ix := used[i] // delete index that was marked as used! @@ -304,9 +315,9 @@ func SimpleInvariantSolver(ctx context.Context, invariants []interfaces.Invarian return equalities, exclusives, nil } - logf("%s: invariants:", Name) + obj.Logf("%s: invariants:", Name) for i, x := range invariants { - logf("invariant(%d): %T: %s", i, x, x) + obj.Logf("invariant(%d): %T: %s", i, x, x) } solved := make(map[interfaces.Expr]*types.Type) @@ -392,7 +403,7 @@ func SimpleInvariantSolver(ctx context.Context, invariants []interfaces.Invarian return active } - logf("%s: starting loop with %d equalities", Name, len(equalities)) + obj.Logf("%s: starting loop with %d equalities", Name, len(equalities)) // run until we're solved, stop consuming equalities, or type clash Loop: for { @@ -408,14 +419,14 @@ Loop: // Every generator gets to run once, and if that does not change // the result, then we mark it as inactive. - logf("%s: iterate...", Name) + obj.Logf("%s: iterate...", Name) if len(equalities) == 0 && len(exclusives) == 0 && activeGenerators() == 0 { break // we're done, nothing left } used := []int{} for eqi := 0; eqi < len(equalities); eqi++ { eqx := equalities[eqi] - logf("%s: match(%T): %+v", Name, eqx, eqx) + obj.Logf("%s: match(%T): %+v", Name, eqx, eqx) // TODO: could each of these cases be implemented as a // method on the Invariant type to simplify this code? @@ -426,7 +437,7 @@ Loop: if !exists { solved[eq.Expr] = eq.Type // yay, we learned something! used = append(used, eqi) // mark equality as used up - logf("%s: solved trivial equality", Name) + obj.Logf("%s: solved trivial equality", Name) continue } // we already specified this, so check the repeat is consistent @@ -436,7 +447,7 @@ Loop: return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with equals") } used = append(used, eqi) // mark equality as duplicate - logf("%s: duplicate trivial equality", Name) + obj.Logf("%s: duplicate trivial equality", Name) continue // partials @@ -465,7 +476,7 @@ Loop: if newTyp, exists := solved[y]; !exists { solved[y] = typ // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial list val equality", Name) + obj.Logf("%s: solved partial list val equality", Name) } else if err := newTyp.Cmp(typ); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial list val equality") } @@ -504,7 +515,7 @@ Loop: solved[eq.Expr1] = typ // yay, we learned something! solved[eq.Expr2Val] = typ.Val // yay, we learned something! used = append(used, eqi) // mark equality as used up - logf("%s: solved list wrap partial", Name) + obj.Logf("%s: solved list wrap partial", Name) continue } @@ -534,7 +545,7 @@ Loop: if newTyp, exists := solved[y]; !exists { solved[y] = typ // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial map key/val equality", Name) + obj.Logf("%s: solved partial map key/val equality", Name) } else if err := newTyp.Cmp(typ); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial map key/val equality") } @@ -585,7 +596,7 @@ Loop: solved[eq.Expr2Key] = typ.Key // yay, we learned something! solved[eq.Expr2Val] = typ.Val // yay, we learned something! used = append(used, eqi) // mark equality as used up - logf("%s: solved map wrap partial", Name) + obj.Logf("%s: solved map wrap partial", Name) continue } @@ -620,7 +631,7 @@ Loop: if newTyp, exists := solved[y]; !exists { solved[y] = typ // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial struct field equality", Name) + obj.Logf("%s: solved partial struct field equality", Name) } else if err := newTyp.Cmp(typ); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial struct field equality") } @@ -669,7 +680,7 @@ Loop: solved[y] = typ.Map[name] // yay, we learned something! } used = append(used, eqi) // mark equality as used up - logf("%s: solved struct wrap partial", Name) + obj.Logf("%s: solved struct wrap partial", Name) continue } @@ -705,7 +716,7 @@ Loop: if newTyp, exists := solved[y]; !exists { solved[y] = typ // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial func arg equality", Name) + obj.Logf("%s: solved partial func arg equality", Name) } else if err := newTyp.Cmp(typ); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial func arg equality") } @@ -729,7 +740,7 @@ Loop: if newTyp, exists := solved[y]; !exists { solved[y] = typ // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial func return equality", Name) + obj.Logf("%s: solved partial func return equality", Name) } else if err := newTyp.Cmp(typ); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial func return equality") } @@ -742,10 +753,10 @@ Loop: } equivs := listConnectedFn(eq.Expr1, eqInvariants) // or equivalent! - if debug && len(equivs) > 0 { - logf("%s: equiv %d: %p %+v", Name, len(equivs), eq.Expr1, eq.Expr1) + if obj.Debug && len(equivs) > 0 { + obj.Logf("%s: equiv %d: %p %+v", Name, len(equivs), eq.Expr1, eq.Expr1) for i, x := range equivs { - logf("%s: equiv(%d): %p %+v", Name, i, x, x) + obj.Logf("%s: equiv(%d): %p %+v", Name, i, x, x) } } // This determines if a pointer is equivalent to @@ -762,13 +773,13 @@ Loop: for _, fn := range fnInvariants { // is this fn.Expr1 related by equivalency graph to eq.Expr1 ? if (eq.Expr1 != fn.Expr1) && !inEquiv(fn.Expr1) { - if debug { - logf("%s: equiv skip: %p %+v", Name, fn.Expr1, fn.Expr1) + if obj.Debug { + obj.Logf("%s: equiv skip: %p %+v", Name, fn.Expr1, fn.Expr1) } continue } - if debug { - logf("%s: equiv used: %p %+v", Name, fn.Expr1, fn.Expr1) + if obj.Debug { + obj.Logf("%s: equiv used: %p %+v", Name, fn.Expr1, fn.Expr1) } //if eq.Expr1 != fn.Expr1 { // previously // continue @@ -795,7 +806,7 @@ Loop: Expr2: rhsExpr, } if !eqContains(newEq, eqInvariants) { - logf("%s: new equality: %p %+v <-> %p %+v", Name, newEq.Expr1, newEq.Expr1, newEq.Expr2, newEq.Expr2) + obj.Logf("%s: new equality: %p %+v <-> %p %+v", Name, newEq.Expr1, newEq.Expr1, newEq.Expr2, newEq.Expr2) eqInvariants = append(eqInvariants, newEq) // TODO: add it as a generator instead? equalities = append(equalities, newEq) @@ -811,7 +822,7 @@ Loop: if newTyp, exists := solved[rhsExpr]; !exists { solved[rhsExpr] = lhsTyp // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial rhs func arg equality", Name) + obj.Logf("%s: solved partial rhs func arg equality", Name) } else if err := newTyp.Cmp(lhsTyp); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial rhs func arg equality") } @@ -831,7 +842,7 @@ Loop: if newTyp, exists := solved[lhsExpr]; !exists { solved[lhsExpr] = rhsTyp // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial lhs func arg equality", Name) + obj.Logf("%s: solved partial lhs func arg equality", Name) } else if err := newTyp.Cmp(rhsTyp); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial lhs func arg equality") } @@ -858,7 +869,7 @@ Loop: Expr2: rhsExpr, } if !eqContains(newEq, eqInvariants) { - logf("%s: new equality: %p %+v <-> %p %+v", Name, newEq.Expr1, newEq.Expr1, newEq.Expr2, newEq.Expr2) + obj.Logf("%s: new equality: %p %+v <-> %p %+v", Name, newEq.Expr1, newEq.Expr1, newEq.Expr2, newEq.Expr2) eqInvariants = append(eqInvariants, newEq) // TODO: add it as a generator instead? equalities = append(equalities, newEq) @@ -874,7 +885,7 @@ Loop: if newTyp, exists := solved[rhsExpr]; !exists { solved[rhsExpr] = lhsTyp // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial rhs func return equality", Name) + obj.Logf("%s: solved partial rhs func return equality", Name) } else if err := newTyp.Cmp(lhsTyp); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial rhs func return equality") } @@ -894,7 +905,7 @@ Loop: if newTyp, exists := solved[lhsExpr]; !exists { solved[lhsExpr] = rhsTyp // yay, we learned something! //used = append(used, i) // mark equality as used up when complete! - logf("%s: solved partial lhs func return equality", Name) + obj.Logf("%s: solved partial lhs func return equality", Name) } else if err := newTyp.Cmp(rhsTyp); err != nil { return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with partial lhs func return equality") } @@ -957,7 +968,7 @@ Loop: } solved[eq.Expr2Out] = typ.Out // yay, we learned something! used = append(used, eqi) // mark equality as used up - logf("%s: solved func wrap partial", Name) + obj.Logf("%s: solved func wrap partial", Name) continue } @@ -993,7 +1004,7 @@ Loop: solved[eq.Expr1] = typ // yay, we learned something! used = append(used, eqi) // mark equality as used up - logf("%s: solved call wrap partial", Name) + obj.Logf("%s: solved call wrap partial", Name) continue } @@ -1013,19 +1024,19 @@ Loop: return nil, errwrap.Wrapf(err, "can't unify, invariant illogicality with equality") } used = append(used, eqi) // mark equality as used up - logf("%s: duplicate regular equality", Name) + obj.Logf("%s: duplicate regular equality", Name) continue } if exists1 && !exists2 { // first equality already connects solved[eq.Expr2] = typ1 // yay, we learned something! used = append(used, eqi) // mark equality as used up - logf("%s: solved regular equality", Name) + obj.Logf("%s: solved regular equality", Name) continue } if exists2 && !exists1 { // second equality already connects solved[eq.Expr1] = typ2 // yay, we learned something! used = append(used, eqi) // mark equality as used up - logf("%s: solved regular equality", Name) + obj.Logf("%s: solved regular equality", Name) continue } @@ -1071,7 +1082,7 @@ Loop: exclusives = append(exclusives, exs...) used = append(used, eqi) // mark equality as used up - logf("%s: solved `generator` equality", Name) + obj.Logf("%s: solved `generator` equality", Name) // reset all other generator equality "inactive" flags for _, x := range equalities { gen, ok := x.(*interfaces.GeneratorInvariant) @@ -1088,7 +1099,7 @@ Loop: // this basically ensures that the expr gets solved if _, exists := solved[eq.Expr]; exists { used = append(used, eqi) // mark equality as used up - logf("%s: solved `any` equality", Name) + obj.Logf("%s: solved `any` equality", Name) } continue @@ -1134,12 +1145,12 @@ Loop: // same algorithm and code, so they're combined here... _, isSolved := isSolvedFn(solved) if isSolved { - logf("%s: solved early with %d exclusives left!", Name, len(exclusives)) + obj.Logf("%s: solved early with %d exclusives left!", Name, len(exclusives)) } else { - logf("%s: unsolved with %d exclusives left!", Name, len(exclusives)) - if debug { + obj.Logf("%s: unsolved with %d exclusives left!", Name, len(exclusives)) + if obj.Debug { for i, x := range exclusives { - logf("%s: exclusive(%d) left: %s", Name, i, x) + obj.Logf("%s: exclusive(%d) left: %s", Name, i, x) } } } @@ -1155,13 +1166,13 @@ Loop: } // check for consistency against remaining invariants - logf("%s: checking for consistency against %d exclusives...", Name, len(exclusives)) + obj.Logf("%s: checking for consistency against %d exclusives...", Name, len(exclusives)) done := []int{} for i, invar := range exclusives { // test each one to see if at least one works match, err := invar.Matches(solved) if err != nil { - logf("exclusive invar failed: %+v", invar) + obj.Logf("exclusive invar failed: %+v", invar) return nil, errwrap.Wrapf(err, "inconsistent exclusive") } if !match { @@ -1169,7 +1180,7 @@ Loop: } done = append(done, i) } - logf("%s: removed %d consistent exclusives...", Name, len(done)) + obj.Logf("%s: removed %d consistent exclusives...", Name, len(done)) // Remove exclusives that matched correctly. for i := len(done) - 1; i >= 0; i-- { @@ -1201,7 +1212,7 @@ Loop: } used = append(used, i) // mark equality as used up } - logf("%s: got %d equalities left after %d value invariants used up", Name, len(equalities)-len(used), len(used)) + obj.Logf("%s: got %d equalities left after %d value invariants used up", Name, len(equalities)-len(used), len(used)) // delete used equalities, in reverse order to preserve indexing! for i := len(used) - 1; i >= 0; i-- { ix := used[i] // delete index that was marked as used! @@ -1222,7 +1233,7 @@ Loop: } used = append(used, i) // mark equality as used up } - logf("%s: got %d equalities left after %d generators used up", Name, len(equalities)-len(used), len(used)) + obj.Logf("%s: got %d equalities left after %d generators used up", Name, len(equalities)-len(used), len(used)) // delete used equalities, in reverse order to preserve indexing! for i := len(used) - 1; i >= 0; i-- { ix := used[i] // delete index that was marked as used! @@ -1236,7 +1247,7 @@ Loop: // what have we learned for sure so far? partialSolutions := []interfaces.Invariant{} - logf("%s: %d solved, %d unsolved, and %d exclusives left", Name, len(solved), len(equalities), len(exclusives)) + obj.Logf("%s: %d solved, %d unsolved, and %d exclusives left", Name, len(solved), len(equalities), len(exclusives)) if len(exclusives) > 0 { // FIXME: can we do this loop in a deterministic, sorted way? for expr, typ := range solved { @@ -1245,16 +1256,16 @@ Loop: Type: typ, } partialSolutions = append(partialSolutions, invar) - logf("%s: solved: %+v", Name, invar) + obj.Logf("%s: solved: %+v", Name, invar) } // also include anything that hasn't been solved yet for _, x := range equalities { partialSolutions = append(partialSolutions, x) - logf("%s: unsolved: %+v", Name, x) + obj.Logf("%s: unsolved: %+v", Name, x) } } - logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities)) + obj.Logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities)) // Lastly, we could loop through each exclusive and see // if it only has a single, easy solution. For example, @@ -1266,7 +1277,7 @@ Loop: // simplify method) so that if we're lucky, we rarely // need to run the raw exclusive combinatorial solver, // which is slow. - logf("%s: attempting to simplify %d exclusives...", Name, len(exclusives)) + obj.Logf("%s: attempting to simplify %d exclusives...", Name, len(exclusives)) done = []int{} // clear for re-use simplified := []interfaces.Invariant{} @@ -1275,13 +1286,13 @@ Loop: // exclusives... We look at each individually. s, err := invar.Simplify(partialSolutions) // XXX: pass in the solver? if err != nil { - logf("exclusive simplification failed: %+v", invar) + obj.Logf("exclusive simplification failed: %+v", invar) continue } done = append(done, i) simplified = append(simplified, s...) } - logf("%s: simplified %d exclusives...", Name, len(done)) + obj.Logf("%s: simplified %d exclusives...", Name, len(done)) // Remove exclusives that matched correctly. for i := len(done) - 1; i >= 0; i-- { @@ -1307,12 +1318,12 @@ Loop: // exclusive solver with a real SAT solver algorithm. if !AllowRecursion || len(exclusives) > RecursionInvariantLimit { - logf("%s: %d solved, %d unsolved, and %d exclusives left", Name, len(solved), len(equalities), len(exclusives)) + obj.Logf("%s: %d solved, %d unsolved, and %d exclusives left", Name, len(solved), len(equalities), len(exclusives)) for i, eq := range equalities { - logf("%s: (%d) equality: %s", Name, i, eq) + obj.Logf("%s: (%d) equality: %s", Name, i, eq) } for i, ex := range exclusives { - logf("%s: (%d) exclusive: %s", Name, i, ex) + obj.Logf("%s: (%d) exclusive: %s", Name, i, ex) } // these can be very slow, so try to avoid them @@ -1327,7 +1338,7 @@ Loop: default: // pass } - logf("%s: exclusive(%d):\n%+v", Name, i, ex) + obj.Logf("%s: exclusive(%d):\n%+v", Name, i, ex) // we could waste a lot of cpu, and start from // the beginning, but instead we could use the // list of known solutions found and continue! @@ -1336,29 +1347,29 @@ Loop: recursiveInvariants = append(recursiveInvariants, partialSolutions...) recursiveInvariants = append(recursiveInvariants, ex...) // FIXME: implement RecursionDepthLimit - logf("%s: recursing...", Name) - solution, err := SimpleInvariantSolver(ctx, recursiveInvariants, expected, logf) + obj.Logf("%s: recursing...", Name) + solution, err := obj.Solve(ctx, recursiveInvariants, expected) if err != nil { - logf("%s: recursive solution failed: %+v", Name, err) + obj.Logf("%s: recursive solution failed: %+v", Name, err) continue // no solution found here... } // solution found! - logf("%s: recursive solution found!", Name) + obj.Logf("%s: recursive solution found!", Name) return solution, nil } // TODO: print ambiguity - logf("%s: ================ ambiguity ================", Name) + obj.Logf("%s: ================ ambiguity ================", Name) unsolved, isSolved := isSolvedFn(solved) - logf("%s: isSolved: %+v", Name, isSolved) + obj.Logf("%s: isSolved: %+v", Name, isSolved) for _, x := range equalities { - logf("%s: unsolved equality: %+v", Name, x) + obj.Logf("%s: unsolved equality: %+v", Name, x) } for x := range unsolved { - logf("%s: unsolved expected: (%p) %+v", Name, x, x) + obj.Logf("%s: unsolved expected: (%p) %+v", Name, x, x) } for expr, typ := range solved { - logf("%s: solved: (%p) => %+v", Name, expr, typ) + obj.Logf("%s: solved: (%p) => %+v", Name, expr, typ) } return nil, ErrAmbiguous }