Skip to content

Commit db4cda2

Browse files
committed
testing/iotest: correct ErrReader signature and remove exported error
Corrects ErrReader's signature to what was accepted in the approved proposal, and also removes an exported ErrIO which wasn't part of the proposal and is unnecessary. The new signature allows users to customize their own errors. While here, started examples, with ErrReader leading the way. Updates #38781 Change-Id: Ia7f84721f11061343cfef8b1adc2b7b69bc3f43c Reviewed-on: https://go-review.googlesource.com/c/go/+/248898 Run-TryBot: Emmanuel Odeke <emm.odeke@gmail.com> Run-TryBot: Ian Lance Taylor <iant@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Ian Lance Taylor <iant@golang.org>
1 parent 77a11c0 commit db4cda2

File tree

4 files changed

+53
-17
lines changed

4 files changed

+53
-17
lines changed

src/testing/iotest/example_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright 2020 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package iotest_test
6+
7+
import (
8+
"errors"
9+
"fmt"
10+
"testing/iotest"
11+
)
12+
13+
func ExampleErrReader() {
14+
// A reader that always returns a custom error.
15+
r := iotest.ErrReader(errors.New("custom error"))
16+
n, err := r.Read(nil)
17+
fmt.Printf("n: %d\nerr: %q\n", n, err)
18+
19+
// Output:
20+
// n: 0
21+
// err: "custom error"
22+
}

src/testing/iotest/logger_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ func TestReadLogger_errorOnRead(t *testing.T) {
138138
data := []byte("Hello, World!")
139139
p := make([]byte, len(data))
140140

141-
lr := ErrReader()
141+
lr := ErrReader(errors.New("io failure"))
142142
rl := NewReadLogger("read", lr)
143143
n, err := rl.Read(p)
144144
if err == nil {
145145
t.Fatalf("Unexpectedly succeeded to read: %v", err)
146146
}
147147

148-
wantLogWithHex := fmt.Sprintf("lr: read %x: %v\n", p[:n], "io")
148+
wantLogWithHex := fmt.Sprintf("lr: read %x: io failure\n", p[:n])
149149
if g, w := lOut.String(), wantLogWithHex; g != w {
150150
t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
151151
}

src/testing/iotest/reader.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,15 @@ func (r *timeoutReader) Read(p []byte) (int, error) {
8888
return r.r.Read(p)
8989
}
9090

91-
// ErrIO is a fake IO error.
92-
var ErrIO = errors.New("io")
93-
94-
// ErrReader returns a fake error every time it is read from.
95-
func ErrReader() io.Reader {
96-
return errReader(0)
91+
// ErrReader returns an io.Reader that returns 0, err from all Read calls.
92+
func ErrReader(err error) io.Reader {
93+
return &alwaysErrReader{err: err}
9794
}
9895

99-
type errReader int
96+
type alwaysErrReader struct {
97+
err error
98+
}
10099

101-
func (r errReader) Read(p []byte) (int, error) {
102-
return 0, ErrIO
100+
func (aer *alwaysErrReader) Read(p []byte) (int, error) {
101+
return 0, aer.err
103102
}

src/testing/iotest/reader_test.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package iotest
66

77
import (
88
"bytes"
9+
"errors"
910
"io"
1011
"testing"
1112
)
@@ -226,11 +227,25 @@ func TestDataErrReader_emptyReader(t *testing.T) {
226227
}
227228

228229
func TestErrReader(t *testing.T) {
229-
n, err := ErrReader().Read([]byte{})
230-
if err != ErrIO {
231-
t.Errorf("ErrReader.Read(any) should have returned ErrIO, returned %v", err)
232-
}
233-
if n != 0 {
234-
t.Errorf("ErrReader.Read(any) should have read 0 bytes, read %v", n)
230+
cases := []struct {
231+
name string
232+
err error
233+
}{
234+
{"nil error", nil},
235+
{"non-nil error", errors.New("io failure")},
236+
{"io.EOF", io.EOF},
237+
}
238+
239+
for _, tt := range cases {
240+
tt := tt
241+
t.Run(tt.name, func(t *testing.T) {
242+
n, err := ErrReader(tt.err).Read(nil)
243+
if err != tt.err {
244+
t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, tt.err)
245+
}
246+
if n != 0 {
247+
t.Fatalf("Byte count mismatch: got %d want 0", n)
248+
}
249+
})
235250
}
236251
}

0 commit comments

Comments
 (0)