lang: Plumb through a context into unification

If we have a long type unification, we might want to cancel it early.
This also helps us visualize where we want context to be seen.
This commit is contained in:
James Shubin
2024-03-16 00:30:47 -04:00
parent a8b945e36e
commit 10319dd641
7 changed files with 32 additions and 15 deletions

View File

@@ -30,6 +30,7 @@
package unification // TODO: can we put this solver in a sub-package?
import (
"context"
"fmt"
"sort"
@@ -64,9 +65,9 @@ const (
// SimpleInvariantSolver with the log parameter of your choice specified. The
// result satisfies the correct signature for the solver parameter of the
// Unification function.
func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func([]interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) {
return func(invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) {
return SimpleInvariantSolver(invariants, expected, logf)
func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func(context.Context, []interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) {
return func(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) {
return SimpleInvariantSolver(ctx, invariants, expected, logf)
}
}
@@ -195,7 +196,7 @@ func DebugSolverState(solved map[interfaces.Expr]*types.Type, equalities []inter
// SimpleInvariantSolver is an iterative invariant solver for AST expressions.
// It is intended to be very simple, even if it's computationally inefficient.
func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) {
func SimpleInvariantSolver(ctx context.Context, invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) {
debug := false // XXX: add to interface
process := func(invariants []interfaces.Invariant) ([]interfaces.Invariant, []*interfaces.ExclusiveInvariant, error) {
equalities := []interfaces.Invariant{}
@@ -395,6 +396,12 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfa
// run until we're solved, stop consuming equalities, or type clash
Loop:
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// pass
}
// Once we're done solving everything else except the generators
// then we can exit, but we want to make sure the generators had
// a chance to "speak up" and make sure they were part of Unify.
@@ -1314,6 +1321,12 @@ Loop:
// let's try each combination, one at a time...
for i, ex := range exclusivesProduct(exclusives) { // [][]interfaces.Invariant
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// pass
}
logf("%s: exclusive(%d):\n%+v", Name, i, ex)
// we could waste a lot of cpu, and start from
// the beginning, but instead we could use the
@@ -1324,7 +1337,7 @@ Loop:
recursiveInvariants = append(recursiveInvariants, ex...)
// FIXME: implement RecursionDepthLimit
logf("%s: recursing...", Name)
solution, err := SimpleInvariantSolver(recursiveInvariants, expected, logf)
solution, err := SimpleInvariantSolver(ctx, recursiveInvariants, expected, logf)
if err != nil {
logf("%s: recursive solution failed: %+v", Name, err)
continue // no solution found here...