diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go index da343587a4..0d34d5aa1e 100644 --- a/src/encoding/asn1/marshal.go +++ b/src/encoding/asn1/marshal.go @@ -80,26 +80,6 @@ func (m multiEncoder) Encode(dst []byte) { } } -type octetSorter [][]byte - -func (s octetSorter) Len() int { - return len(s) -} - -func (s octetSorter) Swap(i, j int) { - s[i], s[j] = s[j], s[i] -} - -func (s octetSorter) Less(i, j int) bool { - // Since we are using bytes.Compare to compare TLV encodings we - // don't need to right pad s[i] and s[j] to the same length as - // suggested in X690. If len(s[i]) < len(s[j]) the length octet of - // s[i], which is the first determining byte, will inherently be - // smaller than the length octet of s[j]. This lets us skip the - // padding step. - return bytes.Compare(s[i], s[j]) < 0 -} - type setEncoder []encoder func (s setEncoder) Len() int { @@ -125,7 +105,15 @@ func (s setEncoder) Encode(dst []byte) { e.Encode(l[i]) } - sort.Sort(octetSorter(l)) + sort.Slice(l, func(i, j int) bool { + // Since we are using bytes.Compare to compare TLV encodings we + // don't need to right pad s[i] and s[j] to the same length as + // suggested in X690. If len(s[i]) < len(s[j]) the length octet of + // s[i], which is the first determining byte, will inherently be + // smaller than the length octet of s[j]. This lets us skip the + // padding step. + return bytes.Compare(l[i], l[j]) < 0 + }) var off int for _, b := range l { @@ -677,6 +665,15 @@ func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) { tag = TagSet } + // makeField can be called for a slice that should be treated as a SET + // but doesn't have params.set set, for instance when using a slice + // with the SET type name suffix. In this case getUniversalType returns + // TagSet, but makeBody doesn't know about that so will treat the slice + // as a sequence. To work around this we set params.set. + if tag == TagSet && !params.set { + params.set = true + } + t := new(taggedEncoder) t.body, err = makeBody(v, params) diff --git a/src/encoding/asn1/marshal_test.go b/src/encoding/asn1/marshal_test.go index 5aa1da68b0..529052285f 100644 --- a/src/encoding/asn1/marshal_test.go +++ b/src/encoding/asn1/marshal_test.go @@ -347,6 +347,32 @@ func TestSetEncoder(t *testing.T) { t.Error("Unmarshal returned extra garbage") } if !reflect.DeepEqual(expectedOrder, resultStruct.Strings) { - t.Errorf("Unexpected SET content. got: %s, want: %s", resultStruct.Strings, resultStruct.Strings) + t.Errorf("Unexpected SET content. got: %s, want: %s", resultStruct.Strings, expectedOrder) + } +} + +func TestSetEncoderSETSliceSuffix(t *testing.T) { + type testSetSET []string + testSet := testSetSET{"a", "aa", "b", "bb", "c", "cc"} + + // Expected ordering of the SET should be: + // a, b, c, aa, bb, cc + + output, err := Marshal(testSet) + if err != nil { + t.Errorf("%v", err) + } + + expectedOrder := testSetSET{"a", "b", "c", "aa", "bb", "cc"} + var resultSet testSetSET + rest, err := Unmarshal(output, &resultSet) + if err != nil { + t.Errorf("%v", err) + } + if len(rest) != 0 { + t.Error("Unmarshal returned extra garbage") + } + if !reflect.DeepEqual(expectedOrder, resultSet) { + t.Errorf("Unexpected SET content. got: %s, want: %s", resultSet, expectedOrder) } }