diff --git a/lang/funcs/simplepoly/simplepoly.go b/lang/funcs/simplepoly/simplepoly.go index 06164e41..03a7deb4 100644 --- a/lang/funcs/simplepoly/simplepoly.go +++ b/lang/funcs/simplepoly/simplepoly.go @@ -32,6 +32,15 @@ const ( // API or not. If we don't use it, then these simple functions are // wrapped with the struct below. DirectInterface = false // XXX: fix any bugs and set to true! + + // AllowSimplePolyVariantDefinitions specifies whether we're allowed to + // include the `variant` type in definitons for simple poly functions. + // Long term, it's probably better to have this be false because it adds + // complexity into this simple poly API, and the root of which is the + // argComplexCmp which is only moderately powerful, but I figured I'd + // try and allow this for now because I liked how elegant the definition + // of the len() function was. + AllowSimplePolyVariantDefinitions = true ) // RegisteredFuncs maps a function name to the corresponding static, pure funcs. @@ -56,10 +65,16 @@ func Register(name string, fns []*types.FuncValue) { // check for uniqueness in type signatures typs := []*types.Type{} - for _, f := range fns { + for i, f := range fns { if f.T == nil { panic(fmt.Sprintf("polyfunc %s contains a nil type signature", name)) } + if f.T.Kind != types.KindFunc { // even when this includes a variant + panic(fmt.Sprintf("polyfunc %s must be of kind func", name)) + } + if !AllowSimplePolyVariantDefinitions && f.T.HasVariant() { + panic(fmt.Sprintf("polyfunc %s contains a variant type signature at index: %d", name, i)) + } typs = append(typs, f.T) } @@ -190,7 +205,7 @@ func (obj *WrappedFunc) Unify(expr interfaces.Expr) ([]interfaces.Invariant, err if cfavInvar.Func != expr { continue } - // cfavInvar.Expr is the ExprCall! + // cfavInvar.Expr is the ExprCall! (the return pointer) // cfavInvar.Args are the args that ExprCall uses! // any number of args are permitted @@ -279,9 +294,80 @@ func (obj *WrappedFunc) Unify(expr interfaces.Expr) ([]interfaces.Invariant, err return true // possible } + argComplexCmp := func(typ *types.Type) (*types.Type, bool) { + if !typ.HasVariant() { + return typ, argCmp(typ) + } + + mapped := make(map[string]*types.Type) + ordered := []string{} + out := typ.Out + if len(cfavInvar.Args) != len(typ.Ord) { + return nil, false // arg length differs + } + for i, x := range cfavInvar.Args { + name := typ.Ord[i] + if t, err := x.Type(); err == nil { + if _, err := t.ComplexCmp(typ.Map[typ.Ord[i]]); err != nil { + return nil, false // impossible! + } + mapped[name] = t // found it + } + + // is the type already known as solved? + if t, exists := solved[x]; exists { // alternate way to lookup type + if _, err := t.ComplexCmp(typ.Map[typ.Ord[i]]); err != nil { + return nil, false // impossible! + } + // check it matches the above type + if oldT, exists := mapped[name]; exists && t.Cmp(oldT) != nil { + return nil, false // impossible! + } + mapped[name] = t // found it + } + if _, exists := mapped[name]; !exists { + // impossible, but for a + // different reason: we don't + // have enough information to + // plausibly allow this type to + // pass through, because we'd + // leave a variant in, so skip + // it. We'll probably fail in + // the end with a misleading + // "only recursive solutions + // left" error, but it just + // means we can't solve this! + return nil, false + } + ordered = append(ordered, name) + } + + // if we happen to know the type of the return expr + if t, exists := solved[cfavInvar.Expr]; exists { + if out != nil && t.Cmp(out) != nil { + return nil, false // inconsistent! + } + out = t // learn! + } + + return &types.Type{ + Kind: types.KindFunc, + Map: mapped, + Ord: ordered, + Out: out, + }, true // possible + } + var invariants []interfaces.Invariant var invar interfaces.Invariant + // add the relationship to the returned value + invar = &interfaces.EqualityInvariant{ + Expr1: cfavInvar.Expr, + Expr2: dummyOut, + } + invariants = append(invariants, invar) + ors := []interfaces.Invariant{} // solve only one from this list for _, f := range obj.Fns { // operator func types typ := f.T @@ -293,9 +379,18 @@ func (obj *WrappedFunc) Unify(expr interfaces.Expr) ([]interfaces.Invariant, err return nil, fmt.Errorf("type must be a kind of func") } - if !argCmp(typ) { // filter out impossible types + // filter out impossible types, and on success, + // use the replacement type that we found here! + // this is because the input might be a variant + // and after processing this, we get a concrete + // type that can be substituted in here instead + if typ, ok = argComplexCmp(typ); !ok { continue // not a possible match } + if typ.HasVariant() { + // programming error + return nil, fmt.Errorf("a variant type snuck through: %+v", typ) + } invars, err := buildInvar(typ) if err != nil { diff --git a/lang/interpret_test/TestAstFunc2/tricky-unification0.output b/lang/interpret_test/TestAstFunc2/tricky-unification0.output new file mode 100644 index 00000000..932f3fed --- /dev/null +++ b/lang/interpret_test/TestAstFunc2/tricky-unification0.output @@ -0,0 +1 @@ +Vertex: test[len is: 4] diff --git a/lang/interpret_test/TestAstFunc2/tricky-unification0/main.mcl b/lang/interpret_test/TestAstFunc2/tricky-unification0/main.mcl new file mode 100644 index 00000000..8173619c --- /dev/null +++ b/lang/interpret_test/TestAstFunc2/tricky-unification0/main.mcl @@ -0,0 +1,3 @@ +import "fmt" +$ints = [13, 42, 0, -37,] +test fmt.printf("len is: %d", len($ints)) {} # len is 4 diff --git a/lang/interpret_test/TestAstFunc2/tricky-unification1.output b/lang/interpret_test/TestAstFunc2/tricky-unification1.output new file mode 100644 index 00000000..932f3fed --- /dev/null +++ b/lang/interpret_test/TestAstFunc2/tricky-unification1.output @@ -0,0 +1 @@ +Vertex: test[len is: 4] diff --git a/lang/interpret_test/TestAstFunc2/tricky-unification1/main.mcl b/lang/interpret_test/TestAstFunc2/tricky-unification1/main.mcl new file mode 100644 index 00000000..78aa4a8b --- /dev/null +++ b/lang/interpret_test/TestAstFunc2/tricky-unification1/main.mcl @@ -0,0 +1,4 @@ +import "fmt" +$ints = [13, 42, 0, -37,] +$l int = len($ints) # return type is known statically! +test fmt.printf("len is: %d", $l) {} # len is 4