diff --git a/src/go/go2go/go2go.go b/src/go/go2go/go2go.go index 1c1cf17476..b3a938d8c6 100644 --- a/src/go/go2go/go2go.go +++ b/src/go/go2go/go2go.go @@ -88,12 +88,13 @@ func Rewrite(dir string) error { for _, tpkg := range tpkgs { idToFunc := make(map[types.Object]*ast.FuncDecl) + idToTypeSpec := make(map[types.Object]*ast.TypeSpec) for _, pkgfile := range tpkg.pkgfiles { - addFuncIDs(tpkg.info, pkgfile.ast, idToFunc) + addIDs(tpkg.info, pkgfile.ast, idToFunc, idToTypeSpec) } for _, pkgfile := range tpkg.pkgfiles { - if err := rewriteFile(dir, fset, tpkg.info, idToFunc, pkgfile.name, pkgfile.ast); err != nil { + if err := rewriteFile(dir, fset, tpkg.info, idToFunc, idToTypeSpec, pkgfile.name, pkgfile.ast); err != nil { return err } } @@ -121,8 +122,9 @@ func RewriteBuffer(filename string, file []byte) ([]byte, error) { return nil, err } idToFunc := make(map[types.Object]*ast.FuncDecl) - addFuncIDs(info, pf, idToFunc) - if err := rewriteAST(info, idToFunc, pf); err != nil { + idToTypeSpec := make(map[types.Object]*ast.TypeSpec) + addIDs(info, pf, idToFunc, idToTypeSpec) + if err := rewriteAST(info, idToFunc, idToTypeSpec, pf); err != nil { return nil, err } var buf bytes.Buffer diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go index d21c5fac70..ca4cfde289 100644 --- a/src/go/go2go/instantiate.go +++ b/src/go/go2go/instantiate.go @@ -7,6 +7,7 @@ package go2go import ( "fmt" "go/ast" + "go/token" "go/types" ) @@ -68,7 +69,7 @@ func (t *translator) instantiateFunction(fnident *ast.Ident, astTypes []ast.Expr objType := obj.Type() objParam, ok := objType.(*types.TypeParam) if !ok { - panic(fmt.Sprintf("%v is not a TypeParam")) + panic(fmt.Sprintf("%v is not a TypeParam", objParam)) } ta.add(obj, objParam, astTypes[i], typeTypes[i]) } @@ -102,6 +103,68 @@ func (t *translator) findFuncDecl(id *ast.Ident) (*ast.FuncDecl, error) { return decl, nil } +// instantiateType creates a new instantiation of a type. +func (t *translator) instantiateTypeDecl(tident *ast.Ident, typ types.Type, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, types.Type, error) { + name, err := t.instantiatedName(tident, typeTypes) + if err != nil { + return nil, nil, err + } + + spec, err := t.findTypeSpec(tident) + if err != nil { + return nil, nil, err + } + + ta := newTypeArgs(typeTypes) + for i, tf := range spec.TParams.List { + for _, tn := range tf.Names { + obj, ok := t.info.Defs[tn] + if !ok { + panic(fmt.Sprintf("no object for type parameter %q", tn)) + } + objType := obj.Type() + objParam, ok := objType.(*types.TypeParam) + if !ok { + panic(fmt.Sprintf("%v is not a TypeParam", objParam)) + } + ta.add(obj, objParam, astTypes[i], typeTypes[i]) + } + } + + instIdent := ast.NewIdent(name) + + newSpec := &ast.TypeSpec{ + Doc: spec.Doc, + Name: instIdent, + Assign: spec.Assign, + Type: t.instantiateExpr(ta, spec.Type), + Comment: spec.Comment, + } + newDecl := &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{newSpec}, + } + t.newDecls = append(t.newDecls, newDecl) + + instType := t.instantiateType(ta, typ) + + return instIdent, instType, nil +} + +// findTypeSpec looks for the TypeSpec for id. +// FIXME: Handle imported packages. +func (t *translator) findTypeSpec(id *ast.Ident) (*ast.TypeSpec, error) { + obj, ok := t.info.Uses[id] + if !ok { + return nil, fmt.Errorf("could not find Object for %q", id.Name) + } + spec, ok := t.idToTypeSpec[obj] + if !ok { + return nil, fmt.Errorf("could not find type spec for %q", id.Name) + } + return spec, nil +} + // instantiateBlockStmt instantiates a BlockStmt. func (t *translator) instantiateBlockStmt(ta *typeArgs, pbs *ast.BlockStmt) *ast.BlockStmt { changed := false @@ -257,6 +320,16 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { Len: ln, Elt: elt, } + case *ast.StructType: + fields := t.instantiateFieldList(ta, e.Fields) + if fields == e.Fields { + return e + } + return &ast.StructType{ + Struct: e.Struct, + Fields: fields, + Incomplete: e.Incomplete, + } default: panic(fmt.Sprintf("unimplemented Expr %T", e)) } diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index 6b9f74a144..3423dd9694 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -21,15 +21,29 @@ var config = printer.Config{ Tabwidth: 8, } -// addFuncIDS finds IDs for generic functions and adds them to a map. -func addFuncIDs(info *types.Info, f *ast.File, m map[types.Object]*ast.FuncDecl) { +// addIDs finds IDs for generic functions and types and adds them to a map. +func addIDs(info *types.Info, f *ast.File, mf map[types.Object]*ast.FuncDecl, mt map[types.Object]*ast.TypeSpec) { for _, decl := range f.Decls { - if fd, ok := decl.(*ast.FuncDecl); ok && isParameterizedFuncDecl(fd) { - obj, ok := info.Defs[fd.Name] - if !ok { - panic(fmt.Sprintf("no types.Object for %q", fd.Name.Name)) + switch decl := decl.(type) { + case *ast.FuncDecl: + if isParameterizedFuncDecl(decl) { + obj, ok := info.Defs[decl.Name] + if !ok { + panic(fmt.Sprintf("no types.Object for %q", decl.Name.Name)) + } + mf[obj] = decl + } + case *ast.GenDecl: + if decl.Tok == token.TYPE { + for _, s := range decl.Specs { + ts := s.(*ast.TypeSpec) + obj, ok := info.Defs[ts.Name] + if !ok { + panic(fmt.Sprintf("no types.Object for %q", ts.Name.Name)) + } + mt[obj] = ts + } } - m[obj] = fd } } } @@ -39,11 +53,18 @@ func isParameterizedFuncDecl(fd *ast.FuncDecl) bool { return fd.Type.TParams != nil } +// isParameterizedTypeDecl reports whether s is a parameterized type. +func isParameterizedTypeDecl(s ast.Spec) bool { + ts := s.(*ast.TypeSpec) + return ts.TParams != nil +} + // A translator is used to translate a file from Go with contracts to Go 1. type translator struct { info *types.Info types map[ast.Expr]types.Type idToFunc map[types.Object]*ast.FuncDecl + idToTypeSpec map[types.Object]*ast.TypeSpec instantiations map[*ast.Ident][]*instantiation newDecls []ast.Decl typeInstantiations map[types.Type][]*typeInstantiation @@ -62,12 +83,13 @@ type instantiation struct { // A typeInstantiation is a single instantiation of a type. type typeInstantiation struct { types []types.Type + decl *ast.Ident typ types.Type } // rewrite rewrites the contents of one file. -func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, filename string, file *ast.File) (err error) { - if err := rewriteAST(info, idToFunc, file); err != nil { +func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, filename string, file *ast.File) (err error) { + if err := rewriteAST(info, idToFunc, idToTypeSpec, file); err != nil { return err } @@ -95,11 +117,12 @@ func rewriteFile(dir string, fset *token.FileSet, info *types.Info, idToFunc map } // rewriteAST rewrites the AST for a file. -func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, file *ast.File) (err error) { +func rewriteAST(info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, idToTypeSpec map[types.Object]*ast.TypeSpec, file *ast.File) (err error) { t := translator{ info: info, types: make(map[ast.Expr]types.Type), idToFunc: idToFunc, + idToTypeSpec: idToTypeSpec, instantiations: make(map[*ast.Ident][]*instantiation), typeInstantiations: make(map[types.Type][]*typeInstantiation), } @@ -115,23 +138,34 @@ func (t *translator) translate(file *ast.File) { newDecls := make([]ast.Decl, 0, len(declsToDo)) for i, decl := range declsToDo { switch decl := decl.(type) { - case (*ast.FuncDecl): + case *ast.FuncDecl: if !isParameterizedFuncDecl(decl) { t.translateFuncDecl(&declsToDo[i]) newDecls = append(newDecls, decl) } - case (*ast.GenDecl): + case *ast.GenDecl: switch decl.Tok { case token.TYPE: + newSpecs := make([]ast.Spec, 0, len(decl.Specs)) for j := range decl.Specs { - t.translateTypeSpec(&decl.Specs[j]) + if !isParameterizedTypeDecl(decl.Specs[j]) { + t.translateTypeSpec(&decl.Specs[j]) + newSpecs = append(newSpecs, decl.Specs[j]) + } + } + if len(newSpecs) == 0 { + decl = nil + } else { + decl.Specs = newSpecs } case token.VAR, token.CONST: for j := range decl.Specs { t.translateValueSpec(&decl.Specs[j]) } } - newDecls = append(newDecls, decl) + if decl != nil { + newDecls = append(newDecls, decl) + } default: newDecls = append(newDecls, decl) } @@ -145,11 +179,10 @@ func (t *translator) translate(file *ast.File) { // translateTypeSpec translates a type from Go with contracts to Go 1. func (t *translator) translateTypeSpec(ps *ast.Spec) { ts := (*ps).(*ast.TypeSpec) - if ts.TParams == nil { - t.translateExpr(&ts.Type) - return + if ts.TParams != nil { + panic("parameterized type") } - panic("parameterized type") + t.translateExpr(&ts.Type) } // translateValueSpec translates a variable or constant from Go with @@ -192,16 +225,41 @@ func (t *translator) translateStmt(ps *ast.Stmt) { if t.err != nil { return } + if *ps == nil { + return + } switch s := (*ps).(type) { case *ast.BlockStmt: t.translateBlockStmt(s) case *ast.ExprStmt: t.translateExpr(&s.X) + case *ast.AssignStmt: + t.translateExprList(s.Lhs) + t.translateExprList(s.Rhs) + case *ast.IfStmt: + t.translateStmt(&s.Init) + t.translateExpr(&s.Cond) + t.translateBlockStmt(s.Body) + t.translateStmt(&s.Else) case *ast.RangeStmt: t.translateExpr(&s.Key) t.translateExpr(&s.Value) t.translateExpr(&s.X) t.translateBlockStmt(s.Body) + case *ast.DeclStmt: + d := s.Decl.(*ast.GenDecl) + switch d.Tok { + case token.TYPE: + for i := range d.Specs { + t.translateTypeSpec(&d.Specs[i]) + } + case token.CONST, token.VAR: + for i := range d.Specs { + t.translateValueSpec(&d.Specs[i]) + } + default: + panic(fmt.Sprintf("unknown decl type %v", d.Tok)) + } default: panic(fmt.Sprintf("unimplemented Stmt %T", s)) } @@ -218,11 +276,15 @@ func (t *translator) translateExpr(pe *ast.Expr) { switch e := (*pe).(type) { case *ast.Ident: return + case *ast.BinaryExpr: + t.translateExpr(&e.X) + t.translateExpr(&e.Y) case *ast.CallExpr: t.translateExprList(e.Args) - ftyp := t.lookupType(e.Fun).(*types.Signature) - if ftyp.TParams() != nil { + if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && ftyp.TParams() != nil { t.translateFunctionInstantiation(pe) + } else if ntyp, ok := t.lookupType(e.Fun).(*types.Named); ok && ntyp.TParams() != nil { + t.translateTypeInstantiation(pe) } t.translateExpr(&e.Fun) case *ast.StarExpr: @@ -231,6 +293,8 @@ func (t *translator) translateExpr(pe *ast.Expr) { t.translateExpr(&e.X) case *ast.ArrayType: t.translateExpr(&e.Elt) + case *ast.StructType: + t.translateFieldList(e.Fields) case *ast.BasicLit: return case *ast.CompositeLit: @@ -301,6 +365,45 @@ func (t *translator) translateFunctionInstantiation(pe *ast.Expr) { *pe = instIdent } +// translateTypeInstantiation translates an instantiated type to Go 1. +func (t *translator) translateTypeInstantiation(pe *ast.Expr) { + call := (*pe).(*ast.CallExpr) + tident, ok := call.Fun.(*ast.Ident) + if !ok { + panic("instantiated type non-ident") + } + + typ := t.lookupType(call.Fun).(*types.Named) + + types := make([]types.Type, 0, len(call.Args)) + for _, arg := range call.Args { + types = append(types, t.lookupType(arg)) + } + + instantiations := t.typeInstantiations[typ] + for _, inst := range instantiations { + if t.sameTypes(types, inst.types) { + *pe = inst.decl + return + } + } + + instIdent, instType, err := t.instantiateTypeDecl(tident, typ.Underlying(), call.Args, types) + if err != nil { + t.err = err + return + } + + n := &typeInstantiation{ + types: types, + decl: instIdent, + typ: instType, + } + t.typeInstantiations[typ] = append(instantiations, n) + + *pe = instIdent +} + // sameTypes reports whether two type slices are the same. func (t *translator) sameTypes(a, b []types.Type) bool { if len(a) != len(b) { diff --git a/src/go/go2go/types.go b/src/go/go2go/types.go index 3aed203578..2986047559 100644 --- a/src/go/go2go/types.go +++ b/src/go/go2go/types.go @@ -88,6 +88,33 @@ func (t *translator) doInstantiateType(ta *typeArgs, typ types.Type) types.Type return r case *types.Tuple: return t.instantiateTypeTuple(ta, typ) + case *types.Struct: + n := typ.NumFields() + fields := make([]*types.Var, n) + changed := false + tags := make([]string, n) + hasTag := false + for i := 0; i < n; i++ { + v := typ.Field(i) + instType := t.instantiateType(ta, v.Type()) + if v.Type() != instType { + changed = true + } + fields[i] = types.NewVar(v.Pos(), v.Pkg(), v.Name(), instType) + + tag := typ.Tag(i) + if tag != "" { + tags[i] = tag + hasTag = true + } + } + if !changed { + return typ + } + if !hasTag { + tags = nil + } + return types.NewStruct(fields, tags) default: panic(fmt.Sprintf("unimplemented Type %T", typ)) } diff --git a/src/go/types/type.go b/src/go/types/type.go index 2088823aaf..cce68b60d5 100644 --- a/src/go/types/type.go +++ b/src/go/types/type.go @@ -512,6 +512,9 @@ func NewNamed(obj *TypeName, underlying Type, methods []*Func) *Named { // Obj returns the type name for the named type t. func (t *Named) Obj() *TypeName { return t.obj } +// TParams returns the type parameters of the named type t, or nil. +func (t *Named) TParams() []*TypeName { return t.tparams } + // NumMethods returns the number of explicit methods whose receiver is named type t. func (t *Named) NumMethods() int { return len(t.methods) } diff --git a/test/gen/g003.go2 b/test/gen/g003.go2 new file mode 100644 index 0000000000..e11d3411c4 --- /dev/null +++ b/test/gen/g003.go2 @@ -0,0 +1,32 @@ +// run + +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "unsafe" +) + +type Pair(type F1, F2) struct { + f1 F1 + f2 F2 +} + +func main() { + p := Pair(int32, int64){1, 2} + if got, want := unsafe.Sizeof(p.f1), uintptr(4); got != want { + panic(fmt.Sprintf("unexpected f1 size == %d want %d", got, want)) + } + if got, want := unsafe.Sizeof(p.f2), uintptr(8); got != want { + panic(fmt.Sprintf("unexpected f2 size == %d want %d", got, want)) + } + type MyPair struct { f1 int32; f2 int64 } + mp := MyPair(p) + if mp.f1 != 1 || mp.f2 != 2 { + panic(fmt.Sprintf("mp == %#v want %#v", mp, MyPair{1, 2})) + } +}