diff --git a/src/encoding/asn1/asn1.go b/src/encoding/asn1/asn1.go index 73193c3407..1a01838938 100644 --- a/src/encoding/asn1/asn1.go +++ b/src/encoding/asn1/asn1.go @@ -536,7 +536,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i // a number of ASN.1 values from the given byte slice and returns them as a // slice of Go values of the given type. func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err error) { - expectedTag, compoundType, ok := getUniversalType(elemType) + matchAny, expectedTag, compoundType, ok := getUniversalType(elemType) if !ok { err = StructuralError{"unknown Go type for slice"} return @@ -562,7 +562,7 @@ func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type t.tag = TagUTCTime } - if t.class != ClassUniversal || t.isCompound != compoundType || t.tag != expectedTag { + if !matchAny && (t.class != ClassUniversal || t.isCompound != compoundType || t.tag != expectedTag) { err = StructuralError{"sequence tag mismatch"} return } @@ -617,23 +617,6 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam return } - // Deal with raw values. - if fieldType == rawValueType { - var t tagAndLength - t, offset, err = parseTagAndLength(bytes, offset) - if err != nil { - return - } - if invalidLength(offset, t.length, len(bytes)) { - err = SyntaxError{"data truncated"} - return - } - result := RawValue{t.class, t.tag, t.isCompound, bytes[offset : offset+t.length], bytes[initOffset : offset+t.length]} - offset += t.length - v.Set(reflect.ValueOf(result)) - return - } - // Deal with the ANY type. if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 { var t tagAndLength @@ -682,11 +665,6 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } return } - universalTag, compoundType, ok1 := getUniversalType(fieldType) - if !ok1 { - err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)} - return - } t, offset, err := parseTagAndLength(bytes, offset) if err != nil { @@ -702,7 +680,9 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam return } if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) { - if t.length > 0 { + if fieldType == rawValueType { + // The inner element should not be parsed for RawValues. + } else if t.length > 0 { t, offset, err = parseTagAndLength(bytes, offset) if err != nil { return @@ -727,6 +707,12 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam } } + matchAny, universalTag, compoundType, ok1 := getUniversalType(fieldType) + if !ok1 { + err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)} + return + } + // Special case for strings: all the ASN.1 string types map to the Go // type string. getUniversalType returns the tag for PrintableString // when it sees a string, so if we see a different string type on the @@ -752,21 +738,25 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam universalTag = TagSet } + matchAnyClassAndTag := matchAny expectedClass := ClassUniversal expectedTag := universalTag if !params.explicit && params.tag != nil { expectedClass = ClassContextSpecific expectedTag = *params.tag + matchAnyClassAndTag = false } if !params.explicit && params.application && params.tag != nil { expectedClass = ClassApplication expectedTag = *params.tag + matchAnyClassAndTag = false } // We have unwrapped any explicit tagging at this point. - if t.class != expectedClass || t.tag != expectedTag || t.isCompound != compoundType { + if !matchAnyClassAndTag && (t.class != expectedClass || t.tag != expectedTag) || + (!matchAny && t.isCompound != compoundType) { // Tags don't match. Again, it could be an optional element. ok := setDefaultValue(v, params) if ok { @@ -785,6 +775,10 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam // We deal with the structures defined in this package first. switch fieldType { + case rawValueType: + result := RawValue{t.class, t.tag, t.isCompound, innerBytes, bytes[initOffset:offset]} + v.Set(reflect.ValueOf(result)) + return case objectIdentifierType: newSlice, err1 := parseObjectIdentifier(innerBytes) v.Set(reflect.MakeSlice(v.Type(), len(newSlice), len(newSlice))) diff --git a/src/encoding/asn1/asn1_test.go b/src/encoding/asn1/asn1_test.go index 355ff8c41c..7ff9c05cc0 100644 --- a/src/encoding/asn1/asn1_test.go +++ b/src/encoding/asn1/asn1_test.go @@ -1033,3 +1033,60 @@ func TestNull(t *testing.T) { t.Errorf("Expected Unmarshal of NullBytes to yield %v, got %v", NullRawValue, unmarshaled) } } + +func TestExplicitTagRawValueStruct(t *testing.T) { + type foo struct { + A RawValue `asn1:"optional,explicit,tag:5"` + B []byte `asn1:"optional,explicit,tag:6"` + } + before := foo{B: []byte{1, 2, 3}} + derBytes, err := Marshal(before) + if err != nil { + t.Fatal(err) + } + + var after foo + if rest, err := Unmarshal(derBytes, &after); err != nil || len(rest) != 0 { + t.Fatal(err) + } + + got := fmt.Sprintf("%#v", after) + want := fmt.Sprintf("%#v", before) + if got != want { + t.Errorf("got %s, want %s (DER: %x)", got, want, derBytes) + } +} + +func TestTaggedRawValue(t *testing.T) { + type taggedRawValue struct { + A RawValue `asn1:"tag:5"` + } + type untaggedRawValue struct { + A RawValue + } + const isCompound = 0x20 + const tag = 5 + + tests := []struct { + shouldMatch bool + derBytes []byte + }{ + {false, []byte{0x30, 3, TagInteger, 1, 1}}, + {true, []byte{0x30, 3, (ClassContextSpecific << 6) | tag, 1, 1}}, + {true, []byte{0x30, 3, (ClassContextSpecific << 6) | tag | isCompound, 1, 1}}, + {false, []byte{0x30, 3, (ClassApplication << 6) | tag | isCompound, 1, 1}}, + } + + for i, test := range tests { + var tagged taggedRawValue + if _, err := Unmarshal(test.derBytes, &tagged); (err == nil) != test.shouldMatch { + t.Errorf("#%d: unexpected result parsing %x: %s", i, test.derBytes, err) + } + + // An untagged RawValue should accept anything. + var untagged untaggedRawValue + if _, err := Unmarshal(test.derBytes, &untagged); err != nil { + t.Errorf("#%d: unexpected failure parsing %x with untagged RawValue: %s", i, test.derBytes, err) + } + } +} diff --git a/src/encoding/asn1/common.go b/src/encoding/asn1/common.go index cd93b27ecb..3e4dfd1679 100644 --- a/src/encoding/asn1/common.go +++ b/src/encoding/asn1/common.go @@ -136,36 +136,38 @@ func parseFieldParameters(str string) (ret fieldParameters) { // Given a reflected Go type, getUniversalType returns the default tag number // and expected compound flag. -func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) { +func getUniversalType(t reflect.Type) (matchAny bool, tagNumber int, isCompound, ok bool) { switch t { + case rawValueType: + return true, -1, false, true case objectIdentifierType: - return TagOID, false, true + return false, TagOID, false, true case bitStringType: - return TagBitString, false, true + return false, TagBitString, false, true case timeType: - return TagUTCTime, false, true + return false, TagUTCTime, false, true case enumeratedType: - return TagEnum, false, true + return false, TagEnum, false, true case bigIntType: - return TagInteger, false, true + return false, TagInteger, false, true } switch t.Kind() { case reflect.Bool: - return TagBoolean, false, true + return false, TagBoolean, false, true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return TagInteger, false, true + return false, TagInteger, false, true case reflect.Struct: - return TagSequence, true, true + return false, TagSequence, true, true case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { - return TagOctetString, false, true + return false, TagOctetString, false, true } if strings.HasSuffix(t.Name(), "SET") { - return TagSet, true, true + return false, TagSet, true, true } - return TagSequence, true, true + return false, TagSequence, true, true case reflect.String: - return TagPrintableString, false, true + return false, TagPrintableString, false, true } - return 0, false, false + return false, 0, false, false } diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go index b081431200..7f8119e9ae 100644 --- a/src/encoding/asn1/marshal.go +++ b/src/encoding/asn1/marshal.go @@ -556,8 +556,8 @@ func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) { return t, nil } - tag, isCompound, ok := getUniversalType(v.Type()) - if !ok { + matchAny, tag, isCompound, ok := getUniversalType(v.Type()) + if !ok || matchAny { return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} }