lang: Structurally refactor type unification
This will make it easier to add new solvers and also cleans up some pending issues.
This commit is contained in:
@@ -269,13 +269,18 @@ func (obj *GAPI) Cli(info *gapi.Info) (*gapi.Deploy, error) {
|
||||
}
|
||||
}
|
||||
logf("running type unification...")
|
||||
startTime := time.Now()
|
||||
|
||||
solver, err := unification.LookupDefault()
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(err, "could not get default solver")
|
||||
}
|
||||
unifier := &unification.Unifier{
|
||||
AST: iast,
|
||||
Solver: unification.SimpleInvariantSolverLogger(unificationLogf),
|
||||
Solver: solver,
|
||||
Debug: debug,
|
||||
Logf: unificationLogf,
|
||||
}
|
||||
startTime := time.Now()
|
||||
unifyErr := unifier.Unify(context.TODO())
|
||||
delta := time.Since(startTime)
|
||||
formatted := delta.String()
|
||||
|
||||
@@ -458,9 +458,15 @@ func TestAstFunc1(t *testing.T) {
|
||||
xlogf := func(format string, v ...interface{}) {
|
||||
logf("unification: "+format, v...)
|
||||
}
|
||||
solver, err := unification.LookupDefault()
|
||||
if err != nil {
|
||||
t.Errorf("test #%d: FAIL", index)
|
||||
t.Errorf("test #%d: solver lookup failed with: %+v", index, err)
|
||||
return
|
||||
}
|
||||
unifier := &unification.Unifier{
|
||||
AST: iast,
|
||||
Solver: unification.SimpleInvariantSolverLogger(xlogf),
|
||||
Solver: solver,
|
||||
Debug: testing.Verbose(),
|
||||
Logf: xlogf,
|
||||
}
|
||||
@@ -1028,9 +1034,15 @@ func TestAstFunc2(t *testing.T) {
|
||||
xlogf := func(format string, v ...interface{}) {
|
||||
logf("unification: "+format, v...)
|
||||
}
|
||||
solver, err := unification.LookupDefault()
|
||||
if err != nil {
|
||||
t.Errorf("test #%d: FAIL", index)
|
||||
t.Errorf("test #%d: solver lookup failed with: %+v", index, err)
|
||||
return
|
||||
}
|
||||
unifier := &unification.Unifier{
|
||||
AST: iast,
|
||||
Solver: unification.SimpleInvariantSolverLogger(xlogf),
|
||||
Solver: solver,
|
||||
Debug: testing.Verbose(),
|
||||
Logf: xlogf,
|
||||
}
|
||||
@@ -1830,9 +1842,15 @@ func TestAstFunc3(t *testing.T) {
|
||||
xlogf := func(format string, v ...interface{}) {
|
||||
logf("unification: "+format, v...)
|
||||
}
|
||||
solver, err := unification.LookupDefault()
|
||||
if err != nil {
|
||||
t.Errorf("test #%d: FAIL", index)
|
||||
t.Errorf("test #%d: solver lookup failed with: %+v", index, err)
|
||||
return
|
||||
}
|
||||
unifier := &unification.Unifier{
|
||||
AST: iast,
|
||||
Solver: unification.SimpleInvariantSolverLogger(xlogf),
|
||||
Solver: solver,
|
||||
Debug: testing.Verbose(),
|
||||
Logf: xlogf,
|
||||
}
|
||||
|
||||
10
lang/lang.go
10
lang/lang.go
@@ -50,6 +50,7 @@ import (
|
||||
"github.com/purpleidea/mgmt/lang/interpret"
|
||||
"github.com/purpleidea/mgmt/lang/parser"
|
||||
"github.com/purpleidea/mgmt/lang/unification"
|
||||
_ "github.com/purpleidea/mgmt/lang/unification/solvers" // import so the solvers register
|
||||
"github.com/purpleidea/mgmt/pgraph"
|
||||
"github.com/purpleidea/mgmt/util"
|
||||
"github.com/purpleidea/mgmt/util/errwrap"
|
||||
@@ -225,13 +226,18 @@ func (obj *Lang) Init(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
obj.Logf("running type unification...")
|
||||
timing = time.Now()
|
||||
|
||||
solver, err := unification.LookupDefault()
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(err, "could not get default solver")
|
||||
}
|
||||
unifier := &unification.Unifier{
|
||||
AST: obj.ast,
|
||||
Solver: unification.SimpleInvariantSolverLogger(logf),
|
||||
Solver: solver,
|
||||
Debug: obj.Debug,
|
||||
Logf: logf,
|
||||
}
|
||||
timing = time.Now()
|
||||
// NOTE: This is the "real" Unify that runs. (This is not for deploy.)
|
||||
unifyErr := unifier.Unify(ctx)
|
||||
obj.Logf("type unification took: %s", time.Since(timing))
|
||||
|
||||
231
lang/unification/interfaces.go
Normal file
231
lang/unification/interfaces.go
Normal file
@@ -0,0 +1,231 @@
|
||||
// Mgmt
|
||||
// Copyright (C) 2013-2024+ James Shubin and the project contributors
|
||||
// Written by James Shubin <james@shubin.ca> and the project contributors
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
//
|
||||
// Additional permission under GNU GPL version 3 section 7
|
||||
//
|
||||
// If you modify this program, or any covered work, by linking or combining it
|
||||
// with embedded mcl code and modules (and that the embedded mcl code and
|
||||
// modules which link with this program, contain a copy of their source code in
|
||||
// the authoritative form) containing parts covered by the terms of any other
|
||||
// license, the licensors of this program grant you additional permission to
|
||||
// convey the resulting work. Furthermore, the licensors of this program grant
|
||||
// the original author, James Shubin, additional permission to update this
|
||||
// additional permission if he deems it necessary to achieve the goals of this
|
||||
// additional permission.
|
||||
|
||||
package unification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/purpleidea/mgmt/lang/interfaces"
|
||||
"github.com/purpleidea/mgmt/lang/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// ErrAmbiguous means we couldn't find a solution, but we weren't
|
||||
// inconsistent.
|
||||
ErrAmbiguous = interfaces.Error("can't unify, no equalities were consumed, we're ambiguous")
|
||||
)
|
||||
|
||||
// Init contains some handles that are used to initialize every solver. Each
|
||||
// individual solver can choose to omit using some of the fields.
|
||||
type Init struct {
|
||||
Debug bool
|
||||
Logf func(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// Solver is the general interface that any solver needs to implement.
|
||||
type Solver interface {
|
||||
// Init initializes the solver struct before first use.
|
||||
Init(*Init) error
|
||||
|
||||
// Solve performs the actual solving. It must return as soon as possible
|
||||
// if the context is closed.
|
||||
Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error)
|
||||
}
|
||||
|
||||
// registeredSolvers is a global map of all possible unification solvers which
|
||||
// can be used. You should never touch this map directly. Use methods like
|
||||
// Register instead.
|
||||
var registeredSolvers = make(map[string]func() Solver) // must initialize
|
||||
|
||||
// Register takes a solver and its name and makes it available for use. It is
|
||||
// commonly called in the init() method of the solver at program startup. There
|
||||
// is no matching Unregister function.
|
||||
func Register(name string, solver func() Solver) {
|
||||
if _, exists := registeredSolvers[name]; exists {
|
||||
panic(fmt.Sprintf("a solver named %s is already registered", name))
|
||||
}
|
||||
|
||||
//gob.Register(solver())
|
||||
registeredSolvers[name] = solver
|
||||
}
|
||||
|
||||
// Lookup returns a pointer to the solver's struct.
|
||||
func Lookup(name string) (Solver, error) {
|
||||
solver, exists := registeredSolvers[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("not found")
|
||||
}
|
||||
return solver(), nil
|
||||
}
|
||||
|
||||
// LookupDefault attempts to return a "default" solver.
|
||||
func LookupDefault() (Solver, error) {
|
||||
if len(registeredSolvers) == 0 {
|
||||
return nil, fmt.Errorf("no registered solvers")
|
||||
}
|
||||
if len(registeredSolvers) == 1 {
|
||||
for _, solver := range registeredSolvers {
|
||||
return solver(), nil // return the first and only one
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Should we remove this empty string feature?
|
||||
// If one was registered with no name, then use that as the default.
|
||||
if solver, exists := registeredSolvers[""]; exists { // empty name
|
||||
return solver(), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no registered default solver")
|
||||
}
|
||||
|
||||
// DebugSolverState helps us in understanding the state of the type unification
|
||||
// solver in a more mainstream format.
|
||||
// Example:
|
||||
//
|
||||
// solver state:
|
||||
//
|
||||
// * str("foo") :: str
|
||||
// * call:f(str("foo")) [0xc000ac9f10] :: ?1
|
||||
// * var(x) [0xc00088d840] :: ?2
|
||||
// * param(x) [0xc00000f950] :: ?3
|
||||
// * func(x) { var(x) } [0xc0000e9680] :: ?4
|
||||
// * ?2 = ?3
|
||||
// * ?4 = func(arg0 str) ?1
|
||||
// * ?4 = func(x str) ?2
|
||||
// * ?1 = ?2
|
||||
func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string {
|
||||
s := ""
|
||||
|
||||
// all the relevant Exprs
|
||||
count := 0
|
||||
exprs := make(map[interfaces.Expr]int)
|
||||
for _, equality := range equalities {
|
||||
for _, expr := range equality.ExprList() {
|
||||
count++
|
||||
exprs[expr] = count // for sorting
|
||||
}
|
||||
}
|
||||
|
||||
// print the solved Exprs first
|
||||
for expr, typ := range solved {
|
||||
s += fmt.Sprintf("%v :: %v\n", expr, typ)
|
||||
delete(exprs, expr)
|
||||
}
|
||||
|
||||
sortedExprs := []interfaces.Expr{}
|
||||
for k := range exprs {
|
||||
sortedExprs = append(sortedExprs, k)
|
||||
}
|
||||
sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] })
|
||||
|
||||
// for each remaining expr, generate a shorter name than the full pointer
|
||||
nextVar := 1
|
||||
shortNames := map[interfaces.Expr]string{}
|
||||
for _, expr := range sortedExprs {
|
||||
shortNames[expr] = fmt.Sprintf("?%d", nextVar)
|
||||
nextVar++
|
||||
s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr])
|
||||
}
|
||||
|
||||
// print all the equalities using the short names
|
||||
for _, equality := range equalities {
|
||||
switch e := equality.(type) {
|
||||
case *interfaces.EqualsInvariant:
|
||||
_, ok := solved[e.Expr]
|
||||
if !ok {
|
||||
s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type)
|
||||
} else {
|
||||
// if solved, then this is redundant, don't print anything
|
||||
}
|
||||
|
||||
case *interfaces.EqualityInvariant:
|
||||
type1, ok1 := solved[e.Expr1]
|
||||
type2, ok2 := solved[e.Expr2]
|
||||
if !ok1 && !ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2])
|
||||
} else if ok1 && !ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2])
|
||||
} else if !ok1 && ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2)
|
||||
} else {
|
||||
// if completely solved, then this is redundant, don't print anything
|
||||
}
|
||||
|
||||
case *interfaces.EqualityWrapFuncInvariant:
|
||||
funcType, funcOk := solved[e.Expr1]
|
||||
|
||||
args := ""
|
||||
argsOk := true
|
||||
for i, argName := range e.Expr2Ord {
|
||||
if i > 0 {
|
||||
args += ", "
|
||||
}
|
||||
argExpr := e.Expr2Map[argName]
|
||||
argType, ok := solved[argExpr]
|
||||
if !ok {
|
||||
args += fmt.Sprintf("%s %s", argName, shortNames[argExpr])
|
||||
argsOk = false
|
||||
} else {
|
||||
args += fmt.Sprintf("%s %s", argName, argType)
|
||||
}
|
||||
}
|
||||
|
||||
outType, outOk := solved[e.Expr2Out]
|
||||
|
||||
if !funcOk || !argsOk || !outOk {
|
||||
if !funcOk && !outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out])
|
||||
} else if !funcOk && outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType)
|
||||
} else if funcOk && !outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out])
|
||||
} else {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType)
|
||||
}
|
||||
}
|
||||
|
||||
case *interfaces.CallFuncArgsValueInvariant:
|
||||
// skip, not used in the examples I care about
|
||||
|
||||
case *interfaces.AnyInvariant:
|
||||
// skip, not used in the examples I care about
|
||||
|
||||
case *interfaces.SkipInvariant:
|
||||
// we don't care about this one
|
||||
|
||||
default:
|
||||
s += fmt.Sprintf("%v\n", equality)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
@@ -27,25 +27,21 @@
|
||||
// additional permission if he deems it necessary to achieve the goals of this
|
||||
// additional permission.
|
||||
|
||||
package unification // TODO: can we put this solver in a sub-package?
|
||||
package simplesolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/purpleidea/mgmt/lang/interfaces"
|
||||
"github.com/purpleidea/mgmt/lang/types"
|
||||
"github.com/purpleidea/mgmt/lang/unification"
|
||||
"github.com/purpleidea/mgmt/util/errwrap"
|
||||
)
|
||||
|
||||
const (
|
||||
// Name is the prefix for our solver log messages.
|
||||
Name = "solver: simple"
|
||||
|
||||
// ErrAmbiguous means we couldn't find a solution, but we weren't
|
||||
// inconsistent.
|
||||
ErrAmbiguous = interfaces.Error("can't unify, no equalities were consumed, we're ambiguous")
|
||||
Name = "simple"
|
||||
|
||||
// AllowRecursion specifies whether we're allowed to use the recursive
|
||||
// solver or not. It uses an absurd amount of memory, and might hang
|
||||
@@ -61,154 +57,32 @@ const (
|
||||
RecursionInvariantLimit = 5 // TODO: pick a better value ?
|
||||
)
|
||||
|
||||
// SimpleInvariantSolverLogger is a wrapper which returns a
|
||||
// SimpleInvariantSolver with the log parameter of your choice specified. The
|
||||
// result satisfies the correct signature for the solver parameter of the
|
||||
// Unification function.
|
||||
// TODO: Get rid of this function and consider just using the struct directly.
|
||||
func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) {
|
||||
return func(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) {
|
||||
sis := &SimpleInvariantSolver{
|
||||
Debug: false, // TODO: consider plumbing this through
|
||||
Logf: logf,
|
||||
}
|
||||
return sis.Solve(ctx, invariants, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// DebugSolverState helps us in understanding the state of the type unification
|
||||
// solver in a more mainstream format.
|
||||
// Example:
|
||||
//
|
||||
// solver state:
|
||||
//
|
||||
// * str("foo") :: str
|
||||
// * call:f(str("foo")) [0xc000ac9f10] :: ?1
|
||||
// * var(x) [0xc00088d840] :: ?2
|
||||
// * param(x) [0xc00000f950] :: ?3
|
||||
// * func(x) { var(x) } [0xc0000e9680] :: ?4
|
||||
// * ?2 = ?3
|
||||
// * ?4 = func(arg0 str) ?1
|
||||
// * ?4 = func(x str) ?2
|
||||
// * ?1 = ?2
|
||||
func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string {
|
||||
s := ""
|
||||
|
||||
// all the relevant Exprs
|
||||
count := 0
|
||||
exprs := make(map[interfaces.Expr]int)
|
||||
for _, equality := range equalities {
|
||||
for _, expr := range equality.ExprList() {
|
||||
count++
|
||||
exprs[expr] = count // for sorting
|
||||
}
|
||||
}
|
||||
|
||||
// print the solved Exprs first
|
||||
for expr, typ := range solved {
|
||||
s += fmt.Sprintf("%v :: %v\n", expr, typ)
|
||||
delete(exprs, expr)
|
||||
}
|
||||
|
||||
sortedExprs := []interfaces.Expr{}
|
||||
for k := range exprs {
|
||||
sortedExprs = append(sortedExprs, k)
|
||||
}
|
||||
sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] })
|
||||
|
||||
// for each remaining expr, generate a shorter name than the full pointer
|
||||
nextVar := 1
|
||||
shortNames := map[interfaces.Expr]string{}
|
||||
for _, expr := range sortedExprs {
|
||||
shortNames[expr] = fmt.Sprintf("?%d", nextVar)
|
||||
nextVar++
|
||||
s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr])
|
||||
}
|
||||
|
||||
// print all the equalities using the short names
|
||||
for _, equality := range equalities {
|
||||
switch e := equality.(type) {
|
||||
case *interfaces.EqualsInvariant:
|
||||
_, ok := solved[e.Expr]
|
||||
if !ok {
|
||||
s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type)
|
||||
} else {
|
||||
// if solved, then this is redundant, don't print anything
|
||||
}
|
||||
|
||||
case *interfaces.EqualityInvariant:
|
||||
type1, ok1 := solved[e.Expr1]
|
||||
type2, ok2 := solved[e.Expr2]
|
||||
if !ok1 && !ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2])
|
||||
} else if ok1 && !ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2])
|
||||
} else if !ok1 && ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2)
|
||||
} else {
|
||||
// if completely solved, then this is redundant, don't print anything
|
||||
}
|
||||
|
||||
case *interfaces.EqualityWrapFuncInvariant:
|
||||
funcType, funcOk := solved[e.Expr1]
|
||||
|
||||
args := ""
|
||||
argsOk := true
|
||||
for i, argName := range e.Expr2Ord {
|
||||
if i > 0 {
|
||||
args += ", "
|
||||
}
|
||||
argExpr := e.Expr2Map[argName]
|
||||
argType, ok := solved[argExpr]
|
||||
if !ok {
|
||||
args += fmt.Sprintf("%s %s", argName, shortNames[argExpr])
|
||||
argsOk = false
|
||||
} else {
|
||||
args += fmt.Sprintf("%s %s", argName, argType)
|
||||
}
|
||||
}
|
||||
|
||||
outType, outOk := solved[e.Expr2Out]
|
||||
|
||||
if !funcOk || !argsOk || !outOk {
|
||||
if !funcOk && !outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out])
|
||||
} else if !funcOk && outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType)
|
||||
} else if funcOk && !outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out])
|
||||
} else {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType)
|
||||
}
|
||||
}
|
||||
|
||||
case *interfaces.CallFuncArgsValueInvariant:
|
||||
// skip, not used in the examples I care about
|
||||
|
||||
case *interfaces.AnyInvariant:
|
||||
// skip, not used in the examples I care about
|
||||
|
||||
case *interfaces.SkipInvariant:
|
||||
// we don't care about this one
|
||||
|
||||
default:
|
||||
s += fmt.Sprintf("%v\n", equality)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
func init() {
|
||||
unification.Register(Name, func() unification.Solver { return &SimpleInvariantSolver{} })
|
||||
}
|
||||
|
||||
// SimpleInvariantSolver is an iterative invariant solver for AST expressions.
|
||||
// It is intended to be very simple, even if it's computationally inefficient.
|
||||
// TODO: Move some of the global solver constants into this struct as params.
|
||||
type SimpleInvariantSolver struct {
|
||||
// Strategy is a series of methodologies to heuristically improve the
|
||||
// solver.
|
||||
Strategy map[string]string
|
||||
|
||||
Debug bool
|
||||
Logf func(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// Init contains some handles that are used to initialize the solver.
|
||||
func (obj *SimpleInvariantSolver) Init(init *unification.Init) error {
|
||||
obj.Debug = init.Debug
|
||||
obj.Logf = init.Logf
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Solve is the actual solve implementation of the solver.
|
||||
func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) {
|
||||
func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*unification.InvariantSolution, error) {
|
||||
process := func(invariants []interfaces.Invariant) ([]interfaces.Invariant, []*interfaces.ExclusiveInvariant, error) {
|
||||
equalities := []interfaces.Invariant{}
|
||||
exclusives := []*interfaces.ExclusiveInvariant{}
|
||||
@@ -351,7 +225,7 @@ func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interf
|
||||
|
||||
// list all the expr's connected to expr, use pairs as chains
|
||||
listConnectedFn := func(expr interfaces.Expr, exprs []*interfaces.EqualityInvariant) []interfaces.Expr {
|
||||
pairsType := pairs(exprs)
|
||||
pairsType := unification.Pairs(exprs)
|
||||
return pairsType.DFS(expr)
|
||||
}
|
||||
|
||||
@@ -1272,7 +1146,7 @@ Loop:
|
||||
obj.Logf("%s: unsolved: %+v", Name, x)
|
||||
}
|
||||
}
|
||||
obj.Logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities))
|
||||
obj.Logf("%s: solver state:\n%s", Name, unification.DebugSolverState(solved, equalities))
|
||||
|
||||
// Lastly, we could loop through each exclusive and see
|
||||
// if it only has a single, easy solution. For example,
|
||||
@@ -1338,7 +1212,7 @@ Loop:
|
||||
}
|
||||
|
||||
// let's try each combination, one at a time...
|
||||
for i, ex := range exclusivesProduct(exclusives) { // [][]interfaces.Invariant
|
||||
for i, ex := range unification.ExclusivesProduct(exclusives) { // [][]interfaces.Invariant
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -1378,7 +1252,7 @@ Loop:
|
||||
for expr, typ := range solved {
|
||||
obj.Logf("%s: solved: (%p) => %+v", Name, expr, typ)
|
||||
}
|
||||
return nil, ErrAmbiguous
|
||||
return nil, unification.ErrAmbiguous
|
||||
}
|
||||
// delete used equalities, in reverse order to preserve indexing!
|
||||
for i := len(used) - 1; i >= 0; i-- {
|
||||
@@ -1403,7 +1277,7 @@ Loop:
|
||||
}
|
||||
solutions = append(solutions, invar)
|
||||
}
|
||||
return &InvariantSolution{
|
||||
return &unification.InvariantSolution{
|
||||
Solutions: solutions,
|
||||
}, nil
|
||||
}
|
||||
@@ -29,7 +29,7 @@
|
||||
|
||||
//go:build !root
|
||||
|
||||
package unification
|
||||
package solvers
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/purpleidea/mgmt/lang/ast"
|
||||
"github.com/purpleidea/mgmt/lang/interfaces"
|
||||
"github.com/purpleidea/mgmt/lang/types"
|
||||
"github.com/purpleidea/mgmt/lang/unification"
|
||||
"github.com/purpleidea/mgmt/util"
|
||||
)
|
||||
|
||||
@@ -259,14 +260,27 @@ func TestSimpleSolver1(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("test #%d (%s)", index, tc.name), func(t *testing.T) {
|
||||
invariants, expected, fail, expect, experr, experrstr := tc.invariants, tc.expected, tc.fail, tc.expect, tc.experr, tc.experrstr
|
||||
|
||||
debug := testing.Verbose()
|
||||
logf := func(format string, v ...interface{}) {
|
||||
t.Logf(fmt.Sprintf("test #%d", index)+": "+format, v...)
|
||||
}
|
||||
debug := testing.Verbose()
|
||||
|
||||
solver := SimpleInvariantSolverLogger(logf) // generates a solver with built-in logging
|
||||
|
||||
solution, err := solver(context.TODO(), invariants, expected)
|
||||
solver, err := unification.LookupDefault()
|
||||
if err != nil {
|
||||
t.Errorf("test #%d: FAIL", index)
|
||||
t.Errorf("test #%d: solver lookup failed with: %+v", index, err)
|
||||
return
|
||||
}
|
||||
init := &unification.Init{
|
||||
Debug: debug,
|
||||
Logf: logf,
|
||||
}
|
||||
if err := solver.Init(init); err != nil {
|
||||
t.Errorf("test #%d: FAIL", index)
|
||||
t.Errorf("test #%d: solver init failed with: %+v", index, err)
|
||||
return
|
||||
}
|
||||
solution, err := solver.Solve(context.TODO(), invariants, expected)
|
||||
t.Logf("test #%d: solver completed with: %+v", index, err)
|
||||
|
||||
if !fail && err != nil {
|
||||
37
lang/unification/solvers/solvers.go
Normal file
37
lang/unification/solvers/solvers.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Mgmt
|
||||
// Copyright (C) 2013-2024+ James Shubin and the project contributors
|
||||
// Written by James Shubin <james@shubin.ca> and the project contributors
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
//
|
||||
// Additional permission under GNU GPL version 3 section 7
|
||||
//
|
||||
// If you modify this program, or any covered work, by linking or combining it
|
||||
// with embedded mcl code and modules (and that the embedded mcl code and
|
||||
// modules which link with this program, contain a copy of their source code in
|
||||
// the authoritative form) containing parts covered by the terms of any other
|
||||
// license, the licensors of this program grant you additional permission to
|
||||
// convey the resulting work. Furthermore, the licensors of this program grant
|
||||
// the original author, James Shubin, additional permission to update this
|
||||
// additional permission if he deems it necessary to achieve the goals of this
|
||||
// additional permission.
|
||||
|
||||
// Package solvers is used to have a central place to import all solvers from.
|
||||
// It is also a good locus to run all of the unification tests from.
|
||||
package solvers
|
||||
|
||||
import (
|
||||
// import so the solver registers
|
||||
_ "github.com/purpleidea/mgmt/lang/unification/simplesolver"
|
||||
)
|
||||
@@ -29,7 +29,7 @@
|
||||
|
||||
//go:build !root
|
||||
|
||||
package lang // XXX: move this to the unification package
|
||||
package solvers
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/purpleidea/mgmt/engine/resources" // import so the resources register
|
||||
"github.com/purpleidea/mgmt/lang/ast"
|
||||
"github.com/purpleidea/mgmt/lang/funcs"
|
||||
"github.com/purpleidea/mgmt/lang/funcs/vars"
|
||||
@@ -848,13 +849,21 @@ func TestUnification1(t *testing.T) {
|
||||
}
|
||||
|
||||
// apply type unification
|
||||
debug := testing.Verbose()
|
||||
logf := func(format string, v ...interface{}) {
|
||||
t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...)
|
||||
}
|
||||
|
||||
solver, err := unification.LookupDefault()
|
||||
if err != nil {
|
||||
t.Errorf("test #%d: FAIL", index)
|
||||
t.Errorf("test #%d: solver lookup failed with: %+v", index, err)
|
||||
return
|
||||
}
|
||||
unifier := &unification.Unifier{
|
||||
AST: xast,
|
||||
Solver: unification.SimpleInvariantSolverLogger(logf),
|
||||
Debug: testing.Verbose(),
|
||||
Solver: solver,
|
||||
Debug: debug,
|
||||
Logf: logf,
|
||||
}
|
||||
err = unifier.Unify(context.TODO())
|
||||
@@ -46,8 +46,7 @@ type Unifier struct {
|
||||
AST interfaces.Stmt
|
||||
|
||||
// Solver is the solver algorithm implementation to use.
|
||||
// XXX: Solver should be a solver interface, not a function signature.
|
||||
Solver func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error)
|
||||
Solver Solver
|
||||
|
||||
Debug bool
|
||||
Logf func(format string, v ...interface{})
|
||||
@@ -76,6 +75,14 @@ func (obj *Unifier) Unify(ctx context.Context) error {
|
||||
return fmt.Errorf("the Logf function is missing")
|
||||
}
|
||||
|
||||
init := &Init{
|
||||
Logf: obj.Logf,
|
||||
Debug: obj.Debug,
|
||||
}
|
||||
if err := obj.Solver.Init(init); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if obj.Debug {
|
||||
obj.Logf("tree: %+v", obj.AST)
|
||||
}
|
||||
@@ -98,7 +105,7 @@ func (obj *Unifier) Unify(ctx context.Context) error {
|
||||
exprMap := ExprListToExprMap(exprs) // makes searching faster
|
||||
exprList := ExprMapToExprList(exprMap) // makes it unique (no duplicates)
|
||||
|
||||
solved, err := obj.Solver(ctx, invariants, exprList)
|
||||
solved, err := obj.Solver.Solve(ctx, invariants, exprList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -194,14 +201,14 @@ func (obj *InvariantSolution) ExprList() []interfaces.Expr {
|
||||
return exprs
|
||||
}
|
||||
|
||||
// exclusivesProduct returns a list of different products produced from the
|
||||
// ExclusivesProduct returns a list of different products produced from the
|
||||
// combinatorial product of the list of exclusives. Each ExclusiveInvariant must
|
||||
// contain between one and more Invariants. This takes every combination of
|
||||
// Invariants (choosing one from each ExclusiveInvariant) and returns that list.
|
||||
// In other words, if you have three exclusives, with invariants named (A1, B1),
|
||||
// (A2), and (A3, B3, C3) you'll get: (A1, A2, A3), (A1, A2, B3), (A1, A2, C3),
|
||||
// (B1, A2, A3), (B1, A2, B3), (B1, A2, C3) as results for this function call.
|
||||
func exclusivesProduct(exclusives []*interfaces.ExclusiveInvariant) [][]interfaces.Invariant {
|
||||
func ExclusivesProduct(exclusives []*interfaces.ExclusiveInvariant) [][]interfaces.Invariant {
|
||||
if len(exclusives) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -76,13 +76,13 @@ func ExprContains(needle interfaces.Expr, haystack []interfaces.Expr) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// pairs is a simple list of pairs of expressions which can be used as a simple
|
||||
// Pairs is a simple list of pairs of expressions which can be used as a simple
|
||||
// undirected graph structure, or as a simple list of equalities.
|
||||
type pairs []*interfaces.EqualityInvariant
|
||||
type Pairs []*interfaces.EqualityInvariant
|
||||
|
||||
// Vertices returns the list of vertices that the input expr is directly
|
||||
// connected to.
|
||||
func (obj pairs) Vertices(expr interfaces.Expr) []interfaces.Expr {
|
||||
func (obj Pairs) Vertices(expr interfaces.Expr) []interfaces.Expr {
|
||||
m := make(map[interfaces.Expr]struct{})
|
||||
for _, x := range obj {
|
||||
if x.Expr1 == x.Expr2 { // skip circular
|
||||
@@ -106,7 +106,7 @@ func (obj pairs) Vertices(expr interfaces.Expr) []interfaces.Expr {
|
||||
}
|
||||
|
||||
// DFS returns a depth first search for the graph, starting at the input vertex.
|
||||
func (obj pairs) DFS(start interfaces.Expr) []interfaces.Expr {
|
||||
func (obj Pairs) DFS(start interfaces.Expr) []interfaces.Expr {
|
||||
var d []interfaces.Expr // discovered
|
||||
var s []interfaces.Expr // stack
|
||||
found := false
|
||||
|
||||
Reference in New Issue
Block a user