diff --git a/src/errors/wrap.go b/src/errors/wrap.go
index e4a5ca3..aef3749 100644
--- a/src/errors/wrap.go
+++ b/src/errors/wrap.go
@@ -4,9 +4,7 @@
package errors
-import (
- "internal/reflectlite"
-)
+import "internal/reflectlite"
// Unwrap returns the result of calling the Unwrap method on err, if err's
// type contains an Unwrap method returning error.
@@ -47,7 +45,7 @@
return err == target
}
- isComparable := reflectlite.TypeOf(target).Comparable()
+ isComparable := reflectlite.TypeComparable(target)
return is(err, target, isComparable)
}
@@ -56,10 +54,33 @@
if targetComparable && err == target {
return true
}
- if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(target) {
- return true
- }
switch x := err.(type) {
+ case interface {
+ Is(error) bool
+ Unwrap() error
+ }:
+ if x.Is(target) {
+ return true
+ }
+ err = x.Unwrap()
+ if err == nil {
+ return false
+ }
+ case interface {
+ Is(error) bool
+ Unwrap() []error
+ }:
+ if x.Is(target) {
+ return true
+ }
+ for _, err := range x.Unwrap() {
+ if is(err, target, targetComparable) {
+ return true
+ }
+ }
+ return false
+ case interface{ Is(error) bool }:
+ return x.Is(target)
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
diff --git a/src/errors/wrap_test.go b/src/errors/wrap_test.go
index 81c795a..6f0cfc3 100644
--- a/src/errors/wrap_test.go
+++ b/src/errors/wrap_test.go
@@ -23,6 +23,12 @@
poser := &poser{"either 1 or 3", func(err error) bool {
return err == err1 || err == err3
}}
+ poserWrap := poserWrapped{err: err1, f: func(err error) bool {
+ return err == err3
+ }}
+ poserMulti := poserMultiErr{errs: []error{err1}, f: func(err error) bool {
+ return err == err3
+ }}
testCases := []struct {
err error
@@ -42,6 +48,10 @@
{poser, err3, true},
{poser, erra, false},
{poser, errb, false},
+ {poserWrap, err1, true},
+ {poserWrap, err3, true},
+ {poserMulti, err1, true},
+ {poserMulti, err3, true},
{errorUncomparable{}, errorUncomparable{}, true},
{errorUncomparable{}, &errorUncomparable{}, false},
{&errorUncomparable{}, errorUncomparable{}, true},
@@ -92,6 +102,24 @@
return true
}
+type poserWrapped struct {
+ err error
+ f func(error) bool
+}
+
+func (p poserWrapped) Error() string { return "poserWrapped" }
+func (p poserWrapped) Is(err error) bool { return p.f(err) }
+func (p poserWrapped) Unwrap() error { return p.err }
+
+type poserMultiErr struct {
+ errs []error
+ f func(error) bool
+}
+
+func (p poserMultiErr) Error() string { return "poserMultiErr" }
+func (p poserMultiErr) Is(err error) bool { return p.f(err) }
+func (p poserMultiErr) Unwrap() []error { return p.errs }
+
func TestAs(t *testing.T) {
var errT errorT
var errP *fs.PathError
@@ -367,6 +395,75 @@
}
}
+func BenchmarkIsDirectHit(b *testing.B) {
+ err := errors.New("x")
+ for i := 0; i < b.N; i++ {
+ if !errors.Is(err, err) {
+ b.Fatal("Is failed")
+ }
+ }
+}
+
+func BenchmarkIsDirectMiss(b *testing.B) {
+ err := errors.New("x")
+ target := errors.New("y")
+ for i := 0; i < b.N; i++ {
+ if errors.Is(err, target) {
+ b.Fatal("Is succeeded")
+ }
+ }
+}
+
+func BenchmarkIsWrappedHit(b *testing.B) {
+ target := errors.New("x")
+ err := wrapped{"wrap", wrapped{"wrap", wrapped{"wrap", target}}}
+ for i := 0; i < b.N; i++ {
+ if !errors.Is(err, target) {
+ b.Fatal("Is failed")
+ }
+ }
+}
+
+func BenchmarkIsWrappedMiss(b *testing.B) {
+ target := errors.New("x")
+ err := wrapped{"wrap", wrapped{"wrap", wrapped{"wrap", errors.New("y")}}}
+ for i := 0; i < b.N; i++ {
+ if errors.Is(err, target) {
+ b.Fatal("Is succeeded")
+ }
+ }
+}
+
+func BenchmarkIsJoinHit(b *testing.B) {
+ target := errors.New("x")
+ err := multiErr{errorT{"a"}, multiErr{errorT{"b"}, target}, errorT{"c"}}
+ for i := 0; i < b.N; i++ {
+ if !errors.Is(err, target) {
+ b.Fatal("Is failed")
+ }
+ }
+}
+
+func BenchmarkIsJoinMiss(b *testing.B) {
+ target := errors.New("x")
+ err := multiErr{errorT{"a"}, multiErr{errorT{"b"}, errors.New("y")}, errorT{"c"}}
+ for i := 0; i < b.N; i++ {
+ if errors.Is(err, target) {
+ b.Fatal("Is succeeded")
+ }
+ }
+}
+
+func BenchmarkIsUncomparableTarget(b *testing.B) {
+ err := errorUncomparable{}
+ target := errorUncomparable{}
+ for i := 0; i < b.N; i++ {
+ if !errors.Is(err, target) {
+ b.Fatal("Is failed")
+ }
+ }
+}
+
func BenchmarkAs(b *testing.B) {
err := multiErr{multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}}}
for i := 0; i < b.N; i++ {
diff --git a/src/internal/reflectlite/type.go b/src/internal/reflectlite/type.go
index 88cc50d..1cacbc0 100644
--- a/src/internal/reflectlite/type.go
+++ b/src/internal/reflectlite/type.go
@@ -387,6 +387,12 @@
return toType(abi.TypeOf(i))
}
+// TypeComparable reports whether values of the dynamic type of i are comparable.
+// It is equivalent to TypeOf(i).Comparable() for non-nil i.
+func TypeComparable(i any) bool {
+ return abi.TypeOf(i).Equal != nil
+}
+
func (t rtype) Implements(u Type) bool {
if u == nil {
panic("reflect: nil type passed to Type.Implements")