lang: core, funcs, types: Add ctx to simple func

Plumb through the standard context.Context so that a function can be
cancelled if someone requests this. It makes it less awkward to write
simple functions that might depend on io or network access.
This commit is contained in:
James Shubin
2024-05-09 19:25:46 -04:00
parent 3b754d5324
commit 415e22abe2
51 changed files with 166 additions and 108 deletions

View File

@@ -30,6 +30,8 @@
package core
import (
"context"
"github.com/purpleidea/mgmt/lang/funcs"
"github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types"
@@ -48,7 +50,7 @@ func init() {
}
// Concat concatenates two strings together.
func Concat(input []types.Value) (types.Value, error) {
func Concat(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: input[0].Str() + input[1].Str(),
}, nil

View File

@@ -30,6 +30,7 @@
package convert
import (
"context"
"strconv"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -45,7 +46,7 @@ func init() {
// FormatBool converts a boolean to a string representation that can be consumed
// by ParseBool. This value will be `"true"` or `"false"`.
func FormatBool(input []types.Value) (types.Value, error) {
func FormatBool(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: strconv.FormatBool(input[0].Bool()),
}, nil

View File

@@ -30,6 +30,7 @@
package convert
import (
"context"
"fmt"
"strconv"
@@ -48,7 +49,7 @@ func init() {
// it an invalid value. Valid values match what is accepted by the golang
// strconv.ParseBool function. It's recommended to use the strings `true` or
// `false` if you are undecided about what string representation to choose.
func ParseBool(input []types.Value) (types.Value, error) {
func ParseBool(ctx context.Context, input []types.Value) (types.Value, error) {
s := input[0].Str()
b, err := strconv.ParseBool(s)
if err != nil {

View File

@@ -30,6 +30,8 @@
package convert
import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types"
)
@@ -42,7 +44,7 @@ func init() {
}
// ToFloat converts an integer to a float.
func ToFloat(input []types.Value) (types.Value, error) {
func ToFloat(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: float64(input[0].Int()),
}, nil

View File

@@ -30,13 +30,14 @@
package convert
import (
"context"
"testing"
"github.com/purpleidea/mgmt/lang/types"
)
func testToFloat(t *testing.T, input int64, expected float64) {
got, err := ToFloat([]types.Value{&types.IntValue{V: input}})
got, err := ToFloat(context.Background(), []types.Value{&types.IntValue{V: input}})
if err != nil {
t.Error(err)
return

View File

@@ -30,6 +30,8 @@
package convert
import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types"
)
@@ -42,7 +44,7 @@ func init() {
}
// ToInt converts a float to an integer.
func ToInt(input []types.Value) (types.Value, error) {
func ToInt(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.IntValue{
V: int64(input[0].Float()),
}, nil

View File

@@ -30,6 +30,7 @@
package convert
import (
"context"
"testing"
"github.com/purpleidea/mgmt/lang/types"
@@ -37,7 +38,7 @@ import (
func testToInt(t *testing.T, input float64, expected int64) {
got, err := ToInt([]types.Value{&types.FloatValue{V: input}})
got, err := ToInt(context.Background(), []types.Value{&types.FloatValue{V: input}})
if err != nil {
t.Error(err)
return

View File

@@ -30,6 +30,7 @@
package convert
import (
"context"
"strconv"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -49,7 +50,7 @@ func init() {
}
// IntToStr converts an integer to a string.
func IntToStr(input []types.Value) (types.Value, error) {
func IntToStr(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: strconv.Itoa(int(input[0].Int())),
}, nil

View File

@@ -30,6 +30,7 @@
package coredatetime
import (
"context"
"fmt"
"time"
@@ -48,7 +49,7 @@ func init() {
// has to be defined like specified by the golang "time" package. The time is
// the number of seconds since the epoch, and matches what comes from our Now
// function. Golang documentation: https://golang.org/pkg/time/#Time.Format
func Format(input []types.Value) (types.Value, error) {
func Format(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int()
if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -32,6 +32,7 @@
package coredatetime
import (
"context"
"testing"
"github.com/purpleidea/mgmt/lang/types"
@@ -41,7 +42,7 @@ func TestFormat(t *testing.T) {
inputVal := &types.IntValue{V: 1443158163}
inputFormat := &types.StrValue{V: "2006"}
val, err := Format([]types.Value{inputVal, inputFormat})
val, err := Format(context.Background(), []types.Value{inputVal, inputFormat})
if err != nil {
t.Error(err)
}

View File

@@ -30,6 +30,7 @@
package coredatetime
import (
"context"
"fmt"
"time"
@@ -47,7 +48,7 @@ func init() {
// Hour returns the hour of the day corresponding to the input time. The time is
// the number of seconds since the epoch, and matches what comes from our Now
// function.
func Hour(input []types.Value) (types.Value, error) {
func Hour(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int()
if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -30,6 +30,7 @@
package coredatetime
import (
"context"
"fmt"
"time"
@@ -41,7 +42,7 @@ func init() {
// FIXME: consider renaming this to printf, and add in a format string?
simple.ModuleRegister(ModuleName, "print", &types.FuncValue{
T: types.NewType("func(a int) str"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int()
if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -30,6 +30,7 @@
package coredatetime
import (
"context"
"fmt"
"strings"
"time"
@@ -48,7 +49,7 @@ func init() {
// Weekday returns the lowercased day of the week corresponding to the input
// time. The time is the number of seconds since the epoch, and matches what
// comes from our Now function.
func Weekday(input []types.Value) (types.Value, error) {
func Weekday(ctx context.Context, input []types.Value) (types.Value, error) {
epochDelta := input[0].Int()
if epochDelta < 0 {
return nil, fmt.Errorf("epoch delta must be positive")

View File

@@ -418,11 +418,14 @@ func (obj *provisioner) Register(moduleName string) error {
// Build a few separately...
simple.ModuleRegister(moduleName, "cli_password", &types.FuncValue{
T: types.NewType("func() str"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
if obj.localArgs == nil {
// programming error
return nil, fmt.Errorf("could not convert/access our struct")
}
// TODO: plumb through the password lookup here instead?
//localArgs := *obj.localArgs // optional
return &types.StrValue{
V: obj.password,

View File

@@ -30,6 +30,8 @@
package coreexample
import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types"
)
@@ -40,7 +42,7 @@ const Answer = 42
func init() {
simple.ModuleRegister(ModuleName, "answer", &types.FuncValue{
T: types.NewType("func() int"),
V: func([]types.Value) (types.Value, error) {
V: func(context.Context, []types.Value) (types.Value, error) {
return &types.IntValue{V: Answer}, nil
},
})

View File

@@ -30,6 +30,7 @@
package coreexample
import (
"context"
"fmt"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -39,7 +40,7 @@ import (
func init() {
simple.ModuleRegister(ModuleName, "errorbool", &types.FuncValue{
T: types.NewType("func(a bool) str"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
if input[0].Bool() {
return nil, fmt.Errorf("we errored on request")
}

View File

@@ -30,6 +30,7 @@
package coreexample
import (
"context"
"fmt"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -39,7 +40,7 @@ import (
func init() {
simple.ModuleRegister(ModuleName, "int2str", &types.FuncValue{
T: types.NewType("func(a int) str"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: fmt.Sprintf("%d", input[0].Int()),
}, nil

View File

@@ -30,6 +30,8 @@
package corenested
import (
"context"
coreexample "github.com/purpleidea/mgmt/lang/core/example"
"github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types"
@@ -43,7 +45,7 @@ func init() {
}
// Hello returns some string. This is just to test nesting.
func Hello(input []types.Value) (types.Value, error) {
func Hello(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: "Hello!",
}, nil

View File

@@ -30,6 +30,8 @@
package coreexample
import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types"
)
@@ -42,7 +44,7 @@ func init() {
}
// Plus returns y + z.
func Plus(input []types.Value) (types.Value, error) {
func Plus(ctx context.Context, input []types.Value) (types.Value, error) {
y, z := input[0].Str(), input[1].Str()
return &types.StrValue{
V: y + z,

View File

@@ -30,6 +30,7 @@
package coreexample
import (
"context"
"strconv"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -39,7 +40,7 @@ import (
func init() {
simple.ModuleRegister(ModuleName, "str2int", &types.FuncValue{
T: types.NewType("func(a str) int"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
var i int64
if val, err := strconv.ParseInt(input[0].Str(), 10, 64); err == nil {
i = val

View File

@@ -766,7 +766,7 @@ func (obj *MapFunc) replaceSubGraph(subgraphInput interfaces.Func) error {
outputListFunc := structs.SimpleFnToDirectFunc(
"mapOutputList",
&types.FuncValue{
V: func(args []types.Value) (types.Value, error) {
V: func(_ context.Context, args []types.Value) (types.Value, error) {
listValue := &types.ListValue{
V: args,
T: obj.outputListType,
@@ -788,7 +788,7 @@ func (obj *MapFunc) replaceSubGraph(subgraphInput interfaces.Func) error {
inputElemFunc := structs.SimpleFnToDirectFunc(
fmt.Sprintf("mapInputElem[%d]", i),
&types.FuncValue{
V: func(args []types.Value) (types.Value, error) {
V: func(_ context.Context, args []types.Value) (types.Value, error) {
if len(args) != 1 {
return nil, fmt.Errorf("inputElemFunc: expected a single argument")
}

View File

@@ -30,6 +30,7 @@
package core
import (
"context"
"fmt"
"github.com/purpleidea/mgmt/lang/funcs/simplepoly"
@@ -56,7 +57,7 @@ func init() {
// Len returns the number of elements in a list or the number of key pairs in a
// map. It can operate on either of these types.
func Len(input []types.Value) (types.Value, error) {
func Len(ctx context.Context, input []types.Value) (types.Value, error) {
var length int
switch k := input[0].Type().Kind; k {
case types.KindStr:

View File

@@ -30,6 +30,7 @@
package coremath
import (
"context"
"fmt"
"github.com/purpleidea/mgmt/lang/funcs/simplepoly"
@@ -57,8 +58,8 @@ func init() {
// in a sig field, like how we demonstrate in the implementation of FortyTwo. If
// the API doesn't change, then this is an example of how to build this as a
// wrapper.
func fortyTwo(sig *types.Type) func([]types.Value) (types.Value, error) {
return func(input []types.Value) (types.Value, error) {
func fortyTwo(sig *types.Type) func(context.Context, []types.Value) (types.Value, error) {
return func(ctx context.Context, input []types.Value) (types.Value, error) {
return FortyTwo(sig, input)
}
}

View File

@@ -30,6 +30,8 @@
package coremath
import (
"context"
"github.com/purpleidea/mgmt/lang/funcs/simple"
"github.com/purpleidea/mgmt/lang/types"
)
@@ -42,7 +44,7 @@ func init() {
}
// Minus1 takes an int and subtracts one from it.
func Minus1(input []types.Value) (types.Value, error) {
func Minus1(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: check for overflow
return &types.IntValue{
V: input[0].Int() - 1,

View File

@@ -30,6 +30,7 @@
package coremath
import (
"context"
"fmt"
"math"
@@ -54,7 +55,7 @@ func init() {
// both of KindInt or both of KindFloat, and it will return the same kind. If
// you pass in a divisor of zero, this will error, eg: mod(x, 0) = NaN.
// TODO: consider returning zero instead of erroring?
func Mod(input []types.Value) (types.Value, error) {
func Mod(ctx context.Context, input []types.Value) (types.Value, error) {
var x, y float64
var float bool
k := input[0].Type().Kind

View File

@@ -30,6 +30,7 @@
package coremath
import (
"context"
"fmt"
"math"
@@ -45,7 +46,7 @@ func init() {
}
// Pow returns x ^ y, the base-x exponential of y.
func Pow(input []types.Value) (types.Value, error) {
func Pow(ctx context.Context, input []types.Value) (types.Value, error) {
x, y := input[0].Float(), input[1].Float()
// FIXME: check for overflow
z := math.Pow(x, y)

View File

@@ -30,6 +30,7 @@
package coremath
import (
"context"
"fmt"
"math"
@@ -45,7 +46,7 @@ func init() {
}
// Sqrt returns sqrt(x), the square root of x.
func Sqrt(input []types.Value) (types.Value, error) {
func Sqrt(ctx context.Context, input []types.Value) (types.Value, error) {
x := input[0].Float()
y := math.Sqrt(x)
if math.IsNaN(y) {

View File

@@ -30,6 +30,7 @@
package coremath
import (
"context"
"fmt"
"math"
"testing"
@@ -40,7 +41,7 @@ import (
func testSqrtSuccess(input, sqrt float64) error {
inputVal := &types.FloatValue{V: input}
val, err := Sqrt([]types.Value{inputVal})
val, err := Sqrt(context.Background(), []types.Value{inputVal})
if err != nil {
return err
}
@@ -52,7 +53,7 @@ func testSqrtSuccess(input, sqrt float64) error {
func testSqrtError(input float64) error {
inputVal := &types.FloatValue{V: input}
_, err := Sqrt([]types.Value{inputVal})
_, err := Sqrt(context.Background(), []types.Value{inputVal})
if err == nil {
return fmt.Errorf("expected error for input %f, got nil", input)
}

View File

@@ -30,6 +30,7 @@
package corenet
import (
"context"
"net"
"strings"
@@ -45,7 +46,7 @@ func init() {
}
// CidrToIP returns the IP from a CIDR address
func CidrToIP(input []types.Value) (types.Value, error) {
func CidrToIP(ctx context.Context, input []types.Value) (types.Value, error) {
cidr := input[0].Str()
ip, _, err := net.ParseCIDR(strings.TrimSpace(cidr))
if err != nil {

View File

@@ -30,6 +30,7 @@
package corenet
import (
"context"
"fmt"
"testing"
@@ -61,7 +62,7 @@ func TestCidrToIP(t *testing.T) {
for _, ts := range cidrtests {
test := ts
t.Run(test.name, func(t *testing.T) {
output, err := CidrToIP([]types.Value{&types.StrValue{V: test.input}})
output, err := CidrToIP(context.Background(), []types.Value{&types.StrValue{V: test.input}})
expectedStr := &types.StrValue{V: test.expected}
if test.err != nil && err.Error() != test.err.Error() {

View File

@@ -30,6 +30,7 @@
package corenet
import (
"context"
"fmt"
"net"
"strings"
@@ -51,7 +52,7 @@ func init() {
// MacFmt takes a MAC address with hyphens and converts it to a format with
// colons.
func MacFmt(input []types.Value) (types.Value, error) {
func MacFmt(ctx context.Context, input []types.Value) (types.Value, error) {
mac := input[0].Str()
// Check if the MAC address is valid.
@@ -70,7 +71,7 @@ func MacFmt(input []types.Value) (types.Value, error) {
// OldMacFmt takes a MAC address with colons and converts it to a format with
// hyphens. This is the old deprecated style that nobody likes.
func OldMacFmt(input []types.Value) (types.Value, error) {
func OldMacFmt(ctx context.Context, input []types.Value) (types.Value, error) {
mac := input[0].Str()
// Check if the MAC address is valid.

View File

@@ -30,6 +30,7 @@
package corenet
import (
"context"
"testing"
"github.com/purpleidea/mgmt/lang/types"
@@ -51,7 +52,7 @@ func TestMacFmt(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
m, err := MacFmt([]types.Value{&types.StrValue{V: tt.in}})
m, err := MacFmt(context.Background(), []types.Value{&types.StrValue{V: tt.in}})
if (err != nil) != tt.wantErr {
t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -81,7 +82,7 @@ func TestOldMacFmt(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
m, err := OldMacFmt([]types.Value{&types.StrValue{V: tt.in}})
m, err := OldMacFmt(context.Background(), []types.Value{&types.StrValue{V: tt.in}})
if (err != nil) != tt.wantErr {
t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@@ -30,6 +30,7 @@
package coreos
import (
"context"
"os"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -47,7 +48,7 @@ func init() {
// return different values depending on how this is deployed, so don't expect a
// result on your deploy client to behave the same as a server receiving code.
// FIXME: Sanitize any command-line secrets we might pass in by cli.
func Args([]types.Value) (types.Value, error) {
func Args(context.Context, []types.Value) (types.Value, error) {
values := []types.Value{}
for _, s := range os.Args {
values = append(values, &types.StrValue{V: s})

View File

@@ -30,6 +30,7 @@
package coreos
import (
"context"
"fmt"
"strings"
@@ -51,7 +52,7 @@ func init() {
// ParseDistroUID parses a distro UID into its component values. If it cannot
// parse correctly, all the struct fields have the zero values.
// NOTE: The UID pattern is subject to change.
func ParseDistroUID(input []types.Value) (types.Value, error) {
func ParseDistroUID(ctx context.Context, input []types.Value) (types.Value, error) {
fn := func(distro, version, arch string) (types.Value, error) {
st := types.NewStruct(types.NewType(structDistroUID))
if err := st.Set("distro", &types.StrValue{V: distro}); err != nil {

View File

@@ -30,6 +30,7 @@
package coreos
import (
"context"
"os"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -54,8 +55,9 @@ func init() {
// IsDebian detects if the os family is debian.
// TODO: Detect OS changes.
func IsDebian(input []types.Value) (types.Value, error) {
func IsDebian(ctx context.Context, input []types.Value) (types.Value, error) {
exists := true
// TODO: use ctx around io operations
_, err := os.Stat("/etc/debian_version")
if os.IsNotExist(err) {
exists = false
@@ -67,8 +69,9 @@ func IsDebian(input []types.Value) (types.Value, error) {
// IsRedHat detects if the os family is redhat.
// TODO: Detect OS changes.
func IsRedHat(input []types.Value) (types.Value, error) {
func IsRedHat(ctx context.Context, input []types.Value) (types.Value, error) {
exists := true
// TODO: use ctx around io operations
_, err := os.Stat("/etc/redhat-release")
if os.IsNotExist(err) {
exists = false
@@ -80,8 +83,9 @@ func IsRedHat(input []types.Value) (types.Value, error) {
// IsArchLinux detects if the os family is archlinux.
// TODO: Detect OS changes.
func IsArchLinux(input []types.Value) (types.Value, error) {
func IsArchLinux(ctx context.Context, input []types.Value) (types.Value, error) {
exists := true
// TODO: use ctx around io operations
_, err := os.Stat("/etc/arch-release")
if os.IsNotExist(err) {
exists = false

View File

@@ -30,6 +30,7 @@
package core
import (
"context"
"fmt"
"github.com/purpleidea/mgmt/lang/funcs/simplepoly"
@@ -52,7 +53,7 @@ func init() {
// Panic returns an error when it receives a non-empty string or a true boolean.
// The error should cause the function engine to shutdown. If there's no error,
// it returns false.
func Panic(input []types.Value) (types.Value, error) {
func Panic(ctx context.Context, input []types.Value) (types.Value, error) {
switch k := input[0].Type().Kind; k {
case types.KindBool:
if input[0].Bool() {

View File

@@ -30,6 +30,7 @@
package coreregexp
import (
"context"
"regexp"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -45,7 +46,7 @@ func init() {
}
// Match matches whether a string matches the regexp pattern.
func Match(input []types.Value) (types.Value, error) {
func Match(ctx context.Context, input []types.Value) (types.Value, error) {
pattern := input[0].Str()
s := input[1].Str()

View File

@@ -30,6 +30,7 @@
package coreregexp
import (
"context"
"testing"
"github.com/purpleidea/mgmt/lang/types"
@@ -76,7 +77,7 @@ func TestMatch0(t *testing.T) {
for i, x := range values {
pattern := &types.StrValue{V: x.pattern}
s := &types.StrValue{V: x.s}
val, err := Match([]types.Value{pattern, s})
val, err := Match(context.Background(), []types.Value{pattern, s})
if err != nil {
t.Errorf("test index %d failed with: %+v", i, err)
}

View File

@@ -30,6 +30,7 @@
package corestrings
import (
"context"
"strings"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -45,7 +46,7 @@ func init() {
// Split splits the input string using the separator and returns the segments as
// a list.
func Split(input []types.Value) (types.Value, error) {
func Split(ctx context.Context, input []types.Value) (types.Value, error) {
str, sep := input[0].Str(), input[1].Str()
segments := strings.Split(str, sep)

View File

@@ -30,6 +30,7 @@
package corestrings
import (
"context"
"fmt"
"testing"
@@ -40,7 +41,7 @@ import (
func testSplit(input, sep string, output []string) error {
inputVal, sepVal := &types.StrValue{V: input}, &types.StrValue{V: sep}
val, err := Split([]types.Value{inputVal, sepVal})
val, err := Split(context.Background(), []types.Value{inputVal, sepVal})
if err != nil {
return err
}

View File

@@ -30,6 +30,7 @@
package corestrings
import (
"context"
"strings"
"github.com/purpleidea/mgmt/lang/funcs/simple"
@@ -44,7 +45,7 @@ func init() {
}
// ToLower turns a string to lowercase.
func ToLower(input []types.Value) (types.Value, error) {
func ToLower(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: strings.ToLower(input[0].Str()),
}, nil

View File

@@ -30,6 +30,7 @@
package corestrings
import (
"context"
"testing"
"github.com/purpleidea/mgmt/lang/types"
@@ -37,7 +38,7 @@ import (
func testToLower(t *testing.T, input, expected string) {
inputStr := &types.StrValue{V: input}
value, err := ToLower([]types.Value{inputStr})
value, err := ToLower(context.Background(), []types.Value{inputStr})
if err != nil {
t.Error(err)
return

View File

@@ -30,6 +30,7 @@
package coresys
import (
"context"
"os"
"strings"
@@ -58,7 +59,7 @@ func init() {
// GetEnv gets environment variable by name or returns empty string if non
// existing.
func GetEnv(input []types.Value) (types.Value, error) {
func GetEnv(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: os.Getenv(input[0].Str()),
}, nil
@@ -66,7 +67,7 @@ func GetEnv(input []types.Value) (types.Value, error) {
// DefaultEnv gets environment variable by name or returns default if non
// existing.
func DefaultEnv(input []types.Value) (types.Value, error) {
func DefaultEnv(ctx context.Context, input []types.Value) (types.Value, error) {
value, exists := os.LookupEnv(input[0].Str())
if !exists {
value = input[1].Str()
@@ -77,7 +78,7 @@ func DefaultEnv(input []types.Value) (types.Value, error) {
}
// HasEnv returns true if environment variable exists.
func HasEnv(input []types.Value) (types.Value, error) {
func HasEnv(ctx context.Context, input []types.Value) (types.Value, error) {
_, exists := os.LookupEnv(input[0].Str())
return &types.BoolValue{
V: exists,
@@ -85,7 +86,7 @@ func HasEnv(input []types.Value) (types.Value, error) {
}
// Env returns a map of all keys and their values.
func Env(input []types.Value) (types.Value, error) {
func Env(ctx context.Context, input []types.Value) (types.Value, error) {
environ := make(map[types.Value]types.Value)
for _, keyval := range os.Environ() {
if i := strings.IndexRune(keyval, '='); i != -1 {

View File

@@ -390,7 +390,7 @@ func (obj *TemplateFunc) Init(init *interfaces.Init) error {
}
// run runs a template and returns the result.
func (obj *TemplateFunc) run(templateText string, vars types.Value) (string, error) {
func (obj *TemplateFunc) run(ctx context.Context, templateText string, vars types.Value) (string, error) {
// see: https://golang.org/pkg/text/template/#FuncMap for more info
// note: we can override any other functions by adding them here...
funcMap := map[string]interface{}{
@@ -425,7 +425,7 @@ func (obj *TemplateFunc) run(templateText string, vars types.Value) (string, err
// parameter types. Functions meant to apply to arguments of
// arbitrary type can use parameters of type interface{} or of
// type reflect.Value.
f, err := wrap(name, fn) // wrap it so that it meets API expectations
f, err := wrap(ctx, name, fn) // wrap it so that it meets API expectations
if err != nil {
obj.init.Logf("warning, skipping function named: `%s`, err: %v", name, err)
continue
@@ -538,7 +538,7 @@ func (obj *TemplateFunc) Stream(ctx context.Context) error {
vars = nil
}
result, err := obj.run(tmpl, vars)
result, err := obj.run(ctx, tmpl, vars)
if err != nil {
return err // no errwrap needed b/c helper func
}
@@ -585,7 +585,7 @@ func safename(name string) string {
// function API with what is expected from the reflection API. It returns a
// version that includes the optional second error return value so that our
// functions can return errors without causing a panic.
func wrap(name string, fn *types.FuncValue) (_ interface{}, reterr error) {
func wrap(ctx context.Context, name string, fn *types.FuncValue) (_ interface{}, reterr error) {
defer func() {
// catch unhandled panics
if r := recover(); r != nil {
@@ -633,7 +633,7 @@ func wrap(name string, fn *types.FuncValue) (_ interface{}, reterr error) {
innerArgs = append(innerArgs, v)
}
result, err := fn.Call(innerArgs) // call it
result, err := fn.Call(ctx, innerArgs) // call it
if err != nil { // function errored :(
// errwrap is a better way to report errors, if allowed!
r := reflect.ValueOf(errwrap.Wrapf(err, "function `%s` errored", name))

View File

@@ -123,7 +123,7 @@ func init() {
simple.ModuleRegister(ModuleName, OneInstanceBFuncName, &types.FuncValue{
T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) {
V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceBMutex.Lock()
if oneInstanceBFlag {
panic("should not get called twice")
@@ -135,7 +135,7 @@ func init() {
})
simple.ModuleRegister(ModuleName, OneInstanceDFuncName, &types.FuncValue{
T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) {
V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceDMutex.Lock()
if oneInstanceDFlag {
panic("should not get called twice")
@@ -147,7 +147,7 @@ func init() {
})
simple.ModuleRegister(ModuleName, OneInstanceFFuncName, &types.FuncValue{
T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) {
V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceFMutex.Lock()
if oneInstanceFFlag {
panic("should not get called twice")
@@ -159,7 +159,7 @@ func init() {
})
simple.ModuleRegister(ModuleName, OneInstanceHFuncName, &types.FuncValue{
T: types.NewType("func() str"),
V: func([]types.Value) (types.Value, error) {
V: func(context.Context, []types.Value) (types.Value, error) {
oneInstanceHMutex.Lock()
if oneInstanceHFlag {
panic("should not get called twice")

View File

@@ -30,6 +30,7 @@
package core
import (
"context"
"testpkg"
"github.com/purpleidea/mgmt/lang/funcs/funcgen/util"
@@ -65,25 +66,25 @@ func init() {
}
func TestpkgAllKind(input []types.Value) (types.Value, error) {
func TestpkgAllKind(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: testpkg.AllKind(input[0].Int(), input[1].Str()),
}, nil
}
func TestpkgToUpper(input []types.Value) (types.Value, error) {
func TestpkgToUpper(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: testpkg.ToUpper(input[0].Str()),
}, nil
}
func TestpkgMax(input []types.Value) (types.Value, error) {
func TestpkgMax(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: testpkg.Max(input[0].Float(), input[1].Float()),
}, nil
}
func TestpkgWithError(input []types.Value) (types.Value, error) {
func TestpkgWithError(ctx context.Context, input []types.Value) (types.Value, error) {
v, err := testpkg.WithError(input[0].Str())
if err != nil {
return nil, err
@@ -93,13 +94,13 @@ func TestpkgWithError(input []types.Value) (types.Value, error) {
}, nil
}
func TestpkgWithInt(input []types.Value) (types.Value, error) {
func TestpkgWithInt(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: testpkg.WithInt(input[0].Float(), int(input[1].Int()), input[2].Int(), int(input[3].Int()), int(input[4].Int()), input[5].Bool(), input[6].Str()),
}, nil
}
func TestpkgSuperByte(input []types.Value) (types.Value, error) {
func TestpkgSuperByte(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: string(testpkg.SuperByte([]byte(input[0].Str()), input[1].Str())),
}, nil

View File

@@ -30,6 +30,7 @@
package core
import (
"context"
{{ range $i, $func := .Packages }} {{ if not (eq .Alias "") }}{{.Alias}} {{end}}"{{.Name}}"
{{ end }}
"github.com/purpleidea/mgmt/lang/funcs/funcgen/util"
@@ -45,7 +46,7 @@ func init() {
{{ end }}
}
{{ range $i, $func := .Functions }}
{{$func.Help}}func {{$func.InternalName}}(input []types.Value) (types.Value, error) {
{{$func.Help}}func {{$func.InternalName}}(ctx context.Context, input []types.Value) (types.Value, error) {
{{- if $func.Errorful }}
v, err := {{ if not (eq $func.GolangPackage.Alias "") }}{{$func.GolangPackage.Alias}}{{else}}{{$func.GolangPackage.Name}}{{end}}.{{$func.GolangFunc}}({{$func.MakeGolangArgs}})
if err != nil {

View File

@@ -55,7 +55,7 @@ func init() {
// concatenation
RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a str, b str) str"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.StrValue{
V: input[0].Str() + input[1].Str(),
}, nil
@@ -64,7 +64,7 @@ func init() {
// addition
RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
//if l := len(input); l != 2 {
// return nil, fmt.Errorf("expected two inputs, got: %d", l)
//}
@@ -77,7 +77,7 @@ func init() {
// floating-point addition
RegisterOperator("+", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: input[0].Float() + input[1].Float(),
}, nil
@@ -87,7 +87,7 @@ func init() {
// subtraction
RegisterOperator("-", &types.FuncValue{
T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.IntValue{
V: input[0].Int() - input[1].Int(),
}, nil
@@ -96,7 +96,7 @@ func init() {
// floating-point subtraction
RegisterOperator("-", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: input[0].Float() - input[1].Float(),
}, nil
@@ -106,7 +106,7 @@ func init() {
// multiplication
RegisterOperator("*", &types.FuncValue{
T: types.NewType("func(a int, b int) int"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// FIXME: check for overflow?
return &types.IntValue{
V: input[0].Int() * input[1].Int(),
@@ -116,7 +116,7 @@ func init() {
// floating-point multiplication
RegisterOperator("*", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: input[0].Float() * input[1].Float(),
}, nil
@@ -127,7 +127,7 @@ func init() {
// division
RegisterOperator("/", &types.FuncValue{
T: types.NewType("func(a int, b int) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
divisor := input[1].Int()
if divisor == 0 {
return nil, fmt.Errorf("can't divide by zero")
@@ -140,7 +140,7 @@ func init() {
// floating-point division
RegisterOperator("/", &types.FuncValue{
T: types.NewType("func(a float, b float) float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
divisor := input[1].Float()
if divisor == 0.0 {
return nil, fmt.Errorf("can't divide by zero")
@@ -154,7 +154,7 @@ func init() {
// string equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a str, b str) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Str() == input[1].Str(),
}, nil
@@ -163,7 +163,7 @@ func init() {
// bool equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() == input[1].Bool(),
}, nil
@@ -172,7 +172,7 @@ func init() {
// int equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() == input[1].Int(),
}, nil
@@ -181,7 +181,7 @@ func init() {
// floating-point equality
RegisterOperator("==", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() == input[1].Float(),
@@ -192,7 +192,7 @@ func init() {
// string in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a str, b str) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Str() != input[1].Str(),
}, nil
@@ -201,7 +201,7 @@ func init() {
// bool in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() != input[1].Bool(),
}, nil
@@ -210,7 +210,7 @@ func init() {
// int in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() != input[1].Int(),
}, nil
@@ -219,7 +219,7 @@ func init() {
// floating-point in-equality
RegisterOperator("!=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() != input[1].Float(),
@@ -230,7 +230,7 @@ func init() {
// less-than
RegisterOperator("<", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() < input[1].Int(),
}, nil
@@ -239,7 +239,7 @@ func init() {
// floating-point less-than
RegisterOperator("<", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() < input[1].Float(),
@@ -249,7 +249,7 @@ func init() {
// greater-than
RegisterOperator(">", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() > input[1].Int(),
}, nil
@@ -258,7 +258,7 @@ func init() {
// floating-point greater-than
RegisterOperator(">", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() > input[1].Float(),
@@ -268,7 +268,7 @@ func init() {
// less-than-equal
RegisterOperator("<=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() <= input[1].Int(),
}, nil
@@ -277,7 +277,7 @@ func init() {
// floating-point less-than-equal
RegisterOperator("<=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() <= input[1].Float(),
@@ -287,7 +287,7 @@ func init() {
// greater-than-equal
RegisterOperator(">=", &types.FuncValue{
T: types.NewType("func(a int, b int) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Int() >= input[1].Int(),
}, nil
@@ -296,7 +296,7 @@ func init() {
// floating-point greater-than-equal
RegisterOperator(">=", &types.FuncValue{
T: types.NewType("func(a float, b float) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
// TODO: should we do an epsilon check?
return &types.BoolValue{
V: input[0].Float() >= input[1].Float(),
@@ -309,7 +309,7 @@ func init() {
// short-circuit operators, and does it matter?
RegisterOperator("and", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() && input[1].Bool(),
}, nil
@@ -318,7 +318,7 @@ func init() {
// logical or
RegisterOperator("or", &types.FuncValue{
T: types.NewType("func(a bool, b bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: input[0].Bool() || input[1].Bool(),
}, nil
@@ -328,7 +328,7 @@ func init() {
// logical not (unary operator)
RegisterOperator("not", &types.FuncValue{
T: types.NewType("func(a bool) bool"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.BoolValue{
V: !input[0].Bool(),
}, nil
@@ -338,7 +338,7 @@ func init() {
// pi operator (this is an easter egg to demo a zero arg operator)
RegisterOperator("π", &types.FuncValue{
T: types.NewType("func() float"),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
return &types.FloatValue{
V: math.Pi,
}, nil
@@ -938,7 +938,7 @@ func (obj *OperatorFunc) Stream(ctx context.Context) error {
lastOp = op
var result types.Value
result, err := fn.Call(args) // run the function
result, err := fn.Call(ctx, args) // run the function
if err != nil {
return errwrap.Wrapf(err, "problem running function")
}

View File

@@ -173,7 +173,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error {
values = append(values, x)
}
result, err := obj.Fn.Call(values) // (Value, error)
result, err := obj.Fn.Call(ctx, values) // (Value, error)
if err != nil {
return errwrap.Wrapf(err, "simple function errored")
}
@@ -244,7 +244,7 @@ func StructRegister(moduleName string, args interface{}) error {
ModuleRegister(moduleName, name, &types.FuncValue{
T: types.NewType(fmt.Sprintf("func() %s", typed.String())),
V: func(input []types.Value) (types.Value, error) {
V: func(ctx context.Context, input []types.Value) (types.Value, error) {
//if args == nil {
// // programming error
// return nil, fmt.Errorf("could not convert/access our struct")

View File

@@ -602,7 +602,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error {
if obj.init.Debug {
obj.init.Logf("Calling function with: %+v", values)
}
result, err := obj.fn.Call(values) // (Value, error)
result, err := obj.fn.Call(ctx, values) // (Value, error)
if err != nil {
if obj.init.Debug {
obj.init.Logf("Function returned error: %+v", err)

View File

@@ -30,6 +30,7 @@
package types
import (
"context"
"errors"
"fmt"
"net"
@@ -230,7 +231,7 @@ func ValueOf(v reflect.Value) (Value, error) {
return nil, fmt.Errorf("cannot only represent functions with one output value")
}
f := func(args []Value) (Value, error) {
f := func(ctx context.Context, args []Value) (Value, error) {
in := []reflect.Value{}
for _, x := range args {
// TODO: should we build this method instead?
@@ -239,6 +240,7 @@ func ValueOf(v reflect.Value) (Value, error) {
in = append(in, v)
}
// FIXME: can we pass in ctx ?
// FIXME: can we trap panic's ?
out := value.Call(in) // []reflect.Value
if len(out) != 1 { // TODO: panic, b/c already checked in TypeOf?
@@ -1207,7 +1209,7 @@ func (obj *StructValue) Lookup(k string) (value Value, exists bool) {
// Func nodes.
type FuncValue struct {
Base
V func([]Value) (Value, error)
V func(context.Context, []Value) (Value, error)
T *Type // contains ordered field types, arg names are a bonus part
}
@@ -1217,7 +1219,7 @@ func NewFunc(t *Type) *FuncValue {
if t.Kind != KindFunc {
return nil // sanity check
}
v := func([]Value) (Value, error) {
v := func(context.Context, []Value) (Value, error) {
// You were not supposed to call the temporary function, you
// were supposed to replace it with a real implementation!
return nil, fmt.Errorf("nil function")
@@ -1301,7 +1303,7 @@ func (obj *FuncValue) Value() interface{} {
// Call runs the function value and returns its result. It returns an error if
// something goes wrong during execution, and panic's if you call this with
// inappropriate input types, or if it returns an inappropriate output type.
func (obj *FuncValue) Call(args []Value) (Value, error) {
func (obj *FuncValue) Call(ctx context.Context, args []Value) (Value, error) {
// cmp input args type to obj.T
length := len(obj.T.Ord)
if length != len(args) {
@@ -1313,7 +1315,7 @@ func (obj *FuncValue) Call(args []Value) (Value, error) {
}
}
result, err := obj.V(args) // call it
result, err := obj.V(ctx, args) // call it
if result == nil {
if err == nil {
return nil, fmt.Errorf("function returned nil result")