diff --git a/src/pkg/encoding/json/encode.go b/src/pkg/encoding/json/encode.go index f951250e98..7d6c71d7a9 100644 --- a/src/pkg/encoding/json/encode.go +++ b/src/pkg/encoding/json/encode.go @@ -305,7 +305,7 @@ func (e *encodeState) reflectValue(v reflect.Value) { valueEncoder(v)(e, v, false) } -type encoderFunc func(e *encodeState, v reflect.Value, _ bool) +type encoderFunc func(e *encodeState, v reflect.Value, quoted bool) var encoderCache struct { sync.RWMutex @@ -316,11 +316,10 @@ func valueEncoder(v reflect.Value) encoderFunc { if !v.IsValid() { return invalidValueEncoder } - t := v.Type() - return typeEncoder(t, v) + return typeEncoder(v.Type()) } -func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func typeEncoder(t reflect.Type) encoderFunc { encoderCache.RLock() f := encoderCache.m[t] encoderCache.RUnlock() @@ -346,7 +345,7 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { // Compute fields without lock. // Might duplicate effort but won't hold other computations back. - f = newTypeEncoder(t, vx) + f = newTypeEncoder(t, true) wg.Done() encoderCache.Lock() encoderCache.m[t] = f @@ -354,38 +353,33 @@ func typeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { return f } -// newTypeEncoder constructs an encoderFunc for a type. -// The provided vx is an example value of type t. It's the first seen -// value of that type and should not be used to encode. It may be -// zero. -func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { - if !vx.IsValid() { - vx = reflect.New(t).Elem() - } +var ( + marshalerType = reflect.TypeOf(new(Marshaler)).Elem() + textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() +) - _, ok := vx.Interface().(Marshaler) - if ok { +// newTypeEncoder constructs an encoderFunc for a type. +// The returned encoder only checks CanAddr when allowAddr is true. +func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { + if t.Implements(marshalerType) { return marshalerEncoder } - if vx.Kind() != reflect.Ptr && vx.CanAddr() { - _, ok = vx.Addr().Interface().(Marshaler) - if ok { - return addrMarshalerEncoder + if t.Kind() != reflect.Ptr && allowAddr { + if reflect.PtrTo(t).Implements(marshalerType) { + return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } } - _, ok = vx.Interface().(encoding.TextMarshaler) - if ok { + if t.Implements(textMarshalerType) { return textMarshalerEncoder } - if vx.Kind() != reflect.Ptr && vx.CanAddr() { - _, ok = vx.Addr().Interface().(encoding.TextMarshaler) - if ok { - return addrTextMarshalerEncoder + if t.Kind() != reflect.Ptr && allowAddr { + if reflect.PtrTo(t).Implements(textMarshalerType) { + return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) } } - switch vx.Kind() { + switch t.Kind() { case reflect.Bool: return boolEncoder case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -401,15 +395,15 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { case reflect.Interface: return interfaceEncoder case reflect.Struct: - return newStructEncoder(t, vx) + return newStructEncoder(t) case reflect.Map: - return newMapEncoder(t, vx) + return newMapEncoder(t) case reflect.Slice: - return newSliceEncoder(t, vx) + return newSliceEncoder(t) case reflect.Array: - return newArrayEncoder(t, vx) + return newArrayEncoder(t) case reflect.Ptr: - return newPtrEncoder(t, vx) + return newPtrEncoder(t) default: return unsupportedTypeEncoder } @@ -593,27 +587,19 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { } e.string(f.name) e.WriteByte(':') - if tenc := se.fieldEncs[i]; tenc != nil { - tenc(e, fv, f.quoted) - } else { - // Slower path. - e.reflectValue(fv) - } + se.fieldEncs[i](e, fv, f.quoted) } e.WriteByte('}') } -func newStructEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func newStructEncoder(t reflect.Type) encoderFunc { fields := cachedTypeFields(t) se := &structEncoder{ fields: fields, fieldEncs: make([]encoderFunc, len(fields)), } for i, f := range fields { - vxf := fieldByIndex(vx, f.index) - if vxf.IsValid() { - se.fieldEncs[i] = typeEncoder(vxf.Type(), vxf) - } + se.fieldEncs[i] = typeEncoder(typeByIndex(t, f.index)) } return se.encode } @@ -641,11 +627,11 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) { e.WriteByte('}') } -func newMapEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func newMapEncoder(t reflect.Type) encoderFunc { if t.Key().Kind() != reflect.String { return unsupportedTypeEncoder } - me := &mapEncoder{typeEncoder(vx.Type().Elem(), reflect.Value{})} + me := &mapEncoder{typeEncoder(t.Elem())} return me.encode } @@ -684,12 +670,12 @@ func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) { se.arrayEnc(e, v, false) } -func newSliceEncoder(t reflect.Type, vx reflect.Value) encoderFunc { +func newSliceEncoder(t reflect.Type) encoderFunc { // Byte slices get special treatment; arrays don't. - if vx.Type().Elem().Kind() == reflect.Uint8 { + if t.Elem().Kind() == reflect.Uint8 { return encodeByteSlice } - enc := &sliceEncoder{newArrayEncoder(t, vx)} + enc := &sliceEncoder{newArrayEncoder(t)} return enc.encode } @@ -709,8 +695,8 @@ func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) { e.WriteByte(']') } -func newArrayEncoder(t reflect.Type, vx reflect.Value) encoderFunc { - enc := &arrayEncoder{typeEncoder(t.Elem(), reflect.Value{})} +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := &arrayEncoder{typeEncoder(t.Elem())} return enc.encode } @@ -726,8 +712,27 @@ func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, _ bool) { pe.elemEnc(e, v.Elem(), false) } -func newPtrEncoder(t reflect.Type, vx reflect.Value) encoderFunc { - enc := &ptrEncoder{typeEncoder(t.Elem(), reflect.Value{})} +func newPtrEncoder(t reflect.Type) encoderFunc { + enc := &ptrEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type condAddrEncoder struct { + canAddrEnc, elseEnc encoderFunc +} + +func (ce *condAddrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) { + if v.CanAddr() { + ce.canAddrEnc(e, v, quoted) + } else { + ce.elseEnc(e, v, quoted) + } +} + +// newCondAddrEncoder returns an encoder that checks whether its value +// CanAddr and delegates to canAddrEnc if so, else to elseEnc. +func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { + enc := &condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} return enc.encode } @@ -763,6 +768,16 @@ func fieldByIndex(v reflect.Value, index []int) reflect.Value { return v } +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, i := range index { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + t = t.Field(i).Type + } + return t +} + // stringValues is a slice of reflect.Value holding *reflect.StringValue. // It implements the methods to sort by string. type stringValues []reflect.Value diff --git a/src/pkg/encoding/json/encode_test.go b/src/pkg/encoding/json/encode_test.go index 7052e1db7c..9395db7cb6 100644 --- a/src/pkg/encoding/json/encode_test.go +++ b/src/pkg/encoding/json/encode_test.go @@ -401,3 +401,27 @@ func TestStringBytes(t *testing.T) { t.Errorf("encodings differ at %#q vs %#q", enc, encBytes) } } + +func TestIssue6458(t *testing.T) { + type Foo struct { + M RawMessage + } + x := Foo{RawMessage(`"foo"`)} + + b, err := Marshal(&x) + if err != nil { + t.Fatal(err) + } + if want := `{"M":"foo"}`; string(b) != want { + t.Errorf("Marshal(&x) = %#q; want %#q", b, want) + } + + b, err = Marshal(x) + if err != nil { + t.Fatal(err) + } + + if want := `{"M":"ImZvbyI="}`; string(b) != want { + t.Errorf("Marshal(x) = %#q; want %#q", b, want) + } +}