diff options
Diffstat (limited to 'vendor/github.com/klauspost/compress/zstd/decoder.go')
-rw-r--r-- | vendor/github.com/klauspost/compress/zstd/decoder.go | 631 |
1 files changed, 500 insertions, 131 deletions
diff --git a/vendor/github.com/klauspost/compress/zstd/decoder.go b/vendor/github.com/klauspost/compress/zstd/decoder.go index f430f58b..9fcdaac1 100644 --- a/vendor/github.com/klauspost/compress/zstd/decoder.go +++ b/vendor/github.com/klauspost/compress/zstd/decoder.go @@ -5,9 +5,13 @@ package zstd import ( - "errors" + "bytes" + "context" + "encoding/binary" "io" "sync" + + "github.com/klauspost/compress/zstd/internal/xxhash" ) // Decoder provides decoding of zstandard streams. @@ -22,12 +26,19 @@ type Decoder struct { // Unreferenced decoders, ready for use. decoders chan *blockDec - // Streams ready to be decoded. - stream chan decodeStream - // Current read position used for Reader functionality. current decoderState + // sync stream decoding + syncStream struct { + decodedFrame uint64 + br readerWrapper + enabled bool + inFrame bool + } + + frame *frameDec + // Custom dictionaries. // Always uses copies. dicts map[uint32]dict @@ -46,7 +57,10 @@ type decoderState struct { output chan decodeOutput // cancel remaining output. - cancel chan struct{} + cancel context.CancelFunc + + // crc of current frame + crc *xxhash.Digest flushed bool } @@ -81,7 +95,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) { return nil, err } } - d.current.output = make(chan decodeOutput, d.o.concurrent) + d.current.crc = xxhash.New() d.current.flushed = true if r == nil { @@ -130,7 +144,7 @@ func (d *Decoder) Read(p []byte) (int, error) { break } if !d.nextBlock(n == 0) { - return n, nil + return n, d.current.err } } } @@ -162,6 +176,7 @@ func (d *Decoder) Reset(r io.Reader) error { d.drainOutput() + d.syncStream.br.r = nil if r == nil { d.current.err = ErrDecoderNilInput if len(d.current.b) > 0 { @@ -195,33 +210,39 @@ func (d *Decoder) Reset(r io.Reader) error { } return nil } - - if d.stream == nil { - d.stream = make(chan decodeStream, 1) - d.streamWg.Add(1) - go d.startStreamDecoder(d.stream) - } - // Remove current block. + d.stashDecoder() d.current.decodeOutput = decodeOutput{} d.current.err = nil - d.current.cancel = make(chan struct{}) d.current.flushed = false d.current.d = nil - d.stream <- decodeStream{ - r: r, - output: d.current.output, - cancel: d.current.cancel, + // Ensure no-one else is still running... + d.streamWg.Wait() + if d.frame == nil { + d.frame = newFrameDec(d.o) + } + + if d.o.concurrent == 1 { + return d.startSyncDecoder(r) } + + d.current.output = make(chan decodeOutput, d.o.concurrent) + ctx, cancel := context.WithCancel(context.Background()) + d.current.cancel = cancel + d.streamWg.Add(1) + go d.startStreamDecoder(ctx, r, d.current.output) + return nil } // drainOutput will drain the output until errEndOfStream is sent. func (d *Decoder) drainOutput() { if d.current.cancel != nil { - println("cancelling current") - close(d.current.cancel) + if debugDecoder { + println("cancelling current") + } + d.current.cancel() d.current.cancel = nil } if d.current.d != nil { @@ -243,12 +264,9 @@ func (d *Decoder) drainOutput() { } d.decoders <- v.d } - if v.err == errEndOfStream { - println("current flushed") - d.current.flushed = true - return - } } + d.current.output = nil + d.current.flushed = true } // WriteTo writes data to w until there's no more data to write or when an error occurs. @@ -287,7 +305,7 @@ func (d *Decoder) WriteTo(w io.Writer) (int64, error) { // DecodeAll can be used concurrently. // The Decoder concurrency limits will be respected. func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { - if d.current.err == ErrDecoderClosed { + if d.decoders == nil { return dst, ErrDecoderClosed } @@ -300,6 +318,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } frame.rawInput = nil frame.bBuf = nil + if frame.history.decoders.br != nil { + frame.history.decoders.br.in = nil + } d.decoders <- block }() frame.bBuf = input @@ -307,27 +328,31 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { for { frame.history.reset() err := frame.reset(&frame.bBuf) - if err == io.EOF { - if debugDecoder { - println("frame reset return EOF") + if err != nil { + if err == io.EOF { + if debugDecoder { + println("frame reset return EOF") + } + return dst, nil } - return dst, nil + return dst, err } if frame.DictionaryID != nil { dict, ok := d.dicts[*frame.DictionaryID] if !ok { return nil, ErrUnknownDictionary } + if debugDecoder { + println("setting dict", frame.DictionaryID) + } frame.history.setDict(&dict) } - if err != nil { - return dst, err - } - if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { + + if frame.FrameContentSize != fcsUnknown && frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { return dst, ErrDecoderSizeExceeded } - if frame.FrameContentSize > 0 && frame.FrameContentSize < 1<<30 { - // Never preallocate moe than 1 GB up front. + if frame.FrameContentSize < 1<<30 { + // Never preallocate more than 1 GB up front. if cap(dst)-len(dst) < int(frame.FrameContentSize) { dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)) copy(dst2, dst) @@ -368,33 +393,170 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { // If non-blocking mode is used the returned boolean will be false // if no data was available without blocking. func (d *Decoder) nextBlock(blocking bool) (ok bool) { - if d.current.d != nil { - if debugDecoder { - printf("re-adding current decoder %p", d.current.d) - } - d.decoders <- d.current.d - d.current.d = nil - } if d.current.err != nil { // Keep error state. - return blocking + return false } + d.current.b = d.current.b[:0] + // SYNC: + if d.syncStream.enabled { + if !blocking { + return false + } + ok = d.nextBlockSync() + if !ok { + d.stashDecoder() + } + return ok + } + + //ASYNC: + d.stashDecoder() if blocking { - d.current.decodeOutput = <-d.current.output + d.current.decodeOutput, ok = <-d.current.output } else { select { - case d.current.decodeOutput = <-d.current.output: + case d.current.decodeOutput, ok = <-d.current.output: default: return false } } + if !ok { + // This should not happen, so signal error state... + d.current.err = io.ErrUnexpectedEOF + return false + } + next := d.current.decodeOutput + if next.d != nil && next.d.async.newHist != nil { + d.current.crc.Reset() + } if debugDecoder { - println("got", len(d.current.b), "bytes, error:", d.current.err) + var tmp [4]byte + binary.LittleEndian.PutUint32(tmp[:], uint32(xxhash.Sum64(next.b))) + println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp) + } + + if len(next.b) > 0 { + n, err := d.current.crc.Write(next.b) + if err == nil { + if n != len(next.b) { + d.current.err = io.ErrShortWrite + } + } + } + if next.err == nil && next.d != nil && len(next.d.checkCRC) != 0 { + got := d.current.crc.Sum64() + var tmp [4]byte + binary.LittleEndian.PutUint32(tmp[:], uint32(got)) + if !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC { + if debugDecoder { + println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)") + } + d.current.err = ErrCRCMismatch + } else { + if debugDecoder { + println("CRC ok", tmp[:]) + } + } + } + + return true +} + +func (d *Decoder) nextBlockSync() (ok bool) { + if d.current.d == nil { + d.current.d = <-d.decoders + } + for len(d.current.b) == 0 { + if !d.syncStream.inFrame { + d.frame.history.reset() + d.current.err = d.frame.reset(&d.syncStream.br) + if d.current.err != nil { + return false + } + if d.frame.DictionaryID != nil { + dict, ok := d.dicts[*d.frame.DictionaryID] + if !ok { + d.current.err = ErrUnknownDictionary + return false + } else { + d.frame.history.setDict(&dict) + } + } + if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { + d.current.err = ErrDecoderSizeExceeded + return false + } + + d.syncStream.decodedFrame = 0 + d.syncStream.inFrame = true + } + d.current.err = d.frame.next(d.current.d) + if d.current.err != nil { + return false + } + d.frame.history.ensureBlock() + if debugDecoder { + println("History trimmed:", len(d.frame.history.b), "decoded already:", d.syncStream.decodedFrame) + } + histBefore := len(d.frame.history.b) + d.current.err = d.current.d.decodeBuf(&d.frame.history) + + if d.current.err != nil { + println("error after:", d.current.err) + return false + } + d.current.b = d.frame.history.b[histBefore:] + if debugDecoder { + println("history after:", len(d.frame.history.b)) + } + + // Check frame size (before CRC) + d.syncStream.decodedFrame += uint64(len(d.current.b)) + if d.syncStream.decodedFrame > d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } + d.current.err = ErrFrameSizeExceeded + return false + } + + // Check FCS + if d.current.d.Last && d.frame.FrameContentSize != fcsUnknown && d.syncStream.decodedFrame != d.frame.FrameContentSize { + if debugDecoder { + printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize) + } + d.current.err = ErrFrameSizeMismatch + return false + } + + // Update/Check CRC + if d.frame.HasCheckSum { + d.frame.crc.Write(d.current.b) + if d.current.d.Last { + d.current.err = d.frame.checkCRC() + if d.current.err != nil { + println("CRC error:", d.current.err) + return false + } + } + } + d.syncStream.inFrame = !d.current.d.Last } return true } +func (d *Decoder) stashDecoder() { + if d.current.d != nil { + if debugDecoder { + printf("re-adding current decoder %p", d.current.d) + } + d.decoders <- d.current.d + d.current.d = nil + } +} + // Close will release all resources. // It is NOT possible to reuse the decoder after this. func (d *Decoder) Close() { @@ -402,10 +564,10 @@ func (d *Decoder) Close() { return } d.drainOutput() - if d.stream != nil { - close(d.stream) + if d.current.cancel != nil { + d.current.cancel() d.streamWg.Wait() - d.stream = nil + d.current.cancel = nil } if d.decoders != nil { close(d.decoders) @@ -456,100 +618,307 @@ type decodeOutput struct { err error } -type decodeStream struct { - r io.Reader - - // Blocks ready to be written to output. - output chan decodeOutput - - // cancel reading from the input - cancel chan struct{} +func (d *Decoder) startSyncDecoder(r io.Reader) error { + d.frame.history.reset() + d.syncStream.br = readerWrapper{r: r} + d.syncStream.inFrame = false + d.syncStream.enabled = true + d.syncStream.decodedFrame = 0 + return nil } -// errEndOfStream indicates that everything from the stream was read. -var errEndOfStream = errors.New("end-of-stream") - // Create Decoder: -// Spawn n block decoders. These accept tasks to decode a block. -// Create goroutine that handles stream processing, this will send history to decoders as they are available. -// Decoders update the history as they decode. -// When a block is returned: -// a) history is sent to the next decoder, -// b) content written to CRC. -// c) return data to WRITER. -// d) wait for next block to return data. -// Once WRITTEN, the decoders reused by the writer frame decoder for re-use. -func (d *Decoder) startStreamDecoder(inStream chan decodeStream) { +// ASYNC: +// Spawn 4 go routines. +// 0: Read frames and decode blocks. +// 1: Decode block and literals. Receives hufftree and seqdecs, returns seqdecs and huff tree. +// 2: Wait for recentOffsets if needed. Decode sequences, send recentOffsets. +// 3: Wait for stream history, execute sequences, send stream history. +func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output chan decodeOutput) { defer d.streamWg.Done() - frame := newFrameDec(d.o) - for stream := range inStream { - if debugDecoder { - println("got new stream") + br := readerWrapper{r: r} + + var seqPrepare = make(chan *blockDec, d.o.concurrent) + var seqDecode = make(chan *blockDec, d.o.concurrent) + var seqExecute = make(chan *blockDec, d.o.concurrent) + + // Async 1: Prepare blocks... + go func() { + var hist history + var hasErr bool + for block := range seqPrepare { + if hasErr { + if block != nil { + seqDecode <- block + } + continue + } + if block.async.newHist != nil { + if debugDecoder { + println("Async 1: new history") + } + hist.reset() + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) + } + } + if block.err != nil || block.Type != blockTypeCompressed { + hasErr = block.err != nil + seqDecode <- block + continue + } + + remain, err := block.decodeLiterals(block.data, &hist) + block.err = err + hasErr = block.err != nil + if err == nil { + block.async.literals = hist.decoders.literals + block.async.seqData = remain + } else if debugDecoder { + println("decodeLiterals error:", err) + } + seqDecode <- block } - br := readerWrapper{r: stream.r} - decodeStream: - for { - frame.history.reset() - err := frame.reset(&br) - if debugDecoder && err != nil { - println("Frame decoder returned", err) + close(seqDecode) + }() + + // Async 2: Decode sequences... + go func() { + var hist history + var hasErr bool + + for block := range seqDecode { + if hasErr { + if block != nil { + seqExecute <- block + } + continue } - if err == nil && frame.DictionaryID != nil { - dict, ok := d.dicts[*frame.DictionaryID] - if !ok { - err = ErrUnknownDictionary - } else { - frame.history.setDict(&dict) + if block.async.newHist != nil { + if debugDecoder { + println("Async 2: new history, recent:", block.async.newHist.recentOffsets) + } + hist.decoders = block.async.newHist.decoders + hist.recentOffsets = block.async.newHist.recentOffsets + hist.windowSize = block.async.newHist.windowSize + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) } } - if err != nil { - stream.output <- decodeOutput{ - err: err, + if block.err != nil || block.Type != blockTypeCompressed { + hasErr = block.err != nil + seqExecute <- block + continue + } + + hist.decoders.literals = block.async.literals + block.err = block.prepareSequences(block.async.seqData, &hist) + if debugDecoder && block.err != nil { + println("prepareSequences returned:", block.err) + } + hasErr = block.err != nil + if block.err == nil { + block.err = block.decodeSequences(&hist) + if debugDecoder && block.err != nil { + println("decodeSequences returned:", block.err) } - break + hasErr = block.err != nil + // block.async.sequence = hist.decoders.seq[:hist.decoders.nSeqs] + block.async.seqSize = hist.decoders.seqSize } - if debugDecoder { - println("starting frame decoder") - } - - // This goroutine will forward history between frames. - frame.frameDone.Add(1) - frame.initAsync() - - go frame.startDecoder(stream.output) - decodeFrame: - // Go through all blocks of the frame. - for { - dec := <-d.decoders - select { - case <-stream.cancel: - if !frame.sendErr(dec, io.EOF) { - // To not let the decoder dangle, send it back. - stream.output <- decodeOutput{d: dec} + seqExecute <- block + } + close(seqExecute) + }() + + var wg sync.WaitGroup + wg.Add(1) + + // Async 3: Execute sequences... + frameHistCache := d.frame.history.b + go func() { + var hist history + var decodedFrame uint64 + var fcs uint64 + var hasErr bool + for block := range seqExecute { + out := decodeOutput{err: block.err, d: block} + if block.err != nil || hasErr { + hasErr = true + output <- out + continue + } + if block.async.newHist != nil { + if debugDecoder { + println("Async 3: new history") + } + hist.windowSize = block.async.newHist.windowSize + hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer + if block.async.newHist.dict != nil { + hist.setDict(block.async.newHist.dict) + } + + if cap(hist.b) < hist.allocFrameBuffer { + if cap(frameHistCache) >= hist.allocFrameBuffer { + hist.b = frameHistCache + } else { + hist.b = make([]byte, 0, hist.allocFrameBuffer) + println("Alloc history sized", hist.allocFrameBuffer) + } + } + hist.b = hist.b[:0] + fcs = block.async.fcs + decodedFrame = 0 + } + do := decodeOutput{err: block.err, d: block} + switch block.Type { + case blockTypeRLE: + if debugDecoder { + println("add rle block length:", block.RLESize) + } + + if cap(block.dst) < int(block.RLESize) { + if block.lowMem { + block.dst = make([]byte, block.RLESize) + } else { + block.dst = make([]byte, maxBlockSize) } - break decodeStream - default: } - err := frame.next(dec) - switch err { - case io.EOF: - // End of current frame, no error - println("EOF on next block") - break decodeFrame - case nil: - continue - default: - println("block decoder returned", err) - break decodeStream + block.dst = block.dst[:block.RLESize] + v := block.data[0] + for i := range block.dst { + block.dst[i] = v + } + hist.append(block.dst) + do.b = block.dst + case blockTypeRaw: + if debugDecoder { + println("add raw block length:", len(block.data)) + } + hist.append(block.data) + do.b = block.data + case blockTypeCompressed: + if debugDecoder { + println("execute with history length:", len(hist.b), "window:", hist.windowSize) + } + hist.decoders.seqSize = block.async.seqSize + hist.decoders.literals = block.async.literals + do.err = block.executeSequences(&hist) + hasErr = do.err != nil + if debugDecoder && hasErr { + println("executeSequences returned:", do.err) + } + do.b = block.dst + } + if !hasErr { + decodedFrame += uint64(len(do.b)) + if decodedFrame > fcs { + println("fcs exceeded", block.Last, fcs, decodedFrame) + do.err = ErrFrameSizeExceeded + hasErr = true + } else if block.Last && fcs != fcsUnknown && decodedFrame != fcs { + do.err = ErrFrameSizeMismatch + hasErr = true + } else { + if debugDecoder { + println("fcs ok", block.Last, fcs, decodedFrame) + } } } - // All blocks have started decoding, check if there are more frames. - println("waiting for done") - frame.frameDone.Wait() - println("done waiting...") + output <- do + } + close(output) + frameHistCache = hist.b + wg.Done() + if debugDecoder { + println("decoder goroutines finished") + } + }() + +decodeStream: + for { + frame := d.frame + if debugDecoder { + println("New frame...") + } + var historySent bool + frame.history.reset() + err := frame.reset(&br) + if debugDecoder && err != nil { + println("Frame decoder returned", err) + } + if err == nil && frame.DictionaryID != nil { + dict, ok := d.dicts[*frame.DictionaryID] + if !ok { + err = ErrUnknownDictionary + } else { + frame.history.setDict(&dict) + } + } + if err == nil && d.frame.WindowSize > d.o.maxWindowSize { + err = ErrDecoderSizeExceeded + } + if err != nil { + select { + case <-ctx.Done(): + case dec := <-d.decoders: + dec.sendErr(err) + seqPrepare <- dec + } + break decodeStream + } + + // Go through all blocks of the frame. + for { + var dec *blockDec + select { + case <-ctx.Done(): + break decodeStream + case dec = <-d.decoders: + // Once we have a decoder, we MUST return it. + } + err := frame.next(dec) + if !historySent { + h := frame.history + if debugDecoder { + println("Alloc History:", h.allocFrameBuffer) + } + dec.async.newHist = &h + dec.async.fcs = frame.FrameContentSize + historySent = true + } else { + dec.async.newHist = nil + } + if debugDecoder && err != nil { + println("next block returned error:", err) + } + dec.err = err + dec.checkCRC = nil + if dec.Last && frame.HasCheckSum && err == nil { + crc, err := frame.rawInput.readSmall(4) + if err != nil { + println("CRC missing?", err) + dec.err = err + } + var tmp [4]byte + copy(tmp[:], crc) + dec.checkCRC = tmp[:] + if debugDecoder { + println("found crc to check:", dec.checkCRC) + } + } + err = dec.err + last := dec.Last + seqPrepare <- dec + if err != nil { + break decodeStream + } + if last { + break + } } - frame.frameDone.Wait() - println("Sending EOS") - stream.output <- decodeOutput{err: errEndOfStream} } + close(seqPrepare) + wg.Wait() + d.frame.history.b = frameHistCache } |