go/go2go: support parameterized types

Added an accessor function to go/types:

// TParams returns the type parameters of the named type t, or nil.
func (t *Named) TParams() []*TypeName { return t.tparams }

Change-Id: Ife2322c73dd6eaecaed42655a57a37541661d1ed
This commit is contained in:
Ian Lance Taylor 2020-01-24 16:43:51 -08:00 committed by Robert Griesemer
parent f0d1b476a9
commit b1322d38b6
6 changed files with 265 additions and 25 deletions

View File

@ -88,12 +88,13 @@ func Rewrite(dir string) error {
for _, tpkg := range tpkgs {
idToFunc := make(map[types.Object]*ast.FuncDecl)
idToTypeSpec := make(map[types.Object]*ast.TypeSpec)
for _, pkgfile := range tpkg.pkgfiles {
addFuncIDs(tpkg.info, pkgfile.ast, idToFunc)
addIDs(tpkg.info, pkgfile.ast, idToFunc, idToTypeSpec)
}
for _, pkgfile := range tpkg.pkgfiles {
if err := rewriteFile(dir, fset, tpkg.info, idToFunc, pkgfile.name, pkgfile.ast); err != nil {
if err := rewriteFile(dir, fset, tpkg.info, idToFunc, idToTypeSpec, pkgfile.name, pkgfile.ast); err != nil {
return err
}
}
@ -121,8 +122,9 @@ func RewriteBuffer(filename string, file []byte) ([]byte, error) {
return nil, err
}
idToFunc := make(map[types.Object]*ast.FuncDecl)
addFuncIDs(info, pf, idToFunc)
if err := rewriteAST(info, idToFunc, pf); err != nil {
idToTypeSpec := make(map[types.Object]*ast.TypeSpec)
addIDs(info, pf, idToFunc, idToTypeSpec)
if err := rewriteAST(info, idToFunc, idToTypeSpec, pf); err != nil {
return nil, err
}
var buf bytes.Buffer

View File

@ -7,6 +7,7 @@ package go2go
import (
"fmt"
"go/ast"
"go/token"
"go/types"
)
@ -68,7 +69,7 @@ func (t *translator) instantiateFunction(fnident *ast.Ident, astTypes []ast.Expr
objType := obj.Type()
objParam, ok := objType.(*types.TypeParam)
if !ok {
panic(fmt.Sprintf("%v is not a TypeParam"))
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
}
ta.add(obj, objParam, astTypes[i], typeTypes[i])
}
@ -102,6 +103,68 @@ func (t *translator) findFuncDecl(id *ast.Ident) (*ast.FuncDecl, error) {
return decl, nil
}
// instantiateType creates a new instantiation of a type.
func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) {
name, err := t.instantiatedName(tident, typeTypes)
if err != nil {
return nil, nil, err
}
spec, err := t.findTypeSpec(tident)
if err != nil {
return nil, nil, err
}
ta := newTypeArgs(typeTypes)
for i, tf := range spec.TParams.List {
for _, tn := range tf.Names {
obj, ok := t.info.Defs[tn]
if !ok {
panic(fmt.Sprintf("no object for type parameter %q", tn))
}
objType := obj.Type()
objParam, ok := objType.(*types.TypeParam)
if !ok {
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
}
ta.add(obj, objParam, astTypes[i], typeTypes[i])
}
}
instIdent := ast.NewIdent(name)
newSpec := &ast.TypeSpec{
Doc: spec.Doc,
Name: instIdent,
Assign: spec.Assign,
Type: t.instantiateExpr(ta, spec.Type),
Comment: spec.Comment,
}
newDecl := &ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{newSpec},
}
t.newDecls = append(t.newDecls, newDecl)
instType := t.instantiateType(ta, typ)
return instIdent, instType, nil
}
// findTypeSpec looks for the TypeSpec for id.
// FIXME: Handle imported packages.
func (t *translator) findTypeSpec(id *ast.Ident) (*ast.TypeSpec, error) {
obj, ok := t.info.Uses[id]
if !ok {
return nil, fmt.Errorf("could not find Object for %q", id.Name)
}
spec, ok := t.idToTypeSpec[obj]
if !ok {
return nil, fmt.Errorf("could not find type spec for %q", id.Name)
}
return spec, nil
}
// instantiateBlockStmt instantiates a BlockStmt.
func (t *translator) instantiateBlockStmt(ta *typeArgs, pbs *ast.BlockStmt) *ast.BlockStmt {
changed := false
@ -257,6 +320,16 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
Len: ln,
Elt: elt,
}
case *ast.StructType:
fields := t.instantiateFieldList(ta, e.Fields)
if fields == e.Fields {
return e
}
return &ast.StructType{
Struct: e.Struct,
Fields: fields,
Incomplete: e.Incomplete,
}
default:
panic(fmt.Sprintf("unimplemented Expr %T", e))
}

View File

@ -21,15 +21,29 @@ var config = printer.Config{
Tabwidth: 8,
}
// addFuncIDS finds IDs for generic functions and adds them to a map.
func addFuncIDs(info *types.Info, f *ast.File, m map[types.Object]*ast.FuncDecl) {
// addIDs finds IDs for generic functions and types and adds them to a map.
func addIDs(info *types.Info, f *ast.File, mf map[types.Object]*ast.FuncDecl, mt map[types.Object]*ast.TypeSpec) {
for _, decl := range f.Decls {
if fd, ok := decl.(*ast.FuncDecl); ok && isParameterizedFuncDecl(fd) {
obj, ok := info.Defs[fd.Name]
if !ok {
panic(fmt.Sprintf("no types.Object for %q", fd.Name.Name))
switch decl := decl.(type) {
case *ast.FuncDecl:
if isParameterizedFuncDecl(decl) {
obj, ok := info.Defs[decl.Name]
if !ok {
panic(fmt.Sprintf("no types.Object for %q", decl.Name.Name))
}
mf[obj] = decl
}
case *ast.GenDecl:
if decl.Tok == token.TYPE {
for _, s := range decl.Specs {
ts := s.(*ast.TypeSpec)
obj, ok := info.Defs[ts.Name]
if !ok {
panic(fmt.Sprintf("no types.Object for %q", ts.Name.Name))
}
mt[obj] = ts
}
}
m[obj] = fd
}
}
}
@ -39,11 +53,18 @@ func isParameterizedFuncDecl(fd *ast.FuncDecl) bool {
return fd.Type.TParams != nil
}
// isParameterizedTypeDecl reports whether s is a parameterized type.
func isParameterizedTypeDecl(s ast.Spec) bool {
ts := s.(*ast.TypeSpec)
return ts.TParams != nil
}
// A translator is used to translate a file from Go with contracts to Go 1.
type translator struct {
info *types.Info
types map[ast.Expr]types.Type
idToFunc map[types.Object]*ast.FuncDecl
idToTypeSpec map[types.Object]*ast.TypeSpec
instantiations map[*ast.Ident][]*instantiation
newDecls []ast.Decl
typeInstantiations map[types.Type][]*typeInstantiation
@ -62,12 +83,13 @@ type instantiation struct {
// A typeInstantiation is a single instantiation of a type.
type typeInstantiation struct {
types []types.Type
decl *ast.Ident
typ types.Type
}
// rewrite rewrites the contents of one file.
func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, filename string, file *ast.File) (err error) {
if err := rewriteAST(info, idToFunc, file); err != nil {
func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, filename string, file *ast.File) (err error) {
if err := rewriteAST(info, idToFunc, idToTypeSpec, file); err != nil {
return err
}
@ -95,11 +117,12 @@ func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map
}
// rewriteAST rewrites the AST for a file.
func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, file *ast.File) (err error) {
func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, file *ast.File) (err error) {
t := translator{
info: info,
types: make(map[ast.Expr]types.Type),
idToFunc: idToFunc,
idToTypeSpec: idToTypeSpec,
instantiations: make(map[*ast.Ident][]*instantiation),
typeInstantiations: make(map[types.Type][]*typeInstantiation),
}
@ -115,23 +138,34 @@ func (t *translator) translate(file *ast.File) {
newDecls := make([]ast.Decl, 0, len(declsToDo))
for i, decl := range declsToDo {
switch decl := decl.(type) {
case (*ast.FuncDecl):
case *ast.FuncDecl:
if !isParameterizedFuncDecl(decl) {
t.translateFuncDecl(&declsToDo[i])
newDecls = append(newDecls, decl)
}
case (*ast.GenDecl):
case *ast.GenDecl:
switch decl.Tok {
case token.TYPE:
newSpecs := make([]ast.Spec, 0, len(decl.Specs))
for j := range decl.Specs {
t.translateTypeSpec(&decl.Specs[j])
if !isParameterizedTypeDecl(decl.Specs[j]) {
t.translateTypeSpec(&decl.Specs[j])
newSpecs = append(newSpecs, decl.Specs[j])
}
}
if len(newSpecs) == 0 {
decl = nil
} else {
decl.Specs = newSpecs
}
case token.VAR, token.CONST:
for j := range decl.Specs {
t.translateValueSpec(&decl.Specs[j])
}
}
newDecls = append(newDecls, decl)
if decl != nil {
newDecls = append(newDecls, decl)
}
default:
newDecls = append(newDecls, decl)
}
@ -145,11 +179,10 @@ func (t *translator) translate(file *ast.File) {
// translateTypeSpec translates a type from Go with contracts to Go 1.
func (t *translator) translateTypeSpec(ps *ast.Spec) {
ts := (*ps).(*ast.TypeSpec)
if ts.TParams == nil {
t.translateExpr(&ts.Type)
return
if ts.TParams != nil {
panic("parameterized type")
}
panic("parameterized type")
t.translateExpr(&ts.Type)
}
// translateValueSpec translates a variable or constant from Go with
@ -192,16 +225,41 @@ func (t *translator) translateStmt(ps *ast.Stmt) {
if t.err != nil {
return
}
if *ps == nil {
return
}
switch s := (*ps).(type) {
case *ast.BlockStmt:
t.translateBlockStmt(s)
case *ast.ExprStmt:
t.translateExpr(&s.X)
case *ast.AssignStmt:
t.translateExprList(s.Lhs)
t.translateExprList(s.Rhs)
case *ast.IfStmt:
t.translateStmt(&s.Init)
t.translateExpr(&s.Cond)
t.translateBlockStmt(s.Body)
t.translateStmt(&s.Else)
case *ast.RangeStmt:
t.translateExpr(&s.Key)
t.translateExpr(&s.Value)
t.translateExpr(&s.X)
t.translateBlockStmt(s.Body)
case *ast.DeclStmt:
d := s.Decl.(*ast.GenDecl)
switch d.Tok {
case token.TYPE:
for i := range d.Specs {
t.translateTypeSpec(&d.Specs[i])
}
case token.CONST, token.VAR:
for i := range d.Specs {
t.translateValueSpec(&d.Specs[i])
}
default:
panic(fmt.Sprintf("unknown decl type %v", d.Tok))
}
default:
panic(fmt.Sprintf("unimplemented Stmt %T", s))
}
@ -218,11 +276,15 @@ func (t *translator) translateExpr(pe *ast.Expr) {
switch e := (*pe).(type) {
case *ast.Ident:
return
case *ast.BinaryExpr:
t.translateExpr(&e.X)
t.translateExpr(&e.Y)
case *ast.CallExpr:
t.translateExprList(e.Args)
ftyp := t.lookupType(e.Fun).(*types.Signature)
if ftyp.TParams() != nil {
if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && ftyp.TParams() != nil {
t.translateFunctionInstantiation(pe)
} else if ntyp, ok := t.lookupType(e.Fun).(*types.Named); ok && ntyp.TParams() != nil {
t.translateTypeInstantiation(pe)
}
t.translateExpr(&e.Fun)
case *ast.StarExpr:
@ -231,6 +293,8 @@ func (t *translator) translateExpr(pe *ast.Expr) {
t.translateExpr(&e.X)
case *ast.ArrayType:
t.translateExpr(&e.Elt)
case *ast.StructType:
t.translateFieldList(e.Fields)
case *ast.BasicLit:
return
case *ast.CompositeLit:
@ -301,6 +365,45 @@ func (t *translator) translateFunctionInstantiation(pe *ast.Expr) {
*pe = instIdent
}
// translateTypeInstantiation translates an instantiated type to Go 1.
func (t *translator) translateTypeInstantiation(pe *ast.Expr) {
call := (*pe).(*ast.CallExpr)
tident, ok := call.Fun.(*ast.Ident)
if !ok {
panic("instantiated type non-ident")
}
typ := t.lookupType(call.Fun).(*types.Named)
types := make([]types.Type, 0, len(call.Args))
for _, arg := range call.Args {
types = append(types, t.lookupType(arg))
}
instantiations := t.typeInstantiations[typ]
for _, inst := range instantiations {
if t.sameTypes(types, inst.types) {
*pe = inst.decl
return
}
}
instIdent, instType, err := t.instantiateTypeDecl(tident, typ.Underlying(), call.Args, types)
if err != nil {
t.err = err
return
}
n := &typeInstantiation{
types: types,
decl: instIdent,
typ: instType,
}
t.typeInstantiations[typ] = append(instantiations, n)
*pe = instIdent
}
// sameTypes reports whether two type slices are the same.
func (t *translator) sameTypes(a, b []types.Type) bool {
if len(a) != len(b) {

View File

@ -88,6 +88,33 @@ func (t *translator) doInstantiateType(ta *typeArgs, typ types.Type) types.Type
return r
case *types.Tuple:
return t.instantiateTypeTuple(ta, typ)
case *types.Struct:
n := typ.NumFields()
fields := make([]*types.Var, n)
changed := false
tags := make([]string, n)
hasTag := false
for i := 0; i < n; i++ {
v := typ.Field(i)
instType := t.instantiateType(ta, v.Type())
if v.Type() != instType {
changed = true
}
fields[i] = types.NewVar(v.Pos(), v.Pkg(), v.Name(), instType)
tag := typ.Tag(i)
if tag != "" {
tags[i] = tag
hasTag = true
}
}
if !changed {
return typ
}
if !hasTag {
tags = nil
}
return types.NewStruct(fields, tags)
default:
panic(fmt.Sprintf("unimplemented Type %T", typ))
}

View File

@ -512,6 +512,9 @@ func NewNamed(obj *TypeName, underlying Type, methods []*Func) *Named {
// Obj returns the type name for the named type t.
func (t *Named) Obj() *TypeName { return t.obj }
// TParams returns the type parameters of the named type t, or nil.
func (t *Named) TParams() []*TypeName { return t.tparams }
// NumMethods returns the number of explicit methods whose receiver is named type t.
func (t *Named) NumMethods() int { return len(t.methods) }

32
test/gen/g003.go2 Normal file
View File

@ -0,0 +1,32 @@
// run
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"unsafe"
)
type Pair(type F1, F2) struct {
f1 F1
f2 F2
}
func main() {
p := Pair(int32, int64){1, 2}
if got, want := unsafe.Sizeof(p.f1), uintptr(4); got != want {
panic(fmt.Sprintf("unexpected f1 size == %d want %d", got, want))
}
if got, want := unsafe.Sizeof(p.f2), uintptr(8); got != want {
panic(fmt.Sprintf("unexpected f2 size == %d want %d", got, want))
}
type MyPair struct { f1 int32; f2 int64 }
mp := MyPair(p)
if mp.f1 != 1 || mp.f2 != 2 {
panic(fmt.Sprintf("mp == %#v want %#v", mp, MyPair{1, 2}))
}
}