Skip to content

Commit 279ed80

Browse files
committed
feat: add Join and AsType standard library shortcuts
Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
1 parent a6e40ab commit 279ed80

File tree

4 files changed

+122
-0
lines changed

4 files changed

+122
-0
lines changed

astype.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//go:build !go1.26
2+
3+
package errors
4+
5+
// AsType finds the first error in err's chain that matches type E,
6+
// and if so, returns that error value and true.
7+
func AsType[E error](err error) (E, bool) {
8+
var target E
9+
return target, As(err, &target)
10+
}

astype_go126.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//go:build go1.26
2+
3+
package errors
4+
5+
import (
6+
stderrors "errors"
7+
)
8+
9+
// AsType finds the first error in err's chain that matches type E,
10+
// and if so, returns that error value and true.
11+
func AsType[E error](err error) (E, bool) {
12+
return stderrors.AsType[E](err)
13+
}

unwrap.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,11 @@ func As(err error, target any) bool { return stderrors.As(err, target) }
3434
func Unwrap(err error) error {
3535
return stderrors.Unwrap(err)
3636
}
37+
38+
// Join returns an error that wraps the given errors.
39+
// Any nil error values are discarded.
40+
// Join returns nil if every value in errs is nil.
41+
// The error formats each wrapped error, separated by newlines.
42+
func Join(errs ...error) error {
43+
return stderrors.Join(errs...)
44+
}

unwrap_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package errors
33
import (
44
stderrors "errors"
55
"fmt"
6+
"io"
67
"testing"
78
)
89

@@ -173,3 +174,93 @@ func TestUnwrap(t *testing.T) {
173174
})
174175
}
175176
}
177+
178+
func TestJoin(t *testing.T) {
179+
err1 := New("err1")
180+
err2 := New("err2")
181+
182+
tests := []struct {
183+
name string
184+
errs []error
185+
want string
186+
}{
187+
{
188+
name: "two errors",
189+
errs: []error{err1, err2},
190+
want: "err1\nerr2",
191+
},
192+
{
193+
name: "nil filtered",
194+
errs: []error{err1, nil, err2},
195+
want: "err1\nerr2",
196+
},
197+
}
198+
for _, tt := range tests {
199+
t.Run(tt.name, func(t *testing.T) {
200+
err := Join(tt.errs...)
201+
if err == nil {
202+
t.Fatal("Join() = nil, want non-nil")
203+
}
204+
if got := err.Error(); got != tt.want {
205+
t.Errorf("Join().Error() = %q, want %q", got, tt.want)
206+
}
207+
})
208+
}
209+
}
210+
211+
func TestJoinNil(t *testing.T) {
212+
if err := Join(); err != nil {
213+
t.Errorf("Join() = %v, want nil", err)
214+
}
215+
if err := Join(nil, nil); err != nil {
216+
t.Errorf("Join(nil, nil) = %v, want nil", err)
217+
}
218+
}
219+
220+
func TestWrapAsType(t *testing.T) {
221+
err := customErr{msg: "test"}
222+
wrapped := Wrap(err, "wrapped")
223+
224+
tests := []struct {
225+
name string
226+
fn func(error) (customErr, bool)
227+
}{
228+
{name: "AsType", fn: AsType[customErr]},
229+
{name: "stderrors.AsType", fn: stderrors.AsType[customErr]},
230+
}
231+
232+
for _, tt := range tests {
233+
t.Run(tt.name, func(t *testing.T) {
234+
got, ok := tt.fn(wrapped)
235+
if !ok {
236+
t.Fatalf("%s[customErr]() = false, want true", tt.name)
237+
}
238+
if got != err {
239+
t.Errorf("%s[customErr]() = %v, want %v", tt.name, got, err)
240+
}
241+
})
242+
}
243+
}
244+
245+
func TestAsTypeNotFound(t *testing.T) {
246+
err := io.EOF
247+
assertNotFound := func(name string, ok bool) {
248+
t.Helper()
249+
if ok {
250+
t.Errorf("%s[customErr](io.EOF) = true, want false", name)
251+
}
252+
}
253+
254+
tests := []struct {
255+
name string
256+
fn func(error) (customErr, bool)
257+
}{
258+
{name: "AsType", fn: AsType[customErr]},
259+
{name: "stderrors.AsType", fn: stderrors.AsType[customErr]},
260+
}
261+
262+
for _, tt := range tests {
263+
_, ok := tt.fn(err)
264+
assertNotFound(tt.name, ok)
265+
}
266+
}

0 commit comments

Comments
 (0)