diff --git a/src/compress/flate/inflate.go b/src/compress/flate/inflate.go index faa33cc6e9..d2b471f715 100644 --- a/src/compress/flate/inflate.go +++ b/src/compress/flate/inflate.go @@ -629,10 +629,7 @@ func (f *decompressor) dataBlock() { nr, err := io.ReadFull(f.r, f.buf[0:4]) f.roffset += int64(nr) if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - f.err = err + f.err = noEOF(err) return } n := int(f.buf[0]) | int(f.buf[1])<<8 @@ -665,10 +662,7 @@ func (f *decompressor) copyData() { f.copyLen -= cnt f.dict.writeMark(cnt) if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - f.err = err + f.err = noEOF(err) return } @@ -690,13 +684,18 @@ func (f *decompressor) finishBlock() { f.step = (*decompressor).nextBlock } +// noEOF returns err, unless err == io.EOF, in which case it returns io.ErrUnexpectedEOF. +func noEOF(e error) error { + if e == io.EOF { + return io.ErrUnexpectedEOF + } + return e +} + func (f *decompressor) moreBits() error { c, err := f.r.ReadByte() if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return err + return noEOF(err) } f.roffset++ f.b |= uint32(c) << f.nb @@ -711,25 +710,37 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) { // cases, the chunks slice will be 0 for the invalid sequence, leading it // satisfy the n == 0 check below. n := uint(h.min) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b for { - for f.nb < n { - if err := f.moreBits(); err != nil { - return 0, err + for nb < n { + c, err := f.r.ReadByte() + if err != nil { + f.b = b + f.nb = nb + return 0, noEOF(err) } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 } - chunk := h.chunks[f.b&(huffmanNumChunks-1)] + chunk := h.chunks[b&(huffmanNumChunks-1)] n = uint(chunk & huffmanCountMask) if n > huffmanChunkBits { - chunk = h.links[chunk>>huffmanValueShift][(f.b>>huffmanChunkBits)&h.linkMask] + chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask] n = uint(chunk & huffmanCountMask) } - if n <= f.nb { + if n <= nb { if n == 0 { + f.b = b + f.nb = nb f.err = CorruptInputError(f.roffset) return 0, f.err } - f.b >>= n - f.nb -= n + f.b = b >> (n & 31) + f.nb = nb - n return int(chunk >> huffmanValueShift), nil } }