encoding/gob: cover missed cases when checking ignore depth

This change makes sure that we are properly checking the ignored field
recursion depth in decIgnoreOpFor consistently. This prevents stack
exhaustion when attempting to decode a message that contains an
extremely deeply nested struct which is ignored.

Thanks to Md Sakib Anwar of The Ohio State University (anwar.40@osu.edu)
for reporting this issue.

Fixes #69139
Fixes CVE-2024-34156

Change-Id: Iacce06be95a5892b3064f1c40fcba2e2567862d6
Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/1440
Reviewed-by: Russ Cox <rsc@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-on: https://go-review.googlesource.com/c/go/+/611239
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>
This commit is contained in:
Roland Shoemaker 2024-05-03 09:21:39 -04:00 committed by Gopher Robot
parent dd2019528b
commit 08c84420bc
3 changed files with 27 additions and 8 deletions

View File

@ -911,8 +911,11 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
var maxIgnoreNestingDepth = 10000
// decIgnoreOpFor returns the decoding op for a field that has no destination.
func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, depth int) *decOp {
if depth > maxIgnoreNestingDepth {
func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) *decOp {
// Track how deep we've recursed trying to skip nested ignored fields.
dec.ignoreDepth++
defer func() { dec.ignoreDepth-- }()
if dec.ignoreDepth > maxIgnoreNestingDepth {
error_(errors.New("invalid nesting depth"))
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
@ -938,7 +941,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp,
errorf("bad data: undefined type %s", wireId.string())
case wire.ArrayT != nil:
elemId := wire.ArrayT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len)
}
@ -946,15 +949,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp,
case wire.MapT != nil:
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decIgnoreOpFor(keyId, inProgress, depth+1)
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
keyOp := dec.decIgnoreOpFor(keyId, inProgress)
elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreMap(state, *keyOp, *elemOp)
}
case wire.SliceT != nil:
elemId := wire.SliceT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreSlice(state, *elemOp)
}
@ -1115,7 +1118,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *de
func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine {
engine := new(decEngine)
engine.instr = make([]decInstr, 1) // one item
op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp), 0)
op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp))
ovfl := overflow(dec.typeString(remoteId))
engine.instr[0] = decInstr{*op, 0, nil, ovfl}
engine.numInstr = 1
@ -1160,7 +1163,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn
localField, present := srt.FieldByName(wireField.Name)
// TODO(r): anonymous names
if !present || !isExported(wireField.Name) {
op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp), 0)
op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp))
engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl}
continue
}

View File

@ -35,6 +35,8 @@ type Decoder struct {
freeList *decoderState // list of free decoderStates; avoids reallocation
countBuf []byte // used for decoding integers while parsing messages
err error
// ignoreDepth tracks the depth of recursively parsed ignored fields
ignoreDepth int
}
// NewDecoder returns a new decoder that reads from the [io.Reader].

View File

@ -806,6 +806,8 @@ func TestIgnoreDepthLimit(t *testing.T) {
defer func() { maxIgnoreNestingDepth = oldNestingDepth }()
b := new(bytes.Buffer)
enc := NewEncoder(b)
// Nested slice
typ := reflect.TypeFor[int]()
nested := reflect.ArrayOf(1, typ)
for i := 0; i < 100; i++ {
@ -819,4 +821,16 @@ func TestIgnoreDepthLimit(t *testing.T) {
if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
}
// Nested struct
nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: typ}})
for i := 0; i < 100; i++ {
nested = reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}})
}
badStruct = reflect.New(reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}}))
enc.Encode(badStruct.Interface())
dec = NewDecoder(b)
if err := dec.Decode(&output); err == nil || err.Error() != expectedErr {
t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err)
}
}