diff --git a/src/cmd/compile/internal/ir/stmt.go b/src/cmd/compile/internal/ir/stmt.go index 6a82df58f8..10f8b5e394 100644 --- a/src/cmd/compile/internal/ir/stmt.go +++ b/src/cmd/compile/internal/ir/stmt.go @@ -170,6 +170,17 @@ type CaseClause struct { miniStmt Var *Name // declared variable for this case in type switch List Nodes // list of expressions for switch, early select + + // RTypes is a list of RType expressions, which are copied to the + // corresponding OEQ nodes that are emitted when switch statements + // are desugared. RTypes[i] must be non-nil if the emitted + // comparison for List[i] will be a mixed interface/concrete + // comparison; see reflectdata.CompareRType for details. + // + // Because mixed interface/concrete switch cases are rare, we allow + // len(RTypes) < len(List). Missing entries are implicitly nil. + RTypes Nodes + Body Nodes } diff --git a/src/cmd/compile/internal/noder/reader.go b/src/cmd/compile/internal/noder/reader.go index 32276e7553..aa2cccf86b 100644 --- a/src/cmd/compile/internal/noder/reader.go +++ b/src/cmd/compile/internal/noder/reader.go @@ -1487,7 +1487,7 @@ func (r *reader) switchStmt(label *types.Sym) ir.Node { r.openScope() pos := r.pos() - var cases []ir.Node + var cases, rtypes []ir.Node if iface != nil { cases = make([]ir.Node, r.Len()) if len(cases) == 0 { @@ -1498,9 +1498,30 @@ func (r *reader) switchStmt(label *types.Sym) ir.Node { } } else { cases = r.exprList() + + tagType := types.Types[types.TBOOL] + if tag != nil { + tagType = tag.Type() + } + for i, cas := range cases { + if cas.Op() == ir.ONIL { + continue // never needs rtype + } + if tagType.IsInterface() != cas.Type().IsInterface() { + typ := tagType + if typ.IsInterface() { + typ = cas.Type() + } + for len(rtypes) < i { + rtypes = append(rtypes, nil) + } + rtypes = append(rtypes, reflectdata.TypePtr(typ)) + } + } } clause := ir.NewCaseStmt(pos, cases, nil) + clause.RTypes = rtypes if ident != nil { pos := r.pos() diff --git a/src/cmd/compile/internal/reflectdata/helpers.go b/src/cmd/compile/internal/reflectdata/helpers.go index 66f1864474..5edb495a81 100644 --- a/src/cmd/compile/internal/reflectdata/helpers.go +++ b/src/cmd/compile/internal/reflectdata/helpers.go @@ -84,9 +84,7 @@ func AppendElemRType(pos src.XPos, n *ir.CallExpr) ir.Node { func CompareRType(pos src.XPos, n *ir.BinaryExpr) ir.Node { assertOp2(n, ir.OEQ, ir.ONE) base.AssertfAt(n.X.Type().IsInterface() != n.Y.Type().IsInterface(), n.Pos(), "expect mixed interface and non-interface, have %L and %L", n.X, n.Y) - // TODO(mdempsky): Need to propagate RType from OSWITCH/OCASE - // clauses to emitted OEQ nodes. - if haveRType(n, n.RType, "RType", false) { + if haveRType(n, n.RType, "RType", true) { return n.RType } typ := n.X.Type() diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go index 6cac8f2937..257903c0b3 100644 --- a/src/cmd/compile/internal/walk/switch.go +++ b/src/cmd/compile/internal/walk/switch.go @@ -85,8 +85,12 @@ func walkSwitchExpr(sw *ir.SwitchStmt) { defaultGoto = jmp } - for _, n1 := range ncase.List { - s.Add(ncase.Pos(), n1, jmp) + for i, n1 := range ncase.List { + var rtype ir.Node + if i < len(ncase.RTypes) { + rtype = ncase.RTypes[i] + } + s.Add(ncase.Pos(), n1, rtype, jmp) } // Process body. @@ -124,11 +128,12 @@ type exprSwitch struct { type exprClause struct { pos src.XPos lo, hi ir.Node + rtype ir.Node // *runtime._type for OEQ node jmp ir.Node } -func (s *exprSwitch) Add(pos src.XPos, expr, jmp ir.Node) { - c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp} +func (s *exprSwitch) Add(pos src.XPos, expr, rtype, jmp ir.Node) { + c := exprClause{pos: pos, lo: expr, hi: expr, rtype: rtype, jmp: jmp} if types.IsOrdered[s.exprname.Type().Kind()] && expr.Op() == ir.OLITERAL { s.clauses = append(s.clauses, c) return @@ -233,7 +238,7 @@ func (s *exprSwitch) flush() { // Add length case to outer switch. cas := ir.NewBasicLit(pos, constant.MakeInt64(runLen(run))) jmp := ir.NewBranchStmt(pos, ir.OGOTO, label) - outer.Add(pos, cas, jmp) + outer.Add(pos, cas, nil, jmp) } s.done.Append(ir.NewLabelStmt(s.pos, outerLabel)) outer.Emit(&s.done) @@ -342,7 +347,9 @@ func (c *exprClause) test(exprname ir.Node) ir.Node { } } - return ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo) + n := ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo) + n.RType = c.rtype + return n } func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {