summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/golang/protobuf/proto/extensions.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/golang/protobuf/proto/extensions.go')
-rw-r--r--vendor/github.com/golang/protobuf/proto/extensions.go204
1 files changed, 80 insertions, 124 deletions
diff --git a/vendor/github.com/golang/protobuf/proto/extensions.go b/vendor/github.com/golang/protobuf/proto/extensions.go
index eaad2183..816a3b9d 100644
--- a/vendor/github.com/golang/protobuf/proto/extensions.go
+++ b/vendor/github.com/golang/protobuf/proto/extensions.go
@@ -38,6 +38,7 @@ package proto
import (
"errors"
"fmt"
+ "io"
"reflect"
"strconv"
"sync"
@@ -91,14 +92,29 @@ func (n notLocker) Unlock() {}
// extendable returns the extendableProto interface for the given generated proto message.
// If the proto message has the old extension format, it returns a wrapper that implements
// the extendableProto interface.
-func extendable(p interface{}) (extendableProto, bool) {
- if ep, ok := p.(extendableProto); ok {
- return ep, ok
- }
- if ep, ok := p.(extendableProtoV1); ok {
- return extensionAdapter{ep}, ok
+func extendable(p interface{}) (extendableProto, error) {
+ switch p := p.(type) {
+ case extendableProto:
+ if isNilPtr(p) {
+ return nil, fmt.Errorf("proto: nil %T is not extendable", p)
+ }
+ return p, nil
+ case extendableProtoV1:
+ if isNilPtr(p) {
+ return nil, fmt.Errorf("proto: nil %T is not extendable", p)
+ }
+ return extensionAdapter{p}, nil
}
- return nil, false
+ // Don't allocate a specific error containing %T:
+ // this is the hot path for Clone and MarshalText.
+ return nil, errNotExtendable
+}
+
+var errNotExtendable = errors.New("proto: not an extendable proto.Message")
+
+func isNilPtr(x interface{}) bool {
+ v := reflect.ValueOf(x)
+ return v.Kind() == reflect.Ptr && v.IsNil()
}
// XXX_InternalExtensions is an internal representation of proto extensions.
@@ -143,9 +159,6 @@ func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Loc
return e.p.extensionMap, &e.p.mu
}
-var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
-var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
-
// ExtensionDesc represents an extension specification.
// Used in generated code from the protocol compiler.
type ExtensionDesc struct {
@@ -179,8 +192,8 @@ type Extension struct {
// SetRawExtension is for testing only.
func SetRawExtension(base Message, id int32, b []byte) {
- epb, ok := extendable(base)
- if !ok {
+ epb, err := extendable(base)
+ if err != nil {
return
}
extmap := epb.extensionsWrite()
@@ -205,7 +218,7 @@ func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
pbi = ea.extendableProtoV1
}
if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
- return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
+ return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
}
// Check the range.
if !isExtensionField(pb, extension.Field) {
@@ -250,85 +263,11 @@ func extensionProperties(ed *ExtensionDesc) *Properties {
return prop
}
-// encode encodes any unmarshaled (unencoded) extensions in e.
-func encodeExtensions(e *XXX_InternalExtensions) error {
- m, mu := e.extensionsRead()
- if m == nil {
- return nil // fast path
- }
- mu.Lock()
- defer mu.Unlock()
- return encodeExtensionsMap(m)
-}
-
-// encode encodes any unmarshaled (unencoded) extensions in e.
-func encodeExtensionsMap(m map[int32]Extension) error {
- for k, e := range m {
- if e.value == nil || e.desc == nil {
- // Extension is only in its encoded form.
- continue
- }
-
- // We don't skip extensions that have an encoded form set,
- // because the extension value may have been mutated after
- // the last time this function was called.
-
- et := reflect.TypeOf(e.desc.ExtensionType)
- props := extensionProperties(e.desc)
-
- p := NewBuffer(nil)
- // If e.value has type T, the encoder expects a *struct{ X T }.
- // Pass a *T with a zero field and hope it all works out.
- x := reflect.New(et)
- x.Elem().Set(reflect.ValueOf(e.value))
- if err := props.enc(p, props, toStructPointer(x)); err != nil {
- return err
- }
- e.enc = p.buf
- m[k] = e
- }
- return nil
-}
-
-func extensionsSize(e *XXX_InternalExtensions) (n int) {
- m, mu := e.extensionsRead()
- if m == nil {
- return 0
- }
- mu.Lock()
- defer mu.Unlock()
- return extensionsMapSize(m)
-}
-
-func extensionsMapSize(m map[int32]Extension) (n int) {
- for _, e := range m {
- if e.value == nil || e.desc == nil {
- // Extension is only in its encoded form.
- n += len(e.enc)
- continue
- }
-
- // We don't skip extensions that have an encoded form set,
- // because the extension value may have been mutated after
- // the last time this function was called.
-
- et := reflect.TypeOf(e.desc.ExtensionType)
- props := extensionProperties(e.desc)
-
- // If e.value has type T, the encoder expects a *struct{ X T }.
- // Pass a *T with a zero field and hope it all works out.
- x := reflect.New(et)
- x.Elem().Set(reflect.ValueOf(e.value))
- n += props.size(props, toStructPointer(x))
- }
- return
-}
-
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb Message, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
- epb, ok := extendable(pb)
- if !ok {
+ epb, err := extendable(pb)
+ if err != nil {
return false
}
extmap, mu := epb.extensionsRead()
@@ -336,15 +275,15 @@ func HasExtension(pb Message, extension *ExtensionDesc) bool {
return false
}
mu.Lock()
- _, ok = extmap[extension.Field]
+ _, ok := extmap[extension.Field]
mu.Unlock()
return ok
}
// ClearExtension removes the given extension from pb.
func ClearExtension(pb Message, extension *ExtensionDesc) {
- epb, ok := extendable(pb)
- if !ok {
+ epb, err := extendable(pb)
+ if err != nil {
return
}
// TODO: Check types, field numbers, etc.?
@@ -352,16 +291,26 @@ func ClearExtension(pb Message, extension *ExtensionDesc) {
delete(extmap, extension.Field)
}
-// GetExtension parses and returns the given extension of pb.
-// If the extension is not present and has no default value it returns ErrMissingExtension.
+// GetExtension retrieves a proto2 extended field from pb.
+//
+// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
+// then GetExtension parses the encoded field and returns a Go value of the specified type.
+// If the field is not present, then the default value is returned (if one is specified),
+// otherwise ErrMissingExtension is reported.
+//
+// If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil),
+// then GetExtension returns the raw encoded bytes of the field extension.
func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
- epb, ok := extendable(pb)
- if !ok {
- return nil, errors.New("proto: not an extendable proto")
+ epb, err := extendable(pb)
+ if err != nil {
+ return nil, err
}
- if err := checkExtensionTypes(epb, extension); err != nil {
- return nil, err
+ if extension.ExtendedType != nil {
+ // can only check type if this is a complete descriptor
+ if err := checkExtensionTypes(epb, extension); err != nil {
+ return nil, err
+ }
}
emap, mu := epb.extensionsRead()
@@ -388,6 +337,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
return e.value, nil
}
+ if extension.ExtensionType == nil {
+ // incomplete descriptor
+ return e.enc, nil
+ }
+
v, err := decodeExtension(e.enc, extension)
if err != nil {
return nil, err
@@ -405,6 +359,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
// defaultExtensionValue returns the default value for extension.
// If no default for an extension is defined ErrMissingExtension is returned.
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
+ if extension.ExtensionType == nil {
+ // incomplete descriptor, so no default
+ return nil, ErrMissingExtension
+ }
+
t := reflect.TypeOf(extension.ExtensionType)
props := extensionProperties(extension)
@@ -439,31 +398,28 @@ func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
// decodeExtension decodes an extension encoded in b.
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
- o := NewBuffer(b)
-
t := reflect.TypeOf(extension.ExtensionType)
-
- props := extensionProperties(extension)
+ unmarshal := typeUnmarshaler(t, extension.Tag)
// t is a pointer to a struct, pointer to basic type or a slice.
- // Allocate a "field" to store the pointer/slice itself; the
- // pointer/slice will be stored here. We pass
- // the address of this field to props.dec.
- // This passes a zero field and a *t and lets props.dec
- // interpret it as a *struct{ x t }.
+ // Allocate space to store the pointer/slice.
value := reflect.New(t).Elem()
+ var err error
for {
- // Discard wire type and field number varint. It isn't needed.
- if _, err := o.DecodeVarint(); err != nil {
- return nil, err
+ x, n := decodeVarint(b)
+ if n == 0 {
+ return nil, io.ErrUnexpectedEOF
}
+ b = b[n:]
+ wire := int(x) & 7
- if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
+ b, err = unmarshal(b, valToPointer(value.Addr()), wire)
+ if err != nil {
return nil, err
}
- if o.index >= len(o.buf) {
+ if len(b) == 0 {
break
}
}
@@ -473,9 +429,9 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
// The returned slice has the same length as es; missing extensions will appear as nil elements.
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
- epb, ok := extendable(pb)
- if !ok {
- return nil, errors.New("proto: not an extendable proto")
+ epb, err := extendable(pb)
+ if err != nil {
+ return nil, err
}
extensions = make([]interface{}, len(es))
for i, e := range es {
@@ -494,9 +450,9 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
// just the Field field, which defines the extension's field number.
func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
- epb, ok := extendable(pb)
- if !ok {
- return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
+ epb, err := extendable(pb)
+ if err != nil {
+ return nil, err
}
registeredExtensions := RegisteredExtensions(pb)
@@ -523,9 +479,9 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
// SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
- epb, ok := extendable(pb)
- if !ok {
- return errors.New("proto: not an extendable proto")
+ epb, err := extendable(pb)
+ if err != nil {
+ return err
}
if err := checkExtensionTypes(epb, extension); err != nil {
return err
@@ -550,8 +506,8 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
// ClearAllExtensions clears all extensions from pb.
func ClearAllExtensions(pb Message) {
- epb, ok := extendable(pb)
- if !ok {
+ epb, err := extendable(pb)
+ if err != nil {
return
}
m := epb.extensionsWrite()