summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/klauspost/compress/zstd/decoder.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/klauspost/compress/zstd/decoder.go')
-rw-r--r--vendor/github.com/klauspost/compress/zstd/decoder.go631
1 files changed, 500 insertions, 131 deletions
diff --git a/vendor/github.com/klauspost/compress/zstd/decoder.go b/vendor/github.com/klauspost/compress/zstd/decoder.go
index f430f58b..9fcdaac1 100644
--- a/vendor/github.com/klauspost/compress/zstd/decoder.go
+++ b/vendor/github.com/klauspost/compress/zstd/decoder.go
@@ -5,9 +5,13 @@
package zstd
import (
- "errors"
+ "bytes"
+ "context"
+ "encoding/binary"
"io"
"sync"
+
+ "github.com/klauspost/compress/zstd/internal/xxhash"
)
// Decoder provides decoding of zstandard streams.
@@ -22,12 +26,19 @@ type Decoder struct {
// Unreferenced decoders, ready for use.
decoders chan *blockDec
- // Streams ready to be decoded.
- stream chan decodeStream
-
// Current read position used for Reader functionality.
current decoderState
+ // sync stream decoding
+ syncStream struct {
+ decodedFrame uint64
+ br readerWrapper
+ enabled bool
+ inFrame bool
+ }
+
+ frame *frameDec
+
// Custom dictionaries.
// Always uses copies.
dicts map[uint32]dict
@@ -46,7 +57,10 @@ type decoderState struct {
output chan decodeOutput
// cancel remaining output.
- cancel chan struct{}
+ cancel context.CancelFunc
+
+ // crc of current frame
+ crc *xxhash.Digest
flushed bool
}
@@ -81,7 +95,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
return nil, err
}
}
- d.current.output = make(chan decodeOutput, d.o.concurrent)
+ d.current.crc = xxhash.New()
d.current.flushed = true
if r == nil {
@@ -130,7 +144,7 @@ func (d *Decoder) Read(p []byte) (int, error) {
break
}
if !d.nextBlock(n == 0) {
- return n, nil
+ return n, d.current.err
}
}
}
@@ -162,6 +176,7 @@ func (d *Decoder) Reset(r io.Reader) error {
d.drainOutput()
+ d.syncStream.br.r = nil
if r == nil {
d.current.err = ErrDecoderNilInput
if len(d.current.b) > 0 {
@@ -195,33 +210,39 @@ func (d *Decoder) Reset(r io.Reader) error {
}
return nil
}
-
- if d.stream == nil {
- d.stream = make(chan decodeStream, 1)
- d.streamWg.Add(1)
- go d.startStreamDecoder(d.stream)
- }
-
// Remove current block.
+ d.stashDecoder()
d.current.decodeOutput = decodeOutput{}
d.current.err = nil
- d.current.cancel = make(chan struct{})
d.current.flushed = false
d.current.d = nil
- d.stream <- decodeStream{
- r: r,
- output: d.current.output,
- cancel: d.current.cancel,
+ // Ensure no-one else is still running...
+ d.streamWg.Wait()
+ if d.frame == nil {
+ d.frame = newFrameDec(d.o)
+ }
+
+ if d.o.concurrent == 1 {
+ return d.startSyncDecoder(r)
}
+
+ d.current.output = make(chan decodeOutput, d.o.concurrent)
+ ctx, cancel := context.WithCancel(context.Background())
+ d.current.cancel = cancel
+ d.streamWg.Add(1)
+ go d.startStreamDecoder(ctx, r, d.current.output)
+
return nil
}
// drainOutput will drain the output until errEndOfStream is sent.
func (d *Decoder) drainOutput() {
if d.current.cancel != nil {
- println("cancelling current")
- close(d.current.cancel)
+ if debugDecoder {
+ println("cancelling current")
+ }
+ d.current.cancel()
d.current.cancel = nil
}
if d.current.d != nil {
@@ -243,12 +264,9 @@ func (d *Decoder) drainOutput() {
}
d.decoders <- v.d
}
- if v.err == errEndOfStream {
- println("current flushed")
- d.current.flushed = true
- return
- }
}
+ d.current.output = nil
+ d.current.flushed = true
}
// WriteTo writes data to w until there's no more data to write or when an error occurs.
@@ -287,7 +305,7 @@ func (d *Decoder) WriteTo(w io.Writer) (int64, error) {
// DecodeAll can be used concurrently.
// The Decoder concurrency limits will be respected.
func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
- if d.current.err == ErrDecoderClosed {
+ if d.decoders == nil {
return dst, ErrDecoderClosed
}
@@ -300,6 +318,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
}
frame.rawInput = nil
frame.bBuf = nil
+ if frame.history.decoders.br != nil {
+ frame.history.decoders.br.in = nil
+ }
d.decoders <- block
}()
frame.bBuf = input
@@ -307,27 +328,31 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
for {
frame.history.reset()
err := frame.reset(&frame.bBuf)
- if err == io.EOF {
- if debugDecoder {
- println("frame reset return EOF")
+ if err != nil {
+ if err == io.EOF {
+ if debugDecoder {
+ println("frame reset return EOF")
+ }
+ return dst, nil
}
- return dst, nil
+ return dst, err
}
if frame.DictionaryID != nil {
dict, ok := d.dicts[*frame.DictionaryID]
if !ok {
return nil, ErrUnknownDictionary
}
+ if debugDecoder {
+ println("setting dict", frame.DictionaryID)
+ }
frame.history.setDict(&dict)
}
- if err != nil {
- return dst, err
- }
- if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
+
+ if frame.FrameContentSize != fcsUnknown && frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
return dst, ErrDecoderSizeExceeded
}
- if frame.FrameContentSize > 0 && frame.FrameContentSize < 1<<30 {
- // Never preallocate moe than 1 GB up front.
+ if frame.FrameContentSize < 1<<30 {
+ // Never preallocate more than 1 GB up front.
if cap(dst)-len(dst) < int(frame.FrameContentSize) {
dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize))
copy(dst2, dst)
@@ -368,33 +393,170 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
// If non-blocking mode is used the returned boolean will be false
// if no data was available without blocking.
func (d *Decoder) nextBlock(blocking bool) (ok bool) {
- if d.current.d != nil {
- if debugDecoder {
- printf("re-adding current decoder %p", d.current.d)
- }
- d.decoders <- d.current.d
- d.current.d = nil
- }
if d.current.err != nil {
// Keep error state.
- return blocking
+ return false
}
+ d.current.b = d.current.b[:0]
+ // SYNC:
+ if d.syncStream.enabled {
+ if !blocking {
+ return false
+ }
+ ok = d.nextBlockSync()
+ if !ok {
+ d.stashDecoder()
+ }
+ return ok
+ }
+
+ //ASYNC:
+ d.stashDecoder()
if blocking {
- d.current.decodeOutput = <-d.current.output
+ d.current.decodeOutput, ok = <-d.current.output
} else {
select {
- case d.current.decodeOutput = <-d.current.output:
+ case d.current.decodeOutput, ok = <-d.current.output:
default:
return false
}
}
+ if !ok {
+ // This should not happen, so signal error state...
+ d.current.err = io.ErrUnexpectedEOF
+ return false
+ }
+ next := d.current.decodeOutput
+ if next.d != nil && next.d.async.newHist != nil {
+ d.current.crc.Reset()
+ }
if debugDecoder {
- println("got", len(d.current.b), "bytes, error:", d.current.err)
+ var tmp [4]byte
+ binary.LittleEndian.PutUint32(tmp[:], uint32(xxhash.Sum64(next.b)))
+ println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp)
+ }
+
+ if len(next.b) > 0 {
+ n, err := d.current.crc.Write(next.b)
+ if err == nil {
+ if n != len(next.b) {
+ d.current.err = io.ErrShortWrite
+ }
+ }
+ }
+ if next.err == nil && next.d != nil && len(next.d.checkCRC) != 0 {
+ got := d.current.crc.Sum64()
+ var tmp [4]byte
+ binary.LittleEndian.PutUint32(tmp[:], uint32(got))
+ if !bytes.Equal(tmp[:], next.d.checkCRC) && !ignoreCRC {
+ if debugDecoder {
+ println("CRC Check Failed:", tmp[:], " (got) !=", next.d.checkCRC, "(on stream)")
+ }
+ d.current.err = ErrCRCMismatch
+ } else {
+ if debugDecoder {
+ println("CRC ok", tmp[:])
+ }
+ }
+ }
+
+ return true
+}
+
+func (d *Decoder) nextBlockSync() (ok bool) {
+ if d.current.d == nil {
+ d.current.d = <-d.decoders
+ }
+ for len(d.current.b) == 0 {
+ if !d.syncStream.inFrame {
+ d.frame.history.reset()
+ d.current.err = d.frame.reset(&d.syncStream.br)
+ if d.current.err != nil {
+ return false
+ }
+ if d.frame.DictionaryID != nil {
+ dict, ok := d.dicts[*d.frame.DictionaryID]
+ if !ok {
+ d.current.err = ErrUnknownDictionary
+ return false
+ } else {
+ d.frame.history.setDict(&dict)
+ }
+ }
+ if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize {
+ d.current.err = ErrDecoderSizeExceeded
+ return false
+ }
+
+ d.syncStream.decodedFrame = 0
+ d.syncStream.inFrame = true
+ }
+ d.current.err = d.frame.next(d.current.d)
+ if d.current.err != nil {
+ return false
+ }
+ d.frame.history.ensureBlock()
+ if debugDecoder {
+ println("History trimmed:", len(d.frame.history.b), "decoded already:", d.syncStream.decodedFrame)
+ }
+ histBefore := len(d.frame.history.b)
+ d.current.err = d.current.d.decodeBuf(&d.frame.history)
+
+ if d.current.err != nil {
+ println("error after:", d.current.err)
+ return false
+ }
+ d.current.b = d.frame.history.b[histBefore:]
+ if debugDecoder {
+ println("history after:", len(d.frame.history.b))
+ }
+
+ // Check frame size (before CRC)
+ d.syncStream.decodedFrame += uint64(len(d.current.b))
+ if d.syncStream.decodedFrame > d.frame.FrameContentSize {
+ if debugDecoder {
+ printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize)
+ }
+ d.current.err = ErrFrameSizeExceeded
+ return false
+ }
+
+ // Check FCS
+ if d.current.d.Last && d.frame.FrameContentSize != fcsUnknown && d.syncStream.decodedFrame != d.frame.FrameContentSize {
+ if debugDecoder {
+ printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize)
+ }
+ d.current.err = ErrFrameSizeMismatch
+ return false
+ }
+
+ // Update/Check CRC
+ if d.frame.HasCheckSum {
+ d.frame.crc.Write(d.current.b)
+ if d.current.d.Last {
+ d.current.err = d.frame.checkCRC()
+ if d.current.err != nil {
+ println("CRC error:", d.current.err)
+ return false
+ }
+ }
+ }
+ d.syncStream.inFrame = !d.current.d.Last
}
return true
}
+func (d *Decoder) stashDecoder() {
+ if d.current.d != nil {
+ if debugDecoder {
+ printf("re-adding current decoder %p", d.current.d)
+ }
+ d.decoders <- d.current.d
+ d.current.d = nil
+ }
+}
+
// Close will release all resources.
// It is NOT possible to reuse the decoder after this.
func (d *Decoder) Close() {
@@ -402,10 +564,10 @@ func (d *Decoder) Close() {
return
}
d.drainOutput()
- if d.stream != nil {
- close(d.stream)
+ if d.current.cancel != nil {
+ d.current.cancel()
d.streamWg.Wait()
- d.stream = nil
+ d.current.cancel = nil
}
if d.decoders != nil {
close(d.decoders)
@@ -456,100 +618,307 @@ type decodeOutput struct {
err error
}
-type decodeStream struct {
- r io.Reader
-
- // Blocks ready to be written to output.
- output chan decodeOutput
-
- // cancel reading from the input
- cancel chan struct{}
+func (d *Decoder) startSyncDecoder(r io.Reader) error {
+ d.frame.history.reset()
+ d.syncStream.br = readerWrapper{r: r}
+ d.syncStream.inFrame = false
+ d.syncStream.enabled = true
+ d.syncStream.decodedFrame = 0
+ return nil
}
-// errEndOfStream indicates that everything from the stream was read.
-var errEndOfStream = errors.New("end-of-stream")
-
// Create Decoder:
-// Spawn n block decoders. These accept tasks to decode a block.
-// Create goroutine that handles stream processing, this will send history to decoders as they are available.
-// Decoders update the history as they decode.
-// When a block is returned:
-// a) history is sent to the next decoder,
-// b) content written to CRC.
-// c) return data to WRITER.
-// d) wait for next block to return data.
-// Once WRITTEN, the decoders reused by the writer frame decoder for re-use.
-func (d *Decoder) startStreamDecoder(inStream chan decodeStream) {
+// ASYNC:
+// Spawn 4 go routines.
+// 0: Read frames and decode blocks.
+// 1: Decode block and literals. Receives hufftree and seqdecs, returns seqdecs and huff tree.
+// 2: Wait for recentOffsets if needed. Decode sequences, send recentOffsets.
+// 3: Wait for stream history, execute sequences, send stream history.
+func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output chan decodeOutput) {
defer d.streamWg.Done()
- frame := newFrameDec(d.o)
- for stream := range inStream {
- if debugDecoder {
- println("got new stream")
+ br := readerWrapper{r: r}
+
+ var seqPrepare = make(chan *blockDec, d.o.concurrent)
+ var seqDecode = make(chan *blockDec, d.o.concurrent)
+ var seqExecute = make(chan *blockDec, d.o.concurrent)
+
+ // Async 1: Prepare blocks...
+ go func() {
+ var hist history
+ var hasErr bool
+ for block := range seqPrepare {
+ if hasErr {
+ if block != nil {
+ seqDecode <- block
+ }
+ continue
+ }
+ if block.async.newHist != nil {
+ if debugDecoder {
+ println("Async 1: new history")
+ }
+ hist.reset()
+ if block.async.newHist.dict != nil {
+ hist.setDict(block.async.newHist.dict)
+ }
+ }
+ if block.err != nil || block.Type != blockTypeCompressed {
+ hasErr = block.err != nil
+ seqDecode <- block
+ continue
+ }
+
+ remain, err := block.decodeLiterals(block.data, &hist)
+ block.err = err
+ hasErr = block.err != nil
+ if err == nil {
+ block.async.literals = hist.decoders.literals
+ block.async.seqData = remain
+ } else if debugDecoder {
+ println("decodeLiterals error:", err)
+ }
+ seqDecode <- block
}
- br := readerWrapper{r: stream.r}
- decodeStream:
- for {
- frame.history.reset()
- err := frame.reset(&br)
- if debugDecoder && err != nil {
- println("Frame decoder returned", err)
+ close(seqDecode)
+ }()
+
+ // Async 2: Decode sequences...
+ go func() {
+ var hist history
+ var hasErr bool
+
+ for block := range seqDecode {
+ if hasErr {
+ if block != nil {
+ seqExecute <- block
+ }
+ continue
}
- if err == nil && frame.DictionaryID != nil {
- dict, ok := d.dicts[*frame.DictionaryID]
- if !ok {
- err = ErrUnknownDictionary
- } else {
- frame.history.setDict(&dict)
+ if block.async.newHist != nil {
+ if debugDecoder {
+ println("Async 2: new history, recent:", block.async.newHist.recentOffsets)
+ }
+ hist.decoders = block.async.newHist.decoders
+ hist.recentOffsets = block.async.newHist.recentOffsets
+ hist.windowSize = block.async.newHist.windowSize
+ if block.async.newHist.dict != nil {
+ hist.setDict(block.async.newHist.dict)
}
}
- if err != nil {
- stream.output <- decodeOutput{
- err: err,
+ if block.err != nil || block.Type != blockTypeCompressed {
+ hasErr = block.err != nil
+ seqExecute <- block
+ continue
+ }
+
+ hist.decoders.literals = block.async.literals
+ block.err = block.prepareSequences(block.async.seqData, &hist)
+ if debugDecoder && block.err != nil {
+ println("prepareSequences returned:", block.err)
+ }
+ hasErr = block.err != nil
+ if block.err == nil {
+ block.err = block.decodeSequences(&hist)
+ if debugDecoder && block.err != nil {
+ println("decodeSequences returned:", block.err)
}
- break
+ hasErr = block.err != nil
+ // block.async.sequence = hist.decoders.seq[:hist.decoders.nSeqs]
+ block.async.seqSize = hist.decoders.seqSize
}
- if debugDecoder {
- println("starting frame decoder")
- }
-
- // This goroutine will forward history between frames.
- frame.frameDone.Add(1)
- frame.initAsync()
-
- go frame.startDecoder(stream.output)
- decodeFrame:
- // Go through all blocks of the frame.
- for {
- dec := <-d.decoders
- select {
- case <-stream.cancel:
- if !frame.sendErr(dec, io.EOF) {
- // To not let the decoder dangle, send it back.
- stream.output <- decodeOutput{d: dec}
+ seqExecute <- block
+ }
+ close(seqExecute)
+ }()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ // Async 3: Execute sequences...
+ frameHistCache := d.frame.history.b
+ go func() {
+ var hist history
+ var decodedFrame uint64
+ var fcs uint64
+ var hasErr bool
+ for block := range seqExecute {
+ out := decodeOutput{err: block.err, d: block}
+ if block.err != nil || hasErr {
+ hasErr = true
+ output <- out
+ continue
+ }
+ if block.async.newHist != nil {
+ if debugDecoder {
+ println("Async 3: new history")
+ }
+ hist.windowSize = block.async.newHist.windowSize
+ hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer
+ if block.async.newHist.dict != nil {
+ hist.setDict(block.async.newHist.dict)
+ }
+
+ if cap(hist.b) < hist.allocFrameBuffer {
+ if cap(frameHistCache) >= hist.allocFrameBuffer {
+ hist.b = frameHistCache
+ } else {
+ hist.b = make([]byte, 0, hist.allocFrameBuffer)
+ println("Alloc history sized", hist.allocFrameBuffer)
+ }
+ }
+ hist.b = hist.b[:0]
+ fcs = block.async.fcs
+ decodedFrame = 0
+ }
+ do := decodeOutput{err: block.err, d: block}
+ switch block.Type {
+ case blockTypeRLE:
+ if debugDecoder {
+ println("add rle block length:", block.RLESize)
+ }
+
+ if cap(block.dst) < int(block.RLESize) {
+ if block.lowMem {
+ block.dst = make([]byte, block.RLESize)
+ } else {
+ block.dst = make([]byte, maxBlockSize)
}
- break decodeStream
- default:
}
- err := frame.next(dec)
- switch err {
- case io.EOF:
- // End of current frame, no error
- println("EOF on next block")
- break decodeFrame
- case nil:
- continue
- default:
- println("block decoder returned", err)
- break decodeStream
+ block.dst = block.dst[:block.RLESize]
+ v := block.data[0]
+ for i := range block.dst {
+ block.dst[i] = v
+ }
+ hist.append(block.dst)
+ do.b = block.dst
+ case blockTypeRaw:
+ if debugDecoder {
+ println("add raw block length:", len(block.data))
+ }
+ hist.append(block.data)
+ do.b = block.data
+ case blockTypeCompressed:
+ if debugDecoder {
+ println("execute with history length:", len(hist.b), "window:", hist.windowSize)
+ }
+ hist.decoders.seqSize = block.async.seqSize
+ hist.decoders.literals = block.async.literals
+ do.err = block.executeSequences(&hist)
+ hasErr = do.err != nil
+ if debugDecoder && hasErr {
+ println("executeSequences returned:", do.err)
+ }
+ do.b = block.dst
+ }
+ if !hasErr {
+ decodedFrame += uint64(len(do.b))
+ if decodedFrame > fcs {
+ println("fcs exceeded", block.Last, fcs, decodedFrame)
+ do.err = ErrFrameSizeExceeded
+ hasErr = true
+ } else if block.Last && fcs != fcsUnknown && decodedFrame != fcs {
+ do.err = ErrFrameSizeMismatch
+ hasErr = true
+ } else {
+ if debugDecoder {
+ println("fcs ok", block.Last, fcs, decodedFrame)
+ }
}
}
- // All blocks have started decoding, check if there are more frames.
- println("waiting for done")
- frame.frameDone.Wait()
- println("done waiting...")
+ output <- do
+ }
+ close(output)
+ frameHistCache = hist.b
+ wg.Done()
+ if debugDecoder {
+ println("decoder goroutines finished")
+ }
+ }()
+
+decodeStream:
+ for {
+ frame := d.frame
+ if debugDecoder {
+ println("New frame...")
+ }
+ var historySent bool
+ frame.history.reset()
+ err := frame.reset(&br)
+ if debugDecoder && err != nil {
+ println("Frame decoder returned", err)
+ }
+ if err == nil && frame.DictionaryID != nil {
+ dict, ok := d.dicts[*frame.DictionaryID]
+ if !ok {
+ err = ErrUnknownDictionary
+ } else {
+ frame.history.setDict(&dict)
+ }
+ }
+ if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
+ err = ErrDecoderSizeExceeded
+ }
+ if err != nil {
+ select {
+ case <-ctx.Done():
+ case dec := <-d.decoders:
+ dec.sendErr(err)
+ seqPrepare <- dec
+ }
+ break decodeStream
+ }
+
+ // Go through all blocks of the frame.
+ for {
+ var dec *blockDec
+ select {
+ case <-ctx.Done():
+ break decodeStream
+ case dec = <-d.decoders:
+ // Once we have a decoder, we MUST return it.
+ }
+ err := frame.next(dec)
+ if !historySent {
+ h := frame.history
+ if debugDecoder {
+ println("Alloc History:", h.allocFrameBuffer)
+ }
+ dec.async.newHist = &h
+ dec.async.fcs = frame.FrameContentSize
+ historySent = true
+ } else {
+ dec.async.newHist = nil
+ }
+ if debugDecoder && err != nil {
+ println("next block returned error:", err)
+ }
+ dec.err = err
+ dec.checkCRC = nil
+ if dec.Last && frame.HasCheckSum && err == nil {
+ crc, err := frame.rawInput.readSmall(4)
+ if err != nil {
+ println("CRC missing?", err)
+ dec.err = err
+ }
+ var tmp [4]byte
+ copy(tmp[:], crc)
+ dec.checkCRC = tmp[:]
+ if debugDecoder {
+ println("found crc to check:", dec.checkCRC)
+ }
+ }
+ err = dec.err
+ last := dec.Last
+ seqPrepare <- dec
+ if err != nil {
+ break decodeStream
+ }
+ if last {
+ break
+ }
}
- frame.frameDone.Wait()
- println("Sending EOS")
- stream.output <- decodeOutput{err: errEndOfStream}
}
+ close(seqPrepare)
+ wg.Wait()
+ d.frame.history.b = frameHistCache
}