diff --git a/lang/ast/structs.go b/lang/ast/structs.go index 7c63bda6..9237e253 100644 --- a/lang/ast/structs.go +++ b/lang/ast/structs.go @@ -589,13 +589,13 @@ func (obj *StmtRes) Unify() ([]interfaces.Invariant, error) { invarListStr := &interfaces.EqualsInvariant{ Expr: obj.Name, - Type: types.NewType("[]str"), + Type: types.TypeListStr, } // Optimization: If we know it's a []str, no need for exclusives! if expr, ok := obj.Name.(*ExprList); ok { typ, err := expr.Type() - if err == nil && typ.Cmp(types.NewType("[]str")) == nil { + if err == nil && typ.Cmp(types.TypeListStr) == nil { invariants = append(invariants, invarListStr) return invariants, nil } @@ -707,7 +707,7 @@ func (obj *StmtRes) Output(table map[interfaces.Func]types.Value) (*interfaces.O name := nameValue.Str() // must not panic names = append(names, name) - case types.NewType("[]str").Cmp(nameValue.Type()) == nil: + case types.TypeListStr.Cmp(nameValue.Type()) == nil: for _, x := range nameValue.List() { // must not panic name := x.Str() // must not panic names = append(names, name) @@ -880,7 +880,7 @@ func (obj *StmtRes) edges(table map[interfaces.Func]types.Value, resName string) name := nameValue.Str() // must not panic names = append(names, name) - case types.NewType("[]str").Cmp(nameValue.Type()) == nil: + case types.TypeListStr.Cmp(nameValue.Type()) == nil: for _, x := range nameValue.List() { // must not panic name := x.Str() // must not panic names = append(names, name) @@ -1953,7 +1953,7 @@ func (obj *StmtResMeta) Unify(kind string) ([]interfaces.Invariant, error) { invar = static(types.TypeBool) case "sema": - invar = static(types.NewType("[]str")) + invar = static(types.TypeListStr) case "rewatch": invar = static(types.TypeBool) @@ -2235,14 +2235,14 @@ func (obj *StmtEdge) Unify() ([]interfaces.Invariant, error) { // display something nicer if v.Type().Kind == types.KindStr { p1 = engine.Repr(k1, v.Str()) - } else if v.Type().Cmp(types.NewType("[]str")) == nil { + } else if v.Type().Cmp(types.TypeListStr) == nil { p1 = engine.Repr(k1, v.String()) } } if v, err := obj.EdgeHalfList[1].Name.Value(); err == nil { if v.Type().Kind == types.KindStr { p2 = engine.Repr(k2, v.Str()) - } else if v.Type().Cmp(types.NewType("[]str")) == nil { + } else if v.Type().Cmp(types.TypeListStr) == nil { p2 = engine.Repr(k2, v.String()) } } @@ -2319,7 +2319,7 @@ func (obj *StmtEdge) Output(table map[interfaces.Func]types.Value) (*interfaces. name := nameValue1.Str() // must not panic names1 = append(names1, name) - case types.NewType("[]str").Cmp(nameValue1.Type()) == nil: + case types.TypeListStr.Cmp(nameValue1.Type()) == nil: for _, x := range nameValue1.List() { // must not panic name := x.Str() // must not panic names1 = append(names1, name) @@ -2344,7 +2344,7 @@ func (obj *StmtEdge) Output(table map[interfaces.Func]types.Value) (*interfaces. name := nameValue2.Str() // must not panic names2 = append(names2, name) - case types.NewType("[]str").Cmp(nameValue2.Type()) == nil: + case types.TypeListStr.Cmp(nameValue2.Type()) == nil: for _, x := range nameValue2.List() { // must not panic name := x.Str() // must not panic names2 = append(names2, name) @@ -2501,13 +2501,13 @@ func (obj *StmtEdgeHalf) Unify() ([]interfaces.Invariant, error) { invarListStr := &interfaces.EqualsInvariant{ Expr: obj.Name, - Type: types.NewType("[]str"), + Type: types.TypeListStr, } // Optimization: If we know it's a []str, no need for exclusives! if expr, ok := obj.Name.(*ExprList); ok { typ, err := expr.Type() - if err == nil && typ.Cmp(types.NewType("[]str")) == nil { + if err == nil && typ.Cmp(types.TypeListStr) == nil { invariants = append(invariants, invarListStr) return invariants, nil } diff --git a/lang/core/os/args_func.go b/lang/core/os/args_func.go index 61f425f3..023afbf9 100644 --- a/lang/core/os/args_func.go +++ b/lang/core/os/args_func.go @@ -54,6 +54,6 @@ func Args([]types.Value) (types.Value, error) { } return &types.ListValue{ V: values, - T: types.NewType("[]str"), + T: types.TypeListStr, }, nil } diff --git a/lang/core/strings/split_func.go b/lang/core/strings/split_func.go index 6b9ec55d..ac556e0a 100644 --- a/lang/core/strings/split_func.go +++ b/lang/core/strings/split_func.go @@ -50,7 +50,7 @@ func Split(input []types.Value) (types.Value, error) { segments := strings.Split(str, sep) - listVal := types.NewList(types.NewType("[]str")) + listVal := types.NewList(types.TypeListStr) for _, segment := range segments { listVal.Add(&types.StrValue{ diff --git a/lang/core/world/schedule_func.go b/lang/core/world/schedule_func.go index 8df08af2..6799a4bf 100644 --- a/lang/core/world/schedule_func.go +++ b/lang/core/world/schedule_func.go @@ -149,7 +149,7 @@ func (obj *ScheduleFunc) Unify(expr interfaces.Expr) ([]interfaces.Invariant, er // return type of []string invar = &interfaces.EqualsInvariant{ Expr: dummyOut, - Type: types.NewType("[]str"), + Type: types.TypeListStr, } invariants = append(invariants, invar) @@ -345,7 +345,7 @@ func (obj *ScheduleFunc) Polymorphisms(partialType *types.Type, partialValues [] var typ *types.Type if tOut := partialType.Out; tOut != nil { - if err := tOut.Cmp(types.NewType("[]str")); err != nil { + if err := tOut.Cmp(types.TypeListStr); err != nil { return nil, errwrap.Wrapf(err, "return type must be a list of strings") } } @@ -428,7 +428,7 @@ func (obj *ScheduleFunc) Build(typ *types.Type) (*types.Type, error) { return nil, fmt.Errorf("invalid input type") } - if err := typ.Out.Cmp(types.NewType("[]str")); err != nil { + if err := typ.Out.Cmp(types.TypeListStr); err != nil { return nil, errwrap.Wrapf(err, "return type must be a list of strings") } diff --git a/lang/funcs/funcgen/config.go b/lang/funcs/funcgen/config.go index 11d2c515..85e3c472 100644 --- a/lang/funcs/funcgen/config.go +++ b/lang/funcs/funcgen/config.go @@ -77,7 +77,7 @@ func (obj *arg) ToMcl() (string, error) { case "float64": return fmt.Sprintf("%s%s", prefix, types.TypeFloat.String()), nil case "[]string": - return fmt.Sprintf("%s%s", prefix, types.NewType("[]str").String()), nil + return fmt.Sprintf("%s%s", prefix, types.TypeListStr.String()), nil default: return "", fmt.Errorf("cannot convert %v to mcl", obj.Type) } diff --git a/lang/types/type.go b/lang/types/type.go index e8073c1f..e8089236 100644 --- a/lang/types/type.go +++ b/lang/types/type.go @@ -50,6 +50,7 @@ var ( TypeStr = NewType("str") TypeInt = NewType("int") TypeFloat = NewType("float") + TypeListStr = NewType("[]str") TypeVariant = NewType("variant") ) diff --git a/lang/unification/solvers/unification_test.go b/lang/unification/solvers/unification_test.go index ad281533..a7d9a65c 100644 --- a/lang/unification/solvers/unification_test.go +++ b/lang/unification/solvers/unification_test.go @@ -125,7 +125,7 @@ func TestUnification1(t *testing.T) { v1: types.TypeStr, v2: types.TypeStr, v3: types.TypeStr, - expr: types.NewType("[]str"), + expr: types.TypeListStr, }, }) }