lang: core, funcs, types: Add ctx to simple func

Plumb through the standard context.Context so that a function can be
cancelled if someone requests this. It makes it less awkward to write
simple functions that might depend on io or network access.
This commit is contained in:
James Shubin
2024-05-09 19:25:46 -04:00
parent 3b754d5324
commit 415e22abe2
51 changed files with 166 additions and 108 deletions

View File

@@ -30,6 +30,7 @@
package core
import (
"context"
"testpkg"
"github.com/purpleidea/mgmt/lang/funcs/funcgen/util"
@@ -65,25 +66,25 @@ func init() {
}
func TestpkgAllKind(input []types.Value) (types.Value, error) {
func TestpkgAllKind(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: testpkg.AllKind(input[0].Int(), input[1].Str()),
}, nil
}
func TestpkgToUpper(input []types.Value) (types.Value, error) {
func TestpkgToUpper(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: testpkg.ToUpper(input[0].Str()),
}, nil
}
func TestpkgMax(input []types.Value) (types.Value, error) {
func TestpkgMax(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: testpkg.Max(input[0].Float(), input[1].Float()),
}, nil
}
func TestpkgWithError(input []types.Value) (types.Value, error) {
func TestpkgWithError(ctx context.Context, input []types.Value) (types.Value, error) {
v, err := testpkg.WithError(input[0].Str())
if err != nil {
return nil, err
@@ -93,13 +94,13 @@ func TestpkgWithError(input []types.Value) (types.Value, error) {
}, nil
}
func TestpkgWithInt(input []types.Value) (types.Value, error) {
func TestpkgWithInt(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: testpkg.WithInt(input[0].Float(), int(input[1].Int()), input[2].Int(), int(input[3].Int()), int(input[4].Int()), input[5].Bool(), input[6].Str()),
}, nil
}
func TestpkgSuperByte(input []types.Value) (types.Value, error) {
func TestpkgSuperByte(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: string(testpkg.SuperByte([]byte(input[0].Str()), input[1].Str())),
}, nil

View File

@@ -30,6 +30,7 @@
package core
import (
"context"
{{ range $i, $func := .Packages }} {{ if not (eq .Alias "") }}{{.Alias}} {{end}}"{{.Name}}"
{{ end }}
"github.com/purpleidea/mgmt/lang/funcs/funcgen/util"
@@ -45,7 +46,7 @@ func init() {
{{ end }}
}
{{ range $i, $func := .Functions }}
{{$func.Help}}func {{$func.InternalName}}(input []types.Value) (types.Value, error) {
{{$func.Help}}func {{$func.InternalName}}(ctx context.Context, input []types.Value) (types.Value, error) {
{{- if $func.Errorful }}
v, err := {{ if not (eq $func.GolangPackage.Alias "") }}{{$func.GolangPackage.Alias}}{{else}}{{$func.GolangPackage.Name}}{{end}}.{{$func.GolangFunc}}({{$func.MakeGolangArgs}})
if err != nil {

View File

@@ -55,7 +55,7 @@ func init() {
// concatenation
RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a str, b str) str"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: input[0].Str() + input[1].Str(),
}, nil
@@ -64,7 +64,7 @@ func init() {
// addition
RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
//if l := len(input); l != 2 {
// return nil, fmt.Errorf("expected two inputs, got: %d", l)
//}
@@ -77,7 +77,7 @@ func init() {
// floating-point addition
RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: input[0].Float() + input[1].Float(),
}, nil
@@ -87,7 +87,7 @@ func init() {
// subtraction
RegisterOperator("-", &types.FuncValue{
T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.IntValue{
V: input[0].Int() - input[1].Int(),
}, nil
@@ -96,7 +96,7 @@ func init() {
// floating-point subtraction
RegisterOperator("-", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: input[0].Float() - input[1].Float(),
}, nil
@@ -106,7 +106,7 @@ func init() {
// multiplication
RegisterOperator("*", &types.FuncValue{
T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// FIXME: check for overflow?
return &types.IntValue{
V: input[0].Int() * input[1].Int(),
@@ -116,7 +116,7 @@ func init() {
// floating-point multiplication
RegisterOperator("*", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: input[0].Float() * input[1].Float(),
}, nil
@@ -127,7 +127,7 @@ func init() {
// division
RegisterOperator("/", &types.FuncValue{
T: types.NewType("func(a int, b int) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
divisor := input[1].Int()
if divisor == 0 {
return nil, fmt.Errorf("can't divide by zero")
@@ -140,7 +140,7 @@ func init() {
// floating-point division
RegisterOperator("/", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
divisor := input[1].Float()
if divisor == 0.0 {
return nil, fmt.Errorf("can't divide by zero")
@@ -154,7 +154,7 @@ func init() {
// string equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a str, b str) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Str() == input[1].Str(),
}, nil
@@ -163,7 +163,7 @@ func init() {
// bool equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() == input[1].Bool(),
}, nil
@@ -172,7 +172,7 @@ func init() {
// int equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() == input[1].Int(),
}, nil
@@ -181,7 +181,7 @@ func init() {
// floating-point equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() == input[1].Float(),
@@ -192,7 +192,7 @@ func init() {
// string in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a str, b str) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Str() != input[1].Str(),
}, nil
@@ -201,7 +201,7 @@ func init() {
// bool in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() != input[1].Bool(),
}, nil
@@ -210,7 +210,7 @@ func init() {
// int in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() != input[1].Int(),
}, nil
@@ -219,7 +219,7 @@ func init() {
// floating-point in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() != input[1].Float(),
@@ -230,7 +230,7 @@ func init() {
// less-than
RegisterOperator("<", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() < input[1].Int(),
}, nil
@@ -239,7 +239,7 @@ func init() {
// floating-point less-than
RegisterOperator("<", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() < input[1].Float(),
@@ -249,7 +249,7 @@ func init() {
// greater-than
RegisterOperator(">", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() > input[1].Int(),
}, nil
@@ -258,7 +258,7 @@ func init() {
// floating-point greater-than
RegisterOperator(">", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() > input[1].Float(),
@@ -268,7 +268,7 @@ func init() {
// less-than-equal
RegisterOperator("<=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() <= input[1].Int(),
}, nil
@@ -277,7 +277,7 @@ func init() {
// floating-point less-than-equal
RegisterOperator("<=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() <= input[1].Float(),
@@ -287,7 +287,7 @@ func init() {
// greater-than-equal
RegisterOperator(">=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() >= input[1].Int(),
}, nil
@@ -296,7 +296,7 @@ func init() {
// floating-point greater-than-equal
RegisterOperator(">=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() >= input[1].Float(),
@@ -309,7 +309,7 @@ func init() {
// short-circuit operators, and does it matter?
RegisterOperator("and", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() && input[1].Bool(),
}, nil
@@ -318,7 +318,7 @@ func init() {
// logical or
RegisterOperator("or", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() || input[1].Bool(),
}, nil
@@ -328,7 +328,7 @@ func init() {
// logical not (unary operator)
RegisterOperator("not", &types.FuncValue{
T: types.NewType("func(a bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: !input[0].Bool(),
}, nil
@@ -338,7 +338,7 @@ func init() {
// pi operator (this is an easter egg to demo a zero arg operator)
RegisterOperator("π", &types.FuncValue{
T: types.NewType("func() float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: math.Pi,
}, nil
@@ -938,7 +938,7 @@ func (obj *OperatorFunc) Stream(ctx context.Context) error {
lastOp = op
var result types.Value
result, err := fn.Call(args) // run the function
result, err := fn.Call(ctx, args) // run the function
if err != nil {
return errwrap.Wrapf(err, "problem running function")
}

View File

@@ -173,7 +173,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error {
values = append(values, x)
}
result, err := obj.Fn.Call(values) // (Value, error)
result, err := obj.Fn.Call(ctx, values) // (Value, error)
if err != nil {
return errwrap.Wrapf(err, "simple function errored")
}
@@ -244,7 +244,7 @@ func StructRegister(moduleName string, args interface{}) error {
ModuleRegister(moduleName, name, &types.FuncValue{
T: types.NewType(fmt.Sprintf("func() %s", typed.String())),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
//if args == nil {
// // programming error
// return nil, fmt.Errorf("could not convert/access our struct")

View File

@@ -602,7 +602,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error {
if obj.init.Debug {
obj.init.Logf("Calling function with: %+v", values)
}
result, err := obj.fn.Call(values) // (Value, error)
result, err := obj.fn.Call(ctx, values) // (Value, error)
if err != nil {
if obj.init.Debug {
obj.init.Logf("Function returned error: %+v", err)