lang: types: Improve ComplexCmp function
This improves the ComplexCmp function so that it can compare partial types to variant types. As a result of this improvement, it actually ended up simplifying the code significantly. This also added a test suite for this function. This fix was important for tricky type unification problems.
This commit is contained in:
@@ -902,22 +902,29 @@ func (obj *Type) HasVariant() bool {
|
|||||||
// a possibility against a partial type, the status will be set to the "partial"
|
// a possibility against a partial type, the status will be set to the "partial"
|
||||||
// string, and if it is compatible with the variant type it will be "variant"...
|
// string, and if it is compatible with the variant type it will be "variant"...
|
||||||
// Comparing to a partial can only match "impossible" (error) or possible (nil).
|
// Comparing to a partial can only match "impossible" (error) or possible (nil).
|
||||||
|
// This now also supports comparing a partial type to a variant type as well...
|
||||||
func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
||||||
// match simple "placeholder" variants... skip variants w/ sub types
|
// match simple "placeholder" variants... skip variants w/ sub types
|
||||||
isVariant := func(t *Type) bool { return t.Kind == KindVariant && t.Var == nil }
|
isVariant := func(t *Type) bool { return t != nil && t.Kind == KindVariant && t.Var == nil }
|
||||||
|
|
||||||
if obj == nil {
|
if obj == nil && typ == nil {
|
||||||
return "", fmt.Errorf("can't cmp from a nil type")
|
|
||||||
}
|
|
||||||
// XXX: can we relax this to allow variants matching against partials?
|
|
||||||
if obj.HasVariant() {
|
|
||||||
return "", fmt.Errorf("only input can contain variants")
|
|
||||||
}
|
|
||||||
|
|
||||||
if typ == nil { // match
|
|
||||||
return "partial", nil // compatible :)
|
return "partial", nil // compatible :)
|
||||||
}
|
}
|
||||||
if isVariant(typ) { // match
|
if isVariant(obj) && isVariant(typ) {
|
||||||
|
return "variant", nil // compatible :)
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj == nil && isVariant(typ) { // partial vs variant
|
||||||
|
return "both", nil // compatible :)
|
||||||
|
}
|
||||||
|
if isVariant(obj) && typ == nil { // variant vs partial
|
||||||
|
return "both", nil // compatible :)
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj == nil || typ == nil { // at least one is partial
|
||||||
|
return "partial", nil // compatible :)
|
||||||
|
}
|
||||||
|
if isVariant(obj) || isVariant(typ) { // at least one is variant
|
||||||
return "variant", nil // compatible :)
|
return "variant", nil // compatible :)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -937,30 +944,9 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
return "", nil
|
return "", nil
|
||||||
|
|
||||||
case KindList:
|
case KindList:
|
||||||
if obj.Val == nil {
|
|
||||||
panic("malformed list type")
|
|
||||||
}
|
|
||||||
if typ.Val == nil {
|
|
||||||
return "partial", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return obj.Val.ComplexCmp(typ.Val)
|
return obj.Val.ComplexCmp(typ.Val)
|
||||||
|
|
||||||
case KindMap:
|
case KindMap:
|
||||||
if obj.Key == nil || obj.Val == nil {
|
|
||||||
panic("malformed map type")
|
|
||||||
}
|
|
||||||
|
|
||||||
if typ.Key == nil && typ.Val == nil {
|
|
||||||
return "partial", nil
|
|
||||||
}
|
|
||||||
if typ.Key == nil {
|
|
||||||
return obj.Val.ComplexCmp(typ.Val)
|
|
||||||
}
|
|
||||||
if typ.Val == nil {
|
|
||||||
return obj.Key.ComplexCmp(typ.Key)
|
|
||||||
}
|
|
||||||
|
|
||||||
kstatus, kerr := obj.Key.ComplexCmp(typ.Key)
|
kstatus, kerr := obj.Key.ComplexCmp(typ.Key)
|
||||||
vstatus, verr := obj.Val.ComplexCmp(typ.Val)
|
vstatus, verr := obj.Val.ComplexCmp(typ.Val)
|
||||||
if kerr != nil && verr != nil {
|
if kerr != nil && verr != nil {
|
||||||
@@ -973,19 +959,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
return "", verr
|
return "", verr
|
||||||
}
|
}
|
||||||
|
|
||||||
if kstatus == "" && vstatus == "" {
|
|
||||||
return "", nil
|
|
||||||
} else if kstatus != "" && vstatus == "" {
|
|
||||||
return kstatus, nil
|
|
||||||
} else if vstatus != "" && kstatus == "" {
|
|
||||||
return vstatus, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// optimization, redundant
|
|
||||||
//if kstatus == vstatus { // both partial or both variant...
|
|
||||||
// return kstatus, nil
|
|
||||||
//}
|
|
||||||
|
|
||||||
var isVariant, isPartial bool
|
var isVariant, isPartial bool
|
||||||
if kstatus == "variant" || vstatus == "variant" {
|
if kstatus == "variant" || vstatus == "variant" {
|
||||||
isVariant = true
|
isVariant = true
|
||||||
@@ -1001,24 +974,16 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
if !isVariant && !isPartial {
|
if !isVariant && !isPartial {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
if isVariant {
|
if isVariant && !isPartial {
|
||||||
return "variant", nil
|
return "variant", nil
|
||||||
}
|
}
|
||||||
if isPartial {
|
if isPartial && !isVariant {
|
||||||
return "partial", nil
|
return "partial", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//return "", fmt.Errorf("matches as both partial and variant")
|
|
||||||
return "both", nil
|
return "both", nil
|
||||||
|
|
||||||
case KindStruct: // {a bool; b int}
|
case KindStruct: // {a bool; b int}
|
||||||
if obj.Map == nil {
|
|
||||||
panic("malformed struct type")
|
|
||||||
}
|
|
||||||
if typ.Map == nil {
|
|
||||||
return "partial", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(obj.Ord) != len(typ.Ord) {
|
if len(obj.Ord) != len(typ.Ord) {
|
||||||
return "", fmt.Errorf("struct field count differs")
|
return "", fmt.Errorf("struct field count differs")
|
||||||
}
|
}
|
||||||
@@ -1037,13 +1002,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
panic("malformed struct order")
|
panic("malformed struct order")
|
||||||
}
|
}
|
||||||
if t1 == nil {
|
|
||||||
panic("malformed struct field")
|
|
||||||
}
|
|
||||||
if t2 == nil {
|
|
||||||
isPartial = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := t1.ComplexCmp(t2)
|
status, err := t1.ComplexCmp(t2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1064,24 +1022,16 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
if !isVariant && !isPartial {
|
if !isVariant && !isPartial {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
if isVariant {
|
if isVariant && !isPartial {
|
||||||
return "variant", nil
|
return "variant", nil
|
||||||
}
|
}
|
||||||
if isPartial {
|
if isPartial && !isVariant {
|
||||||
return "partial", nil
|
return "partial", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//return "", fmt.Errorf("matches as both partial and variant")
|
|
||||||
return "both", nil
|
return "both", nil
|
||||||
|
|
||||||
case KindFunc:
|
case KindFunc:
|
||||||
if obj.Map == nil {
|
|
||||||
panic("malformed func type")
|
|
||||||
}
|
|
||||||
if typ.Map == nil {
|
|
||||||
return "partial", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(obj.Ord) != len(typ.Ord) {
|
if len(obj.Ord) != len(typ.Ord) {
|
||||||
return "", fmt.Errorf("func arg count differs")
|
return "", fmt.Errorf("func arg count differs")
|
||||||
}
|
}
|
||||||
@@ -1102,13 +1052,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
// if !ok {
|
// if !ok {
|
||||||
// panic("malformed func order")
|
// panic("malformed func order")
|
||||||
// }
|
// }
|
||||||
// if t1 == nil {
|
|
||||||
// panic("malformed func arg")
|
|
||||||
// }
|
|
||||||
// if t2 == nil {
|
|
||||||
// isPartial = true
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
//
|
//
|
||||||
// status, err := t1.ComplexCmp(t2)
|
// status, err := t1.ComplexCmp(t2)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
@@ -1129,14 +1072,13 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
//if !isVariant && !isPartial {
|
//if !isVariant && !isPartial {
|
||||||
// return "", nil
|
// return "", nil
|
||||||
//}
|
//}
|
||||||
//if isVariant {
|
//if isVariant && !isPartial {
|
||||||
// return "variant", nil
|
// return "variant", nil
|
||||||
//}
|
//}
|
||||||
//if isPartial {
|
//if isPartial && !isVariant {
|
||||||
// return "partial", nil
|
// return "partial", nil
|
||||||
//}
|
//}
|
||||||
//
|
//
|
||||||
////return "", fmt.Errorf("matches as both partial and variant")
|
|
||||||
//return "both", nil
|
//return "both", nil
|
||||||
|
|
||||||
// if we're not comparing arg names, get the two lists of types
|
// if we're not comparing arg names, get the two lists of types
|
||||||
@@ -1146,18 +1088,10 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
panic("malformed func order")
|
panic("malformed func order")
|
||||||
}
|
}
|
||||||
if t1 == nil {
|
|
||||||
panic("malformed func arg")
|
|
||||||
}
|
|
||||||
|
|
||||||
t2, ok := typ.Map[typ.Ord[i]]
|
t2, ok := typ.Map[typ.Ord[i]]
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("malformed func order")
|
panic("malformed func order")
|
||||||
}
|
}
|
||||||
if t2 == nil {
|
|
||||||
isPartial = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := t1.ComplexCmp(t2)
|
status, err := t1.ComplexCmp(t2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1175,8 +1109,11 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//if obj.Out != nil && typ.Out != nil { // let a nil obj.Out in
|
// NOTE: Technically, .Out could be unspecified, then this is a
|
||||||
if typ.Out != nil { // let a nil obj.Out in
|
// Cmp fail, not an isPartial = true, but then we'd have to
|
||||||
|
// support functions without a return value. Since we are
|
||||||
|
// functional, it is not a major problem...
|
||||||
|
|
||||||
status, err := obj.Out.ComplexCmp(typ.Out)
|
status, err := obj.Out.ComplexCmp(typ.Out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1192,28 +1129,16 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
|
|||||||
isPartial = true
|
isPartial = true
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if obj.Out != nil {
|
|
||||||
// TODO: technically, typ.Out could be unspecified, then
|
|
||||||
// this is a Cmp fail, not an isPartial = true, but then
|
|
||||||
// we'd have to support functions without a return value
|
|
||||||
// since we are functional, it is not a major problem...
|
|
||||||
isPartial = true
|
|
||||||
}
|
|
||||||
//} else if typ.Out != nil { // solve this in the above ComplexCmp instead!
|
|
||||||
// return "", fmt.Errorf("can't cmp from a nil type")
|
|
||||||
//}
|
|
||||||
|
|
||||||
if !isVariant && !isPartial {
|
if !isVariant && !isPartial {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
if isVariant {
|
if isVariant && !isPartial {
|
||||||
return "variant", nil
|
return "variant", nil
|
||||||
}
|
}
|
||||||
if isPartial {
|
if isPartial && !isVariant {
|
||||||
return "partial", nil
|
return "partial", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//return "", fmt.Errorf("matches as both partial and variant")
|
|
||||||
return "both", nil
|
return "both", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,10 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/purpleidea/mgmt/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestType0(t *testing.T) {
|
func TestType0(t *testing.T) {
|
||||||
@@ -1295,6 +1298,180 @@ func TestType3(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestComplexCmp0(t *testing.T) {
|
||||||
|
type test struct { // an individual test
|
||||||
|
name string
|
||||||
|
typ1 *Type
|
||||||
|
typ2 *Type
|
||||||
|
err bool // expected err ?
|
||||||
|
str string // expected output str
|
||||||
|
}
|
||||||
|
testCases := []test{}
|
||||||
|
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "int vs str",
|
||||||
|
typ1: TypeInt,
|
||||||
|
typ2: TypeStr,
|
||||||
|
err: true,
|
||||||
|
str: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "nested list vs list variant",
|
||||||
|
typ1: NewType("[][]str"),
|
||||||
|
typ2: &Type{
|
||||||
|
Kind: KindList,
|
||||||
|
Val: TypeVariant,
|
||||||
|
},
|
||||||
|
err: false,
|
||||||
|
str: "variant",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "nil vs type",
|
||||||
|
typ1: nil,
|
||||||
|
typ2: NewType("[][]str"),
|
||||||
|
err: false,
|
||||||
|
str: "partial",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "variant vs type",
|
||||||
|
typ1: TypeVariant,
|
||||||
|
typ2: NewType("[][]str"),
|
||||||
|
err: false,
|
||||||
|
str: "variant",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "nil vs variant",
|
||||||
|
typ1: nil,
|
||||||
|
typ2: TypeVariant,
|
||||||
|
err: false,
|
||||||
|
str: "both",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "type vs nil",
|
||||||
|
typ1: NewType("[][]str"),
|
||||||
|
typ2: nil,
|
||||||
|
err: false,
|
||||||
|
str: "partial",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "type vs variant",
|
||||||
|
typ1: NewType("[][]str"),
|
||||||
|
typ2: TypeVariant,
|
||||||
|
err: false,
|
||||||
|
str: "variant",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "variant vs nil",
|
||||||
|
typ1: TypeVariant,
|
||||||
|
typ2: nil,
|
||||||
|
err: false,
|
||||||
|
str: "both",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// func([]int) VS func([]variant) int
|
||||||
|
testCases = append(testCases, test{
|
||||||
|
name: "partial vs variant",
|
||||||
|
typ1: &Type{
|
||||||
|
Kind: KindFunc,
|
||||||
|
Map: map[string]*Type{
|
||||||
|
"ints": {
|
||||||
|
Kind: KindList,
|
||||||
|
Val: TypeInt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Ord: []string{"ints"},
|
||||||
|
Out: nil, // unspecified, it's a partial
|
||||||
|
},
|
||||||
|
typ2: &Type{
|
||||||
|
Kind: KindFunc,
|
||||||
|
Map: map[string]*Type{
|
||||||
|
"ints": {
|
||||||
|
Kind: KindList,
|
||||||
|
Val: TypeVariant, // variant!
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Ord: []string{"ints"},
|
||||||
|
Out: TypeInt,
|
||||||
|
},
|
||||||
|
err: false,
|
||||||
|
str: "both",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if testing.Short() {
|
||||||
|
t.Logf("available tests:")
|
||||||
|
}
|
||||||
|
names := []string{}
|
||||||
|
for index, tc := range testCases { // run all the tests
|
||||||
|
if tc.name == "" {
|
||||||
|
t.Errorf("test #%d: not named", index)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if util.StrInList(tc.name, names) {
|
||||||
|
t.Errorf("test #%d: duplicate sub test name of: %s", index, tc.name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
names = append(names, tc.name)
|
||||||
|
|
||||||
|
testName := fmt.Sprintf("test #%d (%s)", index, tc.name)
|
||||||
|
if testing.Short() { // make listing tests easier
|
||||||
|
t.Logf("%s", testName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
typ1, typ2, err, str := tc.typ1, tc.typ2, tc.err, tc.str
|
||||||
|
|
||||||
|
// the reverse should probably match the forward version
|
||||||
|
s1, err1 := typ1.ComplexCmp(typ2)
|
||||||
|
s2, err2 := typ2.ComplexCmp(typ1)
|
||||||
|
|
||||||
|
if err && err1 == nil {
|
||||||
|
t.Errorf("test #%d: FAIL", index)
|
||||||
|
t.Errorf("test #%d: expected error, got nil", index)
|
||||||
|
}
|
||||||
|
if !err && err1 != nil {
|
||||||
|
t.Errorf("test #%d: FAIL", index)
|
||||||
|
t.Errorf("test #%d: unexpected error: %+v", index, err1)
|
||||||
|
}
|
||||||
|
if err && err2 == nil {
|
||||||
|
t.Errorf("test #%d: FAIL", index)
|
||||||
|
t.Errorf("test #%d: expected error, got nil", index)
|
||||||
|
}
|
||||||
|
if !err && err2 != nil {
|
||||||
|
t.Errorf("test #%d: FAIL", index)
|
||||||
|
t.Errorf("test #%d: unexpected error: %+v", index, err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s1 != s2 {
|
||||||
|
t.Errorf("test #%d: FAIL", index)
|
||||||
|
t.Errorf("test #%d: strings did not match: %+v != %+v", index, s1, s2)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s1 != str {
|
||||||
|
t.Errorf("test #%d: FAIL", index)
|
||||||
|
t.Errorf("test #%d: unexpected string: %+v != %+v", index, s1, str)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTypeOf0(t *testing.T) {
|
func TestTypeOf0(t *testing.T) {
|
||||||
// TODO: implement testing of the TypeOf function
|
// TODO: implement testing of the TypeOf function
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user