lang: Structurally refactor type unification

This will make it easier to add new solvers and also cleans up some
pending issues.
This commit is contained in:
James Shubin
2024-03-30 16:55:20 -04:00
parent 964bd8ba61
commit cede7e5ac0
10 changed files with 374 additions and 173 deletions

View File

@@ -0,0 +1,231 @@
// Mgmt
// Copyright (C) 2013-2024+ James Shubin and the project contributors
// Written by James Shubin <james@shubin.ca> and the project contributors
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//
// Additional permission under GNU GPL version 3 section 7
//
// If you modify this program, or any covered work, by linking or combining it
// with embedded mcl code and modules (and that the embedded mcl code and
// modules which link with this program, contain a copy of their source code in
// the authoritative form) containing parts covered by the terms of any other
// license, the licensors of this program grant you additional permission to
// convey the resulting work. Furthermore, the licensors of this program grant
// the original author, James Shubin, additional permission to update this
// additional permission if he deems it necessary to achieve the goals of this
// additional permission.
package unification
import (
"context"
"fmt"
"sort"
"github.com/purpleidea/mgmt/lang/interfaces"
"github.com/purpleidea/mgmt/lang/types"
)
const (
// ErrAmbiguous means we couldn't find a solution, but we weren't
// inconsistent.
ErrAmbiguous = interfaces.Error("can't unify, no equalities were consumed, we're ambiguous")
)
// Init contains some handles that are used to initialize every solver. Each
// individual solver can choose to omit using some of the fields.
type Init struct {
Debug bool
Logf func(format string, v ...interface{})
}
// Solver is the general interface that any solver needs to implement.
type Solver interface {
// Init initializes the solver struct before first use.
Init(*Init) error
// Solve performs the actual solving. It must return as soon as possible
// if the context is closed.
Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error)
}
// registeredSolvers is a global map of all possible unification solvers which
// can be used. You should never touch this map directly. Use methods like
// Register instead.
var registeredSolvers = make(map[string]func() Solver) // must initialize
// Register takes a solver and its name and makes it available for use. It is
// commonly called in the init() method of the solver at program startup. There
// is no matching Unregister function.
func Register(name string, solver func() Solver) {
if _, exists := registeredSolvers[name]; exists {
panic(fmt.Sprintf("a solver named %s is already registered", name))
}
//gob.Register(solver())
registeredSolvers[name] = solver
}
// Lookup returns a pointer to the solver's struct.
func Lookup(name string) (Solver, error) {
solver, exists := registeredSolvers[name]
if !exists {
return nil, fmt.Errorf("not found")
}
return solver(), nil
}
// LookupDefault attempts to return a "default" solver.
func LookupDefault() (Solver, error) {
if len(registeredSolvers) == 0 {
return nil, fmt.Errorf("no registered solvers")
}
if len(registeredSolvers) == 1 {
for _, solver := range registeredSolvers {
return solver(), nil // return the first and only one
}
}
// TODO: Should we remove this empty string feature?
// If one was registered with no name, then use that as the default.
if solver, exists := registeredSolvers[""]; exists { // empty name
return solver(), nil
}
return nil, fmt.Errorf("no registered default solver")
}
// DebugSolverState helps us in understanding the state of the type unification
// solver in a more mainstream format.
// Example:
//
// solver state:
//
// * str("foo") :: str
// * call:f(str("foo")) [0xc000ac9f10] :: ?1
// * var(x) [0xc00088d840] :: ?2
// * param(x) [0xc00000f950] :: ?3
// * func(x) { var(x) } [0xc0000e9680] :: ?4
// * ?2 = ?3
// * ?4 = func(arg0 str) ?1
// * ?4 = func(x str) ?2
// * ?1 = ?2
func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string {
s := ""
// all the relevant Exprs
count := 0
exprs := make(map[interfaces.Expr]int)
for _, equality := range equalities {
for _, expr := range equality.ExprList() {
count++
exprs[expr] = count // for sorting
}
}
// print the solved Exprs first
for expr, typ := range solved {
s += fmt.Sprintf("%v :: %v\n", expr, typ)
delete(exprs, expr)
}
sortedExprs := []interfaces.Expr{}
for k := range exprs {
sortedExprs = append(sortedExprs, k)
}
sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] })
// for each remaining expr, generate a shorter name than the full pointer
nextVar := 1
shortNames := map[interfaces.Expr]string{}
for _, expr := range sortedExprs {
shortNames[expr] = fmt.Sprintf("?%d", nextVar)
nextVar++
s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr])
}
// print all the equalities using the short names
for _, equality := range equalities {
switch e := equality.(type) {
case *interfaces.EqualsInvariant:
_, ok := solved[e.Expr]
if !ok {
s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type)
} else {
// if solved, then this is redundant, don't print anything
}
case *interfaces.EqualityInvariant:
type1, ok1 := solved[e.Expr1]
type2, ok2 := solved[e.Expr2]
if !ok1 && !ok2 {
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2])
} else if ok1 && !ok2 {
s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2])
} else if !ok1 && ok2 {
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2)
} else {
// if completely solved, then this is redundant, don't print anything
}
case *interfaces.EqualityWrapFuncInvariant:
funcType, funcOk := solved[e.Expr1]
args := ""
argsOk := true
for i, argName := range e.Expr2Ord {
if i > 0 {
args += ", "
}
argExpr := e.Expr2Map[argName]
argType, ok := solved[argExpr]
if !ok {
args += fmt.Sprintf("%s %s", argName, shortNames[argExpr])
argsOk = false
} else {
args += fmt.Sprintf("%s %s", argName, argType)
}
}
outType, outOk := solved[e.Expr2Out]
if !funcOk || !argsOk || !outOk {
if !funcOk && !outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out])
} else if !funcOk && outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType)
} else if funcOk && !outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out])
} else {
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType)
}
}
case *interfaces.CallFuncArgsValueInvariant:
// skip, not used in the examples I care about
case *interfaces.AnyInvariant:
// skip, not used in the examples I care about
case *interfaces.SkipInvariant:
// we don't care about this one
default:
s += fmt.Sprintf("%v\n", equality)
}
}
return s
}

View File

@@ -27,25 +27,21 @@
// additional permission if he deems it necessary to achieve the goals of this
// additional permission.
package unification // TODO: can we put this solver in a sub-package?
package simplesolver
import (
"context"
"fmt"
"sort"
"github.com/purpleidea/mgmt/lang/interfaces"
"github.com/purpleidea/mgmt/lang/types"
"github.com/purpleidea/mgmt/lang/unification"
"github.com/purpleidea/mgmt/util/errwrap"
)
const (
// Name is the prefix for our solver log messages.
Name = "solver: simple"
// ErrAmbiguous means we couldn't find a solution, but we weren't
// inconsistent.
ErrAmbiguous = interfaces.Error("can't unify, no equalities were consumed, we're ambiguous")
Name = "simple"
// AllowRecursion specifies whether we're allowed to use the recursive
// solver or not. It uses an absurd amount of memory, and might hang
@@ -61,154 +57,32 @@ const (
RecursionInvariantLimit = 5 // TODO: pick a better value ?
)
// SimpleInvariantSolverLogger is a wrapper which returns a
// SimpleInvariantSolver with the log parameter of your choice specified. The
// result satisfies the correct signature for the solver parameter of the
// Unification function.
// TODO: Get rid of this function and consider just using the struct directly.
func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) {
return func(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) {
sis := &SimpleInvariantSolver{
Debug: false, // TODO: consider plumbing this through
Logf: logf,
}
return sis.Solve(ctx, invariants, expected)
}
}
// DebugSolverState helps us in understanding the state of the type unification
// solver in a more mainstream format.
// Example:
//
// solver state:
//
// * str("foo") :: str
// * call:f(str("foo")) [0xc000ac9f10] :: ?1
// * var(x) [0xc00088d840] :: ?2
// * param(x) [0xc00000f950] :: ?3
// * func(x) { var(x) } [0xc0000e9680] :: ?4
// * ?2 = ?3
// * ?4 = func(arg0 str) ?1
// * ?4 = func(x str) ?2
// * ?1 = ?2
func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string {
s := ""
// all the relevant Exprs
count := 0
exprs := make(map[interfaces.Expr]int)
for _, equality := range equalities {
for _, expr := range equality.ExprList() {
count++
exprs[expr] = count // for sorting
}
}
// print the solved Exprs first
for expr, typ := range solved {
s += fmt.Sprintf("%v :: %v\n", expr, typ)
delete(exprs, expr)
}
sortedExprs := []interfaces.Expr{}
for k := range exprs {
sortedExprs = append(sortedExprs, k)
}
sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] })
// for each remaining expr, generate a shorter name than the full pointer
nextVar := 1
shortNames := map[interfaces.Expr]string{}
for _, expr := range sortedExprs {
shortNames[expr] = fmt.Sprintf("?%d", nextVar)
nextVar++
s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr])
}
// print all the equalities using the short names
for _, equality := range equalities {
switch e := equality.(type) {
case *interfaces.EqualsInvariant:
_, ok := solved[e.Expr]
if !ok {
s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type)
} else {
// if solved, then this is redundant, don't print anything
}
case *interfaces.EqualityInvariant:
type1, ok1 := solved[e.Expr1]
type2, ok2 := solved[e.Expr2]
if !ok1 && !ok2 {
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2])
} else if ok1 && !ok2 {
s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2])
} else if !ok1 && ok2 {
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2)
} else {
// if completely solved, then this is redundant, don't print anything
}
case *interfaces.EqualityWrapFuncInvariant:
funcType, funcOk := solved[e.Expr1]
args := ""
argsOk := true
for i, argName := range e.Expr2Ord {
if i > 0 {
args += ", "
}
argExpr := e.Expr2Map[argName]
argType, ok := solved[argExpr]
if !ok {
args += fmt.Sprintf("%s %s", argName, shortNames[argExpr])
argsOk = false
} else {
args += fmt.Sprintf("%s %s", argName, argType)
}
}
outType, outOk := solved[e.Expr2Out]
if !funcOk || !argsOk || !outOk {
if !funcOk && !outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out])
} else if !funcOk && outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType)
} else if funcOk && !outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out])
} else {
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType)
}
}
case *interfaces.CallFuncArgsValueInvariant:
// skip, not used in the examples I care about
case *interfaces.AnyInvariant:
// skip, not used in the examples I care about
case *interfaces.SkipInvariant:
// we don't care about this one
default:
s += fmt.Sprintf("%v\n", equality)
}
}
return s
func init() {
unification.Register(Name, func() unification.Solver { return &SimpleInvariantSolver{} })
}
// SimpleInvariantSolver is an iterative invariant solver for AST expressions.
// It is intended to be very simple, even if it's computationally inefficient.
// TODO: Move some of the global solver constants into this struct as params.
type SimpleInvariantSolver struct {
// Strategy is a series of methodologies to heuristically improve the
// solver.
Strategy map[string]string
Debug bool
Logf func(format string, v ...interface{})
}
// Init contains some handles that are used to initialize the solver.
func (obj *SimpleInvariantSolver) Init(init *unification.Init) error {
obj.Debug = init.Debug
obj.Logf = init.Logf
return nil
}
// Solve is the actual solve implementation of the solver.
func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) {
func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*unification.InvariantSolution, error) {
process := func(invariants []interfaces.Invariant) ([]interfaces.Invariant, []*interfaces.ExclusiveInvariant, error) {
equalities := []interfaces.Invariant{}
exclusives := []*interfaces.ExclusiveInvariant{}
@@ -351,7 +225,7 @@ func (obj *SimpleInvariantSolver) Solve(ctx context.Context, invariants []interf
// list all the expr's connected to expr, use pairs as chains
listConnectedFn := func(expr interfaces.Expr, exprs []*interfaces.EqualityInvariant) []interfaces.Expr {
pairsType := pairs(exprs)
pairsType := unification.Pairs(exprs)
return pairsType.DFS(expr)
}
@@ -1272,7 +1146,7 @@ Loop:
obj.Logf("%s: unsolved: %+v", Name, x)
}
}
obj.Logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities))
obj.Logf("%s: solver state:\n%s", Name, unification.DebugSolverState(solved, equalities))
// Lastly, we could loop through each exclusive and see
// if it only has a single, easy solution. For example,
@@ -1338,7 +1212,7 @@ Loop:
}
// let's try each combination, one at a time...
for i, ex := range exclusivesProduct(exclusives) { // [][]interfaces.Invariant
for i, ex := range unification.ExclusivesProduct(exclusives) { // [][]interfaces.Invariant
select {
case <-ctx.Done():
return nil, ctx.Err()
@@ -1378,7 +1252,7 @@ Loop:
for expr, typ := range solved {
obj.Logf("%s: solved: (%p) => %+v", Name, expr, typ)
}
return nil, ErrAmbiguous
return nil, unification.ErrAmbiguous
}
// delete used equalities, in reverse order to preserve indexing!
for i := len(used) - 1; i >= 0; i-- {
@@ -1403,7 +1277,7 @@ Loop:
}
solutions = append(solutions, invar)
}
return &InvariantSolution{
return &unification.InvariantSolution{
Solutions: solutions,
}, nil
}

View File

@@ -29,7 +29,7 @@
//go:build !root
package unification
package solvers
import (
"context"
@@ -40,6 +40,7 @@ import (
"github.com/purpleidea/mgmt/lang/ast"
"github.com/purpleidea/mgmt/lang/interfaces"
"github.com/purpleidea/mgmt/lang/types"
"github.com/purpleidea/mgmt/lang/unification"
"github.com/purpleidea/mgmt/util"
)
@@ -259,14 +260,27 @@ func TestSimpleSolver1(t *testing.T) {
t.Run(fmt.Sprintf("test #%d (%s)", index, tc.name), func(t *testing.T) {
invariants, expected, fail, expect, experr, experrstr := tc.invariants, tc.expected, tc.fail, tc.expect, tc.experr, tc.experrstr
debug := testing.Verbose()
logf := func(format string, v ...interface{}) {
t.Logf(fmt.Sprintf("test #%d", index)+": "+format, v...)
}
debug := testing.Verbose()
solver := SimpleInvariantSolverLogger(logf) // generates a solver with built-in logging
solution, err := solver(context.TODO(), invariants, expected)
solver, err := unification.LookupDefault()
if err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: solver lookup failed with: %+v", index, err)
return
}
init := &unification.Init{
Debug: debug,
Logf: logf,
}
if err := solver.Init(init); err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: solver init failed with: %+v", index, err)
return
}
solution, err := solver.Solve(context.TODO(), invariants, expected)
t.Logf("test #%d: solver completed with: %+v", index, err)
if !fail && err != nil {

View File

@@ -0,0 +1,37 @@
// Mgmt
// Copyright (C) 2013-2024+ James Shubin and the project contributors
// Written by James Shubin <james@shubin.ca> and the project contributors
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//
// Additional permission under GNU GPL version 3 section 7
//
// If you modify this program, or any covered work, by linking or combining it
// with embedded mcl code and modules (and that the embedded mcl code and
// modules which link with this program, contain a copy of their source code in
// the authoritative form) containing parts covered by the terms of any other
// license, the licensors of this program grant you additional permission to
// convey the resulting work. Furthermore, the licensors of this program grant
// the original author, James Shubin, additional permission to update this
// additional permission if he deems it necessary to achieve the goals of this
// additional permission.
// Package solvers is used to have a central place to import all solvers from.
// It is also a good locus to run all of the unification tests from.
package solvers
import (
// import so the solver registers
_ "github.com/purpleidea/mgmt/lang/unification/simplesolver"
)

View File

@@ -0,0 +1,930 @@
// Mgmt
// Copyright (C) 2013-2024+ James Shubin and the project contributors
// Written by James Shubin <james@shubin.ca> and the project contributors
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//
// Additional permission under GNU GPL version 3 section 7
//
// If you modify this program, or any covered work, by linking or combining it
// with embedded mcl code and modules (and that the embedded mcl code and
// modules which link with this program, contain a copy of their source code in
// the authoritative form) containing parts covered by the terms of any other
// license, the licensors of this program grant you additional permission to
// convey the resulting work. Furthermore, the licensors of this program grant
// the original author, James Shubin, additional permission to update this
// additional permission if he deems it necessary to achieve the goals of this
// additional permission.
//go:build !root
package solvers
import (
"context"
"fmt"
"strings"
"testing"
_ "github.com/purpleidea/mgmt/engine/resources" // import so the resources register
"github.com/purpleidea/mgmt/lang/ast"
"github.com/purpleidea/mgmt/lang/funcs"
"github.com/purpleidea/mgmt/lang/funcs/vars"
"github.com/purpleidea/mgmt/lang/interfaces"
"github.com/purpleidea/mgmt/lang/types"
"github.com/purpleidea/mgmt/lang/unification"
"github.com/purpleidea/mgmt/util"
)
func TestUnification1(t *testing.T) {
type test struct { // an individual test
name string
ast interfaces.Stmt // raw AST
fail bool
expect map[interfaces.Expr]*types.Type
experr error // expected error if fail == true (nil ignores it)
experrstr string // expected error prefix
}
testCases := []test{}
// this causes a panic, so it can't be used
//{
// testCases = append(testCases, test{
// "nil",
// nil,
// true, // expect error
// nil, // no AST
// })
//}
{
expr := &ast.ExprStr{V: "hello"}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "t1"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "str",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "one res",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
expr: types.TypeStr,
},
})
}
{
v1 := &ast.ExprStr{}
v2 := &ast.ExprStr{}
v3 := &ast.ExprStr{}
expr := &ast.ExprList{
Elements: []interfaces.Expr{
v1,
v2,
v3,
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "test"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "slicestring",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "list of strings",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
v1: types.TypeStr,
v2: types.TypeStr,
v3: types.TypeStr,
expr: types.NewType("[]str"),
},
})
}
{
k1 := &ast.ExprInt{}
k2 := &ast.ExprInt{}
k3 := &ast.ExprInt{}
v1 := &ast.ExprFloat{}
v2 := &ast.ExprFloat{}
v3 := &ast.ExprFloat{}
expr := &ast.ExprMap{
KVs: []*ast.ExprMapKV{
{Key: k1, Val: v1},
{Key: k2, Val: v2},
{Key: k3, Val: v3},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "test"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "mapintfloat",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "map of int->float",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
k1: types.TypeInt,
k2: types.TypeInt,
k3: types.TypeInt,
v1: types.TypeFloat,
v2: types.TypeFloat,
v3: types.TypeFloat,
expr: types.NewType("map{int: float}"),
},
})
}
{
b := &ast.ExprBool{}
s := &ast.ExprStr{}
i := &ast.ExprInt{}
f := &ast.ExprFloat{}
expr := &ast.ExprStruct{
Fields: []*ast.ExprStructField{
{Name: "somebool", Value: b},
{Name: "somestr", Value: s},
{Name: "someint", Value: i},
{Name: "somefloat", Value: f},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "test"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "mixedstruct",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "simple struct",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
b: types.TypeBool,
s: types.TypeStr,
i: types.TypeInt,
f: types.TypeFloat,
expr: types.NewType("struct{somebool bool; somestr str; someint int; somefloat float}"),
},
})
}
{
// test "n1" {
// int64ptr => 13 + 42,
//}
expr := &ast.ExprCall{
Name: funcs.OperatorFuncName,
Args: []interfaces.Expr{
&ast.ExprStr{
V: "+",
},
&ast.ExprInt{
V: 13,
},
&ast.ExprInt{
V: 42,
},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{
V: "n1",
},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "int64ptr",
Value: expr, // func
},
},
},
},
}
testCases = append(testCases, test{
name: "func call",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
expr: types.NewType("int"),
},
})
}
{
//test "n1" {
// int64ptr => 13 + 42 - 4,
//}
innerFunc := &ast.ExprCall{
Name: funcs.OperatorFuncName,
Args: []interfaces.Expr{
&ast.ExprStr{
V: "-",
},
&ast.ExprInt{
V: 42,
},
&ast.ExprInt{
V: 4,
},
},
}
expr := &ast.ExprCall{
Name: funcs.OperatorFuncName,
Args: []interfaces.Expr{
&ast.ExprStr{
V: "+",
},
&ast.ExprInt{
V: 13,
},
innerFunc, // nested func, can we unify?
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{
V: "n1",
},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "int64ptr",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "func call, multiple ints",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
innerFunc: types.NewType("int"),
expr: types.NewType("int"),
},
})
}
{
//test "n1" {
// float32 => -25.38789 + 32.6 + 13.7,
//}
innerFunc := &ast.ExprCall{
Name: funcs.OperatorFuncName,
Args: []interfaces.Expr{
&ast.ExprStr{
V: "+",
},
&ast.ExprFloat{
V: 32.6,
},
&ast.ExprFloat{
V: 13.7,
},
},
}
expr := &ast.ExprCall{
Name: funcs.OperatorFuncName,
Args: []interfaces.Expr{
&ast.ExprStr{
V: "+",
},
&ast.ExprFloat{
V: -25.38789,
},
innerFunc, // nested func, can we unify?
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{
V: "n1",
},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "float32",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "func call, multiple floats",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
innerFunc: types.NewType("float"),
expr: types.NewType("float"),
},
})
}
{
//$x = 42 - 13
//test "t1" {
// int64 => $x,
//}
innerFunc := &ast.ExprCall{
Name: funcs.OperatorFuncName,
Args: []interfaces.Expr{
&ast.ExprStr{
V: "-",
},
&ast.ExprInt{
V: 42,
},
&ast.ExprInt{
V: 13,
},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtBind{
Ident: "x",
Value: innerFunc,
},
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{
V: "t1",
},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "int64",
Value: &ast.ExprVar{
Name: "x",
},
},
},
},
},
}
testCases = append(testCases, test{
name: "assign from func call or two ints",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
innerFunc: types.NewType("int"),
},
})
}
{
//$x = template("hello", 42)
//test "t1" {
// anotherstr => $x,
//}
innerFunc := &ast.ExprCall{
Name: "template",
Args: []interfaces.Expr{
&ast.ExprStr{
V: "hello",
},
&ast.ExprInt{
V: 42,
},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtBind{
Ident: "x",
Value: innerFunc,
},
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{
V: "t1",
},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "anotherstr",
Value: &ast.ExprVar{
Name: "x",
},
},
},
},
},
}
testCases = append(testCases, test{
name: "simple template",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
innerFunc: types.NewType("str"),
},
})
}
{
// import "datetime"
//test "t1" {
// stringptr => datetime.now(), # bad (str vs. int)
//}
expr := &ast.ExprCall{
Name: "datetime.now",
Args: []interfaces.Expr{},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtImport{
Name: "datetime",
},
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "t1"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "stringptr",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "single fact unification",
ast: stmt,
fail: true,
})
}
{
//import "sys"
//test "t1" {
// stringptr => sys.getenv("GOPATH", "bug"), # bad (two args vs. one)
//}
expr := &ast.ExprCall{
Name: "sys.getenv",
Args: []interfaces.Expr{
&ast.ExprStr{
V: "GOPATH",
},
&ast.ExprStr{
V: "bug",
},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtImport{
Name: "sys",
},
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "t1"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "stringptr",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "function, wrong arg count",
ast: stmt,
fail: true,
})
}
// XXX: add these tests when we fix the bug!
//{
// //import "fmt"
// //test "t1" {
// // stringptr => fmt.printf("hello %s and %s", "one"), # bad
// //}
// expr := &ast.ExprCall{
// Name: "fmt.printf",
// Args: []interfaces.Expr{
// &ast.ExprStr{
// V: "hello %s and %s",
// },
// &ast.ExprStr{
// V: "one",
// },
// },
// }
// stmt := &ast.StmtProg{
// Body: []interfaces.Stmt{
// &ast.StmtImport{
// Name: "fmt",
// },
// &ast.StmtRes{
// Kind: "test",
// Name: &ast.ExprStr{V: "t1"},
// Contents: []ast.StmtResContents{
// &ast.StmtResField{
// Field: "stringptr",
// Value: expr,
// },
// },
// },
// },
// }
// testCases = append(testCases, test{
// name: "function, missing arg for printf",
// ast: stmt,
// fail: true,
// })
//}
//{
// //import "fmt"
// //test "t1" {
// // stringptr => fmt.printf("hello %s and %s", "one", "two", "three"), # bad
// //}
// expr := &ast.ExprCall{
// Name: "fmt.printf",
// Args: []interfaces.Expr{
// &ast.ExprStr{
// V: "hello %s and %s",
// },
// &ast.ExprStr{
// V: "one",
// },
// &ast.ExprStr{
// V: "two",
// },
// &ast.ExprStr{
// V: "three",
// },
// },
// }
// stmt := &ast.StmtProg{
// Body: []interfaces.Stmt{
// &ast.StmtImport{
// Name: "fmt",
// },
// &ast.StmtRes{
// Kind: "test",
// Name: &ast.ExprStr{V: "t1"},
// Contents: []ast.StmtResContents{
// &ast.StmtResField{
// Field: "stringptr",
// Value: expr,
// },
// },
// },
// },
// }
// testCases = append(testCases, test{
// name: "function, extra arg for printf",
// ast: stmt,
// fail: true,
// })
//}
{
//import "fmt"
//test "t1" {
// stringptr => fmt.printf("hello %s and %s", "one", "two"),
//}
expr := &ast.ExprCall{
Name: "fmt.printf",
Args: []interfaces.Expr{
&ast.ExprStr{
V: "hello %s and %s",
},
&ast.ExprStr{
V: "one",
},
&ast.ExprStr{
V: "two",
},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtImport{
Name: "fmt",
},
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "t1"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "stringptr",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "function, regular printf unification",
ast: stmt,
fail: false,
expect: map[interfaces.Expr]*types.Type{
expr: types.NewType("str"),
},
})
}
{
//import "fmt"
//$x str = if true { # should fail unification
// 42
//} else {
// 13
//}
//test "t1" {
// stringptr => fmt.printf("hello %s", $x),
//}
cond := &ast.ExprIf{
Condition: &ast.ExprBool{V: true},
ThenBranch: &ast.ExprInt{V: 42},
ElseBranch: &ast.ExprInt{V: 13},
}
cond.SetType(types.TypeStr) // should fail unification
expr := &ast.ExprCall{
Name: "fmt.printf",
Args: []interfaces.Expr{
&ast.ExprStr{
V: "hello %s",
},
&ast.ExprVar{
Name: "x", // the var
},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtImport{
Name: "fmt",
},
&ast.StmtBind{
Ident: "x", // the var
Value: cond,
},
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "t1"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "anotherstr",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "typed if expr",
ast: stmt,
fail: true,
experrstr: "can't unify, invariant illogicality with equality: base kind does not match (Str != Int)",
})
}
{
//import "fmt"
//$w = true
//$x str = $w # should fail unification
//test "t1" {
// stringptr => fmt.printf("hello %s", $x),
//}
wvar := &ast.ExprBool{V: true}
xvar := &ast.ExprVar{Name: "w"}
xvar.SetType(types.TypeStr) // should fail unification
expr := &ast.ExprCall{
Name: "fmt.printf",
Args: []interfaces.Expr{
&ast.ExprStr{
V: "hello %s",
},
&ast.ExprVar{
Name: "x", // the var
},
},
}
stmt := &ast.StmtProg{
Body: []interfaces.Stmt{
&ast.StmtImport{
Name: "fmt",
},
&ast.StmtBind{
Ident: "w",
Value: wvar,
},
&ast.StmtBind{
Ident: "x", // the var
Value: xvar,
},
&ast.StmtRes{
Kind: "test",
Name: &ast.ExprStr{V: "t1"},
Contents: []ast.StmtResContents{
&ast.StmtResField{
Field: "anotherstr",
Value: expr,
},
},
},
},
}
testCases = append(testCases, test{
name: "typed var expr",
ast: stmt,
fail: true,
experrstr: "can't unify, invariant illogicality with equality: base kind does not match (Str != Bool)",
})
}
names := []string{}
for index, tc := range testCases { // run all the tests
if tc.name == "" {
t.Errorf("test #%d: not named", index)
continue
}
if util.StrInList(tc.name, names) {
t.Errorf("test #%d: duplicate sub test name of: %s", index, tc.name)
continue
}
names = append(names, tc.name)
t.Run(fmt.Sprintf("test #%d (%s)", index, tc.name), func(t *testing.T) {
xast, fail, expect, experr, experrstr := tc.ast, tc.fail, tc.expect, tc.experr, tc.experrstr
//str := strings.NewReader(code)
//xast, err := parser.LexParse(str)
//if err != nil {
// t.Errorf("test #%d: lex/parse failed with: %+v", index, err)
// return
//}
// TODO: print out the AST's so that we can see the types
t.Logf("\n\ntest #%d: AST (before): %+v\n", index, xast)
data := &interfaces.Data{
// TODO: add missing fields here if/when needed
Debug: testing.Verbose(), // set via the -test.v flag to `go test`
Logf: func(format string, v ...interface{}) {
t.Logf(fmt.Sprintf("test #%d", index)+": ast: "+format, v...)
},
}
// some of this might happen *after* interpolate in SetScope or Unify...
if err := xast.Init(data); err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: could not init and validate AST: %+v", index, err)
return
}
// skip interpolation in this test so that the node pointers
// aren't changed and so we can compare directly to expected
//astInterpolated, err := ast.Interpolate() // interpolate strings in ast
//if err != nil {
// t.Errorf("test #%d: interpolate failed with: %+v", index, err)
// return
//}
//t.Logf("test #%d: astInterpolated: %+v", index, astInterpolated)
variables := map[string]interfaces.Expr{
"purpleidea": &ast.ExprStr{V: "hello world!"}, // james says hi
//"hostname": &ast.ExprStr{V: obj.Hostname},
}
consts := ast.VarPrefixToVariablesScope(vars.ConstNamespace) // strips prefix!
addback := vars.ConstNamespace + interfaces.ModuleSep // add it back...
var err error
variables, err = ast.MergeExprMaps(variables, consts, addback)
if err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: couldn't merge in consts: %+v", index, err)
return
}
// top-level, built-in, initial global scope
scope := &interfaces.Scope{
Variables: variables,
// all the built-in top-level, core functions enter here...
Functions: ast.FuncPrefixToFunctionsScope(""), // runs funcs.LookupPrefix
}
// propagate the scope down through the AST...
if err := xast.SetScope(scope); err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: set scope failed with: %+v", index, err)
return
}
// apply type unification
debug := testing.Verbose()
logf := func(format string, v ...interface{}) {
t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...)
}
solver, err := unification.LookupDefault()
if err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: solver lookup failed with: %+v", index, err)
return
}
unifier := &unification.Unifier{
AST: xast,
Solver: solver,
Debug: debug,
Logf: logf,
}
err = unifier.Unify(context.TODO())
// TODO: print out the AST's so that we can see the types
t.Logf("\n\ntest #%d: AST (after): %+v\n", index, xast)
if !fail && err != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: unification failed with: %+v", index, err)
return
}
if fail && err == nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: unification passed, expected fail", index)
return
}
if fail && experr != nil && err != experr { // test for specific error!
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: expected fail, got wrong error", index)
t.Errorf("test #%d: got error: %+v", index, err)
t.Errorf("test #%d: exp error: %+v", index, experr)
return
}
if fail && err != nil {
t.Logf("test #%d: err: %+v", index, err)
}
// test for specific error string!
if fail && experrstr != "" && !strings.HasPrefix(err.Error(), experrstr) {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: expected fail, got wrong error", index)
t.Errorf("test #%d: got error: %s", index, err.Error())
t.Errorf("test #%d: exp error: %s", index, experrstr)
return
}
if expect == nil { // test done early
return
}
// TODO: do this in sorted order
var failed bool
for expr, exptyp := range expect {
typ, err := expr.Type() // lookup type
if err != nil {
t.Errorf("test #%d: type lookup of %+v failed with: %+v", index, expr, err)
failed = true
break
}
if err := typ.Cmp(exptyp); err != nil {
t.Errorf("test #%d: type cmp failed with: %+v", index, err)
t.Logf("test #%d: got: %+v", index, typ)
t.Logf("test #%d: exp: %+v", index, exptyp)
failed = true
break
}
}
if failed {
return
}
})
}
}

View File

@@ -46,8 +46,7 @@ type Unifier struct {
AST interfaces.Stmt
// Solver is the solver algorithm implementation to use.
// XXX: Solver should be a solver interface, not a function signature.
Solver func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error)
Solver Solver
Debug bool
Logf func(format string, v ...interface{})
@@ -76,6 +75,14 @@ func (obj *Unifier) Unify(ctx context.Context) error {
return fmt.Errorf("the Logf function is missing")
}
init := &Init{
Logf: obj.Logf,
Debug: obj.Debug,
}
if err := obj.Solver.Init(init); err != nil {
return err
}
if obj.Debug {
obj.Logf("tree: %+v", obj.AST)
}
@@ -98,7 +105,7 @@ func (obj *Unifier) Unify(ctx context.Context) error {
exprMap := ExprListToExprMap(exprs) // makes searching faster
exprList := ExprMapToExprList(exprMap) // makes it unique (no duplicates)
solved, err := obj.Solver(ctx, invariants, exprList)
solved, err := obj.Solver.Solve(ctx, invariants, exprList)
if err != nil {
return err
}
@@ -194,14 +201,14 @@ func (obj *InvariantSolution) ExprList() []interfaces.Expr {
return exprs
}
// exclusivesProduct returns a list of different products produced from the
// ExclusivesProduct returns a list of different products produced from the
// combinatorial product of the list of exclusives. Each ExclusiveInvariant must
// contain between one and more Invariants. This takes every combination of
// Invariants (choosing one from each ExclusiveInvariant) and returns that list.
// In other words, if you have three exclusives, with invariants named (A1, B1),
// (A2), and (A3, B3, C3) you'll get: (A1, A2, A3), (A1, A2, B3), (A1, A2, C3),
// (B1, A2, A3), (B1, A2, B3), (B1, A2, C3) as results for this function call.
func exclusivesProduct(exclusives []*interfaces.ExclusiveInvariant) [][]interfaces.Invariant {
func ExclusivesProduct(exclusives []*interfaces.ExclusiveInvariant) [][]interfaces.Invariant {
if len(exclusives) == 0 {
return nil
}

View File

@@ -76,13 +76,13 @@ func ExprContains(needle interfaces.Expr, haystack []interfaces.Expr) bool {
return false
}
// pairs is a simple list of pairs of expressions which can be used as a simple
// Pairs is a simple list of pairs of expressions which can be used as a simple
// undirected graph structure, or as a simple list of equalities.
type pairs []*interfaces.EqualityInvariant
type Pairs []*interfaces.EqualityInvariant
// Vertices returns the list of vertices that the input expr is directly
// connected to.
func (obj pairs) Vertices(expr interfaces.Expr) []interfaces.Expr {
func (obj Pairs) Vertices(expr interfaces.Expr) []interfaces.Expr {
m := make(map[interfaces.Expr]struct{})
for _, x := range obj {
if x.Expr1 == x.Expr2 { // skip circular
@@ -106,7 +106,7 @@ func (obj pairs) Vertices(expr interfaces.Expr) []interfaces.Expr {
}
// DFS returns a depth first search for the graph, starting at the input vertex.
func (obj pairs) DFS(start interfaces.Expr) []interfaces.Expr {
func (obj Pairs) DFS(start interfaces.Expr) []interfaces.Expr {
var d []interfaces.Expr // discovered
var s []interfaces.Expr // stack
found := false