From aec8e1db2d9e6533c923d54f363347376d0d3fad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20G=C3=A9lineau?= Date: Sat, 24 Jun 2023 13:06:17 -0400 Subject: [PATCH] lang: unification: Add type inference state debugging This presents things in a more formal way for those who are more familiar with standard type unification syntax. Patch created by Sam, and cleaned up by James. --- lang/unification/simplesolver.go | 122 +++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver.go index e6c20b7b..7a4e5f6f 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver.go @@ -19,6 +19,7 @@ package unification // TODO: can we put this solver in a sub-package? import ( "fmt" + "sort" "github.com/purpleidea/mgmt/lang/interfaces" "github.com/purpleidea/mgmt/lang/types" @@ -57,6 +58,126 @@ func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) fun } } +// DebugSolverState helps us in understanding the state of the type unification +// solver in a more mainstream format. +// Example: +// +// solver state: +// +// * str("foo") :: str +// * call:f(str("foo")) [0xc000ac9f10] :: ?1 +// * var(x) [0xc00088d840] :: ?2 +// * param(x) [0xc00000f950] :: ?3 +// * func(x) { var(x) } [0xc0000e9680] :: ?4 +// * ?2 = ?3 +// * ?4 = func(arg0 str) ?1 +// * ?4 = func(x str) ?2 +// * ?1 = ?2 +func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []interfaces.Invariant) string { + s := "" + + // all the relevant Exprs + count := 0 + exprs := make(map[interfaces.Expr]int) + for _, equality := range equalities { + for _, expr := range equality.ExprList() { + count++ + exprs[expr] = count // for sorting + } + } + + // print the solved Exprs first + for expr, typ := range solved { + s += fmt.Sprintf("%v :: %v\n", expr, typ) + delete(exprs, expr) + } + + sortedExprs := []interfaces.Expr{} + for k := range exprs { + sortedExprs = append(sortedExprs, k) + } + sort.Slice(sortedExprs, func(i, j int) bool { return exprs[sortedExprs[i]] < exprs[sortedExprs[j]] }) + + // for each remaining expr, generate a shorter name than the full pointer + nextVar := 1 + shortNames := map[interfaces.Expr]string{} + for _, expr := range sortedExprs { + shortNames[expr] = fmt.Sprintf("?%d", nextVar) + nextVar++ + s += fmt.Sprintf("%p %v :: %s\n", expr, expr, shortNames[expr]) + } + + // print all the equalities using the short names + for _, equality := range equalities { + switch e := equality.(type) { + case *interfaces.EqualsInvariant: + _, ok := solved[e.Expr] + if !ok { + s += fmt.Sprintf("%s = %v\n", shortNames[e.Expr], e.Type) + } else { + // if solved, then this is redundant, don't print anything + } + + case *interfaces.EqualityInvariant: + type1, ok1 := solved[e.Expr1] + type2, ok2 := solved[e.Expr2] + if !ok1 && !ok2 { + s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], shortNames[e.Expr2]) + } else if ok1 && !ok2 { + s += fmt.Sprintf("%s = %s\n", type1, shortNames[e.Expr2]) + } else if !ok1 && ok2 { + s += fmt.Sprintf("%s = %s\n", shortNames[e.Expr1], type2) + } else { + // if completely solved, then this is redundant, don't print anything + } + + case *interfaces.EqualityWrapFuncInvariant: + funcType, funcOk := solved[e.Expr1] + + args := "" + argsOk := true + for i, argName := range e.Expr2Ord { + if i > 0 { + args += ", " + } + argExpr := e.Expr2Map[argName] + argType, ok := solved[argExpr] + if !ok { + args += fmt.Sprintf("%s %s", argName, shortNames[argExpr]) + argsOk = false + } else { + args += fmt.Sprintf("%s %s", argName, argType) + } + } + + outType, outOk := solved[e.Expr2Out] + + if !funcOk || !argsOk || !outOk { + if !funcOk && !outOk { + s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, shortNames[e.Expr2Out]) + } else if !funcOk && outOk { + s += fmt.Sprintf("%s = func(%s) %s\n", shortNames[e.Expr1], args, outType) + } else if funcOk && !outOk { + s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, shortNames[e.Expr2Out]) + } else { + s += fmt.Sprintf("%s = func(%s) %s\n", funcType, args, outType) + } + } + + case *interfaces.CallFuncArgsValueInvariant: + // skip, not used in the examples I care about + + case *interfaces.AnyInvariant: + // skip, not used in the examples I care about + + default: + s += fmt.Sprintf("%v\n", equality) + } + } + + return s +} + // SimpleInvariantSolver is an iterative invariant solver for AST expressions. // It is intended to be very simple, even if it's computationally inefficient. func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) { @@ -876,6 +997,7 @@ Loop: logf("%s: unsolved: %+v", Name, x) } } + logf("%s: solver state:\n%s", Name, DebugSolverState(solved, equalities)) // Lastly, we could loop through each exclusive and see // if it only has a single, easy solution. For example,