lang: Add modern type unification implementation
This adds a modern type unification algorithm, which drastically improves performance, particularly for bigger programs. This required a change to the AST to add TypeCheck methods (for Stmt) and Infer/Check methods (for Expr). This also changed how the functions express their invariants, and as a result this was changed as well. This greatly improves the way we express these invariants, and as a result it makes adding new polymorphic functions significantly easier. This also makes error output for the user a lot better in pretty much all scenarios. The one downside of this patch is that a good chunk of it is merged in this giant single commit since it was hard to do it step-wise. That's not the end of the world. This couldn't be done without the guidance of Sam who helped me in explaining, debugging, and writing all the sneaky algorithmic parts and much more. Thanks again Sam! Co-authored-by: Samuel Gélineau <gelisam@gmail.com>
This commit is contained in:
@@ -34,8 +34,6 @@ package unification
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/purpleidea/mgmt/lang/interfaces"
|
||||
"github.com/purpleidea/mgmt/lang/types"
|
||||
@@ -99,77 +97,28 @@ func (obj *Unifier) Unify(ctx context.Context) error {
|
||||
if obj.Debug {
|
||||
obj.Logf("tree: %+v", obj.AST)
|
||||
}
|
||||
invariants, err := obj.AST.Unify()
|
||||
|
||||
// This used to take a map[string]*types.Type type context as in/output.
|
||||
unificationInvariants, err := obj.AST.TypeCheck() // ([]*UnificationInvariant, error)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// build a list of what we think we need to solve for to succeed
|
||||
exprs := []interfaces.Expr{}
|
||||
skips := make(map[interfaces.Expr]struct{})
|
||||
for _, x := range invariants {
|
||||
if si, ok := x.(*interfaces.SkipInvariant); ok {
|
||||
skips[si.Expr] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
exprs = append(exprs, x.ExprList()...)
|
||||
data := &Data{
|
||||
UnificationInvariants: unificationInvariants,
|
||||
}
|
||||
exprMap := ExprListToExprMap(exprs) // makes searching faster
|
||||
exprList := ExprMapToExprList(exprMap) // makes it unique (no duplicates)
|
||||
|
||||
solved, err := obj.Solver.Solve(ctx, invariants, exprList)
|
||||
solved, err := obj.Solver.Solve(ctx, data) // often does union find
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// determine what expr's we need to solve for
|
||||
obj.Logf("found a solution of length: %d", len(solved.Solutions))
|
||||
if obj.Debug {
|
||||
obj.Logf("expr count: %d", len(exprList))
|
||||
//for _, x := range exprList {
|
||||
// obj.Logf("> %p (%+v)", x, x)
|
||||
//}
|
||||
}
|
||||
|
||||
// XXX: why doesn't `len(exprList)` always == `len(solved.Solutions)` ?
|
||||
// XXX: is it due to the extra ExprAny ??? I see an extra function sometimes...
|
||||
|
||||
if obj.Debug {
|
||||
obj.Logf("solutions count: %d", len(solved.Solutions))
|
||||
//for _, x := range solved.Solutions {
|
||||
// obj.Logf("> %p (%+v) -- %s", x.Expr, x.Type, x.Expr.String())
|
||||
//}
|
||||
}
|
||||
|
||||
// Determine that our solver produced a solution for every expr that
|
||||
// we're interested in. If it didn't, and it didn't error, then it's a
|
||||
// bug. We check for this because we care about safety, this ensures
|
||||
// that our AST will get fully populated with the correct types!
|
||||
for _, x := range solved.Solutions {
|
||||
delete(exprMap, x.Expr) // remove everything we know about
|
||||
}
|
||||
if c := len(exprMap); c > 0 { // if there's anything left, it's bad...
|
||||
ptrs := []string{}
|
||||
disp := make(map[string]string) // display hack
|
||||
for i := range exprMap {
|
||||
s := fmt.Sprintf("%p", i) // pointer
|
||||
ptrs = append(ptrs, s)
|
||||
disp[s] = i.String()
|
||||
for _, x := range solved.Solutions {
|
||||
obj.Logf("> %p %s -- %s", x.Expr, x.Type, x.Expr.String())
|
||||
}
|
||||
sort.Strings(ptrs)
|
||||
// programming error!
|
||||
s := strings.Join(ptrs, ", ")
|
||||
|
||||
obj.Logf("got %d unbound expr's: %s", c, s)
|
||||
for i, s := range ptrs {
|
||||
obj.Logf("(%d) %s => %s", i, s, disp[s])
|
||||
}
|
||||
return fmt.Errorf("got %d unbound expr's: %s", c, s)
|
||||
}
|
||||
|
||||
if obj.Debug {
|
||||
obj.Logf("found a solution!")
|
||||
}
|
||||
// solver has found a solution, apply it...
|
||||
// we're modifying the AST, so code can't error now...
|
||||
for _, x := range solved.Solutions {
|
||||
@@ -177,22 +126,17 @@ func (obj *Unifier) Unify(ctx context.Context) error {
|
||||
// programming error ?
|
||||
return fmt.Errorf("unexpected invalid solution at: %p", x)
|
||||
}
|
||||
if _, exists := skips[x.Expr]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if obj.Debug {
|
||||
obj.Logf("solution: %p => %+v\t(%+v)", x.Expr, x.Type, x.Expr.String())
|
||||
}
|
||||
// apply this to each AST node
|
||||
if err := x.Expr.SetType(x.Type); err != nil {
|
||||
// programming error!
|
||||
// If we error here, it's probably a bug. Likely we
|
||||
// should have caught something during type unification,
|
||||
// but it slipped through and the function Build API is
|
||||
// catching it instead. Try and root cause it to avoid
|
||||
// leaving any ghosts in the code.
|
||||
return fmt.Errorf("error setting type: %+v, error: %+v", x.Expr, err)
|
||||
// SetType calls the Build() API, which functions as a
|
||||
// "check" step to add additional constraints that were
|
||||
// not possible during type unification.
|
||||
// TODO: Improve this error message!
|
||||
return fmt.Errorf("error setting type: %+v, error: %s", x.Expr, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -201,54 +145,12 @@ func (obj *Unifier) Unify(ctx context.Context) error {
|
||||
// InvariantSolution lists a trivial set of EqualsInvariant mappings so that you
|
||||
// can populate your AST with SetType calls in a simple loop.
|
||||
type InvariantSolution struct {
|
||||
Solutions []*interfaces.EqualsInvariant // list of trivial solutions for each node
|
||||
Solutions []*EqualsInvariant // list of trivial solutions for each node
|
||||
}
|
||||
|
||||
// ExprList returns the list of valid expressions. This struct is not part of
|
||||
// the invariant interface, but it implements this anyways.
|
||||
func (obj *InvariantSolution) ExprList() []interfaces.Expr {
|
||||
exprs := []interfaces.Expr{}
|
||||
for _, x := range obj.Solutions {
|
||||
exprs = append(exprs, x.ExprList()...)
|
||||
}
|
||||
return exprs
|
||||
}
|
||||
|
||||
// ExclusivesProduct returns a list of different products produced from the
|
||||
// combinatorial product of the list of exclusives. Each ExclusiveInvariant must
|
||||
// contain between one and more Invariants. This takes every combination of
|
||||
// Invariants (choosing one from each ExclusiveInvariant) and returns that list.
|
||||
// In other words, if you have three exclusives, with invariants named (A1, B1),
|
||||
// (A2), and (A3, B3, C3) you'll get: (A1, A2, A3), (A1, A2, B3), (A1, A2, C3),
|
||||
// (B1, A2, A3), (B1, A2, B3), (B1, A2, C3) as results for this function call.
|
||||
func ExclusivesProduct(exclusives []*interfaces.ExclusiveInvariant) [][]interfaces.Invariant {
|
||||
if len(exclusives) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
length := func(i int) int { return len(exclusives[i].Invariants) }
|
||||
|
||||
// NextIx sets ix to the lexicographically next value,
|
||||
// such that for each i > 0, 0 <= ix[i] < length(i).
|
||||
NextIx := func(ix []int) {
|
||||
for i := len(ix) - 1; i >= 0; i-- {
|
||||
ix[i]++
|
||||
if i == 0 || ix[i] < length(i) {
|
||||
return
|
||||
}
|
||||
ix[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
results := [][]interfaces.Invariant{}
|
||||
|
||||
for ix := make([]int, len(exclusives)); ix[0] < length(0); NextIx(ix) {
|
||||
x := []interfaces.Invariant{}
|
||||
for j, k := range ix {
|
||||
x = append(x, exclusives[j].Invariants[k])
|
||||
}
|
||||
results = append(results, x)
|
||||
}
|
||||
|
||||
return results
|
||||
// EqualsInvariant is an invariant that symbolizes that the expression has a
|
||||
// known type. It is used for producing solutions.
|
||||
type EqualsInvariant struct {
|
||||
Expr interfaces.Expr
|
||||
Type *types.Type
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user