go/go2go: improve handling of self-referential types

Also, resolve types when looking up instantiations.

Change-Id: I8b2e976d9c0d313fe3c1dd9dafce41dcb59b33bf
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go2-dev/+/767864
Reviewed-by: Ian Lance Taylor <iant@google.com>
This commit is contained in:
Ian Lance Taylor 2020-06-10 17:01:32 -07:00 committed by Robert Griesemer
parent 73c4e1f380
commit 3df5cb5c52
4 changed files with 266 additions and 65 deletions

View File

@ -134,28 +134,24 @@ func (t *translator) findFuncDecl(qid qualifiedIdent) (*ast.FuncDecl, error) {
// It returns nil if the ID is not found.
func (t *translator) findTypesObject(qid qualifiedIdent) types.Object {
if qid.pkg == nil {
return t.importer.info.Uses[qid.ident]
if obj := t.importer.info.ObjectOf(qid.ident); obj != nil {
return obj
}
return t.tpkg.Scope().Lookup(qid.ident.Name)
} else {
return qid.pkg.Scope().Lookup(qid.ident.Name)
}
}
// instantiateType creates a new instantiation of a type.
func (t *translator) instantiateTypeDecl(qid qualifiedIdent, typ *types.Named, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) {
name, err := t.instantiatedName(qid, typeTypes)
if err != nil {
return nil, nil, err
}
func (t *translator) instantiateTypeDecl(qid qualifiedIdent, typ *types.Named, astTypes []ast.Expr, typeTypes []types.Type, instIdent *ast.Ident) (types.Type, error) {
spec, err := t.findTypeSpec(qid)
if err != nil {
return nil, nil, err
return nil, err
}
ta := typeArgsFromFields(t, astTypes, typeTypes, spec.TParams.List)
instIdent := ast.NewIdent(name)
newSpec := &ast.TypeSpec{
Doc: spec.Doc,
Name: instIdent,
@ -181,7 +177,7 @@ func (t *translator) instantiateTypeDecl(qid qualifiedIdent, typ *types.Named, a
panic(fmt.Sprintf("no AST for method %v", method))
}
rtyp := mast.Recv.List[0].Type
newRtype := ast.Expr(ast.NewIdent(name))
newRtype := ast.Expr(ast.NewIdent(instIdent.Name))
if p, ok := rtyp.(*ast.StarExpr); ok {
rtyp = p.X
newRtype = &ast.StarExpr{
@ -213,7 +209,7 @@ func (t *translator) instantiateTypeDecl(qid qualifiedIdent, typ *types.Named, a
t.newDecls = append(t.newDecls, newDecl)
}
return instIdent, instType, nil
return instType, nil
}
// findTypeSpec looks for the TypeSpec for qid.

View File

@ -93,9 +93,10 @@ type instantiation struct {
// A typeInstantiation is a single instantiation of a type.
type typeInstantiation struct {
types []types.Type
decl *ast.Ident
typ types.Type
types []types.Type
decl *ast.Ident
typ types.Type
inProgress bool
}
// rewrite rewrites the contents of one file.
@ -530,12 +531,12 @@ func (t *translator) translateExpr(pe *ast.Expr) {
t.translateExpr(&e.Type)
case *ast.CallExpr:
t.translateExprList(e.Args)
t.translateExpr(&e.Fun)
if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && len(ftyp.TParams()) > 0 {
t.translateFunctionInstantiation(pe)
} else if ntyp, ok := t.lookupType(e.Fun).(*types.Named); ok && len(ntyp.TParams()) > 0 && len(ntyp.TArgs()) == 0 {
t.translateTypeInstantiation(pe)
}
t.translateExpr(&e.Fun)
case *ast.StarExpr:
t.translateExpr(&e.X)
case *ast.UnaryExpr:
@ -633,6 +634,9 @@ func (t *translator) translateFunctionInstantiation(pe *ast.Expr) {
call := (*pe).(*ast.CallExpr)
qid := t.instantiatedIdent(call)
argList, typeList, typeArgs := t.instantiationTypes(call)
if t.err != nil {
return
}
var instIdent *ast.Ident
key := qid.String()
@ -678,26 +682,57 @@ func (t *translator) translateTypeInstantiation(pe *ast.Expr) {
panic("no type arguments for type")
}
var seen *typeInstantiation
instantiations := t.typeInstantiations[typ]
for _, inst := range instantiations {
if t.sameTypes(typeList, inst.types) {
if inst.inProgress {
panic(fmt.Sprintf("%s: circular type instantiation", t.fset.Position((*pe).Pos())))
}
if inst.decl == nil {
// This can happen if we've instantiated
// the type in instantiateType.
seen = inst
break
}
*pe = inst.decl
return
}
}
instIdent, instType, err := t.instantiateTypeDecl(qid, typ, argList, typeList)
name, err := t.instantiatedName(qid, typeList)
if err != nil {
t.err = err
return
}
instIdent := ast.NewIdent(name)
if seen != nil {
seen.decl = instIdent
seen.inProgress = true
} else {
seen = &typeInstantiation{
types: typeList,
decl: instIdent,
typ: nil,
inProgress: true,
}
t.typeInstantiations[typ] = append(instantiations, seen)
}
defer func() {
seen.inProgress = false
}()
instType, err := t.instantiateTypeDecl(qid, typ, argList, typeList, instIdent)
if err != nil {
t.err = err
return
}
n := &typeInstantiation{
types: typeList,
decl: instIdent,
typ: instType,
if seen.typ == nil {
seen.typ = instType
}
t.typeInstantiations[typ] = append(instantiations, n)
*pe = instIdent
}
@ -744,35 +779,104 @@ func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr,
}
typeArgs = true
} else {
for _, typ := range inferred.Targs {
arg := ast.NewIdent(typ.String())
if named, ok := typ.(*types.Named); ok {
if len(named.TArgs()) > 0 {
var narg *ast.Ident
typ, narg = t.lookupInstantiatedType(named)
if narg != nil {
arg = ast.NewIdent(narg.Name)
}
}
if named.Obj().Pkg() == t.tpkg {
fields := strings.Split(arg.Name, ".")
if len(fields) > 1 {
arg = ast.NewIdent(fields[1])
}
}
}
typeList = append(typeList, typ)
argList = append(argList, arg)
t.setType(arg, typ)
}
typeList, argList = t.typeListToASTList(inferred.Targs)
}
typeList = t.resolveTypes(typeList)
return
}
// lookupInstantiatedType looks for an existing instantiation of an
// instantiated type.
func (t *translator) lookupInstantiatedType(typ *types.Named) (types.Type, *ast.Ident) {
copyType := func(typ *types.Named, newName string) types.Type {
nm := typ.NumMethods()
methods := make([]*types.Func, 0, nm)
for i := 0; i < nm; i++ {
methods = append(methods, typ.Method(i))
}
obj := typ.Obj()
obj = types.NewTypeName(obj.Pos(), obj.Pkg(), newName, nil)
nt := types.NewNamed(obj, typ.Underlying(), methods)
nt.SetTArgs(typ.TArgs())
return nt
}
ntype := t.typeWithoutArgs(typ)
targs := t.resolveTypes(typ.TArgs())
instantiations := t.typeInstantiations[ntype]
var seen *typeInstantiation
for _, inst := range instantiations {
if t.sameTypes(targs, inst.types) {
if inst.inProgress {
panic(fmt.Sprintf("instantiation for %v in progress", typ))
}
if inst.decl == nil {
// This can happen if we've instantiated
// the type in instantiateType.
seen = inst
break
}
if inst.typ == nil {
panic(fmt.Sprintf("no type for instantiation entry for %v", typ))
}
if instNamed, ok := inst.typ.(*types.Named); ok {
return copyType(instNamed, inst.decl.Name), inst.decl
}
return inst.typ, inst.decl
}
}
typeList, argList := t.typeListToASTList(targs)
qid := qualifiedIdent{ident: ast.NewIdent(typ.Obj().Name())}
if typPkg := typ.Obj().Pkg(); typPkg != t.tpkg {
qid.pkg = typPkg
}
name, err := t.instantiatedName(qid, typeList)
if err != nil {
t.err = err
return nil, nil
}
instIdent := ast.NewIdent(name)
if seen != nil {
seen.decl = instIdent
seen.inProgress = true
} else {
seen = &typeInstantiation{
types: targs,
decl: instIdent,
typ: nil,
inProgress: true,
}
t.typeInstantiations[ntype] = append(instantiations, seen)
}
defer func() {
seen.inProgress = false
}()
instType, err := t.instantiateTypeDecl(qid, typ, argList, typeList, instIdent)
if err != nil {
t.err = err
return nil, nil
}
if seen.typ == nil {
seen.typ = instType
}
if instNamed, ok := instType.(*types.Named); ok {
return copyType(instNamed, instIdent.Name), instIdent
}
return instType, instIdent
}
// typeWithoutArgs takes a named type with arguments and returns the
// same type without arguments.
func (t *translator) typeWithoutArgs(typ *types.Named) *types.Named {
name := typ.Obj().Name()
fields := strings.Split(name, ".")
if len(fields) > 2 {
@ -787,26 +891,39 @@ func (t *translator) lookupInstantiatedType(typ *types.Named) (types.Type, *ast.
if nobj == nil {
panic(fmt.Sprintf("can't find %q in scope of package %q", name, tpkg.Name()))
}
return nobj.Type().(*types.Named)
}
targs := typ.TArgs()
instantiations := t.typeInstantiations[nobj.Type()]
for _, inst := range instantiations {
if t.sameTypes(targs, inst.types) {
newName := inst.decl.Name
nm := typ.NumMethods()
methods := make([]*types.Func, 0, nm)
for i := 0; i < nm; i++ {
methods = append(methods, typ.Method(i))
// typeListToASTList returns an AST list for a type list,
// as well as an updated type list.
func (t *translator) typeListToASTList(typeList []types.Type) ([]types.Type, []ast.Expr) {
newTypeList := make([]types.Type, 0, len(typeList))
argList := make([]ast.Expr, 0, len(typeList))
for _, typ := range typeList {
arg := ast.NewIdent(typ.String())
if named, ok := typ.(*types.Named); ok {
if len(named.TArgs()) > 0 {
var narg *ast.Ident
typ, narg = t.lookupInstantiatedType(named)
if t.err != nil {
return nil, nil
}
if narg != nil {
arg = ast.NewIdent(narg.Name)
}
}
if named.Obj().Pkg() == t.tpkg {
fields := strings.Split(arg.Name, ".")
if len(fields) > 1 {
arg = ast.NewIdent(fields[1])
}
}
obj := typ.Obj()
obj = types.NewTypeName(obj.Pos(), obj.Pkg(), newName, nil)
nt := types.NewNamed(obj, typ.Underlying(), methods)
nt.SetTArgs(targs)
return nt, inst.decl
}
newTypeList = append(newTypeList, typ)
argList = append(argList, arg)
t.setType(arg, typ)
}
panic(fmt.Sprintf("did not find instantiation for %v %v\n", typ, typ.Underlying()))
return newTypeList, argList
}
// sameTypes reports whether two type slices are the same.

View File

@ -13,8 +13,8 @@ import (
// lookupType returns the types.Type for an AST expression.
// Returns nil if the type is not known.
func (t *translator) lookupType(e ast.Expr) types.Type {
if typ, ok := t.importer.info.Types[e]; ok {
return typ.Type
if typ := t.importer.info.TypeOf(e); typ != nil {
return typ
}
if typ, ok := t.types[e]; ok {
return typ
@ -43,20 +43,32 @@ func (t *translator) setType(e ast.Expr, nt types.Type) {
// instantiateType instantiates typ using ta.
func (t *translator) instantiateType(ta *typeArgs, typ types.Type) types.Type {
var inProgress *typeInstantiation
if insts, ok := t.typeInstantiations[typ]; ok {
for _, inst := range insts {
if t.sameTypes(ta.types, inst.types) {
if inst.typ == nil {
inProgress = inst
break
}
return inst.typ
}
}
}
ityp := t.doInstantiateType(ta, typ)
typinst := &typeInstantiation{
types: ta.types,
typ: ityp,
if inProgress != nil {
if inProgress.typ == nil {
inProgress.typ = ityp
}
} else {
typinst := &typeInstantiation{
types: ta.types,
typ: ityp,
}
t.typeInstantiations[typ] = append(t.typeInstantiations[typ], typinst)
}
t.typeInstantiations[typ] = append(t.typeInstantiations[typ], typinst)
return ityp
}
@ -230,3 +242,31 @@ func (t *translator) instantiateTypeTuple(ta *typeArgs, tuple *types.Tuple) *typ
}
return types.NewTuple(vars...)
}
// resolveType resolves an instantiated type into its underlying type.
func (t *translator) resolveType(typ types.Type) types.Type {
named, ok := typ.(*types.Named)
if !ok || len(named.TArgs()) == 0 {
return typ
}
ta := newTypeArgs(named.TArgs())
named = t.typeWithoutArgs(named)
return t.instantiateType(ta, named)
}
// resolveTypes resolves a list of types into their underlying types.
func (t *translator) resolveTypes(typeList []types.Type) []types.Type {
ntl := make([]types.Type, len(typeList))
changed := false
for i, typ := range typeList {
ntyp := t.resolveType(typ)
if ntyp != typ {
changed = true
}
ntl[i] = ntyp
}
if !changed {
return typeList
}
return ntl
}

48
test/gen/g013.go2 Normal file
View File

@ -0,0 +1,48 @@
// run
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
type Gen(type A) func() (A, bool)
func combine(type T1, T2, T)(g1 Gen(T1), g2 Gen(T2), join func(T1, T2) T) Gen(T) {
return func() (T, bool) {
var t T
t1, ok := g1()
if !ok {
return t, false
}
t2, ok := g2()
if !ok {
return t, false
}
return join(t1, t2), true
}
}
type Pair(type A, B) struct {
A A
B B
}
func NewPair(type A, B)(a A, b B) Pair(A, B) { return Pair(A, B){a, b} }
func Combine2(type A, B)(ga Gen(A), gb Gen(B)) Gen(Pair(A, B)) {
return combine(ga, gb, NewPair(A, B))
}
func main() {
var g1 Gen(int) = func() (int, bool) { return 3, true }
var g2 Gen(string) = func() (string, bool) { return "x", false }
gc := combine(g1, g2, NewPair(int, string))
gc2 := Combine2(g1, g2)
if got, ok := gc(); ok {
panic(got)
}
if got2, ok := gc2(); ok {
panic(got2)
}
}