lang: unification: Add type inference state debugging

This presents things in a more formal way for those who are more
familiar with standard type unification syntax.

Patch created by Sam, and cleaned up by James.
This commit is contained in:
Samuel Gélineau
2023-06-24 13:06:17 -04:00
committed by James Shubin
parent 170fb64bfc
commit aec8e1db2d

View File

@@ -19,6 +19,7 @@ package unification // TODO: can we put this solver in a sub-package?
import ( import (
"fmt" "fmt"
"sort"
"github.com/purpleidea/mgmt/lang/interfaces" "github.com/purpleidea/mgmt/lang/interfaces"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -57,6 +58,126 @@ func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) fun
} }
} }
// DebugSolverState helps us in understanding the state of the type unification
// solver in a more mainstream format.
// Example:
//
// solver state:
//
// * str("foo") :: str
// * call:f(str("foo")) [0xc000ac9f10] :: ?1
// * var(x) [0xc00088d840] :: ?2
// * param(x) [0xc00000f950] :: ?3
// * func(x) { var(x) } [0xc0000e9680] :: ?4
// * ?2 = ?3
// * ?4 = func(arg0 str) ?1
// * ?4 = func(x str) ?2
// * ?1 = ?2
func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string {
s := ""
// all the relevant Exprs
count := 0
exprs := make(map[interfaces.Expr]int)
for _, equality := range equalities {
for _, expr := range equality.ExprList() {
count++
exprs[expr] = count // for sorting
}
}
// print the solved Exprs first
for expr, typ := range solved {
s += fmt.Sprintf("%v :: %v\n", expr, typ)
delete(exprs, expr)
}
sortedExprs := []interfaces.Expr{}
for k := range exprs {
sortedExprs = append(sortedExprs, k)
}
sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] })
// for each remaining expr, generate a shorter name than the full pointer
nextVar := 1
shortNames := map[interfaces.Expr]string{}
for _, expr := range sortedExprs {
shortNames[expr] = fmt.Sprintf("?%d", nextVar)
nextVar++
s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr])
}
// print all the equalities using the short names
for _, equality := range equalities {
switch e := equality.(type) {
case *interfaces.EqualsInvariant:
_, ok := solved[e.Expr]
if !ok {
s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type)
} else {
// if solved, then this is redundant, don't print anything
}
case *interfaces.EqualityInvariant:
type1, ok1 := solved[e.Expr1]
type2, ok2 := solved[e.Expr2]
if !ok1 && !ok2 {
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2])
} else if ok1 && !ok2 {
s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2])
} else if !ok1 && ok2 {
s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2)
} else {
// if completely solved, then this is redundant, don't print anything
}
case *interfaces.EqualityWrapFuncInvariant:
funcType, funcOk := solved[e.Expr1]
args := ""
argsOk := true
for i, argName := range e.Expr2Ord {
if i > 0 {
args += ", "
}
argExpr := e.Expr2Map[argName]
argType, ok := solved[argExpr]
if !ok {
args += fmt.Sprintf("%s %s", argName, shortNames[argExpr])
argsOk = false
} else {
args += fmt.Sprintf("%s %s", argName, argType)
}
}
outType, outOk := solved[e.Expr2Out]
if !funcOk || !argsOk || !outOk {
if !funcOk && !outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out])
} else if !funcOk && outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType)
} else if funcOk && !outOk {
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out])
} else {
s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType)
}
}
case *interfaces.CallFuncArgsValueInvariant:
// skip, not used in the examples I care about
case *interfaces.AnyInvariant:
// skip, not used in the examples I care about
default:
s += fmt.Sprintf("%v\n", equality)
}
}
return s
}
// SimpleInvariantSolver is an iterative invariant solver for AST expressions. // SimpleInvariantSolver is an iterative invariant solver for AST expressions.
// It is intended to be very simple, even if it's computationally inefficient. // It is intended to be very simple, even if it's computationally inefficient.
func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) { func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) {
@@ -876,6 +997,7 @@ Loop:
logf("%s: unsolved: %+v", Name, x) logf("%s: unsolved: %+v", Name, x)
} }
} }
logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities))
// Lastly, we could loop through each exclusive and see // Lastly, we could loop through each exclusive and see
// if it only has a single, easy solution. For example, // if it only has a single, easy solution. For example,