From bdf5209f68671a04608ba861063691d1282aeed7 Mon Sep 17 00:00:00 2001 From: James Shubin Date: Wed, 16 Jul 2025 23:46:09 -0400 Subject: [PATCH] util: errwrap: Add unwrapping for context removal It's common in many concurrent engines to have a situation where we collect errors on shutdown. Errors can either because a context closed, or because some engine error happened. The latter, can also cause the former, leading to a list of returned errors. In these scenarios, we want to filter out all the secondary context errors, unless that's all that's there. This provides a helper function to do so. --- util/errwrap/errwrap.go | 38 ++++++++++ util/errwrap/errwrap_test.go | 143 +++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) diff --git a/util/errwrap/errwrap.go b/util/errwrap/errwrap.go index 06a91150..f8259a06 100644 --- a/util/errwrap/errwrap.go +++ b/util/errwrap/errwrap.go @@ -31,6 +31,8 @@ package errwrap import ( + "context" + "github.com/hashicorp/go-multierror" "github.com/pkg/errors" ) @@ -72,6 +74,42 @@ func Join(errs []error) error { return reterr } +// WithoutContext passes through an error if it's nil or is a normal error. If +// it is a multierror, then it removes all errors of context cancellation unless +// that's all there is. If there's one error remaining, it returns that. +// Otherwise it returns a new multierror with what's left. +func WithoutContext(err error) error { + if err == nil { + return nil + } + + multiErr, ok := err.(*multierror.Error) + if !ok { + return err + } + + if len(multiErr.Errors) == 0 { + return err // unexpected + } + + errs := []error{} + for _, e := range multiErr.Errors { + if e == context.Canceled { + continue // remove + } + errs = append(errs, e) + } + + if len(errs) == 0 { + return context.Canceled + } + //if len(errs) == 1 { + // return errs[0] + //} + + return Join(errs) +} + // String returns a string representation of the error. In particular, if the // error is nil, it returns an empty string instead of panicking. func String(err error) string { diff --git a/util/errwrap/errwrap_test.go b/util/errwrap/errwrap_test.go index c5b6d7b4..6f291cc9 100644 --- a/util/errwrap/errwrap_test.go +++ b/util/errwrap/errwrap_test.go @@ -32,6 +32,7 @@ package errwrap import ( + "context" "errors" "fmt" "testing" @@ -163,6 +164,148 @@ func TestJoinErr10(t *testing.T) { } } +func TestWithoutContext1(t *testing.T) { + if reterr := WithoutContext(nil); reterr != nil { + t.Errorf("expected nil result") + } +} + +func TestWithoutContext2(t *testing.T) { + err := fmt.Errorf("err") + if reterr := WithoutContext(err); reterr != err { + t.Errorf("expected err") + } +} + +func TestWithoutContext3(t *testing.T) { + err := context.Canceled + if reterr := WithoutContext(err); reterr != err { + t.Errorf("expected err") + } +} + +func TestWithoutContext4(t *testing.T) { + err1 := fmt.Errorf("err") + err2 := context.Canceled + err := Append(err1, err2) + if reterr := WithoutContext(err); reterr != err1 { + t.Errorf("expected err") + } +} + +func TestWithoutContext5(t *testing.T) { + err1 := context.Canceled + err2 := fmt.Errorf("err") + err := Append(err1, err2) + if reterr := WithoutContext(err); reterr != err2 { + t.Errorf("expected err") + } +} + +func TestWithoutContext6(t *testing.T) { + err1 := context.Canceled + err2 := context.Canceled + err := Append(err1, err2) + if reterr := WithoutContext(err); reterr != err1 { + t.Errorf("expected err") + } + if reterr := WithoutContext(err); reterr != err2 { + t.Errorf("expected err") + } +} + +func TestWithoutContext7(t *testing.T) { + err1 := fmt.Errorf("err1") + err2 := fmt.Errorf("err2") + err := Append(err1, err2) + if reterr := WithoutContext(err); reterr.Error() != err.Error() { + t.Errorf("expected err") + } +} + +func TestWithoutContext8(t *testing.T) { + err1 := fmt.Errorf("err1") + err2 := fmt.Errorf("err2") + err3 := fmt.Errorf("err3") + err := Join([]error{err1, err2, err3}) + if reterr := WithoutContext(err); reterr.Error() != err.Error() { + t.Errorf("expected err") + } +} + +func TestWithoutContext9(t *testing.T) { + err1 := context.Canceled + err2 := fmt.Errorf("err2") + err3 := fmt.Errorf("err3") + err := Join([]error{err1, err2, err3}) + exp := Join([]error{nil, err2, err3}) + if reterr := WithoutContext(err); reterr.Error() != exp.Error() { + t.Errorf("expected err") + } +} + +func TestWithoutContext10(t *testing.T) { + err1 := fmt.Errorf("err1") + err2 := context.Canceled + err3 := fmt.Errorf("err3") + err := Join([]error{err1, err2, err3}) + exp := Join([]error{err1, nil, err3}) + if reterr := WithoutContext(err); reterr.Error() != exp.Error() { + t.Errorf("expected err") + } +} + +func TestWithoutContext11(t *testing.T) { + err1 := fmt.Errorf("err1") + err2 := fmt.Errorf("err2") + err3 := context.Canceled + err := Join([]error{err1, err2, err3}) + exp := Join([]error{err1, err2, nil}) + if reterr := WithoutContext(err); reterr.Error() != exp.Error() { + t.Errorf("expected err") + } +} + +func TestWithoutContext12(t *testing.T) { + err1 := fmt.Errorf("err1") + err2 := context.Canceled + err3 := context.Canceled + err := Join([]error{err1, err2, err3}) + if reterr := WithoutContext(err); reterr != err1 { + t.Errorf("expected err") + } +} + +func TestWithoutContext13(t *testing.T) { + err1 := context.Canceled + err2 := fmt.Errorf("err2") + err3 := context.Canceled + err := Join([]error{err1, err2, err3}) + if reterr := WithoutContext(err); reterr != err2 { + t.Errorf("expected err") + } +} + +func TestWithoutContext14(t *testing.T) { + err1 := context.Canceled + err2 := context.Canceled + err3 := fmt.Errorf("err3") + err := Join([]error{err1, err2, err3}) + if reterr := WithoutContext(err); reterr != err3 { + t.Errorf("expected err") + } +} + +func TestWithoutContext15(t *testing.T) { + err1 := context.Canceled + err2 := context.Canceled + err3 := context.Canceled + err := Join([]error{err1, err2, err3}) + if reterr := WithoutContext(err); reterr != context.Canceled { + t.Errorf("expected err") + } +} + func TestString1(t *testing.T) { var err error if String(err) != "" {