diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 55aee9b6ff..51ef46c7e7 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -367,7 +367,24 @@ func (subst *subster) node(n ir.Node) ir.Node { } ir.EditChildren(m, edit) - if x.Op() == ir.OXDOT { + switch x.Op() { + case ir.OLITERAL: + t := m.Type() + if t != x.Type() { + // types2 will give us a constant with a type T, + // if an untyped constant is used with another + // operand of type T (in a provably correct way). + // When we substitute in the type args during + // stenciling, we now know the real type of the + // constant. We may then need to change the + // BasicLit.val to be the correct type (e.g. + // convert an int64Val constant to a floatVal + // constant). + m.SetType(types.UntypedInt) // use any untyped type for DefaultLit to work + m = typecheck.DefaultLit(m, t) + } + + case ir.OXDOT: // A method value/call via a type param will have been left as an // OXDOT. When we see this during stenciling, finish the // typechecking, now that we have the instantiated receiver type. @@ -377,8 +394,8 @@ func (subst *subster) node(n ir.Node) ir.Node { m.SetTypecheck(0) // m will transform to an OCALLPART typecheck.Expr(m) - } - if x.Op() == ir.OCALL { + + case ir.OCALL: call := m.(*ir.CallExpr) if call.X.Op() == ir.OTYPE { // Do typechecking on a conversion, now that we @@ -419,9 +436,8 @@ func (subst *subster) node(n ir.Node) ir.Node { // instantiation to be called. base.FatalfAt(call.Pos(), "Expecting OCALLPART or OTYPE or OFUNCINST or builtin with CALL") } - } - if x.Op() == ir.OCLOSURE { + case ir.OCLOSURE: x := x.(*ir.ClosureExpr) // Need to save/duplicate x.Func.Nname, // x.Func.Nname.Ntype, x.Func.Dcl, x.Func.ClosureVars, and diff --git a/test/typeparam/double.go b/test/typeparam/double.go index 1f7a26c7f4..ce78ec9748 100644 --- a/test/typeparam/double.go +++ b/test/typeparam/double.go @@ -16,6 +16,7 @@ type Number interface { } type MySlice []int +type MyFloatSlice []float64 type _SliceOf[E any] interface { type []E @@ -29,6 +30,15 @@ func _DoubleElems[S _SliceOf[E], E Number](s S) S { return r } +// Test use of untyped constant in an expression with a generically-typed parameter +func _DoubleElems2[S _SliceOf[E], E Number](s S) S { + r := make(S, len(s)) + for i, v := range s { + r[i] = v * 2 + } + return r +} + func main() { arg := MySlice{1, 2, 3} want := MySlice{2, 4, 6} @@ -47,4 +57,16 @@ func main() { if !reflect.DeepEqual(got, want) { panic(fmt.Sprintf("got %s, want %s", got, want)) } + + farg := MyFloatSlice{1.2, 2.0, 3.5} + fwant := MyFloatSlice{2.4, 4.0, 7.0} + fgot := _DoubleElems(farg) + if !reflect.DeepEqual(fgot, fwant) { + panic(fmt.Sprintf("got %s, want %s", fgot, fwant)) + } + + fgot = _DoubleElems2(farg) + if !reflect.DeepEqual(fgot, fwant) { + panic(fmt.Sprintf("got %s, want %s", fgot, fwant)) + } } diff --git a/test/typeparam/list.go b/test/typeparam/list.go index 64230060de..579078f02f 100644 --- a/test/typeparam/list.go +++ b/test/typeparam/list.go @@ -17,13 +17,13 @@ type Ordered interface { string } -// List is a linked list of ordered values of type T. -type list[T Ordered] struct { - next *list[T] +// _List is a linked list of ordered values of type T. +type _List[T Ordered] struct { + next *_List[T] val T } -func (l *list[T]) largest() T { +func (l *_List[T]) Largest() T { var max T for p := l; p != nil; p = p.next { if p.val > max { @@ -33,33 +33,71 @@ func (l *list[T]) largest() T { return max } +type OrderedNum interface { + type int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, uintptr, + float32, float64 +} + +// _ListNum is a linked _List of ordered numeric values of type T. +type _ListNum[T OrderedNum] struct { + next *_ListNum[T] + val T +} + +const Clip = 5 + +// clippedLargest returns the largest in the list of OrderNums, but a max of 5. +// Test use of untyped constant in an expression with a generically-typed parameter +func (l *_ListNum[T]) ClippedLargest() T { + var max T + for p := l; p != nil; p = p.next { + if p.val > max && p.val < Clip { + max = p.val + } + } + return max +} func main() { - i3 := &list[int]{nil, 1} - i2 := &list[int]{i3, 3} - i1 := &list[int]{i2, 2} - if got, want := i1.largest(), 3; got != want { + i3 := &_List[int]{nil, 1} + i2 := &_List[int]{i3, 3} + i1 := &_List[int]{i2, 2} + if got, want := i1.Largest(), 3; got != want { panic(fmt.Sprintf("got %d, want %d", got, want)) } - b3 := &list[byte]{nil, byte(1)} - b2 := &list[byte]{b3, byte(3)} - b1 := &list[byte]{b2, byte(2)} - if got, want := b1.largest(), byte(3); got != want { + b3 := &_List[byte]{nil, byte(1)} + b2 := &_List[byte]{b3, byte(3)} + b1 := &_List[byte]{b2, byte(2)} + if got, want := b1.Largest(), byte(3); got != want { panic(fmt.Sprintf("got %d, want %d", got, want)) } - f3 := &list[float64]{nil, 13.5} - f2 := &list[float64]{f3, 1.2} - f1 := &list[float64]{f2, 4.5} - if got, want := f1.largest(), 13.5; got != want { + f3 := &_List[float64]{nil, 13.5} + f2 := &_List[float64]{f3, 1.2} + f1 := &_List[float64]{f2, 4.5} + if got, want := f1.Largest(), 13.5; got != want { panic(fmt.Sprintf("got %f, want %f", got, want)) } - s3 := &list[string]{nil, "dd"} - s2 := &list[string]{s3, "aa"} - s1 := &list[string]{s2, "bb"} - if got, want := s1.largest(), "dd"; got != want { + s3 := &_List[string]{nil, "dd"} + s2 := &_List[string]{s3, "aa"} + s1 := &_List[string]{s2, "bb"} + if got, want := s1.Largest(), "dd"; got != want { panic(fmt.Sprintf("got %s, want %s", got, want)) } + + j3 := &_ListNum[int]{nil, 1} + j2 := &_ListNum[int]{j3, 32} + j1 := &_ListNum[int]{j2, 2} + if got, want := j1.ClippedLargest(), 2; got != want { + panic(fmt.Sprintf("got %d, want %d", got, want)) + } + g3 := &_ListNum[float64]{nil, 13.5} + g2 := &_ListNum[float64]{g3, 1.2} + g1 := &_ListNum[float64]{g2, 4.5} + if got, want := g1.ClippedLargest(), 4.5; got != want { + panic(fmt.Sprintf("got %f, want %f", got, want)) + } }