summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/rs/xid/id.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/rs/xid/id.go')
-rw-r--r--vendor/github.com/rs/xid/id.go26
1 files changed, 19 insertions, 7 deletions
diff --git a/vendor/github.com/rs/xid/id.go b/vendor/github.com/rs/xid/id.go
index f1db1a18..1f536b41 100644
--- a/vendor/github.com/rs/xid/id.go
+++ b/vendor/github.com/rs/xid/id.go
@@ -47,7 +47,6 @@ import (
"crypto/rand"
"database/sql/driver"
"encoding/binary"
- "errors"
"fmt"
"hash/crc32"
"io/ioutil"
@@ -73,9 +72,6 @@ const (
)
var (
- // ErrInvalidID is returned when trying to unmarshal an invalid ID
- ErrInvalidID = errors.New("xid: invalid ID")
-
// objectIDCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
// This id is initialized with a random value.
@@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
return ErrInvalidID
}
}
- decode(id, text)
+ if !decode(id, text) {
+ return ErrInvalidID
+ }
return nil
}
@@ -253,11 +251,15 @@ func (id *ID) UnmarshalJSON(b []byte) error {
*id = nilID
return nil
}
+ // Check the slice length to prevent panic on passing it to UnmarshalText()
+ if len(b) < 2 {
+ return ErrInvalidID
+ }
return id.UnmarshalText(b[1 : len(b)-1])
}
-// decode by unrolling the stdlib base32 algorithm + removing all safe checks
-func decode(id *ID, src []byte) {
+// decode by unrolling the stdlib base32 algorithm + customized safe check.
+func decode(id *ID, src []byte) bool {
_ = src[19]
_ = id[11]
@@ -273,6 +275,16 @@ func decode(id *ID, src []byte) {
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2
+
+ // Validate that there are no discarer bits (padding) in src that would
+ // cause the string-encoded id not to equal src.
+ var check [4]byte
+
+ check[3] = encoding[(id[11]<<4)&0x1F]
+ check[2] = encoding[(id[11]>>1)&0x1F]
+ check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
+ check[0] = encoding[id[10]>>3]
+ return bytes.Equal([]byte(src[16:20]), check[:])
}
// Time returns the timestamp part of the id.