lang: core, funcs, types: Add ctx to simple func

Plumb through the standard context.Context so that a function can be
cancelled if someone requests this. It makes it less awkward to write
simple functions that might depend on io or network access.
This commit is contained in:
James Shubin
2024-05-09 19:25:46 -04:00
parent 3b754d5324
commit 415e22abe2
51 changed files with 166 additions and 108 deletions

View File

@@ -30,6 +30,8 @@
package core package core
import ( import (
"context"
"github.com/purpleidea/mgmt/lang/funcs" "github.com/purpleidea/mgmt/lang/funcs"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -48,7 +50,7 @@ func init() {
} }
// Concat concatenates two strings together. // Concat concatenates two strings together.
func Concat(input []types.Value) (types.Value, error) { func Concat(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: input[0].Str() + input[1].Str(), V: input[0].Str() + input[1].Str(),
}, nil }, nil

View File

@@ -30,6 +30,7 @@
package convert package convert
import ( import (
"context"
"strconv" "strconv"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -45,7 +46,7 @@ func init() {
// FormatBool converts a boolean to a string representation that can be consumed // FormatBool converts a boolean to a string representation that can be consumed
// by ParseBool. This value will be `"true"` or `"false"`. // by ParseBool. This value will be `"true"` or `"false"`.
func FormatBool(input []types.Value) (types.Value, error) { func FormatBool(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: strconv.FormatBool(input[0].Bool()), V: strconv.FormatBool(input[0].Bool()),
}, nil }, nil

View File

@@ -30,6 +30,7 @@
package convert package convert
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
@@ -48,7 +49,7 @@ func init() {
// it an invalid value. Valid values match what is accepted by the golang // it an invalid value. Valid values match what is accepted by the golang
// strconv.ParseBool function. It's recommended to use the strings `true` or // strconv.ParseBool function. It's recommended to use the strings `true` or
// `false` if you are undecided about what string representation to choose. // `false` if you are undecided about what string representation to choose.
func ParseBool(input []types.Value) (types.Value, error) { func ParseBool(ctx context.Context, input []types.Value) (types.Value, error) {
s := input[0].Str() s := input[0].Str()
b, err := strconv.ParseBool(s) b, err := strconv.ParseBool(s)
if err != nil { if err != nil {

View File

@@ -30,6 +30,8 @@
package convert package convert
import ( import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
) )
@@ -42,7 +44,7 @@ func init() {
} }
// ToFloat converts an integer to a float. // ToFloat converts an integer to a float.
func ToFloat(input []types.Value) (types.Value, error) { func ToFloat(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{ return &types.FloatValue{
V: float64(input[0].Int()), V: float64(input[0].Int()),
}, nil }, nil

View File

@@ -30,13 +30,14 @@
package convert package convert
import ( import (
"context"
"testing" "testing"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
) )
func testToFloat(t *testing.T, input int64, expected float64) { func testToFloat(t *testing.T, input int64, expected float64) {
got, err := ToFloat([]types.Value{&types.IntValue{V: input}}) got, err := ToFloat(context.Background(), []types.Value{&types.IntValue{V: input}})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@@ -30,6 +30,8 @@
package convert package convert
import ( import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
) )
@@ -42,7 +44,7 @@ func init() {
} }
// ToInt converts a float to an integer. // ToInt converts a float to an integer.
func ToInt(input []types.Value) (types.Value, error) { func ToInt(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.IntValue{ return &types.IntValue{
V: int64(input[0].Float()), V: int64(input[0].Float()),
}, nil }, nil

View File

@@ -30,6 +30,7 @@
package convert package convert
import ( import (
"context"
"testing" "testing"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -37,7 +38,7 @@ import (
func testToInt(t *testing.T, input float64, expected int64) { func testToInt(t *testing.T, input float64, expected int64) {
got, err := ToInt([]types.Value{&types.FloatValue{V: input}}) got, err := ToInt(context.Background(), []types.Value{&types.FloatValue{V: input}})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@@ -30,6 +30,7 @@
package convert package convert
import ( import (
"context"
"strconv" "strconv"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -49,7 +50,7 @@ func init() {
} }
// IntToStr converts an integer to a string. // IntToStr converts an integer to a string.
func IntToStr(input []types.Value) (types.Value, error) { func IntToStr(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: strconv.Itoa(int(input[0].Int())), V: strconv.Itoa(int(input[0].Int())),
}, nil }, nil

View File

@@ -30,6 +30,7 @@
package coredatetime package coredatetime
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -48,7 +49,7 @@ func init() {
// has to be defined like specified by the golang "time" package. The time is // has to be defined like specified by the golang "time" package. The time is
// the number of seconds since the epoch, and matches what comes from our Now // the number of seconds since the epoch, and matches what comes from our Now
// function. Golang documentation: https://golang.org/pkg/time/#Time.Format // function. Golang documentation: https://golang.org/pkg/time/#Time.Format
func Format(input []types.Value) (types.Value, error) { func Format(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int() epochDelta := input[0].Int()
if epochDelta < 0 { if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive") return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -32,6 +32,7 @@
package coredatetime package coredatetime
import ( import (
"context"
"testing" "testing"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -41,7 +42,7 @@ func TestFormat(t *testing.T) {
inputVal := &types.IntValue{V: 1443158163} inputVal := &types.IntValue{V: 1443158163}
inputFormat := &types.StrValue{V: "2006"} inputFormat := &types.StrValue{V: "2006"}
val, err := Format([]types.Value{inputVal, inputFormat}) val, err := Format(context.Background(), []types.Value{inputVal, inputFormat})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }

View File

@@ -30,6 +30,7 @@
package coredatetime package coredatetime
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -47,7 +48,7 @@ func init() {
// Hour returns the hour of the day corresponding to the input time. The time is // Hour returns the hour of the day corresponding to the input time. The time is
// the number of seconds since the epoch, and matches what comes from our Now // the number of seconds since the epoch, and matches what comes from our Now
// function. // function.
func Hour(input []types.Value) (types.Value, error) { func Hour(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int() epochDelta := input[0].Int()
if epochDelta < 0 { if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive") return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -30,6 +30,7 @@
package coredatetime package coredatetime
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -41,7 +42,7 @@ func init() {
// FIXME: consider renaming this to printf, and add in a format string? // FIXME: consider renaming this to printf, and add in a format string?
simple.ModuleRegister(ModuleName, "print", &types.FuncValue{ simple.ModuleRegister(ModuleName, "print", &types.FuncValue{
T: types.NewType("func(a int) str"), T: types.NewType("func(a int) str"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int() epochDelta := input[0].Int()
if epochDelta < 0 { if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive") return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -30,6 +30,7 @@
package coredatetime package coredatetime
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -48,7 +49,7 @@ func init() {
// Weekday returns the lowercased day of the week corresponding to the input // Weekday returns the lowercased day of the week corresponding to the input
// time. The time is the number of seconds since the epoch, and matches what // time. The time is the number of seconds since the epoch, and matches what
// comes from our Now function. // comes from our Now function.
func Weekday(input []types.Value) (types.Value, error) { func Weekday(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int() epochDelta := input[0].Int()
if epochDelta < 0 { if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive") return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -418,11 +418,14 @@ func (obj *provisioner) Register(moduleName string) error {
// Build a few separately... // Build a few separately...
simple.ModuleRegister(moduleName, "cli_password", &types.FuncValue{ simple.ModuleRegister(moduleName, "cli_password", &types.FuncValue{
T: types.NewType("func() str"), T: types.NewType("func() str"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
if obj.localArgs == nil { if obj.localArgs == nil {
// programming error // programming error
return nil, fmt.Errorf("could not convert/access our struct") return nil, fmt.Errorf("could not convert/access our struct")
} }
// TODO: plumb through the password lookup here instead?
//localArgs := *obj.localArgs // optional //localArgs := *obj.localArgs // optional
return &types.StrValue{ return &types.StrValue{
V: obj.password, V: obj.password,

View File

@@ -30,6 +30,8 @@
package coreexample package coreexample
import ( import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
) )
@@ -40,7 +42,7 @@ const Answer = 42
func init() { func init() {
simple.ModuleRegister(ModuleName, "answer", &types.FuncValue{ simple.ModuleRegister(ModuleName, "answer", &types.FuncValue{
T: types.NewType("func() int"), T: types.NewType("func() int"),
V: func([]types.Value) (types.Value, error) { V: func(context.Context, []types.Value) (types.Value, error) {
return &types.IntValue{V: Answer}, nil return &types.IntValue{V: Answer}, nil
}, },
}) })

View File

@@ -30,6 +30,7 @@
package coreexample package coreexample
import ( import (
"context"
"fmt" "fmt"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -39,7 +40,7 @@ import (
func init() { func init() {
simple.ModuleRegister(ModuleName, "errorbool", &types.FuncValue{ simple.ModuleRegister(ModuleName, "errorbool", &types.FuncValue{
T: types.NewType("func(a bool) str"), T: types.NewType("func(a bool) str"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
if input[0].Bool() { if input[0].Bool() {
return nil, fmt.Errorf("we errored on request") return nil, fmt.Errorf("we errored on request")
} }

View File

@@ -30,6 +30,7 @@
package coreexample package coreexample
import ( import (
"context"
"fmt" "fmt"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -39,7 +40,7 @@ import (
func init() { func init() {
simple.ModuleRegister(ModuleName, "int2str", &types.FuncValue{ simple.ModuleRegister(ModuleName, "int2str", &types.FuncValue{
T: types.NewType("func(a int) str"), T: types.NewType("func(a int) str"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: fmt.Sprintf("%d", input[0].Int()), V: fmt.Sprintf("%d", input[0].Int()),
}, nil }, nil

View File

@@ -30,6 +30,8 @@
package corenested package corenested
import ( import (
"context"
coreexample "github.com/purpleidea/mgmt/lang/core/example" coreexample "github.com/purpleidea/mgmt/lang/core/example"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -43,7 +45,7 @@ func init() {
} }
// Hello returns some string. This is just to test nesting. // Hello returns some string. This is just to test nesting.
func Hello(input []types.Value) (types.Value, error) { func Hello(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: "Hello!", V: "Hello!",
}, nil }, nil

View File

@@ -30,6 +30,8 @@
package coreexample package coreexample
import ( import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
) )
@@ -42,7 +44,7 @@ func init() {
} }
// Plus returns y + z. // Plus returns y + z.
func Plus(input []types.Value) (types.Value, error) { func Plus(ctx context.Context, input []types.Value) (types.Value, error) {
y, z := input[0].Str(), input[1].Str() y, z := input[0].Str(), input[1].Str()
return &types.StrValue{ return &types.StrValue{
V: y + z, V: y + z,

View File

@@ -30,6 +30,7 @@
package coreexample package coreexample
import ( import (
"context"
"strconv" "strconv"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -39,7 +40,7 @@ import (
func init() { func init() {
simple.ModuleRegister(ModuleName, "str2int", &types.FuncValue{ simple.ModuleRegister(ModuleName, "str2int", &types.FuncValue{
T: types.NewType("func(a str) int"), T: types.NewType("func(a str) int"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
var i int64 var i int64
if val, err := strconv.ParseInt(input[0].Str(), 10, 64); err == nil { if val, err := strconv.ParseInt(input[0].Str(), 10, 64); err == nil {
i = val i = val

View File

@@ -766,7 +766,7 @@ func (obj *MapFunc) replaceSubGraph(subgraphInput interfaces.Func) error {
outputListFunc := structs.SimpleFnToDirectFunc( outputListFunc := structs.SimpleFnToDirectFunc(
"mapOutputList", "mapOutputList",
&types.FuncValue{ &types.FuncValue{
V: func(args []types.Value) (types.Value, error) { V: func(_ context.Context, args []types.Value) (types.Value, error) {
listValue := &types.ListValue{ listValue := &types.ListValue{
V: args, V: args,
T: obj.outputListType, T: obj.outputListType,
@@ -788,7 +788,7 @@ func (obj *MapFunc) replaceSubGraph(subgraphInput interfaces.Func) error {
inputElemFunc := structs.SimpleFnToDirectFunc( inputElemFunc := structs.SimpleFnToDirectFunc(
fmt.Sprintf("mapInputElem[%d]", i), fmt.Sprintf("mapInputElem[%d]", i),
&types.FuncValue{ &types.FuncValue{
V: func(args []types.Value) (types.Value, error) { V: func(_ context.Context, args []types.Value) (types.Value, error) {
if len(args) != 1 { if len(args) != 1 {
return nil, fmt.Errorf("inputElemFunc: expected a single argument") return nil, fmt.Errorf("inputElemFunc: expected a single argument")
} }

View File

@@ -30,6 +30,7 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"github.com/purpleidea/mgmt/lang/funcs/simplepoly" "github.com/purpleidea/mgmt/lang/funcs/simplepoly"
@@ -56,7 +57,7 @@ func init() {
// Len returns the number of elements in a list or the number of key pairs in a // Len returns the number of elements in a list or the number of key pairs in a
// map. It can operate on either of these types. // map. It can operate on either of these types.
func Len(input []types.Value) (types.Value, error) { func Len(ctx context.Context, input []types.Value) (types.Value, error) {
var length int var length int
switch k := input[0].Type().Kind; k { switch k := input[0].Type().Kind; k {
case types.KindStr: case types.KindStr:

View File

@@ -30,6 +30,7 @@
package coremath package coremath
import ( import (
"context"
"fmt" "fmt"
"github.com/purpleidea/mgmt/lang/funcs/simplepoly" "github.com/purpleidea/mgmt/lang/funcs/simplepoly"
@@ -57,8 +58,8 @@ func init() {
// in a sig field, like how we demonstrate in the implementation of FortyTwo. If // in a sig field, like how we demonstrate in the implementation of FortyTwo. If
// the API doesn't change, then this is an example of how to build this as a // the API doesn't change, then this is an example of how to build this as a
// wrapper. // wrapper.
func fortyTwo(sig *types.Type) func([]types.Value) (types.Value, error) { func fortyTwo(sig *types.Type) func(context.Context, []types.Value) (types.Value, error) {
return func(input []types.Value) (types.Value, error) { return func(ctx context.Context, input []types.Value) (types.Value, error) {
return FortyTwo(sig, input) return FortyTwo(sig, input)
} }
} }

View File

@@ -30,6 +30,8 @@
package coremath package coremath
import ( import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
) )
@@ -42,7 +44,7 @@ func init() {
} }
// Minus1 takes an int and subtracts one from it. // Minus1 takes an int and subtracts one from it.
func Minus1(input []types.Value) (types.Value, error) { func Minus1(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: check for overflow // TODO: check for overflow
return &types.IntValue{ return &types.IntValue{
V: input[0].Int() - 1, V: input[0].Int() - 1,

View File

@@ -30,6 +30,7 @@
package coremath package coremath
import ( import (
"context"
"fmt" "fmt"
"math" "math"
@@ -54,7 +55,7 @@ func init() {
// both of KindInt or both of KindFloat, and it will return the same kind. If // both of KindInt or both of KindFloat, and it will return the same kind. If
// you pass in a divisor of zero, this will error, eg: mod(x, 0) = NaN. // you pass in a divisor of zero, this will error, eg: mod(x, 0) = NaN.
// TODO: consider returning zero instead of erroring? // TODO: consider returning zero instead of erroring?
func Mod(input []types.Value) (types.Value, error) { func Mod(ctx context.Context, input []types.Value) (types.Value, error) {
var x, y float64 var x, y float64
var float bool var float bool
k := input[0].Type().Kind k := input[0].Type().Kind

View File

@@ -30,6 +30,7 @@
package coremath package coremath
import ( import (
"context"
"fmt" "fmt"
"math" "math"
@@ -45,7 +46,7 @@ func init() {
} }
// Pow returns x ^ y, the base-x exponential of y. // Pow returns x ^ y, the base-x exponential of y.
func Pow(input []types.Value) (types.Value, error) { func Pow(ctx context.Context, input []types.Value) (types.Value, error) {
x, y := input[0].Float(), input[1].Float() x, y := input[0].Float(), input[1].Float()
// FIXME: check for overflow // FIXME: check for overflow
z := math.Pow(x, y) z := math.Pow(x, y)

View File

@@ -30,6 +30,7 @@
package coremath package coremath
import ( import (
"context"
"fmt" "fmt"
"math" "math"
@@ -45,7 +46,7 @@ func init() {
} }
// Sqrt returns sqrt(x), the square root of x. // Sqrt returns sqrt(x), the square root of x.
func Sqrt(input []types.Value) (types.Value, error) { func Sqrt(ctx context.Context, input []types.Value) (types.Value, error) {
x := input[0].Float() x := input[0].Float()
y := math.Sqrt(x) y := math.Sqrt(x)
if math.IsNaN(y) { if math.IsNaN(y) {

View File

@@ -30,6 +30,7 @@
package coremath package coremath
import ( import (
"context"
"fmt" "fmt"
"math" "math"
"testing" "testing"
@@ -40,7 +41,7 @@ import (
func testSqrtSuccess(input, sqrt float64) error { func testSqrtSuccess(input, sqrt float64) error {
inputVal := &types.FloatValue{V: input} inputVal := &types.FloatValue{V: input}
val, err := Sqrt([]types.Value{inputVal}) val, err := Sqrt(context.Background(), []types.Value{inputVal})
if err != nil { if err != nil {
return err return err
} }
@@ -52,7 +53,7 @@ func testSqrtSuccess(input, sqrt float64) error {
func testSqrtError(input float64) error { func testSqrtError(input float64) error {
inputVal := &types.FloatValue{V: input} inputVal := &types.FloatValue{V: input}
_, err := Sqrt([]types.Value{inputVal}) _, err := Sqrt(context.Background(), []types.Value{inputVal})
if err == nil { if err == nil {
return fmt.Errorf("expected error for input %f, got nil", input) return fmt.Errorf("expected error for input %f, got nil", input)
} }

View File

@@ -30,6 +30,7 @@
package corenet package corenet
import ( import (
"context"
"net" "net"
"strings" "strings"
@@ -45,7 +46,7 @@ func init() {
} }
// CidrToIP returns the IP from a CIDR address // CidrToIP returns the IP from a CIDR address
func CidrToIP(input []types.Value) (types.Value, error) { func CidrToIP(ctx context.Context, input []types.Value) (types.Value, error) {
cidr := input[0].Str() cidr := input[0].Str()
ip, _, err := net.ParseCIDR(strings.TrimSpace(cidr)) ip, _, err := net.ParseCIDR(strings.TrimSpace(cidr))
if err != nil { if err != nil {

View File

@@ -30,6 +30,7 @@
package corenet package corenet
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
@@ -61,7 +62,7 @@ func TestCidrToIP(t *testing.T) {
for _, ts := range cidrtests { for _, ts := range cidrtests {
test := ts test := ts
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
output, err := CidrToIP([]types.Value{&types.StrValue{V: test.input}}) output, err := CidrToIP(context.Background(), []types.Value{&types.StrValue{V: test.input}})
expectedStr := &types.StrValue{V: test.expected} expectedStr := &types.StrValue{V: test.expected}
if test.err != nil && err.Error() != test.err.Error() { if test.err != nil && err.Error() != test.err.Error() {

View File

@@ -30,6 +30,7 @@
package corenet package corenet
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strings" "strings"
@@ -51,7 +52,7 @@ func init() {
// MacFmt takes a MAC address with hyphens and converts it to a format with // MacFmt takes a MAC address with hyphens and converts it to a format with
// colons. // colons.
func MacFmt(input []types.Value) (types.Value, error) { func MacFmt(ctx context.Context, input []types.Value) (types.Value, error) {
mac := input[0].Str() mac := input[0].Str()
// Check if the MAC address is valid. // Check if the MAC address is valid.
@@ -70,7 +71,7 @@ func MacFmt(input []types.Value) (types.Value, error) {
// OldMacFmt takes a MAC address with colons and converts it to a format with // OldMacFmt takes a MAC address with colons and converts it to a format with
// hyphens. This is the old deprecated style that nobody likes. // hyphens. This is the old deprecated style that nobody likes.
func OldMacFmt(input []types.Value) (types.Value, error) { func OldMacFmt(ctx context.Context, input []types.Value) (types.Value, error) {
mac := input[0].Str() mac := input[0].Str()
// Check if the MAC address is valid. // Check if the MAC address is valid.

View File

@@ -30,6 +30,7 @@
package corenet package corenet
import ( import (
"context"
"testing" "testing"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -51,7 +52,7 @@ func TestMacFmt(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
m, err := MacFmt([]types.Value{&types.StrValue{V: tt.in}}) m, err := MacFmt(context.Background(), []types.Value{&types.StrValue{V: tt.in}})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -81,7 +82,7 @@ func TestOldMacFmt(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
m, err := OldMacFmt([]types.Value{&types.StrValue{V: tt.in}}) m, err := OldMacFmt(context.Background(), []types.Value{&types.StrValue{V: tt.in}})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@@ -30,6 +30,7 @@
package coreos package coreos
import ( import (
"context"
"os" "os"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -47,7 +48,7 @@ func init() {
// return different values depending on how this is deployed, so don't expect a // return different values depending on how this is deployed, so don't expect a
// result on your deploy client to behave the same as a server receiving code. // result on your deploy client to behave the same as a server receiving code.
// FIXME: Sanitize any command-line secrets we might pass in by cli. // FIXME: Sanitize any command-line secrets we might pass in by cli.
func Args([]types.Value) (types.Value, error) { func Args(context.Context, []types.Value) (types.Value, error) {
values := []types.Value{} values := []types.Value{}
for _, s := range os.Args { for _, s := range os.Args {
values = append(values, &types.StrValue{V: s}) values = append(values, &types.StrValue{V: s})

View File

@@ -30,6 +30,7 @@
package coreos package coreos
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@@ -51,7 +52,7 @@ func init() {
// ParseDistroUID parses a distro UID into its component values. If it cannot // ParseDistroUID parses a distro UID into its component values. If it cannot
// parse correctly, all the struct fields have the zero values. // parse correctly, all the struct fields have the zero values.
// NOTE: The UID pattern is subject to change. // NOTE: The UID pattern is subject to change.
func ParseDistroUID(input []types.Value) (types.Value, error) { func ParseDistroUID(ctx context.Context, input []types.Value) (types.Value, error) {
fn := func(distro, version, arch string) (types.Value, error) { fn := func(distro, version, arch string) (types.Value, error) {
st := types.NewStruct(types.NewType(structDistroUID)) st := types.NewStruct(types.NewType(structDistroUID))
if err := st.Set("distro", &types.StrValue{V: distro}); err != nil { if err := st.Set("distro", &types.StrValue{V: distro}); err != nil {

View File

@@ -30,6 +30,7 @@
package coreos package coreos
import ( import (
"context"
"os" "os"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -54,8 +55,9 @@ func init() {
// IsDebian detects if the os family is debian. // IsDebian detects if the os family is debian.
// TODO: Detect OS changes. // TODO: Detect OS changes.
func IsDebian(input []types.Value) (types.Value, error) { func IsDebian(ctx context.Context, input []types.Value) (types.Value, error) {
exists := true exists := true
// TODO: use ctx around io operations
_, err := os.Stat("/etc/debian_version") _, err := os.Stat("/etc/debian_version")
if os.IsNotExist(err) { if os.IsNotExist(err) {
exists = false exists = false
@@ -67,8 +69,9 @@ func IsDebian(input []types.Value) (types.Value, error) {
// IsRedHat detects if the os family is redhat. // IsRedHat detects if the os family is redhat.
// TODO: Detect OS changes. // TODO: Detect OS changes.
func IsRedHat(input []types.Value) (types.Value, error) { func IsRedHat(ctx context.Context, input []types.Value) (types.Value, error) {
exists := true exists := true
// TODO: use ctx around io operations
_, err := os.Stat("/etc/redhat-release") _, err := os.Stat("/etc/redhat-release")
if os.IsNotExist(err) { if os.IsNotExist(err) {
exists = false exists = false
@@ -80,8 +83,9 @@ func IsRedHat(input []types.Value) (types.Value, error) {
// IsArchLinux detects if the os family is archlinux. // IsArchLinux detects if the os family is archlinux.
// TODO: Detect OS changes. // TODO: Detect OS changes.
func IsArchLinux(input []types.Value) (types.Value, error) { func IsArchLinux(ctx context.Context, input []types.Value) (types.Value, error) {
exists := true exists := true
// TODO: use ctx around io operations
_, err := os.Stat("/etc/arch-release") _, err := os.Stat("/etc/arch-release")
if os.IsNotExist(err) { if os.IsNotExist(err) {
exists = false exists = false

View File

@@ -30,6 +30,7 @@
package core package core
import ( import (
"context"
"fmt" "fmt"
"github.com/purpleidea/mgmt/lang/funcs/simplepoly" "github.com/purpleidea/mgmt/lang/funcs/simplepoly"
@@ -52,7 +53,7 @@ func init() {
// Panic returns an error when it receives a non-empty string or a true boolean. // Panic returns an error when it receives a non-empty string or a true boolean.
// The error should cause the function engine to shutdown. If there's no error, // The error should cause the function engine to shutdown. If there's no error,
// it returns false. // it returns false.
func Panic(input []types.Value) (types.Value, error) { func Panic(ctx context.Context, input []types.Value) (types.Value, error) {
switch k := input[0].Type().Kind; k { switch k := input[0].Type().Kind; k {
case types.KindBool: case types.KindBool:
if input[0].Bool() { if input[0].Bool() {

View File

@@ -30,6 +30,7 @@
package coreregexp package coreregexp
import ( import (
"context"
"regexp" "regexp"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -45,7 +46,7 @@ func init() {
} }
// Match matches whether a string matches the regexp pattern. // Match matches whether a string matches the regexp pattern.
func Match(input []types.Value) (types.Value, error) { func Match(ctx context.Context, input []types.Value) (types.Value, error) {
pattern := input[0].Str() pattern := input[0].Str()
s := input[1].Str() s := input[1].Str()

View File

@@ -30,6 +30,7 @@
package coreregexp package coreregexp
import ( import (
"context"
"testing" "testing"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -76,7 +77,7 @@ func TestMatch0(t *testing.T) {
for i, x := range values { for i, x := range values {
pattern := &types.StrValue{V: x.pattern} pattern := &types.StrValue{V: x.pattern}
s := &types.StrValue{V: x.s} s := &types.StrValue{V: x.s}
val, err := Match([]types.Value{pattern, s}) val, err := Match(context.Background(), []types.Value{pattern, s})
if err != nil { if err != nil {
t.Errorf("test index %d failed with: %+v", i, err) t.Errorf("test index %d failed with: %+v", i, err)
} }

View File

@@ -30,6 +30,7 @@
package corestrings package corestrings
import ( import (
"context"
"strings" "strings"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -45,7 +46,7 @@ func init() {
// Split splits the input string using the separator and returns the segments as // Split splits the input string using the separator and returns the segments as
// a list. // a list.
func Split(input []types.Value) (types.Value, error) { func Split(ctx context.Context, input []types.Value) (types.Value, error) {
str, sep := input[0].Str(), input[1].Str() str, sep := input[0].Str(), input[1].Str()
segments := strings.Split(str, sep) segments := strings.Split(str, sep)

View File

@@ -30,6 +30,7 @@
package corestrings package corestrings
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
@@ -40,7 +41,7 @@ import (
func testSplit(input, sep string, output []string) error { func testSplit(input, sep string, output []string) error {
inputVal, sepVal := &types.StrValue{V: input}, &types.StrValue{V: sep} inputVal, sepVal := &types.StrValue{V: input}, &types.StrValue{V: sep}
val, err := Split([]types.Value{inputVal, sepVal}) val, err := Split(context.Background(), []types.Value{inputVal, sepVal})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -30,6 +30,7 @@
package corestrings package corestrings
import ( import (
"context"
"strings" "strings"
"github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -44,7 +45,7 @@ func init() {
} }
// ToLower turns a string to lowercase. // ToLower turns a string to lowercase.
func ToLower(input []types.Value) (types.Value, error) { func ToLower(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: strings.ToLower(input[0].Str()), V: strings.ToLower(input[0].Str()),
}, nil }, nil

View File

@@ -30,6 +30,7 @@
package corestrings package corestrings
import ( import (
"context"
"testing" "testing"
"github.com/purpleidea/mgmt/lang/types" "github.com/purpleidea/mgmt/lang/types"
@@ -37,7 +38,7 @@ import (
func testToLower(t *testing.T, input, expected string) { func testToLower(t *testing.T, input, expected string) {
inputStr := &types.StrValue{V: input} inputStr := &types.StrValue{V: input}
value, err := ToLower([]types.Value{inputStr}) value, err := ToLower(context.Background(), []types.Value{inputStr})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@@ -30,6 +30,7 @@
package coresys package coresys
import ( import (
"context"
"os" "os"
"strings" "strings"
@@ -58,7 +59,7 @@ func init() {
// GetEnv gets environment variable by name or returns empty string if non // GetEnv gets environment variable by name or returns empty string if non
// existing. // existing.
func GetEnv(input []types.Value) (types.Value, error) { func GetEnv(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: os.Getenv(input[0].Str()), V: os.Getenv(input[0].Str()),
}, nil }, nil
@@ -66,7 +67,7 @@ func GetEnv(input []types.Value) (types.Value, error) {
// DefaultEnv gets environment variable by name or returns default if non // DefaultEnv gets environment variable by name or returns default if non
// existing. // existing.
func DefaultEnv(input []types.Value) (types.Value, error) { func DefaultEnv(ctx context.Context, input []types.Value) (types.Value, error) {
value, exists := os.LookupEnv(input[0].Str()) value, exists := os.LookupEnv(input[0].Str())
if !exists { if !exists {
value = input[1].Str() value = input[1].Str()
@@ -77,7 +78,7 @@ func DefaultEnv(input []types.Value) (types.Value, error) {
} }
// HasEnv returns true if environment variable exists. // HasEnv returns true if environment variable exists.
func HasEnv(input []types.Value) (types.Value, error) { func HasEnv(ctx context.Context, input []types.Value) (types.Value, error) {
_, exists := os.LookupEnv(input[0].Str()) _, exists := os.LookupEnv(input[0].Str())
return &types.BoolValue{ return &types.BoolValue{
V: exists, V: exists,
@@ -85,7 +86,7 @@ func HasEnv(input []types.Value) (types.Value, error) {
} }
// Env returns a map of all keys and their values. // Env returns a map of all keys and their values.
func Env(input []types.Value) (types.Value, error) { func Env(ctx context.Context, input []types.Value) (types.Value, error) {
environ := make(map[types.Value]types.Value) environ := make(map[types.Value]types.Value)
for _, keyval := range os.Environ() { for _, keyval := range os.Environ() {
if i := strings.IndexRune(keyval, '='); i != -1 { if i := strings.IndexRune(keyval, '='); i != -1 {

View File

@@ -390,7 +390,7 @@ func (obj *TemplateFunc) Init(init *interfaces.Init) error {
} }
// run runs a template and returns the result. // run runs a template and returns the result.
func (obj *TemplateFunc) run(templateText string, vars types.Value) (string, error) { func (obj *TemplateFunc) run(ctx context.Context, templateText string, vars types.Value) (string, error) {
// see: https://golang.org/pkg/text/template/#FuncMap for more info // see: https://golang.org/pkg/text/template/#FuncMap for more info
// note: we can override any other functions by adding them here... // note: we can override any other functions by adding them here...
funcMap := map[string]interface{}{ funcMap := map[string]interface{}{
@@ -425,7 +425,7 @@ func (obj *TemplateFunc) run(templateText string, vars types.Value) (string, err
// parameter types. Functions meant to apply to arguments of // parameter types. Functions meant to apply to arguments of
// arbitrary type can use parameters of type interface{} or of // arbitrary type can use parameters of type interface{} or of
// type reflect.Value. // type reflect.Value.
f, err := wrap(name, fn) // wrap it so that it meets API expectations f, err := wrap(ctx, name, fn) // wrap it so that it meets API expectations
if err != nil { if err != nil {
obj.init.Logf("warning, skipping function named: `%s`, err: %v", name, err) obj.init.Logf("warning, skipping function named: `%s`, err: %v", name, err)
continue continue
@@ -538,7 +538,7 @@ func (obj *TemplateFunc) Stream(ctx context.Context) error {
vars = nil vars = nil
} }
result, err := obj.run(tmpl, vars) result, err := obj.run(ctx, tmpl, vars)
if err != nil { if err != nil {
return err // no errwrap needed b/c helper func return err // no errwrap needed b/c helper func
} }
@@ -585,7 +585,7 @@ func safename(name string) string {
// function API with what is expected from the reflection API. It returns a // function API with what is expected from the reflection API. It returns a
// version that includes the optional second error return value so that our // version that includes the optional second error return value so that our
// functions can return errors without causing a panic. // functions can return errors without causing a panic.
func wrap(name string, fn *types.FuncValue) (_ interface{}, reterr error) { func wrap(ctx context.Context, name string, fn *types.FuncValue) (_ interface{}, reterr error) {
defer func() { defer func() {
// catch unhandled panics // catch unhandled panics
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -633,8 +633,8 @@ func wrap(name string, fn *types.FuncValue) (_ interface{}, reterr error) {
innerArgs = append(innerArgs, v) innerArgs = append(innerArgs, v)
} }
result, err := fn.Call(innerArgs) // call it result, err := fn.Call(ctx, innerArgs) // call it
if err != nil { // function errored :( if err != nil { // function errored :(
// errwrap is a better way to report errors, if allowed! // errwrap is a better way to report errors, if allowed!
r := reflect.ValueOf(errwrap.Wrapf(err, "function `%s` errored", name)) r := reflect.ValueOf(errwrap.Wrapf(err, "function `%s` errored", name))
if !r.Type().ConvertibleTo(errorType) { // for fun! if !r.Type().ConvertibleTo(errorType) { // for fun!

View File

@@ -123,7 +123,7 @@ func init() {
simple.ModuleRegister(ModuleName, OneInstanceBFuncName, &types.FuncValue{ simple.ModuleRegister(ModuleName, OneInstanceBFuncName, &types.FuncValue{
T: types.NewType("func() str"), T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) { V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceBMutex.Lock() oneInstanceBMutex.Lock()
if oneInstanceBFlag { if oneInstanceBFlag {
panic("should not get called twice") panic("should not get called twice")
@@ -135,7 +135,7 @@ func init() {
}) })
simple.ModuleRegister(ModuleName, OneInstanceDFuncName, &types.FuncValue{ simple.ModuleRegister(ModuleName, OneInstanceDFuncName, &types.FuncValue{
T: types.NewType("func() str"), T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) { V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceDMutex.Lock() oneInstanceDMutex.Lock()
if oneInstanceDFlag { if oneInstanceDFlag {
panic("should not get called twice") panic("should not get called twice")
@@ -147,7 +147,7 @@ func init() {
}) })
simple.ModuleRegister(ModuleName, OneInstanceFFuncName, &types.FuncValue{ simple.ModuleRegister(ModuleName, OneInstanceFFuncName, &types.FuncValue{
T: types.NewType("func() str"), T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) { V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceFMutex.Lock() oneInstanceFMutex.Lock()
if oneInstanceFFlag { if oneInstanceFFlag {
panic("should not get called twice") panic("should not get called twice")
@@ -159,7 +159,7 @@ func init() {
}) })
simple.ModuleRegister(ModuleName, OneInstanceHFuncName, &types.FuncValue{ simple.ModuleRegister(ModuleName, OneInstanceHFuncName, &types.FuncValue{
T: types.NewType("func() str"), T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) { V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceHMutex.Lock() oneInstanceHMutex.Lock()
if oneInstanceHFlag { if oneInstanceHFlag {
panic("should not get called twice") panic("should not get called twice")

View File

@@ -30,6 +30,7 @@
package core package core
import ( import (
"context"
"testpkg" "testpkg"
"github.com/purpleidea/mgmt/lang/funcs/funcgen/util" "github.com/purpleidea/mgmt/lang/funcs/funcgen/util"
@@ -65,25 +66,25 @@ func init() {
} }
func TestpkgAllKind(input []types.Value) (types.Value, error) { func TestpkgAllKind(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{ return &types.FloatValue{
V: testpkg.AllKind(input[0].Int(), input[1].Str()), V: testpkg.AllKind(input[0].Int(), input[1].Str()),
}, nil }, nil
} }
func TestpkgToUpper(input []types.Value) (types.Value, error) { func TestpkgToUpper(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: testpkg.ToUpper(input[0].Str()), V: testpkg.ToUpper(input[0].Str()),
}, nil }, nil
} }
func TestpkgMax(input []types.Value) (types.Value, error) { func TestpkgMax(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{ return &types.FloatValue{
V: testpkg.Max(input[0].Float(), input[1].Float()), V: testpkg.Max(input[0].Float(), input[1].Float()),
}, nil }, nil
} }
func TestpkgWithError(input []types.Value) (types.Value, error) { func TestpkgWithError(ctx context.Context, input []types.Value) (types.Value, error) {
v, err := testpkg.WithError(input[0].Str()) v, err := testpkg.WithError(input[0].Str())
if err != nil { if err != nil {
return nil, err return nil, err
@@ -93,13 +94,13 @@ func TestpkgWithError(input []types.Value) (types.Value, error) {
}, nil }, nil
} }
func TestpkgWithInt(input []types.Value) (types.Value, error) { func TestpkgWithInt(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: testpkg.WithInt(input[0].Float(), int(input[1].Int()), input[2].Int(), int(input[3].Int()), int(input[4].Int()), input[5].Bool(), input[6].Str()), V: testpkg.WithInt(input[0].Float(), int(input[1].Int()), input[2].Int(), int(input[3].Int()), int(input[4].Int()), input[5].Bool(), input[6].Str()),
}, nil }, nil
} }
func TestpkgSuperByte(input []types.Value) (types.Value, error) { func TestpkgSuperByte(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: string(testpkg.SuperByte([]byte(input[0].Str()), input[1].Str())), V: string(testpkg.SuperByte([]byte(input[0].Str()), input[1].Str())),
}, nil }, nil

View File

@@ -30,6 +30,7 @@
package core package core
import ( import (
"context"
{{ range $i, $func := .Packages }} {{ if not (eq .Alias "") }}{{.Alias}} {{end}}"{{.Name}}" {{ range $i, $func := .Packages }} {{ if not (eq .Alias "") }}{{.Alias}} {{end}}"{{.Name}}"
{{ end }} {{ end }}
"github.com/purpleidea/mgmt/lang/funcs/funcgen/util" "github.com/purpleidea/mgmt/lang/funcs/funcgen/util"
@@ -45,7 +46,7 @@ func init() {
{{ end }} {{ end }}
} }
{{ range $i, $func := .Functions }} {{ range $i, $func := .Functions }}
{{$func.Help}}func {{$func.InternalName}}(input []types.Value) (types.Value, error) { {{$func.Help}}func {{$func.InternalName}}(ctx context.Context, input []types.Value) (types.Value, error) {
{{- if $func.Errorful }} {{- if $func.Errorful }}
v, err := {{ if not (eq $func.GolangPackage.Alias "") }}{{$func.GolangPackage.Alias}}{{else}}{{$func.GolangPackage.Name}}{{end}}.{{$func.GolangFunc}}({{$func.MakeGolangArgs}}) v, err := {{ if not (eq $func.GolangPackage.Alias "") }}{{$func.GolangPackage.Alias}}{{else}}{{$func.GolangPackage.Name}}{{end}}.{{$func.GolangFunc}}({{$func.MakeGolangArgs}})
if err != nil { if err != nil {

View File

@@ -55,7 +55,7 @@ func init() {
// concatenation // concatenation
RegisterOperator("+", &types.FuncValue{ RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a str, b str) str"), T: types.NewType("func(a str, b str) str"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{ return &types.StrValue{
V: input[0].Str() + input[1].Str(), V: input[0].Str() + input[1].Str(),
}, nil }, nil
@@ -64,7 +64,7 @@ func init() {
// addition // addition
RegisterOperator("+", &types.FuncValue{ RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a int, b int) int"), T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
//if l := len(input); l != 2 { //if l := len(input); l != 2 {
// return nil, fmt.Errorf("expected two inputs, got: %d", l) // return nil, fmt.Errorf("expected two inputs, got: %d", l)
//} //}
@@ -77,7 +77,7 @@ func init() {
// floating-point addition // floating-point addition
RegisterOperator("+", &types.FuncValue{ RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a float, b float) float"), T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{ return &types.FloatValue{
V: input[0].Float() + input[1].Float(), V: input[0].Float() + input[1].Float(),
}, nil }, nil
@@ -87,7 +87,7 @@ func init() {
// subtraction // subtraction
RegisterOperator("-", &types.FuncValue{ RegisterOperator("-", &types.FuncValue{
T: types.NewType("func(a int, b int) int"), T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.IntValue{ return &types.IntValue{
V: input[0].Int() - input[1].Int(), V: input[0].Int() - input[1].Int(),
}, nil }, nil
@@ -96,7 +96,7 @@ func init() {
// floating-point subtraction // floating-point subtraction
RegisterOperator("-", &types.FuncValue{ RegisterOperator("-", &types.FuncValue{
T: types.NewType("func(a float, b float) float"), T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{ return &types.FloatValue{
V: input[0].Float() - input[1].Float(), V: input[0].Float() - input[1].Float(),
}, nil }, nil
@@ -106,7 +106,7 @@ func init() {
// multiplication // multiplication
RegisterOperator("*", &types.FuncValue{ RegisterOperator("*", &types.FuncValue{
T: types.NewType("func(a int, b int) int"), T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// FIXME: check for overflow? // FIXME: check for overflow?
return &types.IntValue{ return &types.IntValue{
V: input[0].Int() * input[1].Int(), V: input[0].Int() * input[1].Int(),
@@ -116,7 +116,7 @@ func init() {
// floating-point multiplication // floating-point multiplication
RegisterOperator("*", &types.FuncValue{ RegisterOperator("*", &types.FuncValue{
T: types.NewType("func(a float, b float) float"), T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{ return &types.FloatValue{
V: input[0].Float() * input[1].Float(), V: input[0].Float() * input[1].Float(),
}, nil }, nil
@@ -127,7 +127,7 @@ func init() {
// division // division
RegisterOperator("/", &types.FuncValue{ RegisterOperator("/", &types.FuncValue{
T: types.NewType("func(a int, b int) float"), T: types.NewType("func(a int, b int) float"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
divisor := input[1].Int() divisor := input[1].Int()
if divisor == 0 { if divisor == 0 {
return nil, fmt.Errorf("can't divide by zero") return nil, fmt.Errorf("can't divide by zero")
@@ -140,7 +140,7 @@ func init() {
// floating-point division // floating-point division
RegisterOperator("/", &types.FuncValue{ RegisterOperator("/", &types.FuncValue{
T: types.NewType("func(a float, b float) float"), T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
divisor := input[1].Float() divisor := input[1].Float()
if divisor == 0.0 { if divisor == 0.0 {
return nil, fmt.Errorf("can't divide by zero") return nil, fmt.Errorf("can't divide by zero")
@@ -154,7 +154,7 @@ func init() {
// string equality // string equality
RegisterOperator("==", &types.FuncValue{ RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a str, b str) bool"), T: types.NewType("func(a str, b str) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Str() == input[1].Str(), V: input[0].Str() == input[1].Str(),
}, nil }, nil
@@ -163,7 +163,7 @@ func init() {
// bool equality // bool equality
RegisterOperator("==", &types.FuncValue{ RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"), T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Bool() == input[1].Bool(), V: input[0].Bool() == input[1].Bool(),
}, nil }, nil
@@ -172,7 +172,7 @@ func init() {
// int equality // int equality
RegisterOperator("==", &types.FuncValue{ RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"), T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Int() == input[1].Int(), V: input[0].Int() == input[1].Int(),
}, nil }, nil
@@ -181,7 +181,7 @@ func init() {
// floating-point equality // floating-point equality
RegisterOperator("==", &types.FuncValue{ RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"), T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check? // TODO: should we do an epsilon check?
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Float() == input[1].Float(), V: input[0].Float() == input[1].Float(),
@@ -192,7 +192,7 @@ func init() {
// string in-equality // string in-equality
RegisterOperator("!=", &types.FuncValue{ RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a str, b str) bool"), T: types.NewType("func(a str, b str) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Str() != input[1].Str(), V: input[0].Str() != input[1].Str(),
}, nil }, nil
@@ -201,7 +201,7 @@ func init() {
// bool in-equality // bool in-equality
RegisterOperator("!=", &types.FuncValue{ RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"), T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Bool() != input[1].Bool(), V: input[0].Bool() != input[1].Bool(),
}, nil }, nil
@@ -210,7 +210,7 @@ func init() {
// int in-equality // int in-equality
RegisterOperator("!=", &types.FuncValue{ RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"), T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Int() != input[1].Int(), V: input[0].Int() != input[1].Int(),
}, nil }, nil
@@ -219,7 +219,7 @@ func init() {
// floating-point in-equality // floating-point in-equality
RegisterOperator("!=", &types.FuncValue{ RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"), T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check? // TODO: should we do an epsilon check?
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Float() != input[1].Float(), V: input[0].Float() != input[1].Float(),
@@ -230,7 +230,7 @@ func init() {
// less-than // less-than
RegisterOperator("<", &types.FuncValue{ RegisterOperator("<", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"), T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Int() < input[1].Int(), V: input[0].Int() < input[1].Int(),
}, nil }, nil
@@ -239,7 +239,7 @@ func init() {
// floating-point less-than // floating-point less-than
RegisterOperator("<", &types.FuncValue{ RegisterOperator("<", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"), T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check? // TODO: should we do an epsilon check?
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Float() < input[1].Float(), V: input[0].Float() < input[1].Float(),
@@ -249,7 +249,7 @@ func init() {
// greater-than // greater-than
RegisterOperator(">", &types.FuncValue{ RegisterOperator(">", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"), T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Int() > input[1].Int(), V: input[0].Int() > input[1].Int(),
}, nil }, nil
@@ -258,7 +258,7 @@ func init() {
// floating-point greater-than // floating-point greater-than
RegisterOperator(">", &types.FuncValue{ RegisterOperator(">", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"), T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check? // TODO: should we do an epsilon check?
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Float() > input[1].Float(), V: input[0].Float() > input[1].Float(),
@@ -268,7 +268,7 @@ func init() {
// less-than-equal // less-than-equal
RegisterOperator("<=", &types.FuncValue{ RegisterOperator("<=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"), T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Int() <= input[1].Int(), V: input[0].Int() <= input[1].Int(),
}, nil }, nil
@@ -277,7 +277,7 @@ func init() {
// floating-point less-than-equal // floating-point less-than-equal
RegisterOperator("<=", &types.FuncValue{ RegisterOperator("<=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"), T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check? // TODO: should we do an epsilon check?
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Float() <= input[1].Float(), V: input[0].Float() <= input[1].Float(),
@@ -287,7 +287,7 @@ func init() {
// greater-than-equal // greater-than-equal
RegisterOperator(">=", &types.FuncValue{ RegisterOperator(">=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"), T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Int() >= input[1].Int(), V: input[0].Int() >= input[1].Int(),
}, nil }, nil
@@ -296,7 +296,7 @@ func init() {
// floating-point greater-than-equal // floating-point greater-than-equal
RegisterOperator(">=", &types.FuncValue{ RegisterOperator(">=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"), T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check? // TODO: should we do an epsilon check?
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Float() >= input[1].Float(), V: input[0].Float() >= input[1].Float(),
@@ -309,7 +309,7 @@ func init() {
// short-circuit operators, and does it matter? // short-circuit operators, and does it matter?
RegisterOperator("and", &types.FuncValue{ RegisterOperator("and", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"), T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Bool() && input[1].Bool(), V: input[0].Bool() && input[1].Bool(),
}, nil }, nil
@@ -318,7 +318,7 @@ func init() {
// logical or // logical or
RegisterOperator("or", &types.FuncValue{ RegisterOperator("or", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"), T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: input[0].Bool() || input[1].Bool(), V: input[0].Bool() || input[1].Bool(),
}, nil }, nil
@@ -328,7 +328,7 @@ func init() {
// logical not (unary operator) // logical not (unary operator)
RegisterOperator("not", &types.FuncValue{ RegisterOperator("not", &types.FuncValue{
T: types.NewType("func(a bool) bool"), T: types.NewType("func(a bool) bool"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{ return &types.BoolValue{
V: !input[0].Bool(), V: !input[0].Bool(),
}, nil }, nil
@@ -338,7 +338,7 @@ func init() {
// pi operator (this is an easter egg to demo a zero arg operator) // pi operator (this is an easter egg to demo a zero arg operator)
RegisterOperator("π", &types.FuncValue{ RegisterOperator("π", &types.FuncValue{
T: types.NewType("func() float"), T: types.NewType("func() float"),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{ return &types.FloatValue{
V: math.Pi, V: math.Pi,
}, nil }, nil
@@ -938,7 +938,7 @@ func (obj *OperatorFunc) Stream(ctx context.Context) error {
lastOp = op lastOp = op
var result types.Value var result types.Value
result, err := fn.Call(args) // run the function result, err := fn.Call(ctx, args) // run the function
if err != nil { if err != nil {
return errwrap.Wrapf(err, "problem running function") return errwrap.Wrapf(err, "problem running function")
} }

View File

@@ -173,7 +173,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error {
values = append(values, x) values = append(values, x)
} }
result, err := obj.Fn.Call(values) // (Value, error) result, err := obj.Fn.Call(ctx, values) // (Value, error)
if err != nil { if err != nil {
return errwrap.Wrapf(err, "simple function errored") return errwrap.Wrapf(err, "simple function errored")
} }
@@ -244,7 +244,7 @@ func StructRegister(moduleName string, args interface{}) error {
ModuleRegister(moduleName, name, &types.FuncValue{ ModuleRegister(moduleName, name, &types.FuncValue{
T: types.NewType(fmt.Sprintf("func() %s", typed.String())), T: types.NewType(fmt.Sprintf("func() %s", typed.String())),
V: func(input []types.Value) (types.Value, error) { V: func(ctx context.Context, input []types.Value) (types.Value, error) {
//if args == nil { //if args == nil {
// // programming error // // programming error
// return nil, fmt.Errorf("could not convert/access our struct") // return nil, fmt.Errorf("could not convert/access our struct")

View File

@@ -602,7 +602,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error {
if obj.init.Debug { if obj.init.Debug {
obj.init.Logf("Calling function with: %+v", values) obj.init.Logf("Calling function with: %+v", values)
} }
result, err := obj.fn.Call(values) // (Value, error) result, err := obj.fn.Call(ctx, values) // (Value, error)
if err != nil { if err != nil {
if obj.init.Debug { if obj.init.Debug {
obj.init.Logf("Function returned error: %+v", err) obj.init.Logf("Function returned error: %+v", err)

View File

@@ -30,6 +30,7 @@
package types package types
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -230,7 +231,7 @@ func ValueOf(v reflect.Value) (Value, error) {
return nil, fmt.Errorf("cannot only represent functions with one output value") return nil, fmt.Errorf("cannot only represent functions with one output value")
} }
f := func(args []Value) (Value, error) { f := func(ctx context.Context, args []Value) (Value, error) {
in := []reflect.Value{} in := []reflect.Value{}
for _, x := range args { for _, x := range args {
// TODO: should we build this method instead? // TODO: should we build this method instead?
@@ -239,6 +240,7 @@ func ValueOf(v reflect.Value) (Value, error) {
in = append(in, v) in = append(in, v)
} }
// FIXME: can we pass in ctx ?
// FIXME: can we trap panic's ? // FIXME: can we trap panic's ?
out := value.Call(in) // []reflect.Value out := value.Call(in) // []reflect.Value
if len(out) != 1 { // TODO: panic, b/c already checked in TypeOf? if len(out) != 1 { // TODO: panic, b/c already checked in TypeOf?
@@ -1207,7 +1209,7 @@ func (obj *StructValue) Lookup(k string) (value Value, exists bool) {
// Func nodes. // Func nodes.
type FuncValue struct { type FuncValue struct {
Base Base
V func([]Value) (Value, error) V func(context.Context, []Value) (Value, error)
T *Type // contains ordered field types, arg names are a bonus part T *Type // contains ordered field types, arg names are a bonus part
} }
@@ -1217,7 +1219,7 @@ func NewFunc(t *Type) *FuncValue {
if t.Kind != KindFunc { if t.Kind != KindFunc {
return nil // sanity check return nil // sanity check
} }
v := func([]Value) (Value, error) { v := func(context.Context, []Value) (Value, error) {
// You were not supposed to call the temporary function, you // You were not supposed to call the temporary function, you
// were supposed to replace it with a real implementation! // were supposed to replace it with a real implementation!
return nil, fmt.Errorf("nil function") return nil, fmt.Errorf("nil function")
@@ -1301,7 +1303,7 @@ func (obj *FuncValue) Value() interface{} {
// Call runs the function value and returns its result. It returns an error if // Call runs the function value and returns its result. It returns an error if
// something goes wrong during execution, and panic's if you call this with // something goes wrong during execution, and panic's if you call this with
// inappropriate input types, or if it returns an inappropriate output type. // inappropriate input types, or if it returns an inappropriate output type.
func (obj *FuncValue) Call(args []Value) (Value, error) { func (obj *FuncValue) Call(ctx context.Context, args []Value) (Value, error) {
// cmp input args type to obj.T // cmp input args type to obj.T
length := len(obj.T.Ord) length := len(obj.T.Ord)
if length != len(args) { if length != len(args) {
@@ -1313,7 +1315,7 @@ func (obj *FuncValue) Call(args []Value) (Value, error) {
} }
} }
result, err := obj.V(args) // call it result, err := obj.V(ctx, args) // call it
if result == nil { if result == nil {
if err == nil { if err == nil {
return nil, fmt.Errorf("function returned nil result") return nil, fmt.Errorf("function returned nil result")