diff --git a/src/cmd/compile/internal/noder/decl.go b/src/cmd/compile/internal/noder/decl.go index df1ca1c505..a9522d09af 100644 --- a/src/cmd/compile/internal/noder/decl.go +++ b/src/cmd/compile/internal/noder/decl.go @@ -133,12 +133,20 @@ func (g *irgen) funcDecl(out *ir.Nodes, decl *syntax.FuncDecl) { g.target.Inits = append(g.target.Inits, fn) } - haveEmbed := g.haveEmbed + saveHaveEmbed := g.haveEmbed + saveCurDecl := g.curDecl g.curDecl = "" g.later(func() { - defer func(b bool) { g.haveEmbed = b }(g.haveEmbed) + defer func(b bool, s string) { + // Revert haveEmbed and curDecl back to what they were before + // the "later" function. + g.haveEmbed = b + g.curDecl = s + }(g.haveEmbed, g.curDecl) - g.haveEmbed = haveEmbed + // Set haveEmbed and curDecl to what they were for this funcDecl. + g.haveEmbed = saveHaveEmbed + g.curDecl = saveCurDecl if fn.Type().HasTParam() { g.topFuncIsGeneric = true } @@ -162,9 +170,10 @@ func (g *irgen) funcDecl(out *ir.Nodes, decl *syntax.FuncDecl) { func (g *irgen) typeDecl(out *ir.Nodes, decl *syntax.TypeDecl) { // Set the position for any error messages we might print (e.g. too large types). base.Pos = g.pos(decl) - assert(g.curDecl == "") + assert(ir.CurFunc != nil || g.curDecl == "") // Set g.curDecl to the type name, as context for the type params declared // during types2-to-types1 translation if this is a generic type. + saveCurDecl := g.curDecl g.curDecl = decl.Name.Value if decl.Alias { name, _ := g.def(decl.Name) @@ -225,7 +234,7 @@ func (g *irgen) typeDecl(out *ir.Nodes, decl *syntax.TypeDecl) { } types.ResumeCheckSize() - g.curDecl = "" + g.curDecl = saveCurDecl if otyp, ok := otyp.(*types2.Named); ok && otyp.NumMethods() != 0 { methods := make([]*types.Field, otyp.NumMethods()) for i := range methods { diff --git a/test/run.go b/test/run.go index 2a7f080f9d..0e35ed2c0f 100644 --- a/test/run.go +++ b/test/run.go @@ -2180,6 +2180,8 @@ var unifiedFailures = setOf( "typeparam/typeswitch4.go", // duplicate case failure due to stenciling "typeparam/issue50417b.go", // Need to handle field access on a type param "typeparam/issue50552.go", // gives missing method for instantiated type + "typeparam/absdiff2.go", // wrong assertion about closure variables + "typeparam/absdiffimp2.go", // wrong assertion about closure variables ) func setOf(keys ...string) map[string]bool { diff --git a/test/typeparam/absdiff2.go b/test/typeparam/absdiff2.go new file mode 100644 index 0000000000..8f13bad2b6 --- /dev/null +++ b/test/typeparam/absdiff2.go @@ -0,0 +1,102 @@ +// run -gcflags=-G=3 + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "math" +) + +type Numeric interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~complex64 | ~complex128 +} + +// numericAbs matches a struct containing a numeric type that has an Abs method. +type numericAbs[T Numeric] interface { + ~struct{ Value T } + Abs() T +} + +// AbsDifference computes the absolute value of the difference of +// a and b, where the absolute value is determined by the Abs method. +func absDifference[T Numeric, U numericAbs[T]](a, b U) T { + d := a.Value - b.Value + dt := U{Value: d} + return dt.Abs() +} + +// orderedNumeric matches numeric types that support the < operator. +type orderedNumeric interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 +} + +// Complex matches the two complex types, which do not have a < operator. +type Complex interface { + ~complex64 | ~complex128 +} + +// orderedAbs is a helper type that defines an Abs method for +// a struct containing an ordered numeric type. +type orderedAbs[T orderedNumeric] struct { + Value T +} + +func (a orderedAbs[T]) Abs() T { + if a.Value < 0 { + return -a.Value + } + return a.Value +} + +// complexAbs is a helper type that defines an Abs method for +// a struct containing a complex type. +type complexAbs[T Complex] struct { + Value T +} + +func (a complexAbs[T]) Abs() T { + r := float64(real(a.Value)) + i := float64(imag(a.Value)) + d := math.Sqrt(r*r + i*i) + return T(complex(d, 0)) +} + +// OrderedAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of an ordered type. +func OrderedAbsDifference[T orderedNumeric](a, b T) T { + return absDifference(orderedAbs[T]{a}, orderedAbs[T]{b}) +} + +// ComplexAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of a complex type. +func ComplexAbsDifference[T Complex](a, b T) T { + return absDifference(complexAbs[T]{a}, complexAbs[T]{b}) +} + +func main() { + if got, want := OrderedAbsDifference(1.0, -2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := OrderedAbsDifference(-1.0, 2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := OrderedAbsDifference(-20, 15), 35; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + + if got, want := ComplexAbsDifference(5.0+2.0i, 2.0-2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := ComplexAbsDifference(2.0-2.0i, 5.0+2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } +} diff --git a/test/typeparam/absdiffimp2.dir/a.go b/test/typeparam/absdiffimp2.dir/a.go new file mode 100644 index 0000000000..782e000da9 --- /dev/null +++ b/test/typeparam/absdiffimp2.dir/a.go @@ -0,0 +1,80 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package a + +import ( + "math" +) + +type Numeric interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~complex64 | ~complex128 +} + +// numericAbs matches a struct containing a numeric type that has an Abs method. +type numericAbs[T Numeric] interface { + ~struct{ Value T } + Abs() T +} + +// AbsDifference computes the absolute value of the difference of +// a and b, where the absolute value is determined by the Abs method. +func absDifference[T Numeric, U numericAbs[T]](a, b U) T { + d := a.Value - b.Value + dt := U{Value: d} + return dt.Abs() +} + +// orderedNumeric matches numeric types that support the < operator. +type orderedNumeric interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 +} + +// Complex matches the two complex types, which do not have a < operator. +type Complex interface { + ~complex64 | ~complex128 +} + +// orderedAbs is a helper type that defines an Abs method for +// a struct containing an ordered numeric type. +type orderedAbs[T orderedNumeric] struct { + Value T +} + +func (a orderedAbs[T]) Abs() T { + if a.Value < 0 { + return -a.Value + } + return a.Value +} + +// complexAbs is a helper type that defines an Abs method for +// a struct containing a complex type. +type complexAbs[T Complex] struct { + Value T +} + +func (a complexAbs[T]) Abs() T { + r := float64(real(a.Value)) + i := float64(imag(a.Value)) + d := math.Sqrt(r*r + i*i) + return T(complex(d, 0)) +} + +// OrderedAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of an ordered type. +func OrderedAbsDifference[T orderedNumeric](a, b T) T { + return absDifference(orderedAbs[T]{a}, orderedAbs[T]{b}) +} + +// ComplexAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of a complex type. +func ComplexAbsDifference[T Complex](a, b T) T { + return absDifference(complexAbs[T]{a}, complexAbs[T]{b}) +} diff --git a/test/typeparam/absdiffimp2.dir/main.go b/test/typeparam/absdiffimp2.dir/main.go new file mode 100644 index 0000000000..8eefdbdf38 --- /dev/null +++ b/test/typeparam/absdiffimp2.dir/main.go @@ -0,0 +1,29 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "a" + "fmt" +) + +func main() { + if got, want := a.OrderedAbsDifference(1.0, -2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := a.OrderedAbsDifference(-1.0, 2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := a.OrderedAbsDifference(-20, 15), 35; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + + if got, want := a.ComplexAbsDifference(5.0+2.0i, 2.0-2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := a.ComplexAbsDifference(2.0-2.0i, 5.0+2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } +} diff --git a/test/typeparam/absdiffimp2.go b/test/typeparam/absdiffimp2.go new file mode 100644 index 0000000000..76930e5e4f --- /dev/null +++ b/test/typeparam/absdiffimp2.go @@ -0,0 +1,7 @@ +// rundir -G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ignored