diff --git a/src/encoding/base32/base32.go b/src/encoding/base32/base32.go index de95df0043..e921887285 100644 --- a/src/encoding/base32/base32.go +++ b/src/encoding/base32/base32.go @@ -402,7 +402,13 @@ func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { // and returns the extended buffer. // If the input is malformed, it returns the partially decoded src and an error. func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) { - n := enc.DecodedLen(len(src)) + // Compute the output size without padding to avoid over allocating. + n := len(src) + for n > 0 && rune(src[n-1]) == enc.padChar { + n-- + } + n = decodedLen(n, NoPadding) + dst = slices.Grow(dst, n) n, err := enc.Decode(dst[len(dst):][:n], src) return dst[:len(dst)+n], err @@ -567,7 +573,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader { // DecodedLen returns the maximum length in bytes of the decoded data // corresponding to n bytes of base32-encoded data. func (enc *Encoding) DecodedLen(n int) int { - if enc.padChar == NoPadding { + return decodedLen(n, enc.padChar) +} + +func decodedLen(n int, padChar rune) int { + if padChar == NoPadding { return n/8*5 + n%8*5/8 } return n / 8 * 5 diff --git a/src/encoding/base32/base32_test.go b/src/encoding/base32/base32_test.go index 0132744507..33638adeac 100644 --- a/src/encoding/base32/base32_test.go +++ b/src/encoding/base32/base32_test.go @@ -110,6 +110,13 @@ func TestDecode(t *testing.T) { dst, err := StdEncoding.AppendDecode([]byte("lead"), []byte(p.encoded)) testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded) + + dst2, err := StdEncoding.AppendDecode(dst[:0:len(p.decoded)], []byte(p.encoded)) + testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded) + if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] { + t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst)) + } } } diff --git a/src/encoding/base64/base64.go b/src/encoding/base64/base64.go index 802ef14c38..9445cbd4ef 100644 --- a/src/encoding/base64/base64.go +++ b/src/encoding/base64/base64.go @@ -411,7 +411,13 @@ func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err err // and returns the extended buffer. // If the input is malformed, it returns the partially decoded src and an error. func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) { - n := enc.DecodedLen(len(src)) + // Compute the output size without padding to avoid over allocating. + n := len(src) + for n > 0 && rune(src[n-1]) == enc.padChar { + n-- + } + n = decodedLen(n, NoPadding) + dst = slices.Grow(dst, n) n, err := enc.Decode(dst[len(dst):][:n], src) return dst[:len(dst)+n], err @@ -643,7 +649,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader { // DecodedLen returns the maximum length in bytes of the decoded data // corresponding to n bytes of base64-encoded data. func (enc *Encoding) DecodedLen(n int) int { - if enc.padChar == NoPadding { + return decodedLen(n, enc.padChar) +} + +func decodedLen(n int, padChar rune) int { + if padChar == NoPadding { // Unpadded data may end with partial block of 2-3 characters. return n/4*3 + n%4*6/8 } diff --git a/src/encoding/base64/base64_test.go b/src/encoding/base64/base64_test.go index 4d7437b919..6dfdaef1f1 100644 --- a/src/encoding/base64/base64_test.go +++ b/src/encoding/base64/base64_test.go @@ -167,6 +167,13 @@ func TestDecode(t *testing.T) { dst, err := tt.enc.AppendDecode([]byte("lead"), []byte(encoded)) testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded) + + dst2, err := tt.enc.AppendDecode(dst[:0:len(p.decoded)], []byte(encoded)) + testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded) + if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] { + t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst)) + } } } }