lang: types: Improve ComplexCmp function

This improves the ComplexCmp function so that it can compare partial
types to variant types. As a result of this improvement, it actually
ended up simplifying the code significantly. This also added a test
suite for this function. This fix was important for tricky type
unification problems.
This commit is contained in:
James Shubin
2019-06-02 16:48:41 -04:00
parent 99d3ef42e9
commit 4c6d304e60
2 changed files with 219 additions and 117 deletions

View File

@@ -902,22 +902,29 @@ func (obj *Type) HasVariant() bool {
// a possibility against a partial type, the status will be set to the "partial" // a possibility against a partial type, the status will be set to the "partial"
// string, and if it is compatible with the variant type it will be "variant"... // string, and if it is compatible with the variant type it will be "variant"...
// Comparing to a partial can only match "impossible" (error) or possible (nil). // Comparing to a partial can only match "impossible" (error) or possible (nil).
// This now also supports comparing a partial type to a variant type as well...
func (obj *Type) ComplexCmp(typ *Type) (string, error) { func (obj *Type) ComplexCmp(typ *Type) (string, error) {
// match simple "placeholder" variants... skip variants w/ sub types // match simple "placeholder" variants... skip variants w/ sub types
isVariant := func(t *Type) bool { return t.Kind == KindVariant && t.Var == nil } isVariant := func(t *Type) bool { return t != nil && t.Kind == KindVariant && t.Var == nil }
if obj == nil { if obj == nil && typ == nil {
return "", fmt.Errorf("can't cmp from a nil type")
}
// XXX: can we relax this to allow variants matching against partials?
if obj.HasVariant() {
return "", fmt.Errorf("only input can contain variants")
}
if typ == nil { // match
return "partial", nil // compatible :) return "partial", nil // compatible :)
} }
if isVariant(typ) { // match if isVariant(obj) && isVariant(typ) {
return "variant", nil // compatible :)
}
if obj == nil && isVariant(typ) { // partial vs variant
return "both", nil // compatible :)
}
if isVariant(obj) && typ == nil { // variant vs partial
return "both", nil // compatible :)
}
if obj == nil || typ == nil { // at least one is partial
return "partial", nil // compatible :)
}
if isVariant(obj) || isVariant(typ) { // at least one is variant
return "variant", nil // compatible :) return "variant", nil // compatible :)
} }
@@ -937,30 +944,9 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
return "", nil return "", nil
case KindList: case KindList:
if obj.Val == nil {
panic("malformed list type")
}
if typ.Val == nil {
return "partial", nil
}
return obj.Val.ComplexCmp(typ.Val) return obj.Val.ComplexCmp(typ.Val)
case KindMap: case KindMap:
if obj.Key == nil || obj.Val == nil {
panic("malformed map type")
}
if typ.Key == nil && typ.Val == nil {
return "partial", nil
}
if typ.Key == nil {
return obj.Val.ComplexCmp(typ.Val)
}
if typ.Val == nil {
return obj.Key.ComplexCmp(typ.Key)
}
kstatus, kerr := obj.Key.ComplexCmp(typ.Key) kstatus, kerr := obj.Key.ComplexCmp(typ.Key)
vstatus, verr := obj.Val.ComplexCmp(typ.Val) vstatus, verr := obj.Val.ComplexCmp(typ.Val)
if kerr != nil && verr != nil { if kerr != nil && verr != nil {
@@ -973,19 +959,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
return "", verr return "", verr
} }
if kstatus == "" && vstatus == "" {
return "", nil
} else if kstatus != "" && vstatus == "" {
return kstatus, nil
} else if vstatus != "" && kstatus == "" {
return vstatus, nil
}
// optimization, redundant
//if kstatus == vstatus { // both partial or both variant...
// return kstatus, nil
//}
var isVariant, isPartial bool var isVariant, isPartial bool
if kstatus == "variant" || vstatus == "variant" { if kstatus == "variant" || vstatus == "variant" {
isVariant = true isVariant = true
@@ -1001,24 +974,16 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
if !isVariant && !isPartial { if !isVariant && !isPartial {
return "", nil return "", nil
} }
if isVariant { if isVariant && !isPartial {
return "variant", nil return "variant", nil
} }
if isPartial { if isPartial && !isVariant {
return "partial", nil return "partial", nil
} }
//return "", fmt.Errorf("matches as both partial and variant")
return "both", nil return "both", nil
case KindStruct: // {a bool; b int} case KindStruct: // {a bool; b int}
if obj.Map == nil {
panic("malformed struct type")
}
if typ.Map == nil {
return "partial", nil
}
if len(obj.Ord) != len(typ.Ord) { if len(obj.Ord) != len(typ.Ord) {
return "", fmt.Errorf("struct field count differs") return "", fmt.Errorf("struct field count differs")
} }
@@ -1037,13 +1002,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
if !ok { if !ok {
panic("malformed struct order") panic("malformed struct order")
} }
if t1 == nil {
panic("malformed struct field")
}
if t2 == nil {
isPartial = true
continue
}
status, err := t1.ComplexCmp(t2) status, err := t1.ComplexCmp(t2)
if err != nil { if err != nil {
@@ -1064,24 +1022,16 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
if !isVariant && !isPartial { if !isVariant && !isPartial {
return "", nil return "", nil
} }
if isVariant { if isVariant && !isPartial {
return "variant", nil return "variant", nil
} }
if isPartial { if isPartial && !isVariant {
return "partial", nil return "partial", nil
} }
//return "", fmt.Errorf("matches as both partial and variant")
return "both", nil return "both", nil
case KindFunc: case KindFunc:
if obj.Map == nil {
panic("malformed func type")
}
if typ.Map == nil {
return "partial", nil
}
if len(obj.Ord) != len(typ.Ord) { if len(obj.Ord) != len(typ.Ord) {
return "", fmt.Errorf("func arg count differs") return "", fmt.Errorf("func arg count differs")
} }
@@ -1102,13 +1052,6 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
// if !ok { // if !ok {
// panic("malformed func order") // panic("malformed func order")
// } // }
// if t1 == nil {
// panic("malformed func arg")
// }
// if t2 == nil {
// isPartial = true
// continue
// }
// //
// status, err := t1.ComplexCmp(t2) // status, err := t1.ComplexCmp(t2)
// if err != nil { // if err != nil {
@@ -1129,14 +1072,13 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
//if !isVariant && !isPartial { //if !isVariant && !isPartial {
// return "", nil // return "", nil
//} //}
//if isVariant { //if isVariant && !isPartial {
// return "variant", nil // return "variant", nil
//} //}
//if isPartial { //if isPartial && !isVariant {
// return "partial", nil // return "partial", nil
//} //}
// //
////return "", fmt.Errorf("matches as both partial and variant")
//return "both", nil //return "both", nil
// if we're not comparing arg names, get the two lists of types // if we're not comparing arg names, get the two lists of types
@@ -1146,18 +1088,10 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
if !ok { if !ok {
panic("malformed func order") panic("malformed func order")
} }
if t1 == nil {
panic("malformed func arg")
}
t2, ok := typ.Map[typ.Ord[i]] t2, ok := typ.Map[typ.Ord[i]]
if !ok { if !ok {
panic("malformed func order") panic("malformed func order")
} }
if t2 == nil {
isPartial = true
continue
}
status, err := t1.ComplexCmp(t2) status, err := t1.ComplexCmp(t2)
if err != nil { if err != nil {
@@ -1175,45 +1109,36 @@ func (obj *Type) ComplexCmp(typ *Type) (string, error) {
} }
} }
//if obj.Out != nil && typ.Out != nil { // let a nil obj.Out in // NOTE: Technically, .Out could be unspecified, then this is a
if typ.Out != nil { // let a nil obj.Out in // Cmp fail, not an isPartial = true, but then we'd have to
status, err := obj.Out.ComplexCmp(typ.Out) // support functions without a return value. Since we are
if err != nil { // functional, it is not a major problem...
return "", err
}
if status == "variant" {
isVariant = true
}
if status == "partial" {
isPartial = true
}
if status == "both" {
isVariant = true
isPartial = true
}
} else if obj.Out != nil { status, err := obj.Out.ComplexCmp(typ.Out)
// TODO: technically, typ.Out could be unspecified, then if err != nil {
// this is a Cmp fail, not an isPartial = true, but then return "", err
// we'd have to support functions without a return value }
// since we are functional, it is not a major problem... if status == "variant" {
isVariant = true
}
if status == "partial" {
isPartial = true
}
if status == "both" {
isVariant = true
isPartial = true isPartial = true
} }
//} else if typ.Out != nil { // solve this in the above ComplexCmp instead!
// return "", fmt.Errorf("can't cmp from a nil type")
//}
if !isVariant && !isPartial { if !isVariant && !isPartial {
return "", nil return "", nil
} }
if isVariant { if isVariant && !isPartial {
return "variant", nil return "variant", nil
} }
if isPartial { if isPartial && !isVariant {
return "partial", nil return "partial", nil
} }
//return "", fmt.Errorf("matches as both partial and variant")
return "both", nil return "both", nil
} }

View File

@@ -20,7 +20,10 @@
package types package types
import ( import (
"fmt"
"testing" "testing"
"github.com/purpleidea/mgmt/util"
) )
func TestType0(t *testing.T) { func TestType0(t *testing.T) {
@@ -1295,6 +1298,180 @@ func TestType3(t *testing.T) {
} }
} }
func TestComplexCmp0(t *testing.T) {
type test struct { // an individual test
name string
typ1 *Type
typ2 *Type
err bool // expected err ?
str string // expected output str
}
testCases := []test{}
{
testCases = append(testCases, test{
name: "int vs str",
typ1: TypeInt,
typ2: TypeStr,
err: true,
str: "",
})
}
{
testCases = append(testCases, test{
name: "nested list vs list variant",
typ1: NewType("[][]str"),
typ2: &Type{
Kind: KindList,
Val: TypeVariant,
},
err: false,
str: "variant",
})
}
{
testCases = append(testCases, test{
name: "nil vs type",
typ1: nil,
typ2: NewType("[][]str"),
err: false,
str: "partial",
})
}
{
testCases = append(testCases, test{
name: "variant vs type",
typ1: TypeVariant,
typ2: NewType("[][]str"),
err: false,
str: "variant",
})
}
{
testCases = append(testCases, test{
name: "nil vs variant",
typ1: nil,
typ2: TypeVariant,
err: false,
str: "both",
})
}
{
testCases = append(testCases, test{
name: "type vs nil",
typ1: NewType("[][]str"),
typ2: nil,
err: false,
str: "partial",
})
}
{
testCases = append(testCases, test{
name: "type vs variant",
typ1: NewType("[][]str"),
typ2: TypeVariant,
err: false,
str: "variant",
})
}
{
testCases = append(testCases, test{
name: "variant vs nil",
typ1: TypeVariant,
typ2: nil,
err: false,
str: "both",
})
}
{
// func([]int) VS func([]variant) int
testCases = append(testCases, test{
name: "partial vs variant",
typ1: &Type{
Kind: KindFunc,
Map: map[string]*Type{
"ints": {
Kind: KindList,
Val: TypeInt,
},
},
Ord: []string{"ints"},
Out: nil, // unspecified, it's a partial
},
typ2: &Type{
Kind: KindFunc,
Map: map[string]*Type{
"ints": {
Kind: KindList,
Val: TypeVariant, // variant!
},
},
Ord: []string{"ints"},
Out: TypeInt,
},
err: false,
str: "both",
})
}
if testing.Short() {
t.Logf("available tests:")
}
names := []string{}
for index, tc := range testCases { // run all the tests
if tc.name == "" {
t.Errorf("test #%d: not named", index)
continue
}
if util.StrInList(tc.name, names) {
t.Errorf("test #%d: duplicate sub test name of: %s", index, tc.name)
continue
}
names = append(names, tc.name)
testName := fmt.Sprintf("test #%d (%s)", index, tc.name)
if testing.Short() { // make listing tests easier
t.Logf("%s", testName)
continue
}
t.Run(testName, func(t *testing.T) {
typ1, typ2, err, str := tc.typ1, tc.typ2, tc.err, tc.str
// the reverse should probably match the forward version
s1, err1 := typ1.ComplexCmp(typ2)
s2, err2 := typ2.ComplexCmp(typ1)
if err && err1 == nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: expected error, got nil", index)
}
if !err && err1 != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: unexpected error: %+v", index, err1)
}
if err && err2 == nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: expected error, got nil", index)
}
if !err && err2 != nil {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: unexpected error: %+v", index, err2)
}
if s1 != s2 {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: strings did not match: %+v != %+v", index, s1, s2)
return
}
if s1 != str {
t.Errorf("test #%d: FAIL", index)
t.Errorf("test #%d: unexpected string: %+v != %+v", index, s1, str)
return
}
})
}
}
func TestTypeOf0(t *testing.T) { func TestTypeOf0(t *testing.T) {
// TODO: implement testing of the TypeOf function // TODO: implement testing of the TypeOf function
} }