lang: Add partial recursive support/detection to class
This adds the additional bits onto the class/include statements to support or detect class recursion. It's not currently supported, but I figured I'd commit the detection code as a variant of the recursion implementation, since I think this is correct, and it was a bit tricky for me to get it right.
This commit is contained in:
@@ -1214,6 +1214,11 @@ func (obj *StmtProg) SetScope(scope *interfaces.Scope) error {
|
||||
|
||||
// now set the child scopes (even on bind...)
|
||||
for _, x := range obj.Prog {
|
||||
// skip over *StmtClass here (essential for recursive classes)
|
||||
if _, ok := x.(*StmtClass); ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := x.SetScope(newScope); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1379,7 +1384,8 @@ func (obj *StmtClass) Output() (*interfaces.Output, error) {
|
||||
// to call a class except that it produces output instead of a value. Most of
|
||||
// the interesting logic for classes happens here or in StmtProg.
|
||||
type StmtInclude struct {
|
||||
class *StmtClass // copy of class that we're using
|
||||
class *StmtClass // copy of class that we're using
|
||||
orig *StmtInclude // original pointer to this
|
||||
|
||||
Name string
|
||||
Args []interfaces.Expr
|
||||
@@ -1400,13 +1406,20 @@ func (obj *StmtInclude) Interpolate() (interfaces.Stmt, error) {
|
||||
}
|
||||
}
|
||||
|
||||
orig := obj
|
||||
if obj.orig != nil { // preserve the original pointer (the identifier!)
|
||||
orig = obj.orig
|
||||
}
|
||||
return &StmtInclude{
|
||||
orig: orig,
|
||||
Name: obj.Name,
|
||||
Args: args,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetScope stores the scope for use in this statement.
|
||||
// SetScope stores the scope for use in this statement. Since this is the first
|
||||
// location where recursion would play an important role, this also detects and
|
||||
// handles the recursion scenario.
|
||||
func (obj *StmtInclude) SetScope(scope *interfaces.Scope) error {
|
||||
if scope == nil {
|
||||
scope = scope.Empty()
|
||||
@@ -1421,6 +1434,30 @@ func (obj *StmtInclude) SetScope(scope *interfaces.Scope) error {
|
||||
return fmt.Errorf("class scope of `%s` does not contain a class", obj.Name)
|
||||
}
|
||||
|
||||
// is it even possible for the signatures to match?
|
||||
if len(class.Args) != len(obj.Args) {
|
||||
return fmt.Errorf("class `%s` expected %d args but got %d", obj.Name, len(class.Args), len(obj.Args))
|
||||
}
|
||||
|
||||
if obj.class != nil {
|
||||
// possible programming error
|
||||
return fmt.Errorf("include already contains a class pointer")
|
||||
}
|
||||
|
||||
for i := len(scope.Chain) - 1; i >= 0; i-- { // reverse order
|
||||
x, ok := scope.Chain[i].(*StmtInclude)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if x == obj.orig { // look for my original self
|
||||
// scope chain found!
|
||||
obj.class = class // same pointer, don't copy
|
||||
return fmt.Errorf("recursive class `%s` found", obj.Name)
|
||||
//return nil // if recursion was supported
|
||||
}
|
||||
}
|
||||
|
||||
// helper function to keep things more logical
|
||||
cp := func(input *StmtClass) (*StmtClass, error) {
|
||||
// TODO: should we have a dedicated copy method instead? because
|
||||
@@ -1446,6 +1483,11 @@ func (obj *StmtInclude) SetScope(scope *interfaces.Scope) error {
|
||||
for i, arg := range obj.class.Args { // copy
|
||||
newScope.Variables[arg.Name] = obj.Args[i]
|
||||
}
|
||||
|
||||
// recursion detection
|
||||
newScope.Chain = append(newScope.Chain, obj.orig) // add stmt to list
|
||||
newScope.Classes[obj.Name] = copied // overwrite with new pointer
|
||||
|
||||
if err := obj.class.SetScope(newScope); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user