util: errwrap: Add unwrapping for context removal

It's common in many concurrent engines to have a situation where we
collect errors on shutdown. Errors can either because a context closed,
or because some engine error happened. The latter, can also cause the
former, leading to a list of returned errors. In these scenarios, we
want to filter out all the secondary context errors, unless that's all
that's there. This provides a helper function to do so.
This commit is contained in:
James Shubin
2025-07-16 23:46:09 -04:00
parent 299b49bb17
commit bdf5209f68
2 changed files with 181 additions and 0 deletions

View File

@@ -31,6 +31,8 @@
package errwrap package errwrap
import ( import (
"context"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -72,6 +74,42 @@ func Join(errs []error) error {
return reterr return reterr
} }
// WithoutContext passes through an error if it's nil or is a normal error. If
// it is a multierror, then it removes all errors of context cancellation unless
// that's all there is. If there's one error remaining, it returns that.
// Otherwise it returns a new multierror with what's left.
func WithoutContext(err error) error {
if err == nil {
return nil
}
multiErr, ok := err.(*multierror.Error)
if !ok {
return err
}
if len(multiErr.Errors) == 0 {
return err // unexpected
}
errs := []error{}
for _, e := range multiErr.Errors {
if e == context.Canceled {
continue // remove
}
errs = append(errs, e)
}
if len(errs) == 0 {
return context.Canceled
}
//if len(errs) == 1 {
// return errs[0]
//}
return Join(errs)
}
// String returns a string representation of the error. In particular, if the // String returns a string representation of the error. In particular, if the
// error is nil, it returns an empty string instead of panicking. // error is nil, it returns an empty string instead of panicking.
func String(err error) string { func String(err error) string {

View File

@@ -32,6 +32,7 @@
package errwrap package errwrap
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
@@ -163,6 +164,148 @@ func TestJoinErr10(t *testing.T) {
} }
} }
func TestWithoutContext1(t *testing.T) {
if reterr := WithoutContext(nil); reterr != nil {
t.Errorf("expected nil result")
}
}
func TestWithoutContext2(t *testing.T) {
err := fmt.Errorf("err")
if reterr := WithoutContext(err); reterr != err {
t.Errorf("expected err")
}
}
func TestWithoutContext3(t *testing.T) {
err := context.Canceled
if reterr := WithoutContext(err); reterr != err {
t.Errorf("expected err")
}
}
func TestWithoutContext4(t *testing.T) {
err1 := fmt.Errorf("err")
err2 := context.Canceled
err := Append(err1, err2)
if reterr := WithoutContext(err); reterr != err1 {
t.Errorf("expected err")
}
}
func TestWithoutContext5(t *testing.T) {
err1 := context.Canceled
err2 := fmt.Errorf("err")
err := Append(err1, err2)
if reterr := WithoutContext(err); reterr != err2 {
t.Errorf("expected err")
}
}
func TestWithoutContext6(t *testing.T) {
err1 := context.Canceled
err2 := context.Canceled
err := Append(err1, err2)
if reterr := WithoutContext(err); reterr != err1 {
t.Errorf("expected err")
}
if reterr := WithoutContext(err); reterr != err2 {
t.Errorf("expected err")
}
}
func TestWithoutContext7(t *testing.T) {
err1 := fmt.Errorf("err1")
err2 := fmt.Errorf("err2")
err := Append(err1, err2)
if reterr := WithoutContext(err); reterr.Error() != err.Error() {
t.Errorf("expected err")
}
}
func TestWithoutContext8(t *testing.T) {
err1 := fmt.Errorf("err1")
err2 := fmt.Errorf("err2")
err3 := fmt.Errorf("err3")
err := Join([]error{err1, err2, err3})
if reterr := WithoutContext(err); reterr.Error() != err.Error() {
t.Errorf("expected err")
}
}
func TestWithoutContext9(t *testing.T) {
err1 := context.Canceled
err2 := fmt.Errorf("err2")
err3 := fmt.Errorf("err3")
err := Join([]error{err1, err2, err3})
exp := Join([]error{nil, err2, err3})
if reterr := WithoutContext(err); reterr.Error() != exp.Error() {
t.Errorf("expected err")
}
}
func TestWithoutContext10(t *testing.T) {
err1 := fmt.Errorf("err1")
err2 := context.Canceled
err3 := fmt.Errorf("err3")
err := Join([]error{err1, err2, err3})
exp := Join([]error{err1, nil, err3})
if reterr := WithoutContext(err); reterr.Error() != exp.Error() {
t.Errorf("expected err")
}
}
func TestWithoutContext11(t *testing.T) {
err1 := fmt.Errorf("err1")
err2 := fmt.Errorf("err2")
err3 := context.Canceled
err := Join([]error{err1, err2, err3})
exp := Join([]error{err1, err2, nil})
if reterr := WithoutContext(err); reterr.Error() != exp.Error() {
t.Errorf("expected err")
}
}
func TestWithoutContext12(t *testing.T) {
err1 := fmt.Errorf("err1")
err2 := context.Canceled
err3 := context.Canceled
err := Join([]error{err1, err2, err3})
if reterr := WithoutContext(err); reterr != err1 {
t.Errorf("expected err")
}
}
func TestWithoutContext13(t *testing.T) {
err1 := context.Canceled
err2 := fmt.Errorf("err2")
err3 := context.Canceled
err := Join([]error{err1, err2, err3})
if reterr := WithoutContext(err); reterr != err2 {
t.Errorf("expected err")
}
}
func TestWithoutContext14(t *testing.T) {
err1 := context.Canceled
err2 := context.Canceled
err3 := fmt.Errorf("err3")
err := Join([]error{err1, err2, err3})
if reterr := WithoutContext(err); reterr != err3 {
t.Errorf("expected err")
}
}
func TestWithoutContext15(t *testing.T) {
err1 := context.Canceled
err2 := context.Canceled
err3 := context.Canceled
err := Join([]error{err1, err2, err3})
if reterr := WithoutContext(err); reterr != context.Canceled {
t.Errorf("expected err")
}
}
func TestString1(t *testing.T) { func TestString1(t *testing.T) {
var err error var err error
if String(err) != "" { if String(err) != "" {