diff --git a/lang/interpret_test.go b/lang/interpret_test.go index d69f4513..09e78d71 100644 --- a/lang/interpret_test.go +++ b/lang/interpret_test.go @@ -547,6 +547,7 @@ func TestAstFunc0(t *testing.T) { func TestAstFunc1(t *testing.T) { const magicError = "# err: " const magicError1 = "err1: " + const magicErrorFailInit = "errInit: " const magicError2 = "err2: " const magicError3 = "err3: " const magicError4 = "err4: " @@ -578,10 +579,11 @@ func TestAstFunc1(t *testing.T) { } type errs struct { - fail1 bool - fail2 bool - fail3 bool - fail4 bool + fail1 bool + failInit bool + fail2 bool + fail3 bool + fail4 bool } type test struct { // an individual test name string @@ -635,6 +637,7 @@ func TestAstFunc1(t *testing.T) { // if the graph file has a magic error string, it's a failure errStr := "" fail1 := false + failInit := false fail2 := false fail3 := false fail4 := false @@ -647,6 +650,11 @@ func TestAstFunc1(t *testing.T) { str = errStr fail1 = true } + if strings.HasPrefix(str, magicErrorFailInit) { + errStr = strings.TrimPrefix(str, magicErrorFailInit) + str = errStr + failInit = true + } if strings.HasPrefix(str, magicError2) { errStr = strings.TrimPrefix(str, magicError2) str = errStr @@ -671,10 +679,11 @@ func TestAstFunc1(t *testing.T) { fail: errStr != "", expstr: str, errs: errs{ - fail1: fail1, - fail2: fail2, - fail3: fail3, - fail4: fail4, + fail1: fail1, + failInit: failInit, + fail2: fail2, + fail3: fail3, + fail4: fail4, }, }) //t.Logf("adding: %s", f + "/") @@ -709,6 +718,7 @@ func TestAstFunc1(t *testing.T) { name, path, fail, expstr, errs := tc.name, tc.path, tc.fail, strings.Trim(tc.expstr, "\n"), tc.errs src := dir + path // location of the test fail1 := errs.fail1 + failInit := errs.failInit fail2 := errs.fail2 fail3 := errs.fail3 fail4 := errs.fail4 @@ -818,11 +828,27 @@ func TestAstFunc1(t *testing.T) { }, } // some of this might happen *after* interpolate in SetScope or Unify... - if err := ast.Init(data); err != nil { + err = ast.Init(data) + if (!fail || !failInit) && err != nil { t.Errorf("test #%d: FAIL", index) t.Errorf("test #%d: could not init and validate AST: %+v", index, err) return } + if failInit && err != nil { + s := err.Error() // convert to string + if s != expstr { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: expected different error", index) + t.Logf("test #%d: err: %s", index, s) + t.Logf("test #%d: exp: %s", index, expstr) + } + return + } + if failInit && err == nil { + t.Errorf("test #%d: FAIL", index) + t.Errorf("test #%d: functions passed, expected fail", index) + return + } iast, err := ast.Interpolate() if err != nil {