go/go2go, go/types: preliminary support for rewriting and instantiation

Good enough to run this program:

package main

import "fmt"

func Print(type T)(s []T) {
	for _, v := range s {
		fmt.Println(v)
	}
}

func PrintInts(s []int) {
	Print(int)(s)
}

func main() {
	PrintInts([]int{1, 2})
}

Change-Id: I5ac205138085a63e7075b01ca2779b7eb71f9682
This commit is contained in:
Ian Lance Taylor 2020-01-17 20:35:51 -08:00 committed by Robert Griesemer
parent 10c3db1727
commit a588cbcd8e
5 changed files with 590 additions and 10 deletions

View File

@ -13,8 +13,8 @@ import (
"go/token"
"go/types"
"io"
"path/filepath"
"os"
"path/filepath"
"sort"
"strings"
)
@ -43,8 +43,9 @@ func Rewrite(dir string) error {
ast *ast.File
}
type gpkg struct {
tpkg *types.Package
tpkg *types.Package
pkgfiles []fileAST
info *types.Info
}
var tpkgs []*gpkg
@ -63,13 +64,21 @@ func Rewrite(dir string) error {
}
conf := types.Config{Importer: importer.Default()}
var info types.Info
tpkg, err := conf.Check(name, fset, asts, &info)
info := &types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
}
tpkg, err := conf.Check(name, fset, asts, info)
if err != nil {
return err
}
tpkgs = append(tpkgs, &gpkg{tpkg: tpkg, pkgfiles: pkgfiles})
tpkgs = append(tpkgs, &gpkg{
tpkg: tpkg,
pkgfiles: pkgfiles,
info: info,
})
}
if err := checkAndRemoveGofiles(dir, gofiles); err != nil {
@ -77,8 +86,13 @@ func Rewrite(dir string) error {
}
for _, tpkg := range tpkgs {
idToFunc := make(map[types.Object]*ast.FuncDecl)
for _, pkgfile := range tpkg.pkgfiles {
if err := rewrite(dir, fset, pkgfile.name, pkgfile.ast); err != nil {
addFuncIDs(tpkg.info, pkgfile.ast, idToFunc)
}
for _, pkgfile := range tpkg.pkgfiles {
if err := rewrite(dir, fset, tpkg.info, idToFunc, pkgfile.name, pkgfile.ast); err != nil {
return err
}
}

241
src/go/go2go/instantiate.go Normal file
View File

@ -0,0 +1,241 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package go2go
import (
"fmt"
"go/ast"
"go/types"
)
type typemap map[types.Object]ast.Expr
// instantiateFunction creates a new instantiation of a function.
func (t *translator) instantiateFunction(fnident *ast.Ident, astTypes []ast.Expr, typeTypes []types.Type) (*ast.Ident, error) {
name, err := t.instantiatedName(fnident, typeTypes)
if err != nil {
return nil, err
}
decl, err := t.findFuncDecl(fnident)
if err != nil {
return nil, err
}
targs := make(typemap, len(decl.Type.TParams.List))
for i, tf := range decl.Type.TParams.List {
for _, tn := range tf.Names {
obj, ok := t.info.Defs[tn]
if !ok {
panic(fmt.Sprintf("no object for type parameter %q", tn))
}
targs[obj] = astTypes[i]
}
}
instIdent := ast.NewIdent(name)
newDecl := &ast.FuncDecl{
Doc: decl.Doc,
Recv: t.instantiateFieldList(targs, decl.Recv),
Name: instIdent,
Type: t.instantiateExpr(targs, decl.Type).(*ast.FuncType),
Body: t.instantiateBlockStmt(targs, decl.Body),
}
t.newDecls = append(t.newDecls, newDecl)
return instIdent, nil
}
// findFuncDecl looks for the FuncDecl for id.
// FIXME: Handle imported packages.
func (t *translator) findFuncDecl(id *ast.Ident) (*ast.FuncDecl, error) {
obj, ok := t.info.Uses[id]
if !ok {
return nil, fmt.Errorf("could not find Object for %q", id.Name)
}
decl, ok := t.idToFunc[obj]
if !ok {
return nil, fmt.Errorf("could not find function body for %q", id.Name)
}
return decl, nil
}
// instantiateBlockStmt instantiates a BlockStmt.
func (t *translator) instantiateBlockStmt(targs typemap, pbs *ast.BlockStmt) *ast.BlockStmt {
changed := false
stmts := make([]ast.Stmt, len(pbs.List))
for i, s := range pbs.List {
is := t.instantiateStmt(targs, s)
stmts[i] = is
if is != s {
changed = true
}
}
if !changed {
return pbs
}
return &ast.BlockStmt{
Lbrace: pbs.Lbrace,
List: stmts,
Rbrace: pbs.Rbrace,
}
}
// instantiateStmt instantiates a statement.
func (t *translator) instantiateStmt(targs typemap, s ast.Stmt) ast.Stmt {
switch s := s.(type) {
case *ast.BlockStmt:
return t.instantiateBlockStmt(targs, s)
case *ast.ExprStmt:
x := t.instantiateExpr(targs, s.X)
if x == s.X {
return s
}
return &ast.ExprStmt{
X: x,
}
case *ast.RangeStmt:
key := t.instantiateExpr(targs, s.Key)
value := t.instantiateExpr(targs, s.Value)
x := t.instantiateExpr(targs, s.X)
body := t.instantiateBlockStmt(targs, s.Body)
if key == s.Key && value == s.Value && x == s.X && body == s.Body {
return s
}
return &ast.RangeStmt{
For: s.For,
Key: key,
Value: value,
TokPos: s.TokPos,
Tok: s.Tok,
X: x,
Body: body,
}
default:
panic(fmt.Sprintf("unimplemented Stmt %T", s))
}
}
// instantiateFieldList instantiates a field list.
func (t *translator) instantiateFieldList(targs typemap, fl *ast.FieldList) *ast.FieldList {
if fl == nil {
return nil
}
nfl := make([]*ast.Field, len(fl.List))
changed := false
for i, f := range fl.List {
nf := t.instantiateField(targs, f)
if nf != f {
changed = true
}
nfl[i] = nf
}
if !changed {
return fl
}
return &ast.FieldList{
Opening: fl.Opening,
List: nfl,
Closing: fl.Closing,
}
}
// instantiateField instantiates a field.
func (t *translator) instantiateField(targs typemap, f *ast.Field) *ast.Field {
typ := t.instantiateExpr(targs, f.Type)
if typ == f.Type {
return f
}
return &ast.Field{
Doc: f.Doc,
Names: f.Names,
Type: typ,
Tag: f.Tag,
Comment: f.Comment,
}
}
// instantiateExpr instantiates an expression.
func (t *translator) instantiateExpr(targs typemap, e ast.Expr) ast.Expr {
if e == nil {
return nil
}
switch e := e.(type) {
case *ast.CallExpr:
fun := t.instantiateExpr(targs, e.Fun)
args, argsChanged := t.instantiateExprList(targs, e.Args)
if fun == e.Fun && !argsChanged {
return e
}
return &ast.CallExpr{
Fun: fun,
Lparen: e.Lparen,
Args: args,
Ellipsis: e.Ellipsis,
Rparen: e.Rparen,
}
case *ast.Ident:
obj, ok := t.info.Uses[e]
if ok {
typ, ok := targs[obj]
if ok {
return typ
}
}
return e
case *ast.SelectorExpr:
x := t.instantiateExpr(targs, e.X)
if x == e.X {
return e
}
return &ast.SelectorExpr{
X: x,
Sel: e.Sel,
}
case *ast.FuncType:
params := t.instantiateFieldList(targs, e.Params)
results := t.instantiateFieldList(targs, e.Results)
if e.TParams == nil && params == e.Params && results == e.Results {
return e
}
return &ast.FuncType{
Func: e.Func,
TParams: nil,
Params: params,
Results: results,
}
case *ast.ArrayType:
ln := t.instantiateExpr(targs, e.Len)
elt := t.instantiateExpr(targs, e.Elt)
if ln == e.Len && elt == e.Elt {
return e
}
return &ast.ArrayType{
Lbrack: e.Lbrack,
Len: ln,
Elt: elt,
}
default:
panic(fmt.Sprintf("unimplemented Expr %T", e))
}
}
// instantiateExprList instantiates an expression list.
func (t *translator) instantiateExprList(targs typemap, el []ast.Expr) ([]ast.Expr, bool) {
nel := make([]ast.Expr, len(el))
changed := false
for i, e := range el {
ne := t.instantiateExpr(targs, e)
if ne != e {
changed = true
}
nel[i] = ne
}
if !changed {
return el, false
}
return nel, true
}

60
src/go/go2go/names.go Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package go2go
import (
"fmt"
"go/ast"
"go/types"
"strings"
"unicode"
)
// We use Arabic digit zero as a separator.
// Do not use this character in your own identifiers.
const nameSep = '٠'
// We use Vai digit one to introduce a special character code.
// Do not use this character in your own identifiers.
const nameIntro = '꘡'
var nameCodes = map[rune]int{
' ': 0,
'*': 1,
';': 2,
',': 3,
'{': 4,
'}': 5,
'[': 6,
']': 7,
'(': 8,
')': 9,
}
// instantiatedName returns the name of a newly instantiated function.
func (t *translator) instantiatedName(fnident *ast.Ident, types []types.Type) (string, error) {
var sb strings.Builder
fmt.Fprintf(&sb, "_instantiate%c%s", nameSep, fnident.Name)
for _, typ := range types {
sb.WriteRune(nameSep)
s := typ.String()
// We have to uniquely translate s into a valid Go identifier.
// This is not possible in general but we assume that
// identifiers will not contain
for _, r := range s {
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
sb.WriteRune(r)
} else {
code, ok := nameCodes[r]
if !ok {
panic(fmt.Sprintf("unexpected type string character %q", r))
}
fmt.Fprintf(&sb, "%c%d", nameIntro, code)
}
}
}
return sb.String(), nil
}

View File

@ -10,18 +10,60 @@ import (
"go/ast"
"go/printer"
"go/token"
"path/filepath"
"go/types"
"os"
"path/filepath"
"strings"
)
var config = printer.Config{
Mode: printer.UseSpaces | printer.TabIndent | printer.SourcePos,
Mode: printer.UseSpaces | printer.TabIndent | printer.SourcePos,
Tabwidth: 8,
}
// addFuncIDS finds IDs for instantiated functions and adds them to a map.
func addFuncIDs(info *types.Info, f *ast.File, m map[types.Object]*ast.FuncDecl) {
for _, decl := range f.Decls {
if fd, ok := decl.(*ast.FuncDecl); ok && isParameterizedFuncDecl(fd) {
obj, ok := info.Defs[fd.Name]
if !ok {
panic(fmt.Sprintf("no types.Object for %q", fd.Name.Name))
}
m[obj] = fd
}
}
}
// isParameterizedFuncDecl reports whether fd is a parameterized function.
func isParameterizedFuncDecl(fd *ast.FuncDecl) bool {
return fd.Type.TParams != nil
}
// A translator is used to translate a file from Go with contracts to Go 1.
type translator struct {
info *types.Info
idToFunc map[types.Object]*ast.FuncDecl
instantiations map[*ast.Ident][]*instantiation
newDecls []ast.Decl
}
// An instantiation is a single instantiation of a function.
type instantiation struct {
types []types.Type
decl *ast.Ident
}
// rewrite rewrites the contents of one file.
func rewrite(dir string, fset *token.FileSet, filename string, ast *ast.File) (err error) {
func rewrite(dir string, fset *token.FileSet, info *types.Info, idToFunc map[types.Object]*ast.FuncDecl, filename string, file *ast.File) (err error) {
t := translator{
info: info,
idToFunc: idToFunc,
instantiations: make(map[*ast.Ident][]*instantiation),
}
if err := t.translate(file); err != nil {
return err
}
filename = filepath.Base(filename)
goFile := strings.TrimSuffix(filename, filepath.Ext(filename)) + ".go"
o, err := os.Create(filepath.Join(dir, goFile))
@ -42,5 +84,225 @@ func rewrite(dir string, fset *token.FileSet, filename string, ast *ast.File) (e
}()
fmt.Fprintln(w, rewritePrefix)
return config.Fprint(w, fset, ast)
return config.Fprint(w, fset, file)
}
// translate translates the AST for a file from Go with contracts to Go 1.
func (t *translator) translate(file *ast.File) error {
newDecls := make([]ast.Decl, 0, len(file.Decls))
for i, decl := range file.Decls {
switch decl := decl.(type) {
case (*ast.FuncDecl):
if !isParameterizedFuncDecl(decl) {
if err := t.translateFuncDecl(&file.Decls[i]); err != nil {
return err
}
newDecls = append(newDecls, decl)
}
case (*ast.GenDecl):
switch decl.Tok {
case token.TYPE:
for j := range decl.Specs {
if err := t.translateTypeSpec(&decl.Specs[j]); err != nil {
return err
}
}
case token.VAR, token.CONST:
for j := range decl.Specs {
if err := t.translateValueSpec(&decl.Specs[j]); err != nil {
return err
}
}
default:
newDecls = append(newDecls, decl)
}
default:
newDecls = append(newDecls, decl)
}
}
file.Decls = append(newDecls, t.newDecls...)
return nil
}
// translateTypeSpec translates a type from Go with contracts to Go 1.
func (t *translator) translateTypeSpec(ps *ast.Spec) error {
ts := (*ps).(*ast.TypeSpec)
if ts.TParams == nil {
return t.translateExpr(&ts.Type)
}
panic("parameterized type")
}
// translateValueSpec translates a variable or constant from Go with
// contracts to Go 1.
func (t *translator) translateValueSpec(ps *ast.Spec) error {
vs := (*ps).(*ast.ValueSpec)
if err := t.translateExpr(&vs.Type); err != nil {
return err
}
for i := range vs.Values {
if err := t.translateExpr(&vs.Values[i]); err != nil {
return err
}
}
return nil
}
// translateFuncDecl translates a function from Go with contracts to Go 1.
func (t *translator) translateFuncDecl(pd *ast.Decl) error {
fd := (*pd).(*ast.FuncDecl)
if fd.Type.TParams != nil {
panic("parameterized function")
}
if fd.Recv != nil {
if err := t.translateFieldList(fd.Recv); err != nil {
return err
}
}
if err := t.translateFieldList(fd.Type.Params); err != nil {
return err
}
if err := t.translateFieldList(fd.Type.Results); err != nil {
return err
}
if err := t.translateBlockStmt(fd.Body); err != nil {
return err
}
return nil
}
// translateBlockStmt translates a block statement from Go with
// contracts to Go 1.
func (t *translator) translateBlockStmt(pbs *ast.BlockStmt) error {
for i := range pbs.List {
if err := t.translateStmt(&pbs.List[i]); err != nil {
return err
}
}
return nil
}
// translateStmt translates a statement from Go with contracts to Go 1.
func (t *translator) translateStmt(ps *ast.Stmt) error {
switch s := (*ps).(type) {
case *ast.BlockStmt:
return t.translateBlockStmt(s)
case *ast.ExprStmt:
return t.translateExpr(&s.X)
default:
panic(fmt.Sprintf("unimplemented Stmt %T", s))
}
}
// translateExpr translates an expression from Go with contracts to Go 1.
func (t *translator) translateExpr(pe *ast.Expr) error {
switch e := (*pe).(type) {
case *ast.Ident:
return nil
case *ast.CallExpr:
if err := t.translateExprList(e.Args); err != nil {
return err
}
ftyp := t.info.Types[e.Fun].Type.(*types.Signature)
if ftyp.TParams() != nil {
if err := t.translateFunctionInstantiation(pe); err != nil {
return err
}
}
return t.translateExpr(&e.Fun)
case *ast.StarExpr:
return t.translateExpr(&e.X)
case *ast.SelectorExpr:
return t.translateExpr(&e.X)
case *ast.ArrayType:
return t.translateExpr(&e.Elt)
case *ast.BasicLit:
return nil
case *ast.CompositeLit:
if err := t.translateExpr(&e.Type); err != nil {
return err
}
return t.translateExprList(e.Elts)
default:
panic(fmt.Sprintf("unimplemented Expr %T", e))
}
}
// translateExprList translate an expression list from Go with
// contracts to Go 1.
func (t *translator) translateExprList(el []ast.Expr) error {
for i := range el {
if err := t.translateExpr(&el[i]); err != nil {
return err
}
}
return nil
}
// translateFieldList translates a field list from Go with contracts to Go 1.
func (t *translator) translateFieldList(fl *ast.FieldList) error {
if fl == nil {
return nil
}
for _, f := range fl.List {
if err := t.translateField(f); err != nil {
return err
}
}
return nil
}
// translateField translates a field from Go with contracts to Go 1.
func (t *translator) translateField(f *ast.Field) error {
return t.translateExpr(&f.Type)
}
// translateFunctionInstantiation translates an instantiated function
// to Go 1.
func (t *translator) translateFunctionInstantiation(pe *ast.Expr) error {
call := (*pe).(*ast.CallExpr)
fnident, ok := call.Fun.(*ast.Ident)
if !ok {
panic("instantiated function non-ident")
}
types := make([]types.Type, 0, len(call.Args))
for _, arg := range call.Args {
types = append(types, t.info.Types[arg].Type)
}
instantiations := t.instantiations[fnident]
for _, inst := range instantiations {
if t.sameTypes(types, inst.types) {
*pe = inst.decl
return nil
}
}
instIdent, err := t.instantiateFunction(fnident, call.Args, types)
if err != nil {
return err
}
n := &instantiation{
types: types,
decl: instIdent,
}
t.instantiations[fnident] = append(instantiations, n)
*pe = instIdent
return nil
}
// sameTypes reports whether two type slices are the same.
func (t *translator) sameTypes(a, b []types.Type) bool {
if len(a) != len(b) {
return false
}
for i, x := range a {
if x != b[i] {
return false
}
}
return true
}

View File

@ -233,6 +233,9 @@ func NewSignature(recv *Var, params, results *Tuple, variadic bool) *Signature {
// contain methods whose receiver type is a different interface.
func (s *Signature) Recv() *Var { return s.recv }
// TParams returns the type parameters of signature s, or nil.
func (s *Signature) TParams() []*TypeName { return s.tparams }
// Params returns the parameters of signature s, or nil.
func (s *Signature) Params() *Tuple { return s.params }