diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver.go index 5354c967..c6b0411e 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver.go @@ -64,8 +64,10 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfa process := func(invariants []interfaces.Invariant) ([]interfaces.Invariant, []*interfaces.ExclusiveInvariant, error) { equalities := []interfaces.Invariant{} exclusives := []*interfaces.ExclusiveInvariant{} + generators := []interfaces.Invariant{} - for _, x := range invariants { + for ix := 0; len(invariants) > ix; ix++ { // while + x := invariants[ix] switch invariant := x.(type) { case *interfaces.EqualsInvariant: equalities = append(equalities, invariant) @@ -103,11 +105,13 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfa case *interfaces.EqualityWrapCallInvariant: equalities = append(equalities, invariant) + case *interfaces.GeneratorInvariant: + // these are special, note the different list + generators = append(generators, invariant) + // contains a list of invariants which this represents case *interfaces.ConjunctionInvariant: - for _, invar := range invariant.Invariants { - equalities = append(equalities, invar) - } + invariants = append(invariants, invariant.Invariants...) case *interfaces.ExclusiveInvariant: // these are special, note the different list @@ -118,11 +122,41 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfa case *interfaces.AnyInvariant: equalities = append(equalities, invariant) + case *interfaces.ValueInvariant: + equalities = append(equalities, invariant) + + case *interfaces.CallFuncArgsValueInvariant: + equalities = append(equalities, invariant) + default: return nil, nil, fmt.Errorf("unknown invariant type: %T", x) } } + // optimization: if we have zero generator invariants, we can + // discard the value invariants! + if len(generators) == 0 { + used := []int{} + for i, x := range equalities { + if _, ok := x.(*interfaces.ValueInvariant); !ok { + continue + } + if _, ok := x.(*interfaces.CallFuncArgsValueInvariant); !ok { + continue + } + used = append(used, i) // mark equality as used up + } + // delete used equalities, in reverse order to preserve indexing! + for i := len(used) - 1; i >= 0; i-- { + ix := used[i] // delete index that was marked as used! + equalities = append(equalities[:ix], equalities[ix+1:]...) + } + } + + // append the generators at the end + // (they can go in any order, but it's more optimal this way) + equalities = append(equalities, generators...) + return equalities, exclusives, nil } @@ -564,6 +598,40 @@ Loop: panic("reached unexpected code") + case *interfaces.GeneratorInvariant: + // this invariant can generate new ones + + // optimization: we want to run the generators + // last (but before the exclusives) because + // they take longer to run. So as long as we've + // made progress this time around, don't run + // this just yet, there's still time left... + if len(used) > 0 { + continue + } + + // If this returns nil, we add the invariants + // it returned and we remove it from the list. + // If we error, it's because we don't have any + // new information to provide at this time... + // XXX: should we pass in `invariants` instead? + gi, err := eq.Func(equalities, solved) + if err != nil { + continue + } + + eqs, exs, err := process(gi) // process like at the top + if err != nil { + // programming error? + return nil, errwrap.Wrapf(err, "processing error") + } + equalities = append(equalities, eqs...) + exclusives = append(exclusives, exs...) + + used = append(used, i) // mark equality as used up + logf("%s: solved `generator` equality", Name) + continue + // wtf matching case *interfaces.AnyInvariant: // this basically ensures that the expr gets solved @@ -573,6 +641,16 @@ Loop: } continue + case *interfaces.ValueInvariant: + // don't consume these, they're stored in case + // a generator invariant wants to read them... + continue + + case *interfaces.CallFuncArgsValueInvariant: + // don't consume these, they're stored in case + // a generator invariant wants to read them... + continue + default: return nil, fmt.Errorf("unknown invariant type: %T", x) } @@ -692,13 +770,13 @@ Loop: } // Add new equalities and exclusives onto state globals. - eq, ex, err := process(simplified) // process like at the top + eqs, exs, err := process(simplified) // process like at the top if err != nil { // programming error? return nil, errwrap.Wrapf(err, "processing error") } - equalities = append(equalities, eq...) - exclusives = append(exclusives, ex...) + equalities = append(equalities, eqs...) + exclusives = append(exclusives, exs...) // If we removed any exclusives, then we can start over. if len(done) > 0 {