From 12a7da1eb004698840a05e4fca8504baaedfe65e Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Sat, 25 Jan 2020 18:52:05 -0800 Subject: [PATCH] go/go2go: add support for methods of parameterized types Change-Id: I308eb692612cb8d6e7321c4972e90b102466b4c2 --- src/go/go2go/go2go.go | 6 +- src/go/go2go/instantiate.go | 150 ++++++++++++++++++++++++++++-------- src/go/go2go/names.go | 6 +- src/go/go2go/rewrite.go | 40 ++++++++-- test/gen/g004.go2 | 39 ++++++++++ 5 files changed, 194 insertions(+), 47 deletions(-) create mode 100644 test/gen/g004.go2 diff --git a/src/go/go2go/go2go.go b/src/go/go2go/go2go.go index b3a938d8c6..7cc63b4ddd 100644 --- a/src/go/go2go/go2go.go +++ b/src/go/go2go/go2go.go @@ -72,7 +72,7 @@ func Rewrite(dir string) error { } tpkg, err := conf.Check(name, fset, asts, info) if err != nil { - return err + return fmt.Errorf("type checking failed for %s: %v", name, err) } tpkgs = append(tpkgs, &gpkg{ @@ -119,12 +119,12 @@ func RewriteBuffer(filename string, file []byte) ([]byte, error) { Uses: make(map[*ast.Ident]types.Object), } if _, err := conf.Check(pf.Name.Name, fset, []*ast.File{pf}, info); err != nil { - return nil, err + return nil, fmt.Errorf("type checking failed for %s: %v", pf.Name.Name, err) } idToFunc := make(map[types.Object]*ast.FuncDecl) idToTypeSpec := make(map[types.Object]*ast.TypeSpec) addIDs(info, pf, idToFunc, idToTypeSpec) - if err := rewriteAST(info, idToFunc, idToTypeSpec, pf); err != nil { + if err := rewriteAST(fset, info, idToFunc, idToTypeSpec, pf); err != nil { return nil, err } var buf bytes.Buffer diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go index ca4cfde289..55d716065c 100644 --- a/src/go/go2go/instantiate.go +++ b/src/go/go2go/instantiate.go @@ -29,6 +29,46 @@ func newTypeArgs(typeTypes []types.Type) *typeArgs { } } +// typeArgsFromTParams builds mappings from a list of type parameters +// expressed as ast.Field values. +func typeArgsFromFields(t *translator, astTypes []ast.Expr, typeTypes []types.Type, tparams []*ast.Field) *typeArgs { + ta := newTypeArgs(typeTypes) + for i, tf := range tparams { + for _, tn := range tf.Names { + obj, ok := t.info.Defs[tn] + if !ok { + panic(fmt.Sprintf("no object for type parameter %q", tn)) + } + objType := obj.Type() + objParam, ok := objType.(*types.TypeParam) + if !ok { + panic(fmt.Sprintf("%v is not a TypeParam", objParam)) + } + ta.add(obj, objParam, astTypes[i], typeTypes[i]) + } + } + return ta +} + +// typeArgsFromTParams builds mappings from a list of type parameters +// expressed as ast.Expr values. +func typeArgsFromExprs(t *translator, astTypes []ast.Expr, typeTypes []types.Type, tparams []ast.Expr) *typeArgs { + ta := newTypeArgs(typeTypes) + for i, ti := range tparams { + obj, ok := t.info.Defs[ti.(*ast.Ident)] + if !ok { + panic(fmt.Sprintf("no object for type parameter %q", ti)) + } + objType := obj.Type() + objParam, ok := objType.(*types.TypeParam) + if !ok { + panic(fmt.Sprintf("%v is not a TypeParam", objParam)) + } + ta.add(obj, objParam, astTypes[i], typeTypes[i]) + } + return ta +} + // add adds mappings for obj to ast and typ. func (ta *typeArgs) add(obj types.Object, objParam *types.TypeParam, ast ast.Expr, typ types.Type) { ta.toAST[obj] = ast @@ -59,21 +99,7 @@ func (t *translator) instantiateFunction(fnident *ast.Ident, astTypes []ast.Expr return nil, err } - ta := newTypeArgs(typeTypes) - for i, tf := range decl.Type.TParams.List { - for _, tn := range tf.Names { - obj, ok := t.info.Defs[tn] - if !ok { - panic(fmt.Sprintf("no object for type parameter %q", tn)) - } - objType := obj.Type() - objParam, ok := objType.(*types.TypeParam) - if !ok { - panic(fmt.Sprintf("%v is not a TypeParam", objParam)) - } - ta.add(obj, objParam, astTypes[i], typeTypes[i]) - } - } + ta := typeArgsFromFields(t, astTypes, typeTypes, decl.Type.TParams.List) instIdent := ast.NewIdent(name) @@ -104,7 +130,7 @@ func (t *translator) findFuncDecl(id *ast.Ident) (*ast.FuncDecl, error) { } // instantiateType creates a new instantiation of a type. -func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) { +func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ *types.Named, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) { name, err := t.instantiatedName(tident, typeTypes) if err != nil { return nil, nil, err @@ -115,21 +141,7 @@ func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astT return nil, nil, err } - ta := newTypeArgs(typeTypes) - for i, tf := range spec.TParams.List { - for _, tn := range tf.Names { - obj, ok := t.info.Defs[tn] - if !ok { - panic(fmt.Sprintf("no object for type parameter %q", tn)) - } - objType := obj.Type() - objParam, ok := objType.(*types.TypeParam) - if !ok { - panic(fmt.Sprintf("%v is not a TypeParam", objParam)) - } - ta.add(obj, objParam, astTypes[i], typeTypes[i]) - } - } + ta := typeArgsFromFields(t, astTypes, typeTypes, spec.TParams.List) instIdent := ast.NewIdent(name) @@ -146,7 +158,47 @@ func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astT } t.newDecls = append(t.newDecls, newDecl) - instType := t.instantiateType(ta, typ) + instType := t.instantiateType(ta, typ.Underlying()) + + nm := typ.NumMethods() + for i := 0; i < nm; i++ { + method := typ.Method(i) + mast := t.idToFunc[method] + if mast == nil { + panic(fmt.Sprintf("no AST for method %v", method)) + } + rtyp := mast.Recv.List[0].Type + newRtype := ast.Expr(ast.NewIdent(name)) + if p, ok := rtyp.(*ast.StarExpr); ok { + rtyp = p.X + newRtype = &ast.StarExpr{ + X: newRtype, + } + } + tparams := rtyp.(*ast.CallExpr).Args + ta := typeArgsFromExprs(t, astTypes, typeTypes, tparams) + newDecl := &ast.FuncDecl{ + Doc: mast.Doc, + Recv: &ast.FieldList{ + Opening: mast.Recv.Opening, + List: []*ast.Field{ + { + Doc: mast.Recv.List[0].Doc, + Names: []*ast.Ident{ + mast.Recv.List[0].Names[0], + }, + Type: newRtype, + Comment: mast.Recv.List[0].Comment, + }, + }, + Closing: mast.Recv.Closing, + }, + Name: mast.Name, + Type: t.instantiateExpr(ta, mast.Type).(*ast.FuncType), + Body: t.instantiateBlockStmt(ta, mast.Body), + } + t.newDecls = append(t.newDecls, newDecl) + } return instIdent, instType, nil } @@ -199,6 +251,18 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { return &ast.ExprStmt{ X: x, } + case *ast.AssignStmt: + lhs, lchanged := t.instantiateExprList(ta, s.Lhs) + rhs, rchanged := t.instantiateExprList(ta, s.Rhs) + if !lchanged && !rchanged { + return s + } + return &ast.AssignStmt{ + Lhs: lhs, + TokPos: s.TokPos, + Tok: s.Tok, + Rhs: rhs, + } case *ast.RangeStmt: key := t.instantiateExpr(ta, s.Key) value := t.instantiateExpr(ta, s.Value) @@ -216,6 +280,15 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { X: x, Body: body, } + case *ast.ReturnStmt: + results, changed := t.instantiateExprList(ta, s.Results) + if !changed { + return s + } + return &ast.ReturnStmt{ + Return: s.Return, + Results: results, + } default: panic(fmt.Sprintf("unimplemented Stmt %T", s)) } @@ -281,8 +354,8 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { Rparen: e.Rparen, } case *ast.Ident: - obj, ok := t.info.Uses[e] - if ok { + obj := t.info.ObjectOf(e) + if obj != nil { if typ, ok := ta.ast(obj); ok { return typ } @@ -297,6 +370,15 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { X: x, Sel: e.Sel, } + case *ast.StarExpr: + x := t.instantiateExpr(ta, e.X) + if x == e.X { + return e + } + r = &ast.StarExpr{ + Star: e.Star, + X: x, + } case *ast.FuncType: params := t.instantiateFieldList(ta, e.Params) results := t.instantiateFieldList(ta, e.Results) diff --git a/src/go/go2go/names.go b/src/go/go2go/names.go index ff53be163f..62f1ec3a41 100644 --- a/src/go/go2go/names.go +++ b/src/go/go2go/names.go @@ -34,9 +34,9 @@ var nameCodes = map[rune]int{ } // instantiatedName returns the name of a newly instantiated function. -func (t *translator) instantiatedName(fnident *ast.Ident, types []types.Type) (string, error) { +func (t *translator) instantiatedName(ident *ast.Ident, types []types.Type) (string, error) { var sb strings.Builder - fmt.Fprintf(&sb, "instantiate%c%s", nameSep, fnident.Name) + fmt.Fprintf(&sb, "instantiate%c%s", nameSep, ident.Name) for _, typ := range types { sb.WriteRune(nameSep) s := typ.String() @@ -50,7 +50,7 @@ func (t *translator) instantiatedName(fnident *ast.Ident, types []types.Type) (s } else { code, ok := nameCodes[r] if !ok { - panic(fmt.Sprintf("unexpected type string character %q", r)) + panic(fmt.Sprintf("unexpected type string character %q in %q", r, s)) } fmt.Fprintf(&sb, "%c%d", nameIntro, code) } diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index 3423dd9694..5a6fa20d7d 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -26,7 +26,7 @@ func addIDs(info *types.Info, f *ast.File, mf map[types.Object]*ast.FuncDecl, mt for _, decl := range f.Decls { switch decl := decl.(type) { case *ast.FuncDecl: - if isParameterizedFuncDecl(decl) { + if isParameterizedFuncDecl(decl, info) { obj, ok := info.Defs[decl.Name] if !ok { panic(fmt.Sprintf("no types.Object for %q", decl.Name.Name)) @@ -49,8 +49,26 @@ func addIDs(info *types.Info, f *ast.File, mf map[types.Object]*ast.FuncDecl, mt } // isParameterizedFuncDecl reports whether fd is a parameterized function. -func isParameterizedFuncDecl(fd *ast.FuncDecl) bool { - return fd.Type.TParams != nil +func isParameterizedFuncDecl(fd *ast.FuncDecl, info *types.Info) bool { + if fd.Type.TParams != nil { + return true + } + if fd.Recv != nil { + rtyp := info.TypeOf(fd.Recv.List[0].Type) + if rtyp == nil { + // Already instantiated. + return false + } + if p, ok := rtyp.(*types.Pointer); ok { + rtyp = p.Elem() + } + if named, ok := rtyp.(*types.Named); ok { + if named.TParams() != nil { + return true + } + } + } + return false } // isParameterizedTypeDecl reports whether s is a parameterized type. @@ -61,6 +79,7 @@ func isParameterizedTypeDecl(s ast.Spec) bool { // A translator is used to translate a file from Go with contracts to Go 1. type translator struct { + fset *token.FileSet info *types.Info types map[ast.Expr]types.Type idToFunc map[types.Object]*ast.FuncDecl @@ -89,7 +108,7 @@ type typeInstantiation struct { // rewrite rewrites the contents of one file. func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, filename string, file *ast.File) (err error) { - if err := rewriteAST(info, idToFunc, idToTypeSpec, file); err != nil { + if err := rewriteAST(fset, info, idToFunc, idToTypeSpec, file); err != nil { return err } @@ -117,8 +136,9 @@ func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map } // rewriteAST rewrites the AST for a file. -func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, file *ast.File) (err error) { +func rewriteAST(fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, file *ast.File) (err error) { t := translator{ + fset: fset, info: info, types: make(map[ast.Expr]types.Type), idToFunc: idToFunc, @@ -139,7 +159,7 @@ func (t *translator) translate(file *ast.File) { for i, decl := range declsToDo { switch decl := decl.(type) { case *ast.FuncDecl: - if !isParameterizedFuncDecl(decl) { + if !isParameterizedFuncDecl(decl, t.info) { t.translateFuncDecl(&declsToDo[i]) newDecls = append(newDecls, decl) } @@ -260,6 +280,8 @@ func (t *translator) translateStmt(ps *ast.Stmt) { default: panic(fmt.Sprintf("unknown decl type %v", d.Tok)) } + case *ast.ReturnStmt: + t.translateExprList(s.Results) default: panic(fmt.Sprintf("unimplemented Stmt %T", s)) } @@ -276,9 +298,13 @@ func (t *translator) translateExpr(pe *ast.Expr) { switch e := (*pe).(type) { case *ast.Ident: return + case *ast.ParenExpr: + t.translateExpr(&e.X) case *ast.BinaryExpr: t.translateExpr(&e.X) t.translateExpr(&e.Y) + case *ast.UnaryExpr: + t.translateExpr(&e.X) case *ast.CallExpr: t.translateExprList(e.Args) if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && ftyp.TParams() != nil { @@ -388,7 +414,7 @@ func (t *translator) translateTypeInstantiation(pe *ast.Expr) { } } - instIdent, instType, err := t.instantiateTypeDecl(tident, typ.Underlying(), call.Args, types) + instIdent, instType, err := t.instantiateTypeDecl(tident, typ, call.Args, types) if err != nil { t.err = err return diff --git a/test/gen/g004.go2 b/test/gen/g004.go2 new file mode 100644 index 0000000000..899bc430d4 --- /dev/null +++ b/test/gen/g004.go2 @@ -0,0 +1,39 @@ +// run + +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "fmt" + +type Value(type T) struct { + val T +} + +func (v *Value(T)) Get() T { + return v.val +} + +func (v *Value(T)) Set(val T) { + v.val = val +} + +func main() { + var v1 Value(int) + v1.Set(1) + if got, want := v1.Get(), 1; got != want { + panic(fmt.Sprintf("Get() == %d, want %d", got, want)) + } + v1.Set(2) + if got, want := v1.Get(), 2; got != want { + panic(fmt.Sprintf("Get() == %d, want %d", got, want)) + } + + var v2 Value(string) + v2.Set("a") + if got, want := v2.Get(), "a"; got != want { + panic(fmt.Sprintf("Get() == %q, want %q", got, want)) + } +}