go/go2go: add support for methods of parameterized types

Change-Id: I308eb692612cb8d6e7321c4972e90b102466b4c2
This commit is contained in:
Ian Lance Taylor 2020-01-25 18:52:05 -08:00 committed by Robert Griesemer
parent 1bcfb0add1
commit 12a7da1eb0
5 changed files with 194 additions and 47 deletions

View File

@ -72,7 +72,7 @@ func Rewrite(dir string) error {
}
tpkg, err := conf.Check(name, fset, asts, info)
if err != nil {
return err
return fmt.Errorf("type checking failed for %s: %v", name, err)
}
tpkgs = append(tpkgs, &gpkg{
@ -119,12 +119,12 @@ func RewriteBuffer(filename string, file []byte) ([]byte, error) {
Uses: make(map[*ast.Ident]types.Object),
}
if _, err := conf.Check(pf.Name.Name, fset, []*ast.File{pf}, info); err != nil {
return nil, err
return nil, fmt.Errorf("type checking failed for %s: %v", pf.Name.Name, err)
}
idToFunc := make(map[types.Object]*ast.FuncDecl)
idToTypeSpec := make(map[types.Object]*ast.TypeSpec)
addIDs(info, pf, idToFunc, idToTypeSpec)
if err := rewriteAST(info, idToFunc, idToTypeSpec, pf); err != nil {
if err := rewriteAST(fset, info, idToFunc, idToTypeSpec, pf); err != nil {
return nil, err
}
var buf bytes.Buffer

View File

@ -29,6 +29,46 @@ func newTypeArgs(typeTypes []types.Type) *typeArgs {
}
}
// typeArgsFromTParams builds mappings from a list of type parameters
// expressed as ast.Field values.
func typeArgsFromFields(t *translator, astTypes []ast.Expr, typeTypes []types.Type, tparams []*ast.Field) *typeArgs {
ta := newTypeArgs(typeTypes)
for i, tf := range tparams {
for _, tn := range tf.Names {
obj, ok := t.info.Defs[tn]
if !ok {
panic(fmt.Sprintf("no object for type parameter %q", tn))
}
objType := obj.Type()
objParam, ok := objType.(*types.TypeParam)
if !ok {
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
}
ta.add(obj, objParam, astTypes[i], typeTypes[i])
}
}
return ta
}
// typeArgsFromTParams builds mappings from a list of type parameters
// expressed as ast.Expr values.
func typeArgsFromExprs(t *translator, astTypes []ast.Expr, typeTypes []types.Type, tparams []ast.Expr) *typeArgs {
ta := newTypeArgs(typeTypes)
for i, ti := range tparams {
obj, ok := t.info.Defs[ti.(*ast.Ident)]
if !ok {
panic(fmt.Sprintf("no object for type parameter %q", ti))
}
objType := obj.Type()
objParam, ok := objType.(*types.TypeParam)
if !ok {
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
}
ta.add(obj, objParam, astTypes[i], typeTypes[i])
}
return ta
}
// add adds mappings for obj to ast and typ.
func (ta *typeArgs) add(obj types.Object, objParam *types.TypeParam, ast ast.Expr, typ types.Type) {
ta.toAST[obj] = ast
@ -59,21 +99,7 @@ func (t *translator) instantiateFunction(fnident *ast.Ident, astTypes []ast.Expr
return nil, err
}
ta := newTypeArgs(typeTypes)
for i, tf := range decl.Type.TParams.List {
for _, tn := range tf.Names {
obj, ok := t.info.Defs[tn]
if !ok {
panic(fmt.Sprintf("no object for type parameter %q", tn))
}
objType := obj.Type()
objParam, ok := objType.(*types.TypeParam)
if !ok {
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
}
ta.add(obj, objParam, astTypes[i], typeTypes[i])
}
}
ta := typeArgsFromFields(t, astTypes, typeTypes, decl.Type.TParams.List)
instIdent := ast.NewIdent(name)
@ -104,7 +130,7 @@ func (t *translator) findFuncDecl(id *ast.Ident) (*ast.FuncDecl, error) {
}
// instantiateType creates a new instantiation of a type.
func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) {
func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ *types.Named, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) {
name, err := t.instantiatedName(tident, typeTypes)
if err != nil {
return nil, nil, err
@ -115,21 +141,7 @@ func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astT
return nil, nil, err
}
ta := newTypeArgs(typeTypes)
for i, tf := range spec.TParams.List {
for _, tn := range tf.Names {
obj, ok := t.info.Defs[tn]
if !ok {
panic(fmt.Sprintf("no object for type parameter %q", tn))
}
objType := obj.Type()
objParam, ok := objType.(*types.TypeParam)
if !ok {
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
}
ta.add(obj, objParam, astTypes[i], typeTypes[i])
}
}
ta := typeArgsFromFields(t, astTypes, typeTypes, spec.TParams.List)
instIdent := ast.NewIdent(name)
@ -146,7 +158,47 @@ func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astT
}
t.newDecls = append(t.newDecls, newDecl)
instType := t.instantiateType(ta, typ)
instType := t.instantiateType(ta, typ.Underlying())
nm := typ.NumMethods()
for i := 0; i < nm; i++ {
method := typ.Method(i)
mast := t.idToFunc[method]
if mast == nil {
panic(fmt.Sprintf("no AST for method %v", method))
}
rtyp := mast.Recv.List[0].Type
newRtype := ast.Expr(ast.NewIdent(name))
if p, ok := rtyp.(*ast.StarExpr); ok {
rtyp = p.X
newRtype = &ast.StarExpr{
X: newRtype,
}
}
tparams := rtyp.(*ast.CallExpr).Args
ta := typeArgsFromExprs(t, astTypes, typeTypes, tparams)
newDecl := &ast.FuncDecl{
Doc: mast.Doc,
Recv: &ast.FieldList{
Opening: mast.Recv.Opening,
List: []*ast.Field{
{
Doc: mast.Recv.List[0].Doc,
Names: []*ast.Ident{
mast.Recv.List[0].Names[0],
},
Type: newRtype,
Comment: mast.Recv.List[0].Comment,
},
},
Closing: mast.Recv.Closing,
},
Name: mast.Name,
Type: t.instantiateExpr(ta, mast.Type).(*ast.FuncType),
Body: t.instantiateBlockStmt(ta, mast.Body),
}
t.newDecls = append(t.newDecls, newDecl)
}
return instIdent, instType, nil
}
@ -199,6 +251,18 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt {
return &ast.ExprStmt{
X: x,
}
case *ast.AssignStmt:
lhs, lchanged := t.instantiateExprList(ta, s.Lhs)
rhs, rchanged := t.instantiateExprList(ta, s.Rhs)
if !lchanged && !rchanged {
return s
}
return &ast.AssignStmt{
Lhs: lhs,
TokPos: s.TokPos,
Tok: s.Tok,
Rhs: rhs,
}
case *ast.RangeStmt:
key := t.instantiateExpr(ta, s.Key)
value := t.instantiateExpr(ta, s.Value)
@ -216,6 +280,15 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt {
X: x,
Body: body,
}
case *ast.ReturnStmt:
results, changed := t.instantiateExprList(ta, s.Results)
if !changed {
return s
}
return &ast.ReturnStmt{
Return: s.Return,
Results: results,
}
default:
panic(fmt.Sprintf("unimplemented Stmt %T", s))
}
@ -281,8 +354,8 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
Rparen: e.Rparen,
}
case *ast.Ident:
obj, ok := t.info.Uses[e]
if ok {
obj := t.info.ObjectOf(e)
if obj != nil {
if typ, ok := ta.ast(obj); ok {
return typ
}
@ -297,6 +370,15 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
X: x,
Sel: e.Sel,
}
case *ast.StarExpr:
x := t.instantiateExpr(ta, e.X)
if x == e.X {
return e
}
r = &ast.StarExpr{
Star: e.Star,
X: x,
}
case *ast.FuncType:
params := t.instantiateFieldList(ta, e.Params)
results := t.instantiateFieldList(ta, e.Results)

View File

@ -34,9 +34,9 @@ var nameCodes = map[rune]int{
}
// instantiatedName returns the name of a newly instantiated function.
func (t *translator) instantiatedName(fnident *ast.Ident, types []types.Type) (string, error) {
func (t *translator) instantiatedName(ident *ast.Ident, types []types.Type) (string, error) {
var sb strings.Builder
fmt.Fprintf(&sb, "instantiate%c%s", nameSep, fnident.Name)
fmt.Fprintf(&sb, "instantiate%c%s", nameSep, ident.Name)
for _, typ := range types {
sb.WriteRune(nameSep)
s := typ.String()
@ -50,7 +50,7 @@ func (t *translator) instantiatedName(fnident *ast.Ident, types []types.Type) (s
} else {
code, ok := nameCodes[r]
if !ok {
panic(fmt.Sprintf("unexpected type string character %q", r))
panic(fmt.Sprintf("unexpected type string character %q in %q", r, s))
}
fmt.Fprintf(&sb, "%c%d", nameIntro, code)
}

View File

@ -26,7 +26,7 @@ func addIDs(info *types.Info, f *ast.File, mf map[types.Object]*ast.FuncDecl, mt
for _, decl := range f.Decls {
switch decl := decl.(type) {
case *ast.FuncDecl:
if isParameterizedFuncDecl(decl) {
if isParameterizedFuncDecl(decl, info) {
obj, ok := info.Defs[decl.Name]
if !ok {
panic(fmt.Sprintf("no types.Object for %q", decl.Name.Name))
@ -49,8 +49,26 @@ func addIDs(info *types.Info, f *ast.File, mf map[types.Object]*ast.FuncDecl, mt
}
// isParameterizedFuncDecl reports whether fd is a parameterized function.
func isParameterizedFuncDecl(fd *ast.FuncDecl) bool {
return fd.Type.TParams != nil
func isParameterizedFuncDecl(fd *ast.FuncDecl, info *types.Info) bool {
if fd.Type.TParams != nil {
return true
}
if fd.Recv != nil {
rtyp := info.TypeOf(fd.Recv.List[0].Type)
if rtyp == nil {
// Already instantiated.
return false
}
if p, ok := rtyp.(*types.Pointer); ok {
rtyp = p.Elem()
}
if named, ok := rtyp.(*types.Named); ok {
if named.TParams() != nil {
return true
}
}
}
return false
}
// isParameterizedTypeDecl reports whether s is a parameterized type.
@ -61,6 +79,7 @@ func isParameterizedTypeDecl(s ast.Spec) bool {
// A translator is used to translate a file from Go with contracts to Go 1.
type translator struct {
fset *token.FileSet
info *types.Info
types map[ast.Expr]types.Type
idToFunc map[types.Object]*ast.FuncDecl
@ -89,7 +108,7 @@ type typeInstantiation struct {
// rewrite rewrites the contents of one file.
func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, filename string, file *ast.File) (err error) {
if err := rewriteAST(info, idToFunc, idToTypeSpec, file); err != nil {
if err := rewriteAST(fset, info, idToFunc, idToTypeSpec, file); err != nil {
return err
}
@ -117,8 +136,9 @@ func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map
}
// rewriteAST rewrites the AST for a file.
func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, file *ast.File) (err error) {
func rewriteAST(fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, file *ast.File) (err error) {
t := translator{
fset: fset,
info: info,
types: make(map[ast.Expr]types.Type),
idToFunc: idToFunc,
@ -139,7 +159,7 @@ func (t *translator) translate(file *ast.File) {
for i, decl := range declsToDo {
switch decl := decl.(type) {
case *ast.FuncDecl:
if !isParameterizedFuncDecl(decl) {
if !isParameterizedFuncDecl(decl, t.info) {
t.translateFuncDecl(&declsToDo[i])
newDecls = append(newDecls, decl)
}
@ -260,6 +280,8 @@ func (t *translator) translateStmt(ps *ast.Stmt) {
default:
panic(fmt.Sprintf("unknown decl type %v", d.Tok))
}
case *ast.ReturnStmt:
t.translateExprList(s.Results)
default:
panic(fmt.Sprintf("unimplemented Stmt %T", s))
}
@ -276,9 +298,13 @@ func (t *translator) translateExpr(pe *ast.Expr) {
switch e := (*pe).(type) {
case *ast.Ident:
return
case *ast.ParenExpr:
t.translateExpr(&e.X)
case *ast.BinaryExpr:
t.translateExpr(&e.X)
t.translateExpr(&e.Y)
case *ast.UnaryExpr:
t.translateExpr(&e.X)
case *ast.CallExpr:
t.translateExprList(e.Args)
if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && ftyp.TParams() != nil {
@ -388,7 +414,7 @@ func (t *translator) translateTypeInstantiation(pe *ast.Expr) {
}
}
instIdent, instType, err := t.instantiateTypeDecl(tident, typ.Underlying(), call.Args, types)
instIdent, instType, err := t.instantiateTypeDecl(tident, typ, call.Args, types)
if err != nil {
t.err = err
return

39
test/gen/g004.go2 Normal file
View File

@ -0,0 +1,39 @@
// run
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "fmt"
type Value(type T) struct {
val T
}
func (v *Value(T)) Get() T {
return v.val
}
func (v *Value(T)) Set(val T) {
v.val = val
}
func main() {
var v1 Value(int)
v1.Set(1)
if got, want := v1.Get(), 1; got != want {
panic(fmt.Sprintf("Get() == %d, want %d", got, want))
}
v1.Set(2)
if got, want := v1.Get(), 2; got != want {
panic(fmt.Sprintf("Get() == %d, want %d", got, want))
}
var v2 Value(string)
v2.Set("a")
if got, want := v2.Get(), "a"; got != want {
panic(fmt.Sprintf("Get() == %q, want %q", got, want))
}
}