mirror of https://github.com/golang/go.git
go/go2go: support parameterized types
Added an accessor function to go/types:
// TParams returns the type parameters of the named type t, or nil.
func (t *Named) TParams() []*TypeName { return t.tparams }
Change-Id: Ife2322c73dd6eaecaed42655a57a37541661d1ed
This commit is contained in:
parent
f0d1b476a9
commit
b1322d38b6
|
|
@ -88,12 +88,13 @@ func Rewrite(dir string) error {
|
|||
|
||||
for _, tpkg := range tpkgs {
|
||||
idToFunc := make(map[types.Object]*ast.FuncDecl)
|
||||
idToTypeSpec := make(map[types.Object]*ast.TypeSpec)
|
||||
for _, pkgfile := range tpkg.pkgfiles {
|
||||
addFuncIDs(tpkg.info, pkgfile.ast, idToFunc)
|
||||
addIDs(tpkg.info, pkgfile.ast, idToFunc, idToTypeSpec)
|
||||
}
|
||||
|
||||
for _, pkgfile := range tpkg.pkgfiles {
|
||||
if err := rewriteFile(dir, fset, tpkg.info, idToFunc, pkgfile.name, pkgfile.ast); err != nil {
|
||||
if err := rewriteFile(dir, fset, tpkg.info, idToFunc, idToTypeSpec, pkgfile.name, pkgfile.ast); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -121,8 +122,9 @@ func RewriteBuffer(filename string, file []byte) ([]byte, error) {
|
|||
return nil, err
|
||||
}
|
||||
idToFunc := make(map[types.Object]*ast.FuncDecl)
|
||||
addFuncIDs(info, pf, idToFunc)
|
||||
if err := rewriteAST(info, idToFunc, pf); err != nil {
|
||||
idToTypeSpec := make(map[types.Object]*ast.TypeSpec)
|
||||
addIDs(info, pf, idToFunc, idToTypeSpec)
|
||||
if err := rewriteAST(info, idToFunc, idToTypeSpec, pf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ package go2go
|
|||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
)
|
||||
|
||||
|
|
@ -68,7 +69,7 @@ func (t *translator) instantiateFunction(fnident *ast.Ident, astTypes []ast.Expr
|
|||
objType := obj.Type()
|
||||
objParam, ok := objType.(*types.TypeParam)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("%v is not a TypeParam"))
|
||||
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
|
||||
}
|
||||
ta.add(obj, objParam, astTypes[i], typeTypes[i])
|
||||
}
|
||||
|
|
@ -102,6 +103,68 @@ func (t *translator) findFuncDecl(id *ast.Ident) (*ast.FuncDecl, error) {
|
|||
return decl, nil
|
||||
}
|
||||
|
||||
// instantiateType creates a new instantiation of a type.
|
||||
func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) {
|
||||
name, err := t.instantiatedName(tident, typeTypes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
spec, err := t.findTypeSpec(tident)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ta := newTypeArgs(typeTypes)
|
||||
for i, tf := range spec.TParams.List {
|
||||
for _, tn := range tf.Names {
|
||||
obj, ok := t.info.Defs[tn]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("no object for type parameter %q", tn))
|
||||
}
|
||||
objType := obj.Type()
|
||||
objParam, ok := objType.(*types.TypeParam)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
|
||||
}
|
||||
ta.add(obj, objParam, astTypes[i], typeTypes[i])
|
||||
}
|
||||
}
|
||||
|
||||
instIdent := ast.NewIdent(name)
|
||||
|
||||
newSpec := &ast.TypeSpec{
|
||||
Doc: spec.Doc,
|
||||
Name: instIdent,
|
||||
Assign: spec.Assign,
|
||||
Type: t.instantiateExpr(ta, spec.Type),
|
||||
Comment: spec.Comment,
|
||||
}
|
||||
newDecl := &ast.GenDecl{
|
||||
Tok: token.TYPE,
|
||||
Specs: []ast.Spec{newSpec},
|
||||
}
|
||||
t.newDecls = append(t.newDecls, newDecl)
|
||||
|
||||
instType := t.instantiateType(ta, typ)
|
||||
|
||||
return instIdent, instType, nil
|
||||
}
|
||||
|
||||
// findTypeSpec looks for the TypeSpec for id.
|
||||
// FIXME: Handle imported packages.
|
||||
func (t *translator) findTypeSpec(id *ast.Ident) (*ast.TypeSpec, error) {
|
||||
obj, ok := t.info.Uses[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not find Object for %q", id.Name)
|
||||
}
|
||||
spec, ok := t.idToTypeSpec[obj]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not find type spec for %q", id.Name)
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
// instantiateBlockStmt instantiates a BlockStmt.
|
||||
func (t *translator) instantiateBlockStmt(ta *typeArgs, pbs *ast.BlockStmt) *ast.BlockStmt {
|
||||
changed := false
|
||||
|
|
@ -257,6 +320,16 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
|
|||
Len: ln,
|
||||
Elt: elt,
|
||||
}
|
||||
case *ast.StructType:
|
||||
fields := t.instantiateFieldList(ta, e.Fields)
|
||||
if fields == e.Fields {
|
||||
return e
|
||||
}
|
||||
return &ast.StructType{
|
||||
Struct: e.Struct,
|
||||
Fields: fields,
|
||||
Incomplete: e.Incomplete,
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unimplemented Expr %T", e))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,15 +21,29 @@ var config = printer.Config{
|
|||
Tabwidth: 8,
|
||||
}
|
||||
|
||||
// addFuncIDS finds IDs for generic functions and adds them to a map.
|
||||
func addFuncIDs(info *types.Info, f *ast.File, m map[types.Object]*ast.FuncDecl) {
|
||||
// addIDs finds IDs for generic functions and types and adds them to a map.
|
||||
func addIDs(info *types.Info, f *ast.File, mf map[types.Object]*ast.FuncDecl, mt map[types.Object]*ast.TypeSpec) {
|
||||
for _, decl := range f.Decls {
|
||||
if fd, ok := decl.(*ast.FuncDecl); ok && isParameterizedFuncDecl(fd) {
|
||||
obj, ok := info.Defs[fd.Name]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("no types.Object for %q", fd.Name.Name))
|
||||
switch decl := decl.(type) {
|
||||
case *ast.FuncDecl:
|
||||
if isParameterizedFuncDecl(decl) {
|
||||
obj, ok := info.Defs[decl.Name]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("no types.Object for %q", decl.Name.Name))
|
||||
}
|
||||
mf[obj] = decl
|
||||
}
|
||||
case *ast.GenDecl:
|
||||
if decl.Tok == token.TYPE {
|
||||
for _, s := range decl.Specs {
|
||||
ts := s.(*ast.TypeSpec)
|
||||
obj, ok := info.Defs[ts.Name]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("no types.Object for %q", ts.Name.Name))
|
||||
}
|
||||
mt[obj] = ts
|
||||
}
|
||||
}
|
||||
m[obj] = fd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -39,11 +53,18 @@ func isParameterizedFuncDecl(fd *ast.FuncDecl) bool {
|
|||
return fd.Type.TParams != nil
|
||||
}
|
||||
|
||||
// isParameterizedTypeDecl reports whether s is a parameterized type.
|
||||
func isParameterizedTypeDecl(s ast.Spec) bool {
|
||||
ts := s.(*ast.TypeSpec)
|
||||
return ts.TParams != nil
|
||||
}
|
||||
|
||||
// A translator is used to translate a file from Go with contracts to Go 1.
|
||||
type translator struct {
|
||||
info *types.Info
|
||||
types map[ast.Expr]types.Type
|
||||
idToFunc map[types.Object]*ast.FuncDecl
|
||||
idToTypeSpec map[types.Object]*ast.TypeSpec
|
||||
instantiations map[*ast.Ident][]*instantiation
|
||||
newDecls []ast.Decl
|
||||
typeInstantiations map[types.Type][]*typeInstantiation
|
||||
|
|
@ -62,12 +83,13 @@ type instantiation struct {
|
|||
// A typeInstantiation is a single instantiation of a type.
|
||||
type typeInstantiation struct {
|
||||
types []types.Type
|
||||
decl *ast.Ident
|
||||
typ types.Type
|
||||
}
|
||||
|
||||
// rewrite rewrites the contents of one file.
|
||||
func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, filename string, file *ast.File) (err error) {
|
||||
if err := rewriteAST(info, idToFunc, file); err != nil {
|
||||
func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, filename string, file *ast.File) (err error) {
|
||||
if err := rewriteAST(info, idToFunc, idToTypeSpec, file); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -95,11 +117,12 @@ func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map
|
|||
}
|
||||
|
||||
// rewriteAST rewrites the AST for a file.
|
||||
func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, file *ast.File) (err error) {
|
||||
func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, file *ast.File) (err error) {
|
||||
t := translator{
|
||||
info: info,
|
||||
types: make(map[ast.Expr]types.Type),
|
||||
idToFunc: idToFunc,
|
||||
idToTypeSpec: idToTypeSpec,
|
||||
instantiations: make(map[*ast.Ident][]*instantiation),
|
||||
typeInstantiations: make(map[types.Type][]*typeInstantiation),
|
||||
}
|
||||
|
|
@ -115,23 +138,34 @@ func (t *translator) translate(file *ast.File) {
|
|||
newDecls := make([]ast.Decl, 0, len(declsToDo))
|
||||
for i, decl := range declsToDo {
|
||||
switch decl := decl.(type) {
|
||||
case (*ast.FuncDecl):
|
||||
case *ast.FuncDecl:
|
||||
if !isParameterizedFuncDecl(decl) {
|
||||
t.translateFuncDecl(&declsToDo[i])
|
||||
newDecls = append(newDecls, decl)
|
||||
}
|
||||
case (*ast.GenDecl):
|
||||
case *ast.GenDecl:
|
||||
switch decl.Tok {
|
||||
case token.TYPE:
|
||||
newSpecs := make([]ast.Spec, 0, len(decl.Specs))
|
||||
for j := range decl.Specs {
|
||||
t.translateTypeSpec(&decl.Specs[j])
|
||||
if !isParameterizedTypeDecl(decl.Specs[j]) {
|
||||
t.translateTypeSpec(&decl.Specs[j])
|
||||
newSpecs = append(newSpecs, decl.Specs[j])
|
||||
}
|
||||
}
|
||||
if len(newSpecs) == 0 {
|
||||
decl = nil
|
||||
} else {
|
||||
decl.Specs = newSpecs
|
||||
}
|
||||
case token.VAR, token.CONST:
|
||||
for j := range decl.Specs {
|
||||
t.translateValueSpec(&decl.Specs[j])
|
||||
}
|
||||
}
|
||||
newDecls = append(newDecls, decl)
|
||||
if decl != nil {
|
||||
newDecls = append(newDecls, decl)
|
||||
}
|
||||
default:
|
||||
newDecls = append(newDecls, decl)
|
||||
}
|
||||
|
|
@ -145,11 +179,10 @@ func (t *translator) translate(file *ast.File) {
|
|||
// translateTypeSpec translates a type from Go with contracts to Go 1.
|
||||
func (t *translator) translateTypeSpec(ps *ast.Spec) {
|
||||
ts := (*ps).(*ast.TypeSpec)
|
||||
if ts.TParams == nil {
|
||||
t.translateExpr(&ts.Type)
|
||||
return
|
||||
if ts.TParams != nil {
|
||||
panic("parameterized type")
|
||||
}
|
||||
panic("parameterized type")
|
||||
t.translateExpr(&ts.Type)
|
||||
}
|
||||
|
||||
// translateValueSpec translates a variable or constant from Go with
|
||||
|
|
@ -192,16 +225,41 @@ func (t *translator) translateStmt(ps *ast.Stmt) {
|
|||
if t.err != nil {
|
||||
return
|
||||
}
|
||||
if *ps == nil {
|
||||
return
|
||||
}
|
||||
switch s := (*ps).(type) {
|
||||
case *ast.BlockStmt:
|
||||
t.translateBlockStmt(s)
|
||||
case *ast.ExprStmt:
|
||||
t.translateExpr(&s.X)
|
||||
case *ast.AssignStmt:
|
||||
t.translateExprList(s.Lhs)
|
||||
t.translateExprList(s.Rhs)
|
||||
case *ast.IfStmt:
|
||||
t.translateStmt(&s.Init)
|
||||
t.translateExpr(&s.Cond)
|
||||
t.translateBlockStmt(s.Body)
|
||||
t.translateStmt(&s.Else)
|
||||
case *ast.RangeStmt:
|
||||
t.translateExpr(&s.Key)
|
||||
t.translateExpr(&s.Value)
|
||||
t.translateExpr(&s.X)
|
||||
t.translateBlockStmt(s.Body)
|
||||
case *ast.DeclStmt:
|
||||
d := s.Decl.(*ast.GenDecl)
|
||||
switch d.Tok {
|
||||
case token.TYPE:
|
||||
for i := range d.Specs {
|
||||
t.translateTypeSpec(&d.Specs[i])
|
||||
}
|
||||
case token.CONST, token.VAR:
|
||||
for i := range d.Specs {
|
||||
t.translateValueSpec(&d.Specs[i])
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown decl type %v", d.Tok))
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unimplemented Stmt %T", s))
|
||||
}
|
||||
|
|
@ -218,11 +276,15 @@ func (t *translator) translateExpr(pe *ast.Expr) {
|
|||
switch e := (*pe).(type) {
|
||||
case *ast.Ident:
|
||||
return
|
||||
case *ast.BinaryExpr:
|
||||
t.translateExpr(&e.X)
|
||||
t.translateExpr(&e.Y)
|
||||
case *ast.CallExpr:
|
||||
t.translateExprList(e.Args)
|
||||
ftyp := t.lookupType(e.Fun).(*types.Signature)
|
||||
if ftyp.TParams() != nil {
|
||||
if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && ftyp.TParams() != nil {
|
||||
t.translateFunctionInstantiation(pe)
|
||||
} else if ntyp, ok := t.lookupType(e.Fun).(*types.Named); ok && ntyp.TParams() != nil {
|
||||
t.translateTypeInstantiation(pe)
|
||||
}
|
||||
t.translateExpr(&e.Fun)
|
||||
case *ast.StarExpr:
|
||||
|
|
@ -231,6 +293,8 @@ func (t *translator) translateExpr(pe *ast.Expr) {
|
|||
t.translateExpr(&e.X)
|
||||
case *ast.ArrayType:
|
||||
t.translateExpr(&e.Elt)
|
||||
case *ast.StructType:
|
||||
t.translateFieldList(e.Fields)
|
||||
case *ast.BasicLit:
|
||||
return
|
||||
case *ast.CompositeLit:
|
||||
|
|
@ -301,6 +365,45 @@ func (t *translator) translateFunctionInstantiation(pe *ast.Expr) {
|
|||
*pe = instIdent
|
||||
}
|
||||
|
||||
// translateTypeInstantiation translates an instantiated type to Go 1.
|
||||
func (t *translator) translateTypeInstantiation(pe *ast.Expr) {
|
||||
call := (*pe).(*ast.CallExpr)
|
||||
tident, ok := call.Fun.(*ast.Ident)
|
||||
if !ok {
|
||||
panic("instantiated type non-ident")
|
||||
}
|
||||
|
||||
typ := t.lookupType(call.Fun).(*types.Named)
|
||||
|
||||
types := make([]types.Type, 0, len(call.Args))
|
||||
for _, arg := range call.Args {
|
||||
types = append(types, t.lookupType(arg))
|
||||
}
|
||||
|
||||
instantiations := t.typeInstantiations[typ]
|
||||
for _, inst := range instantiations {
|
||||
if t.sameTypes(types, inst.types) {
|
||||
*pe = inst.decl
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
instIdent, instType, err := t.instantiateTypeDecl(tident, typ.Underlying(), call.Args, types)
|
||||
if err != nil {
|
||||
t.err = err
|
||||
return
|
||||
}
|
||||
|
||||
n := &typeInstantiation{
|
||||
types: types,
|
||||
decl: instIdent,
|
||||
typ: instType,
|
||||
}
|
||||
t.typeInstantiations[typ] = append(instantiations, n)
|
||||
|
||||
*pe = instIdent
|
||||
}
|
||||
|
||||
// sameTypes reports whether two type slices are the same.
|
||||
func (t *translator) sameTypes(a, b []types.Type) bool {
|
||||
if len(a) != len(b) {
|
||||
|
|
|
|||
|
|
@ -88,6 +88,33 @@ func (t *translator) doInstantiateType(ta *typeArgs, typ types.Type) types.Type
|
|||
return r
|
||||
case *types.Tuple:
|
||||
return t.instantiateTypeTuple(ta, typ)
|
||||
case *types.Struct:
|
||||
n := typ.NumFields()
|
||||
fields := make([]*types.Var, n)
|
||||
changed := false
|
||||
tags := make([]string, n)
|
||||
hasTag := false
|
||||
for i := 0; i < n; i++ {
|
||||
v := typ.Field(i)
|
||||
instType := t.instantiateType(ta, v.Type())
|
||||
if v.Type() != instType {
|
||||
changed = true
|
||||
}
|
||||
fields[i] = types.NewVar(v.Pos(), v.Pkg(), v.Name(), instType)
|
||||
|
||||
tag := typ.Tag(i)
|
||||
if tag != "" {
|
||||
tags[i] = tag
|
||||
hasTag = true
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return typ
|
||||
}
|
||||
if !hasTag {
|
||||
tags = nil
|
||||
}
|
||||
return types.NewStruct(fields, tags)
|
||||
default:
|
||||
panic(fmt.Sprintf("unimplemented Type %T", typ))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -512,6 +512,9 @@ func NewNamed(obj *TypeName, underlying Type, methods []*Func) *Named {
|
|||
// Obj returns the type name for the named type t.
|
||||
func (t *Named) Obj() *TypeName { return t.obj }
|
||||
|
||||
// TParams returns the type parameters of the named type t, or nil.
|
||||
func (t *Named) TParams() []*TypeName { return t.tparams }
|
||||
|
||||
// NumMethods returns the number of explicit methods whose receiver is named type t.
|
||||
func (t *Named) NumMethods() int { return len(t.methods) }
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
// run
|
||||
|
||||
// Copyright 2020 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Pair(type F1, F2) struct {
|
||||
f1 F1
|
||||
f2 F2
|
||||
}
|
||||
|
||||
func main() {
|
||||
p := Pair(int32, int64){1, 2}
|
||||
if got, want := unsafe.Sizeof(p.f1), uintptr(4); got != want {
|
||||
panic(fmt.Sprintf("unexpected f1 size == %d want %d", got, want))
|
||||
}
|
||||
if got, want := unsafe.Sizeof(p.f2), uintptr(8); got != want {
|
||||
panic(fmt.Sprintf("unexpected f2 size == %d want %d", got, want))
|
||||
}
|
||||
type MyPair struct { f1 int32; f2 int64 }
|
||||
mp := MyPair(p)
|
||||
if mp.f1 != 1 || mp.f2 != 2 {
|
||||
panic(fmt.Sprintf("mp == %#v want %#v", mp, MyPair{1, 2}))
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue