From a588cbcd8e3ac8cb536fd3b0d0285201c3575dca Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Fri, 17 Jan 2020 20:35:51 -0800 Subject: [PATCH] go/go2go, go/types: preliminary support for rewriting and instantiation Good enough to run this program: package main import "fmt" func Print(type T)(s []T) { for _, v := range s { fmt.Println(v) } } func PrintInts(s []int) { Print(int)(s) } func main() { PrintInts([]int{1, 2}) } Change-Id: I5ac205138085a63e7075b01ca2779b7eb71f9682 --- src/go/go2go/go2go.go | 26 +++- src/go/go2go/instantiate.go | 241 ++++++++++++++++++++++++++++++++ src/go/go2go/names.go | 60 ++++++++ src/go/go2go/rewrite.go | 270 +++++++++++++++++++++++++++++++++++- src/go/types/type.go | 3 + 5 files changed, 590 insertions(+), 10 deletions(-) create mode 100644 src/go/go2go/instantiate.go create mode 100644 src/go/go2go/names.go diff --git a/src/go/go2go/go2go.go b/src/go/go2go/go2go.go index 4b9e810ed5..b4565a63b5 100644 --- a/src/go/go2go/go2go.go +++ b/src/go/go2go/go2go.go @@ -13,8 +13,8 @@ import ( "go/token" "go/types" "io" - "path/filepath" "os" + "path/filepath" "sort" "strings" ) @@ -43,8 +43,9 @@ func Rewrite(dir string) error { ast *ast.File } type gpkg struct { - tpkg *types.Package + tpkg *types.Package pkgfiles []fileAST + info *types.Info } var tpkgs []*gpkg @@ -63,13 +64,21 @@ func Rewrite(dir string) error { } conf := types.Config{Importer: importer.Default()} - var info types.Info - tpkg, err := conf.Check(name, fset, asts, &info) + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + tpkg, err := conf.Check(name, fset, asts, info) if err != nil { return err } - tpkgs = append(tpkgs, &gpkg{tpkg: tpkg, pkgfiles: pkgfiles}) + tpkgs = append(tpkgs, &gpkg{ + tpkg: tpkg, + pkgfiles: pkgfiles, + info: info, + }) } if err := checkAndRemoveGofiles(dir, gofiles); err != nil { @@ -77,8 +86,13 @@ func Rewrite(dir string) error { } for _, tpkg := range tpkgs { + idToFunc := make(map[types.Object]*ast.FuncDecl) for _, pkgfile := range tpkg.pkgfiles { - if err := rewrite(dir, fset, pkgfile.name, pkgfile.ast); err != nil { + addFuncIDs(tpkg.info, pkgfile.ast, idToFunc) + } + + for _, pkgfile := range tpkg.pkgfiles { + if err := rewrite(dir, fset, tpkg.info, idToFunc, pkgfile.name, pkgfile.ast); err != nil { return err } } diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go new file mode 100644 index 0000000000..59e6b48d43 --- /dev/null +++ b/src/go/go2go/instantiate.go @@ -0,0 +1,241 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package go2go + +import ( + "fmt" + "go/ast" + "go/types" +) + +type typemap map[types.Object]ast.Expr + +// instantiateFunction creates a new instantiation of a function. +func (t *translator) instantiateFunction(fnident *ast.Ident, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, error) { + name, err := t.instantiatedName(fnident, typeTypes) + if err != nil { + return nil, err + } + + decl, err := t.findFuncDecl(fnident) + if err != nil { + return nil, err + } + + targs := make(typemap, len(decl.Type.TParams.List)) + for i, tf := range decl.Type.TParams.List { + for _, tn := range tf.Names { + obj, ok := t.info.Defs[tn] + if !ok { + panic(fmt.Sprintf("no object for type parameter %q", tn)) + } + targs[obj] = astTypes[i] + } + } + + instIdent := ast.NewIdent(name) + + newDecl := &ast.FuncDecl{ + Doc: decl.Doc, + Recv: t.instantiateFieldList(targs, decl.Recv), + Name: instIdent, + Type: t.instantiateExpr(targs, decl.Type).(*ast.FuncType), + Body: t.instantiateBlockStmt(targs, decl.Body), + } + t.newDecls = append(t.newDecls, newDecl) + + return instIdent, nil +} + +// findFuncDecl looks for the FuncDecl for id. +// FIXME: Handle imported packages. +func (t *translator) findFuncDecl(id *ast.Ident) (*ast.FuncDecl, error) { + obj, ok := t.info.Uses[id] + if !ok { + return nil, fmt.Errorf("could not find Object for %q", id.Name) + } + decl, ok := t.idToFunc[obj] + if !ok { + return nil, fmt.Errorf("could not find function body for %q", id.Name) + } + return decl, nil +} + +// instantiateBlockStmt instantiates a BlockStmt. +func (t *translator) instantiateBlockStmt(targs typemap, pbs *ast.BlockStmt) *ast.BlockStmt { + changed := false + stmts := make([]ast.Stmt, len(pbs.List)) + for i, s := range pbs.List { + is := t.instantiateStmt(targs, s) + stmts[i] = is + if is != s { + changed = true + } + } + if !changed { + return pbs + } + return &ast.BlockStmt{ + Lbrace: pbs.Lbrace, + List: stmts, + Rbrace: pbs.Rbrace, + } +} + +// instantiateStmt instantiates a statement. +func (t *translator) instantiateStmt(targs typemap, s ast.Stmt) ast.Stmt { + switch s := s.(type) { + case *ast.BlockStmt: + return t.instantiateBlockStmt(targs, s) + case *ast.ExprStmt: + x := t.instantiateExpr(targs, s.X) + if x == s.X { + return s + } + return &ast.ExprStmt{ + X: x, + } + case *ast.RangeStmt: + key := t.instantiateExpr(targs, s.Key) + value := t.instantiateExpr(targs, s.Value) + x := t.instantiateExpr(targs, s.X) + body := t.instantiateBlockStmt(targs, s.Body) + if key == s.Key && value == s.Value && x == s.X && body == s.Body { + return s + } + return &ast.RangeStmt{ + For: s.For, + Key: key, + Value: value, + TokPos: s.TokPos, + Tok: s.Tok, + X: x, + Body: body, + } + default: + panic(fmt.Sprintf("unimplemented Stmt %T", s)) + } +} + +// instantiateFieldList instantiates a field list. +func (t *translator) instantiateFieldList(targs typemap, fl *ast.FieldList) *ast.FieldList { + if fl == nil { + return nil + } + nfl := make([]*ast.Field, len(fl.List)) + changed := false + for i, f := range fl.List { + nf := t.instantiateField(targs, f) + if nf != f { + changed = true + } + nfl[i] = nf + } + if !changed { + return fl + } + return &ast.FieldList{ + Opening: fl.Opening, + List: nfl, + Closing: fl.Closing, + } +} + +// instantiateField instantiates a field. +func (t *translator) instantiateField(targs typemap, f *ast.Field) *ast.Field { + typ := t.instantiateExpr(targs, f.Type) + if typ == f.Type { + return f + } + return &ast.Field{ + Doc: f.Doc, + Names: f.Names, + Type: typ, + Tag: f.Tag, + Comment: f.Comment, + } +} + +// instantiateExpr instantiates an expression. +func (t *translator) instantiateExpr(targs typemap, e ast.Expr) ast.Expr { + if e == nil { + return nil + } + switch e := e.(type) { + case *ast.CallExpr: + fun := t.instantiateExpr(targs, e.Fun) + args, argsChanged := t.instantiateExprList(targs, e.Args) + if fun == e.Fun && !argsChanged { + return e + } + return &ast.CallExpr{ + Fun: fun, + Lparen: e.Lparen, + Args: args, + Ellipsis: e.Ellipsis, + Rparen: e.Rparen, + } + case *ast.Ident: + obj, ok := t.info.Uses[e] + if ok { + typ, ok := targs[obj] + if ok { + return typ + } + } + return e + case *ast.SelectorExpr: + x := t.instantiateExpr(targs, e.X) + if x == e.X { + return e + } + return &ast.SelectorExpr{ + X: x, + Sel: e.Sel, + } + case *ast.FuncType: + params := t.instantiateFieldList(targs, e.Params) + results := t.instantiateFieldList(targs, e.Results) + if e.TParams == nil && params == e.Params && results == e.Results { + return e + } + return &ast.FuncType{ + Func: e.Func, + TParams: nil, + Params: params, + Results: results, + } + case *ast.ArrayType: + ln := t.instantiateExpr(targs, e.Len) + elt := t.instantiateExpr(targs, e.Elt) + if ln == e.Len && elt == e.Elt { + return e + } + return &ast.ArrayType{ + Lbrack: e.Lbrack, + Len: ln, + Elt: elt, + } + default: + panic(fmt.Sprintf("unimplemented Expr %T", e)) + } +} + +// instantiateExprList instantiates an expression list. +func (t *translator) instantiateExprList(targs typemap, el []ast.Expr) ([]ast.Expr, bool) { + nel := make([]ast.Expr, len(el)) + changed := false + for i, e := range el { + ne := t.instantiateExpr(targs, e) + if ne != e { + changed = true + } + nel[i] = ne + } + if !changed { + return el, false + } + return nel, true +} diff --git a/src/go/go2go/names.go b/src/go/go2go/names.go new file mode 100644 index 0000000000..e77c6b3109 --- /dev/null +++ b/src/go/go2go/names.go @@ -0,0 +1,60 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package go2go + +import ( + "fmt" + "go/ast" + "go/types" + "strings" + "unicode" +) + +// We use Arabic digit zero as a separator. +// Do not use this character in your own identifiers. +const nameSep = '٠' + +// We use Vai digit one to introduce a special character code. +// Do not use this character in your own identifiers. +const nameIntro = '꘡' + +var nameCodes = map[rune]int{ + ' ': 0, + '*': 1, + ';': 2, + ',': 3, + '{': 4, + '}': 5, + '[': 6, + ']': 7, + '(': 8, + ')': 9, +} + +// instantiatedName returns the name of a newly instantiated function. +func (t *translator) instantiatedName(fnident *ast.Ident, types []types.Type) (string, error) { + var sb strings.Builder + fmt.Fprintf(&sb, "_instantiate%c%s", nameSep, fnident.Name) + for _, typ := range types { + sb.WriteRune(nameSep) + s := typ.String() + + // We have to uniquely translate s into a valid Go identifier. + // This is not possible in general but we assume that + // identifiers will not contain + for _, r := range s { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { + sb.WriteRune(r) + } else { + code, ok := nameCodes[r] + if !ok { + panic(fmt.Sprintf("unexpected type string character %q", r)) + } + fmt.Fprintf(&sb, "%c%d", nameIntro, code) + } + } + } + return sb.String(), nil +} diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index 3c296b5259..8f5ff38ff7 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -10,18 +10,60 @@ import ( "go/ast" "go/printer" "go/token" - "path/filepath" + "go/types" "os" + "path/filepath" "strings" ) var config = printer.Config{ - Mode: printer.UseSpaces | printer.TabIndent | printer.SourcePos, + Mode: printer.UseSpaces | printer.TabIndent | printer.SourcePos, Tabwidth: 8, } +// addFuncIDS finds IDs for instantiated functions and adds them to a map. +func addFuncIDs(info *types.Info, f *ast.File, m map[types.Object]*ast.FuncDecl) { + for _, decl := range f.Decls { + if fd, ok := decl.(*ast.FuncDecl); ok && isParameterizedFuncDecl(fd) { + obj, ok := info.Defs[fd.Name] + if !ok { + panic(fmt.Sprintf("no types.Object for %q", fd.Name.Name)) + } + m[obj] = fd + } + } +} + +// isParameterizedFuncDecl reports whether fd is a parameterized function. +func isParameterizedFuncDecl(fd *ast.FuncDecl) bool { + return fd.Type.TParams != nil +} + +// A translator is used to translate a file from Go with contracts to Go 1. +type translator struct { + info *types.Info + idToFunc map[types.Object]*ast.FuncDecl + instantiations map[*ast.Ident][]*instantiation + newDecls []ast.Decl +} + +// An instantiation is a single instantiation of a function. +type instantiation struct { + types []types.Type + decl *ast.Ident +} + // rewrite rewrites the contents of one file. -func rewrite(dir string, fset *token.FileSet, filename string, ast *ast.File) (err error) { +func rewrite(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, filename string, file *ast.File) (err error) { + t := translator{ + info: info, + idToFunc: idToFunc, + instantiations: make(map[*ast.Ident][]*instantiation), + } + if err := t.translate(file); err != nil { + return err + } + filename = filepath.Base(filename) goFile := strings.TrimSuffix(filename, filepath.Ext(filename)) + ".go" o, err := os.Create(filepath.Join(dir, goFile)) @@ -42,5 +84,225 @@ func rewrite(dir string, fset *token.FileSet, filename string, ast *ast.File) (e }() fmt.Fprintln(w, rewritePrefix) - return config.Fprint(w, fset, ast) + return config.Fprint(w, fset, file) +} + +// translate translates the AST for a file from Go with contracts to Go 1. +func (t *translator) translate(file *ast.File) error { + newDecls := make([]ast.Decl, 0, len(file.Decls)) + for i, decl := range file.Decls { + switch decl := decl.(type) { + case (*ast.FuncDecl): + if !isParameterizedFuncDecl(decl) { + if err := t.translateFuncDecl(&file.Decls[i]); err != nil { + return err + } + newDecls = append(newDecls, decl) + } + case (*ast.GenDecl): + switch decl.Tok { + case token.TYPE: + for j := range decl.Specs { + if err := t.translateTypeSpec(&decl.Specs[j]); err != nil { + return err + } + } + case token.VAR, token.CONST: + for j := range decl.Specs { + if err := t.translateValueSpec(&decl.Specs[j]); err != nil { + return err + } + } + default: + newDecls = append(newDecls, decl) + } + default: + newDecls = append(newDecls, decl) + } + } + file.Decls = append(newDecls, t.newDecls...) + return nil +} + +// translateTypeSpec translates a type from Go with contracts to Go 1. +func (t *translator) translateTypeSpec(ps *ast.Spec) error { + ts := (*ps).(*ast.TypeSpec) + if ts.TParams == nil { + return t.translateExpr(&ts.Type) + } + panic("parameterized type") +} + +// translateValueSpec translates a variable or constant from Go with +// contracts to Go 1. +func (t *translator) translateValueSpec(ps *ast.Spec) error { + vs := (*ps).(*ast.ValueSpec) + if err := t.translateExpr(&vs.Type); err != nil { + return err + } + for i := range vs.Values { + if err := t.translateExpr(&vs.Values[i]); err != nil { + return err + } + } + return nil +} + +// translateFuncDecl translates a function from Go with contracts to Go 1. +func (t *translator) translateFuncDecl(pd *ast.Decl) error { + fd := (*pd).(*ast.FuncDecl) + if fd.Type.TParams != nil { + panic("parameterized function") + } + if fd.Recv != nil { + if err := t.translateFieldList(fd.Recv); err != nil { + return err + } + } + if err := t.translateFieldList(fd.Type.Params); err != nil { + return err + } + if err := t.translateFieldList(fd.Type.Results); err != nil { + return err + } + if err := t.translateBlockStmt(fd.Body); err != nil { + return err + } + return nil +} + +// translateBlockStmt translates a block statement from Go with +// contracts to Go 1. +func (t *translator) translateBlockStmt(pbs *ast.BlockStmt) error { + for i := range pbs.List { + if err := t.translateStmt(&pbs.List[i]); err != nil { + return err + } + } + return nil +} + +// translateStmt translates a statement from Go with contracts to Go 1. +func (t *translator) translateStmt(ps *ast.Stmt) error { + switch s := (*ps).(type) { + case *ast.BlockStmt: + return t.translateBlockStmt(s) + case *ast.ExprStmt: + return t.translateExpr(&s.X) + default: + panic(fmt.Sprintf("unimplemented Stmt %T", s)) + } +} + +// translateExpr translates an expression from Go with contracts to Go 1. +func (t *translator) translateExpr(pe *ast.Expr) error { + switch e := (*pe).(type) { + case *ast.Ident: + return nil + case *ast.CallExpr: + if err := t.translateExprList(e.Args); err != nil { + return err + } + ftyp := t.info.Types[e.Fun].Type.(*types.Signature) + if ftyp.TParams() != nil { + if err := t.translateFunctionInstantiation(pe); err != nil { + return err + } + } + return t.translateExpr(&e.Fun) + case *ast.StarExpr: + return t.translateExpr(&e.X) + case *ast.SelectorExpr: + return t.translateExpr(&e.X) + case *ast.ArrayType: + return t.translateExpr(&e.Elt) + case *ast.BasicLit: + return nil + case *ast.CompositeLit: + if err := t.translateExpr(&e.Type); err != nil { + return err + } + return t.translateExprList(e.Elts) + default: + panic(fmt.Sprintf("unimplemented Expr %T", e)) + } +} + +// translateExprList translate an expression list from Go with +// contracts to Go 1. +func (t *translator) translateExprList(el []ast.Expr) error { + for i := range el { + if err := t.translateExpr(&el[i]); err != nil { + return err + } + } + return nil +} + +// translateFieldList translates a field list from Go with contracts to Go 1. +func (t *translator) translateFieldList(fl *ast.FieldList) error { + if fl == nil { + return nil + } + for _, f := range fl.List { + if err := t.translateField(f); err != nil { + return err + } + } + return nil +} + +// translateField translates a field from Go with contracts to Go 1. +func (t *translator) translateField(f *ast.Field) error { + return t.translateExpr(&f.Type) +} + +// translateFunctionInstantiation translates an instantiated function +// to Go 1. +func (t *translator) translateFunctionInstantiation(pe *ast.Expr) error { + call := (*pe).(*ast.CallExpr) + fnident, ok := call.Fun.(*ast.Ident) + if !ok { + panic("instantiated function non-ident") + } + + types := make([]types.Type, 0, len(call.Args)) + for _, arg := range call.Args { + types = append(types, t.info.Types[arg].Type) + } + + instantiations := t.instantiations[fnident] + for _, inst := range instantiations { + if t.sameTypes(types, inst.types) { + *pe = inst.decl + return nil + } + } + + instIdent, err := t.instantiateFunction(fnident, call.Args, types) + if err != nil { + return err + } + + n := &instantiation{ + types: types, + decl: instIdent, + } + t.instantiations[fnident] = append(instantiations, n) + + *pe = instIdent + return nil +} + +// sameTypes reports whether two type slices are the same. +func (t *translator) sameTypes(a, b []types.Type) bool { + if len(a) != len(b) { + return false + } + for i, x := range a { + if x != b[i] { + return false + } + } + return true } diff --git a/src/go/types/type.go b/src/go/types/type.go index 8cad041acd..83ea777f43 100644 --- a/src/go/types/type.go +++ b/src/go/types/type.go @@ -233,6 +233,9 @@ func NewSignature(recv *Var, params, results *Tuple, variadic bool) *Signature { // contain methods whose receiver type is a different interface. func (s *Signature) Recv() *Var { return s.recv } +// TParams returns the type parameters of signature s, or nil. +func (s *Signature) TParams() []*TypeName { return s.tparams } + // Params returns the parameters of signature s, or nil. func (s *Signature) Params() *Tuple { return s.params }