diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 8001d6d398..071a2f44c2 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -553,7 +553,25 @@ func (subst *subster) typ(t *types.Type) *types.Type { return subst.targs[i].Type() } } - return t + // If t is a simple typeparam T, then t has the name/symbol 'T' + // and t.Underlying() == t. + // + // However, consider the type definition: 'type P[T any] T'. We + // might use this definition so we can have a variant of type T + // that we can add new methods to. Suppose t is a reference to + // P[T]. t has the name 'P[T]', but its kind is TTYPEPARAM, + // because P[T] is defined as T. If we look at t.Underlying(), it + // is different, because the name of t.Underlying() is 'T' rather + // than 'P[T]'. But the kind of t.Underlying() is also TTYPEPARAM. + // In this case, we do the needed recursive substitution in the + // case statement below. + if t.Underlying() == t { + // t is a simple typeparam that didn't match anything in tparam + return t + } + // t is a more complex typeparam (e.g. P[T], as above, whose + // definition is just T). + assert(t.Sym() != nil) } var newsym *types.Sym @@ -591,6 +609,15 @@ func (subst *subster) typ(t *types.Type) *types.Type { var newt *types.Type switch t.Kind() { + case types.TTYPEPARAM: + if t.Sym() == newsym { + // The substitution did not change the type. + return t + } + // Substitute the underlying typeparam (e.g. T in P[T], see + // the example describing type P[T] above). + newt = subst.typ(t.Underlying()) + assert(newt != t) case types.TARRAY: elem := t.Elem() diff --git a/src/cmd/compile/internal/noder/types.go b/src/cmd/compile/internal/noder/types.go index dfcf55d9c8..96bf75d594 100644 --- a/src/cmd/compile/internal/noder/types.go +++ b/src/cmd/compile/internal/noder/types.go @@ -180,11 +180,18 @@ func (g *irgen) typ0(typ types2.Type) *types.Type { return types.NewInterface(g.tpkg(typ), append(embeddeds, methods...)) case *types2.TypeParam: - tp := types.NewTypeParam(g.tpkg(typ), g.typ(typ.Bound())) + tp := types.NewTypeParam(g.tpkg(typ)) // Save the name of the type parameter in the sym of the type. // Include the types2 subscript in the sym name sym := g.pkg(typ.Obj().Pkg()).Lookup(types2.TypeString(typ, func(*types2.Package) string { return "" })) tp.SetSym(sym) + // Set g.typs[typ] in case the bound methods reference typ. + g.typs[typ] = tp + + // TODO(danscales): we don't currently need to use the bounds + // anywhere, so eventually we can probably remove. + bound := g.typ(typ.Bound()) + *tp.Methods() = *bound.Methods() return tp case *types2.Tuple: diff --git a/src/cmd/compile/internal/types/fmt.go b/src/cmd/compile/internal/types/fmt.go index c59f62e302..e29c826bb7 100644 --- a/src/cmd/compile/internal/types/fmt.go +++ b/src/cmd/compile/internal/types/fmt.go @@ -318,7 +318,7 @@ func tconv2(b *bytes.Buffer, t *Type, verb rune, mode fmtMode, visited map[*Type } // Unless the 'L' flag was specified, if the type has a name, just print that name. - if verb != 'L' && t.Sym() != nil && t != Types[t.Kind()] && t.Kind() != TTYPEPARAM { + if verb != 'L' && t.Sym() != nil && t != Types[t.Kind()] { switch mode { case fmtTypeID, fmtTypeIDName: if verb == 'S' { diff --git a/src/cmd/compile/internal/types/type.go b/src/cmd/compile/internal/types/type.go index d76d9b409f..ffaf755345 100644 --- a/src/cmd/compile/internal/types/type.go +++ b/src/cmd/compile/internal/types/type.go @@ -1742,12 +1742,9 @@ func NewInterface(pkg *Pkg, methods []*Field) *Type { return t } -// NewTypeParam returns a new type param with the given constraint (which may -// not really be needed except for the type checker). -func NewTypeParam(pkg *Pkg, constraint *Type) *Type { +// NewTypeParam returns a new type param. +func NewTypeParam(pkg *Pkg) *Type { t := New(TTYPEPARAM) - constraint.wantEtype(TINTER) - t.methods = constraint.methods t.Extra.(*Interface).pkg = pkg t.SetHasTParam(true) return t diff --git a/test/typeparam/absdiff.go b/test/typeparam/absdiff.go new file mode 100644 index 0000000000..5dd58f14f7 --- /dev/null +++ b/test/typeparam/absdiff.go @@ -0,0 +1,99 @@ +// run -gcflags=-G=3 + +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + //"math" +) + +type Numeric interface { + type int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, uintptr, + float32, float64, + complex64, complex128 +} + +// numericAbs matches numeric types with an Abs method. +type numericAbs[T any] interface { + Numeric + Abs() T +} + +// AbsDifference computes the absolute value of the difference of +// a and b, where the absolute value is determined by the Abs method. +func absDifference[T numericAbs[T]](a, b T) T { + d := a - b + return d.Abs() +} + +// orderedNumeric matches numeric types that support the < operator. +type orderedNumeric interface { + type int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, uintptr, + float32, float64 +} + +// Complex matches the two complex types, which do not have a < operator. +type Complex interface { + type complex64, complex128 +} + +// orderedAbs is a helper type that defines an Abs method for +// ordered numeric types. +type orderedAbs[T orderedNumeric] T + +func (a orderedAbs[T]) Abs() orderedAbs[T] { + // TODO(danscales): orderedAbs[T] conversion shouldn't be needed + if a < orderedAbs[T](0) { + return -a + } + return a +} + +// complexAbs is a helper type that defines an Abs method for +// complex types. +// type complexAbs[T Complex] T + +// func (a complexAbs[T]) Abs() complexAbs[T] { +// r := float64(real(a)) +// i := float64(imag(a)) +// d := math.Sqrt(r * r + i * i) +// return complexAbs[T](complex(d, 0)) +// } + +// OrderedAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of an ordered type. +func orderedAbsDifference[T orderedNumeric](a, b T) T { + return T(absDifference(orderedAbs[T](a), orderedAbs[T](b))) +} + +// ComplexAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of a complex type. +// func complexAbsDifference[T Complex](a, b T) T { +// return T(absDifference(complexAbs[T](a), complexAbs[T](b))) +// } + +func main() { + if got, want := orderedAbsDifference(1.0, -2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := orderedAbsDifference(-1.0, 2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := orderedAbsDifference(-20, 15), 35; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + + // Still have to handle built-ins real/abs to make this work + // if got, want := complexAbsDifference(5.0 + 2.0i, 2.0 - 2.0i), 5; got != want { + // panic(fmt.Sprintf("got = %v, want = %v", got, want) + // } + // if got, want := complexAbsDifference(2.0 - 2.0i, 5.0 + 2.0i), 5; got != want { + // panic(fmt.Sprintf("got = %v, want = %v", got, want) + // } +}