diff --git a/src/cmd/compile/internal/ir/stmt.go b/src/cmd/compile/internal/ir/stmt.go index 81d139cf12..0801ecdd9e 100644 --- a/src/cmd/compile/internal/ir/stmt.go +++ b/src/cmd/compile/internal/ir/stmt.go @@ -313,17 +313,19 @@ func NewJumpTableStmt(pos src.XPos, idx Node) *JumpTableStmt { // An InterfaceSwitchStmt is used to implement type switches. // Its semantics are: // -// if RuntimeType implements Descriptor.Cases[0] { -// Case, Itab = 0, itab -// } else if RuntimeType implements Descriptor.Cases[1] { -// Case, Itab = 1, itab -// ... -// } else if RuntimeType implements Descriptor.Cases[N-1] { -// Case, Itab = N-1, itab -// } else { -// Case, Itab = len(cases), nil -// } +// if RuntimeType implements Descriptor.Cases[0] { +// Case, Itab = 0, itab +// } else if RuntimeType implements Descriptor.Cases[1] { +// Case, Itab = 1, itab +// ... +// } else if RuntimeType implements Descriptor.Cases[N-1] { +// Case, Itab = N-1, itab +// } else { +// Case, Itab = len(cases), nil +// } +// // RuntimeType must be a non-nil *runtime._type. +// Hash must be the hash field of RuntimeType (or its copy loaded from an itab). // Descriptor must represent an abi.InterfaceSwitch global variable. type InterfaceSwitchStmt struct { miniStmt @@ -331,14 +333,16 @@ type InterfaceSwitchStmt struct { Case Node Itab Node RuntimeType Node + Hash Node Descriptor *obj.LSym } -func NewInterfaceSwitchStmt(pos src.XPos, case_, itab, runtimeType Node, descriptor *obj.LSym) *InterfaceSwitchStmt { +func NewInterfaceSwitchStmt(pos src.XPos, case_, itab, runtimeType, hash Node, descriptor *obj.LSym) *InterfaceSwitchStmt { n := &InterfaceSwitchStmt{ Case: case_, Itab: itab, RuntimeType: runtimeType, + Hash: hash, Descriptor: descriptor, } n.pos = pos diff --git a/src/cmd/compile/internal/ssagen/ssa.go b/src/cmd/compile/internal/ssagen/ssa.go index af2e0e477e..e8f0f561d0 100644 --- a/src/cmd/compile/internal/ssagen/ssa.go +++ b/src/cmd/compile/internal/ssagen/ssa.go @@ -2025,6 +2025,7 @@ func (s *state) stmt(n ir.Node) { typs := s.f.Config.Types t := s.expr(n.RuntimeType) + h := s.expr(n.Hash) d := s.newValue1A(ssa.OpAddr, typs.BytePtr, n.Descriptor, s.sb) // Check the cache first. @@ -2061,10 +2062,9 @@ func (s *state) stmt(n ir.Node) { cache := s.newValue1(ssa.OpSelect0, typs.BytePtr, atomicLoad) s.vars[memVar] = s.newValue1(ssa.OpSelect1, types.TypeMem, atomicLoad) - // Load hash from type. - hash := s.newValue2(ssa.OpLoad, typs.UInt32, s.newValue1I(ssa.OpOffPtr, typs.UInt32Ptr, 2*s.config.PtrSize, t), s.mem()) - hash = s.newValue1(zext, typs.Uintptr, hash) - s.vars[hashVar] = hash + // Initialize hash variable. + s.vars[hashVar] = s.newValue1(zext, typs.Uintptr, h) + // Load mask from cache. mask := s.newValue2(ssa.OpLoad, typs.Uintptr, cache, s.mem()) // Jump to loop head. @@ -6703,8 +6703,13 @@ func (s *state) dottype1(pos src.XPos, src, dst *types.Type, iface, source, targ cache := s.newValue1(ssa.OpSelect0, typs.BytePtr, atomicLoad) s.vars[memVar] = s.newValue1(ssa.OpSelect1, types.TypeMem, atomicLoad) - // Load hash from type. - hash := s.newValue2(ssa.OpLoad, typs.UInt32, s.newValue1I(ssa.OpOffPtr, typs.UInt32Ptr, 2*s.config.PtrSize, typ), s.mem()) + // Load hash from type or itab. + var hash *ssa.Value + if src.IsEmptyInterface() { + hash = s.newValue2(ssa.OpLoad, typs.UInt32, s.newValue1I(ssa.OpOffPtr, typs.UInt32Ptr, 2*s.config.PtrSize, typ), s.mem()) + } else { + hash = s.newValue2(ssa.OpLoad, typs.UInt32, s.newValue1I(ssa.OpOffPtr, typs.UInt32Ptr, 2*s.config.PtrSize, itab), s.mem()) + } hash = s.newValue1(zext, typs.Uintptr, hash) s.vars[hashVar] = hash // Load mask from cache. diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go index aa04700088..1aff97d7fd 100644 --- a/src/cmd/compile/internal/walk/switch.go +++ b/src/cmd/compile/internal/walk/switch.go @@ -547,7 +547,7 @@ func walkSwitchType(sw *ir.SwitchStmt) { typeArg = itabType(srcItab) } caseVar := typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TINT]) - isw := ir.NewInterfaceSwitchStmt(base.Pos, caseVar, s.itabName, typeArg, lsym) + isw := ir.NewInterfaceSwitchStmt(base.Pos, caseVar, s.itabName, typeArg, dotHash, lsym) sw.Compiled.Append(isw) // Switch on the result of the call (or cache lookup). diff --git a/test/codegen/switch.go b/test/codegen/switch.go index 4103bf5297..980ea70561 100644 --- a/test/codegen/switch.go +++ b/test/codegen/switch.go @@ -129,11 +129,27 @@ type IJ interface { I J } +type K interface { + baz() +} // use a runtime call for type switches to interface types. func interfaceSwitch(x any) int { - // amd64:`CALL\truntime.interfaceSwitch`,`MOVL\t16\(.*\)`,`MOVQ\t8\(.*\)(.*\*8)` - // arm64:`CALL\truntime.interfaceSwitch`,`LDAR`,`MOVWU`,`MOVD\t\(R.*\)\(R.*\)` + // amd64:`CALL\truntime.interfaceSwitch`,`MOVL\t16\(AX\)`,`MOVQ\t8\(.*\)(.*\*8)` + // arm64:`CALL\truntime.interfaceSwitch`,`LDAR`,`MOVWU\t16\(R0\)`,`MOVD\t\(R.*\)\(R.*\)` + switch x.(type) { + case I: + return 1 + case J: + return 2 + default: + return 3 + } +} + +func interfaceSwitch2(x K) int { + // amd64:`CALL\truntime.interfaceSwitch`,`MOVL\t16\(AX\)`,`MOVQ\t8\(.*\)(.*\*8)` + // arm64:`CALL\truntime.interfaceSwitch`,`LDAR`,`MOVWU\t16\(R0\)`,`MOVD\t\(R.*\)\(R.*\)` switch x.(type) { case I: return 1 @@ -145,8 +161,17 @@ func interfaceSwitch(x any) int { } func interfaceCast(x any) int { - // amd64:`CALL\truntime.typeAssert`,`MOVL\t16\(.*\)`,`MOVQ\t8\(.*\)(.*\*1)` - // arm64:`CALL\truntime.typeAssert`,`LDAR`,`MOVWU`,`MOVD\t\(R.*\)\(R.*\)` + // amd64:`CALL\truntime.typeAssert`,`MOVL\t16\(AX\)`,`MOVQ\t8\(.*\)(.*\*1)` + // arm64:`CALL\truntime.typeAssert`,`LDAR`,`MOVWU\t16\(R0\)`,`MOVD\t\(R.*\)\(R.*\)` + if _, ok := x.(I); ok { + return 3 + } + return 5 +} + +func interfaceCast2(x K) int { + // amd64:`CALL\truntime.typeAssert`,`MOVL\t16\(AX\)`,`MOVQ\t8\(.*\)(.*\*1)` + // arm64:`CALL\truntime.typeAssert`,`LDAR`,`MOVWU\t16\(R0\)`,`MOVD\t\(R.*\)\(R.*\)` if _, ok := x.(I); ok { return 3 } @@ -154,7 +179,7 @@ func interfaceCast(x any) int { } func interfaceConv(x IJ) I { - // amd64:`CALL\truntime.typeAssert`,`MOVL\t16\(.*\)`,`MOVQ\t8\(.*\)(.*\*1)` - // arm64:`CALL\truntime.typeAssert`,`LDAR`,`MOVWU`,`MOVD\t\(R.*\)\(R.*\)` + // amd64:`CALL\truntime.typeAssert`,`MOVL\t16\(AX\)`,`MOVQ\t8\(.*\)(.*\*1)` + // arm64:`CALL\truntime.typeAssert`,`LDAR`,`MOVWU\t16\(R0\)`,`MOVD\t\(R.*\)\(R.*\)` return x }