diff --git a/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 b/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 index 47715ed23f..412a946f73 100644 --- a/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 +++ b/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 @@ -17,7 +17,7 @@ import ( var ints = []int{74, 59, 238, -784, 9845, 959, 905, 0, 0, 42, 7586, -5467984, 7586} var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN(), math.NaN(), math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8} -var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"} +var strs = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"} func TestSortOrderedInts(t *testing.T) { testOrdered(t, ints, sort.Ints) @@ -28,7 +28,7 @@ func TestSortOrderedFloat64s(t *testing.T) { } func TestSortOrderedStrings(t *testing.T) { - testOrdered(t, strings, sort.Strings) + testOrdered(t, strs, sort.Strings) } func testOrdered(type Elem contracts.Ordered)(t *testing.T, s []Elem, sorter func([]Elem)) { diff --git a/src/cmd/go2go/testdata/go2path/src/orderedmap/orderedmap.go2 b/src/cmd/go2go/testdata/go2path/src/orderedmap/orderedmap.go2 new file mode 100644 index 0000000000..684250f41d --- /dev/null +++ b/src/cmd/go2go/testdata/go2path/src/orderedmap/orderedmap.go2 @@ -0,0 +1,136 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package orderedmap provides an ordered map, implemented as a binary tree. +package orderedmap + +// FIXME: This should probably be container/orderedmap. + +import ( + "context" + + "chans" + "contracts" +) + +// Map is an ordered map. +type Map(type K, V) struct { + root *node(K, V) + compare func(K, K) int +} + +// node is the type of a node in the binary tree. +type node(type K, V) struct { + key K + val V + left, right *node(K, V) +} + +// New returns a new map. It takes a comparison function that compares two +// keys and returns < 0 if the first is less, == 0 if they are equal, +// > 0 if the first is greater. +func New(type K, V)(compare func(K, K) int) *Map(K, V) { + return &Map(K, V){compare: compare} +} + +// NewOrdered returns a new map whose key is an ordered type. +// This is like New, but does not require providing a compare function. +// The map compare function uses the obvious key ordering. +func NewOrdered(type K, V contracts.Ordered(K))() *Map(K, V) { + return New(K, V)(func(k1, k2 K) int { + switch { + case k1 < k2: + return -1 + case k1 == k2: + return 0 + default: + return -1 + } + }) +} + +// find looks up key in the map, returning either a pointer to the slot of the +// node holding key, or a pointer to the slot where should a node would go. +func (m *Map(K, V)) find(key K) **node(K, V) { + pn := &m.root + for *pn != nil { + switch cmp := m.compare(key, (*pn).key); { + case cmp < 0: + pn = &(*pn).left + case cmp > 0: + pn = &(*pn).right + default: + return pn + } + } + return pn +} + +// Insert inserts a new key/value into the map. +// If the key is already present, the value is replaced. +// Reports whether this is a new key. +func (m *Map(K, V)) Insert(key K, val V) bool { + pn := m.find(key) + if *pn != nil { + (*pn).val = val + return false + } + *pn = &node(K, V){key: key, val: val} + return true +} + +// Find returns the value associated with a key, or the zero value +// if not present. The found result reports whether the key was found. +func (m *Map(K, V)) Find(key K) (V, bool) { + pn := m.find(key) + if *pn == nil { + var zero V + return zero, false + } + return (*pn).val, true +} + +// keyValue is a pair of key and value used while iterating. +type keyValue(type K, V) struct { + key K + val V +} + +// iterate returns an iterator that traverses the map. +func (m *Map(K, V)) Iterate() *Iterator(K, V) { + sender, receiver := chans.Ranger(keyValue(K, V))() + var f func(*node(K, V)) bool + f = func(n *node(K, V)) bool { + if n == nil { + return true + } + // Stop the traversal if Send fails, which means that + // nothing is listening to the receiver. + return f(n.left) && + sender.Send(context.Background(), keyValue(K, V){n.key, n.val}) && + f(n.right) + } + go func() { + f(m.root) + sender.Close() + }() + return &Iterator(K, V){receiver} +} + +// Iterator is used to iterate over the map. +type Iterator(type K, V) struct { + r *chans.Receiver(keyValue(K, V)) +} + +// Next returns the next key and value pair, and a boolean that reports +// whether they are valid. If not valid, we have reached the end of the map. +func (it *Iterator(K, V)) Next() (K, V, bool) { + keyval, ok := it.r.Next(context.Background()) + if !ok { + var zerok K + var zerov V + return zerok, zerov, false + } + return keyval.key, keyval.val, true +} diff --git a/src/cmd/go2go/testdata/go2path/src/orderedmap/orderedmap_test.go2 b/src/cmd/go2go/testdata/go2path/src/orderedmap/orderedmap_test.go2 new file mode 100644 index 0000000000..b51ac0a7f3 --- /dev/null +++ b/src/cmd/go2go/testdata/go2path/src/orderedmap/orderedmap_test.go2 @@ -0,0 +1,63 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package orderedmap + +import ( + "bytes" + "testing" + + "slices" +) + +func TestMap(t *testing.T) { + m := New([]byte, int)(bytes.Compare) + + if _, found := m.Find([]byte("a")); found { + t.Errorf("unexpectedly found %q in empty map", []byte("a")) + } + if !m.Insert([]byte("a"), 'a') { + t.Errorf("key %q unexpectedly already present", []byte("a")) + } + if !m.Insert([]byte("c"), 'c') { + t.Errorf("key %q unexpectedly already present", []byte("c")) + } + if !m.Insert([]byte("b"), 'b') { + t.Errorf("key %q unexpectedly already present", []byte("b")) + } + if m.Insert([]byte("c"), 'x') { + t.Errorf("key %q unexpectedly not present", []byte("c")) + } + + if v, found := m.Find([]byte("a")); !found { + t.Errorf("did not find %q", []byte("a")) + } else if v != 'a' { + t.Errorf("key %q returned wrong value %c, expected %c", []byte("a"), v, 'a') + } + if v, found := m.Find([]byte("c")); !found { + t.Errorf("did not find %q", []byte("c")) + } else if v != 'x' { + t.Errorf("key %q returned wrong value %c, expected %c", []byte("c"), v, 'x') + } + + if _, found := m.Find([]byte("d")); found { + t.Errorf("unexpectedly found %q", []byte("d")) + } + + gather := func(it *Iterator([]byte, int)) []int { + var r []int + for { + _, v, ok := it.Next() + if !ok { + return r + } + r = append(r, v) + } + } + got := gather(m.Iterate()) + want := []int{'a', 'b', 'x'} + if !slices.Equal(got, want) { + t.Errorf("Iterate returned %v, want %v", got, want) + } +} diff --git a/src/go/go2go/go2go.go b/src/go/go2go/go2go.go index 1c47560680..6d44ebaa15 100644 --- a/src/go/go2go/go2go.go +++ b/src/go/go2go/go2go.go @@ -92,7 +92,7 @@ func rewriteFilesInPath(importer *Importer, importPath, dir string, go2files []s } if !strings.HasSuffix(pkg.Name, "_test") { - importer.record(pkgfiles, importPath, tpkg) + importer.record(pkgfiles, importPath, tpkg, asts) } rpkgs = append(rpkgs, tpkg) @@ -101,7 +101,7 @@ func rewriteFilesInPath(importer *Importer, importPath, dir string, go2files []s for _, tpkg := range tpkgs { for i, pkgfile := range tpkg { - if err := rewriteFile(dir, fset, importer, pkgfile.name, pkgfile.ast, i == 0); err != nil { + if err := rewriteFile(dir, fset, importer, importPath, pkgfile.name, pkgfile.ast, i == 0); err != nil { return nil, err } } @@ -128,7 +128,7 @@ func RewriteBuffer(importer *Importer, filename string, file []byte) ([]byte, er return nil, fmt.Errorf("type checking failed for %s\n%v", pf.Name.Name, merr) } importer.addIDs(pf) - if err := rewriteAST(fset, importer, pf, true); err != nil { + if err := rewriteAST(fset, importer, "", pf, true); err != nil { return nil, err } var buf bytes.Buffer diff --git a/src/go/go2go/importer.go b/src/go/go2go/importer.go index 49a166887e..ce50fa66f9 100644 --- a/src/go/go2go/importer.go +++ b/src/go/go2go/importer.go @@ -15,6 +15,7 @@ import ( "log" "os" "path/filepath" + "sort" "strings" ) @@ -35,6 +36,9 @@ type Importer struct { // Map from import path to package information. packages map[string]*types.Package + // Map from import path to list of import paths that it imports. + imports map[string][]string + // Map from Object to AST function declaration for // parameterized functions. idToFunc map[types.Object]*ast.FuncDecl @@ -59,6 +63,7 @@ func NewImporter(tmpdir string) *Importer { info: info, translated: make(map[string]string), packages: make(map[string]*types.Package), + imports: make(map[string][]string), idToFunc: make(map[types.Object]*ast.FuncDecl), idToTypeSpec: make(map[types.Object]*ast.TypeSpec), } @@ -216,13 +221,42 @@ func (imp *Importer) localImport(importPath, dir string) (*types.Package, error) // record records information for a package, for use when working // with packages that import this one. -func (imp *Importer) record(pkgfiles []namedAST, importPath string, tpkg *types.Package) { +func (imp *Importer) record(pkgfiles []namedAST, importPath string, tpkg *types.Package, asts []*ast.File) { if importPath != "" { imp.packages[importPath] = tpkg } - for _, nast := range pkgfiles { - imp.addIDs(nast.ast) - } + imp.imports[importPath] = imp.collectImports(asts) + for _, nast := range pkgfiles { + imp.addIDs(nast.ast) + } +} + +// collectImports returns all the imports paths imported by any of the ASTs. +func (imp *Importer) collectImports(asts []*ast.File) []string { + m := make(map[string]bool) + for _, a := range asts { + for _, decl := range a.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.IMPORT { + continue + } + for _, spec := range gen.Specs { + imp := spec.(*ast.ImportSpec) + if imp.Name != nil { + // We don't try to handle import aliases. + continue + } + path := strings.TrimPrefix(strings.TrimSuffix(imp.Path.Value, `"`), `"`) + m[path] = true + } + } + } + s := make([]string, 0, len(m)) + for p := range m { + s = append(s, p) + } + sort.Strings(s) + return s } // addIDs finds IDs for generic functions and types and adds them to a map. @@ -269,3 +303,35 @@ func (imp *Importer) lookupTypeSpec(obj types.Object) (*ast.TypeSpec, bool) { ts, ok := imp.idToTypeSpec[obj] return ts, ok } + +// transitiveImports returns all the transitive imports of an import path. +func (imp *Importer) transitiveImports(path string) []string { + return imp.gatherTransitiveImports(path, make(map[string]bool)) +} + +// gatherTransitiveImports returns all the transitive imports of an import path, +// using a map to avoid duplicate work. +func (imp *Importer) gatherTransitiveImports(path string, m map[string]bool) []string { + imports := imp.imports[path] + if len(imports) == 0 { + return nil + } + var r []string + for _, im := range imports { + r = append(r, im) + if !m[im] { + m[im] = true + r = append(r, imp.gatherTransitiveImports(im, m)...) + } + } + dup := make(map[string]bool) + for _, p := range r { + dup[p] = true + } + r = make([]string, 0, len(dup)) + for p := range dup { + r = append(r, p) + } + sort.Strings(r) + return r +} diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go index 5eba7355b9..f508eb9606 100644 --- a/src/go/go2go/instantiate.go +++ b/src/go/go2go/instantiate.go @@ -171,6 +171,8 @@ func (t *translator) instantiateTypeDecl(qid qualifiedIdent, typ *types.Named, a instType := t.instantiateType(ta, typ.Underlying()) + t.setType(instIdent, instType) + nm := typ.NumMethods() for i := 0; i < nm; i++ { method := typ.Method(i) diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index 2d519a1488..f9a65c14a9 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -13,6 +13,8 @@ import ( "go/types" "os" "path/filepath" + "sort" + "strconv" "strings" ) @@ -79,8 +81,8 @@ type typeInstantiation struct { } // rewrite rewrites the contents of one file. -func rewriteFile(dir string, fset *token.FileSet, importer *Importer, filename string, file *ast.File, addImportableName bool) (err error) { - if err := rewriteAST(fset, importer, file, addImportableName); err != nil { +func rewriteFile(dir string, fset *token.FileSet, importer *Importer, importPath, filename string, file *ast.File, addImportableName bool) (err error) { + if err := rewriteAST(fset, importer, importPath, file, addImportableName); err != nil { return err } @@ -108,7 +110,7 @@ func rewriteFile(dir string, fset *token.FileSet, importer *Importer, filename s } // rewriteAST rewrites the AST for a file. -func rewriteAST(fset *token.FileSet, importer *Importer, file *ast.File, addImportableName bool) (err error) { +func rewriteAST(fset *token.FileSet, importer *Importer, importPath string, file *ast.File, addImportableName bool) (err error) { t := translator{ fset: fset, importer: importer, @@ -118,6 +120,63 @@ func rewriteAST(fset *token.FileSet, importer *Importer, file *ast.File, addImpo } t.translate(file) + // Add all the transitive imports. This is more than we need, + // but we're not trying to be elegant here. + imps := make(map[string]bool) + + for _, p := range importer.transitiveImports(importPath) { + imps[p] = true + } + + decls := make([]ast.Decl, 0, len(file.Decls)) + var specs []ast.Spec + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.IMPORT { + decls = append(decls, decl) + continue + } + for _, spec := range gen.Specs { + imp := spec.(*ast.ImportSpec) + if imp.Name != nil { + specs = append(specs, imp) + } + // We picked up Go 2 imports above, but we still + // need to pick up Go 1 imports here. + path := strings.TrimPrefix(strings.TrimSuffix(imp.Path.Value, `"`), `"`) + if imps[path] { + continue + } + imps[path] = true + for _, p := range importer.transitiveImports(path) { + imps[p] = true + } + } + } + file.Decls = decls + + paths := make([]string, 0, len(imps)) + for p := range imps { + paths = append(paths, p) + } + sort.Strings(paths) + + for _, p := range paths { + specs = append(specs, ast.Spec(&ast.ImportSpec{ + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: strconv.Quote(p), + }, + })) + } + if len(specs) > 0 { + first := &ast.GenDecl{ + Tok: token.IMPORT, + Specs: specs, + } + file.Decls = append([]ast.Decl{first}, file.Decls...) + } + // Add a name that other packages can reference to avoid an error // about an unused package. if addImportableName { @@ -622,7 +681,7 @@ func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr, typeList = make([]types.Type, 0, len(argList)) for _, arg := range argList { if at := t.lookupType(arg); at == nil { - panic(fmt.Sprintf("no type found for %T %v", arg, arg)) + panic(fmt.Sprintf("%s: no type found for %T %v", t.fset.Position(arg.Pos()), arg, arg)) } else { typeList = append(typeList, at) }