diff --git a/internal/imports/fix.go b/internal/imports/fix.go index 4066565192..bf2bbed5d5 100644 --- a/internal/imports/fix.go +++ b/internal/imports/fix.go @@ -712,6 +712,10 @@ type Resolver interface { loadPackageNames(importPaths []string, srcDir string) (map[string]string, error) // scan finds (at least) the packages satisfying refs. The returned slice is unordered. scan(refs references) ([]*pkg, error) + // loadExports returns the set of exported symbols in the package at dir. + // It returns an error if the package name in dir does not match expectPackage. + // loadExports may be called concurrently. + loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[string]bool, error) } // gopackagesResolver implements resolver for GOPATH and module workspaces using go/packages. @@ -766,6 +770,26 @@ func (r *goPackagesResolver) scan(refs references) ([]*pkg, error) { return scan, nil } +func (r *goPackagesResolver) loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[string]bool, error) { + if pkg.goPackage == nil { + return nil, fmt.Errorf("goPackage not set") + } + exports := map[string]bool{} + fset := token.NewFileSet() + for _, fname := range pkg.goPackage.CompiledGoFiles { + f, err := parser.ParseFile(fset, fname, nil, 0) + if err != nil { + return nil, fmt.Errorf("parsing %s: %v", fname, err) + } + for name := range f.Scope.Objects { + if ast.IsExported(name) { + exports[name] = true + } + } + } + return exports, nil +} + func addExternalCandidates(pass *pass, refs references, filename string) error { dirScan, err := pass.env.GetResolver().scan(refs) if err != nil { @@ -1018,6 +1042,10 @@ func (r *gopathResolver) scan(_ references) ([]*pkg, error) { return result, nil } +func (r *gopathResolver) loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[string]bool, error) { + return loadExportsFromFiles(ctx, r.env, expectPackage, pkg.dir) +} + // VendorlessPath returns the devendorized version of the import path ipath. // For example, VendorlessPath("foo/bar/vendor/a/b") returns "a/b". func VendorlessPath(ipath string) string { @@ -1031,33 +1059,11 @@ func VendorlessPath(ipath string) string { return ipath } -// loadExports returns the set of exported symbols in the package at dir. -// It returns nil on error or if the package name in dir does not match expectPackage. -func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg *pkg) (map[string]bool, error) { - if env.Debug { - env.Logf("loading exports in dir %s (seeking package %s)", pkg.dir, expectPackage) - } - if pkg.goPackage != nil { - exports := map[string]bool{} - fset := token.NewFileSet() - for _, fname := range pkg.goPackage.CompiledGoFiles { - f, err := parser.ParseFile(fset, fname, nil, 0) - if err != nil { - return nil, fmt.Errorf("parsing %s: %v", fname, err) - } - for name := range f.Scope.Objects { - if ast.IsExported(name) { - exports[name] = true - } - } - } - return exports, nil - } - +func loadExportsFromFiles(ctx context.Context, env *ProcessEnv, expectPackage string, dir string) (map[string]bool, error) { exports := make(map[string]bool) // Look for non-test, buildable .go files which could provide exports. - all, err := ioutil.ReadDir(pkg.dir) + all, err := ioutil.ReadDir(dir) if err != nil { return nil, err } @@ -1067,7 +1073,7 @@ func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { continue } - match, err := env.buildContext().MatchFile(pkg.dir, fi.Name()) + match, err := env.buildContext().MatchFile(dir, fi.Name()) if err != nil || !match { continue } @@ -1075,7 +1081,7 @@ func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg } if len(files) == 0 { - return nil, fmt.Errorf("dir %v contains no buildable, non-test .go files", pkg.dir) + return nil, fmt.Errorf("dir %v contains no buildable, non-test .go files", dir) } fset := token.NewFileSet() @@ -1086,7 +1092,7 @@ func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg default: } - fullFile := filepath.Join(pkg.dir, fi.Name()) + fullFile := filepath.Join(dir, fi.Name()) f, err := parser.ParseFile(fset, fullFile, nil, 0) if err != nil { return nil, fmt.Errorf("parsing %s: %v", fullFile, err) @@ -1098,7 +1104,7 @@ func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg continue } if pkgName != expectPackage { - return nil, fmt.Errorf("scan of dir %v is not expected package %v (actually %v)", pkg.dir, expectPackage, pkgName) + return nil, fmt.Errorf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName) } for name := range f.Scope.Objects { if ast.IsExported(name) { @@ -1113,7 +1119,7 @@ func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg exportList = append(exportList, k) } sort.Strings(exportList) - env.Logf("loaded exports in dir %v (package %v): %v", pkg.dir, expectPackage, strings.Join(exportList, ", ")) + env.Logf("loaded exports in dir %v (package %v): %v", dir, expectPackage, strings.Join(exportList, ", ")) } return exports, nil } @@ -1187,7 +1193,10 @@ func findImport(ctx context.Context, pass *pass, dirScan []*pkg, pkgName string, wg.Done() }() - exports, err := loadExports(ctx, pass.env, pkgName, c.pkg) + if pass.env.Debug { + pass.env.Logf("loading exports in dir %s (seeking package %s)", c.pkg.dir, pkgName) + } + exports, err := pass.env.GetResolver().loadExports(ctx, pkgName, c.pkg) if err != nil { if pass.env.Debug { pass.env.Logf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err) diff --git a/internal/imports/mod.go b/internal/imports/mod.go index 7cc6c9db99..4f088b746c 100644 --- a/internal/imports/mod.go +++ b/internal/imports/mod.go @@ -2,6 +2,7 @@ package imports import ( "bytes" + "context" "encoding/json" "io/ioutil" "os" @@ -327,6 +328,13 @@ func (r *ModuleResolver) scan(_ references) ([]*pkg, error) { return result, nil } +func (r *ModuleResolver) loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[string]bool, error) { + if err := r.init(); err != nil { + return nil, err + } + return loadExportsFromFiles(ctx, r.env, expectPackage, pkg.dir) +} + // modCacheRegexp splits a path in a module cache into module, module version, and package. var modCacheRegexp = regexp.MustCompile(`(.*)@([^/\\]*)(.*)`)