From 10319dd641db230fad2b3383adf1277cc4d97bcc Mon Sep 17 00:00:00 2001 From: James Shubin Date: Sat, 16 Mar 2024 00:30:47 -0400 Subject: [PATCH] lang: Plumb through a context into unification If we have a long type unification, we might want to cancel it early. This also helps us visualize where we want context to be seen. --- lang/gapi/gapi.go | 2 +- lang/interpret_test.go | 6 +++--- lang/lang.go | 2 +- lang/unification/simplesolver.go | 23 ++++++++++++++++++----- lang/unification/simplesolver_test.go | 3 ++- lang/unification/unification.go | 8 +++++--- lang/unification_test.go | 3 ++- 7 files changed, 32 insertions(+), 15 deletions(-) diff --git a/lang/gapi/gapi.go b/lang/gapi/gapi.go index 25f31f61..5f498aaa 100644 --- a/lang/gapi/gapi.go +++ b/lang/gapi/gapi.go @@ -276,7 +276,7 @@ func (obj *GAPI) Cli(info *gapi.Info) (*gapi.Deploy, error) { Debug: debug, Logf: unificationLogf, } - unifyErr := unifier.Unify() + unifyErr := unifier.Unify(context.TODO()) delta := time.Since(startTime) formatted := delta.String() if delta.Milliseconds() > 1000 { // 1 second diff --git a/lang/interpret_test.go b/lang/interpret_test.go index c4c30a76..cd63617f 100644 --- a/lang/interpret_test.go +++ b/lang/interpret_test.go @@ -464,7 +464,7 @@ func TestAstFunc1(t *testing.T) { Debug: testing.Verbose(), Logf: xlogf, } - err = unifier.Unify() + err = unifier.Unify(context.TODO()) if (!fail || !failUnify) && err != nil { t.Errorf("test #%d: FAIL", index) t.Errorf("test #%d: could not unify types: %+v", index, err) @@ -1034,7 +1034,7 @@ func TestAstFunc2(t *testing.T) { Debug: testing.Verbose(), Logf: xlogf, } - err = unifier.Unify() + err = unifier.Unify(context.TODO()) if (!fail || !failUnify) && err != nil { t.Errorf("test #%d: FAIL", index) t.Errorf("test #%d: could not unify types: %+v", index, err) @@ -1836,7 +1836,7 @@ func TestAstFunc3(t *testing.T) { Debug: testing.Verbose(), Logf: xlogf, } - err = unifier.Unify() + err = unifier.Unify(context.TODO()) if (!fail || !failUnify) && err != nil { t.Errorf("test #%d: FAIL", index) t.Errorf("test #%d: could not unify types: %+v", index, err) diff --git a/lang/lang.go b/lang/lang.go index 195d1906..7b37ece2 100644 --- a/lang/lang.go +++ b/lang/lang.go @@ -232,7 +232,7 @@ func (obj *Lang) Init() error { Debug: obj.Debug, Logf: logf, } - unifyErr := unifier.Unify() + unifyErr := unifier.Unify(context.TODO()) obj.Logf("type unification took: %s", time.Since(timing)) if unifyErr != nil { return errwrap.Wrapf(unifyErr, "could not unify types") diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver.go index 5186e035..4149b05f 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver.go @@ -30,6 +30,7 @@ package unification // TODO: can we put this solver in a sub-package? import ( + "context" "fmt" "sort" @@ -64,9 +65,9 @@ const ( // SimpleInvariantSolver with the log parameter of your choice specified. The // result satisfies the correct signature for the solver parameter of the // Unification function. -func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func([]interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) { - return func(invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) { - return SimpleInvariantSolver(invariants, expected, logf) +func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) { + return func(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) { + return SimpleInvariantSolver(ctx, invariants, expected, logf) } } @@ -195,7 +196,7 @@ func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []inter // SimpleInvariantSolver is an iterative invariant solver for AST expressions. // It is intended to be very simple, even if it's computationally inefficient. -func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) { +func SimpleInvariantSolver(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) { debug := false // XXX: add to interface process := func(invariants []interfaces.Invariant) ([]interfaces.Invariant, []*interfaces.ExclusiveInvariant, error) { equalities := []interfaces.Invariant{} @@ -395,6 +396,12 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfa // run until we're solved, stop consuming equalities, or type clash Loop: for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + // pass + } // Once we're done solving everything else except the generators // then we can exit, but we want to make sure the generators had // a chance to "speak up" and make sure they were part of Unify. @@ -1314,6 +1321,12 @@ Loop: // let's try each combination, one at a time... for i, ex := range exclusivesProduct(exclusives) { // [][]interfaces.Invariant + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + // pass + } logf("%s: exclusive(%d):\n%+v", Name, i, ex) // we could waste a lot of cpu, and start from // the beginning, but instead we could use the @@ -1324,7 +1337,7 @@ Loop: recursiveInvariants = append(recursiveInvariants, ex...) // FIXME: implement RecursionDepthLimit logf("%s: recursing...", Name) - solution, err := SimpleInvariantSolver(recursiveInvariants, expected, logf) + solution, err := SimpleInvariantSolver(ctx, recursiveInvariants, expected, logf) if err != nil { logf("%s: recursive solution failed: %+v", Name, err) continue // no solution found here... diff --git a/lang/unification/simplesolver_test.go b/lang/unification/simplesolver_test.go index 30975800..84c2b47a 100644 --- a/lang/unification/simplesolver_test.go +++ b/lang/unification/simplesolver_test.go @@ -32,6 +32,7 @@ package unification import ( + "context" "fmt" "strings" "testing" @@ -265,7 +266,7 @@ func TestSimpleSolver1(t *testing.T) { solver := SimpleInvariantSolverLogger(logf) // generates a solver with built-in logging - solution, err := solver(invariants, expected) + solution, err := solver(context.TODO(), invariants, expected) t.Logf("test #%d: solver completed with: %+v", index, err) if !fail && err != nil { diff --git a/lang/unification/unification.go b/lang/unification/unification.go index 64d6e990..1b939d54 100644 --- a/lang/unification/unification.go +++ b/lang/unification/unification.go @@ -32,6 +32,7 @@ package unification import ( + "context" "fmt" "sort" "strings" @@ -45,7 +46,8 @@ type Unifier struct { AST interfaces.Stmt // Solver is the solver algorithm implementation to use. - Solver func([]interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) + // XXX: Solver should be a solver interface, not a function signature. + Solver func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) Debug bool Logf func(format string, v ...interface{}) @@ -63,7 +65,7 @@ type Unifier struct { // type. This function and logic was invented after the author could not find // any proper literature or examples describing a well-known implementation of // this process. Improvements and polite recommendations are welcome. -func (obj *Unifier) Unify() error { +func (obj *Unifier) Unify(ctx context.Context) error { if obj.AST == nil { return fmt.Errorf("the AST is nil") } @@ -96,7 +98,7 @@ func (obj *Unifier) Unify() error { exprMap := ExprListToExprMap(exprs) // makes searching faster exprList := ExprMapToExprList(exprMap) // makes it unique (no duplicates) - solved, err := obj.Solver(invariants, exprList) + solved, err := obj.Solver(ctx, invariants, exprList) if err != nil { return err } diff --git a/lang/unification_test.go b/lang/unification_test.go index f6b9f06d..429a84e2 100644 --- a/lang/unification_test.go +++ b/lang/unification_test.go @@ -32,6 +32,7 @@ package lang // XXX: move this to the unification package import ( + "context" "fmt" "strings" "testing" @@ -856,7 +857,7 @@ func TestUnification1(t *testing.T) { Debug: testing.Verbose(), Logf: logf, } - err = unifier.Unify() + err = unifier.Unify(context.TODO()) // TODO: print out the AST's so that we can see the types t.Logf("\n\ntest #%d: AST (after): %+v\n", index, xast)