diff --git a/src/testing/iotest/example_test.go b/src/testing/iotest/example_test.go new file mode 100644 index 0000000000..10f6bd38f7 --- /dev/null +++ b/src/testing/iotest/example_test.go @@ -0,0 +1,22 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iotest_test + +import ( + "errors" + "fmt" + "testing/iotest" +) + +func ExampleErrReader() { + // A reader that always returns a custom error. + r := iotest.ErrReader(errors.New("custom error")) + n, err := r.Read(nil) + fmt.Printf("n: %d\nerr: %q\n", n, err) + + // Output: + // n: 0 + // err: "custom error" +} diff --git a/src/testing/iotest/logger_test.go b/src/testing/iotest/logger_test.go index 575f37e05c..fec4467cc6 100644 --- a/src/testing/iotest/logger_test.go +++ b/src/testing/iotest/logger_test.go @@ -138,14 +138,14 @@ func TestReadLogger_errorOnRead(t *testing.T) { data := []byte("Hello, World!") p := make([]byte, len(data)) - lr := ErrReader() + lr := ErrReader(errors.New("io failure")) rl := NewReadLogger("read", lr) n, err := rl.Read(p) if err == nil { t.Fatalf("Unexpectedly succeeded to read: %v", err) } - wantLogWithHex := fmt.Sprintf("lr: read %x: %v\n", p[:n], "io") + wantLogWithHex := fmt.Sprintf("lr: read %x: io failure\n", p[:n]) if g, w := lOut.String(), wantLogWithHex; g != w { t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w) } diff --git a/src/testing/iotest/reader.go b/src/testing/iotest/reader.go index b18e912f27..bc2f72a911 100644 --- a/src/testing/iotest/reader.go +++ b/src/testing/iotest/reader.go @@ -88,16 +88,15 @@ func (r *timeoutReader) Read(p []byte) (int, error) { return r.r.Read(p) } -// ErrIO is a fake IO error. -var ErrIO = errors.New("io") - -// ErrReader returns a fake error every time it is read from. -func ErrReader() io.Reader { - return errReader(0) +// ErrReader returns an io.Reader that returns 0, err from all Read calls. +func ErrReader(err error) io.Reader { + return &alwaysErrReader{err: err} } -type errReader int - -func (r errReader) Read(p []byte) (int, error) { - return 0, ErrIO +type alwaysErrReader struct { + err error +} + +func (aer *alwaysErrReader) Read(p []byte) (int, error) { + return 0, aer.err } diff --git a/src/testing/iotest/reader_test.go b/src/testing/iotest/reader_test.go index ccba22ee29..6004e841e5 100644 --- a/src/testing/iotest/reader_test.go +++ b/src/testing/iotest/reader_test.go @@ -6,6 +6,7 @@ package iotest import ( "bytes" + "errors" "io" "testing" ) @@ -226,11 +227,25 @@ func TestDataErrReader_emptyReader(t *testing.T) { } func TestErrReader(t *testing.T) { - n, err := ErrReader().Read([]byte{}) - if err != ErrIO { - t.Errorf("ErrReader.Read(any) should have returned ErrIO, returned %v", err) + cases := []struct { + name string + err error + }{ + {"nil error", nil}, + {"non-nil error", errors.New("io failure")}, + {"io.EOF", io.EOF}, } - if n != 0 { - t.Errorf("ErrReader.Read(any) should have read 0 bytes, read %v", n) + + for _, tt := range cases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + n, err := ErrReader(tt.err).Read(nil) + if err != tt.err { + t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, tt.err) + } + if n != 0 { + t.Fatalf("Byte count mismatch: got %d want 0", n) + } + }) } }