lang: unification: Improve type unification algorithm
The simple type unification algorithm suffered from some serious performance and memory problems when used with certain code bases. This adds some crucial optimizations that improve performance drastically.
This commit is contained in:
@@ -2,7 +2,6 @@ import "world"
|
|||||||
|
|
||||||
$ns = "estate"
|
$ns = "estate"
|
||||||
$exchanged = world.kvlookup($ns)
|
$exchanged = world.kvlookup($ns)
|
||||||
|
|
||||||
$state = maplookup($exchanged, $hostname, "default")
|
$state = maplookup($exchanged, $hostname, "default")
|
||||||
|
|
||||||
if $state == "one" || $state == "default" {
|
if $state == "one" || $state == "default" {
|
||||||
|
|||||||
@@ -293,7 +293,13 @@ func (obj *GAPI) Cli(cliInfo *gapi.CliInfo) (*gapi.Deploy, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
logf("running type unification...")
|
logf("running type unification...")
|
||||||
if err := unification.Unify(interpolated, unification.SimpleInvariantSolverLogger(unificationLogf)); err != nil {
|
unifier := &unification.Unifier{
|
||||||
|
AST: interpolated,
|
||||||
|
Solver: unification.SimpleInvariantSolverLogger(unificationLogf),
|
||||||
|
Debug: debug,
|
||||||
|
Logf: unificationLogf,
|
||||||
|
}
|
||||||
|
if err := unifier.Unify(); err != nil {
|
||||||
return nil, errwrap.Wrapf(err, "could not unify types")
|
return nil, errwrap.Wrapf(err, "could not unify types")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import (
|
|||||||
// often since we usually know which kind of node we want.
|
// often since we usually know which kind of node we want.
|
||||||
type Node interface {
|
type Node interface {
|
||||||
Apply(fn func(Node) error) error
|
Apply(fn func(Node) error) error
|
||||||
|
//Parent() Node // TODO: should we implement this?
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stmt represents a statement node in the language. A stmt could be a resource,
|
// Stmt represents a statement node in the language. A stmt could be a resource,
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ package interfaces
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/purpleidea/mgmt/lang/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Invariant represents a constraint that is described by the Expr's and Stmt's,
|
// Invariant represents a constraint that is described by the Expr's and Stmt's,
|
||||||
@@ -27,4 +29,11 @@ import (
|
|||||||
type Invariant interface {
|
type Invariant interface {
|
||||||
// TODO: should we add any other methods to this type?
|
// TODO: should we add any other methods to this type?
|
||||||
fmt.Stringer
|
fmt.Stringer
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
ExprList() []Expr
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution.
|
||||||
|
// If it is inconsistent, then it errors.
|
||||||
|
Matches(solved map[Expr]*types.Type) (bool, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -467,7 +467,13 @@ func TestAstFunc0(t *testing.T) {
|
|||||||
logf := func(format string, v ...interface{}) {
|
logf := func(format string, v ...interface{}) {
|
||||||
t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...)
|
t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...)
|
||||||
}
|
}
|
||||||
err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(logf))
|
unifier := &unification.Unifier{
|
||||||
|
AST: iast,
|
||||||
|
Solver: unification.SimpleInvariantSolverLogger(logf),
|
||||||
|
Debug: testing.Verbose(),
|
||||||
|
Logf: logf,
|
||||||
|
}
|
||||||
|
err = unifier.Unify()
|
||||||
if !fail && err != nil {
|
if !fail && err != nil {
|
||||||
t.Errorf("test #%d: FAIL", index)
|
t.Errorf("test #%d: FAIL", index)
|
||||||
t.Errorf("test #%d: could not unify types: %+v", index, err)
|
t.Errorf("test #%d: could not unify types: %+v", index, err)
|
||||||
@@ -822,7 +828,13 @@ func TestAstFunc1(t *testing.T) {
|
|||||||
xlogf := func(format string, v ...interface{}) {
|
xlogf := func(format string, v ...interface{}) {
|
||||||
logf("unification: "+format, v...)
|
logf("unification: "+format, v...)
|
||||||
}
|
}
|
||||||
err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(xlogf))
|
unifier := &unification.Unifier{
|
||||||
|
AST: iast,
|
||||||
|
Solver: unification.SimpleInvariantSolverLogger(xlogf),
|
||||||
|
Debug: testing.Verbose(),
|
||||||
|
Logf: xlogf,
|
||||||
|
}
|
||||||
|
err = unifier.Unify()
|
||||||
if !fail && err != nil {
|
if !fail && err != nil {
|
||||||
t.Errorf("test #%d: FAIL", index)
|
t.Errorf("test #%d: FAIL", index)
|
||||||
t.Errorf("test #%d: could not unify types: %+v", index, err)
|
t.Errorf("test #%d: could not unify types: %+v", index, err)
|
||||||
@@ -1216,7 +1228,13 @@ func TestAstFunc2(t *testing.T) {
|
|||||||
xlogf := func(format string, v ...interface{}) {
|
xlogf := func(format string, v ...interface{}) {
|
||||||
logf("unification: "+format, v...)
|
logf("unification: "+format, v...)
|
||||||
}
|
}
|
||||||
err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(xlogf))
|
unifier := &unification.Unifier{
|
||||||
|
AST: iast,
|
||||||
|
Solver: unification.SimpleInvariantSolverLogger(xlogf),
|
||||||
|
Debug: testing.Verbose(),
|
||||||
|
Logf: xlogf,
|
||||||
|
}
|
||||||
|
err = unifier.Unify()
|
||||||
if !fail && err != nil {
|
if !fail && err != nil {
|
||||||
t.Errorf("test #%d: FAIL", index)
|
t.Errorf("test #%d: FAIL", index)
|
||||||
t.Errorf("test #%d: could not unify types: %+v", index, err)
|
t.Errorf("test #%d: could not unify types: %+v", index, err)
|
||||||
|
|||||||
11
lang/interpret_test/TestAstFunc1/doubleinclude.graph
Normal file
11
lang/interpret_test/TestAstFunc1/doubleinclude.graph
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
Edge: str("hey") -> var(foo) # foo
|
||||||
|
Edge: str("hey") -> var(foo) # foo
|
||||||
|
Edge: str("t1") -> var(a) # a
|
||||||
|
Edge: str("t2") -> var(a) # a
|
||||||
|
Vertex: str("hey")
|
||||||
|
Vertex: str("t1")
|
||||||
|
Vertex: str("t2")
|
||||||
|
Vertex: var(a)
|
||||||
|
Vertex: var(a)
|
||||||
|
Vertex: var(foo)
|
||||||
|
Vertex: var(foo)
|
||||||
8
lang/interpret_test/TestAstFunc1/doubleinclude/main.mcl
Normal file
8
lang/interpret_test/TestAstFunc1/doubleinclude/main.mcl
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
include c1("t1")
|
||||||
|
include c1("t2")
|
||||||
|
class c1($a) {
|
||||||
|
test $a {
|
||||||
|
stringptr => $foo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$foo = "hey"
|
||||||
32
lang/interpret_test/TestAstFunc1/polydoubleinclude.graph
Normal file
32
lang/interpret_test/TestAstFunc1/polydoubleinclude.graph
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
Edge: call:len(var(b)) -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # b
|
||||||
|
Edge: call:len(var(b)) -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # b
|
||||||
|
Edge: int(-37) -> list(int(13), int(42), int(0), int(-37)) # 3
|
||||||
|
Edge: int(0) -> list(int(13), int(42), int(0), int(-37)) # 2
|
||||||
|
Edge: int(13) -> list(int(13), int(42), int(0), int(-37)) # 0
|
||||||
|
Edge: int(42) -> list(int(13), int(42), int(0), int(-37)) # 1
|
||||||
|
Edge: list(int(13), int(42), int(0), int(-37)) -> var(b) # b
|
||||||
|
Edge: str("hello") -> var(b) # b
|
||||||
|
Edge: str("len is: %d") -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # a
|
||||||
|
Edge: str("len is: %d") -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # a
|
||||||
|
Edge: str("t1") -> var(a) # a
|
||||||
|
Edge: str("t2") -> var(a) # a
|
||||||
|
Edge: var(b) -> call:len(var(b)) # 0
|
||||||
|
Edge: var(b) -> call:len(var(b)) # 0
|
||||||
|
Vertex: call:fmt.printf(str("len is: %d"), call:len(var(b)))
|
||||||
|
Vertex: call:fmt.printf(str("len is: %d"), call:len(var(b)))
|
||||||
|
Vertex: call:len(var(b))
|
||||||
|
Vertex: call:len(var(b))
|
||||||
|
Vertex: int(-37)
|
||||||
|
Vertex: int(0)
|
||||||
|
Vertex: int(13)
|
||||||
|
Vertex: int(42)
|
||||||
|
Vertex: list(int(13), int(42), int(0), int(-37))
|
||||||
|
Vertex: str("hello")
|
||||||
|
Vertex: str("len is: %d")
|
||||||
|
Vertex: str("len is: %d")
|
||||||
|
Vertex: str("t1")
|
||||||
|
Vertex: str("t2")
|
||||||
|
Vertex: var(a)
|
||||||
|
Vertex: var(a)
|
||||||
|
Vertex: var(b)
|
||||||
|
Vertex: var(b)
|
||||||
10
lang/interpret_test/TestAstFunc1/polydoubleinclude/main.mcl
Normal file
10
lang/interpret_test/TestAstFunc1/polydoubleinclude/main.mcl
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import "fmt"
|
||||||
|
|
||||||
|
# note that the class can have two separate types for $b
|
||||||
|
include c1("t1", "hello") # len is 5
|
||||||
|
include c1("t2", [13, 42, 0, -37,]) # len is 4
|
||||||
|
class c1($a, $b) {
|
||||||
|
test $a {
|
||||||
|
anotherstr => fmt.printf("len is: %d", len($b)),
|
||||||
|
}
|
||||||
|
}
|
||||||
88
lang/interpret_test/TestAstFunc1/slow_unification0.graph
Normal file
88
lang/interpret_test/TestAstFunc1/slow_unification0.graph
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
Edge: call:_operator(str("=="), var(state), str("default")) -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # b
|
||||||
|
Edge: call:_operator(str("=="), var(state), str("one")) -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # a
|
||||||
|
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
|
||||||
|
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
|
||||||
|
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
|
||||||
|
Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state
|
||||||
|
Edge: call:world.kvlookup(var(ns)) -> var(exchanged) # exchanged
|
||||||
|
Edge: str("") -> var(hostname) # hostname
|
||||||
|
Edge: str("==") -> call:_operator(str("=="), var(state), str("default")) # x
|
||||||
|
Edge: str("==") -> call:_operator(str("=="), var(state), str("one")) # x
|
||||||
|
Edge: str("==") -> call:_operator(str("=="), var(state), str("three")) # x
|
||||||
|
Edge: str("==") -> call:_operator(str("=="), var(state), str("two")) # x
|
||||||
|
Edge: str("default") -> call:_operator(str("=="), var(state), str("default")) # b
|
||||||
|
Edge: str("default") -> call:maplookup(var(exchanged), var(hostname), str("default")) # default
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("estate") -> var(ns) # ns
|
||||||
|
Edge: str("one") -> call:_operator(str("=="), var(state), str("one")) # b
|
||||||
|
Edge: str("three") -> call:_operator(str("=="), var(state), str("three")) # b
|
||||||
|
Edge: str("two") -> call:_operator(str("=="), var(state), str("two")) # b
|
||||||
|
Edge: str("||") -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # x
|
||||||
|
Edge: var(exchanged) -> call:maplookup(var(exchanged), var(hostname), str("default")) # map
|
||||||
|
Edge: var(hostname) -> call:maplookup(var(exchanged), var(hostname), str("default")) # key
|
||||||
|
Edge: var(ns) -> call:world.kvlookup(var(ns)) # namespace
|
||||||
|
Edge: var(state) -> call:_operator(str("=="), var(state), str("default")) # a
|
||||||
|
Edge: var(state) -> call:_operator(str("=="), var(state), str("one")) # a
|
||||||
|
Edge: var(state) -> call:_operator(str("=="), var(state), str("three")) # a
|
||||||
|
Edge: var(state) -> call:_operator(str("=="), var(state), str("two")) # a
|
||||||
|
Vertex: call:_operator(str("=="), var(state), str("default"))
|
||||||
|
Vertex: call:_operator(str("=="), var(state), str("one"))
|
||||||
|
Vertex: call:_operator(str("=="), var(state), str("three"))
|
||||||
|
Vertex: call:_operator(str("=="), var(state), str("two"))
|
||||||
|
Vertex: call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default")))
|
||||||
|
Vertex: call:maplookup(var(exchanged), var(hostname), str("default"))
|
||||||
|
Vertex: call:world.kvlookup(var(ns))
|
||||||
|
Vertex: str("")
|
||||||
|
Vertex: str("/tmp/mgmt/state")
|
||||||
|
Vertex: str("/tmp/mgmt/state")
|
||||||
|
Vertex: str("/tmp/mgmt/state")
|
||||||
|
Vertex: str("/usr/bin/sleep 1s")
|
||||||
|
Vertex: str("/usr/bin/sleep 1s")
|
||||||
|
Vertex: str("/usr/bin/sleep 1s")
|
||||||
|
Vertex: str("==")
|
||||||
|
Vertex: str("==")
|
||||||
|
Vertex: str("==")
|
||||||
|
Vertex: str("==")
|
||||||
|
Vertex: str("default")
|
||||||
|
Vertex: str("default")
|
||||||
|
Vertex: str("estate")
|
||||||
|
Vertex: str("one")
|
||||||
|
Vertex: str("one")
|
||||||
|
Vertex: str("state: one\n")
|
||||||
|
Vertex: str("state: three\n")
|
||||||
|
Vertex: str("state: two\n")
|
||||||
|
Vertex: str("three")
|
||||||
|
Vertex: str("three")
|
||||||
|
Vertex: str("timer")
|
||||||
|
Vertex: str("timer")
|
||||||
|
Vertex: str("timer")
|
||||||
|
Vertex: str("timer")
|
||||||
|
Vertex: str("timer")
|
||||||
|
Vertex: str("timer")
|
||||||
|
Vertex: str("two")
|
||||||
|
Vertex: str("two")
|
||||||
|
Vertex: str("||")
|
||||||
|
Vertex: var(exchanged)
|
||||||
|
Vertex: var(hostname)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(ns)
|
||||||
|
Vertex: var(state)
|
||||||
|
Vertex: var(state)
|
||||||
|
Vertex: var(state)
|
||||||
|
Vertex: var(state)
|
||||||
52
lang/interpret_test/TestAstFunc1/slow_unification0/main.mcl
Normal file
52
lang/interpret_test/TestAstFunc1/slow_unification0/main.mcl
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# state machine that previously experienced unusable slow type unification
|
||||||
|
import "world"
|
||||||
|
|
||||||
|
$ns = "estate"
|
||||||
|
$exchanged = world.kvlookup($ns)
|
||||||
|
$state = maplookup($exchanged, $hostname, "default")
|
||||||
|
|
||||||
|
if $state == "one" || $state == "default" {
|
||||||
|
|
||||||
|
file "/tmp/mgmt/state" {
|
||||||
|
content => "state: one\n",
|
||||||
|
}
|
||||||
|
|
||||||
|
exec "timer" {
|
||||||
|
cmd => "/usr/bin/sleep 1s",
|
||||||
|
}
|
||||||
|
kv "${ns}" {
|
||||||
|
key => $ns,
|
||||||
|
value => "two",
|
||||||
|
}
|
||||||
|
Exec["timer"] -> Kv["${ns}"]
|
||||||
|
}
|
||||||
|
if $state == "two" {
|
||||||
|
|
||||||
|
file "/tmp/mgmt/state" {
|
||||||
|
content => "state: two\n",
|
||||||
|
}
|
||||||
|
|
||||||
|
exec "timer" {
|
||||||
|
cmd => "/usr/bin/sleep 1s",
|
||||||
|
}
|
||||||
|
kv "${ns}" {
|
||||||
|
key => $ns,
|
||||||
|
value => "three",
|
||||||
|
}
|
||||||
|
Exec["timer"] -> Kv["${ns}"]
|
||||||
|
}
|
||||||
|
if $state == "three" {
|
||||||
|
|
||||||
|
file "/tmp/mgmt/state" {
|
||||||
|
content => "state: three\n",
|
||||||
|
}
|
||||||
|
|
||||||
|
exec "timer" {
|
||||||
|
cmd => "/usr/bin/sleep 1s",
|
||||||
|
}
|
||||||
|
kv "${ns}" {
|
||||||
|
key => $ns,
|
||||||
|
value => "one",
|
||||||
|
}
|
||||||
|
Exec["timer"] -> Kv["${ns}"]
|
||||||
|
}
|
||||||
@@ -185,7 +185,13 @@ func (obj *Lang) Init() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
obj.Logf("running type unification...")
|
obj.Logf("running type unification...")
|
||||||
if err := unification.Unify(obj.ast, unification.SimpleInvariantSolverLogger(logf)); err != nil {
|
unifier := &unification.Unifier{
|
||||||
|
AST: obj.ast,
|
||||||
|
Solver: unification.SimpleInvariantSolverLogger(logf),
|
||||||
|
Debug: obj.Debug,
|
||||||
|
Logf: logf,
|
||||||
|
}
|
||||||
|
if err := unifier.Unify(); err != nil {
|
||||||
return errwrap.Wrapf(err, "could not unify types")
|
return errwrap.Wrapf(err, "could not unify types")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2977,6 +2977,15 @@ type StmtInclude struct {
|
|||||||
// Nevertheless, it is a useful facility for operations that might only apply to
|
// Nevertheless, it is a useful facility for operations that might only apply to
|
||||||
// a select number of node types, since they won't need extra noop iterators...
|
// a select number of node types, since they won't need extra noop iterators...
|
||||||
func (obj *StmtInclude) Apply(fn func(interfaces.Node) error) error {
|
func (obj *StmtInclude) Apply(fn func(interfaces.Node) error) error {
|
||||||
|
// If the class exists, then descend into it, because at this point, the
|
||||||
|
// copy of the original class that is stored here, is the effective
|
||||||
|
// class that we care about for type unification, and everything else...
|
||||||
|
// It's not clear if this is needed, but it's probably nor harmful atm.
|
||||||
|
if obj.class != nil {
|
||||||
|
if err := obj.class.Apply(fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
if obj.Args != nil {
|
if obj.Args != nil {
|
||||||
for _, x := range obj.Args {
|
for _, x := range obj.Args {
|
||||||
if err := x.Apply(fn); err != nil {
|
if err := x.Apply(fn); err != nil {
|
||||||
@@ -4890,7 +4899,11 @@ func (obj *ExprFunc) String() string {
|
|||||||
if obj.Return != nil {
|
if obj.Return != nil {
|
||||||
s += fmt.Sprintf(" %s", obj.Return.String())
|
s += fmt.Sprintf(" %s", obj.Return.String())
|
||||||
}
|
}
|
||||||
|
if obj.Body == nil {
|
||||||
|
s += fmt.Sprintf(" { ??? }") // TODO: why does this happen?
|
||||||
|
} else {
|
||||||
s += fmt.Sprintf(" { %s }", obj.Body.String())
|
s += fmt.Sprintf(" { %s }", obj.Body.String())
|
||||||
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -34,15 +34,16 @@ const (
|
|||||||
// SimpleInvariantSolver with the log parameter of your choice specified. The
|
// SimpleInvariantSolver with the log parameter of your choice specified. The
|
||||||
// result satisfies the correct signature for the solver parameter of the
|
// result satisfies the correct signature for the solver parameter of the
|
||||||
// Unification function.
|
// Unification function.
|
||||||
func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func([]interfaces.Invariant) (*InvariantSolution, error) {
|
func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func([]interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) {
|
||||||
return func(invariants []interfaces.Invariant) (*InvariantSolution, error) {
|
return func(invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) {
|
||||||
return SimpleInvariantSolver(invariants, logf)
|
return SimpleInvariantSolver(invariants, expected, logf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimpleInvariantSolver is an iterative invariant solver for AST expressions.
|
// SimpleInvariantSolver is an iterative invariant solver for AST expressions.
|
||||||
// It is intended to be very simple, even if it's computationally inefficient.
|
// It is intended to be very simple, even if it's computationally inefficient.
|
||||||
func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format string, v ...interface{})) (*InvariantSolution, error) {
|
func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) {
|
||||||
|
debug := false // XXX: add to interface
|
||||||
logf("%s: invariants:", Name)
|
logf("%s: invariants:", Name)
|
||||||
for i, x := range invariants {
|
for i, x := range invariants {
|
||||||
logf("invariant(%d): %T: %s", i, x, x)
|
logf("invariant(%d): %T: %s", i, x, x)
|
||||||
@@ -112,8 +113,18 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s
|
|||||||
structPartials := make(map[interfaces.Expr]map[interfaces.Expr]*types.Type)
|
structPartials := make(map[interfaces.Expr]map[interfaces.Expr]*types.Type)
|
||||||
funcPartials := make(map[interfaces.Expr]map[interfaces.Expr]*types.Type)
|
funcPartials := make(map[interfaces.Expr]map[interfaces.Expr]*types.Type)
|
||||||
|
|
||||||
|
isSolved := func(solved map[interfaces.Expr]*types.Type) bool {
|
||||||
|
for _, x := range expected {
|
||||||
|
if typ, exists := solved[x]; !exists || typ == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
logf("%s: starting loop with %d equalities", Name, len(equalities))
|
logf("%s: starting loop with %d equalities", Name, len(equalities))
|
||||||
// run until we're solved, stop consuming equalities, or type clash
|
// run until we're solved, stop consuming equalities, or type clash
|
||||||
|
Loop:
|
||||||
for {
|
for {
|
||||||
logf("%s: iterate...", Name)
|
logf("%s: iterate...", Name)
|
||||||
if len(equalities) == 0 && len(exclusives) == 0 {
|
if len(equalities) == 0 && len(exclusives) == 0 {
|
||||||
@@ -498,11 +509,71 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s
|
|||||||
}
|
}
|
||||||
} // end inner for loop
|
} // end inner for loop
|
||||||
if len(used) == 0 {
|
if len(used) == 0 {
|
||||||
// looks like we're now ambiguous, but if we have any
|
// Looks like we're now ambiguous, but if we have any
|
||||||
// exclusives, recurse into each possibility to see if
|
// exclusives, recurse into each possibility to see if
|
||||||
// one of them can help solve this! first one wins. add
|
// one of them can help solve this! first one wins. Add
|
||||||
// in the exclusive to the current set of equalities!
|
// in the exclusive to the current set of equalities!
|
||||||
|
|
||||||
|
// To decrease the problem space, first check if we have
|
||||||
|
// enough solutions to solve everything. If so, then we
|
||||||
|
// don't need to solve any exclusives, and instead we
|
||||||
|
// only need to verify that they don't conflict with the
|
||||||
|
// found solution, which reduces the search space...
|
||||||
|
|
||||||
|
// Another optimization that can be done before we run
|
||||||
|
// the combinatorial exclusive solver, is we can look at
|
||||||
|
// each exclusive, and remove the ones that already
|
||||||
|
// match, because they don't tell us any new information
|
||||||
|
// that we don't already know. We can also fail early
|
||||||
|
// if anything proves we're already inconsistent.
|
||||||
|
|
||||||
|
// These two optimizations turn out to use the exact
|
||||||
|
// same algorithm and code, so they're combined here...
|
||||||
|
if isSolved(solved) {
|
||||||
|
logf("%s: solved early with %d exclusives left!", Name, len(exclusives))
|
||||||
|
} else {
|
||||||
|
logf("%s: unsolved with %d exclusives left!", Name, len(exclusives))
|
||||||
|
}
|
||||||
|
// check for consistency against remaining invariants
|
||||||
|
done := []int{}
|
||||||
|
for i, invar := range exclusives {
|
||||||
|
// test each one to see if at least one works
|
||||||
|
match, err := invar.Matches(solved)
|
||||||
|
if err != nil {
|
||||||
|
if debug {
|
||||||
|
logf("exclusive invar failed: %+v", invar)
|
||||||
|
}
|
||||||
|
return nil, errwrap.Wrapf(err, "inconsistent exclusive")
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
done = append(done, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove exclusives that matched correctly
|
||||||
|
for i := len(done) - 1; i >= 0; i-- {
|
||||||
|
ix := done[i] // delete index that was marked as done!
|
||||||
|
exclusives = append(exclusives[:ix], exclusives[ix+1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(exclusives) == 0 {
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Lastly, we could loop through each exclusive
|
||||||
|
// and see if it only has a single, easy solution. For
|
||||||
|
// example, if we know that an exclusive is A or B or C
|
||||||
|
// and that B and C are inconsistent, then we can
|
||||||
|
// replace the exclusive with a single invariant and
|
||||||
|
// then run that through our solver. We can do this
|
||||||
|
// iteratively (recursively in our case) so that if
|
||||||
|
// we're lucky, we rarely need to run the raw exclusive
|
||||||
|
// combinatorial solver which is slow.
|
||||||
|
|
||||||
|
// TODO: We could try and replace our combinatorial
|
||||||
|
// exclusive solver with a real SAT solver algorithm.
|
||||||
|
|
||||||
// what have we learned for sure so far?
|
// what have we learned for sure so far?
|
||||||
partialSolutions := []interfaces.Invariant{}
|
partialSolutions := []interfaces.Invariant{}
|
||||||
logf("%s: %d solved, %d unsolved, and %d exclusives left", Name, len(solved), len(equalities), len(exclusives))
|
logf("%s: %d solved, %d unsolved, and %d exclusives left", Name, len(solved), len(equalities), len(exclusives))
|
||||||
@@ -535,7 +606,7 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s
|
|||||||
recursiveInvariants = append(recursiveInvariants, partialSolutions...)
|
recursiveInvariants = append(recursiveInvariants, partialSolutions...)
|
||||||
recursiveInvariants = append(recursiveInvariants, ex...)
|
recursiveInvariants = append(recursiveInvariants, ex...)
|
||||||
logf("%s: recursing...", Name)
|
logf("%s: recursing...", Name)
|
||||||
solution, err := SimpleInvariantSolver(recursiveInvariants, logf)
|
solution, err := SimpleInvariantSolver(recursiveInvariants, expected, logf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf("%s: recursive solution failed: %+v", Name, err)
|
logf("%s: recursive solution failed: %+v", Name, err)
|
||||||
continue // no solution found here...
|
continue // no solution found here...
|
||||||
|
|||||||
@@ -25,6 +25,18 @@ import (
|
|||||||
"github.com/purpleidea/mgmt/lang/types"
|
"github.com/purpleidea/mgmt/lang/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Unifier holds all the data that the Unify function will need for it to run.
|
||||||
|
type Unifier struct {
|
||||||
|
// AST is the input abstract syntax tree to unify.
|
||||||
|
AST interfaces.Stmt
|
||||||
|
|
||||||
|
// Solver is the solver algorithm implementation to use.
|
||||||
|
Solver func([]interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error)
|
||||||
|
|
||||||
|
Debug bool
|
||||||
|
Logf func(format string, v ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
// Unify takes an AST expression tree and attempts to assign types to every node
|
// Unify takes an AST expression tree and attempts to assign types to every node
|
||||||
// using the specified solver. The expression tree returns a list of invariants
|
// using the specified solver. The expression tree returns a list of invariants
|
||||||
// (or constraints) which must be met in order to find a unique value for the
|
// (or constraints) which must be met in order to find a unique value for the
|
||||||
@@ -37,32 +49,77 @@ import (
|
|||||||
// type. This function and logic was invented after the author could not find
|
// type. This function and logic was invented after the author could not find
|
||||||
// any proper literature or examples describing a well-known implementation of
|
// any proper literature or examples describing a well-known implementation of
|
||||||
// this process. Improvements and polite recommendations are welcome.
|
// this process. Improvements and polite recommendations are welcome.
|
||||||
func Unify(ast interfaces.Stmt, solver func([]interfaces.Invariant) (*InvariantSolution, error)) error {
|
func (obj *Unifier) Unify() error {
|
||||||
//log.Printf("unification: tree: %+v", ast) // debug
|
if obj.AST == nil {
|
||||||
if ast == nil {
|
return fmt.Errorf("the AST is nil")
|
||||||
return fmt.Errorf("AST is nil")
|
}
|
||||||
|
if obj.Solver == nil {
|
||||||
|
return fmt.Errorf("the Solver is missing")
|
||||||
|
}
|
||||||
|
if obj.Logf == nil {
|
||||||
|
return fmt.Errorf("the Logf function is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
invariants, err := ast.Unify()
|
if obj.Debug {
|
||||||
|
obj.Logf("tree: %+v", obj.AST)
|
||||||
|
}
|
||||||
|
invariants, err := obj.AST.Unify()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
solved, err := solver(invariants)
|
// build a list of what we think we need to solve for to succeed
|
||||||
|
exprs := []interfaces.Expr{}
|
||||||
|
for _, x := range invariants {
|
||||||
|
exprs = append(exprs, x.ExprList()...)
|
||||||
|
}
|
||||||
|
exprMap := ExprListToExprMap(exprs) // makes searching faster
|
||||||
|
exprList := ExprMapToExprList(exprMap) // makes it unique (no duplicates)
|
||||||
|
|
||||||
|
solved, err := obj.Solver(invariants, exprList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: ideally we would know how many different expressions need their
|
// determine what expr's we need to solve for
|
||||||
// types set in the AST and then ensure we have this many unique
|
if obj.Debug {
|
||||||
// solutions, and if not, then fail. This would ensure we don't have an
|
obj.Logf("expr count: %d", len(exprList))
|
||||||
// AST that is only partially populated with the correct types.
|
//for _, x := range exprList {
|
||||||
|
// obj.Logf("> %p (%+v)", x, x)
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
|
||||||
//log.Printf("unification: found a solution!") // TODO: get a logf function passed in...
|
// XXX: why doesn't `len(exprList)` always == `len(solved.Solutions)` ?
|
||||||
|
// XXX: is it due to the extra ExprAny ??? I see an extra function sometimes...
|
||||||
|
|
||||||
|
if obj.Debug {
|
||||||
|
obj.Logf("solutions count: %d", len(solved.Solutions))
|
||||||
|
//for _, x := range solved.Solutions {
|
||||||
|
// obj.Logf("> %p (%+v) -- %s", x.Expr, x.Type, x.Expr.String())
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine that our solver produced a solution for every expr that
|
||||||
|
// we're interested in. If it didn't, and it didn't error, then it's a
|
||||||
|
// bug. We check for this because we care about safety, this ensures
|
||||||
|
// that our AST will get fully populated with the correct types!
|
||||||
|
for _, x := range solved.Solutions {
|
||||||
|
delete(exprMap, x.Expr) // remove everything we know about
|
||||||
|
}
|
||||||
|
if c := len(exprMap); c > 0 { // if there's anything left, it's bad...
|
||||||
|
// programming error!
|
||||||
|
return fmt.Errorf("got %d unbound expr's", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.Debug {
|
||||||
|
obj.Logf("found a solution!")
|
||||||
|
}
|
||||||
// solver has found a solution, apply it...
|
// solver has found a solution, apply it...
|
||||||
// we're modifying the AST, so code can't error now...
|
// we're modifying the AST, so code can't error now...
|
||||||
for _, x := range solved.Solutions {
|
for _, x := range solved.Solutions {
|
||||||
//log.Printf("unification: solution: %p => %+v\t(%+v)", x.Expr, x.Type, x.Expr.String()) // debug
|
if obj.Debug {
|
||||||
|
obj.Logf("solution: %p => %+v\t(%+v)", x.Expr, x.Type, x.Expr.String())
|
||||||
|
}
|
||||||
// apply this to each AST node
|
// apply this to each AST node
|
||||||
if err := x.Expr.SetType(x.Type); err != nil {
|
if err := x.Expr.SetType(x.Type); err != nil {
|
||||||
// programming error!
|
// programming error!
|
||||||
@@ -85,6 +142,24 @@ func (obj *EqualsInvariant) String() string {
|
|||||||
return fmt.Sprintf("%p == %s", obj.Expr, obj.Type)
|
return fmt.Sprintf("%p == %s", obj.Expr, obj.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *EqualsInvariant) ExprList() []interfaces.Expr {
|
||||||
|
return []interfaces.Expr{obj.Expr}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *EqualsInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
typ, exists := solved[obj.Expr]
|
||||||
|
if !exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err := typ.Cmp(obj.Type); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// EqualityInvariant is an invariant that symbolizes that the two expressions
|
// EqualityInvariant is an invariant that symbolizes that the two expressions
|
||||||
// must have equivalent types.
|
// must have equivalent types.
|
||||||
// TODO: is there a better name than EqualityInvariant
|
// TODO: is there a better name than EqualityInvariant
|
||||||
@@ -98,6 +173,26 @@ func (obj *EqualityInvariant) String() string {
|
|||||||
return fmt.Sprintf("%p == %p", obj.Expr1, obj.Expr2)
|
return fmt.Sprintf("%p == %p", obj.Expr1, obj.Expr2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *EqualityInvariant) ExprList() []interfaces.Expr {
|
||||||
|
return []interfaces.Expr{obj.Expr1, obj.Expr2}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *EqualityInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
t1, exists1 := solved[obj.Expr1]
|
||||||
|
t2, exists2 := solved[obj.Expr2]
|
||||||
|
if !exists1 || !exists2 {
|
||||||
|
return false, nil // not matched yet
|
||||||
|
}
|
||||||
|
if err := t1.Cmp(t2); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil // matched!
|
||||||
|
}
|
||||||
|
|
||||||
// EqualityInvariantList is an invariant that symbolizes that all the
|
// EqualityInvariantList is an invariant that symbolizes that all the
|
||||||
// expressions listed must have equivalent types.
|
// expressions listed must have equivalent types.
|
||||||
type EqualityInvariantList struct {
|
type EqualityInvariantList struct {
|
||||||
@@ -113,6 +208,32 @@ func (obj *EqualityInvariantList) String() string {
|
|||||||
return fmt.Sprintf("[%s]", strings.Join(a, ", "))
|
return fmt.Sprintf("[%s]", strings.Join(a, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *EqualityInvariantList) ExprList() []interfaces.Expr {
|
||||||
|
return obj.Exprs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *EqualityInvariantList) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
found := true // assume true
|
||||||
|
var typ *types.Type
|
||||||
|
for _, x := range obj.Exprs {
|
||||||
|
t, exists := solved[x]
|
||||||
|
if !exists {
|
||||||
|
found = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if typ == nil { // set the first time
|
||||||
|
typ = t
|
||||||
|
}
|
||||||
|
if err := typ.Cmp(t); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
// EqualityWrapListInvariant expresses that a list in Expr1 must have elements
|
// EqualityWrapListInvariant expresses that a list in Expr1 must have elements
|
||||||
// that have the same type as the expression in Expr2Val.
|
// that have the same type as the expression in Expr2Val.
|
||||||
type EqualityWrapListInvariant struct {
|
type EqualityWrapListInvariant struct {
|
||||||
@@ -125,6 +246,28 @@ func (obj *EqualityWrapListInvariant) String() string {
|
|||||||
return fmt.Sprintf("%p == [%p]", obj.Expr1, obj.Expr2Val)
|
return fmt.Sprintf("%p == [%p]", obj.Expr1, obj.Expr2Val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *EqualityWrapListInvariant) ExprList() []interfaces.Expr {
|
||||||
|
return []interfaces.Expr{obj.Expr1, obj.Expr2Val}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *EqualityWrapListInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
t1, exists1 := solved[obj.Expr1] // list type
|
||||||
|
t2, exists2 := solved[obj.Expr2Val]
|
||||||
|
if !exists1 || !exists2 {
|
||||||
|
return false, nil // not matched yet
|
||||||
|
}
|
||||||
|
if t1.Kind != types.KindList {
|
||||||
|
return false, fmt.Errorf("expected list kind")
|
||||||
|
}
|
||||||
|
if err := t1.Val.Cmp(t2); err != nil {
|
||||||
|
return false, err // inconsistent!
|
||||||
|
}
|
||||||
|
return true, nil // matched!
|
||||||
|
}
|
||||||
|
|
||||||
// EqualityWrapMapInvariant expresses that a map in Expr1 must have keys that
|
// EqualityWrapMapInvariant expresses that a map in Expr1 must have keys that
|
||||||
// match the type of the expression in Expr2Key and values that match the type
|
// match the type of the expression in Expr2Key and values that match the type
|
||||||
// of the expression in Expr2Val.
|
// of the expression in Expr2Val.
|
||||||
@@ -139,6 +282,32 @@ func (obj *EqualityWrapMapInvariant) String() string {
|
|||||||
return fmt.Sprintf("%p == {%p: %p}", obj.Expr1, obj.Expr2Key, obj.Expr2Val)
|
return fmt.Sprintf("%p == {%p: %p}", obj.Expr1, obj.Expr2Key, obj.Expr2Val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *EqualityWrapMapInvariant) ExprList() []interfaces.Expr {
|
||||||
|
return []interfaces.Expr{obj.Expr1, obj.Expr2Key, obj.Expr2Val}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *EqualityWrapMapInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
t1, exists1 := solved[obj.Expr1] // list type
|
||||||
|
t2, exists2 := solved[obj.Expr2Key]
|
||||||
|
t3, exists3 := solved[obj.Expr2Val]
|
||||||
|
if !exists1 || !exists2 || !exists3 {
|
||||||
|
return false, nil // not matched yet
|
||||||
|
}
|
||||||
|
if t1.Kind != types.KindMap {
|
||||||
|
return false, fmt.Errorf("expected map kind")
|
||||||
|
}
|
||||||
|
if err := t1.Key.Cmp(t2); err != nil {
|
||||||
|
return false, err // inconsistent!
|
||||||
|
}
|
||||||
|
if err := t1.Val.Cmp(t3); err != nil {
|
||||||
|
return false, err // inconsistent!
|
||||||
|
}
|
||||||
|
return true, nil // matched!
|
||||||
|
}
|
||||||
|
|
||||||
// EqualityWrapStructInvariant expresses that a struct in Expr1 must have fields
|
// EqualityWrapStructInvariant expresses that a struct in Expr1 must have fields
|
||||||
// that match the type of the expressions listed in Expr2Map.
|
// that match the type of the expressions listed in Expr2Map.
|
||||||
type EqualityWrapStructInvariant struct {
|
type EqualityWrapStructInvariant struct {
|
||||||
@@ -163,6 +332,49 @@ func (obj *EqualityWrapStructInvariant) String() string {
|
|||||||
return fmt.Sprintf("%p == struct{%s}", obj.Expr1, strings.Join(s, "; "))
|
return fmt.Sprintf("%p == struct{%s}", obj.Expr1, strings.Join(s, "; "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *EqualityWrapStructInvariant) ExprList() []interfaces.Expr {
|
||||||
|
exprs := []interfaces.Expr{obj.Expr1}
|
||||||
|
for _, x := range obj.Expr2Map {
|
||||||
|
exprs = append(exprs, x)
|
||||||
|
}
|
||||||
|
return exprs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *EqualityWrapStructInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
t1, exists1 := solved[obj.Expr1] // list type
|
||||||
|
if !exists1 {
|
||||||
|
return false, nil // not matched yet
|
||||||
|
}
|
||||||
|
if t1.Kind != types.KindStruct {
|
||||||
|
return false, fmt.Errorf("expected struct kind")
|
||||||
|
}
|
||||||
|
|
||||||
|
found := true // assume true
|
||||||
|
for _, key := range obj.Expr2Ord {
|
||||||
|
_, exists := t1.Map[key]
|
||||||
|
if !exists {
|
||||||
|
return false, fmt.Errorf("missing invariant struct key of: `%s`", key)
|
||||||
|
}
|
||||||
|
e, exists := obj.Expr2Map[key]
|
||||||
|
if !exists {
|
||||||
|
return false, fmt.Errorf("missing matched struct key of: `%s`", key)
|
||||||
|
}
|
||||||
|
t, exists := solved[e]
|
||||||
|
if !exists {
|
||||||
|
found = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := t1.Map[key].Cmp(t); err != nil {
|
||||||
|
return false, err // inconsistent!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return found, nil // matched!
|
||||||
|
}
|
||||||
|
|
||||||
// EqualityWrapFuncInvariant expresses that a func in Expr1 must have args that
|
// EqualityWrapFuncInvariant expresses that a func in Expr1 must have args that
|
||||||
// match the type of the expressions listed in Expr2Map and a return value that
|
// match the type of the expressions listed in Expr2Map and a return value that
|
||||||
// matches the type of the expression in Expr2Out.
|
// matches the type of the expression in Expr2Out.
|
||||||
@@ -190,6 +402,58 @@ func (obj *EqualityWrapFuncInvariant) String() string {
|
|||||||
return fmt.Sprintf("%p == func{%s} %p", obj.Expr1, strings.Join(s, "; "), obj.Expr2Out)
|
return fmt.Sprintf("%p == func{%s} %p", obj.Expr1, strings.Join(s, "; "), obj.Expr2Out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *EqualityWrapFuncInvariant) ExprList() []interfaces.Expr {
|
||||||
|
exprs := []interfaces.Expr{obj.Expr1}
|
||||||
|
for _, x := range obj.Expr2Map {
|
||||||
|
exprs = append(exprs, x)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, obj.Expr2Out)
|
||||||
|
return exprs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *EqualityWrapFuncInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
t1, exists1 := solved[obj.Expr1] // list type
|
||||||
|
if !exists1 {
|
||||||
|
return false, nil // not matched yet
|
||||||
|
}
|
||||||
|
if t1.Kind != types.KindFunc {
|
||||||
|
return false, fmt.Errorf("expected func kind")
|
||||||
|
}
|
||||||
|
|
||||||
|
found := true // assume true
|
||||||
|
for _, key := range obj.Expr2Ord {
|
||||||
|
_, exists := t1.Map[key]
|
||||||
|
if !exists {
|
||||||
|
return false, fmt.Errorf("missing invariant struct key of: `%s`", key)
|
||||||
|
}
|
||||||
|
e, exists := obj.Expr2Map[key]
|
||||||
|
if !exists {
|
||||||
|
return false, fmt.Errorf("missing matched struct key of: `%s`", key)
|
||||||
|
}
|
||||||
|
t, exists := solved[e]
|
||||||
|
if !exists {
|
||||||
|
found = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := t1.Map[key].Cmp(t); err != nil {
|
||||||
|
return false, err // inconsistent!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t, exists := solved[obj.Expr2Out]
|
||||||
|
if !exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err := t1.Out.Cmp(t); err != nil {
|
||||||
|
return false, err // inconsistent!
|
||||||
|
}
|
||||||
|
|
||||||
|
return found, nil // matched!
|
||||||
|
}
|
||||||
|
|
||||||
// ConjunctionInvariant represents a list of invariants which must all be true
|
// ConjunctionInvariant represents a list of invariants which must all be true
|
||||||
// together. In other words, it's a grouping construct for a set of invariants.
|
// together. In other words, it's a grouping construct for a set of invariants.
|
||||||
type ConjunctionInvariant struct {
|
type ConjunctionInvariant struct {
|
||||||
@@ -206,6 +470,31 @@ func (obj *ConjunctionInvariant) String() string {
|
|||||||
return fmt.Sprintf("[%s]", strings.Join(a, ", "))
|
return fmt.Sprintf("[%s]", strings.Join(a, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *ConjunctionInvariant) ExprList() []interfaces.Expr {
|
||||||
|
exprs := []interfaces.Expr{}
|
||||||
|
for _, x := range obj.Invariants {
|
||||||
|
exprs = append(exprs, x.ExprList()...)
|
||||||
|
}
|
||||||
|
return exprs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *ConjunctionInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
found := true // assume true
|
||||||
|
for _, invar := range obj.Invariants {
|
||||||
|
match, err := invar.Matches(solved)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
found = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ExclusiveInvariant represents a list of invariants where one and *only* one
|
// ExclusiveInvariant represents a list of invariants where one and *only* one
|
||||||
// should hold true. To combine multiple invariants in one of the list elements,
|
// should hold true. To combine multiple invariants in one of the list elements,
|
||||||
// you can group multiple invariants together using a ConjunctionInvariant. Do
|
// you can group multiple invariants together using a ConjunctionInvariant. Do
|
||||||
@@ -226,6 +515,54 @@ func (obj *ExclusiveInvariant) String() string {
|
|||||||
return fmt.Sprintf("[%s]", strings.Join(a, ", "))
|
return fmt.Sprintf("[%s]", strings.Join(a, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *ExclusiveInvariant) ExprList() []interfaces.Expr {
|
||||||
|
// XXX: We should do this if we assume that exclusives don't have some
|
||||||
|
// sort of transient expr to satisfy that doesn't disappear depending on
|
||||||
|
// which choice in the exclusive is chosen...
|
||||||
|
//exprs := []interfaces.Expr{}
|
||||||
|
//for _, x := range obj.Invariants {
|
||||||
|
// exprs = append(exprs, x.ExprList()...)
|
||||||
|
//}
|
||||||
|
//return exprs
|
||||||
|
// XXX: But if we ever specify an expr in this exclusive that isn't
|
||||||
|
// referenced anywhere else, then we'd need to use the above so that our
|
||||||
|
// type unification algorithm knows not to stop too early.
|
||||||
|
return []interfaces.Expr{} // XXX: Do we want to the set instead?
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors. Because this partial invariant requires only
|
||||||
|
// one to be true, it will mask children errors, since it's normal for only one
|
||||||
|
// to be consistent.
|
||||||
|
func (obj *ExclusiveInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
found := false
|
||||||
|
reterr := fmt.Errorf("all exclusives errored")
|
||||||
|
for _, invar := range obj.Invariants {
|
||||||
|
match, err := invar.Matches(solved)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
// at least one was false, so we're not done here yet...
|
||||||
|
// we don't want to error yet, since we can't know there
|
||||||
|
// won't be a conflict once we get more data about this!
|
||||||
|
reterr = nil // clear the error
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if found { // we already found one
|
||||||
|
return false, fmt.Errorf("more than one exclusive solution")
|
||||||
|
}
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if found { // we got exactly one valid solution
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, reterr
|
||||||
|
}
|
||||||
|
|
||||||
// exclusivesProduct returns a list of different products produced from the
|
// exclusivesProduct returns a list of different products produced from the
|
||||||
// combinatorial product of the list of exclusives. Each ExclusiveInvariant
|
// combinatorial product of the list of exclusives. Each ExclusiveInvariant
|
||||||
// must contain between one and more Invariants. This takes every combination of
|
// must contain between one and more Invariants. This takes every combination of
|
||||||
@@ -278,8 +615,30 @@ func (obj *AnyInvariant) String() string {
|
|||||||
return fmt.Sprintf("%p == *", obj.Expr)
|
return fmt.Sprintf("%p == *", obj.Expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions in this invariant.
|
||||||
|
func (obj *AnyInvariant) ExprList() []interfaces.Expr {
|
||||||
|
return []interfaces.Expr{obj.Expr}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches returns whether an invariant matches the existing solution. If it is
|
||||||
|
// inconsistent, then it errors.
|
||||||
|
func (obj *AnyInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) {
|
||||||
|
_, exists := solved[obj.Expr] // we only care that it is found.
|
||||||
|
return exists, nil
|
||||||
|
}
|
||||||
|
|
||||||
// InvariantSolution lists a trivial set of EqualsInvariant mappings so that you
|
// InvariantSolution lists a trivial set of EqualsInvariant mappings so that you
|
||||||
// can populate your AST with SetType calls in a simple loop.
|
// can populate your AST with SetType calls in a simple loop.
|
||||||
type InvariantSolution struct {
|
type InvariantSolution struct {
|
||||||
Solutions []*EqualsInvariant // list of trivial solutions for each node
|
Solutions []*EqualsInvariant // list of trivial solutions for each node
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExprList returns the list of valid expressions. This struct is not part of
|
||||||
|
// the invariant interface, but it implements this anyways.
|
||||||
|
func (obj *InvariantSolution) ExprList() []interfaces.Expr {
|
||||||
|
exprs := []interfaces.Expr{}
|
||||||
|
for _, x := range obj.Solutions {
|
||||||
|
exprs = append(exprs, x.ExprList()...)
|
||||||
|
}
|
||||||
|
return exprs
|
||||||
|
}
|
||||||
|
|||||||
54
lang/unification/util.go
Normal file
54
lang/unification/util.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
// Mgmt
|
||||||
|
// Copyright (C) 2013-2019+ James Shubin and the project contributors
|
||||||
|
// Written by James Shubin <james@shubin.ca> and the project contributors
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU General Public License
|
||||||
|
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package unification
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/purpleidea/mgmt/lang/interfaces"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExprListToExprMap converts a list of expressions to a map that has the unique
|
||||||
|
// expr pointers as the keys. This is just an alternate representation of the
|
||||||
|
// same data structure. If you have any duplicate values in your list, they'll
|
||||||
|
// get removed when stored as a map.
|
||||||
|
func ExprListToExprMap(exprList []interfaces.Expr) map[interfaces.Expr]struct{} {
|
||||||
|
exprMap := make(map[interfaces.Expr]struct{})
|
||||||
|
for _, x := range exprList {
|
||||||
|
exprMap[x] = struct{}{}
|
||||||
|
}
|
||||||
|
return exprMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExprMapToExprList converts a map of expressions to a list that has the unique
|
||||||
|
// expr pointers as the values. This is just an alternate representation of the
|
||||||
|
// same data structure.
|
||||||
|
func ExprMapToExprList(exprMap map[interfaces.Expr]struct{}) []interfaces.Expr {
|
||||||
|
exprList := []interfaces.Expr{}
|
||||||
|
// TODO: sort by pointer address for determinism ?
|
||||||
|
for x := range exprMap {
|
||||||
|
exprList = append(exprList, x)
|
||||||
|
}
|
||||||
|
return exprList
|
||||||
|
}
|
||||||
|
|
||||||
|
// UniqueExprList returns a unique list of expressions with no duplicates. It
|
||||||
|
// does this my converting it to a map and then back. This isn't necessarily the
|
||||||
|
// most efficient way, and doesn't preserve list ordering.
|
||||||
|
func UniqueExprList(exprList []interfaces.Expr) []interfaces.Expr {
|
||||||
|
exprMap := ExprListToExprMap(exprList)
|
||||||
|
return ExprMapToExprList(exprMap)
|
||||||
|
}
|
||||||
@@ -819,7 +819,13 @@ func TestUnification1(t *testing.T) {
|
|||||||
logf := func(format string, v ...interface{}) {
|
logf := func(format string, v ...interface{}) {
|
||||||
t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...)
|
t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...)
|
||||||
}
|
}
|
||||||
err := unification.Unify(ast, unification.SimpleInvariantSolverLogger(logf))
|
unifier := &unification.Unifier{
|
||||||
|
AST: ast,
|
||||||
|
Solver: unification.SimpleInvariantSolverLogger(logf),
|
||||||
|
Debug: testing.Verbose(),
|
||||||
|
Logf: logf,
|
||||||
|
}
|
||||||
|
err := unifier.Unify()
|
||||||
|
|
||||||
// TODO: print out the AST's so that we can see the types
|
// TODO: print out the AST's so that we can see the types
|
||||||
t.Logf("\n\ntest #%d: AST (after): %+v\n", index, ast)
|
t.Logf("\n\ntest #%d: AST (after): %+v\n", index, ast)
|
||||||
|
|||||||
Reference in New Issue
Block a user