lang: Add partial recursive support/detection to class

This adds the additional bits onto the class/include statements to
support or detect class recursion. It's not currently supported, but
I figured I'd commit the detection code as a variant of the recursion
implementation, since I think this is correct, and it was a bit tricky
for me to get it right.
This commit is contained in:
James Shubin
2018-06-17 17:14:45 -04:00
parent c62b8a5d4f
commit 05f6ba7297
3 changed files with 101 additions and 2 deletions

View File

@@ -69,6 +69,8 @@ type Scope struct {
Variables map[string]Expr Variables map[string]Expr
//Functions map[string]??? // TODO: do we want a separate namespace for user defined functions? //Functions map[string]??? // TODO: do we want a separate namespace for user defined functions?
Classes map[string]Stmt Classes map[string]Stmt
Chain []Stmt // chain of previously seen stmt's
} }
// Empty returns the zero, empty value for the scope, with all the internal // Empty returns the zero, empty value for the scope, with all the internal
@@ -78,6 +80,7 @@ func (obj *Scope) Empty() *Scope {
Variables: make(map[string]Expr), Variables: make(map[string]Expr),
//Functions: ???, //Functions: ???,
Classes: make(map[string]Stmt), Classes: make(map[string]Stmt),
Chain: []Stmt{},
} }
} }
@@ -88,6 +91,7 @@ func (obj *Scope) Empty() *Scope {
func (obj *Scope) Copy() *Scope { func (obj *Scope) Copy() *Scope {
variables := make(map[string]Expr) variables := make(map[string]Expr)
classes := make(map[string]Stmt) classes := make(map[string]Stmt)
chain := []Stmt{}
if obj != nil { // allow copying nil scopes if obj != nil { // allow copying nil scopes
for k, v := range obj.Variables { // copy for k, v := range obj.Variables { // copy
variables[k] = v // we don't copy the expr's! variables[k] = v // we don't copy the expr's!
@@ -95,10 +99,14 @@ func (obj *Scope) Copy() *Scope {
for k, v := range obj.Classes { // copy for k, v := range obj.Classes { // copy
classes[k] = v // we don't copy the StmtClass! classes[k] = v // we don't copy the StmtClass!
} }
for _, x := range obj.Chain { // copy
chain = append(chain, x) // we don't copy the Stmt pointer!
}
} }
return &Scope{ return &Scope{
Variables: variables, Variables: variables,
Classes: classes, Classes: classes,
Chain: chain,
} }
} }

View File

@@ -800,6 +800,55 @@ func TestInterpretMany(t *testing.T) {
// graph: graph, // graph: graph,
// }) // })
//} //}
// TODO: remove this test if we ever support recursive classes
{
values = append(values, test{
name: "recursive classes fail 1",
code: `
$max = 3
include c1(0) # start at zero
class c1($count) {
if $count == $max {
test "done" {
stringptr => printf("count is %d", $count),
}
} else {
include c1($count + 1) # recursion not supported atm
}
}
`,
fail: true,
})
}
// TODO: remove this test if we ever support recursive classes
{
values = append(values, test{
name: "recursive classes fail 2",
code: `
$max = 5
include c1(0) # start at zero
class c1($count) {
if $count == $max {
test "done" {
stringptr => printf("count is %d", $count),
}
} else {
include c2($count + 1) # recursion not supported atm
}
}
class c2($count) {
if $count == $max {
test "done" {
stringptr => printf("count is %d", $count),
}
} else {
include c1($count + 1) # recursion not supported atm
}
}
`,
fail: true,
})
}
for index, test := range values { // run all the tests for index, test := range values { // run all the tests
name, code, fail, exp := test.name, test.code, test.fail, test.graph name, code, fail, exp := test.name, test.code, test.fail, test.graph

View File

@@ -1214,6 +1214,11 @@ func (obj *StmtProg) SetScope(scope *interfaces.Scope) error {
// now set the child scopes (even on bind...) // now set the child scopes (even on bind...)
for _, x := range obj.Prog { for _, x := range obj.Prog {
// skip over *StmtClass here (essential for recursive classes)
if _, ok := x.(*StmtClass); ok {
continue
}
if err := x.SetScope(newScope); err != nil { if err := x.SetScope(newScope); err != nil {
return err return err
} }
@@ -1379,7 +1384,8 @@ func (obj *StmtClass) Output() (*interfaces.Output, error) {
// to call a class except that it produces output instead of a value. Most of // to call a class except that it produces output instead of a value. Most of
// the interesting logic for classes happens here or in StmtProg. // the interesting logic for classes happens here or in StmtProg.
type StmtInclude struct { type StmtInclude struct {
class *StmtClass // copy of class that we're using class *StmtClass // copy of class that we're using
orig *StmtInclude // original pointer to this
Name string Name string
Args []interfaces.Expr Args []interfaces.Expr
@@ -1400,13 +1406,20 @@ func (obj *StmtInclude) Interpolate() (interfaces.Stmt, error) {
} }
} }
orig := obj
if obj.orig != nil { // preserve the original pointer (the identifier!)
orig = obj.orig
}
return &StmtInclude{ return &StmtInclude{
orig: orig,
Name: obj.Name, Name: obj.Name,
Args: args, Args: args,
}, nil }, nil
} }
// SetScope stores the scope for use in this statement. // SetScope stores the scope for use in this statement. Since this is the first
// location where recursion would play an important role, this also detects and
// handles the recursion scenario.
func (obj *StmtInclude) SetScope(scope *interfaces.Scope) error { func (obj *StmtInclude) SetScope(scope *interfaces.Scope) error {
if scope == nil { if scope == nil {
scope = scope.Empty() scope = scope.Empty()
@@ -1421,6 +1434,30 @@ func (obj *StmtInclude) SetScope(scope *interfaces.Scope) error {
return fmt.Errorf("class scope of `%s` does not contain a class", obj.Name) return fmt.Errorf("class scope of `%s` does not contain a class", obj.Name)
} }
// is it even possible for the signatures to match?
if len(class.Args) != len(obj.Args) {
return fmt.Errorf("class `%s` expected %d args but got %d", obj.Name, len(class.Args), len(obj.Args))
}
if obj.class != nil {
// possible programming error
return fmt.Errorf("include already contains a class pointer")
}
for i := len(scope.Chain) - 1; i >= 0; i-- { // reverse order
x, ok := scope.Chain[i].(*StmtInclude)
if !ok {
continue
}
if x == obj.orig { // look for my original self
// scope chain found!
obj.class = class // same pointer, don't copy
return fmt.Errorf("recursive class `%s` found", obj.Name)
//return nil // if recursion was supported
}
}
// helper function to keep things more logical // helper function to keep things more logical
cp := func(input *StmtClass) (*StmtClass, error) { cp := func(input *StmtClass) (*StmtClass, error) {
// TODO: should we have a dedicated copy method instead? because // TODO: should we have a dedicated copy method instead? because
@@ -1446,6 +1483,11 @@ func (obj *StmtInclude) SetScope(scope *interfaces.Scope) error {
for i, arg := range obj.class.Args { // copy for i, arg := range obj.class.Args { // copy
newScope.Variables[arg.Name] = obj.Args[i] newScope.Variables[arg.Name] = obj.Args[i]
} }
// recursion detection
newScope.Chain = append(newScope.Chain, obj.orig) // add stmt to list
newScope.Classes[obj.Name] = copied // overwrite with new pointer
if err := obj.class.SetScope(newScope); err != nil { if err := obj.class.SetScope(newScope); err != nil {
return err return err
} }