diff --git a/lang/gapi/gapi.go b/lang/gapi/gapi.go index b75fead9..208555bf 100644 --- a/lang/gapi/gapi.go +++ b/lang/gapi/gapi.go @@ -269,13 +269,18 @@ func (obj *GAPI) Cli(info *gapi.Info) (*gapi.Deploy, error) { } } logf("running type unification...") - startTime := time.Now() + + solver, err := unification.LookupDefault() + if err != nil { + return nil, errwrap.Wrapf(err, "could not get default solver") + } unifier := &unification.Unifier{ AST: iast, - Solver: unification.SimpleInvariantSolverLogger(unificationLogf), + Solver: solver, Debug: debug, Logf: unificationLogf, } + startTime := time.Now() unifyErr := unifier.Unify(context.TODO()) delta := time.Since(startTime) formatted := delta.String() diff --git a/lang/interpret_test.go b/lang/interpret_test.go index cd63617f..03854b52 100644 --- a/lang/interpret_test.go +++ b/lang/interpret_test.go @@ -458,9 +458,15 @@ func TestAstFunc1(t *testing.T) { xlogf := func(format string, v ...interface{}) { logf("unification: "+format, v...) } + solver, err := unification.LookupDefault() + if err != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: solver lookup failed with: %+v", index, err) + return + } unifier := &unification.Unifier{ AST: iast, - Solver: unification.SimpleInvariantSolverLogger(xlogf), + Solver: solver, Debug: testing.Verbose(), Logf: xlogf, } @@ -1028,9 +1034,15 @@ func TestAstFunc2(t *testing.T) { xlogf := func(format string, v ...interface{}) { logf("unification: "+format, v...) } + solver, err := unification.LookupDefault() + if err != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: solver lookup failed with: %+v", index, err) + return + } unifier := &unification.Unifier{ AST: iast, - Solver: unification.SimpleInvariantSolverLogger(xlogf), + Solver: solver, Debug: testing.Verbose(), Logf: xlogf, } @@ -1830,9 +1842,15 @@ func TestAstFunc3(t *testing.T) { xlogf := func(format string, v ...interface{}) { logf("unification: "+format, v...) } + solver, err := unification.LookupDefault() + if err != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: solver lookup failed with: %+v", index, err) + return + } unifier := &unification.Unifier{ AST: iast, - Solver: unification.SimpleInvariantSolverLogger(xlogf), + Solver: solver, Debug: testing.Verbose(), Logf: xlogf, } diff --git a/lang/lang.go b/lang/lang.go index 3c9bd657..6e03fa04 100644 --- a/lang/lang.go +++ b/lang/lang.go @@ -50,6 +50,7 @@ import ( "github.com/purpleidea/mgmt/lang/interpret" "github.com/purpleidea/mgmt/lang/parser" "github.com/purpleidea/mgmt/lang/unification" + _ "github.com/purpleidea/mgmt/lang/unification/solvers" // import so the solvers register "github.com/purpleidea/mgmt/pgraph" "github.com/purpleidea/mgmt/util" "github.com/purpleidea/mgmt/util/errwrap" @@ -225,13 +226,18 @@ func (obj *Lang) Init(ctx context.Context) error { } } obj.Logf("running type unification...") - timing = time.Now() + + solver, err := unification.LookupDefault() + if err != nil { + return errwrap.Wrapf(err, "could not get default solver") + } unifier := &unification.Unifier{ AST: obj.ast, - Solver: unification.SimpleInvariantSolverLogger(logf), + Solver: solver, Debug: obj.Debug, Logf: logf, } + timing = time.Now() // NOTE: This is the "real" Unify that runs. (This is not for deploy.) unifyErr := unifier.Unify(ctx) obj.Logf("type unification took: %s", time.Since(timing)) diff --git a/lang/unification/interfaces.go b/lang/unification/interfaces.go new file mode 100644 index 00000000..8e31deb5 --- /dev/null +++ b/lang/unification/interfaces.go @@ -0,0 +1,231 @@ +// Mgmt +// Copyright (C) 2013-2024+ James Shubin and the project contributors +// Written by James Shubin and the project contributors +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// +// Additional permission under GNU GPL version 3 section 7 +// +// If you modify this program, or any covered work, by linking or combining it +// with embedded mcl code and modules (and that the embedded mcl code and +// modules which link with this program, contain a copy of their source code in +// the authoritative form) containing parts covered by the terms of any other +// license, the licensors of this program grant you additional permission to +// convey the resulting work. Furthermore, the licensors of this program grant +// the original author, James Shubin, additional permission to update this +// additional permission if he deems it necessary to achieve the goals of this +// additional permission. + +package unification + +import ( + "context" + "fmt" + "sort" + + "github.com/purpleidea/mgmt/lang/interfaces" + "github.com/purpleidea/mgmt/lang/types" +) + +const ( + // ErrAmbiguous means we couldn't find a solution, but we weren't + // inconsistent. + ErrAmbiguous = interfaces.Error("can't unify, no equalities were consumed, we're ambiguous") +) + +// Init contains some handles that are used to initialize every solver. Each +// individual solver can choose to omit using some of the fields. +type Init struct { + Debug bool + Logf func(format string, v ...interface{}) +} + +// Solver is the general interface that any solver needs to implement. +type Solver interface { + // Init initializes the solver struct before first use. + Init(*Init) error + + // Solve performs the actual solving. It must return as soon as possible + // if the context is closed. + Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) +} + +// registeredSolvers is a global map of all possible unification solvers which +// can be used. You should never touch this map directly. Use methods like +// Register instead. +var registeredSolvers = make(map[string]func() Solver) // must initialize + +// Register takes a solver and its name and makes it available for use. It is +// commonly called in the init() method of the solver at program startup. There +// is no matching Unregister function. +func Register(name string, solver func() Solver) { + if _, exists := registeredSolvers[name]; exists { + panic(fmt.Sprintf("a solver named %s is already registered", name)) + } + + //gob.Register(solver()) + registeredSolvers[name] = solver +} + +// Lookup returns a pointer to the solver's struct. +func Lookup(name string) (Solver, error) { + solver, exists := registeredSolvers[name] + if !exists { + return nil, fmt.Errorf("not found") + } + return solver(), nil +} + +// LookupDefault attempts to return a "default" solver. +func LookupDefault() (Solver, error) { + if len(registeredSolvers) == 0 { + return nil, fmt.Errorf("no registered solvers") + } + if len(registeredSolvers) == 1 { + for _, solver := range registeredSolvers { + return solver(), nil // return the first and only one + } + } + + // TODO: Should we remove this empty string feature? + // If one was registered with no name, then use that as the default. + if solver, exists := registeredSolvers[""]; exists { // empty name + return solver(), nil + } + + return nil, fmt.Errorf("no registered default solver") +} + +// DebugSolverState helps us in understanding the state of the type unification +// solver in a more mainstream format. +// Example: +// +// solver state: +// +// * str("foo") :: str +// * call:f(str("foo")) [0xc000ac9f10] :: ?1 +// * var(x) [0xc00088d840] :: ?2 +// * param(x) [0xc00000f950] :: ?3 +// * func(x) { var(x) } [0xc0000e9680] :: ?4 +// * ?2 = ?3 +// * ?4 = func(arg0 str) ?1 +// * ?4 = func(x str) ?2 +// * ?1 = ?2 +func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string { + s := "" + + // all the relevant Exprs + count := 0 + exprs := make(map[interfaces.Expr]int) + for _, equality := range equalities { + for _, expr := range equality.ExprList() { + count++ + exprs[expr] = count // for sorting + } + } + + // print the solved Exprs first + for expr, typ := range solved { + s += fmt.Sprintf("%v :: %v\n", expr, typ) + delete(exprs, expr) + } + + sortedExprs := []interfaces.Expr{} + for k := range exprs { + sortedExprs = append(sortedExprs, k) + } + sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] }) + + // for each remaining expr, generate a shorter name than the full pointer + nextVar := 1 + shortNames := map[interfaces.Expr]string{} + for _, expr := range sortedExprs { + shortNames[expr] = fmt.Sprintf("?%d", nextVar) + nextVar++ + s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr]) + } + + // print all the equalities using the short names + for _, equality := range equalities { + switch e := equality.(type) { + case *interfaces.EqualsInvariant: + _, ok := solved[e.Expr] + if !ok { + s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type) + } else { + // if solved, then this is redundant, don't print anything + } + + case *interfaces.EqualityInvariant: + type1, ok1 := solved[e.Expr1] + type2, ok2 := solved[e.Expr2] + if !ok1 && !ok2 { + s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2]) + } else if ok1 && !ok2 { + s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2]) + } else if !ok1 && ok2 { + s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2) + } else { + // if completely solved, then this is redundant, don't print anything + } + + case *interfaces.EqualityWrapFuncInvariant: + funcType, funcOk := solved[e.Expr1] + + args := "" + argsOk := true + for i, argName := range e.Expr2Ord { + if i > 0 { + args += ", " + } + argExpr := e.Expr2Map[argName] + argType, ok := solved[argExpr] + if !ok { + args += fmt.Sprintf("%s %s", argName, shortNames[argExpr]) + argsOk = false + } else { + args += fmt.Sprintf("%s %s", argName, argType) + } + } + + outType, outOk := solved[e.Expr2Out] + + if !funcOk || !argsOk || !outOk { + if !funcOk && !outOk { + s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out]) + } else if !funcOk && outOk { + s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType) + } else if funcOk && !outOk { + s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out]) + } else { + s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType) + } + } + + case *interfaces.CallFuncArgsValueInvariant: + // skip, not used in the examples I care about + + case *interfaces.AnyInvariant: + // skip, not used in the examples I care about + + case *interfaces.SkipInvariant: + // we don't care about this one + + default: + s += fmt.Sprintf("%v\n", equality) + } + } + + return s +} diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver/simplesolver.go similarity index 90% rename from lang/unification/simplesolver.go rename to lang/unification/simplesolver/simplesolver.go index 6c3f0d0e..92c98a4c 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver/simplesolver.go @@ -27,25 +27,21 @@ // additional permission if he deems it necessary to achieve the goals of this // additional permission. -package unification // TODO: can we put this solver in a sub-package? +package simplesolver import ( "context" "fmt" - "sort" "github.com/purpleidea/mgmt/lang/interfaces" "github.com/purpleidea/mgmt/lang/types" + "github.com/purpleidea/mgmt/lang/unification" "github.com/purpleidea/mgmt/util/errwrap" ) const ( // Name is the prefix for our solver log messages. - Name = "solver: simple" - - // ErrAmbiguous means we couldn't find a solution, but we weren't - // inconsistent. - ErrAmbiguous = interfaces.Error("can't unify, no equalities were consumed, we're ambiguous") + Name = "simple" // AllowRecursion specifies whether we're allowed to use the recursive // solver or not. It uses an absurd amount of memory, and might hang @@ -61,154 +57,32 @@ const ( RecursionInvariantLimit = 5 // TODO: pick a better value ? ) -// SimpleInvariantSolverLogger is a wrapper which returns a -// SimpleInvariantSolver with the log parameter of your choice specified. The -// result satisfies the correct signature for the solver parameter of the -// Unification function. -// TODO: Get rid of this function and consider just using the struct directly. -func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) { - return func(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) { - sis := &SimpleInvariantSolver{ - Debug: false, // TODO: consider plumbing this through - Logf: logf, - } - return sis.Solve(ctx, invariants, expected) - } -} - -// DebugSolverState helps us in understanding the state of the type unification -// solver in a more mainstream format. -// Example: -// -// solver state: -// -// * str("foo") :: str -// * call:f(str("foo")) [0xc000ac9f10] :: ?1 -// * var(x) [0xc00088d840] :: ?2 -// * param(x) [0xc00000f950] :: ?3 -// * func(x) { var(x) } [0xc0000e9680] :: ?4 -// * ?2 = ?3 -// * ?4 = func(arg0 str) ?1 -// * ?4 = func(x str) ?2 -// * ?1 = ?2 -func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string { - s := "" - - // all the relevant Exprs - count := 0 - exprs := make(map[interfaces.Expr]int) - for _, equality := range equalities { - for _, expr := range equality.ExprList() { - count++ - exprs[expr] = count // for sorting - } - } - - // print the solved Exprs first - for expr, typ := range solved { - s += fmt.Sprintf("%v :: %v\n", expr, typ) - delete(exprs, expr) - } - - sortedExprs := []interfaces.Expr{} - for k := range exprs { - sortedExprs = append(sortedExprs, k) - } - sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] }) - - // for each remaining expr, generate a shorter name than the full pointer - nextVar := 1 - shortNames := map[interfaces.Expr]string{} - for _, expr := range sortedExprs { - shortNames[expr] = fmt.Sprintf("?%d", nextVar) - nextVar++ - s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr]) - } - - // print all the equalities using the short names - for _, equality := range equalities { - switch e := equality.(type) { - case *interfaces.EqualsInvariant: - _, ok := solved[e.Expr] - if !ok { - s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type) - } else { - // if solved, then this is redundant, don't print anything - } - - case *interfaces.EqualityInvariant: - type1, ok1 := solved[e.Expr1] - type2, ok2 := solved[e.Expr2] - if !ok1 && !ok2 { - s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2]) - } else if ok1 && !ok2 { - s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2]) - } else if !ok1 && ok2 { - s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2) - } else { - // if completely solved, then this is redundant, don't print anything - } - - case *interfaces.EqualityWrapFuncInvariant: - funcType, funcOk := solved[e.Expr1] - - args := "" - argsOk := true - for i, argName := range e.Expr2Ord { - if i > 0 { - args += ", " - } - argExpr := e.Expr2Map[argName] - argType, ok := solved[argExpr] - if !ok { - args += fmt.Sprintf("%s %s", argName, shortNames[argExpr]) - argsOk = false - } else { - args += fmt.Sprintf("%s %s", argName, argType) - } - } - - outType, outOk := solved[e.Expr2Out] - - if !funcOk || !argsOk || !outOk { - if !funcOk && !outOk { - s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out]) - } else if !funcOk && outOk { - s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType) - } else if funcOk && !outOk { - s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out]) - } else { - s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType) - } - } - - case *interfaces.CallFuncArgsValueInvariant: - // skip, not used in the examples I care about - - case *interfaces.AnyInvariant: - // skip, not used in the examples I care about - - case *interfaces.SkipInvariant: - // we don't care about this one - - default: - s += fmt.Sprintf("%v\n", equality) - } - } - - return s +func init() { + unification.Register(Name, func() unification.Solver { return &SimpleInvariantSolver{} }) } // SimpleInvariantSolver is an iterative invariant solver for AST expressions. // It is intended to be very simple, even if it's computationally inefficient. // TODO: Move some of the global solver constants into this struct as params. type SimpleInvariantSolver struct { + // Strategy is a series of methodologies to heuristically improve the + // solver. + Strategy map[string]string + Debug bool Logf func(format string, v ...interface{}) } +// Init contains some handles that are used to initialize the solver. +func (obj *SimpleInvariantSolver) Init(init *unification.Init) error { + obj.Debug = init.Debug + obj.Logf = init.Logf + + return nil +} + // Solve is the actual solve implementation of the solver. -func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) { +func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*unification.InvariantSolution, error) { process := func(invariants []interfaces.Invariant) ([]interfaces.Invariant, []*interfaces.ExclusiveInvariant, error) { equalities := []interfaces.Invariant{} exclusives := []*interfaces.ExclusiveInvariant{} @@ -351,7 +225,7 @@ func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interf // list all the expr's connected to expr, use pairs as chains listConnectedFn := func(expr interfaces.Expr, exprs []*interfaces.EqualityInvariant) []interfaces.Expr { - pairsType := pairs(exprs) + pairsType := unification.Pairs(exprs) return pairsType.DFS(expr) } @@ -1272,7 +1146,7 @@ Loop: obj.Logf("%s: unsolved: %+v", Name, x) } } - obj.Logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities)) + obj.Logf("%s: solver state:\n%s", Name, unification.DebugSolverState(solved, equalities)) // Lastly, we could loop through each exclusive and see // if it only has a single, easy solution. For example, @@ -1338,7 +1212,7 @@ Loop: } // let's try each combination, one at a time... - for i, ex := range exclusivesProduct(exclusives) { // [][]interfaces.Invariant + for i, ex := range unification.ExclusivesProduct(exclusives) { // [][]interfaces.Invariant select { case <-ctx.Done(): return nil, ctx.Err() @@ -1378,7 +1252,7 @@ Loop: for expr, typ := range solved { obj.Logf("%s: solved: (%p) => %+v", Name, expr, typ) } - return nil, ErrAmbiguous + return nil, unification.ErrAmbiguous } // delete used equalities, in reverse order to preserve indexing! for i := len(used) - 1; i >= 0; i-- { @@ -1403,7 +1277,7 @@ Loop: } solutions = append(solutions, invar) } - return &InvariantSolution{ + return &unification.InvariantSolution{ Solutions: solutions, }, nil } diff --git a/lang/unification/simplesolver_test.go b/lang/unification/solvers/simplesolver_test.go similarity index 94% rename from lang/unification/simplesolver_test.go rename to lang/unification/solvers/simplesolver_test.go index 84c2b47a..adbd585f 100644 --- a/lang/unification/simplesolver_test.go +++ b/lang/unification/solvers/simplesolver_test.go @@ -29,7 +29,7 @@ //go:build !root -package unification +package solvers import ( "context" @@ -40,6 +40,7 @@ import ( "github.com/purpleidea/mgmt/lang/ast" "github.com/purpleidea/mgmt/lang/interfaces" "github.com/purpleidea/mgmt/lang/types" + "github.com/purpleidea/mgmt/lang/unification" "github.com/purpleidea/mgmt/util" ) @@ -259,14 +260,27 @@ func TestSimpleSolver1(t *testing.T) { t.Run(fmt.Sprintf("test #%d (%s)", index, tc.name), func(t *testing.T) { invariants, expected, fail, expect, experr, experrstr := tc.invariants, tc.expected, tc.fail, tc.expect, tc.experr, tc.experrstr + debug := testing.Verbose() logf := func(format string, v ...interface{}) { t.Logf(fmt.Sprintf("test #%d", index)+": "+format, v...) } - debug := testing.Verbose() - solver := SimpleInvariantSolverLogger(logf) // generates a solver with built-in logging - - solution, err := solver(context.TODO(), invariants, expected) + solver, err := unification.LookupDefault() + if err != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: solver lookup failed with: %+v", index, err) + return + } + init := &unification.Init{ + Debug: debug, + Logf: logf, + } + if err := solver.Init(init); err != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: solver init failed with: %+v", index, err) + return + } + solution, err := solver.Solve(context.TODO(), invariants, expected) t.Logf("test #%d: solver completed with: %+v", index, err) if !fail && err != nil { diff --git a/lang/unification/solvers/solvers.go b/lang/unification/solvers/solvers.go new file mode 100644 index 00000000..c1c4d109 --- /dev/null +++ b/lang/unification/solvers/solvers.go @@ -0,0 +1,37 @@ +// Mgmt +// Copyright (C) 2013-2024+ James Shubin and the project contributors +// Written by James Shubin and the project contributors +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// +// Additional permission under GNU GPL version 3 section 7 +// +// If you modify this program, or any covered work, by linking or combining it +// with embedded mcl code and modules (and that the embedded mcl code and +// modules which link with this program, contain a copy of their source code in +// the authoritative form) containing parts covered by the terms of any other +// license, the licensors of this program grant you additional permission to +// convey the resulting work. Furthermore, the licensors of this program grant +// the original author, James Shubin, additional permission to update this +// additional permission if he deems it necessary to achieve the goals of this +// additional permission. + +// Package solvers is used to have a central place to import all solvers from. +// It is also a good locus to run all of the unification tests from. +package solvers + +import ( + // import so the solver registers + _ "github.com/purpleidea/mgmt/lang/unification/simplesolver" +) diff --git a/lang/unification_test.go b/lang/unification/solvers/unification_test.go similarity index 98% rename from lang/unification_test.go rename to lang/unification/solvers/unification_test.go index 429a84e2..ad281533 100644 --- a/lang/unification_test.go +++ b/lang/unification/solvers/unification_test.go @@ -29,7 +29,7 @@ //go:build !root -package lang // XXX: move this to the unification package +package solvers import ( "context" @@ -37,6 +37,7 @@ import ( "strings" "testing" + _ "github.com/purpleidea/mgmt/engine/resources" // import so the resources register "github.com/purpleidea/mgmt/lang/ast" "github.com/purpleidea/mgmt/lang/funcs" "github.com/purpleidea/mgmt/lang/funcs/vars" @@ -848,13 +849,21 @@ func TestUnification1(t *testing.T) { } // apply type unification + debug := testing.Verbose() logf := func(format string, v ...interface{}) { t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...) } + + solver, err := unification.LookupDefault() + if err != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: solver lookup failed with: %+v", index, err) + return + } unifier := &unification.Unifier{ AST: xast, - Solver: unification.SimpleInvariantSolverLogger(logf), - Debug: testing.Verbose(), + Solver: solver, + Debug: debug, Logf: logf, } err = unifier.Unify(context.TODO()) diff --git a/lang/unification/unification.go b/lang/unification/unification.go index 1b939d54..f3888cf5 100644 --- a/lang/unification/unification.go +++ b/lang/unification/unification.go @@ -46,8 +46,7 @@ type Unifier struct { AST interfaces.Stmt // Solver is the solver algorithm implementation to use. - // XXX: Solver should be a solver interface, not a function signature. - Solver func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) + Solver Solver Debug bool Logf func(format string, v ...interface{}) @@ -76,6 +75,14 @@ func (obj *Unifier) Unify(ctx context.Context) error { return fmt.Errorf("the Logf function is missing") } + init := &Init{ + Logf: obj.Logf, + Debug: obj.Debug, + } + if err := obj.Solver.Init(init); err != nil { + return err + } + if obj.Debug { obj.Logf("tree: %+v", obj.AST) } @@ -98,7 +105,7 @@ func (obj *Unifier) Unify(ctx context.Context) error { exprMap := ExprListToExprMap(exprs) // makes searching faster exprList := ExprMapToExprList(exprMap) // makes it unique (no duplicates) - solved, err := obj.Solver(ctx, invariants, exprList) + solved, err := obj.Solver.Solve(ctx, invariants, exprList) if err != nil { return err } @@ -194,14 +201,14 @@ func (obj *InvariantSolution) ExprList() []interfaces.Expr { return exprs } -// exclusivesProduct returns a list of different products produced from the +// ExclusivesProduct returns a list of different products produced from the // combinatorial product of the list of exclusives. Each ExclusiveInvariant must // contain between one and more Invariants. This takes every combination of // Invariants (choosing one from each ExclusiveInvariant) and returns that list. // In other words, if you have three exclusives, with invariants named (A1, B1), // (A2), and (A3, B3, C3) you'll get: (A1, A2, A3), (A1, A2, B3), (A1, A2, C3), // (B1, A2, A3), (B1, A2, B3), (B1, A2, C3) as results for this function call. -func exclusivesProduct(exclusives []*interfaces.ExclusiveInvariant) [][]interfaces.Invariant { +func ExclusivesProduct(exclusives []*interfaces.ExclusiveInvariant) [][]interfaces.Invariant { if len(exclusives) == 0 { return nil } diff --git a/lang/unification/util.go b/lang/unification/util.go index 39a04948..69e7a4bd 100644 --- a/lang/unification/util.go +++ b/lang/unification/util.go @@ -76,13 +76,13 @@ func ExprContains(needle interfaces.Expr, haystack []interfaces.Expr) bool { return false } -// pairs is a simple list of pairs of expressions which can be used as a simple +// Pairs is a simple list of pairs of expressions which can be used as a simple // undirected graph structure, or as a simple list of equalities. -type pairs []*interfaces.EqualityInvariant +type Pairs []*interfaces.EqualityInvariant // Vertices returns the list of vertices that the input expr is directly // connected to. -func (obj pairs) Vertices(expr interfaces.Expr) []interfaces.Expr { +func (obj Pairs) Vertices(expr interfaces.Expr) []interfaces.Expr { m := make(map[interfaces.Expr]struct{}) for _, x := range obj { if x.Expr1 == x.Expr2 { // skip circular @@ -106,7 +106,7 @@ func (obj pairs) Vertices(expr interfaces.Expr) []interfaces.Expr { } // DFS returns a depth first search for the graph, starting at the input vertex. -func (obj pairs) DFS(start interfaces.Expr) []interfaces.Expr { +func (obj Pairs) DFS(start interfaces.Expr) []interfaces.Expr { var d []interfaces.Expr // discovered var s []interfaces.Expr // stack found := false