diff --git a/lang/types/type.go b/lang/types/type.go index 089e4aa2..4a4b0128 100644 --- a/lang/types/type.go +++ b/lang/types/type.go @@ -902,22 +902,29 @@ func (obj *Type) HasVariant() bool { // a possibility against a partial type, the status will be set to the "partial" // string, and if it is compatible with the variant type it will be "variant"... // Comparing to a partial can only match "impossible" (error) or possible (nil). +// This now also supports comparing a partial type to a variant type as well... func (obj *Type) ComplexCmp(typ *Type) (string, error) { // match simple "placeholder" variants... skip variants w/ sub types - isVariant := func(t *Type) bool { return t.Kind == KindVariant && t.Var == nil } + isVariant := func(t *Type) bool { return t != nil && t.Kind == KindVariant && t.Var == nil } - if obj == nil { - return "", fmt.Errorf("can't cmp from a nil type") - } - // XXX: can we relax this to allow variants matching against partials? - if obj.HasVariant() { - return "", fmt.Errorf("only input can contain variants") - } - - if typ == nil { // match + if obj == nil && typ == nil { return "partial", nil // compatible :) } - if isVariant(typ) { // match + if isVariant(obj) && isVariant(typ) { + return "variant", nil // compatible :) + } + + if obj == nil && isVariant(typ) { // partial vs variant + return "both", nil // compatible :) + } + if isVariant(obj) && typ == nil { // variant vs partial + return "both", nil // compatible :) + } + + if obj == nil || typ == nil { // at least one is partial + return "partial", nil // compatible :) + } + if isVariant(obj) || isVariant(typ) { // at least one is variant return "variant", nil // compatible :) } @@ -937,30 +944,9 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { return "", nil case KindList: - if obj.Val == nil { - panic("malformed list type") - } - if typ.Val == nil { - return "partial", nil - } - return obj.Val.ComplexCmp(typ.Val) case KindMap: - if obj.Key == nil || obj.Val == nil { - panic("malformed map type") - } - - if typ.Key == nil && typ.Val == nil { - return "partial", nil - } - if typ.Key == nil { - return obj.Val.ComplexCmp(typ.Val) - } - if typ.Val == nil { - return obj.Key.ComplexCmp(typ.Key) - } - kstatus, kerr := obj.Key.ComplexCmp(typ.Key) vstatus, verr := obj.Val.ComplexCmp(typ.Val) if kerr != nil && verr != nil { @@ -973,19 +959,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { return "", verr } - if kstatus == "" && vstatus == "" { - return "", nil - } else if kstatus != "" && vstatus == "" { - return kstatus, nil - } else if vstatus != "" && kstatus == "" { - return vstatus, nil - } - - // optimization, redundant - //if kstatus == vstatus { // both partial or both variant... - // return kstatus, nil - //} - var isVariant, isPartial bool if kstatus == "variant" || vstatus == "variant" { isVariant = true @@ -1001,24 +974,16 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { if !isVariant && !isPartial { return "", nil } - if isVariant { + if isVariant && !isPartial { return "variant", nil } - if isPartial { + if isPartial && !isVariant { return "partial", nil } - //return "", fmt.Errorf("matches as both partial and variant") return "both", nil case KindStruct: // {a bool; b int} - if obj.Map == nil { - panic("malformed struct type") - } - if typ.Map == nil { - return "partial", nil - } - if len(obj.Ord) != len(typ.Ord) { return "", fmt.Errorf("struct field count differs") } @@ -1037,13 +1002,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { if !ok { panic("malformed struct order") } - if t1 == nil { - panic("malformed struct field") - } - if t2 == nil { - isPartial = true - continue - } status, err := t1.ComplexCmp(t2) if err != nil { @@ -1064,24 +1022,16 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { if !isVariant && !isPartial { return "", nil } - if isVariant { + if isVariant && !isPartial { return "variant", nil } - if isPartial { + if isPartial && !isVariant { return "partial", nil } - //return "", fmt.Errorf("matches as both partial and variant") return "both", nil case KindFunc: - if obj.Map == nil { - panic("malformed func type") - } - if typ.Map == nil { - return "partial", nil - } - if len(obj.Ord) != len(typ.Ord) { return "", fmt.Errorf("func arg count differs") } @@ -1102,13 +1052,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { // if !ok { // panic("malformed func order") // } - // if t1 == nil { - // panic("malformed func arg") - // } - // if t2 == nil { - // isPartial = true - // continue - // } // // status, err := t1.ComplexCmp(t2) // if err != nil { @@ -1129,14 +1072,13 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { //if !isVariant && !isPartial { // return "", nil //} - //if isVariant { + //if isVariant && !isPartial { // return "variant", nil //} - //if isPartial { + //if isPartial && !isVariant { // return "partial", nil //} // - ////return "", fmt.Errorf("matches as both partial and variant") //return "both", nil // if we're not comparing arg names, get the two lists of types @@ -1146,18 +1088,10 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { if !ok { panic("malformed func order") } - if t1 == nil { - panic("malformed func arg") - } - t2, ok := typ.Map[typ.Ord[i]] if !ok { panic("malformed func order") } - if t2 == nil { - isPartial = true - continue - } status, err := t1.ComplexCmp(t2) if err != nil { @@ -1175,45 +1109,36 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) { } } - //if obj.Out != nil && typ.Out != nil { // let a nil obj.Out in - if typ.Out != nil { // let a nil obj.Out in - status, err := obj.Out.ComplexCmp(typ.Out) - if err != nil { - return "", err - } - if status == "variant" { - isVariant = true - } - if status == "partial" { - isPartial = true - } - if status == "both" { - isVariant = true - isPartial = true - } + // NOTE: Technically, .Out could be unspecified, then this is a + // Cmp fail, not an isPartial = true, but then we'd have to + // support functions without a return value. Since we are + // functional, it is not a major problem... - } else if obj.Out != nil { - // TODO: technically, typ.Out could be unspecified, then - // this is a Cmp fail, not an isPartial = true, but then - // we'd have to support functions without a return value - // since we are functional, it is not a major problem... + status, err := obj.Out.ComplexCmp(typ.Out) + if err != nil { + return "", err + } + if status == "variant" { + isVariant = true + } + if status == "partial" { + isPartial = true + } + if status == "both" { + isVariant = true isPartial = true } - //} else if typ.Out != nil { // solve this in the above ComplexCmp instead! - // return "", fmt.Errorf("can't cmp from a nil type") - //} if !isVariant && !isPartial { return "", nil } - if isVariant { + if isVariant && !isPartial { return "variant", nil } - if isPartial { + if isPartial && !isVariant { return "partial", nil } - //return "", fmt.Errorf("matches as both partial and variant") return "both", nil } diff --git a/lang/types/type_test.go b/lang/types/type_test.go index cf0654fb..7427c053 100644 --- a/lang/types/type_test.go +++ b/lang/types/type_test.go @@ -20,7 +20,10 @@ package types import ( + "fmt" "testing" + + "github.com/purpleidea/mgmt/util" ) func TestType0(t *testing.T) { @@ -1295,6 +1298,180 @@ func TestType3(t *testing.T) { } } +func TestComplexCmp0(t *testing.T) { + type test struct { // an individual test + name string + typ1 *Type + typ2 *Type + err bool // expected err ? + str string // expected output str + } + testCases := []test{} + + { + testCases = append(testCases, test{ + name: "int vs str", + typ1: TypeInt, + typ2: TypeStr, + err: true, + str: "", + }) + } + { + testCases = append(testCases, test{ + name: "nested list vs list variant", + typ1: NewType("[][]str"), + typ2: &Type{ + Kind: KindList, + Val: TypeVariant, + }, + err: false, + str: "variant", + }) + } + { + testCases = append(testCases, test{ + name: "nil vs type", + typ1: nil, + typ2: NewType("[][]str"), + err: false, + str: "partial", + }) + } + { + testCases = append(testCases, test{ + name: "variant vs type", + typ1: TypeVariant, + typ2: NewType("[][]str"), + err: false, + str: "variant", + }) + } + { + testCases = append(testCases, test{ + name: "nil vs variant", + typ1: nil, + typ2: TypeVariant, + err: false, + str: "both", + }) + } + { + testCases = append(testCases, test{ + name: "type vs nil", + typ1: NewType("[][]str"), + typ2: nil, + err: false, + str: "partial", + }) + } + { + testCases = append(testCases, test{ + name: "type vs variant", + typ1: NewType("[][]str"), + typ2: TypeVariant, + err: false, + str: "variant", + }) + } + { + testCases = append(testCases, test{ + name: "variant vs nil", + typ1: TypeVariant, + typ2: nil, + err: false, + str: "both", + }) + } + { + // func([]int) VS func([]variant) int + testCases = append(testCases, test{ + name: "partial vs variant", + typ1: &Type{ + Kind: KindFunc, + Map: map[string]*Type{ + "ints": { + Kind: KindList, + Val: TypeInt, + }, + }, + Ord: []string{"ints"}, + Out: nil, // unspecified, it's a partial + }, + typ2: &Type{ + Kind: KindFunc, + Map: map[string]*Type{ + "ints": { + Kind: KindList, + Val: TypeVariant, // variant! + }, + }, + Ord: []string{"ints"}, + Out: TypeInt, + }, + err: false, + str: "both", + }) + } + + if testing.Short() { + t.Logf("available tests:") + } + names := []string{} + for index, tc := range testCases { // run all the tests + if tc.name == "" { + t.Errorf("test #%d: not named", index) + continue + } + if util.StrInList(tc.name, names) { + t.Errorf("test #%d: duplicate sub test name of: %s", index, tc.name) + continue + } + names = append(names, tc.name) + + testName := fmt.Sprintf("test #%d (%s)", index, tc.name) + if testing.Short() { // make listing tests easier + t.Logf("%s", testName) + continue + } + t.Run(testName, func(t *testing.T) { + typ1, typ2, err, str := tc.typ1, tc.typ2, tc.err, tc.str + + // the reverse should probably match the forward version + s1, err1 := typ1.ComplexCmp(typ2) + s2, err2 := typ2.ComplexCmp(typ1) + + if err && err1 == nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: expected error, got nil", index) + } + if !err && err1 != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: unexpected error: %+v", index, err1) + } + if err && err2 == nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: expected error, got nil", index) + } + if !err && err2 != nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: unexpected error: %+v", index, err2) + } + + if s1 != s2 { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: strings did not match: %+v != %+v", index, s1, s2) + return + } + if s1 != str { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: unexpected string: %+v != %+v", index, s1, str) + return + } + }) + } +} + func TestTypeOf0(t *testing.T) { // TODO: implement testing of the TypeOf function }