summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/vmihailenco/msgpack/v5/decode_map.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/vmihailenco/msgpack/v5/decode_map.go')
-rw-r--r--vendor/github.com/vmihailenco/msgpack/v5/decode_map.go339
1 files changed, 339 insertions, 0 deletions
diff --git a/vendor/github.com/vmihailenco/msgpack/v5/decode_map.go b/vendor/github.com/vmihailenco/msgpack/v5/decode_map.go
new file mode 100644
index 00000000..52e0526c
--- /dev/null
+++ b/vendor/github.com/vmihailenco/msgpack/v5/decode_map.go
@@ -0,0 +1,339 @@
+package msgpack
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
+)
+
+var errArrayStruct = errors.New("msgpack: number of fields in array-encoded struct has changed")
+
+var (
+ mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
+ mapStringStringType = mapStringStringPtrType.Elem()
+)
+
+var (
+ mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
+ mapStringInterfaceType = mapStringInterfacePtrType.Elem()
+)
+
+func decodeMapValue(d *Decoder, v reflect.Value) error {
+ n, err := d.DecodeMapLen()
+ if err != nil {
+ return err
+ }
+
+ typ := v.Type()
+ if n == -1 {
+ v.Set(reflect.Zero(typ))
+ return nil
+ }
+
+ if v.IsNil() {
+ v.Set(reflect.MakeMap(typ))
+ }
+ if n == 0 {
+ return nil
+ }
+
+ return d.decodeTypedMapValue(v, n)
+}
+
+func (d *Decoder) decodeMapDefault() (interface{}, error) {
+ if d.mapDecoder != nil {
+ return d.mapDecoder(d)
+ }
+ return d.DecodeMap()
+}
+
+// DecodeMapLen decodes map length. Length is -1 when map is nil.
+func (d *Decoder) DecodeMapLen() (int, error) {
+ c, err := d.readCode()
+ if err != nil {
+ return 0, err
+ }
+
+ if msgpcode.IsExt(c) {
+ if err = d.skipExtHeader(c); err != nil {
+ return 0, err
+ }
+
+ c, err = d.readCode()
+ if err != nil {
+ return 0, err
+ }
+ }
+ return d.mapLen(c)
+}
+
+func (d *Decoder) mapLen(c byte) (int, error) {
+ if c == msgpcode.Nil {
+ return -1, nil
+ }
+ if c >= msgpcode.FixedMapLow && c <= msgpcode.FixedMapHigh {
+ return int(c & msgpcode.FixedMapMask), nil
+ }
+ if c == msgpcode.Map16 {
+ size, err := d.uint16()
+ return int(size), err
+ }
+ if c == msgpcode.Map32 {
+ size, err := d.uint32()
+ return int(size), err
+ }
+ return 0, unexpectedCodeError{code: c, hint: "map length"}
+}
+
+func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
+ mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
+ return d.decodeMapStringStringPtr(mptr)
+}
+
+func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
+ size, err := d.DecodeMapLen()
+ if err != nil {
+ return err
+ }
+ if size == -1 {
+ *ptr = nil
+ return nil
+ }
+
+ m := *ptr
+ if m == nil {
+ *ptr = make(map[string]string, min(size, maxMapSize))
+ m = *ptr
+ }
+
+ for i := 0; i < size; i++ {
+ mk, err := d.DecodeString()
+ if err != nil {
+ return err
+ }
+ mv, err := d.DecodeString()
+ if err != nil {
+ return err
+ }
+ m[mk] = mv
+ }
+
+ return nil
+}
+
+func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
+ ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
+ return d.decodeMapStringInterfacePtr(ptr)
+}
+
+func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
+ m, err := d.DecodeMap()
+ if err != nil {
+ return err
+ }
+ *ptr = m
+ return nil
+}
+
+func (d *Decoder) DecodeMap() (map[string]interface{}, error) {
+ n, err := d.DecodeMapLen()
+ if err != nil {
+ return nil, err
+ }
+
+ if n == -1 {
+ return nil, nil
+ }
+
+ m := make(map[string]interface{}, min(n, maxMapSize))
+
+ for i := 0; i < n; i++ {
+ mk, err := d.DecodeString()
+ if err != nil {
+ return nil, err
+ }
+ mv, err := d.decodeInterfaceCond()
+ if err != nil {
+ return nil, err
+ }
+ m[mk] = mv
+ }
+
+ return m, nil
+}
+
+func (d *Decoder) DecodeUntypedMap() (map[interface{}]interface{}, error) {
+ n, err := d.DecodeMapLen()
+ if err != nil {
+ return nil, err
+ }
+
+ if n == -1 {
+ return nil, nil
+ }
+
+ m := make(map[interface{}]interface{}, min(n, maxMapSize))
+
+ for i := 0; i < n; i++ {
+ mk, err := d.decodeInterfaceCond()
+ if err != nil {
+ return nil, err
+ }
+
+ mv, err := d.decodeInterfaceCond()
+ if err != nil {
+ return nil, err
+ }
+
+ m[mk] = mv
+ }
+
+ return m, nil
+}
+
+// DecodeTypedMap decodes a typed map. Typed map is a map that has a fixed type for keys and values.
+// Key and value types may be different.
+func (d *Decoder) DecodeTypedMap() (interface{}, error) {
+ n, err := d.DecodeMapLen()
+ if err != nil {
+ return nil, err
+ }
+ if n <= 0 {
+ return nil, nil
+ }
+
+ key, err := d.decodeInterfaceCond()
+ if err != nil {
+ return nil, err
+ }
+
+ value, err := d.decodeInterfaceCond()
+ if err != nil {
+ return nil, err
+ }
+
+ keyType := reflect.TypeOf(key)
+ valueType := reflect.TypeOf(value)
+
+ if !keyType.Comparable() {
+ return nil, fmt.Errorf("msgpack: unsupported map key: %s", keyType.String())
+ }
+
+ mapType := reflect.MapOf(keyType, valueType)
+ mapValue := reflect.MakeMap(mapType)
+ mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
+
+ n--
+ if err := d.decodeTypedMapValue(mapValue, n); err != nil {
+ return nil, err
+ }
+
+ return mapValue.Interface(), nil
+}
+
+func (d *Decoder) decodeTypedMapValue(v reflect.Value, n int) error {
+ typ := v.Type()
+ keyType := typ.Key()
+ valueType := typ.Elem()
+
+ for i := 0; i < n; i++ {
+ mk := reflect.New(keyType).Elem()
+ if err := d.DecodeValue(mk); err != nil {
+ return err
+ }
+
+ mv := reflect.New(valueType).Elem()
+ if err := d.DecodeValue(mv); err != nil {
+ return err
+ }
+
+ v.SetMapIndex(mk, mv)
+ }
+
+ return nil
+}
+
+func (d *Decoder) skipMap(c byte) error {
+ n, err := d.mapLen(c)
+ if err != nil {
+ return err
+ }
+ for i := 0; i < n; i++ {
+ if err := d.Skip(); err != nil {
+ return err
+ }
+ if err := d.Skip(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func decodeStructValue(d *Decoder, v reflect.Value) error {
+ c, err := d.readCode()
+ if err != nil {
+ return err
+ }
+
+ n, err := d.mapLen(c)
+ if err == nil {
+ return d.decodeStruct(v, n)
+ }
+
+ var err2 error
+ n, err2 = d.arrayLen(c)
+ if err2 != nil {
+ return err
+ }
+
+ if n <= 0 {
+ v.Set(reflect.Zero(v.Type()))
+ return nil
+ }
+
+ fields := structs.Fields(v.Type(), d.structTag)
+ if n != len(fields.List) {
+ return errArrayStruct
+ }
+
+ for _, f := range fields.List {
+ if err := f.DecodeValue(d, v); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *Decoder) decodeStruct(v reflect.Value, n int) error {
+ if n == -1 {
+ v.Set(reflect.Zero(v.Type()))
+ return nil
+ }
+
+ fields := structs.Fields(v.Type(), d.structTag)
+ for i := 0; i < n; i++ {
+ name, err := d.decodeStringTemp()
+ if err != nil {
+ return err
+ }
+
+ if f := fields.Map[name]; f != nil {
+ if err := f.DecodeValue(d, v); err != nil {
+ return err
+ }
+ continue
+ }
+
+ if d.flags&disallowUnknownFieldsFlag != 0 {
+ return fmt.Errorf("msgpack: unknown field %q", name)
+ }
+ if err := d.Skip(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}