diff --git a/go/analysis/passes/buildssa/buildssa.go b/go/analysis/passes/buildssa/buildssa.go index 4ec0e73ff2..02b7b18b3f 100644 --- a/go/analysis/passes/buildssa/buildssa.go +++ b/go/analysis/passes/buildssa/buildssa.go @@ -48,8 +48,7 @@ func run(pass *analysis.Pass) (interface{}, error) { // Some Analyzers may need GlobalDebug, in which case we'll have // to set it globally, but let's wait till we need it. - // Monomorphize at least until type parameters are available. - mode := ssa.InstantiateGenerics + mode := ssa.BuilderMode(0) prog := ssa.NewProgram(pass.Fset, mode) diff --git a/go/analysis/passes/nilness/nilness.go b/go/analysis/passes/nilness/nilness.go index 8db18c73ad..61fa30a523 100644 --- a/go/analysis/passes/nilness/nilness.go +++ b/go/analysis/passes/nilness/nilness.go @@ -62,7 +62,6 @@ var Analyzer = &analysis.Analyzer{ func run(pass *analysis.Pass) (interface{}, error) { ssainput := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) - // TODO(48525): ssainput.SrcFuncs is missing fn._Instances(). runFunc will be skipped. for _, fn := range ssainput.SrcFuncs { runFunc(pass, fn) } @@ -307,9 +306,9 @@ func nilnessOf(stack []fact, v ssa.Value) nilness { return isnonnil case *ssa.Const: if v.IsNil() { - return isnil + return isnil // nil or zero value of a pointer-like type } else { - return isnonnil + return unknown // non-pointer } } diff --git a/go/analysis/passes/nilness/testdata/src/a/a.go b/go/analysis/passes/nilness/testdata/src/a/a.go index aa7f9a8f85..0629e08d89 100644 --- a/go/analysis/passes/nilness/testdata/src/a/a.go +++ b/go/analysis/passes/nilness/testdata/src/a/a.go @@ -209,3 +209,10 @@ func f13() { var d *Y print(d.value) // want "nil dereference in field selection" } + +func f14() { + var x struct{ f string } + if x == struct{ f string }{} { // we don't catch this tautology as we restrict to reference types + print(x) + } +} diff --git a/go/analysis/passes/nilness/testdata/src/c/c.go b/go/analysis/passes/nilness/testdata/src/c/c.go index 2b2036595a..c9a05a714f 100644 --- a/go/analysis/passes/nilness/testdata/src/c/c.go +++ b/go/analysis/passes/nilness/testdata/src/c/c.go @@ -2,7 +2,7 @@ package c func instantiated[X any](x *X) int { if x == nil { - print(*x) // not reported until _Instances are added to SrcFuncs + print(*x) // want "nil dereference in load" } return 1 } diff --git a/go/callgraph/callgraph.go b/go/callgraph/callgraph.go index 352ce0c76e..905623753d 100644 --- a/go/callgraph/callgraph.go +++ b/go/callgraph/callgraph.go @@ -37,6 +37,8 @@ package callgraph // import "golang.org/x/tools/go/callgraph" // More generally, we could eliminate "uninteresting" nodes such as // nodes from packages we don't care about. +// TODO(zpavlinovic): decide how callgraphs handle calls to and from generic function bodies. + import ( "fmt" "go/token" diff --git a/go/callgraph/cha/cha.go b/go/callgraph/cha/cha.go index 170040426b..7075a73cbe 100644 --- a/go/callgraph/cha/cha.go +++ b/go/callgraph/cha/cha.go @@ -22,6 +22,8 @@ // partial programs, such as libraries without a main or test function. package cha // import "golang.org/x/tools/go/callgraph/cha" +// TODO(zpavlinovic): update CHA for how it handles generic function bodies. + import ( "go/types" diff --git a/go/callgraph/cha/testdata/generics.go b/go/callgraph/cha/testdata/generics.go index 79250a56ca..0323c7582b 100644 --- a/go/callgraph/cha/testdata/generics.go +++ b/go/callgraph/cha/testdata/generics.go @@ -41,5 +41,9 @@ func f(h func(), g func(I), k func(A), a A, b B) { // f --> instantiated[main.A] // f --> instantiated[main.A] // f --> instantiated[main.B] +// instantiated --> (*A).Foo +// instantiated --> (*B).Foo +// instantiated --> (A).Foo +// instantiated --> (B).Foo // instantiated[main.A] --> (A).Foo // instantiated[main.B] --> (B).Foo diff --git a/go/callgraph/static/static.go b/go/callgraph/static/static.go index c7fae75bbd..62d2364bf2 100644 --- a/go/callgraph/static/static.go +++ b/go/callgraph/static/static.go @@ -6,6 +6,8 @@ // only static call edges. package static // import "golang.org/x/tools/go/callgraph/static" +// TODO(zpavlinovic): update static for how it handles generic function bodies. + import ( "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go index c0b5775907..d1831983ad 100644 --- a/go/callgraph/vta/utils.go +++ b/go/callgraph/vta/utils.go @@ -9,6 +9,7 @@ import ( "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" + "golang.org/x/tools/internal/typeparams" ) func canAlias(n1, n2 node) bool { @@ -117,19 +118,27 @@ func functionUnderPtr(t types.Type) types.Type { } // sliceArrayElem returns the element type of type `t` that is -// expected to be a (pointer to) array or slice, consistent with +// expected to be a (pointer to) array, slice or string, consistent with // the ssa.Index and ssa.IndexAddr instructions. Panics otherwise. func sliceArrayElem(t types.Type) types.Type { - u := t.Underlying() - - if p, ok := u.(*types.Pointer); ok { - u = p.Elem().Underlying() + switch u := t.Underlying().(type) { + case *types.Pointer: + return u.Elem().Underlying().(*types.Array).Elem() + case *types.Array: + return u.Elem() + case *types.Slice: + return u.Elem() + case *types.Basic: + return types.Typ[types.Byte] + case *types.Interface: // type param. + terms, err := typeparams.InterfaceTermSet(u) + if err != nil || len(terms) == 0 { + panic(t) + } + return sliceArrayElem(terms[0].Type()) // Element types must match. + default: + panic(t) } - - if a, ok := u.(*types.Array); ok { - return a.Elem() - } - return u.(*types.Slice).Elem() } // siteCallees computes a set of callees for call site `c` given program `callgraph`. diff --git a/go/callgraph/vta/vta.go b/go/callgraph/vta/vta.go index 9839bd3f3c..5839360033 100644 --- a/go/callgraph/vta/vta.go +++ b/go/callgraph/vta/vta.go @@ -54,6 +54,8 @@ // reaching the node representing the call site to create a set of callees. package vta +// TODO(zpavlinovic): update VTA for how it handles generic function bodies and instantiation wrappers. + import ( "go/types" diff --git a/go/pointer/analysis.go b/go/pointer/analysis.go index 35ad8abdb1..e3c85ede4f 100644 --- a/go/pointer/analysis.go +++ b/go/pointer/analysis.go @@ -16,6 +16,7 @@ import ( "runtime" "runtime/debug" "sort" + "strings" "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" @@ -377,12 +378,27 @@ func (a *analysis) callEdge(caller *cgnode, site *callsite, calleeid nodeid) { fmt.Fprintf(a.log, "\tcall edge %s -> %s\n", site, callee) } - // Warn about calls to non-intrinsic external functions. + // Warn about calls to functions that are handled unsoundly. // TODO(adonovan): de-dup these messages. - if fn := callee.fn; fn.Blocks == nil && a.findIntrinsic(fn) == nil { + fn := callee.fn + + // Warn about calls to non-intrinsic external functions. + if fn.Blocks == nil && a.findIntrinsic(fn) == nil { a.warnf(site.pos(), "unsound call to unknown intrinsic: %s", fn) a.warnf(fn.Pos(), " (declared here)") } + + // Warn about calls to generic function bodies. + if fn.TypeParams().Len() > 0 && len(fn.TypeArgs()) == 0 { + a.warnf(site.pos(), "unsound call to generic function body: %s (build with ssa.InstantiateGenerics)", fn) + a.warnf(fn.Pos(), " (declared here)") + } + + // Warn about calls to instantiation wrappers of generics functions. + if fn.Origin() != nil && strings.HasPrefix(fn.Synthetic, "instantiation wrapper ") { + a.warnf(site.pos(), "unsound call to instantiation wrapper of generic: %s (build with ssa.InstantiateGenerics)", fn) + a.warnf(fn.Pos(), " (declared here)") + } } // dumpSolution writes the PTS solution to the specified file. diff --git a/go/pointer/api.go b/go/pointer/api.go index 9a4cc0af4a..64de110035 100644 --- a/go/pointer/api.go +++ b/go/pointer/api.go @@ -28,7 +28,11 @@ type Config struct { // dependencies of any main package may still affect the // analysis result, because they contribute runtime types and // thus methods. + // // TODO(adonovan): investigate whether this is desirable. + // + // Calls to generic functions will be unsound unless packages + // are built using the ssa.InstantiateGenerics builder mode. Mains []*ssa.Package // Reflection determines whether to handle reflection diff --git a/go/pointer/doc.go b/go/pointer/doc.go index d41346e699..aca343b88e 100644 --- a/go/pointer/doc.go +++ b/go/pointer/doc.go @@ -358,6 +358,14 @@ A. Control-flow joins would merge interfaces ({T1}, {V1}) and ({T2}, type-unsafe combination (T1,V2). Treating the value and its concrete type as inseparable makes the analysis type-safe.) +Type parameters: + +Type parameters are not directly supported by the analysis. +Calls to generic functions will be left as if they had empty bodies. +Users of the package are expected to use the ssa.InstantiateGenerics +builder mode when building code that uses or depends on code +containing generics. + reflect.Value: A reflect.Value is modelled very similar to an interface{}, i.e. as diff --git a/go/pointer/gen.go b/go/pointer/gen.go index 09705948d9..bee656b623 100644 --- a/go/pointer/gen.go +++ b/go/pointer/gen.go @@ -14,9 +14,11 @@ import ( "fmt" "go/token" "go/types" + "strings" "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" + "golang.org/x/tools/internal/typeparams" ) var ( @@ -978,7 +980,10 @@ func (a *analysis) genInstr(cgn *cgnode, instr ssa.Instruction) { a.sizeof(instr.Type())) case *ssa.Index: - a.copy(a.valueNode(instr), 1+a.valueNode(instr.X), a.sizeof(instr.Type())) + _, isstring := typeparams.CoreType(instr.X.Type()).(*types.Basic) + if !isstring { + a.copy(a.valueNode(instr), 1+a.valueNode(instr.X), a.sizeof(instr.Type())) + } case *ssa.Select: recv := a.valueOffsetNode(instr, 2) // instr : (index, recvOk, recv0, ... recv_n-1) @@ -1202,6 +1207,19 @@ func (a *analysis) genFunc(cgn *cgnode) { return } + if fn.TypeParams().Len() > 0 && len(fn.TypeArgs()) == 0 { + // Body of generic function. + // We'll warn about calls to such functions at the end. + return + } + + if strings.HasPrefix(fn.Synthetic, "instantiation wrapper ") { + // instantiation wrapper of a generic function. + // These may contain type coercions which are not currently supported. + // We'll warn about calls to such functions at the end. + return + } + if a.log != nil { fmt.Fprintln(a.log, "; Creating nodes for local values") } diff --git a/go/pointer/pointer_test.go b/go/pointer/pointer_test.go index 05fe981f86..1fa54f6e8f 100644 --- a/go/pointer/pointer_test.go +++ b/go/pointer/pointer_test.go @@ -240,9 +240,14 @@ func doOneInput(t *testing.T, input, fpath string) bool { // Find all calls to the built-in print(x). Analytically, // print is a no-op, but it's a convenient hook for testing // the PTS of an expression, so our tests use it. + // Exclude generic bodies as these should be dead code for pointer. + // Instance of generics are included. probes := make(map[*ssa.CallCommon]bool) for fn := range ssautil.AllFunctions(prog) { - // TODO(taking): Switch to a more principled check like fn.declaredPackage() == mainPkg if _Origin is exported. + if isGenericBody(fn) { + continue // skip generic bodies + } + // TODO(taking): Switch to a more principled check like fn.declaredPackage() == mainPkg if Origin is exported. if fn.Pkg == mainpkg || (fn.Pkg == nil && mainFiles[prog.Fset.File(fn.Pos())]) { for _, b := range fn.Blocks { for _, instr := range b.Instrs { @@ -656,6 +661,15 @@ func TestInput(t *testing.T) { } } +// isGenericBody returns true if fn is the body of a generic function. +func isGenericBody(fn *ssa.Function) bool { + sig := fn.Signature + if typeparams.ForSignature(sig).Len() > 0 || typeparams.RecvTypeParams(sig).Len() > 0 { + return fn.Synthetic == "" + } + return false +} + // join joins the elements of multiset with " | "s. func join(set map[string]int) string { var buf bytes.Buffer diff --git a/go/pointer/reflect.go b/go/pointer/reflect.go index efb11b0009..3762dd8d40 100644 --- a/go/pointer/reflect.go +++ b/go/pointer/reflect.go @@ -1024,7 +1024,7 @@ func ext۰reflect۰ChanOf(a *analysis, cgn *cgnode) { var dir reflect.ChanDir // unknown if site := cgn.callersite; site != nil { if c, ok := site.instr.Common().Args[0].(*ssa.Const); ok { - v, _ := constant.Int64Val(c.Value) + v := c.Int64() if 0 <= v && v <= int64(reflect.BothDir) { dir = reflect.ChanDir(v) } @@ -1751,8 +1751,7 @@ func ext۰reflect۰rtype۰InOut(a *analysis, cgn *cgnode, out bool) { index := -1 if site := cgn.callersite; site != nil { if c, ok := site.instr.Common().Args[0].(*ssa.Const); ok { - v, _ := constant.Int64Val(c.Value) - index = int(v) + index = int(c.Int64()) } } a.addConstraint(&rtypeInOutConstraint{ diff --git a/go/pointer/util.go b/go/pointer/util.go index 5fec1fc4ed..17728aa06a 100644 --- a/go/pointer/util.go +++ b/go/pointer/util.go @@ -8,12 +8,13 @@ import ( "bytes" "fmt" "go/types" - exec "golang.org/x/sys/execabs" "log" "os" "runtime" "time" + exec "golang.org/x/sys/execabs" + "golang.org/x/tools/container/intsets" ) @@ -125,7 +126,7 @@ func (a *analysis) flatten(t types.Type) []*fieldInfo { // Debuggability hack: don't remove // the named type from interfaces as // they're very verbose. - fl = append(fl, &fieldInfo{typ: t}) + fl = append(fl, &fieldInfo{typ: t}) // t may be a type param } else { fl = a.flatten(u) } diff --git a/go/ssa/TODO b/go/ssa/TODO new file mode 100644 index 0000000000..6c35253c73 --- /dev/null +++ b/go/ssa/TODO @@ -0,0 +1,16 @@ +-*- text -*- + +SSA Generics to-do list +=========================== + +DOCUMENTATION: +- Read me for internals + +TYPE PARAMETERIZED GENERIC FUNCTIONS: +- sanity.go updates. +- Check source functions going to generics. +- Tests, tests, tests... + +USAGE: +- Back fill users for handling ssa.InstantiateGenerics being off. + diff --git a/go/ssa/builder.go b/go/ssa/builder.go index 8ec8f6e310..04deb7b063 100644 --- a/go/ssa/builder.go +++ b/go/ssa/builder.go @@ -101,6 +101,9 @@ package ssa // // This is a low level operation for creating functions that do not exist in // the source. Use with caution. +// +// TODO(taking): Use consistent terminology for "concrete". +// TODO(taking): Use consistent terminology for "monomorphization"/"instantiate"/"expand". import ( "fmt" @@ -272,7 +275,7 @@ func (b *builder) exprN(fn *Function, e ast.Expr) Value { return fn.emit(&c) case *ast.IndexExpr: - mapt := fn.typeOf(e.X).Underlying().(*types.Map) + mapt := coreType(fn.typeOf(e.X)).(*types.Map) // ,ok must be a map. lookup := &Lookup{ X: b.expr(fn, e.X), Index: emitConv(fn, b.expr(fn, e.Index), mapt.Key()), @@ -309,7 +312,7 @@ func (b *builder) builtin(fn *Function, obj *types.Builtin, args []ast.Expr, typ typ = fn.typ(typ) switch obj.Name() { case "make": - switch typ.Underlying().(type) { + switch ct := coreType(typ).(type) { case *types.Slice: n := b.expr(fn, args[1]) m := n @@ -319,7 +322,7 @@ func (b *builder) builtin(fn *Function, obj *types.Builtin, args []ast.Expr, typ if m, ok := m.(*Const); ok { // treat make([]T, n, m) as new([m]T)[:n] cap := m.Int64() - at := types.NewArray(typ.Underlying().(*types.Slice).Elem(), cap) + at := types.NewArray(ct.Elem(), cap) alloc := emitNew(fn, at, pos) alloc.Comment = "makeslice" v := &Slice{ @@ -370,6 +373,8 @@ func (b *builder) builtin(fn *Function, obj *types.Builtin, args []ast.Expr, typ // We must still evaluate the value, though. (If it // was side-effect free, the whole call would have // been constant-folded.) + // + // Type parameters are always non-constant so use Underlying. t := deref(fn.typeOf(args[0])).Underlying() if at, ok := t.(*types.Array); ok { b.expr(fn, args[0]) // for effects only @@ -465,27 +470,27 @@ func (b *builder) addr(fn *Function, e ast.Expr, escaping bool) lvalue { return &lazyAddress{addr: emit, t: fld.Type(), pos: e.Sel.Pos(), expr: e.Sel} case *ast.IndexExpr: + xt := fn.typeOf(e.X) + elem, mode := indexType(xt) var x Value var et types.Type - switch t := fn.typeOf(e.X).Underlying().(type) { - case *types.Array: + switch mode { + case ixArrVar: // array, array|slice, array|*array, or array|*array|slice. x = b.addr(fn, e.X, escaping).address(fn) - et = types.NewPointer(t.Elem()) - case *types.Pointer: // *array + et = types.NewPointer(elem) + case ixVar: // *array, slice, *array|slice x = b.expr(fn, e.X) - et = types.NewPointer(t.Elem().Underlying().(*types.Array).Elem()) - case *types.Slice: - x = b.expr(fn, e.X) - et = types.NewPointer(t.Elem()) - case *types.Map: + et = types.NewPointer(elem) + case ixMap: + mt := coreType(xt).(*types.Map) return &element{ m: b.expr(fn, e.X), - k: emitConv(fn, b.expr(fn, e.Index), t.Key()), - t: t.Elem(), + k: emitConv(fn, b.expr(fn, e.Index), mt.Key()), + t: mt.Elem(), pos: e.Lbrack, } default: - panic("unexpected container type in IndexExpr: " + t.String()) + panic("unexpected container type in IndexExpr: " + xt.String()) } index := b.expr(fn, e.Index) if isUntyped(index.Type()) { @@ -562,7 +567,7 @@ func (b *builder) assign(fn *Function, loc lvalue, e ast.Expr, isZero bool, sb * } if _, ok := loc.(*address); ok { - if isInterface(loc.typ()) { + if isNonTypeParamInterface(loc.typ()) { // e.g. var x interface{} = T{...} // Can't in-place initialize an interface value. // Fall back to copying. @@ -632,18 +637,19 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { case *ast.FuncLit: fn2 := &Function{ - name: fmt.Sprintf("%s$%d", fn.Name(), 1+len(fn.AnonFuncs)), - Signature: fn.typeOf(e.Type).Underlying().(*types.Signature), - pos: e.Type.Func, - parent: fn, - Pkg: fn.Pkg, - Prog: fn.Prog, - syntax: e, - _Origin: nil, // anon funcs do not have an origin. - _TypeParams: fn._TypeParams, // share the parent's type parameters. - _TypeArgs: fn._TypeArgs, // share the parent's type arguments. - info: fn.info, - subst: fn.subst, // share the parent's type substitutions. + name: fmt.Sprintf("%s$%d", fn.Name(), 1+len(fn.AnonFuncs)), + Signature: fn.typeOf(e.Type).(*types.Signature), + pos: e.Type.Func, + parent: fn, + anonIdx: int32(len(fn.AnonFuncs)), + Pkg: fn.Pkg, + Prog: fn.Prog, + syntax: e, + topLevelOrigin: nil, // use anonIdx to lookup an anon instance's origin. + typeparams: fn.typeparams, // share the parent's type parameters. + typeargs: fn.typeargs, // share the parent's type arguments. + info: fn.info, + subst: fn.subst, // share the parent's type substitutions. } fn.AnonFuncs = append(fn.AnonFuncs, fn2) b.created.Add(fn2) @@ -745,14 +751,20 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { case *ast.SliceExpr: var low, high, max Value var x Value - switch fn.typeOf(e.X).Underlying().(type) { + xtyp := fn.typeOf(e.X) + switch coreType(xtyp).(type) { case *types.Array: // Potentially escaping. x = b.addr(fn, e.X, true).address(fn) case *types.Basic, *types.Slice, *types.Pointer: // *array x = b.expr(fn, e.X) default: - panic("unreachable") + // coreType exception? + if isBytestring(xtyp) { + x = b.expr(fn, e.X) // bytestring is handled as string and []byte. + } else { + panic("unexpected sequence type in SliceExpr") + } } if e.Low != nil { low = b.expr(fn, e.Low) @@ -780,7 +792,7 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { case *types.Builtin: return &Builtin{name: obj.Name(), sig: fn.instanceType(e).(*types.Signature)} case *types.Nil: - return nilConst(fn.instanceType(e)) + return zeroConst(fn.instanceType(e)) } // Package-level func or var? if v := fn.Prog.packageLevelMember(obj); v != nil { @@ -788,7 +800,7 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { return emitLoad(fn, g) // var (address) } callee := v.(*Function) // (func) - if len(callee._TypeParams) > 0 { + if callee.typeparams.Len() > 0 { targs := fn.subst.types(instanceArgs(fn.info, e)) callee = fn.Prog.needsInstance(callee, targs, b.created) } @@ -822,11 +834,32 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { wantAddr := isPointer(rt) escaping := true v := b.receiver(fn, e.X, wantAddr, escaping, sel) - if isInterface(rt) { - // If v has interface type I, + + if types.IsInterface(rt) { + // If v may be an interface type I (after instantiating), // we must emit a check that v is non-nil. - // We use: typeassert v.(I). - emitTypeAssert(fn, v, rt, token.NoPos) + if recv, ok := sel.recv.(*typeparams.TypeParam); ok { + // Emit a nil check if any possible instantiation of the + // type parameter is an interface type. + if len(typeSetOf(recv)) > 0 { + // recv has a concrete term its typeset. + // So it cannot be instantiated as an interface. + // + // Example: + // func _[T interface{~int; Foo()}] () { + // var v T + // _ = v.Foo // <-- MethodVal + // } + } else { + // rt may be instantiated as an interface. + // Emit nil check: typeassert (any(v)).(any). + emitTypeAssert(fn, emitConv(fn, v, tEface), tEface, token.NoPos) + } + } else { + // non-type param interface + // Emit nil check: typeassert v.(I). + emitTypeAssert(fn, v, rt, token.NoPos) + } } if targs := receiverTypeArgs(obj); len(targs) > 0 { // obj is generic. @@ -863,9 +896,17 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { return b.expr(fn, e.X) // Handle instantiation within the *Ident or *SelectorExpr cases. } // not a generic instantiation. - switch t := fn.typeOf(e.X).Underlying().(type) { - case *types.Array: - // Non-addressable array (in a register). + xt := fn.typeOf(e.X) + switch et, mode := indexType(xt); mode { + case ixVar: + // Addressable slice/array; use IndexAddr and Load. + return b.addr(fn, e, false).load(fn) + + case ixArrVar, ixValue: + // An array in a register, a string or a combined type that contains + // either an [_]array (ixArrVar) or string (ixValue). + + // Note: for ixArrVar and coreType(xt)==nil can be IndexAddr and Load. index := b.expr(fn, e.Index) if isUntyped(index.Type()) { index = emitConv(fn, index, tInt) @@ -875,38 +916,20 @@ func (b *builder) expr0(fn *Function, e ast.Expr, tv types.TypeAndValue) Value { Index: index, } v.setPos(e.Lbrack) - v.setType(t.Elem()) + v.setType(et) return fn.emit(v) - case *types.Map: + case ixMap: + ct := coreType(xt).(*types.Map) v := &Lookup{ X: b.expr(fn, e.X), - Index: emitConv(fn, b.expr(fn, e.Index), t.Key()), + Index: emitConv(fn, b.expr(fn, e.Index), ct.Key()), } v.setPos(e.Lbrack) - v.setType(t.Elem()) + v.setType(ct.Elem()) return fn.emit(v) - - case *types.Basic: // => string - // Strings are not addressable. - index := b.expr(fn, e.Index) - if isUntyped(index.Type()) { - index = emitConv(fn, index, tInt) - } - v := &Lookup{ - X: b.expr(fn, e.X), - Index: index, - } - v.setPos(e.Lbrack) - v.setType(tByte) - return fn.emit(v) - - case *types.Slice, *types.Pointer: // *array - // Addressable slice/array; use IndexAddr and Load. - return b.addr(fn, e, false).load(fn) - default: - panic("unexpected container type in IndexExpr: " + t.String()) + panic("unexpected container type in IndexExpr: " + xt.String()) } case *ast.CompositeLit, *ast.StarExpr: @@ -967,14 +990,14 @@ func (b *builder) setCallFunc(fn *Function, e *ast.CallExpr, c *CallCommon) { wantAddr := isPointer(recv) escaping := true v := b.receiver(fn, selector.X, wantAddr, escaping, sel) - if isInterface(recv) { + if types.IsInterface(recv) { // Invoke-mode call. - c.Value = v + c.Value = v // possibly type param c.Method = obj } else { // "Call"-mode call. callee := fn.Prog.originFunc(obj) - if len(callee._TypeParams) > 0 { + if callee.typeparams.Len() > 0 { callee = fn.Prog.needsInstance(callee, receiverTypeArgs(obj), b.created) } c.Value = callee @@ -1065,7 +1088,7 @@ func (b *builder) emitCallArgs(fn *Function, sig *types.Signature, e *ast.CallEx st := sig.Params().At(np).Type().(*types.Slice) vt := st.Elem() if len(varargs) == 0 { - args = append(args, nilConst(st)) + args = append(args, zeroConst(st)) } else { // Replace a suffix of args with a slice containing it. at := types.NewArray(vt, int64(len(varargs))) @@ -1097,7 +1120,7 @@ func (b *builder) setCall(fn *Function, e *ast.CallExpr, c *CallCommon) { b.setCallFunc(fn, e, c) // Then append the other actual parameters. - sig, _ := fn.typeOf(e.Fun).Underlying().(*types.Signature) + sig, _ := coreType(fn.typeOf(e.Fun)).(*types.Signature) if sig == nil { panic(fmt.Sprintf("no signature for call of %s", e.Fun)) } @@ -1230,8 +1253,32 @@ func (b *builder) arrayLen(fn *Function, elts []ast.Expr) int64 { // literal has type *T behaves like &T{}. // In that case, addr must hold a T, not a *T. func (b *builder) compLit(fn *Function, addr Value, e *ast.CompositeLit, isZero bool, sb *storebuf) { - typ := deref(fn.typeOf(e)) - switch t := typ.Underlying().(type) { + typ := deref(fn.typeOf(e)) // type with name [may be type param] + t := deref(coreType(typ)).Underlying() // core type for comp lit case + // Computing typ and t is subtle as these handle pointer types. + // For example, &T{...} is valid even for maps and slices. + // Also typ should refer to T (not *T) while t should be the core type of T. + // + // To show the ordering to take into account, consider the composite literal + // expressions `&T{f: 1}` and `{f: 1}` within the expression `[]S{{f: 1}}` here: + // type N struct{f int} + // func _[T N, S *N]() { + // _ = &T{f: 1} + // _ = []S{{f: 1}} + // } + // For `&T{f: 1}`, we compute `typ` and `t` as: + // typeOf(&T{f: 1}) == *T + // deref(*T) == T (typ) + // coreType(T) == N + // deref(N) == N + // N.Underlying() == struct{f int} (t) + // For `{f: 1}` in `[]S{{f: 1}}`, we compute `typ` and `t` as: + // typeOf({f: 1}) == S + // deref(S) == S (typ) + // coreType(S) == *N + // deref(*N) == N + // N.Underlying() == struct{f int} (t) + switch t := t.(type) { case *types.Struct: if !isZero && len(e.Elts) != t.NumFields() { // memclear @@ -1259,6 +1306,7 @@ func (b *builder) compLit(fn *Function, addr Value, e *ast.CompositeLit, isZero X: addr, Field: fieldIndex, } + faddr.setPos(pos) faddr.setType(types.NewPointer(sf.Type())) fn.emit(faddr) b.assign(fn, &address{addr: faddr, pos: pos, expr: e}, e, isZero, sb) @@ -1529,7 +1577,7 @@ func (b *builder) typeSwitchStmt(fn *Function, s *ast.TypeSwitchStmt, label *lbl casetype = fn.typeOf(cond) var condv Value if casetype == tUntypedNil { - condv = emitCompare(fn, token.EQL, x, nilConst(x.Type()), cond.Pos()) + condv = emitCompare(fn, token.EQL, x, zeroConst(x.Type()), cond.Pos()) ti = x } else { yok := emitTypeTest(fn, x, casetype, cc.Case) @@ -1612,7 +1660,7 @@ func (b *builder) selectStmt(fn *Function, s *ast.SelectStmt, label *lblock) { case *ast.SendStmt: // ch<- i ch := b.expr(fn, comm.Chan) - chtyp := fn.typ(ch.Type()).Underlying().(*types.Chan) + chtyp := coreType(fn.typ(ch.Type())).(*types.Chan) st = &SelectState{ Dir: types.SendOnly, Chan: ch, @@ -1669,9 +1717,8 @@ func (b *builder) selectStmt(fn *Function, s *ast.SelectStmt, label *lblock) { vars = append(vars, varIndex, varOk) for _, st := range states { if st.Dir == types.RecvOnly { - chtyp := fn.typ(st.Chan.Type()).Underlying().(*types.Chan) - tElem := chtyp.Elem() - vars = append(vars, anonVar(tElem)) + chtyp := coreType(fn.typ(st.Chan.Type())).(*types.Chan) + vars = append(vars, anonVar(chtyp.Elem())) } } sel.setType(types.NewTuple(vars...)) @@ -1835,6 +1882,8 @@ func (b *builder) rangeIndexed(fn *Function, x Value, tv types.Type, pos token.P // elimination if x is pure, static unrolling, etc. // Ranging over a nil *array may have >0 iterations. // We still generate code for x, in case it has effects. + // + // TypeParams do not have constant length. Use underlying instead of core type. length = intConst(arr.Len()) } else { // length = len(x). @@ -1867,7 +1916,7 @@ func (b *builder) rangeIndexed(fn *Function, x Value, tv types.Type, pos token.P k = emitLoad(fn, index) if tv != nil { - switch t := x.Type().Underlying().(type) { + switch t := coreType(x.Type()).(type) { case *types.Array: instr := &Index{ X: x, @@ -1937,11 +1986,9 @@ func (b *builder) rangeIter(fn *Function, x Value, tk, tv types.Type, pos token. emitJump(fn, loop) fn.currentBlock = loop - _, isString := x.Type().Underlying().(*types.Basic) - okv := &Next{ Iter: it, - IsString: isString, + IsString: isBasic(coreType(x.Type())), } okv.setType(types.NewTuple( varOk, @@ -1991,7 +2038,7 @@ func (b *builder) rangeChan(fn *Function, x Value, tk types.Type, pos token.Pos) } recv.setPos(pos) recv.setType(types.NewTuple( - newVar("k", x.Type().Underlying().(*types.Chan).Elem()), + newVar("k", coreType(x.Type()).(*types.Chan).Elem()), varOk, )) ko := fn.emit(recv) @@ -2035,7 +2082,7 @@ func (b *builder) rangeStmt(fn *Function, s *ast.RangeStmt, label *lblock) { var k, v Value var loop, done *BasicBlock - switch rt := x.Type().Underlying().(type) { + switch rt := coreType(x.Type()).(type) { case *types.Slice, *types.Array, *types.Pointer: // *array k, v, loop, done = b.rangeIndexed(fn, x, tv, s.For) @@ -2113,11 +2160,11 @@ start: b.expr(fn, s.X) case *ast.SendStmt: + chtyp := coreType(fn.typeOf(s.Chan)).(*types.Chan) fn.emit(&Send{ Chan: b.expr(fn, s.Chan), - X: emitConv(fn, b.expr(fn, s.Value), - fn.typeOf(s.Chan).Underlying().(*types.Chan).Elem()), - pos: s.Arrow, + X: emitConv(fn, b.expr(fn, s.Value), chtyp.Elem()), + pos: s.Arrow, }) case *ast.IncDecStmt: @@ -2295,11 +2342,9 @@ func (b *builder) buildFunctionBody(fn *Function) { var functype *ast.FuncType switch n := fn.syntax.(type) { case nil: - // TODO(taking): Temporarily this can be the body of a generic function. if fn.Params != nil { return // not a Go source function. (Synthetic, or from object file.) } - // fn.Params == nil is handled within body == nil case. case *ast.FuncDecl: functype = n.Type recvField = n.Recv @@ -2331,6 +2376,13 @@ func (b *builder) buildFunctionBody(fn *Function) { } return } + + // Build instantiation wrapper around generic body? + if fn.topLevelOrigin != nil && fn.subst == nil { + buildInstantiationWrapper(fn) + return + } + if fn.Prog.mode&LogSource != 0 { defer logStack("build function %s @ %s", fn, fn.Prog.Fset.Position(fn.pos))() } @@ -2435,7 +2487,17 @@ func (p *Package) build() { // TODO(adonovan): ideally belongs in memberFromObject, but // that would require package creation in topological order. for name, mem := range p.Members { - if ast.IsExported(name) && !isGeneric(mem) { + isGround := func(m Member) bool { + switch m := m.(type) { + case *Type: + named, _ := m.Type().(*types.Named) + return named == nil || typeparams.ForNamed(named) == nil + case *Function: + return m.typeparams.Len() == 0 + } + return true // *NamedConst, *Global + } + if ast.IsExported(name) && isGround(mem) { p.Prog.needMethodsOf(mem.Type(), &p.created) } } diff --git a/go/ssa/builder_generic_test.go b/go/ssa/builder_generic_test.go new file mode 100644 index 0000000000..dda53e1541 --- /dev/null +++ b/go/ssa/builder_generic_test.go @@ -0,0 +1,664 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa_test + +import ( + "fmt" + "go/parser" + "go/token" + "reflect" + "sort" + "testing" + + "golang.org/x/tools/go/expect" + "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/internal/typeparams" +) + +// TestGenericBodies tests that bodies of generic functions and methods containing +// different constructs can be built in BuilderMode(0). +// +// Each test specifies the contents of package containing a single go file. +// Each call print(arg0, arg1, ...) to the builtin print function +// in ssa is correlated a comment at the end of the line of the form: +// +// //@ types(a, b, c) +// +// where a, b and c are the types of the arguments to the print call +// serialized using go/types.Type.String(). +// See x/tools/go/expect for details on the syntax. +func TestGenericBodies(t *testing.T) { + if !typeparams.Enabled { + t.Skip("TestGenericBodies requires type parameters") + } + for _, test := range []struct { + pkg string // name of the package. + contents string // contents of the Go package. + }{ + { + pkg: "p", + contents: ` + package p + + func f(x int) { + var i interface{} + print(i, 0) //@ types("interface{}", int) + print() //@ types() + print(x) //@ types(int) + } + `, + }, + { + pkg: "q", + contents: ` + package q + + func f[T any](x T) { + print(x) //@ types(T) + } + `, + }, + { + pkg: "r", + contents: ` + package r + + func f[T ~int]() { + var x T + print(x) //@ types(T) + } + `, + }, + { + pkg: "s", + contents: ` + package s + + func a[T ~[4]byte](x T) { + for k, v := range x { + print(x, k, v) //@ types(T, int, byte) + } + } + func b[T ~*[4]byte](x T) { + for k, v := range x { + print(x, k, v) //@ types(T, int, byte) + } + } + func c[T ~[]byte](x T) { + for k, v := range x { + print(x, k, v) //@ types(T, int, byte) + } + } + func d[T ~string](x T) { + for k, v := range x { + print(x, k, v) //@ types(T, int, rune) + } + } + func e[T ~map[int]string](x T) { + for k, v := range x { + print(x, k, v) //@ types(T, int, string) + } + } + func f[T ~chan string](x T) { + for v := range x { + print(x, v) //@ types(T, string) + } + } + + func From() { + type A [4]byte + print(a[A]) //@ types("func(x s.A)") + + type B *[4]byte + print(b[B]) //@ types("func(x s.B)") + + type C []byte + print(c[C]) //@ types("func(x s.C)") + + type D string + print(d[D]) //@ types("func(x s.D)") + + type E map[int]string + print(e[E]) //@ types("func(x s.E)") + + type F chan string + print(f[F]) //@ types("func(x s.F)") + } + `, + }, + { + pkg: "t", + contents: ` + package t + + func f[S any, T ~chan S](x T) { + for v := range x { + print(x, v) //@ types(T, S) + } + } + + func From() { + type F chan string + print(f[string, F]) //@ types("func(x t.F)") + } + `, + }, + { + pkg: "u", + contents: ` + package u + + func fibonacci[T ~chan int](c, quit T) { + x, y := 0, 1 + for { + select { + case c <- x: + x, y = y, x+y + case <-quit: + print(c, quit, x, y) //@ types(T, T, int, int) + return + } + } + } + func start[T ~chan int](c, quit T) { + go func() { + for i := 0; i < 10; i++ { + print(<-c) //@ types(int) + } + quit <- 0 + }() + } + func From() { + type F chan int + c := make(F) + quit := make(F) + print(start[F], c, quit) //@ types("func(c u.F, quit u.F)", "u.F", "u.F") + print(fibonacci[F], c, quit) //@ types("func(c u.F, quit u.F)", "u.F", "u.F") + } + `, + }, + { + pkg: "v", + contents: ` + package v + + func f[T ~struct{ x int; y string }](i int) T { + u := []T{ T{0, "lorem"}, T{1, "ipsum"}} + return u[i] + } + func From() { + type S struct{ x int; y string } + print(f[S]) //@ types("func(i int) v.S") + } + `, + }, + { + pkg: "w", + contents: ` + package w + + func f[T ~[4]int8](x T, l, h int) []int8 { + return x[l:h] + } + func g[T ~*[4]int16](x T, l, h int) []int16 { + return x[l:h] + } + func h[T ~[]int32](x T, l, h int) T { + return x[l:h] + } + func From() { + type F [4]int8 + type G *[4]int16 + type H []int32 + print(f[F](F{}, 0, 0)) //@ types("[]int8") + print(g[G](nil, 0, 0)) //@ types("[]int16") + print(h[H](nil, 0, 0)) //@ types("w.H") + } + `, + }, + { + pkg: "x", + contents: ` + package x + + func h[E any, T ~[]E](x T, l, h int) []E { + s := x[l:h] + print(s) //@ types("T") + return s + } + func From() { + type H []int32 + print(h[int32, H](nil, 0, 0)) //@ types("[]int32") + } + `, + }, + { + pkg: "y", + contents: ` + package y + + // Test "make" builtin with different forms on core types and + // when capacities are constants or variable. + func h[E any, T ~[]E](m, n int) { + print(make(T, 3)) //@ types(T) + print(make(T, 3, 5)) //@ types(T) + print(make(T, m)) //@ types(T) + print(make(T, m, n)) //@ types(T) + } + func i[K comparable, E any, T ~map[K]E](m int) { + print(make(T)) //@ types(T) + print(make(T, 5)) //@ types(T) + print(make(T, m)) //@ types(T) + } + func j[E any, T ~chan E](m int) { + print(make(T)) //@ types(T) + print(make(T, 6)) //@ types(T) + print(make(T, m)) //@ types(T) + } + func From() { + type H []int32 + h[int32, H](3, 4) + type I map[int8]H + i[int8, H, I](5) + type J chan I + j[I, J](6) + } + `, + }, + { + pkg: "z", + contents: ` + package z + + func h[T ~[4]int](x T) { + print(len(x), cap(x)) //@ types(int, int) + } + func i[T ~[4]byte | []int | ~chan uint8](x T) { + print(len(x), cap(x)) //@ types(int, int) + } + func j[T ~[4]int | any | map[string]int]() { + print(new(T)) //@ types("*T") + } + func k[T ~[4]int | any | map[string]int](x T) { + print(x) //@ types(T) + panic(x) + } + `, + }, + { + pkg: "a", + contents: ` + package a + + func f[E any, F ~func() E](x F) { + print(x, x()) //@ types(F, E) + } + func From() { + type T func() int + f[int, T](func() int { return 0 }) + f[int, func() int](func() int { return 1 }) + } + `, + }, + { + pkg: "b", + contents: ` + package b + + func f[E any, M ~map[string]E](m M) { + y, ok := m["lorem"] + print(m, y, ok) //@ types(M, E, bool) + } + func From() { + type O map[string][]int + f(O{"lorem": []int{0, 1, 2, 3}}) + } + `, + }, + { + pkg: "c", + contents: ` + package c + + func a[T interface{ []int64 | [5]int64 }](x T) int64 { + print(x, x[2], x[3]) //@ types(T, int64, int64) + x[2] = 5 + return x[3] + } + func b[T interface{ []byte | string }](x T) byte { + print(x, x[3]) //@ types(T, byte) + return x[3] + } + func c[T interface{ []byte }](x T) byte { + print(x, x[2], x[3]) //@ types(T, byte, byte) + x[2] = 'b' + return x[3] + } + func d[T interface{ map[int]int64 }](x T) int64 { + print(x, x[2], x[3]) //@ types(T, int64, int64) + x[2] = 43 + return x[3] + } + func e[T ~string](t T) { + print(t, t[0]) //@ types(T, uint8) + } + func f[T ~string|[]byte](t T) { + print(t, t[0]) //@ types(T, uint8) + } + func g[T []byte](t T) { + print(t, t[0]) //@ types(T, byte) + } + func h[T ~[4]int|[]int](t T) { + print(t, t[0]) //@ types(T, int) + } + func i[T ~[4]int|*[4]int|[]int](t T) { + print(t, t[0]) //@ types(T, int) + } + func j[T ~[4]int|*[4]int|[]int](t T) { + print(t, &t[0]) //@ types(T, "*int") + } + `, + }, + { + pkg: "d", + contents: ` + package d + + type MyInt int + type Other int + type MyInterface interface{ foo() } + + // ChangeType tests + func ct0(x int) { v := MyInt(x); print(x, v) /*@ types(int, "d.MyInt")*/ } + func ct1[T MyInt | Other, S int ](x S) { v := T(x); print(x, v) /*@ types(S, T)*/ } + func ct2[T int, S MyInt | int ](x S) { v := T(x); print(x, v) /*@ types(S, T)*/ } + func ct3[T MyInt | Other, S MyInt | int ](x S) { v := T(x) ; print(x, v) /*@ types(S, T)*/ } + + // Convert tests + func co0[T int | int8](x MyInt) { v := T(x); print(x, v) /*@ types("d.MyInt", T)*/} + func co1[T int | int8](x T) { v := MyInt(x); print(x, v) /*@ types(T, "d.MyInt")*/ } + func co2[S, T int | int8](x T) { v := S(x); print(x, v) /*@ types(T, S)*/ } + + // MakeInterface tests + func mi0[T MyInterface](x T) { v := MyInterface(x); print(x, v) /*@ types(T, "d.MyInterface")*/ } + + // NewConst tests + func nc0[T any]() { v := (*T)(nil); print(v) /*@ types("*T")*/} + + // SliceToArrayPointer + func sl0[T *[4]int | *[2]int](x []int) { v := T(x); print(x, v) /*@ types("[]int", T)*/ } + func sl1[T *[4]int | *[2]int, S []int](x S) { v := T(x); print(x, v) /*@ types(S, T)*/ } + `, + }, + { + pkg: "e", + contents: ` + package e + + func c[T interface{ foo() string }](x T) { + print(x, x.foo, x.foo()) /*@ types(T, "func() string", string)*/ + } + `, + }, + { + pkg: "f", + contents: `package f + + func eq[T comparable](t T, i interface{}) bool { + return t == i + } + `, + }, + { + pkg: "g", + contents: `package g + type S struct{ f int } + func c[P *S]() []P { return []P{{f: 1}} } + `, + }, + { + pkg: "h", + contents: `package h + func sign[bytes []byte | string](s bytes) (bool, bool) { + neg := false + if len(s) > 0 && (s[0] == '-' || s[0] == '+') { + neg = s[0] == '-' + s = s[1:] + } + return !neg, len(s) > 0 + }`, + }, + { + pkg: "i", + contents: `package i + func digits[bytes []byte | string](s bytes) bool { + for _, c := range []byte(s) { + if c < '0' || '9' < c { + return false + } + } + return true + }`, + }, + } { + test := test + t.Run(test.pkg, func(t *testing.T) { + // Parse + conf := loader.Config{ParserMode: parser.ParseComments} + fname := test.pkg + ".go" + f, err := conf.ParseFile(fname, test.contents) + if err != nil { + t.Fatalf("parse: %v", err) + } + conf.CreateFromFiles(test.pkg, f) + + // Load + lprog, err := conf.Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + + // Create and build SSA + prog := ssa.NewProgram(lprog.Fset, ssa.SanityCheckFunctions) + for _, info := range lprog.AllPackages { + if info.TransitivelyErrorFree { + prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable) + } + } + p := prog.Package(lprog.Package(test.pkg).Pkg) + p.Build() + + // Collect calls to the builtin print function. + probes := make(map[*ssa.CallCommon]bool) + for _, mem := range p.Members { + if fn, ok := mem.(*ssa.Function); ok { + for _, bb := range fn.Blocks { + for _, i := range bb.Instrs { + if i, ok := i.(ssa.CallInstruction); ok { + call := i.Common() + if b, ok := call.Value.(*ssa.Builtin); ok && b.Name() == "print" { + probes[i.Common()] = true + } + } + } + } + } + } + + // Collect all notes in f, i.e. comments starting with "//@ types". + notes, err := expect.ExtractGo(prog.Fset, f) + if err != nil { + t.Errorf("expect.ExtractGo: %v", err) + } + + // Matches each probe with a note that has the same line. + sameLine := func(x, y token.Pos) bool { + xp := prog.Fset.Position(x) + yp := prog.Fset.Position(y) + return xp.Filename == yp.Filename && xp.Line == yp.Line + } + expectations := make(map[*ssa.CallCommon]*expect.Note) + for call := range probes { + var match *expect.Note + for _, note := range notes { + if note.Name == "types" && sameLine(call.Pos(), note.Pos) { + match = note // first match is good enough. + break + } + } + if match != nil { + expectations[call] = match + } else { + t.Errorf("Unmatched probe: %v", call) + } + } + + // Check each expectation. + for call, note := range expectations { + var args []string + for _, a := range call.Args { + args = append(args, a.Type().String()) + } + if got, want := fmt.Sprint(args), fmt.Sprint(note.Args); got != want { + t.Errorf("Arguments to print() were expected to be %q. got %q", want, got) + } + } + }) + } +} + +// TestInstructionString tests serializing instructions via Instruction.String(). +func TestInstructionString(t *testing.T) { + if !typeparams.Enabled { + t.Skip("TestInstructionString requires type parameters") + } + // Tests (ssa.Instruction).String(). Instructions are from a single go file. + // The Instructions tested are those that match a comment of the form: + // + // //@ instrs(f, kind, strs...) + // + // where f is the name of the function, kind is the type of the instructions matched + // within the function, and tests that the String() value for all of the instructions + // matched of String() is strs (in some order). + // See x/tools/go/expect for details on the syntax. + + const contents = ` + package p + + //@ instrs("f", "*ssa.TypeAssert") + //@ instrs("f", "*ssa.Call", "print(nil:interface{}, 0:int)") + func f(x int) { // non-generic smoke test. + var i interface{} + print(i, 0) + } + + //@ instrs("h", "*ssa.Alloc", "local T (u)") + //@ instrs("h", "*ssa.FieldAddr", "&t0.x [#0]") + func h[T ~struct{ x string }]() T { + u := T{"lorem"} + return u + } + + //@ instrs("c", "*ssa.TypeAssert", "typeassert t0.(interface{})") + //@ instrs("c", "*ssa.Call", "invoke x.foo()") + func c[T interface{ foo() string }](x T) { + _ = x.foo + _ = x.foo() + } + + //@ instrs("d", "*ssa.TypeAssert", "typeassert t0.(interface{})") + //@ instrs("d", "*ssa.Call", "invoke x.foo()") + func d[T interface{ foo() string; comparable }](x T) { + _ = x.foo + _ = x.foo() + } + ` + + // Parse + conf := loader.Config{ParserMode: parser.ParseComments} + const fname = "p.go" + f, err := conf.ParseFile(fname, contents) + if err != nil { + t.Fatalf("parse: %v", err) + } + conf.CreateFromFiles("p", f) + + // Load + lprog, err := conf.Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + + // Create and build SSA + prog := ssa.NewProgram(lprog.Fset, ssa.SanityCheckFunctions) + for _, info := range lprog.AllPackages { + if info.TransitivelyErrorFree { + prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable) + } + } + p := prog.Package(lprog.Package("p").Pkg) + p.Build() + + // Collect all notes in f, i.e. comments starting with "//@ instr". + notes, err := expect.ExtractGo(prog.Fset, f) + if err != nil { + t.Errorf("expect.ExtractGo: %v", err) + } + + // Expectation is a {function, type string} -> {want, matches} + // where matches is all Instructions.String() that match the key. + // Each expecation is that some permutation of matches is wants. + type expKey struct { + function string + kind string + } + type expValue struct { + wants []string + matches []string + } + expectations := make(map[expKey]*expValue) + for _, note := range notes { + if note.Name == "instrs" { + if len(note.Args) < 2 { + t.Error("Had @instrs annotation without at least 2 arguments") + continue + } + fn, kind := fmt.Sprint(note.Args[0]), fmt.Sprint(note.Args[1]) + var wants []string + for _, arg := range note.Args[2:] { + wants = append(wants, fmt.Sprint(arg)) + } + expectations[expKey{fn, kind}] = &expValue{wants, nil} + } + } + + // Collect all Instructions that match the expectations. + for _, mem := range p.Members { + if fn, ok := mem.(*ssa.Function); ok { + for _, bb := range fn.Blocks { + for _, i := range bb.Instrs { + kind := fmt.Sprintf("%T", i) + if e := expectations[expKey{fn.Name(), kind}]; e != nil { + e.matches = append(e.matches, i.String()) + } + } + } + } + } + + // Check each expectation. + for key, value := range expectations { + if _, ok := p.Members[key.function]; !ok { + t.Errorf("Expectation on %s does not match a member in %s", key.function, p.Pkg.Name()) + } + got, want := value.matches, value.wants + sort.Strings(got) + sort.Strings(want) + if !reflect.DeepEqual(want, got) { + t.Errorf("Within %s wanted instructions of kind %s: %q. got %q", key.function, key.kind, want, got) + } + } +} diff --git a/go/ssa/builder_go120_test.go b/go/ssa/builder_go120_test.go index 84bdd4c41a..04fb11a2d2 100644 --- a/go/ssa/builder_go120_test.go +++ b/go/ssa/builder_go120_test.go @@ -25,6 +25,10 @@ func TestBuildPackageGo120(t *testing.T) { importer types.Importer }{ {"slice to array", "package p; var s []byte; var _ = ([4]byte)(s)", nil}, + {"slice to zero length array", "package p; var s []byte; var _ = ([0]byte)(s)", nil}, + {"slice to zero length array type parameter", "package p; var s []byte; func f[T ~[0]byte]() { tmp := (T)(s); var z T; _ = tmp == z}", nil}, + {"slice to non-zero length array type parameter", "package p; var s []byte; func h[T ~[1]byte | [4]byte]() { tmp := T(s); var z T; _ = tmp == z}", nil}, + {"slice to maybe-zero length array type parameter", "package p; var s []byte; func g[T ~[0]byte | [4]byte]() { tmp := T(s); var z T; _ = tmp == z}", nil}, } for _, tc := range tests { diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go index 6fc844187d..a80d8d5ab7 100644 --- a/go/ssa/builder_test.go +++ b/go/ssa/builder_test.go @@ -226,6 +226,18 @@ func TestRuntimeTypes(t *testing.T) { nil, }, } + + if typeparams.Enabled { + tests = append(tests, []struct { + input string + want []string + }{ + // MakeInterface does not create runtime type for parameterized types. + {`package N; var g interface{}; func f[S any]() { var v []S; g = v }; `, + nil, + }, + }...) + } for _, test := range tests { // Parse the file. fset := token.NewFileSet() diff --git a/go/ssa/const.go b/go/ssa/const.go index dc182d9616..3468eac7e1 100644 --- a/go/ssa/const.go +++ b/go/ssa/const.go @@ -12,65 +12,73 @@ import ( "go/token" "go/types" "strconv" + "strings" + + "golang.org/x/tools/internal/typeparams" ) // NewConst returns a new constant of the specified value and type. // val must be valid according to the specification of Const.Value. func NewConst(val constant.Value, typ types.Type) *Const { + if val == nil { + switch soleTypeKind(typ) { + case types.IsBoolean: + val = constant.MakeBool(false) + case types.IsInteger: + val = constant.MakeInt64(0) + case types.IsString: + val = constant.MakeString("") + } + } return &Const{typ, val} } +// soleTypeKind returns a BasicInfo for which constant.Value can +// represent all zero values for the types in the type set. +// +// types.IsBoolean for false is a representative. +// types.IsInteger for 0 +// types.IsString for "" +// 0 otherwise. +func soleTypeKind(typ types.Type) types.BasicInfo { + // State records the set of possible zero values (false, 0, ""). + // Candidates (perhaps all) are eliminated during the type-set + // iteration, which executes at least once. + state := types.IsBoolean | types.IsInteger | types.IsString + typeSetOf(typ).underIs(func(t types.Type) bool { + var c types.BasicInfo + if t, ok := t.(*types.Basic); ok { + c = t.Info() + } + if c&types.IsNumeric != 0 { // int/float/complex + c = types.IsInteger + } + state = state & c + return state != 0 + }) + return state +} + // intConst returns an 'int' constant that evaluates to i. // (i is an int64 in case the host is narrower than the target.) func intConst(i int64) *Const { return NewConst(constant.MakeInt64(i), tInt) } -// nilConst returns a nil constant of the specified type, which may -// be any reference type, including interfaces. -func nilConst(typ types.Type) *Const { - return NewConst(nil, typ) -} - // stringConst returns a 'string' constant that evaluates to s. func stringConst(s string) *Const { return NewConst(constant.MakeString(s), tString) } -// zeroConst returns a new "zero" constant of the specified type, -// which must not be an array or struct type: the zero values of -// aggregates are well-defined but cannot be represented by Const. +// zeroConst returns a new "zero" constant of the specified type. func zeroConst(t types.Type) *Const { - switch t := t.(type) { - case *types.Basic: - switch { - case t.Info()&types.IsBoolean != 0: - return NewConst(constant.MakeBool(false), t) - case t.Info()&types.IsNumeric != 0: - return NewConst(constant.MakeInt64(0), t) - case t.Info()&types.IsString != 0: - return NewConst(constant.MakeString(""), t) - case t.Kind() == types.UnsafePointer: - fallthrough - case t.Kind() == types.UntypedNil: - return nilConst(t) - default: - panic(fmt.Sprint("zeroConst for unexpected type:", t)) - } - case *types.Pointer, *types.Slice, *types.Interface, *types.Chan, *types.Map, *types.Signature: - return nilConst(t) - case *types.Named: - return NewConst(zeroConst(t.Underlying()).Value, t) - case *types.Array, *types.Struct, *types.Tuple: - panic(fmt.Sprint("zeroConst applied to aggregate:", t)) - } - panic(fmt.Sprint("zeroConst: unexpected ", t)) + return NewConst(nil, t) } func (c *Const) RelString(from *types.Package) string { var s string if c.Value == nil { - s = "nil" + s = zeroString(c.typ, from) } else if c.Value.Kind() == constant.String { s = constant.StringVal(c.Value) const max = 20 @@ -85,6 +93,44 @@ func (c *Const) RelString(from *types.Package) string { return s + ":" + relType(c.Type(), from) } +// zeroString returns the string representation of the "zero" value of the type t. +func zeroString(t types.Type, from *types.Package) string { + switch t := t.(type) { + case *types.Basic: + switch { + case t.Info()&types.IsBoolean != 0: + return "false" + case t.Info()&types.IsNumeric != 0: + return "0" + case t.Info()&types.IsString != 0: + return `""` + case t.Kind() == types.UnsafePointer: + fallthrough + case t.Kind() == types.UntypedNil: + return "nil" + default: + panic(fmt.Sprint("zeroString for unexpected type:", t)) + } + case *types.Pointer, *types.Slice, *types.Interface, *types.Chan, *types.Map, *types.Signature: + return "nil" + case *types.Named: + return zeroString(t.Underlying(), from) + case *types.Array, *types.Struct: + return relType(t, from) + "{}" + case *types.Tuple: + // Tuples are not normal values. + // We are currently format as "(t[0], ..., t[n])". Could be something else. + components := make([]string, t.Len()) + for i := 0; i < t.Len(); i++ { + components[i] = zeroString(t.At(i).Type(), from) + } + return "(" + strings.Join(components, ", ") + ")" + case *typeparams.TypeParam: + return "*new(" + relType(t, from) + ")" + } + panic(fmt.Sprint("zeroString: unexpected ", t)) +} + func (c *Const) Name() string { return c.RelString(nil) } @@ -107,9 +153,26 @@ func (c *Const) Pos() token.Pos { return token.NoPos } -// IsNil returns true if this constant represents a typed or untyped nil value. +// IsNil returns true if this constant represents a typed or untyped nil value +// with an underlying reference type: pointer, slice, chan, map, function, or +// *basic* interface. +// +// Note: a type parameter whose underlying type is a basic interface is +// considered a reference type. func (c *Const) IsNil() bool { - return c.Value == nil + return c.Value == nil && nillable(c.typ) +} + +// nillable reports whether *new(T) == nil is legal for type T. +func nillable(t types.Type) bool { + switch t := t.Underlying().(type) { + case *types.Pointer, *types.Slice, *types.Chan, *types.Map, *types.Signature: + return true + case *types.Interface: + return len(typeSetOf(t)) == 0 // basic interface. + default: + return false + } } // TODO(adonovan): move everything below into golang.org/x/tools/go/ssa/interp. @@ -149,14 +212,16 @@ func (c *Const) Uint64() uint64 { // Float64 returns the numeric value of this constant truncated to fit // a float64. func (c *Const) Float64() float64 { - f, _ := constant.Float64Val(c.Value) + x := constant.ToFloat(c.Value) // (c.Value == nil) => x.Kind() == Unknown + f, _ := constant.Float64Val(x) return f } // Complex128 returns the complex value of this constant truncated to // fit a complex128. func (c *Const) Complex128() complex128 { - re, _ := constant.Float64Val(constant.Real(c.Value)) - im, _ := constant.Float64Val(constant.Imag(c.Value)) + x := constant.ToComplex(c.Value) // (c.Value == nil) => x.Kind() == Unknown + re, _ := constant.Float64Val(constant.Real(x)) + im, _ := constant.Float64Val(constant.Imag(x)) return complex(re, im) } diff --git a/go/ssa/const_test.go b/go/ssa/const_test.go new file mode 100644 index 0000000000..131fe1aced --- /dev/null +++ b/go/ssa/const_test.go @@ -0,0 +1,104 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa_test + +import ( + "go/ast" + "go/constant" + "go/parser" + "go/token" + "go/types" + "math/big" + "strings" + "testing" + + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/internal/typeparams" +) + +func TestConstString(t *testing.T) { + if !typeparams.Enabled { + t.Skip("TestConstString requires type parameters.") + } + + const source = ` + package P + + type Named string + + func fn() (int, bool, string) + func gen[T int]() {} + ` + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "p.go", source, 0) + if err != nil { + t.Fatal(err) + } + + var conf types.Config + pkg, err := conf.Check("P", fset, []*ast.File{f}, nil) + if err != nil { + t.Fatal(err) + } + + for _, test := range []struct { + expr string // type expression + constant interface{} // constant value + want string // expected String() value + }{ + {"int", int64(0), "0:int"}, + {"int64", int64(0), "0:int64"}, + {"float32", int64(0), "0:float32"}, + {"float32", big.NewFloat(1.5), "1.5:float32"}, + {"bool", false, "false:bool"}, + {"string", "", `"":string`}, + {"Named", "", `"":P.Named`}, + {"struct{x string}", nil, "struct{x string}{}:struct{x string}"}, + {"[]int", nil, "nil:[]int"}, + {"[3]int", nil, "[3]int{}:[3]int"}, + {"*int", nil, "nil:*int"}, + {"interface{}", nil, "nil:interface{}"}, + {"interface{string}", nil, `"":interface{string}`}, + {"interface{int|int64}", nil, "0:interface{int|int64}"}, + {"interface{bool}", nil, "false:interface{bool}"}, + {"interface{bool|int}", nil, "nil:interface{bool|int}"}, + {"interface{int|string}", nil, "nil:interface{int|string}"}, + {"interface{bool|string}", nil, "nil:interface{bool|string}"}, + {"interface{struct{x string}}", nil, "nil:interface{struct{x string}}"}, + {"interface{int|int64}", int64(1), "1:interface{int|int64}"}, + {"interface{~bool}", true, "true:interface{~bool}"}, + {"interface{Named}", "lorem ipsum", `"lorem ipsum":interface{P.Named}`}, + {"func() (int, bool, string)", nil, "nil:func() (int, bool, string)"}, + } { + // Eval() expr for its type. + tv, err := types.Eval(fset, pkg, 0, test.expr) + if err != nil { + t.Fatalf("Eval(%s) failed: %v", test.expr, err) + } + var val constant.Value + if test.constant != nil { + val = constant.Make(test.constant) + } + c := ssa.NewConst(val, tv.Type) + got := strings.ReplaceAll(c.String(), " | ", "|") // Accept both interface{a | b} and interface{a|b}. + if got != test.want { + t.Errorf("ssa.NewConst(%v, %s).String() = %v, want %v", val, tv.Type, got, test.want) + } + } + + // Test tuples + fn := pkg.Scope().Lookup("fn") + tup := fn.Type().(*types.Signature).Results() + if got, want := ssa.NewConst(nil, tup).String(), `(0, false, ""):(int, bool, string)`; got != want { + t.Errorf("ssa.NewConst(%v, %s).String() = %v, want %v", nil, tup, got, want) + } + + // Test type-param + gen := pkg.Scope().Lookup("gen") + tp := typeparams.ForSignature(gen.Type().(*types.Signature)).At(0) + if got, want := ssa.NewConst(nil, tp).String(), "0:T"; got != want { + t.Errorf("ssa.NewConst(%v, %s).String() = %v, want %v", nil, tup, got, want) + } +} diff --git a/go/ssa/coretype.go b/go/ssa/coretype.go new file mode 100644 index 0000000000..54bc4a8e6d --- /dev/null +++ b/go/ssa/coretype.go @@ -0,0 +1,256 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +import ( + "go/types" + + "golang.org/x/tools/internal/typeparams" +) + +// Utilities for dealing with core types. + +// coreType returns the core type of T or nil if T does not have a core type. +// +// See https://go.dev/ref/spec#Core_types for the definition of a core type. +func coreType(T types.Type) types.Type { + U := T.Underlying() + if _, ok := U.(*types.Interface); !ok { + return U // for non-interface types, + } + + terms, err := _NormalTerms(U) + if len(terms) == 0 || err != nil { + // len(terms) -> empty type set of interface. + // err != nil => U is invalid, exceeds complexity bounds, or has an empty type set. + return nil // no core type. + } + + U = terms[0].Type().Underlying() + var identical int // i in [0,identical) => Identical(U, terms[i].Type().Underlying()) + for identical = 1; identical < len(terms); identical++ { + if !types.Identical(U, terms[identical].Type().Underlying()) { + break + } + } + + if identical == len(terms) { + // https://go.dev/ref/spec#Core_types + // "There is a single type U which is the underlying type of all types in the type set of T" + return U + } + ch, ok := U.(*types.Chan) + if !ok { + return nil // no core type as identical < len(terms) and U is not a channel. + } + // https://go.dev/ref/spec#Core_types + // "the type chan E if T contains only bidirectional channels, or the type chan<- E or + // <-chan E depending on the direction of the directional channels present." + for chans := identical; chans < len(terms); chans++ { + curr, ok := terms[chans].Type().Underlying().(*types.Chan) + if !ok { + return nil + } + if !types.Identical(ch.Elem(), curr.Elem()) { + return nil // channel elements are not identical. + } + if ch.Dir() == types.SendRecv { + // ch is bidirectional. We can safely always use curr's direction. + ch = curr + } else if curr.Dir() != types.SendRecv && ch.Dir() != curr.Dir() { + // ch and curr are not bidirectional and not the same direction. + return nil + } + } + return ch +} + +// isBytestring returns true if T has the same terms as interface{[]byte | string}. +// These act like a coreType for some operations: slice expressions, append and copy. +// +// See https://go.dev/ref/spec#Core_types for the details on bytestring. +func isBytestring(T types.Type) bool { + U := T.Underlying() + if _, ok := U.(*types.Interface); !ok { + return false + } + + tset := typeSetOf(U) + if len(tset) != 2 { + return false + } + hasBytes, hasString := false, false + tset.underIs(func(t types.Type) bool { + switch { + case isString(t): + hasString = true + case isByteSlice(t): + hasBytes = true + } + return hasBytes || hasString + }) + return hasBytes && hasString +} + +// _NormalTerms returns a slice of terms representing the normalized structural +// type restrictions of a type, if any. +// +// For all types other than *types.TypeParam, *types.Interface, and +// *types.Union, this is just a single term with Tilde() == false and +// Type() == typ. For *types.TypeParam, *types.Interface, and *types.Union, see +// below. +// +// Structural type restrictions of a type parameter are created via +// non-interface types embedded in its constraint interface (directly, or via a +// chain of interface embeddings). For example, in the declaration type +// T[P interface{~int; m()}] int the structural restriction of the type +// parameter P is ~int. +// +// With interface embedding and unions, the specification of structural type +// restrictions may be arbitrarily complex. For example, consider the +// following: +// +// type A interface{ ~string|~[]byte } +// +// type B interface{ int|string } +// +// type C interface { ~string|~int } +// +// type T[P interface{ A|B; C }] int +// +// In this example, the structural type restriction of P is ~string|int: A|B +// expands to ~string|~[]byte|int|string, which reduces to ~string|~[]byte|int, +// which when intersected with C (~string|~int) yields ~string|int. +// +// _NormalTerms computes these expansions and reductions, producing a +// "normalized" form of the embeddings. A structural restriction is normalized +// if it is a single union containing no interface terms, and is minimal in the +// sense that removing any term changes the set of types satisfying the +// constraint. It is left as a proof for the reader that, modulo sorting, there +// is exactly one such normalized form. +// +// Because the minimal representation always takes this form, _NormalTerms +// returns a slice of tilde terms corresponding to the terms of the union in +// the normalized structural restriction. An error is returned if the type is +// invalid, exceeds complexity bounds, or has an empty type set. In the latter +// case, _NormalTerms returns ErrEmptyTypeSet. +// +// _NormalTerms makes no guarantees about the order of terms, except that it +// is deterministic. +// +// This is a copy of x/exp/typeparams.NormalTerms which x/tools cannot depend on. +// TODO(taking): Remove this copy when possible. +func _NormalTerms(typ types.Type) ([]*typeparams.Term, error) { + switch typ := typ.(type) { + case *typeparams.TypeParam: + return typeparams.StructuralTerms(typ) + case *typeparams.Union: + return typeparams.UnionTermSet(typ) + case *types.Interface: + return typeparams.InterfaceTermSet(typ) + default: + return []*typeparams.Term{typeparams.NewTerm(false, typ)}, nil + } +} + +// typeSetOf returns the type set of typ. Returns an empty typeset on an error. +func typeSetOf(typ types.Type) typeSet { + terms, err := _NormalTerms(typ) + if err != nil { + return nil + } + return terms +} + +type typeSet []*typeparams.Term // type terms of the type set + +// underIs calls f with the underlying types of the specific type terms +// of s and reports whether all calls to f returned true. If there are +// no specific terms, underIs returns the result of f(nil). +func (s typeSet) underIs(f func(types.Type) bool) bool { + if len(s) == 0 { + return f(nil) + } + for _, t := range s { + u := t.Type().Underlying() + if !f(u) { + return false + } + } + return true +} + +// indexType returns the element type and index mode of a IndexExpr over a type. +// It returns (nil, invalid) if the type is not indexable; this should never occur in a well-typed program. +func indexType(typ types.Type) (types.Type, indexMode) { + switch U := typ.Underlying().(type) { + case *types.Array: + return U.Elem(), ixArrVar + case *types.Pointer: + if arr, ok := U.Elem().Underlying().(*types.Array); ok { + return arr.Elem(), ixVar + } + case *types.Slice: + return U.Elem(), ixVar + case *types.Map: + return U.Elem(), ixMap + case *types.Basic: + return tByte, ixValue // must be a string + case *types.Interface: + terms, err := _NormalTerms(U) + if len(terms) == 0 || err != nil { + return nil, ixInvalid // no underlying terms or error is empty. + } + + elem, mode := indexType(terms[0].Type()) + for i := 1; i < len(terms) && mode != ixInvalid; i++ { + e, m := indexType(terms[i].Type()) + if !types.Identical(elem, e) { // if type checked, just a sanity check + return nil, ixInvalid + } + // Update the mode to the most constrained address type. + mode = mode.meet(m) + } + if mode != ixInvalid { + return elem, mode + } + } + return nil, ixInvalid +} + +// An indexMode specifies the (addressing) mode of an index operand. +// +// Addressing mode of an index operation is based on the set of +// underlying types. +// Hasse diagram of the indexMode meet semi-lattice: +// +// ixVar ixMap +// | | +// ixArrVar | +// | | +// ixValue | +// \ / +// ixInvalid +type indexMode byte + +const ( + ixInvalid indexMode = iota // index is invalid + ixValue // index is a computed value (not addressable) + ixArrVar // like ixVar, but index operand contains an array + ixVar // index is an addressable variable + ixMap // index is a map index expression (acts like a variable on lhs, commaok on rhs of an assignment) +) + +// meet is the address type that is constrained by both x and y. +func (x indexMode) meet(y indexMode) indexMode { + if (x == ixMap || y == ixMap) && x != y { + return ixInvalid + } + // Use int representation and return min. + if x < y { + return y + } + return x +} diff --git a/go/ssa/coretype_test.go b/go/ssa/coretype_test.go new file mode 100644 index 0000000000..c4ed290fd8 --- /dev/null +++ b/go/ssa/coretype_test.go @@ -0,0 +1,105 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +import ( + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "golang.org/x/tools/internal/typeparams" +) + +func TestCoreType(t *testing.T) { + if !typeparams.Enabled { + t.Skip("TestCoreType requires type parameters.") + } + + const source = ` + package P + + type Named int + + type A any + type B interface{~int} + type C interface{int} + type D interface{Named} + type E interface{~int|interface{Named}} + type F interface{~int|~float32} + type G interface{chan int|interface{chan int}} + type H interface{chan int|chan float32} + type I interface{chan<- int|chan int} + type J interface{chan int|chan<- int} + type K interface{<-chan int|chan int} + type L interface{chan int|<-chan int} + type M interface{chan int|chan Named} + type N interface{<-chan int|chan<- int} + type O interface{chan int|bool} + type P struct{ Named } + type Q interface{ Foo() } + type R interface{ Foo() ; Named } + type S interface{ Foo() ; ~int } + + type T interface{chan int|interface{chan int}|<-chan int} +` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "hello.go", source, 0) + if err != nil { + t.Fatal(err) + } + + var conf types.Config + pkg, err := conf.Check("P", fset, []*ast.File{f}, nil) + if err != nil { + t.Fatal(err) + } + + for _, test := range []struct { + expr string // type expression of Named type + want string // expected core type (or "" if none) + }{ + {"Named", "int"}, // Underlying type is not interface. + {"A", ""}, // Interface has no terms. + {"B", "int"}, // Tilde term. + {"C", "int"}, // Non-tilde term. + {"D", "int"}, // Named term. + {"E", "int"}, // Identical underlying types. + {"F", ""}, // Differing underlying types. + {"G", "chan int"}, // Identical Element types. + {"H", ""}, // Element type int has differing underlying type to float32. + {"I", "chan<- int"}, // SendRecv followed by SendOnly + {"J", "chan<- int"}, // SendOnly followed by SendRecv + {"K", "<-chan int"}, // RecvOnly followed by SendRecv + {"L", "<-chan int"}, // SendRecv followed by RecvOnly + {"M", ""}, // Element type int is not *identical* to Named. + {"N", ""}, // Differing channel directions + {"O", ""}, // A channel followed by a non-channel. + {"P", "struct{P.Named}"}, // Embedded type. + {"Q", ""}, // interface type with no terms and functions + {"R", "int"}, // interface type with both terms and functions. + {"S", "int"}, // interface type with a tilde term + {"T", "<-chan int"}, // Prefix of 2 terms that are identical before switching to channel. + } { + // Eval() expr for its type. + tv, err := types.Eval(fset, pkg, 0, test.expr) + if err != nil { + t.Fatalf("Eval(%s) failed: %v", test.expr, err) + } + + ct := coreType(tv.Type) + var got string + if ct == nil { + got = "" + } else { + got = ct.String() + } + if got != test.want { + t.Errorf("coreType(%s) = %v, want %v", test.expr, got, test.want) + } + } +} diff --git a/go/ssa/create.go b/go/ssa/create.go index 345d9acfbb..ccb20e7968 100644 --- a/go/ssa/create.go +++ b/go/ssa/create.go @@ -91,37 +91,31 @@ func memberFromObject(pkg *Package, obj types.Object, syntax ast.Node) { } // Collect type parameters if this is a generic function/method. - var tparams []*typeparams.TypeParam - for i, rtparams := 0, typeparams.RecvTypeParams(sig); i < rtparams.Len(); i++ { - tparams = append(tparams, rtparams.At(i)) - } - for i, sigparams := 0, typeparams.ForSignature(sig); i < sigparams.Len(); i++ { - tparams = append(tparams, sigparams.At(i)) + var tparams *typeparams.TypeParamList + if rtparams := typeparams.RecvTypeParams(sig); rtparams.Len() > 0 { + tparams = rtparams + } else if sigparams := typeparams.ForSignature(sig); sigparams.Len() > 0 { + tparams = sigparams } fn := &Function{ - name: name, - object: obj, - Signature: sig, - syntax: syntax, - pos: obj.Pos(), - Pkg: pkg, - Prog: pkg.Prog, - _TypeParams: tparams, - info: pkg.info, + name: name, + object: obj, + Signature: sig, + syntax: syntax, + pos: obj.Pos(), + Pkg: pkg, + Prog: pkg.Prog, + typeparams: tparams, + info: pkg.info, } pkg.created.Add(fn) if syntax == nil { fn.Synthetic = "loaded from gc object file" } - if len(tparams) > 0 { + if tparams.Len() > 0 { fn.Prog.createInstanceSet(fn) } - if len(tparams) > 0 && syntax != nil { - fn.Synthetic = "generic function" - // TODO(taking): Allow for the function to be built once type params are supported. - fn.syntax = nil // Treating as an external function temporarily. - } pkg.objects[obj] = fn if sig.Recv() == nil { diff --git a/go/ssa/emit.go b/go/ssa/emit.go index f6537acc97..b041491b6e 100644 --- a/go/ssa/emit.go +++ b/go/ssa/emit.go @@ -121,9 +121,9 @@ func emitCompare(f *Function, op token.Token, x, y Value, pos token.Pos) Value { if types.Identical(xt, yt) { // no conversion necessary - } else if _, ok := xt.(*types.Interface); ok { + } else if isNonTypeParamInterface(x.Type()) { y = emitConv(f, y, x.Type()) - } else if _, ok := yt.(*types.Interface); ok { + } else if isNonTypeParamInterface(y.Type()) { x = emitConv(f, x, y.Type()) } else if _, ok := x.(*Const); ok { x = emitConv(f, x, y.Type()) @@ -166,6 +166,32 @@ func isValuePreserving(ut_src, ut_dst types.Type) bool { return false } +// isSliceToArrayPointer reports whether ut_src is a slice type +// that can be converted to a pointer to an array type ut_dst. +// Precondition: neither argument is a named type. +func isSliceToArrayPointer(ut_src, ut_dst types.Type) bool { + if slice, ok := ut_src.(*types.Slice); ok { + if ptr, ok := ut_dst.(*types.Pointer); ok { + if arr, ok := ptr.Elem().Underlying().(*types.Array); ok { + return types.Identical(slice.Elem(), arr.Elem()) + } + } + } + return false +} + +// isSliceToArray reports whether ut_src is a slice type +// that can be converted to an array type ut_dst. +// Precondition: neither argument is a named type. +func isSliceToArray(ut_src, ut_dst types.Type) bool { + if slice, ok := ut_src.(*types.Slice); ok { + if arr, ok := ut_dst.(*types.Array); ok { + return types.Identical(slice.Elem(), arr.Elem()) + } + } + return false +} + // emitConv emits to f code to convert Value val to exactly type typ, // and returns the converted value. Implicit conversions are required // by language assignability rules in assignments, parameter passing, @@ -180,17 +206,25 @@ func emitConv(f *Function, val Value, typ types.Type) Value { ut_dst := typ.Underlying() ut_src := t_src.Underlying() + dst_types := typeSetOf(ut_dst) + src_types := typeSetOf(ut_src) + // Just a change of type, but not value or representation? - if isValuePreserving(ut_src, ut_dst) { + preserving := src_types.underIs(func(s types.Type) bool { + return dst_types.underIs(func(d types.Type) bool { + return s != nil && d != nil && isValuePreserving(s, d) // all (s -> d) are value preserving. + }) + }) + if preserving { c := &ChangeType{X: val} c.setType(typ) return f.emit(c) } // Conversion to, or construction of a value of, an interface type? - if _, ok := ut_dst.(*types.Interface); ok { + if isNonTypeParamInterface(typ) { // Assignment from one interface type to another? - if _, ok := ut_src.(*types.Interface); ok { + if isNonTypeParamInterface(t_src) { c := &ChangeInterface{X: val} c.setType(typ) return f.emit(c) @@ -198,7 +232,7 @@ func emitConv(f *Function, val Value, typ types.Type) Value { // Untyped nil constant? Return interface-typed nil constant. if ut_src == tUntypedNil { - return nilConst(typ) + return zeroConst(typ) } // Convert (non-nil) "untyped" literals to their default type. @@ -213,7 +247,7 @@ func emitConv(f *Function, val Value, typ types.Type) Value { // Conversion of a compile-time constant value? if c, ok := val.(*Const); ok { - if _, ok := ut_dst.(*types.Basic); ok || c.IsNil() { + if isBasic(ut_dst) || c.Value == nil { // Conversion of a compile-time constant to // another constant type results in a new // constant of the destination type and @@ -227,41 +261,30 @@ func emitConv(f *Function, val Value, typ types.Type) Value { } // Conversion from slice to array pointer? - if slice, ok := ut_src.(*types.Slice); ok { - switch t := ut_dst.(type) { - case *types.Pointer: - ptr := t - if arr, ok := ptr.Elem().Underlying().(*types.Array); ok && types.Identical(slice.Elem(), arr.Elem()) { - c := &SliceToArrayPointer{X: val} - // TODO(taking): Check if this should be ut_dst or ptr. - c.setType(ptr) - return f.emit(c) - } - case *types.Array: - arr := t - if arr.Len() == 0 { - return zeroValue(f, arr) - } - if types.Identical(slice.Elem(), arr.Elem()) { - c := &SliceToArrayPointer{X: val} - c.setType(types.NewPointer(arr)) - x := f.emit(c) - unOp := &UnOp{ - Op: token.MUL, - X: x, - CommaOk: false, - } - unOp.setType(typ) - return f.emit(unOp) - } - } + slice2ptr := src_types.underIs(func(s types.Type) bool { + return dst_types.underIs(func(d types.Type) bool { + return s != nil && d != nil && isSliceToArrayPointer(s, d) // all (s->d) are slice to array pointer conversion. + }) + }) + if slice2ptr { + c := &SliceToArrayPointer{X: val} + c.setType(typ) + return f.emit(c) } + + // Conversion from slice to array? + slice2array := src_types.underIs(func(s types.Type) bool { + return dst_types.underIs(func(d types.Type) bool { + return s != nil && d != nil && isSliceToArray(s, d) // all (s->d) are slice to array conversion. + }) + }) + if slice2array { + return emitSliceToArray(f, val, typ) + } + // A representation-changing conversion? - // At least one of {ut_src,ut_dst} must be *Basic. - // (The other may be []byte or []rune.) - _, ok1 := ut_src.(*types.Basic) - _, ok2 := ut_dst.(*types.Basic) - if ok1 || ok2 { + // All of ut_src or ut_dst is basic, byte slice, or rune slice? + if isBasicConvTypes(src_types) || isBasicConvTypes(dst_types) { c := &Convert{X: val} c.setType(typ) return f.emit(c) @@ -270,6 +293,33 @@ func emitConv(f *Function, val Value, typ types.Type) Value { panic(fmt.Sprintf("in %s: cannot convert %s (%s) to %s", f, val, val.Type(), typ)) } +// emitTypeCoercion emits to f code to coerce the type of a +// Value v to exactly type typ, and returns the coerced value. +// +// Requires that coercing v.Typ() to typ is a value preserving change. +// +// Currently used only when v.Type() is a type instance of typ or vice versa. +// A type v is a type instance of a type t if there exists a +// type parameter substitution σ s.t. σ(v) == t. Example: +// +// σ(func(T) T) == func(int) int for σ == [T ↦ int] +// +// This happens in instantiation wrappers for conversion +// from an instantiation to a parameterized type (and vice versa) +// with σ substituting f.typeparams by f.typeargs. +func emitTypeCoercion(f *Function, v Value, typ types.Type) Value { + if types.Identical(v.Type(), typ) { + return v // no coercion needed + } + // TODO(taking): for instances should we record which side is the instance? + c := &ChangeType{ + X: v, + } + c.setType(typ) + f.emit(c) + return c +} + // emitStore emits to f an instruction to store value val at location // addr, applying implicit conversions as required by assignability rules. func emitStore(f *Function, addr, val Value, pos token.Pos) *Store { @@ -378,7 +428,7 @@ func emitTailCall(f *Function, call *Call) { // value of a field. func emitImplicitSelections(f *Function, v Value, indices []int, pos token.Pos) Value { for _, index := range indices { - fld := deref(v.Type()).Underlying().(*types.Struct).Field(index) + fld := coreType(deref(v.Type())).(*types.Struct).Field(index) if isPointer(v.Type()) { instr := &FieldAddr{ @@ -412,7 +462,7 @@ func emitImplicitSelections(f *Function, v Value, indices []int, pos token.Pos) // field's value. // Ident id is used for position and debug info. func emitFieldSelection(f *Function, v Value, index int, wantAddr bool, id *ast.Ident) Value { - fld := deref(v.Type()).Underlying().(*types.Struct).Field(index) + fld := coreType(deref(v.Type())).(*types.Struct).Field(index) if isPointer(v.Type()) { instr := &FieldAddr{ X: v, @@ -438,6 +488,48 @@ func emitFieldSelection(f *Function, v Value, index int, wantAddr bool, id *ast. return v } +// emitSliceToArray emits to f code to convert a slice value to an array value. +// +// Precondition: all types in type set of typ are arrays and convertible to all +// types in the type set of val.Type(). +func emitSliceToArray(f *Function, val Value, typ types.Type) Value { + // Emit the following: + // if val == nil && len(typ) == 0 { + // ptr = &[0]T{} + // } else { + // ptr = SliceToArrayPointer(val) + // } + // v = *ptr + + ptype := types.NewPointer(typ) + p := &SliceToArrayPointer{X: val} + p.setType(ptype) + ptr := f.emit(p) + + nilb := f.newBasicBlock("slicetoarray.nil") + nonnilb := f.newBasicBlock("slicetoarray.nonnil") + done := f.newBasicBlock("slicetoarray.done") + + cond := emitCompare(f, token.EQL, ptr, zeroConst(ptype), token.NoPos) + emitIf(f, cond, nilb, nonnilb) + f.currentBlock = nilb + + zero := f.addLocal(typ, token.NoPos) + emitJump(f, done) + f.currentBlock = nonnilb + + emitJump(f, done) + f.currentBlock = done + + phi := &Phi{Edges: []Value{zero, ptr}, Comment: "slicetoarray"} + phi.pos = val.Pos() + phi.setType(typ) + x := f.emit(phi) + unOp := &UnOp{Op: token.MUL, X: x} + unOp.setType(typ) + return f.emit(unOp) +} + // zeroValue emits to f code to produce a zero value of type t, // and returns it. func zeroValue(f *Function, t types.Type) Value { diff --git a/go/ssa/func.go b/go/ssa/func.go index c598ff836d..57f5f718f7 100644 --- a/go/ssa/func.go +++ b/go/ssa/func.go @@ -251,7 +251,10 @@ func buildReferrers(f *Function) { } // mayNeedRuntimeTypes returns all of the types in the body of fn that might need runtime types. +// +// EXCLUSIVE_LOCKS_ACQUIRED(meth.Prog.methodsMu) func mayNeedRuntimeTypes(fn *Function) []types.Type { + // Collect all types that may need rtypes, i.e. those that flow into an interface. var ts []types.Type for _, bb := range fn.Blocks { for _, instr := range bb.Instrs { @@ -260,7 +263,21 @@ func mayNeedRuntimeTypes(fn *Function) []types.Type { } } } - return ts + + // Types that contain a parameterized type are considered to not be runtime types. + if fn.typeparams.Len() == 0 { + return ts // No potentially parameterized types. + } + // Filter parameterized types, in place. + fn.Prog.methodsMu.Lock() + defer fn.Prog.methodsMu.Unlock() + filtered := ts[:0] + for _, t := range ts { + if !fn.Prog.parameterized.isParameterized(t) { + filtered = append(filtered, t) + } + } + return filtered } // finishBody() finalizes the contents of the function after SSA code generation of its body. @@ -518,8 +535,8 @@ func (fn *Function) declaredPackage() *Package { switch { case fn.Pkg != nil: return fn.Pkg // non-generic function - case fn._Origin != nil: - return fn._Origin.Pkg // instance of a named generic function + case fn.topLevelOrigin != nil: + return fn.topLevelOrigin.Pkg // instance of a named generic function case fn.parent != nil: return fn.parent.declaredPackage() // instance of an anonymous [generic] function default: diff --git a/go/ssa/instantiate.go b/go/ssa/instantiate.go index 049b53487d..f73594bb41 100644 --- a/go/ssa/instantiate.go +++ b/go/ssa/instantiate.go @@ -18,7 +18,7 @@ import ( // // This is an experimental interface! It may change without warning. func (prog *Program) _Instances(fn *Function) []*Function { - if len(fn._TypeParams) == 0 { + if fn.typeparams.Len() == 0 || len(fn.typeargs) > 0 { return nil } @@ -29,7 +29,7 @@ func (prog *Program) _Instances(fn *Function) []*Function { // A set of instantiations of a generic function fn. type instanceSet struct { - fn *Function // len(fn._TypeParams) > 0 and len(fn._TypeArgs) == 0. + fn *Function // fn.typeparams.Len() > 0 and len(fn.typeargs) == 0. instances map[*typeList]*Function // canonical type arguments to an instance. syntax *ast.FuncDecl // fn.syntax copy for instantiating after fn is done. nil on synthetic packages. info *types.Info // fn.pkg.info copy for building after fn is done.. nil on synthetic packages. @@ -56,7 +56,7 @@ func (insts *instanceSet) list() []*Function { // // EXCLUSIVE_LOCKS_ACQUIRED(prog.methodMu) func (prog *Program) createInstanceSet(fn *Function) { - assert(len(fn._TypeParams) > 0 && len(fn._TypeArgs) == 0, "Can only create instance sets for generic functions") + assert(fn.typeparams.Len() > 0 && len(fn.typeargs) == 0, "Can only create instance sets for generic functions") prog.methodsMu.Lock() defer prog.methodsMu.Unlock() @@ -73,7 +73,7 @@ func (prog *Program) createInstanceSet(fn *Function) { } } -// needsInstance returns an Function that that is the instantiation of fn with the type arguments targs. +// needsInstance returns a Function that is the instantiation of fn with the type arguments targs. // // Any CREATEd instance is added to cr. // @@ -82,41 +82,45 @@ func (prog *Program) needsInstance(fn *Function, targs []types.Type, cr *creator prog.methodsMu.Lock() defer prog.methodsMu.Unlock() - return prog.instances[fn].lookupOrCreate(targs, cr) + return prog.lookupOrCreateInstance(fn, targs, cr) +} + +// lookupOrCreateInstance returns a Function that is the instantiation of fn with the type arguments targs. +// +// Any CREATEd instance is added to cr. +// +// EXCLUSIVE_LOCKS_REQUIRED(prog.methodMu) +func (prog *Program) lookupOrCreateInstance(fn *Function, targs []types.Type, cr *creator) *Function { + return prog.instances[fn].lookupOrCreate(targs, &prog.parameterized, cr) } // lookupOrCreate returns the instantiation of insts.fn using targs. -// If the instantiation is reported, this is added to cr. -func (insts *instanceSet) lookupOrCreate(targs []types.Type, cr *creator) *Function { +// If the instantiation is created, this is added to cr. +func (insts *instanceSet) lookupOrCreate(targs []types.Type, parameterized *tpWalker, cr *creator) *Function { if insts.instances == nil { insts.instances = make(map[*typeList]*Function) } + fn := insts.fn + prog := fn.Prog + // canonicalize on a tuple of targs. Sig is not unique. // // func A[T any]() { // var x T // fmt.Println("%T", x) // } - key := insts.fn.Prog.canon.List(targs) + key := prog.canon.List(targs) if inst, ok := insts.instances[key]; ok { return inst } + // CREATE instance/instantiation wrapper var syntax ast.Node if insts.syntax != nil { syntax = insts.syntax } - instance := createInstance(insts.fn, targs, insts.info, syntax, cr) - insts.instances[key] = instance - return instance -} -// createInstance returns an CREATEd instantiation of fn using targs. -// -// Function is added to cr. -func createInstance(fn *Function, targs []types.Type, info *types.Info, syntax ast.Node, cr *creator) *Function { - prog := fn.Prog var sig *types.Signature var obj *types.Func if recv := fn.Signature.Recv(); recv != nil { @@ -137,25 +141,36 @@ func createInstance(fn *Function, targs []types.Type, info *types.Info, syntax a sig = prog.canon.Type(instance).(*types.Signature) } + var synthetic string + var subst *subster + + concrete := !parameterized.anyParameterized(targs) + + if prog.mode&InstantiateGenerics != 0 && concrete { + synthetic = fmt.Sprintf("instance of %s", fn.Name()) + subst = makeSubster(prog.ctxt, fn.typeparams, targs, false) + } else { + synthetic = fmt.Sprintf("instantiation wrapper of %s", fn.Name()) + } + name := fmt.Sprintf("%s%s", fn.Name(), targs) // may not be unique - synthetic := fmt.Sprintf("instantiation of %s", fn.Name()) instance := &Function{ - name: name, - object: obj, - Signature: sig, - Synthetic: synthetic, - _Origin: fn, - pos: obj.Pos(), - Pkg: nil, - Prog: fn.Prog, - _TypeParams: fn._TypeParams, - _TypeArgs: targs, - info: info, // on synthetic packages info is nil. - subst: makeSubster(prog.ctxt, fn._TypeParams, targs, false), - } - if prog.mode&InstantiateGenerics != 0 { - instance.syntax = syntax // otherwise treat instance as an external function. + name: name, + object: obj, + Signature: sig, + Synthetic: synthetic, + syntax: syntax, + topLevelOrigin: fn, + pos: obj.Pos(), + Pkg: nil, + Prog: fn.Prog, + typeparams: fn.typeparams, // share with origin + typeargs: targs, + info: insts.info, // on synthetic packages info is nil. + subst: subst, } + cr.Add(instance) + insts.instances[key] = instance return instance } diff --git a/go/ssa/instantiate_test.go b/go/ssa/instantiate_test.go index 0da8c63042..cd33e7e659 100644 --- a/go/ssa/instantiate_test.go +++ b/go/ssa/instantiate_test.go @@ -4,19 +4,52 @@ package ssa -// Note: Tests use unexported functions. +// Note: Tests use unexported method _Instances. import ( "bytes" + "fmt" "go/types" "reflect" "sort" + "strings" "testing" "golang.org/x/tools/go/loader" "golang.org/x/tools/internal/typeparams" ) +// loadProgram creates loader.Program out of p. +func loadProgram(p string) (*loader.Program, error) { + // Parse + var conf loader.Config + f, err := conf.ParseFile("", p) + if err != nil { + return nil, fmt.Errorf("parse: %v", err) + } + conf.CreateFromFiles("p", f) + + // Load + lprog, err := conf.Load() + if err != nil { + return nil, fmt.Errorf("Load: %v", err) + } + return lprog, nil +} + +// buildPackage builds and returns ssa representation of package pkg of lprog. +func buildPackage(lprog *loader.Program, pkg string, mode BuilderMode) *Package { + prog := NewProgram(lprog.Fset, mode) + + for _, info := range lprog.AllPackages { + prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable) + } + + p := prog.Package(lprog.Package(pkg).Pkg) + p.Build() + return p +} + // TestNeedsInstance ensures that new method instances can be created via needsInstance, // that TypeArgs are as expected, and can be accessed via _Instances. func TestNeedsInstance(t *testing.T) { @@ -45,30 +78,15 @@ func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer) // func init func() // var init$guard bool - // Parse - var conf loader.Config - f, err := conf.ParseFile("", input) - if err != nil { - t.Fatalf("parse: %v", err) - } - conf.CreateFromFiles("p", f) - - // Load - lprog, err := conf.Load() - if err != nil { - t.Fatalf("Load: %v", err) + lprog, err := loadProgram(input) + if err != err { + t.Fatal(err) } for _, mode := range []BuilderMode{BuilderMode(0), InstantiateGenerics} { // Create and build SSA - prog := NewProgram(lprog.Fset, mode) - - for _, info := range lprog.AllPackages { - prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable) - } - - p := prog.Package(lprog.Package("p").Pkg) - p.Build() + p := buildPackage(lprog, "p", mode) + prog := p.Prog ptr := p.Type("Pointer").Type().(*types.Named) if ptr.NumMethods() != 1 { @@ -88,11 +106,11 @@ func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer) if len(cr) != 1 { t.Errorf("Expected first instance to create a function. got %d created functions", len(cr)) } - if instance._Origin != meth { - t.Errorf("Expected Origin of %s to be %s. got %s", instance, meth, instance._Origin) + if instance.Origin() != meth { + t.Errorf("Expected Origin of %s to be %s. got %s", instance, meth, instance.Origin()) } - if len(instance._TypeArgs) != 1 || !types.Identical(instance._TypeArgs[0], intSliceTyp) { - t.Errorf("Expected TypeArgs of %s to be %v. got %v", instance, []types.Type{intSliceTyp}, instance._TypeArgs) + if len(instance.TypeArgs()) != 1 || !types.Identical(instance.TypeArgs()[0], intSliceTyp) { + t.Errorf("Expected TypeArgs of %s to be %v. got %v", instance, []types.Type{intSliceTyp}, instance.typeargs) } instances := prog._Instances(meth) if want := []*Function{instance}; !reflect.DeepEqual(instances, want) { @@ -126,3 +144,218 @@ func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer) } } } + +// TestCallsToInstances checks that calles of calls to generic functions, +// without monomorphization, are wrappers around the origin generic function. +func TestCallsToInstances(t *testing.T) { + if !typeparams.Enabled { + return + } + const input = ` +package p + +type I interface { + Foo() +} + +type A int +func (a A) Foo() {} + +type J[T any] interface{ Bar() T } +type K[T any] struct{ J[T] } + +func Id[T any] (t T) T { + return t +} + +func Lambda[T I]() func() func(T) { + return func() func(T) { + return T.Foo + } +} + +func NoOp[T any]() {} + +func Bar[T interface { Foo(); ~int | ~string }, U any] (t T, u U) { + Id[U](u) + Id[T](t) +} + +func Make[T any]() interface{} { + NoOp[K[T]]() + return nil +} + +func entry(i int, a A) int { + Lambda[A]()()(a) + + x := Make[int]() + if j, ok := x.(interface{ Bar() int }); ok { + print(j) + } + + Bar[A, int](a, i) + + return Id[int](i) +} +` + lprog, err := loadProgram(input) + if err != err { + t.Fatal(err) + } + + p := buildPackage(lprog, "p", SanityCheckFunctions) + prog := p.Prog + + for _, ti := range []struct { + orig string + instance string + tparams string + targs string + chTypeInstrs int // number of ChangeType instructions in f's body + }{ + {"Id", "Id[int]", "[T]", "[int]", 2}, + {"Lambda", "Lambda[p.A]", "[T]", "[p.A]", 1}, + {"Make", "Make[int]", "[T]", "[int]", 0}, + {"NoOp", "NoOp[p.K[T]]", "[T]", "[p.K[T]]", 0}, + } { + test := ti + t.Run(test.instance, func(t *testing.T) { + f := p.Members[test.orig].(*Function) + if f == nil { + t.Fatalf("origin function not found") + } + + i := instanceOf(f, test.instance, prog) + if i == nil { + t.Fatalf("instance not found") + } + + // for logging on failures + var body strings.Builder + i.WriteTo(&body) + t.Log(body.String()) + + if len(i.Blocks) != 1 { + t.Fatalf("body has more than 1 block") + } + + if instrs := changeTypeInstrs(i.Blocks[0]); instrs != test.chTypeInstrs { + t.Errorf("want %v instructions; got %v", test.chTypeInstrs, instrs) + } + + if test.tparams != tparams(i) { + t.Errorf("want %v type params; got %v", test.tparams, tparams(i)) + } + + if test.targs != targs(i) { + t.Errorf("want %v type arguments; got %v", test.targs, targs(i)) + } + }) + } +} + +func instanceOf(f *Function, name string, prog *Program) *Function { + for _, i := range prog._Instances(f) { + if i.Name() == name { + return i + } + } + return nil +} + +func tparams(f *Function) string { + tplist := f.TypeParams() + var tps []string + for i := 0; i < tplist.Len(); i++ { + tps = append(tps, tplist.At(i).String()) + } + return fmt.Sprint(tps) +} + +func targs(f *Function) string { + var tas []string + for _, ta := range f.TypeArgs() { + tas = append(tas, ta.String()) + } + return fmt.Sprint(tas) +} + +func changeTypeInstrs(b *BasicBlock) int { + cnt := 0 + for _, i := range b.Instrs { + if _, ok := i.(*ChangeType); ok { + cnt++ + } + } + return cnt +} + +func TestInstanceUniqueness(t *testing.T) { + if !typeparams.Enabled { + return + } + const input = ` +package p + +func H[T any](t T) { + print(t) +} + +func F[T any](t T) { + H[T](t) + H[T](t) + H[T](t) +} + +func G[T any](t T) { + H[T](t) + H[T](t) +} + +func Foo[T any, S any](t T, s S) { + Foo[S, T](s, t) + Foo[T, S](t, s) +} +` + lprog, err := loadProgram(input) + if err != err { + t.Fatal(err) + } + + p := buildPackage(lprog, "p", SanityCheckFunctions) + prog := p.Prog + + for _, test := range []struct { + orig string + instances string + }{ + {"H", "[p.H[T] p.H[T]]"}, + {"Foo", "[p.Foo[S T] p.Foo[T S]]"}, + } { + t.Run(test.orig, func(t *testing.T) { + f := p.Members[test.orig].(*Function) + if f == nil { + t.Fatalf("origin function not found") + } + + instances := prog._Instances(f) + sort.Slice(instances, func(i, j int) bool { return instances[i].Name() < instances[j].Name() }) + + if got := fmt.Sprintf("%v", instances); !reflect.DeepEqual(got, test.instances) { + t.Errorf("got %v instances, want %v", got, test.instances) + } + }) + } +} + +// instancesStr returns a sorted slice of string +// representation of instances. +func instancesStr(instances []*Function) []string { + var is []string + for _, i := range instances { + is = append(is, fmt.Sprintf("%v", i)) + } + sort.Strings(is) + return is +} diff --git a/go/ssa/interp/interp.go b/go/ssa/interp/interp.go index 2b21aad708..58cac46424 100644 --- a/go/ssa/interp/interp.go +++ b/go/ssa/interp/interp.go @@ -51,7 +51,6 @@ import ( "os" "reflect" "runtime" - "strings" "sync/atomic" "golang.org/x/tools/go/ssa" @@ -335,7 +334,17 @@ func visitInstr(fr *frame, instr ssa.Instruction) continuation { } case *ssa.Index: - fr.env[instr] = fr.get(instr.X).(array)[asInt64(fr.get(instr.Index))] + x := fr.get(instr.X) + idx := fr.get(instr.Index) + + switch x := x.(type) { + case array: + fr.env[instr] = x[asInt64(idx)] + case string: + fr.env[instr] = x[asInt64(idx)] + default: + panic(fmt.Sprintf("unexpected x type in Index: %T", x)) + } case *ssa.Lookup: fr.env[instr] = lookup(instr, fr.get(instr.X), fr.get(instr.Index)) @@ -506,13 +515,15 @@ func callSSA(i *interpreter, caller *frame, callpos token.Pos, fn *ssa.Function, return ext(fr, args) } if fn.Blocks == nil { - var reason string // empty by default - if strings.HasPrefix(fn.Synthetic, "instantiation") { - reason = " (interp requires ssa.BuilderMode to include InstantiateGenerics on generics)" - } - panic("no code for function: " + name + reason) + panic("no code for function: " + name) } } + + // generic function body? + if fn.TypeParams().Len() > 0 && len(fn.TypeArgs()) == 0 { + panic("interp requires ssa.BuilderMode to include InstantiateGenerics to execute generics") + } + fr.env = make(map[ssa.Value]value) fr.block = fn.Blocks[0] fr.locals = make([]value, len(fn.Locals)) diff --git a/go/ssa/interp/interp_test.go b/go/ssa/interp/interp_test.go index 51a74015c9..c893d83e75 100644 --- a/go/ssa/interp/interp_test.go +++ b/go/ssa/interp/interp_test.go @@ -127,12 +127,14 @@ var testdataTests = []string{ "width32.go", "fixedbugs/issue52342.go", - "fixedbugs/issue55086.go", } func init() { if typeparams.Enabled { testdataTests = append(testdataTests, "fixedbugs/issue52835.go") + testdataTests = append(testdataTests, "fixedbugs/issue55086.go") + testdataTests = append(testdataTests, "typeassert.go") + testdataTests = append(testdataTests, "zeros.go") } } diff --git a/go/ssa/interp/ops.go b/go/ssa/interp/ops.go index 8f031384f0..188899d69f 100644 --- a/go/ssa/interp/ops.go +++ b/go/ssa/interp/ops.go @@ -34,9 +34,10 @@ type exitPanic int // constValue returns the value of the constant with the // dynamic type tag appropriate for c.Type(). func constValue(c *ssa.Const) value { - if c.IsNil() { - return zero(c.Type()) // typed nil + if c.Value == nil { + return zero(c.Type()) // typed zero } + // c is not a type parameter so it's underlying type is basic. if t, ok := c.Type().Underlying().(*types.Basic); ok { // TODO(adonovan): eliminate untyped constants from SSA form. @@ -307,7 +308,7 @@ func slice(x, lo, hi, max value) value { panic(fmt.Sprintf("slice: unexpected X type: %T", x)) } -// lookup returns x[idx] where x is a map or string. +// lookup returns x[idx] where x is a map. func lookup(instr *ssa.Lookup, x, idx value) value { switch x := x.(type) { // map or string case map[value]value, *hashmap: @@ -327,8 +328,6 @@ func lookup(instr *ssa.Lookup, x, idx value) value { v = tuple{v, ok} } return v - case string: - return x[asInt64(idx)] } panic(fmt.Sprintf("unexpected x type in Lookup: %T", x)) } diff --git a/go/ssa/interp/testdata/boundmeth.go b/go/ssa/interp/testdata/boundmeth.go index 69937f9d3c..47b9406859 100644 --- a/go/ssa/interp/testdata/boundmeth.go +++ b/go/ssa/interp/testdata/boundmeth.go @@ -123,7 +123,8 @@ func nilInterfaceMethodValue() { r := fmt.Sprint(recover()) // runtime panic string varies across toolchains if r != "interface conversion: interface is nil, not error" && - r != "runtime error: invalid memory address or nil pointer dereference" { + r != "runtime error: invalid memory address or nil pointer dereference" && + r != "method value: interface is nil" { panic("want runtime panic from nil interface method value, got " + r) } }() diff --git a/go/ssa/interp/testdata/slice2array.go b/go/ssa/interp/testdata/slice2array.go index 43c0543eab..84e6b73300 100644 --- a/go/ssa/interp/testdata/slice2array.go +++ b/go/ssa/interp/testdata/slice2array.go @@ -19,7 +19,7 @@ func main() { { var s []int - a:= ([0]int)(s) + a := ([0]int)(s) if a != [0]int{} { panic("zero len array is not equal") } @@ -31,6 +31,20 @@ func main() { if !threeToFourDoesPanic() { panic("panic expected from threeToFourDoesPanic()") } + + if !fourPanicsWhileOneDoesNot[[4]int]() { + panic("panic expected from fourPanicsWhileOneDoesNot[[4]int]()") + } + if fourPanicsWhileOneDoesNot[[1]int]() { + panic("no panic expected from fourPanicsWhileOneDoesNot[[1]int]()") + } + + if !fourPanicsWhileZeroDoesNot[[4]int]() { + panic("panic expected from fourPanicsWhileZeroDoesNot[[4]int]()") + } + if fourPanicsWhileZeroDoesNot[[0]int]() { + panic("no panic expected from fourPanicsWhileZeroDoesNot[[0]int]()") + } } func emptyToEmptyDoesNotPanic() (raised bool) { @@ -53,4 +67,26 @@ func threeToFourDoesPanic() (raised bool) { s := make([]int, 3, 5) _ = ([4]int)(s) return false -} \ No newline at end of file +} + +func fourPanicsWhileOneDoesNot[T [1]int | [4]int]() (raised bool) { + defer func() { + if e := recover(); e != nil { + raised = true + } + }() + s := make([]int, 3, 5) + _ = T(s) + return false +} + +func fourPanicsWhileZeroDoesNot[T [0]int | [4]int]() (raised bool) { + defer func() { + if e := recover(); e != nil { + raised = true + } + }() + var s []int + _ = T(s) + return false +} diff --git a/go/ssa/interp/testdata/typeassert.go b/go/ssa/interp/testdata/typeassert.go new file mode 100644 index 0000000000..792a7558f6 --- /dev/null +++ b/go/ssa/interp/testdata/typeassert.go @@ -0,0 +1,32 @@ +// Tests of type asserts. +// Requires type parameters. +package typeassert + +type fooer interface{ foo() string } + +type X int + +func (_ X) foo() string { return "x" } + +func f[T fooer](x T) func() string { + return x.foo +} + +func main() { + if f[X](0)() != "x" { + panic("f[X]() != 'x'") + } + + p := false + func() { + defer func() { + if recover() != nil { + p = true + } + }() + f[fooer](nil) // panics on x.foo when T is an interface and nil. + }() + if !p { + panic("f[fooer] did not panic") + } +} diff --git a/go/ssa/interp/testdata/zeros.go b/go/ssa/interp/testdata/zeros.go new file mode 100644 index 0000000000..509c78a36e --- /dev/null +++ b/go/ssa/interp/testdata/zeros.go @@ -0,0 +1,45 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Test interpretation on zero values with type params. +package zeros + +func assert(cond bool, msg string) { + if !cond { + panic(msg) + } +} + +func tp0[T int | string | float64]() T { return T(0) } + +func tpFalse[T ~bool]() T { return T(false) } + +func tpEmptyString[T string | []byte]() T { return T("") } + +func tpNil[T *int | []byte]() T { return T(nil) } + +func main() { + // zero values + var zi int + var zf float64 + var zs string + + assert(zi == int(0), "zero value of int is int(0)") + assert(zf == float64(0), "zero value of float64 is float64(0)") + assert(zs != string(0), "zero value of string is not string(0)") + + assert(zi == tp0[int](), "zero value of int is int(0)") + assert(zf == tp0[float64](), "zero value of float64 is float64(0)") + assert(zs != tp0[string](), "zero value of string is not string(0)") + + assert(zf == -0.0, "constant -0.0 is converted to 0.0") + + assert(!tpFalse[bool](), "zero value of bool is false") + + assert(tpEmptyString[string]() == zs, `zero value of string is string("")`) + assert(len(tpEmptyString[[]byte]()) == 0, `[]byte("") is empty`) + + assert(tpNil[*int]() == nil, "nil is nil") + assert(tpNil[[]byte]() == nil, "nil is nil") +} diff --git a/go/ssa/lift.go b/go/ssa/lift.go index c350481db7..945536bbbf 100644 --- a/go/ssa/lift.go +++ b/go/ssa/lift.go @@ -44,6 +44,8 @@ import ( "go/types" "math/big" "os" + + "golang.org/x/tools/internal/typeparams" ) // If true, show diagnostic information at each step of lifting. @@ -381,10 +383,9 @@ type newPhiMap map[*BasicBlock][]newPhi // // fresh is a source of fresh ids for phi nodes. func liftAlloc(df domFrontier, alloc *Alloc, newPhis newPhiMap, fresh *int) bool { - // Don't lift aggregates into registers, because we don't have - // a way to express their zero-constants. + // TODO(taking): zero constants of aggregated types can now be lifted. switch deref(alloc.Type()).Underlying().(type) { - case *types.Array, *types.Struct: + case *types.Array, *types.Struct, *typeparams.TypeParam: return false } diff --git a/go/ssa/lvalue.go b/go/ssa/lvalue.go index 455b1e50fa..51122b8e85 100644 --- a/go/ssa/lvalue.go +++ b/go/ssa/lvalue.go @@ -56,12 +56,12 @@ func (a *address) typ() types.Type { } // An element is an lvalue represented by m[k], the location of an -// element of a map or string. These locations are not addressable +// element of a map. These locations are not addressable // since pointers cannot be formed from them, but they do support -// load(), and in the case of maps, store(). +// load() and store(). type element struct { - m, k Value // map or string - t types.Type // map element type or string byte type + m, k Value // map + t types.Type // map element type pos token.Pos // source position of colon ({k:v}) or lbrack (m[k]=v) } @@ -86,7 +86,7 @@ func (e *element) store(fn *Function, v Value) { } func (e *element) address(fn *Function) Value { - panic("map/string elements are not addressable") + panic("map elements are not addressable") } func (e *element) typ() types.Type { diff --git a/go/ssa/methods.go b/go/ssa/methods.go index 6954e17b77..4185618cdd 100644 --- a/go/ssa/methods.go +++ b/go/ssa/methods.go @@ -27,8 +27,8 @@ func (prog *Program) MethodValue(sel *types.Selection) *Function { panic(fmt.Sprintf("MethodValue(%s) kind != MethodVal", sel)) } T := sel.Recv() - if isInterface(T) { - return nil // abstract method (interface) + if types.IsInterface(T) { + return nil // abstract method (interface, possibly type param) } if prog.mode&LogSource != 0 { defer logStack("MethodValue %s %v", T, sel)() @@ -76,7 +76,7 @@ type methodSet struct { // EXCLUSIVE_LOCKS_REQUIRED(prog.methodsMu) func (prog *Program) createMethodSet(T types.Type) *methodSet { if prog.mode&SanityCheckFunctions != 0 { - if isInterface(T) || prog.parameterized.isParameterized(T) { + if types.IsInterface(T) || prog.parameterized.isParameterized(T) { panic("type is interface or parameterized") } } @@ -107,9 +107,9 @@ func (prog *Program) addMethod(mset *methodSet, sel *types.Selection, cr *creato fn = makeWrapper(prog, sel, cr) } else { fn = prog.originFunc(obj) - if len(fn._TypeParams) > 0 { // instantiate + if fn.typeparams.Len() > 0 { // instantiate targs := receiverTypeArgs(obj) - fn = prog.instances[fn].lookupOrCreate(targs, cr) + fn = prog.lookupOrCreateInstance(fn, targs, cr) } } if fn.Signature.Recv() == nil { @@ -190,7 +190,7 @@ func (prog *Program) needMethods(T types.Type, skip bool, cr *creator) { tmset := prog.MethodSets.MethodSet(T) - if !skip && !isInterface(T) && tmset.Len() > 0 { + if !skip && !types.IsInterface(T) && tmset.Len() > 0 { // Create methods of T. mset := prog.createMethodSet(T) if !mset.complete { diff --git a/go/ssa/parameterized.go b/go/ssa/parameterized.go index 956718cd72..b11413c818 100644 --- a/go/ssa/parameterized.go +++ b/go/ssa/parameterized.go @@ -111,3 +111,12 @@ func (w *tpWalker) isParameterized(typ types.Type) (res bool) { return false } + +func (w *tpWalker) anyParameterized(ts []types.Type) bool { + for _, t := range ts { + if w.isParameterized(t) { + return true + } + } + return false +} diff --git a/go/ssa/print.go b/go/ssa/print.go index b8e53923a1..9aa6809789 100644 --- a/go/ssa/print.go +++ b/go/ssa/print.go @@ -232,7 +232,7 @@ func (v *MakeChan) String() string { } func (v *FieldAddr) String() string { - st := deref(v.X.Type()).Underlying().(*types.Struct) + st := coreType(deref(v.X.Type())).(*types.Struct) // Be robust against a bad index. name := "?" if 0 <= v.Field && v.Field < st.NumFields() { @@ -242,7 +242,7 @@ func (v *FieldAddr) String() string { } func (v *Field) String() string { - st := v.X.Type().Underlying().(*types.Struct) + st := coreType(v.X.Type()).(*types.Struct) // Be robust against a bad index. name := "?" if 0 <= v.Field && v.Field < st.NumFields() { diff --git a/go/ssa/sanity.go b/go/ssa/sanity.go index 7d71302756..3fb3f394e8 100644 --- a/go/ssa/sanity.go +++ b/go/ssa/sanity.go @@ -132,9 +132,9 @@ func (s *sanity) checkInstr(idx int, instr Instruction) { case *ChangeType: case *SliceToArrayPointer: case *Convert: - if _, ok := instr.X.Type().Underlying().(*types.Basic); !ok { - if _, ok := instr.Type().Underlying().(*types.Basic); !ok { - s.errorf("convert %s -> %s: at least one type must be basic", instr.X.Type(), instr.Type()) + if from := instr.X.Type(); !isBasicConvTypes(typeSetOf(from)) { + if to := instr.Type(); !isBasicConvTypes(typeSetOf(to)) { + s.errorf("convert %s -> %s: at least one type must be basic (or all basic, []byte, or []rune)", from, to) } } @@ -403,7 +403,7 @@ func (s *sanity) checkFunction(fn *Function) bool { // - check transient fields are nil // - warn if any fn.Locals do not appear among block instructions. - // TODO(taking): Sanity check _Origin, _TypeParams, and _TypeArgs. + // TODO(taking): Sanity check origin, typeparams, and typeargs. s.fn = fn if fn.Prog == nil { s.errorf("nil Prog") @@ -420,16 +420,19 @@ func (s *sanity) checkFunction(fn *Function) bool { strings.HasPrefix(fn.Synthetic, "bound ") || strings.HasPrefix(fn.Synthetic, "thunk ") || strings.HasSuffix(fn.name, "Error") || - strings.HasPrefix(fn.Synthetic, "instantiation") || - (fn.parent != nil && len(fn._TypeArgs) > 0) /* anon fun in instance */ { + strings.HasPrefix(fn.Synthetic, "instance ") || + strings.HasPrefix(fn.Synthetic, "instantiation ") || + (fn.parent != nil && len(fn.typeargs) > 0) /* anon fun in instance */ { // ok } else { s.errorf("nil Pkg") } } if src, syn := fn.Synthetic == "", fn.Syntax() != nil; src != syn { - if strings.HasPrefix(fn.Synthetic, "instantiation") && fn.Prog.mode&InstantiateGenerics != 0 { - // ok + if len(fn.typeargs) > 0 && fn.Prog.mode&InstantiateGenerics != 0 { + // ok (instantiation with InstantiateGenerics on) + } else if fn.topLevelOrigin != nil && len(fn.typeargs) > 0 { + // ok (we always have the syntax set for instantiation) } else { s.errorf("got fromSource=%t, hasSyntax=%t; want same values", src, syn) } @@ -494,6 +497,9 @@ func (s *sanity) checkFunction(fn *Function) bool { if anon.Parent() != fn { s.errorf("AnonFuncs[%d]=%s but %s.Parent()=%s", i, anon, anon, anon.Parent()) } + if i != int(anon.anonIdx) { + s.errorf("AnonFuncs[%d]=%s but %s.anonIdx=%d", i, anon, anon, anon.anonIdx) + } } s.fn = nil return !s.insane diff --git a/go/ssa/ssa.go b/go/ssa/ssa.go index cbc638c81a..698cb16507 100644 --- a/go/ssa/ssa.go +++ b/go/ssa/ssa.go @@ -294,16 +294,15 @@ type Node interface { // // Type() returns the function's Signature. // -// A function is generic iff it has a non-empty TypeParams list and an -// empty TypeArgs list. TypeParams lists the type parameters of the -// function's Signature or the receiver's type parameters for a method. -// -// The instantiation of a generic function is a concrete function. These -// are a list of n>0 TypeParams and n TypeArgs. An instantiation will -// have a generic Origin function. There is at most one instantiation -// of each origin type per Identical() type list. Instantiations do not -// belong to any Pkg. The generic function and the instantiations will -// share the same source Pos for the functions and the instructions. +// A generic function is a function or method that has uninstantiated type +// parameters (TypeParams() != nil). Consider a hypothetical generic +// method, (*Map[K,V]).Get. It may be instantiated with all ground +// (non-parameterized) types as (*Map[string,int]).Get or with +// parameterized types as (*Map[string,U]).Get, where U is a type parameter. +// In both instantiations, Origin() refers to the instantiated generic +// method, (*Map[K,V]).Get, TypeParams() refers to the parameters [K,V] of +// the generic method. TypeArgs() refers to [string,U] or [string,int], +// respectively, and is nil in the generic method. type Function struct { name string object types.Object // a declared *types.Func or one of its wrappers @@ -324,10 +323,11 @@ type Function struct { AnonFuncs []*Function // anonymous functions directly beneath this one referrers []Instruction // referring instructions (iff Parent() != nil) built bool // function has completed both CREATE and BUILD phase. + anonIdx int32 // position of a nested function in parent's AnonFuncs. fn.Parent()!=nil => fn.Parent().AnonFunc[fn.anonIdx] == fn. - _Origin *Function // the origin function if this the instantiation of a generic function. nil if Parent() != nil. - _TypeParams []*typeparams.TypeParam // the type paramaters of this function. len(TypeParams) == len(_TypeArgs) => runtime function - _TypeArgs []types.Type // type arguments for for an instantiation. len(_TypeArgs) != 0 => instantiation + typeparams *typeparams.TypeParamList // type parameters of this function. typeparams.Len() > 0 => generic or instance of generic function + typeargs []types.Type // type arguments that instantiated typeparams. len(typeargs) > 0 => instance of generic function + topLevelOrigin *Function // the origin function if this is an instance of a source function. nil if Parent()!=nil. // The following fields are set transiently during building, // then cleared. @@ -337,7 +337,7 @@ type Function struct { targets *targets // linked stack of branch targets lblocks map[types.Object]*lblock // labelled blocks info *types.Info // *types.Info to build from. nil for wrappers. - subst *subster // type substitution cache + subst *subster // non-nil => expand generic body using this type substitution of ground types } // BasicBlock represents an SSA basic block. @@ -409,26 +409,28 @@ type Parameter struct { referrers []Instruction } -// A Const represents the value of a constant expression. +// A Const represents a value known at build time. // -// The underlying type of a constant may be any boolean, numeric, or -// string type. In addition, a Const may represent the nil value of -// any reference type---interface, map, channel, pointer, slice, or -// function---but not "untyped nil". +// Consts include true constants of boolean, numeric, and string types, as +// defined by the Go spec; these are represented by a non-nil Value field. // -// All source-level constant expressions are represented by a Const -// of the same type and value. -// -// Value holds the value of the constant, independent of its Type(), -// using go/constant representation, or nil for a typed nil value. +// Consts also include the "zero" value of any type, of which the nil values +// of various pointer-like types are a special case; these are represented +// by a nil Value field. // // Pos() returns token.NoPos. // -// Example printed form: +// Example printed forms: // -// 42:int -// "hello":untyped string -// 3+4i:MyComplex +// 42:int +// "hello":untyped string +// 3+4i:MyComplex +// nil:*int +// nil:[]string +// [3]int{}:[3]int +// struct{x string}{}:struct{x string} +// 0:interface{int|int64} +// nil:interface{bool|int} // no go/constant representation type Const struct { typ types.Type Value constant.Value @@ -603,9 +605,17 @@ type UnOp struct { // - between (possibly named) pointers to identical base types. // - from a bidirectional channel to a read- or write-channel, // optionally adding/removing a name. +// - between a type (t) and an instance of the type (tσ), i.e. +// Type() == σ(X.Type()) (or X.Type()== σ(Type())) where +// σ is the type substitution of Parent().TypeParams by +// Parent().TypeArgs. // // This operation cannot fail dynamically. // +// Type changes may to be to or from a type parameter (or both). All +// types in the type set of X.Type() have a value-preserving type +// change to all types in the type set of Type(). +// // Pos() returns the ast.CallExpr.Lparen, if the instruction arose // from an explicit conversion in the source. // @@ -631,6 +641,10 @@ type ChangeType struct { // // A conversion may imply a type name change also. // +// Conversions may to be to or from a type parameter. All types in +// the type set of X.Type() can be converted to all types in the type +// set of Type(). +// // This operation cannot fail dynamically. // // Conversions of untyped string/number/bool constants to a specific @@ -670,6 +684,11 @@ type ChangeInterface struct { // Pos() returns the ast.CallExpr.Lparen, if the instruction arose // from an explicit conversion in the source. // +// Conversion may to be to or from a type parameter. All types in +// the type set of X.Type() must be a slice types that can be converted to +// all types in the type set of Type() which must all be pointer to array +// types. +// // Example printed form: // // t1 = slice to array pointer *[4]byte <- []byte (t0) @@ -809,7 +828,9 @@ type Slice struct { // // Pos() returns the position of the ast.SelectorExpr.Sel for the // field, if explicit in the source. For implicit selections, returns -// the position of the inducing explicit selection. +// the position of the inducing explicit selection. If produced for a +// struct literal S{f: e}, it returns the position of the colon; for +// S{e} it returns the start of expression e. // // Example printed form: // @@ -817,7 +838,7 @@ type Slice struct { type FieldAddr struct { register X Value // *struct - Field int // field is X.Type().Underlying().(*types.Pointer).Elem().Underlying().(*types.Struct).Field(Field) + Field int // field is typeparams.CoreType(X.Type().Underlying().(*types.Pointer).Elem()).(*types.Struct).Field(Field) } // The Field instruction yields the Field of struct X. @@ -836,14 +857,14 @@ type FieldAddr struct { type Field struct { register X Value // struct - Field int // index into X.Type().(*types.Struct).Fields + Field int // index into typeparams.CoreType(X.Type()).(*types.Struct).Fields } // The IndexAddr instruction yields the address of the element at // index Index of collection X. Index is an integer expression. // -// The elements of maps and strings are not addressable; use Lookup or -// MapUpdate instead. +// The elements of maps and strings are not addressable; use Lookup (map), +// Index (string), or MapUpdate instead. // // Dynamically, this instruction panics if X evaluates to a nil *array // pointer. @@ -858,11 +879,13 @@ type Field struct { // t2 = &t0[t1] type IndexAddr struct { register - X Value // slice or *array, + X Value // *array, slice or type parameter with types array, *array, or slice. Index Value // numeric index } -// The Index instruction yields element Index of array X. +// The Index instruction yields element Index of collection X, an array, +// string or type parameter containing an array, a string, a pointer to an, +// array or a slice. // // Pos() returns the ast.IndexExpr.Lbrack for the index operation, if // explicit in the source. @@ -872,13 +895,12 @@ type IndexAddr struct { // t2 = t0[t1] type Index struct { register - X Value // array + X Value // array, string or type parameter with types array, *array, slice, or string. Index Value // integer index } -// The Lookup instruction yields element Index of collection X, a map -// or string. Index is an integer expression if X is a string or the -// appropriate key type if X is a map. +// The Lookup instruction yields element Index of collection map X. +// Index is the appropriate key type. // // If CommaOk, the result is a 2-tuple of the value above and a // boolean indicating the result of a map membership test for the key. @@ -892,8 +914,8 @@ type Index struct { // t5 = t3[t4],ok type Lookup struct { register - X Value // string or map - Index Value // numeric or key-typed index + X Value // map + Index Value // key-typed index CommaOk bool // return a value,ok pair } @@ -1337,9 +1359,10 @@ type anInstruction struct { // 2. "invoke" mode: when Method is non-nil (IsInvoke), a CallCommon // represents a dynamically dispatched call to an interface method. // In this mode, Value is the interface value and Method is the -// interface's abstract method. Note: an abstract method may be -// shared by multiple interfaces due to embedding; Value.Type() -// provides the specific interface used for this call. +// interface's abstract method. The interface value may be a type +// parameter. Note: an abstract method may be shared by multiple +// interfaces due to embedding; Value.Type() provides the specific +// interface used for this call. // // Value is implicitly supplied to the concrete method implementation // as the receiver parameter; in other words, Args[0] holds not the @@ -1378,7 +1401,7 @@ func (c *CallCommon) Signature() *types.Signature { if c.Method != nil { return c.Method.Type().(*types.Signature) } - return c.Value.Type().Underlying().(*types.Signature) + return coreType(c.Value.Type()).(*types.Signature) } // StaticCallee returns the callee if this is a trivially static @@ -1469,6 +1492,29 @@ func (v *Function) Referrers() *[]Instruction { return nil } +// TypeParams are the function's type parameters if generic or the +// type parameters that were instantiated if fn is an instantiation. +// +// TODO(taking): declare result type as *types.TypeParamList +// after we drop support for go1.17. +func (fn *Function) TypeParams() *typeparams.TypeParamList { + return fn.typeparams +} + +// TypeArgs are the types that TypeParams() were instantiated by to create fn +// from fn.Origin(). +func (fn *Function) TypeArgs() []types.Type { return fn.typeargs } + +// Origin is the function fn is an instantiation of. Returns nil if fn is not +// an instantiation. +func (fn *Function) Origin() *Function { + if fn.parent != nil && len(fn.typeargs) > 0 { + // Nested functions are BUILT at a different time than there instances. + return fn.parent.Origin().AnonFuncs[fn.anonIdx] + } + return fn.topLevelOrigin +} + func (v *Parameter) Type() types.Type { return v.typ } func (v *Parameter) Name() string { return v.name } func (v *Parameter) Object() types.Object { return v.object } diff --git a/go/ssa/stdlib_test.go b/go/ssa/stdlib_test.go index 7e02f97a7e..8b9f4238da 100644 --- a/go/ssa/stdlib_test.go +++ b/go/ssa/stdlib_test.go @@ -21,12 +21,10 @@ import ( "testing" "time" - "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" "golang.org/x/tools/internal/testenv" - "golang.org/x/tools/internal/typeparams/genericfeatures" ) func bytesAllocated() uint64 { @@ -51,22 +49,6 @@ func TestStdlib(t *testing.T) { if err != nil { t.Fatal(err) } - var nonGeneric int - for i := 0; i < len(pkgs); i++ { - pkg := pkgs[i] - inspect := inspector.New(pkg.Syntax) - features := genericfeatures.ForPackage(inspect, pkg.TypesInfo) - // Skip standard library packages that use generics. This won't be - // sufficient if any standard library packages start _importing_ packages - // that use generics. - if features != 0 { - t.Logf("skipping package %q which uses generics", pkg.PkgPath) - continue - } - pkgs[nonGeneric] = pkg - nonGeneric++ - } - pkgs = pkgs[:nonGeneric] t1 := time.Now() alloc1 := bytesAllocated() diff --git a/go/ssa/subst.go b/go/ssa/subst.go index b29130ea0c..396626befc 100644 --- a/go/ssa/subst.go +++ b/go/ssa/subst.go @@ -18,6 +18,8 @@ import ( // // Not concurrency-safe. type subster struct { + // TODO(zpavlinovic): replacements can contain type params + // when generating instances inside of a generic function body. replacements map[*typeparams.TypeParam]types.Type // values should contain no type params cache map[types.Type]types.Type // cache of subst results ctxt *typeparams.Context @@ -27,17 +29,17 @@ type subster struct { // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. // targs should not contain any types in tparams. -func makeSubster(ctxt *typeparams.Context, tparams []*typeparams.TypeParam, targs []types.Type, debug bool) *subster { - assert(len(tparams) == len(targs), "makeSubster argument count must match") +func makeSubster(ctxt *typeparams.Context, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster { + assert(tparams.Len() == len(targs), "makeSubster argument count must match") subst := &subster{ - replacements: make(map[*typeparams.TypeParam]types.Type, len(tparams)), + replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()), cache: make(map[types.Type]types.Type), ctxt: ctxt, debug: debug, } - for i, tpar := range tparams { - subst.replacements[tpar] = targs[i] + for i := 0; i < tparams.Len(); i++ { + subst.replacements[tparams.At(i)] = targs[i] } if subst.debug { if err := subst.wellFormed(); err != nil { @@ -331,9 +333,9 @@ func (subst *subster) named(t *types.Named) types.Type { // type N[A any] func() A // func Foo[T](g N[T]) {} // To instantiate Foo[string], one goes through {T->string}. To get the type of g - // one subsitutes T with string in {N with TypeArgs == {T} and TypeParams == {A} } - // to get {N with TypeArgs == {string} and TypeParams == {A} }. - assert(targs.Len() == tparams.Len(), "TypeArgs().Len() must match TypeParams().Len() if present") + // one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } + // to get {N with TypeArgs == {string} and typeparams == {A} }. + assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") for i, n := 0, targs.Len(); i < n; i++ { inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion insts[i] = inst diff --git a/go/ssa/subst_test.go b/go/ssa/subst_test.go index fe84adcc3d..5fa8827000 100644 --- a/go/ssa/subst_test.go +++ b/go/ssa/subst_test.go @@ -99,12 +99,8 @@ var _ L[int] = Fn0[L[int]](nil) } T := tv.Type.(*types.Named) - var tparams []*typeparams.TypeParam - for i, l := 0, typeparams.ForNamed(T); i < l.Len(); i++ { - tparams = append(tparams, l.At(i)) - } - subst := makeSubster(typeparams.NewContext(), tparams, targs, true) + subst := makeSubster(typeparams.NewContext(), typeparams.ForNamed(T), targs, true) sub := subst.typ(T.Underlying()) if got := sub.String(); got != test.want { t.Errorf("subst{%v->%v}.typ(%s) = %v, want %v", test.expr, test.args, T.Underlying(), got, test.want) diff --git a/go/ssa/testdata/valueforexpr.go b/go/ssa/testdata/valueforexpr.go index da76f13a39..243ec614f6 100644 --- a/go/ssa/testdata/valueforexpr.go +++ b/go/ssa/testdata/valueforexpr.go @@ -1,3 +1,4 @@ +//go:build ignore // +build ignore package main diff --git a/go/ssa/util.go b/go/ssa/util.go index 80c7d5cbec..c30e74c7fc 100644 --- a/go/ssa/util.go +++ b/go/ssa/util.go @@ -49,7 +49,56 @@ func isPointer(typ types.Type) bool { return ok } -func isInterface(T types.Type) bool { return types.IsInterface(T) } +// isNonTypeParamInterface reports whether t is an interface type but not a type parameter. +func isNonTypeParamInterface(t types.Type) bool { + return !typeparams.IsTypeParam(t) && types.IsInterface(t) +} + +// isBasic reports whether t is a basic type. +func isBasic(t types.Type) bool { + _, ok := t.(*types.Basic) + return ok +} + +// isString reports whether t is exactly a string type. +func isString(t types.Type) bool { + return isBasic(t) && t.(*types.Basic).Info()&types.IsString != 0 +} + +// isByteSlice reports whether t is []byte. +func isByteSlice(t types.Type) bool { + if b, ok := t.(*types.Slice); ok { + e, _ := b.Elem().(*types.Basic) + return e != nil && e.Kind() == types.Byte + } + return false +} + +// isRuneSlice reports whether t is []rune. +func isRuneSlice(t types.Type) bool { + if b, ok := t.(*types.Slice); ok { + e, _ := b.Elem().(*types.Basic) + return e != nil && e.Kind() == types.Rune + } + return false +} + +// isBasicConvType returns true when a type set can be +// one side of a Convert operation. This is when: +// - All are basic, []byte, or []rune. +// - At least 1 is basic. +// - At most 1 is []byte or []rune. +func isBasicConvTypes(tset typeSet) bool { + basics := 0 + all := tset.underIs(func(t types.Type) bool { + if isBasic(t) { + basics++ + return true + } + return isByteSlice(t) || isRuneSlice(t) + }) + return all && basics >= 1 && len(tset)-basics <= 1 +} // deref returns a pointer's element type; otherwise it returns typ. func deref(typ types.Type) types.Type { @@ -113,7 +162,7 @@ func nonbasicTypes(ts []types.Type) []types.Type { added := make(map[types.Type]bool) // additionally filter duplicates var filtered []types.Type for _, T := range ts { - if _, basic := T.(*types.Basic); !basic { + if !isBasic(T) { if !added[T] { added[T] = true filtered = append(filtered, T) @@ -123,22 +172,6 @@ func nonbasicTypes(ts []types.Type) []types.Type { return filtered } -// isGeneric returns true if a package-level member is generic. -func isGeneric(m Member) bool { - switch m := m.(type) { - case *NamedConst, *Global: - return false - case *Type: - // lifted from types.isGeneric. - named, _ := m.Type().(*types.Named) - return named != nil && named.Obj() != nil && typeparams.NamedTypeArgs(named) == nil && typeparams.ForNamed(named) != nil - case *Function: - return len(m._TypeParams) != len(m._TypeArgs) - default: - panic("unreachable") - } -} - // receiverTypeArgs returns the type arguments to a function's reciever. // Returns an empty list if obj does not have a reciever or its reciever does not have type arguments. func receiverTypeArgs(obj *types.Func) []types.Type { diff --git a/go/ssa/wrappers.go b/go/ssa/wrappers.go index 3f2267c8a1..228daf6158 100644 --- a/go/ssa/wrappers.go +++ b/go/ssa/wrappers.go @@ -120,19 +120,19 @@ func makeWrapper(prog *Program, sel *selection, cr *creator) *Function { // address of implicit C field. var c Call - if r := recvType(obj); !isInterface(r) { // concrete method + if r := recvType(obj); !types.IsInterface(r) { // concrete method if !isPointer(r) { v = emitLoad(fn, v) } callee := prog.originFunc(obj) - if len(callee._TypeParams) > 0 { - callee = prog.instances[callee].lookupOrCreate(receiverTypeArgs(obj), cr) + if callee.typeparams.Len() > 0 { + callee = prog.lookupOrCreateInstance(callee, receiverTypeArgs(obj), cr) } c.Call.Value = callee c.Call.Args = append(c.Call.Args, v) } else { c.Call.Method = obj - c.Call.Value = emitLoad(fn, v) + c.Call.Value = emitLoad(fn, v) // interface (possibly a typeparam) } for _, arg := range fn.Params[1:] { c.Call.Args = append(c.Call.Args, arg) @@ -208,16 +208,16 @@ func makeBound(prog *Program, obj *types.Func, cr *creator) *Function { createParams(fn, 0) var c Call - if !isInterface(recvType(obj)) { // concrete + if !types.IsInterface(recvType(obj)) { // concrete callee := prog.originFunc(obj) - if len(callee._TypeParams) > 0 { - callee = prog.instances[callee].lookupOrCreate(targs, cr) + if callee.typeparams.Len() > 0 { + callee = prog.lookupOrCreateInstance(callee, targs, cr) } c.Call.Value = callee c.Call.Args = []Value{fv} } else { - c.Call.Value = fv c.Call.Method = obj + c.Call.Value = fv // interface (possibly a typeparam) } for _, arg := range fn.Params { c.Call.Args = append(c.Call.Args, arg) @@ -324,3 +324,63 @@ func toSelection(sel *types.Selection) *selection { indirect: sel.Indirect(), } } + +// -- instantiations -------------------------------------------------- + +// buildInstantiationWrapper creates a body for an instantiation +// wrapper fn. The body calls the original generic function, +// bracketed by ChangeType conversions on its arguments and results. +func buildInstantiationWrapper(fn *Function) { + orig := fn.topLevelOrigin + sig := fn.Signature + + fn.startBody() + if sig.Recv() != nil { + fn.addParamObj(sig.Recv()) + } + createParams(fn, 0) + + // Create body. Add a call to origin generic function + // and make type changes between argument and parameters, + // as well as return values. + var c Call + c.Call.Value = orig + if res := orig.Signature.Results(); res.Len() == 1 { + c.typ = res.At(0).Type() + } else { + c.typ = res + } + + // parameter of instance becomes an argument to the call + // to the original generic function. + argOffset := 0 + for i, arg := range fn.Params { + var typ types.Type + if i == 0 && sig.Recv() != nil { + typ = orig.Signature.Recv().Type() + argOffset = 1 + } else { + typ = orig.Signature.Params().At(i - argOffset).Type() + } + c.Call.Args = append(c.Call.Args, emitTypeCoercion(fn, arg, typ)) + } + + results := fn.emit(&c) + var ret Return + switch res := sig.Results(); res.Len() { + case 0: + // no results, do nothing. + case 1: + ret.Results = []Value{emitTypeCoercion(fn, results, res.At(0).Type())} + default: + for i := 0; i < sig.Results().Len(); i++ { + v := emitExtract(fn, results, i) + ret.Results = append(ret.Results, emitTypeCoercion(fn, v, res.At(i).Type())) + } + } + + fn.emit(&ret) + fn.currentBlock = nil + + fn.finishBody() +}