From 81e6f75aa491c7bfa11780dd622587cf5fff3c8f Mon Sep 17 00:00:00 2001 From: Wim Date: Mon, 2 May 2022 00:10:54 +0200 Subject: Update dependencies (#1822) --- vendor/github.com/rs/xid/id.go | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) (limited to 'vendor/github.com/rs/xid/id.go') diff --git a/vendor/github.com/rs/xid/id.go b/vendor/github.com/rs/xid/id.go index f1db1a18..1f536b41 100644 --- a/vendor/github.com/rs/xid/id.go +++ b/vendor/github.com/rs/xid/id.go @@ -47,7 +47,6 @@ import ( "crypto/rand" "database/sql/driver" "encoding/binary" - "errors" "fmt" "hash/crc32" "io/ioutil" @@ -73,9 +72,6 @@ const ( ) var ( - // ErrInvalidID is returned when trying to unmarshal an invalid ID - ErrInvalidID = errors.New("xid: invalid ID") - // objectIDCounter is atomically incremented when generating a new ObjectId // using NewObjectId() function. It's used as a counter part of an id. // This id is initialized with a random value. @@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error { return ErrInvalidID } } - decode(id, text) + if !decode(id, text) { + return ErrInvalidID + } return nil } @@ -253,11 +251,15 @@ func (id *ID) UnmarshalJSON(b []byte) error { *id = nilID return nil } + // Check the slice length to prevent panic on passing it to UnmarshalText() + if len(b) < 2 { + return ErrInvalidID + } return id.UnmarshalText(b[1 : len(b)-1]) } -// decode by unrolling the stdlib base32 algorithm + removing all safe checks -func decode(id *ID, src []byte) { +// decode by unrolling the stdlib base32 algorithm + customized safe check. +func decode(id *ID, src []byte) bool { _ = src[19] _ = id[11] @@ -273,6 +275,16 @@ func decode(id *ID, src []byte) { id[2] = dec[src[3]]<<4 | dec[src[4]]>>1 id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4 id[0] = dec[src[0]]<<3 | dec[src[1]]>>2 + + // Validate that there are no discarer bits (padding) in src that would + // cause the string-encoded id not to equal src. + var check [4]byte + + check[3] = encoding[(id[11]<<4)&0x1F] + check[2] = encoding[(id[11]>>1)&0x1F] + check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F] + check[0] = encoding[id[10]>>3] + return bytes.Equal([]byte(src[16:20]), check[:]) } // Time returns the timestamp part of the id. -- cgit v1.2.3