diff --git a/src/io/multi.go b/src/io/multi.go index 24ee71e4ca..909b7e4523 100644 --- a/src/io/multi.go +++ b/src/io/multi.go @@ -41,6 +41,31 @@ func (mr *multiReader) Read(p []byte) (n int, err error) { return 0, EOF } +func (mr *multiReader) WriteTo(w Writer) (sum int64, err error) { + return mr.writeToWithBuffer(w, make([]byte, 1024 * 32)) +} + +func (mr *multiReader) writeToWithBuffer(w Writer, buf []byte) (sum int64, err error) { + for i, r := range mr.readers { + var n int64 + if subMr, ok := r.(*multiReader); ok { // reuse buffer with nested multiReaders + n, err = subMr.writeToWithBuffer(w, buf) + } else { + n, err = copyBuffer(w, r, buf) + } + sum += n + if err != nil { + mr.readers = mr.readers[i:] // permit resume / retry after error + return sum, err + } + mr.readers[i] = nil // permit early GC + } + mr.readers = nil + return sum, nil +} + +var _ WriterTo = (*multiReader)(nil) + // MultiReader returns a Reader that's the logical concatenation of // the provided input readers. They're read sequentially. Once all // inputs have returned EOF, Read will return EOF. If any of the readers diff --git a/src/io/multi_test.go b/src/io/multi_test.go index e877e54571..679312c23b 100644 --- a/src/io/multi_test.go +++ b/src/io/multi_test.go @@ -63,6 +63,31 @@ func TestMultiReader(t *testing.T) { }) } +func TestMultiReaderAsWriterTo(t *testing.T) { + mr := MultiReader( + strings.NewReader("foo "), + MultiReader( // Tickle the buffer reusing codepath + strings.NewReader(""), + strings.NewReader("bar"), + ), + ) + mrAsWriterTo, ok := mr.(WriterTo) + if !ok { + t.Fatalf("expected cast to WriterTo to succeed") + } + sink := &strings.Builder{} + n, err := mrAsWriterTo.WriteTo(sink) + if err != nil { + t.Fatalf("expected no error; got %v", err) + } + if n != 7 { + t.Errorf("expected read 7 bytes; got %d", n) + } + if result := sink.String(); result != "foo bar" { + t.Errorf(`expected "foo bar"; got %q`, result) + } +} + func TestMultiWriter(t *testing.T) { sink := new(bytes.Buffer) // Hide bytes.Buffer's WriteString method: