diff --git a/src/go/go2go/go2go.go b/src/go/go2go/go2go.go index ca9225de97..7bdb71b07c 100644 --- a/src/go/go2go/go2go.go +++ b/src/go/go2go/go2go.go @@ -27,13 +27,13 @@ const rewritePrefix = "// Code generated by go2go; DO NOT EDIT.\n\n" // them as a single package. It writes out a .go file with any // polymorphic code rewritten into normal code. func Rewrite(importer *Importer, dir string) error { - _, err := rewriteToPkgs(importer, dir) + _, err := rewriteToPkgs(importer, "", dir) return err } // rewriteToPkgs rewrites the contents of a single directory, // and returns the types.Packages that it computes. -func rewriteToPkgs(importer *Importer, dir string) ([]*types.Package, error) { +func rewriteToPkgs(importer *Importer, importPath, dir string) ([]*types.Package, error) { go2files, gofiles, err := go2Files(dir) if err != nil { return nil, err @@ -43,7 +43,7 @@ func rewriteToPkgs(importer *Importer, dir string) ([]*types.Package, error) { return nil, err } - return RewriteFiles(importer, dir, go2files) + return rewriteFilesInPath(importer, importPath, dir, go2files) } // namedAST holds a file name and the AST parsed from that file. @@ -54,6 +54,11 @@ type namedAST struct { // rewriteFiles rewrites a set of .go2 files in dir. func RewriteFiles(importer *Importer, dir string, go2files []string) ([]*types.Package, error) { + return rewriteFilesInPath(importer, "", dir, go2files) +} + +// rewriteFilesInPath rewrites a set of .go2 files in dir for importPath. +func rewriteFilesInPath(importer *Importer, importPath, dir string, go2files []string) ([]*types.Package, error) { fset := token.NewFileSet() pkgs, err := parseFiles(dir, go2files, fset) if err != nil { @@ -62,7 +67,7 @@ func RewriteFiles(importer *Importer, dir string, go2files []string) ([]*types.P var rpkgs []*types.Package var tpkgs [][]namedAST - for name, pkg := range pkgs { + for _, pkg := range pkgs { pkgfiles := make([]namedAST, 0, len(pkg.Files)) for n, f := range pkg.Files { pkgfiles = append(pkgfiles, namedAST{n, f}) @@ -77,12 +82,14 @@ func RewriteFiles(importer *Importer, dir string, go2files []string) ([]*types.P } conf := types.Config{Importer: importer} - tpkg, err := conf.Check(name, fset, asts, importer.info) + tpkg, err := conf.Check(pkg.Name, fset, asts, importer.info) if err != nil { - return nil, fmt.Errorf("type checking failed for %s: %v", name, err) + return nil, fmt.Errorf("type checking failed for %s: %v", pkg.Name, err) } - importer.register(pkgfiles, tpkg) + if !strings.HasSuffix(pkg.Name, "_test") { + importer.record(pkgfiles, importPath, tpkg) + } rpkgs = append(rpkgs, tpkg) tpkgs = append(tpkgs, pkgfiles) @@ -191,7 +198,7 @@ func checkGoFile(dir, f string) error { } // parseFiles parses a list of .go2 files. -func parseFiles(dir string, go2files []string, fset *token.FileSet) (map[string]*ast.Package, error) { +func parseFiles(dir string, go2files []string, fset *token.FileSet) ([]*ast.Package, error) { pkgs := make(map[string]*ast.Package) for _, go2f := range go2files { filename := filepath.Join(dir, go2f) @@ -211,5 +218,14 @@ func parseFiles(dir string, go2files []string, fset *token.FileSet) (map[string] } pkg.Files[filename] = pf } - return pkgs, nil + + rpkgs := make([]*ast.Package, 0, len(pkgs)) + for _, pkg := range pkgs { + rpkgs = append(rpkgs, pkg) + } + sort.Slice(rpkgs, func(i, j int) bool { + return rpkgs[i].Name < rpkgs[j].Name + }) + + return rpkgs, nil } diff --git a/src/go/go2go/importer.go b/src/go/go2go/importer.go index 5ee69d572c..49a166887e 100644 --- a/src/go/go2go/importer.go +++ b/src/go/go2go/importer.go @@ -151,7 +151,7 @@ func (imp *Importer) ImportFrom(importPath, dir string, mode types.ImportMode) ( imp.translated[importPath] = tdir - tpkgs, err := rewriteToPkgs(imp, tdir) + tpkgs, err := rewriteToPkgs(imp, importPath, tdir) if err != nil { return nil, err } @@ -184,6 +184,25 @@ func (imp *Importer) findFromPath(gopath, dir string) string { return "" } +// Register registers a package under an import path. +// This is for tests that use directives like //compiledir. +func (imp *Importer) Register(importPath string, tpkgs []*types.Package) error { + switch len(tpkgs) { + case 1: + imp.packages[importPath] = tpkgs[0] + return nil + case 2: + if strings.HasSuffix(tpkgs[0].Name(), "_test") { + imp.packages[importPath] = tpkgs[1] + return nil + } else if strings.HasSuffix(tpkgs[1].Name(), "_test") { + imp.packages[importPath] = tpkgs[0] + return nil + } + } + return fmt.Errorf("unexpected number of packages (%d) for %q", len(tpkgs), importPath) +} + // localImport handles a local import such as // import "./a" // This is for tests that use directives like //compiledir. @@ -195,13 +214,15 @@ func (imp *Importer) localImport(importPath, dir string) (*types.Package, error) return tpkg, nil } -// register records information for a package, for use when working +// record records information for a package, for use when working // with packages that import this one. -func (imp *Importer) register(pkgfiles []namedAST, tpkg *types.Package) { - imp.packages[tpkg.Path()] = tpkg - for _, nast := range pkgfiles { - imp.addIDs(nast.ast) +func (imp *Importer) record(pkgfiles []namedAST, importPath string, tpkg *types.Package) { + if importPath != "" { + imp.packages[importPath] = tpkg } + for _, nast := range pkgfiles { + imp.addIDs(nast.ast) + } } // addIDs finds IDs for generic functions and types and adds them to a map. diff --git a/test/run.go b/test/run.go index aa05bda709..b6afcebb96 100644 --- a/test/run.go +++ b/test/run.go @@ -233,10 +233,13 @@ func compileFile(runcmd runCmd, longname string, flags []string) (out []byte, er func compileInDir(runcmd runCmd, dir string, ft fileType, importer *go2go.Importer, flags []string, localImports bool, names ...string) (out []byte, err error) { gofiles := names if ft == go2Files { - _, err := go2go.RewriteFiles(importer, dir, names) + tpkgs, err := go2go.RewriteFiles(importer, dir, names) if err != nil { return nil, err } + if err := importer.Register(tpkgs[0].Path(), tpkgs); err != nil { + return nil, err + } gofiles = make([]string, len(names)) for i, name := range names {