summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/klauspost/compress/s2/decode.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/klauspost/compress/s2/decode.go')
-rw-r--r--vendor/github.com/klauspost/compress/s2/decode.go289
1 files changed, 279 insertions, 10 deletions
diff --git a/vendor/github.com/klauspost/compress/s2/decode.go b/vendor/github.com/klauspost/compress/s2/decode.go
index 9e7fce88..2aba9e27 100644
--- a/vendor/github.com/klauspost/compress/s2/decode.go
+++ b/vendor/github.com/klauspost/compress/s2/decode.go
@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"io/ioutil"
+ "runtime"
+ "sync"
)
var (
@@ -169,6 +171,14 @@ func ReaderSkippableCB(id uint8, fn func(r io.Reader) error) ReaderOption {
}
}
+// ReaderIgnoreCRC will make the reader skip CRC calculation and checks.
+func ReaderIgnoreCRC() ReaderOption {
+ return func(r *Reader) error {
+ r.ignoreCRC = true
+ return nil
+ }
+}
+
// Reader is an io.Reader that can read Snappy-compressed bytes.
type Reader struct {
r io.Reader
@@ -191,18 +201,19 @@ type Reader struct {
paramsOK bool
snappyFrame bool
ignoreStreamID bool
+ ignoreCRC bool
}
// ensureBufferSize will ensure that the buffer can take at least n bytes.
// If false is returned the buffer exceeds maximum allowed size.
func (r *Reader) ensureBufferSize(n int) bool {
- if len(r.buf) >= n {
- return true
- }
if n > r.maxBufSize {
r.err = ErrCorrupt
return false
}
+ if cap(r.buf) >= n {
+ return true
+ }
// Realloc buffer.
r.buf = make([]byte, n)
return true
@@ -220,6 +231,7 @@ func (r *Reader) Reset(reader io.Reader) {
r.err = nil
r.i = 0
r.j = 0
+ r.blockStart = 0
r.readHeader = r.ignoreStreamID
}
@@ -344,7 +356,7 @@ func (r *Reader) Read(p []byte) (int, error) {
r.err = err
return 0, r.err
}
- if crc(r.decoded[:n]) != checksum {
+ if !r.ignoreCRC && crc(r.decoded[:n]) != checksum {
r.err = ErrCRC
return 0, r.err
}
@@ -385,7 +397,7 @@ func (r *Reader) Read(p []byte) (int, error) {
if !r.readFull(r.decoded[:n], false) {
return 0, r.err
}
- if crc(r.decoded[:n]) != checksum {
+ if !r.ignoreCRC && crc(r.decoded[:n]) != checksum {
r.err = ErrCRC
return 0, r.err
}
@@ -435,6 +447,259 @@ func (r *Reader) Read(p []byte) (int, error) {
}
}
+// DecodeConcurrent will decode the full stream to w.
+// This function should not be combined with reading, seeking or other operations.
+// Up to 'concurrent' goroutines will be used.
+// If <= 0, runtime.NumCPU will be used.
+// On success the number of bytes decompressed nil and is returned.
+// This is mainly intended for bigger streams.
+func (r *Reader) DecodeConcurrent(w io.Writer, concurrent int) (written int64, err error) {
+ if r.i > 0 || r.j > 0 || r.blockStart > 0 {
+ return 0, errors.New("DecodeConcurrent called after ")
+ }
+ if concurrent <= 0 {
+ concurrent = runtime.NumCPU()
+ }
+
+ // Write to output
+ var errMu sync.Mutex
+ var aErr error
+ setErr := func(e error) (ok bool) {
+ errMu.Lock()
+ defer errMu.Unlock()
+ if e == nil {
+ return aErr == nil
+ }
+ if aErr == nil {
+ aErr = e
+ }
+ return false
+ }
+ hasErr := func() (ok bool) {
+ errMu.Lock()
+ v := aErr != nil
+ errMu.Unlock()
+ return v
+ }
+
+ var aWritten int64
+ toRead := make(chan []byte, concurrent)
+ writtenBlocks := make(chan []byte, concurrent)
+ queue := make(chan chan []byte, concurrent)
+ reUse := make(chan chan []byte, concurrent)
+ for i := 0; i < concurrent; i++ {
+ toRead <- make([]byte, 0, r.maxBufSize)
+ writtenBlocks <- make([]byte, 0, r.maxBufSize)
+ reUse <- make(chan []byte, 1)
+ }
+ // Writer
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for toWrite := range queue {
+ entry := <-toWrite
+ reUse <- toWrite
+ if hasErr() {
+ writtenBlocks <- entry
+ continue
+ }
+ n, err := w.Write(entry)
+ want := len(entry)
+ writtenBlocks <- entry
+ if err != nil {
+ setErr(err)
+ continue
+ }
+ if n != want {
+ setErr(io.ErrShortWrite)
+ continue
+ }
+ aWritten += int64(n)
+ }
+ }()
+
+ // Reader
+ defer func() {
+ close(queue)
+ if r.err != nil {
+ err = r.err
+ setErr(r.err)
+ }
+ wg.Wait()
+ if err == nil {
+ err = aErr
+ }
+ written = aWritten
+ }()
+
+ for !hasErr() {
+ if !r.readFull(r.buf[:4], true) {
+ if r.err == io.EOF {
+ r.err = nil
+ }
+ return 0, r.err
+ }
+ chunkType := r.buf[0]
+ if !r.readHeader {
+ if chunkType != chunkTypeStreamIdentifier {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ r.readHeader = true
+ }
+ chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
+
+ // The chunk types are specified at
+ // https://github.com/google/snappy/blob/master/framing_format.txt
+ switch chunkType {
+ case chunkTypeCompressedData:
+ r.blockStart += int64(r.j)
+ // Section 4.2. Compressed data (chunk type 0x00).
+ if chunkLen < checksumSize {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ if chunkLen > r.maxBufSize {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ orgBuf := <-toRead
+ buf := orgBuf[:chunkLen]
+
+ if !r.readFull(buf, false) {
+ return 0, r.err
+ }
+
+ checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
+ buf = buf[checksumSize:]
+
+ n, err := DecodedLen(buf)
+ if err != nil {
+ r.err = err
+ return 0, r.err
+ }
+ if r.snappyFrame && n > maxSnappyBlockSize {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+
+ if n > r.maxBlock {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ wg.Add(1)
+
+ decoded := <-writtenBlocks
+ entry := <-reUse
+ queue <- entry
+ go func() {
+ defer wg.Done()
+ decoded = decoded[:n]
+ _, err := Decode(decoded, buf)
+ toRead <- orgBuf
+ if err != nil {
+ writtenBlocks <- decoded
+ setErr(err)
+ return
+ }
+ if !r.ignoreCRC && crc(decoded) != checksum {
+ writtenBlocks <- decoded
+ setErr(ErrCRC)
+ return
+ }
+ entry <- decoded
+ }()
+ continue
+
+ case chunkTypeUncompressedData:
+
+ // Section 4.3. Uncompressed data (chunk type 0x01).
+ if chunkLen < checksumSize {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ if chunkLen > r.maxBufSize {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ // Grab write buffer
+ orgBuf := <-writtenBlocks
+ buf := orgBuf[:checksumSize]
+ if !r.readFull(buf, false) {
+ return 0, r.err
+ }
+ checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
+ // Read content.
+ n := chunkLen - checksumSize
+
+ if r.snappyFrame && n > maxSnappyBlockSize {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ if n > r.maxBlock {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ // Read uncompressed
+ buf = orgBuf[:n]
+ if !r.readFull(buf, false) {
+ return 0, r.err
+ }
+
+ if !r.ignoreCRC && crc(buf) != checksum {
+ r.err = ErrCRC
+ return 0, r.err
+ }
+ entry := <-reUse
+ queue <- entry
+ entry <- buf
+ continue
+
+ case chunkTypeStreamIdentifier:
+ // Section 4.1. Stream identifier (chunk type 0xff).
+ if chunkLen != len(magicBody) {
+ r.err = ErrCorrupt
+ return 0, r.err
+ }
+ if !r.readFull(r.buf[:len(magicBody)], false) {
+ return 0, r.err
+ }
+ if string(r.buf[:len(magicBody)]) != magicBody {
+ if string(r.buf[:len(magicBody)]) != magicBodySnappy {
+ r.err = ErrCorrupt
+ return 0, r.err
+ } else {
+ r.snappyFrame = true
+ }
+ } else {
+ r.snappyFrame = false
+ }
+ continue
+ }
+
+ if chunkType <= 0x7f {
+ // Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
+ // fmt.Printf("ERR chunktype: 0x%x\n", chunkType)
+ r.err = ErrUnsupported
+ return 0, r.err
+ }
+ // Section 4.4 Padding (chunk type 0xfe).
+ // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
+ if chunkLen > maxChunkSize {
+ // fmt.Printf("ERR chunkLen: 0x%x\n", chunkLen)
+ r.err = ErrUnsupported
+ return 0, r.err
+ }
+
+ // fmt.Printf("skippable: ID: 0x%x, len: 0x%x\n", chunkType, chunkLen)
+ if !r.skippable(r.buf, chunkLen, false, chunkType) {
+ return 0, r.err
+ }
+ }
+ return 0, r.err
+}
+
// Skip will skip n bytes forward in the decompressed output.
// For larger skips this consumes less CPU and is faster than reading output and discarding it.
// CRC is not checked on skipped blocks.
@@ -699,8 +964,16 @@ func (r *ReadSeeker) Seek(offset int64, whence int) (int64, error) {
case io.SeekCurrent:
offset += r.blockStart + int64(r.i)
case io.SeekEnd:
- offset = -offset
+ if offset > 0 {
+ return 0, errors.New("seek after end of file")
+ }
+ offset = r.index.TotalUncompressed + offset
+ }
+
+ if offset < 0 {
+ return 0, errors.New("seek before start of file")
}
+
c, u, err := r.index.Find(offset)
if err != nil {
return r.blockStart + int64(r.i), err
@@ -712,10 +985,6 @@ func (r *ReadSeeker) Seek(offset int64, whence int) (int64, error) {
return 0, err
}
- if offset < 0 {
- offset = r.index.TotalUncompressed + offset
- }
-
r.i = r.j // Remove rest of current block.
if u < offset {
// Forward inside block