diff --git a/lang/types/type.go b/lang/types/type.go index e8089236..c29ab6e9 100644 --- a/lang/types/type.go +++ b/lang/types/type.go @@ -33,15 +33,20 @@ import ( "fmt" "net" "reflect" + "strconv" "strings" "github.com/purpleidea/mgmt/util" + "github.com/purpleidea/mgmt/util/disjoint" "github.com/purpleidea/mgmt/util/errwrap" ) const ( // StructTag is the key we use in struct field names for key mapping. StructTag = "lang" + + // MaxInt8 is 127. It's max uint8: ^uint8(0), then we >> 1 for max int8. + MaxInt8 = int((^uint8(0)) >> 1) ) // Basic types defined here as a convenience for use with Type.Cmp(X). @@ -71,6 +76,8 @@ const ( KindStruct KindFunc KindVariant + + KindUnification = Kind(MaxInt8) // keep this last ) // Type is the datastructure representing any type. It can be recursive for @@ -85,6 +92,20 @@ type Type struct { Ord []string Out *Type // if Kind == Func, use Map and Ord for Input, Out for Output Var *Type // if Kind == Variant, use Var only + + // unification variable (question mark, eg ?1, ?2) + Uni *Elem // if Kind == Unification (optional) use Uni only +} + +// Elem is the type used for the unification variable in the Uni field of Type. +// We create this alias here to avoid needing to write *disjoint.Elem[*Type] all +// over. This is a golang type alias. These should be created with NewElem. +type Elem = disjoint.Elem[*Type] + +// NewElem creates a new set with one element and returns the sole element (the +// representative element) of that set. +func NewElem() *Elem { + return disjoint.NewElem[*Type]() } // TypeOf takes a reflect.Type and returns an equivalent *Type. It removes any @@ -340,6 +361,14 @@ func ConfigurableTypeOf(t reflect.Type, opts ...TypeOfOption) (*Type, error) { // NewType creates the Type from the string representation. func NewType(s string) *Type { + table := make(map[uint]*Elem) + return newType(s, table) +} + +// newType creates the Type from the string representation. This private version +// takes a table so that we can collect unification variables as we see them and +// return a type with correctly unified unification variables. +func newType(s string, table map[uint]*Elem) *Type { switch s { case "bool": return &Type{ @@ -361,7 +390,7 @@ func NewType(s string) *Type { // KindList if strings.HasPrefix(s, "[]") { - val := NewType(s[len("[]"):]) + val := newType(s[len("[]"):], table) if val == nil { return nil } @@ -395,11 +424,11 @@ func NewType(s string) *Type { return nil } - key := NewType(strings.Trim(s[:found], " ")) + key := newType(strings.Trim(s[:found], " "), table) if key == nil { return nil } - val := NewType(strings.Trim(s[found+1:], " ")) + val := newType(strings.Trim(s[found+1:], " "), table) if val == nil { return nil } @@ -457,7 +486,7 @@ func NewType(s string) *Type { trim = 1 } - typ := NewType(strings.Trim(s[:found+1-trim], " ")) + typ := newType(strings.Trim(s[:found+1-trim], " "), table) if typ == nil { return nil } @@ -557,7 +586,7 @@ func NewType(s string) *Type { trim = 1 } - typ := NewType(strings.Trim(s[:found+1-trim], " ")) + typ := newType(strings.Trim(s[:found+1-trim], " "), table) if typ == nil { return nil } @@ -568,7 +597,7 @@ func NewType(s string) *Type { // return type var tail *Type if out != "" { // allow functions with no return type (in parser) - tail = NewType(out) + tail = newType(out, table) if tail == nil { return nil } @@ -589,6 +618,47 @@ func NewType(s string) *Type { } } + // KindUnification + if strings.HasPrefix(s, "?") { + // find end of number... + var length = 0 // number of digits + for i := len("?"); i < len(s); i++ { + c := s[i] + if length == 0 && c == '0' { + return nil // can't start with a zero + } + + // Check manually because strconv.ParseUint accepts ^0x. + if '0' <= c && c <= '9' { + length++ + continue + } + return nil // invalid char + } + + v := s[len("?") : len("?")+length] + n, err := strconv.ParseUint(v, 10, 32) // base 10, 32 bits + if err != nil { + return nil // programming error or overflow + } + num := uint(n) + + // XXX: Should we instead always return new unification + // variables, but call .Union() on all of the ones that have the + // same integer? Sam says they are equivalent. + uni, exists := table[num] + if !exists { + uni = NewElem() // unification variable, eg: ?1 + table[num] = uni // store + } + + // return a new type, may have an existing unification variable + return &Type{ + Kind: KindUnification, + Uni: uni, // unification variable, eg: ?1 + } + } + return nil // error (this also matches the empty string as input) } @@ -617,12 +687,21 @@ func (obj *Type) New() Value { return NewFunc(obj) case KindVariant: return NewVariant(obj) + case KindUnification: + panic("can't make new value from unification variable kind") } panic("malformed type") } // String returns the textual representation for this type. func (obj *Type) String() string { + table := make(map[*Elem]uint) + return obj.string(table) +} + +// string returns the textual representation for this type. This is a private +// helper function that is used by the real String function. +func (obj *Type) string(table map[*Elem]uint) string { switch obj.Kind { case KindBool: return "bool" @@ -637,13 +716,13 @@ func (obj *Type) String() string { if obj.Val == nil { panic("malformed list type") } - return "[]" + obj.Val.String() + return "[]" + obj.Val.string(table) case KindMap: if obj.Key == nil || obj.Val == nil { panic("malformed map type") } - return fmt.Sprintf("map{%s: %s}", obj.Key.String(), obj.Val.String()) + return fmt.Sprintf("map{%s: %s}", obj.Key.string(table), obj.Val.string(table)) case KindStruct: // {a bool; b int} if obj.Map == nil { @@ -661,7 +740,7 @@ func (obj *Type) String() string { if t == nil { panic("malformed struct field") } - s[i] = fmt.Sprintf("%s %s", k, t.String()) + s[i] = fmt.Sprintf("%s %s", k, t.string(table)) } return fmt.Sprintf("struct{%s}", strings.Join(s, "; ")) @@ -684,17 +763,37 @@ func (obj *Type) String() string { // We need to print function arg names for Copy() to use // the String() hack here and avoid erasing them here! - //s[i] = t.String() - s[i] = fmt.Sprintf("%s %s", k, t.String()) // strict + //s[i] = t.string(table) + s[i] = fmt.Sprintf("%s %s", k, t.string(table)) // strict } var out string if obj.Out != nil { - out = fmt.Sprintf(" %s", obj.Out.String()) + out = fmt.Sprintf(" %s", obj.Out.string(table)) } return fmt.Sprintf("func(%s)%s", strings.Join(s, ", "), out) case KindVariant: return "variant" + + case KindUnification: + if obj.Uni == nil { + panic("malformed unification variable") + } + + // XXX: Should we instead run .IsConnected() on the two Elem + // unification variables to determine if they should have the + // same integer representation when printing them? + num, exists := table[obj.Uni] + if !exists { + for _, n := range table { + num = max(num, n) + } + num++ // add 1 + table[obj.Uni] = num // store + } + + //fmt.Printf("?%d: %p\n", int(num), obj.Uni.Find()) // debug + return "?" + strconv.Itoa(int(num)) } panic("malformed type") @@ -702,6 +801,14 @@ func (obj *Type) String() string { // Cmp compares this type to another. func (obj *Type) Cmp(typ *Type) error { + table1 := make(map[*Elem]uint) // for obj + table2 := make(map[*Elem]uint) // for typ + return obj.cmp(typ, table1, table2) +} + +// cmp compares this type to another. This is a private helper function that is +// used by the real Cmp function. +func (obj *Type) cmp(typ *Type, table1, table2 map[*Elem]uint) error { if obj == nil || typ == nil { return fmt.Errorf("cannot compare to nil") } @@ -709,10 +816,10 @@ func (obj *Type) Cmp(typ *Type) error { // TODO: is this correct? // recurse into variants if we want base type comparisons //if obj.Kind == KindVariant { - // return obj.Var.Cmp(t) + // return obj.Var.cmp(t, table1, table2) //} //if t.Kind == KindVariant { - // return obj.Cmp(t.Var) + // return obj.cmp(t.Var, table1, table2) //} if obj.Kind != typ.Kind { @@ -732,14 +839,14 @@ func (obj *Type) Cmp(typ *Type) error { if obj.Val == nil || typ.Val == nil { panic("malformed list type") } - return obj.Val.Cmp(typ.Val) + return obj.Val.cmp(typ.Val, table1, table2) case KindMap: if obj.Key == nil || obj.Val == nil || typ.Key == nil || typ.Val == nil { panic("malformed map type") } - kerr := obj.Key.Cmp(typ.Key) - verr := obj.Val.Cmp(typ.Val) + kerr := obj.Key.cmp(typ.Key, table1, table2) + verr := obj.Val.cmp(typ.Val, table1, table2) if kerr != nil && verr != nil { return errwrap.Append(kerr, verr) // two errors } @@ -775,7 +882,7 @@ func (obj *Type) Cmp(typ *Type) error { if t1 == nil || t2 == nil { panic("malformed struct field") } - if err := t1.Cmp(t2); err != nil { + if err := t1.cmp(t2, table1, table2); err != nil { return err } } @@ -806,7 +913,7 @@ func (obj *Type) Cmp(typ *Type) error { // if t1 == nil || t2 == nil { // panic("malformed func arg") // } - // if err := t1.Cmp(t2); err != nil { + // if err := t1.cmp(t2, table1, table2); err != nil { // return err // } //} @@ -829,13 +936,13 @@ func (obj *Type) Cmp(typ *Type) error { panic("malformed func arg") } - if err := t1.Cmp(t2); err != nil { + if err := t1.cmp(t2, table1, table2); err != nil { return err } } if obj.Out != nil || typ.Out != nil { - if err := obj.Out.Cmp(typ.Out); err != nil { + if err := obj.Out.cmp(typ.Out, table1, table2); err != nil { return err } } @@ -848,6 +955,39 @@ func (obj *Type) Cmp(typ *Type) error { } // TODO: should we Cmp obj.Var with typ.Var ? -- not necessarily return nil + + // used for testing + case KindUnification: + if obj.Uni == nil || typ.Uni == nil { + panic("malformed unification variable") + } + + // If both types store and lookup variables symmetrically and in + // the same order, then the count's should also match. + // XXX: Should we instead run .IsConnected() on the two Elem + // unification variables to determine if they should have the + // same integer representation when printing them? + num1, exists := table1[obj.Uni] + if !exists { + for _, n := range table1 { + num1 = max(num1, n) + } + num1++ // add 1 + table1[obj.Uni] = num1 // store + } + + num2, exists := table2[typ.Uni] + if !exists { + for _, n := range table2 { + num2 = max(num2, n) + } + num2++ // add 1 + table2[typ.Uni] = num2 // store + } + if num1 != num2 { + return fmt.Errorf("unbalanced unification variables") + } + return nil } return fmt.Errorf("unknown kind") } @@ -1040,6 +1180,97 @@ func (obj *Type) HasVariant() bool { case KindVariant: return true // found it! + + case KindUnification: + return false // TODO: Do we want to panic here instead? + } + + panic("malformed type") +} + +// HasUni tells us if the type contains any unification variables. +func (obj *Type) HasUni() bool { + if obj == nil { + return false + } + if obj.Uni != nil { + return true // found it (by this method) + } + + switch obj.Kind { + case KindBool: + return false + case KindStr: + return false + case KindInt: + return false + case KindFloat: + return false + + case KindList: + if obj.Val == nil { + panic("malformed list type") + } + return obj.Val.HasUni() + + case KindMap: + if obj.Key == nil || obj.Val == nil { + panic("malformed map type") + } + return obj.Key.HasUni() || obj.Val.HasUni() + + case KindStruct: // {a bool; b int} + if obj.Map == nil { + panic("malformed struct type") + } + if len(obj.Map) != len(obj.Ord) { + panic("malformed struct length") + } + for _, k := range obj.Ord { + t, ok := obj.Map[k] + if !ok { + panic("malformed struct order") + } + if t == nil { + panic("malformed struct field") + } + if t.HasUni() { + return true + } + } + return false + + case KindFunc: + if obj.Map == nil { + panic("malformed func type") + } + if len(obj.Map) != len(obj.Ord) { + panic("malformed func length") + } + for _, k := range obj.Ord { + t, ok := obj.Map[k] + if !ok { + panic("malformed func order") + } + if t == nil { + panic("malformed func field") + } + if t.HasUni() { + return true + } + } + if obj.Out != nil { + if obj.Out.HasUni() { + return true + } + } + return false + + case KindVariant: + return obj.Var.HasUni() + + case KindUnification: + return true // found it! } panic("malformed type") @@ -1053,6 +1284,7 @@ func (obj *Type) HasVariant() bool { // string, and if it is compatible with the variant type it will be "variant"... // Comparing to a partial can only match "impossible" (error) or possible (nil). // This now also supports comparing a partial type to a variant type as well... +// TODO: Should we support KindUnification somehow? func (obj *Type) ComplexCmp(typ *Type) (string, error) { // match simple "placeholder" variants... skip variants w/ sub types isVariant := func(t *Type) bool { return t != nil && t.Kind == KindVariant && t.Var == nil } diff --git a/lang/types/type_test.go b/lang/types/type_test.go index 23c169ca..b8b3c57c 100644 --- a/lang/types/type_test.go +++ b/lang/types/type_test.go @@ -1498,6 +1498,270 @@ func TestTypeCopy0(t *testing.T) { } } +func TestUni0(t *testing.T) { + // good type strings + if NewType("?1") == nil { + t.Errorf("unexpected nil type") + } + if NewType("?123") == nil { + t.Errorf("unexpected nil type") + } + if NewType("[]?123") == nil { + t.Errorf("unexpected nil type") + } + if NewType("map{?123: ?123}") == nil { + t.Errorf("unexpected nil type") + } + + // bad type strings + if typ := NewType("?0"); typ != nil { + t.Errorf("expected nil type, got: %v", typ) + } + if typ := NewType("?00"); typ != nil { + t.Errorf("expected nil type, got: %v", typ) + } + if typ := NewType("?000000000000000000000"); typ != nil { + t.Errorf("expected nil type, got: %v", typ) + } +} + +func TestUni1(t *testing.T) { + // functions with named types... + testCases := map[string]*Type{ + // good type strings + "?1": { + Kind: KindUnification, + Uni: NewElem(), + }, + "?123": { + Kind: KindUnification, + Uni: NewElem(), + }, + "[]?123": { + Kind: KindList, + Val: &Type{ + Kind: KindUnification, + Uni: NewElem(), + }, + }, + + // bad type strings + "?0": nil, + "?00": nil, + "?00000": nil, + "?-1": nil, + "?-42": nil, + "?0x42": nil, // hexadecimal + "?013": nil, // octal + } + + for str, val := range testCases { // run all the tests + // for debugging + //if str != "?0" { + //continue + //} + + // check the type + typ := NewType(str) + //t.Logf("str: %+v", str) + //t.Logf("typ: %+v", typ) + + if val == nil { // catch error cases + if typ != nil { + t.Errorf("invalid type: `%s` did not match expected nil", str) + } + continue + } + + if err := typ.Cmp(val); err != nil { + t.Errorf("type: `%s` did not match expected: `%v`", str, err) + return + } + } +} + +func TestUniCmp0(t *testing.T) { + type test struct { // an individual test + name string + typ1 *Type + typ2 *Type + err bool // expected err ? + str string // expected output str + } + testCases := []test{} + + testCases = append(testCases, test{ + name: "simple ?1 compare", + typ1: NewType("?1"), + typ2: NewType("?1"), + err: false, + }) + testCases = append(testCases, test{ + name: "different ?1 compare", + typ1: NewType("?13"), // they don't need to be the same + typ2: NewType("?42"), + err: false, + }) + testCases = append(testCases, test{ + name: "duplicate type unification variables", + // the type unification variables should be the same + typ1: NewType("map{?123:?123}"), + typ2: &Type{ + Kind: KindMap, + Key: &Type{ + Kind: KindUnification, + Uni: NewElem(), + }, + Val: &Type{ + Kind: KindUnification, + Uni: NewElem(), + }, + }, + err: true, + }) + { + uni0 := NewElem() + testCases = append(testCases, test{ + name: "same type unification variables in map", + // the type unification variables should be the same + typ1: NewType("map{?123:?123}"), + typ2: &Type{ + Kind: KindMap, + Key: &Type{ + Kind: KindUnification, + Uni: uni0, + }, + Val: &Type{ + Kind: KindUnification, + Uni: uni0, + }, + }, + err: false, + }) + } + { + uni1 := NewElem() + uni2 := NewElem() + uni3 := NewElem() + // XXX: should we instead have uni0 for the return type and + // .Union() it with uni2 ? + //uni0 := NewElem() + //uni2.Union(uni0) + testCases = append(testCases, test{ + name: "duplicate type unification variables in functions", + // the type unification variables should be the same + typ1: NewType("func(?13, ?42, ?4, int) ?42"), + typ2: &Type{ + Kind: KindFunc, + Map: map[string]*Type{ + "a": { + Kind: KindUnification, + Uni: uni1, + }, + "b": { + Kind: KindUnification, + Uni: uni2, + }, + "c": { + Kind: KindUnification, + Uni: uni3, + }, + "d": TypeInt, + }, + Ord: []string{"a", "b", "c", "d"}, + Out: &Type{ + Kind: KindUnification, + Uni: uni2, // same as the second arg + }, + }, + err: false, + }) + } + { + uni1 := NewElem() + uni2 := NewElem() + // XXX: should we instead have uni0 for the return type and + // .Union() it with uni2 ? + //uni0 := NewElem() + //uni2.Union(uni0) + testCases = append(testCases, test{ + name: "duplicate type unification variables in functions unbalanced", + // the type unification variables should be the same + typ1: NewType("func(?13, ?42, ?4, int) ?42"), + typ2: &Type{ + Kind: KindFunc, + Map: map[string]*Type{ + "a": { + Kind: KindUnification, + Uni: uni1, + }, + "b": { + Kind: KindUnification, + Uni: uni2, + }, + "c": { + Kind: KindUnification, + Uni: uni1, // must not match! + }, + "d": TypeInt, + }, + Ord: []string{"a", "b", "c", "d"}, + Out: &Type{ + Kind: KindUnification, + Uni: uni2, // same as the second arg + }, + }, + err: true, + }) + } + + if testing.Short() { + t.Logf("available tests:") + } + names := []string{} + for index, tc := range testCases { // run all the tests + if tc.name == "" { + t.Errorf("test #%d: not named", index) + continue + } + if util.StrInList(tc.name, names) { + t.Errorf("test #%d: duplicate sub test name of: %s", index, tc.name) + continue + } + names = append(names, tc.name) + + testName := fmt.Sprintf("test #%d (%s)", index, tc.name) + if testing.Short() { // make listing tests easier + t.Logf("%s", testName) + continue + } + t.Run(testName, func(t *testing.T) { + typ1, typ2, err := tc.typ1, tc.typ2, tc.err + + // the reverse should probably match the forward version + err1 := typ1.Cmp(typ2) + err2 := typ2.Cmp(typ1) + + if err && err1 == nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: expected error, got nil", index) + } + if !err && err1 != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: unexpected error: %+v", index, err1) + } + if err && err2 == nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: expected error, got nil", index) + } + if !err && err2 != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: unexpected error: %+v", index, err2) + } + }) + } +} + func TestTypeOf0(t *testing.T) { // TODO: implement testing of the TypeOf function // TODO: implement testing TypeOf for struct field name mappings