lang: unification: Add type inference state debugging
This presents things in a more formal way for those who are more familiar with standard type unification syntax. Patch created by Sam, and cleaned up by James.
This commit is contained in:
committed by
James Shubin
parent
170fb64bfc
commit
aec8e1db2d
@@ -19,6 +19,7 @@ package unification // TODO: can we put this solver in a sub-package?
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/purpleidea/mgmt/lang/interfaces"
|
||||
"github.com/purpleidea/mgmt/lang/types"
|
||||
@@ -57,6 +58,126 @@ func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) fun
|
||||
}
|
||||
}
|
||||
|
||||
// DebugSolverState helps us in understanding the state of the type unification
|
||||
// solver in a more mainstream format.
|
||||
// Example:
|
||||
//
|
||||
// solver state:
|
||||
//
|
||||
// * str("foo") :: str
|
||||
// * call:f(str("foo")) [0xc000ac9f10] :: ?1
|
||||
// * var(x) [0xc00088d840] :: ?2
|
||||
// * param(x) [0xc00000f950] :: ?3
|
||||
// * func(x) { var(x) } [0xc0000e9680] :: ?4
|
||||
// * ?2 = ?3
|
||||
// * ?4 = func(arg0 str) ?1
|
||||
// * ?4 = func(x str) ?2
|
||||
// * ?1 = ?2
|
||||
func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string {
|
||||
s := ""
|
||||
|
||||
// all the relevant Exprs
|
||||
count := 0
|
||||
exprs := make(map[interfaces.Expr]int)
|
||||
for _, equality := range equalities {
|
||||
for _, expr := range equality.ExprList() {
|
||||
count++
|
||||
exprs[expr] = count // for sorting
|
||||
}
|
||||
}
|
||||
|
||||
// print the solved Exprs first
|
||||
for expr, typ := range solved {
|
||||
s += fmt.Sprintf("%v :: %v\n", expr, typ)
|
||||
delete(exprs, expr)
|
||||
}
|
||||
|
||||
sortedExprs := []interfaces.Expr{}
|
||||
for k := range exprs {
|
||||
sortedExprs = append(sortedExprs, k)
|
||||
}
|
||||
sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] })
|
||||
|
||||
// for each remaining expr, generate a shorter name than the full pointer
|
||||
nextVar := 1
|
||||
shortNames := map[interfaces.Expr]string{}
|
||||
for _, expr := range sortedExprs {
|
||||
shortNames[expr] = fmt.Sprintf("?%d", nextVar)
|
||||
nextVar++
|
||||
s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr])
|
||||
}
|
||||
|
||||
// print all the equalities using the short names
|
||||
for _, equality := range equalities {
|
||||
switch e := equality.(type) {
|
||||
case *interfaces.EqualsInvariant:
|
||||
_, ok := solved[e.Expr]
|
||||
if !ok {
|
||||
s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type)
|
||||
} else {
|
||||
// if solved, then this is redundant, don't print anything
|
||||
}
|
||||
|
||||
case *interfaces.EqualityInvariant:
|
||||
type1, ok1 := solved[e.Expr1]
|
||||
type2, ok2 := solved[e.Expr2]
|
||||
if !ok1 && !ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2])
|
||||
} else if ok1 && !ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2])
|
||||
} else if !ok1 && ok2 {
|
||||
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2)
|
||||
} else {
|
||||
// if completely solved, then this is redundant, don't print anything
|
||||
}
|
||||
|
||||
case *interfaces.EqualityWrapFuncInvariant:
|
||||
funcType, funcOk := solved[e.Expr1]
|
||||
|
||||
args := ""
|
||||
argsOk := true
|
||||
for i, argName := range e.Expr2Ord {
|
||||
if i > 0 {
|
||||
args += ", "
|
||||
}
|
||||
argExpr := e.Expr2Map[argName]
|
||||
argType, ok := solved[argExpr]
|
||||
if !ok {
|
||||
args += fmt.Sprintf("%s %s", argName, shortNames[argExpr])
|
||||
argsOk = false
|
||||
} else {
|
||||
args += fmt.Sprintf("%s %s", argName, argType)
|
||||
}
|
||||
}
|
||||
|
||||
outType, outOk := solved[e.Expr2Out]
|
||||
|
||||
if !funcOk || !argsOk || !outOk {
|
||||
if !funcOk && !outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out])
|
||||
} else if !funcOk && outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType)
|
||||
} else if funcOk && !outOk {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out])
|
||||
} else {
|
||||
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType)
|
||||
}
|
||||
}
|
||||
|
||||
case *interfaces.CallFuncArgsValueInvariant:
|
||||
// skip, not used in the examples I care about
|
||||
|
||||
case *interfaces.AnyInvariant:
|
||||
// skip, not used in the examples I care about
|
||||
|
||||
default:
|
||||
s += fmt.Sprintf("%v\n", equality)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SimpleInvariantSolver is an iterative invariant solver for AST expressions.
|
||||
// It is intended to be very simple, even if it's computationally inefficient.
|
||||
func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) {
|
||||
@@ -876,6 +997,7 @@ Loop:
|
||||
logf("%s: unsolved: %+v", Name, x)
|
||||
}
|
||||
}
|
||||
logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities))
|
||||
|
||||
// Lastly, we could loop through each exclusive and see
|
||||
// if it only has a single, easy solution. For example,
|
||||
|
||||
Reference in New Issue
Block a user