diff options
Diffstat (limited to 'vendor/github.com/go-asn1-ber/asn1-ber/ber.go')
-rw-r--r-- | vendor/github.com/go-asn1-ber/asn1-ber/ber.go | 198 |
1 files changed, 143 insertions, 55 deletions
diff --git a/vendor/github.com/go-asn1-ber/asn1-ber/ber.go b/vendor/github.com/go-asn1-ber/asn1-ber/ber.go index 1e186cb8..4fd7a66e 100644 --- a/vendor/github.com/go-asn1-ber/asn1-ber/ber.go +++ b/vendor/github.com/go-asn1-ber/asn1-ber/ber.go @@ -8,6 +8,8 @@ import ( "math" "os" "reflect" + "time" + "unicode/utf8" ) // MaxPacketLengthBytes specifies the maximum allowed packet size when calling ReadPacket or DecodePacket. Set to 0 for @@ -143,20 +145,20 @@ var TypeMap = map[Type]string{ TypeConstructed: "Constructed", } -var Debug bool = false +var Debug = false func PrintBytes(out io.Writer, buf []byte, indent string) { - data_lines := make([]string, (len(buf)/30)+1) - num_lines := make([]string, (len(buf)/30)+1) + dataLines := make([]string, (len(buf)/30)+1) + numLines := make([]string, (len(buf)/30)+1) for i, b := range buf { - data_lines[i/30] += fmt.Sprintf("%02x ", b) - num_lines[i/30] += fmt.Sprintf("%02d ", (i+1)%100) + dataLines[i/30] += fmt.Sprintf("%02x ", b) + numLines[i/30] += fmt.Sprintf("%02d ", (i+1)%100) } - for i := 0; i < len(data_lines); i++ { - out.Write([]byte(indent + data_lines[i] + "\n")) - out.Write([]byte(indent + num_lines[i] + "\n\n")) + for i := 0; i < len(dataLines); i++ { + _, _ = out.Write([]byte(indent + dataLines[i] + "\n")) + _, _ = out.Write([]byte(indent + numLines[i] + "\n\n")) } } @@ -169,20 +171,20 @@ func PrintPacket(p *Packet) { } func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) { - indent_str := "" + indentStr := "" - for len(indent_str) != indent { - indent_str += " " + for len(indentStr) != indent { + indentStr += " " } - class_str := ClassMap[p.ClassType] + classStr := ClassMap[p.ClassType] - tagtype_str := TypeMap[p.TagType] + tagTypeStr := TypeMap[p.TagType] - tag_str := fmt.Sprintf("0x%02X", p.Tag) + tagStr := fmt.Sprintf("0x%02X", p.Tag) if p.ClassType == ClassUniversal { - tag_str = tagMap[p.Tag] + tagStr = tagMap[p.Tag] } value := fmt.Sprint(p.Value) @@ -192,10 +194,10 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) { description = p.Description + ": " } - fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indent_str, description, class_str, tagtype_str, tag_str, p.Data.Len(), value) + _, _ = fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indentStr, description, classStr, tagTypeStr, tagStr, p.Data.Len(), value) if printBytes { - PrintBytes(out, p.Bytes(), indent_str) + PrintBytes(out, p.Bytes(), indentStr) } for _, child := range p.Children { @@ -203,7 +205,7 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) { } } -// ReadPacket reads a single Packet from the reader +// ReadPacket reads a single Packet from the reader. func ReadPacket(reader io.Reader) (*Packet, error) { p, _, err := readPacket(reader) if err != nil { @@ -239,7 +241,7 @@ func encodeInteger(i int64) []byte { var j int for ; n > 0; n-- { - out[j] = (byte(i >> uint((n-1)*8))) + out[j] = byte(i >> uint((n-1)*8)) j++ } @@ -271,7 +273,7 @@ func DecodePacket(data []byte) *Packet { } // DecodePacketErr decodes the given bytes into a single Packet -// If a decode error is encountered, nil is returned +// If a decode error is encountered, nil is returned. func DecodePacketErr(data []byte) (*Packet, error) { p, _, err := readPacket(bytes.NewBuffer(data)) if err != nil { @@ -280,7 +282,7 @@ func DecodePacketErr(data []byte) (*Packet, error) { return p, nil } -// readPacket reads a single Packet from the reader, returning the number of bytes read +// readPacket reads a single Packet from the reader, returning the number of bytes read. func readPacket(reader io.Reader) (*Packet, int, error) { identifier, length, read, err := readHeader(reader) if err != nil { @@ -342,7 +344,7 @@ func readPacket(reader io.Reader) (*Packet, int, error) { if MaxPacketLengthBytes > 0 && int64(length) > MaxPacketLengthBytes { return nil, read, fmt.Errorf("length %d greater than maximum %d", length, MaxPacketLengthBytes) } - content := make([]byte, length, length) + content := make([]byte, length) if length > 0 { _, err := io.ReadFull(reader, content) if err != nil { @@ -377,22 +379,42 @@ func readPacket(reader io.Reader) (*Packet, int, error) { case TagObjectDescriptor: case TagExternal: case TagRealFloat: + p.Value, err = ParseReal(content) case TagEnumerated: p.Value, _ = ParseInt64(content) case TagEmbeddedPDV: case TagUTF8String: - p.Value = DecodeString(content) + val := DecodeString(content) + if !utf8.Valid([]byte(val)) { + err = errors.New("invalid UTF-8 string") + } else { + p.Value = val + } case TagRelativeOID: case TagSequence: case TagSet: case TagNumericString: case TagPrintableString: - p.Value = DecodeString(content) + val := DecodeString(content) + if err = isPrintableString(val); err == nil { + p.Value = val + } case TagT61String: case TagVideotexString: case TagIA5String: + val := DecodeString(content) + for i, c := range val { + if c >= 0x7F { + err = fmt.Errorf("invalid character for IA5String at pos %d: %c", i, c) + break + } + } + if err == nil { + p.Value = val + } case TagUTCTime: case TagGeneralizedTime: + p.Value, err = ParseGeneralizedTime(content) case TagGraphicString: case TagVisibleString: case TagGeneralString: @@ -404,7 +426,24 @@ func readPacket(reader io.Reader) (*Packet, int, error) { p.Data.Write(content) } - return p, read, nil + return p, read, err +} + +func isPrintableString(val string) error { + for i, c := range val { + switch { + case c >= 'a' && c <= 'z': + case c >= 'A' && c <= 'Z': + case c >= '0' && c <= '9': + default: + switch c { + case '\'', '(', ')', '+', ',', '-', '.', '=', '/', ':', '?', ' ': + default: + return fmt.Errorf("invalid character in position %d", i) + } + } + } + return nil } func (p *Packet) Bytes() []byte { @@ -422,77 +461,99 @@ func (p *Packet) AppendChild(child *Packet) { p.Children = append(p.Children, child) } -func Encode(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet { +func Encode(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet { p := new(Packet) - p.ClassType = ClassType - p.TagType = TagType - p.Tag = Tag + p.ClassType = classType + p.TagType = tagType + p.Tag = tag p.Data = new(bytes.Buffer) p.Children = make([]*Packet, 0, 2) - p.Value = Value - p.Description = Description + p.Value = value + p.Description = description - if Value != nil { - v := reflect.ValueOf(Value) + if value != nil { + v := reflect.ValueOf(value) - if ClassType == ClassUniversal { - switch Tag { + if classType == ClassUniversal { + switch tag { case TagOctetString: sv, ok := v.Interface().(string) if ok { p.Data.Write([]byte(sv)) } + case TagEnumerated: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } + case TagEmbeddedPDV: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } + } + } else if classType == ClassContext { + switch tag { + case TagEnumerated: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } + case TagEmbeddedPDV: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } } } } - return p } -func NewSequence(Description string) *Packet { - return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, Description) +func NewSequence(description string) *Packet { + return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, description) } -func NewBoolean(ClassType Class, TagType Type, Tag Tag, Value bool, Description string) *Packet { +func NewBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet { intValue := int64(0) - if Value { + if value { intValue = 1 } - p := Encode(ClassType, TagType, Tag, nil, Description) + p := Encode(classType, tagType, tag, nil, description) - p.Value = Value + p.Value = value p.Data.Write(encodeInteger(intValue)) return p } -// NewLDAPBoolean returns a RFC 4511-compliant Boolean packet -func NewLDAPBoolean(Value bool, Description string) *Packet { +// NewLDAPBoolean returns a RFC 4511-compliant Boolean packet. +func NewLDAPBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet { intValue := int64(0) - if Value { + if value { intValue = 255 } - p := Encode(ClassUniversal, TypePrimitive, TagBoolean, nil, Description) + p := Encode(classType, tagType, tag, nil, description) - p.Value = Value + p.Value = value p.Data.Write(encodeInteger(intValue)) return p } -func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet { - p := Encode(ClassType, TagType, Tag, nil, Description) +func NewInteger(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) - p.Value = Value - switch v := Value.(type) { + p.Value = value + switch v := value.(type) { case int: p.Data.Write(encodeInteger(int64(v))) case uint: @@ -522,11 +583,38 @@ func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Descr return p } -func NewString(ClassType Class, TagType Type, Tag Tag, Value, Description string) *Packet { - p := Encode(ClassType, TagType, Tag, nil, Description) +func NewString(classType Class, tagType Type, tag Tag, value, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) + + p.Value = value + p.Data.Write([]byte(value)) - p.Value = Value - p.Data.Write([]byte(Value)) + return p +} +func NewGeneralizedTime(classType Class, tagType Type, tag Tag, value time.Time, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) + var s string + if value.Nanosecond() != 0 { + s = value.Format(`20060102150405.000000000Z`) + } else { + s = value.Format(`20060102150405Z`) + } + p.Value = s + p.Data.Write([]byte(s)) + return p +} + +func NewReal(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) + + switch v := value.(type) { + case float64: + p.Data.Write(encodeFloat(v)) + case float32: + p.Data.Write(encodeFloat(float64(v))) + default: + panic(fmt.Sprintf("Invalid type %T, expected float{64|32}", v)) + } return p } |