diff --git a/src/encoding/xml/marshal.go b/src/encoding/xml/marshal.go index 63f8e2aa87..100e41df24 100644 --- a/src/encoding/xml/marshal.go +++ b/src/encoding/xml/marshal.go @@ -578,12 +578,14 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat // 3. type name var start StartElement - // Historic behaviour: elements use the default name space - // they are contained in by default. - start.Name.Space = p.defaultNS + // explicitNS records whether the element's name + // space has been explicitly set (for example an + // and XMLName field). + explicitNS := false if startTemplate != nil { start.Name = startTemplate.Name + explicitNS = true start.Attr = append(start.Attr, startTemplate.Attr...) } else if tinfo.xmlname != nil { xmlname := tinfo.xmlname @@ -592,11 +594,13 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat } else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" { start.Name = v } + explicitNS = true } if start.Name.Local == "" && finfo != nil { start.Name.Local = finfo.name if finfo.xmlns != "" { start.Name.Space = finfo.xmlns + explicitNS = true } } if start.Name.Local == "" { @@ -606,9 +610,12 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat } start.Name.Local = name } - // Historic behaviour: an element that's in a namespace sets - // the default namespace for all elements contained within it. - start.setDefaultNamespace() + + // defaultNS records the default name space as set by a xmlns="..." + // attribute. We don't set p.defaultNS because we want to let + // the attribute writing code (in p.defineNS) be solely responsible + // for maintaining that. + defaultNS := p.defaultNS // Attributes for i := range tinfo.fields { @@ -616,81 +623,26 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat if finfo.flags&fAttr == 0 { continue } - fv := finfo.value(val) - name := Name{Space: finfo.xmlns, Local: finfo.name} - - if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) { - continue - } - - if fv.Kind() == reflect.Interface && fv.IsNil() { - continue - } - - if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) { - attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name) - if err != nil { - return err - } - if attr.Name.Local != "" { - start.Attr = append(start.Attr, attr) - } - continue - } - - if fv.CanAddr() { - pv := fv.Addr() - if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { - attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) - if err != nil { - return err - } - if attr.Name.Local != "" { - start.Attr = append(start.Attr, attr) - } - continue - } - } - - if fv.CanInterface() && fv.Type().Implements(textMarshalerType) { - text, err := fv.Interface().(encoding.TextMarshaler).MarshalText() - if err != nil { - return err - } - start.Attr = append(start.Attr, Attr{name, string(text)}) - continue - } - - if fv.CanAddr() { - pv := fv.Addr() - if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { - text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() - if err != nil { - return err - } - start.Attr = append(start.Attr, Attr{name, string(text)}) - continue - } - } - - // Dereference or skip nil pointer, interface values. - switch fv.Kind() { - case reflect.Ptr, reflect.Interface: - if fv.IsNil() { - continue - } - fv = fv.Elem() - } - - s, b, err := p.marshalSimple(fv.Type(), fv) + attr, add, err := p.fieldAttr(finfo, val) if err != nil { return err } - if b != nil { - s = string(b) + if !add { + continue + } + start.Attr = append(start.Attr, attr) + if attr.Name.Space == "" && attr.Name.Local == "xmlns" { + defaultNS = attr.Value } - start.Attr = append(start.Attr, Attr{name, s}) } + if !explicitNS { + // Historic behavior: elements use the default name space + // they are contained in by default. + start.Name.Space = defaultNS + } + // Historic behaviour: an element that's in a namespace sets + // the default namespace for all elements contained within it. + start.setDefaultNamespace() if err := p.writeStart(&start); err != nil { return err @@ -719,6 +671,64 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat return p.cachedWriteError() } +// fieldAttr returns the attribute of the given field and +// whether it should actually be added as an attribute; +// val holds the value containing the field. +func (p *printer) fieldAttr(finfo *fieldInfo, val reflect.Value) (Attr, bool, error) { + fv := finfo.value(val) + name := Name{Space: finfo.xmlns, Local: finfo.name} + if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) { + return Attr{}, false, nil + } + if fv.Kind() == reflect.Interface && fv.IsNil() { + return Attr{}, false, nil + } + if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) { + attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + return attr, attr.Name.Local != "", err + } + if fv.CanAddr() { + pv := fv.Addr() + if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { + attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + return attr, attr.Name.Local != "", err + } + } + if fv.CanInterface() && fv.Type().Implements(textMarshalerType) { + text, err := fv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return Attr{}, false, err + } + return Attr{name, string(text)}, true, nil + } + if fv.CanAddr() { + pv := fv.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return Attr{}, false, err + } + return Attr{name, string(text)}, true, nil + } + } + // Dereference or skip nil pointer, interface values. + switch fv.Kind() { + case reflect.Ptr, reflect.Interface: + if fv.IsNil() { + return Attr{}, false, nil + } + fv = fv.Elem() + } + s, b, err := p.marshalSimple(fv.Type(), fv) + if err != nil { + return Attr{}, false, err + } + if b != nil { + s = string(b) + } + return Attr{name, s}, true, nil +} + // defaultStart returns the default start element to use, // given the reflect type, field info, and start template. func (p *printer) defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement { diff --git a/src/encoding/xml/marshal_test.go b/src/encoding/xml/marshal_test.go index 394855782e..4c478ddded 100644 --- a/src/encoding/xml/marshal_test.go +++ b/src/encoding/xml/marshal_test.go @@ -174,6 +174,11 @@ type XMLNameWithTag struct { Value string `xml:",chardata"` } +type XMLNameWithNSTag struct { + XMLName Name `xml:"ns InXMLNameWithNSTag"` + Value string `xml:",chardata"` +} + type XMLNameWithoutTag struct { XMLName Name Value string `xml:",chardata"` @@ -302,8 +307,7 @@ func (m *MyMarshalerTest) MarshalXML(e *Encoder, start StartElement) error { return nil } -type MyMarshalerAttrTest struct { -} +type MyMarshalerAttrTest struct{} var _ MarshalerAttr = (*MyMarshalerAttrTest)(nil) @@ -311,10 +315,22 @@ func (m *MyMarshalerAttrTest) MarshalXMLAttr(name Name) (Attr, error) { return Attr{name, "hello world"}, nil } +type MyMarshalerValueAttrTest struct{} + +var _ MarshalerAttr = MyMarshalerValueAttrTest{} + +func (m MyMarshalerValueAttrTest) MarshalXMLAttr(name Name) (Attr, error) { + return Attr{name, "hello world"}, nil +} + type MarshalerStruct struct { Foo MyMarshalerAttrTest `xml:",attr"` } +type MarshalerValueStruct struct { + Foo MyMarshalerValueAttrTest `xml:",attr"` +} + type InnerStruct struct { XMLName Name `xml:"testns outer"` } @@ -350,6 +366,34 @@ type NestedAndComment struct { Comment string `xml:",comment"` } +type XMLNSFieldStruct struct { + Ns string `xml:"xmlns,attr"` + Body string +} + +type NamedXMLNSFieldStruct struct { + XMLName struct{} `xml:"testns test"` + Ns string `xml:"xmlns,attr"` + Body string +} + +type XMLNSFieldStructWithOmitEmpty struct { + Ns string `xml:"xmlns,attr,omitempty"` + Body string +} + +type NamedXMLNSFieldStructWithEmptyNamespace struct { + XMLName struct{} `xml:"test"` + Ns string `xml:"xmlns,attr"` + Body string +} + +type RecursiveXMLNSFieldStruct struct { + Ns string `xml:"xmlns,attr"` + Body *RecursiveXMLNSFieldStruct `xml:",omitempty"` + Text string `xml:",omitempty"` +} + func ifaceptr(x interface{}) interface{} { return &x } @@ -989,6 +1033,10 @@ var marshalTests = []struct { ExpectXML: ``, Value: &MarshalerStruct{}, }, + { + ExpectXML: ``, + Value: &MarshalerValueStruct{}, + }, { ExpectXML: ``, Value: &OuterStruct{IntAttr: 10}, @@ -1013,6 +1061,39 @@ var marshalTests = []struct { ExpectXML: ``, Value: &NestedAndComment{AB: make([]string, 2), Comment: "test"}, }, + { + ExpectXML: `hello world`, + Value: &XMLNSFieldStruct{Ns: "http://example.com/ns", Body: "hello world"}, + }, + { + ExpectXML: `hello world`, + Value: &NamedXMLNSFieldStruct{Ns: "http://example.com/ns", Body: "hello world"}, + }, + { + ExpectXML: `hello world`, + Value: &NamedXMLNSFieldStruct{Ns: "", Body: "hello world"}, + }, + { + ExpectXML: `hello world`, + Value: &XMLNSFieldStructWithOmitEmpty{Body: "hello world"}, + }, + { + // The xmlns attribute must be ignored because the + // element is in the empty namespace, so it's not possible + // to set the default namespace to something non-empty. + ExpectXML: `hello world`, + Value: &NamedXMLNSFieldStructWithEmptyNamespace{Ns: "foo", Body: "hello world"}, + MarshalOnly: true, + }, + { + ExpectXML: `hello world`, + Value: &RecursiveXMLNSFieldStruct{ + Ns: "foo", + Body: &RecursiveXMLNSFieldStruct{ + Text: "hello world", + }, + }, + }, } func TestMarshal(t *testing.T) { @@ -1235,6 +1316,100 @@ func TestMarshalFlush(t *testing.T) { } } +var encodeElementTests = []struct { + desc string + value interface{} + start StartElement + expectXML string +}{{ + desc: "simple string", + value: "hello", + start: StartElement{ + Name: Name{Local: "a"}, + }, + expectXML: `hello`, +}, { + desc: "string with added attributes", + value: "hello", + start: StartElement{ + Name: Name{Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "x"}, + Value: "y", + }, { + Name: Name{Local: "foo"}, + Value: "bar", + }}, + }, + expectXML: `hello`, +}, { + desc: "start element with default name space", + value: struct { + Foo XMLNameWithNSTag + }{ + Foo: XMLNameWithNSTag{ + Value: "hello", + }, + }, + start: StartElement{ + Name: Name{Space: "ns", Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "xmlns"}, + // "ns" is the name space defined in XMLNameWithNSTag + Value: "ns", + }}, + }, + expectXML: `hello`, +}, { + desc: "start element in name space with different default name space", + value: struct { + Foo XMLNameWithNSTag + }{ + Foo: XMLNameWithNSTag{ + Value: "hello", + }, + }, + start: StartElement{ + Name: Name{Space: "ns2", Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "xmlns"}, + // "ns" is the name space defined in XMLNameWithNSTag + Value: "ns", + }}, + }, + expectXML: `hello`, +}, { + desc: "XMLMarshaler with start element with default name space", + value: &MyMarshalerTest{}, + start: StartElement{ + Name: Name{Space: "ns2", Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "xmlns"}, + // "ns" is the name space defined in XMLNameWithNSTag + Value: "ns", + }}, + }, + expectXML: `hello world`, +}} + +func TestEncodeElement(t *testing.T) { + for idx, test := range encodeElementTests { + var buf bytes.Buffer + enc := NewEncoder(&buf) + err := enc.EncodeElement(test.value, test.start) + if err != nil { + t.Fatalf("enc.EncodeElement: %v", err) + } + err = enc.Flush() + if err != nil { + t.Fatalf("enc.Flush: %v", err) + } + if got, want := buf.String(), test.expectXML; got != want { + t.Errorf("#%d(%s): EncodeElement(%#v, %#v):\nhave %#q\nwant %#q", idx, test.desc, test.value, test.start, got, want) + } + } +} + func BenchmarkMarshal(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { diff --git a/src/encoding/xml/xml.go b/src/encoding/xml/xml.go index 3090750c48..ffab4a70c9 100644 --- a/src/encoding/xml/xml.go +++ b/src/encoding/xml/xml.go @@ -91,6 +91,12 @@ func (e *StartElement) setDefaultNamespace() { // or was just using the default namespace. return } + // Don't add a default name space if there's already one set. + for _, attr := range e.Attr { + if attr.Name.Space == "" && attr.Name.Local == "xmlns" { + return + } + } e.Attr = append(e.Attr, Attr{ Name: Name{ Local: "xmlns",