diff --git a/gopls/internal/lsp/mod/diagnostics.go b/gopls/internal/lsp/mod/diagnostics.go index 05189070cb..e455ba3f71 100644 --- a/gopls/internal/lsp/mod/diagnostics.go +++ b/gopls/internal/lsp/mod/diagnostics.go @@ -186,22 +186,16 @@ func ModVulnerabilityDiagnostics(ctx context.Context, snapshot source.Snapshot, mod *govulncheck.Module vuln *govulncheck.Vuln } - affecting := make(map[string][]modVuln) - nonaffecting := make(map[string][]modVuln) - for _, v := range vs { - for _, m := range v.Modules { - if v.IsCalled() { - affecting[m.Path] = append(affecting[m.Path], modVuln{mod: m, vuln: v}) - } else { - nonaffecting[m.Path] = append(nonaffecting[m.Path], modVuln{mod: m, vuln: v}) - } + vulnsByModule := make(map[string][]modVuln) + for _, vuln := range vs { + for _, mod := range vuln.Modules { + vulnsByModule[mod.Path] = append(vulnsByModule[mod.Path], modVuln{mod, vuln}) } } for _, req := range pm.File.Require { - affectingVulns, ok := affecting[req.Mod.Path] - nonaffectingVulns, ok2 := nonaffecting[req.Mod.Path] - if !ok && !ok2 { + vulns := vulnsByModule[req.Mod.Path] + if len(vulns) == 0 { continue } rng, err := pm.Mapper.OffsetRange(req.Syntax.Start.Byte, req.Syntax.End.Byte) @@ -213,7 +207,7 @@ func ModVulnerabilityDiagnostics(ctx context.Context, snapshot source.Snapshot, // Fixes will include only the upgrades for warning level diagnostics. var fixes []source.SuggestedFix var warning, info []string - for _, mv := range nonaffectingVulns { + for _, mv := range vulns { mod, vuln := mv.mod, mv.vuln // Only show the diagnostic if the vulnerability was calculated // for the module at the current version. @@ -224,18 +218,12 @@ func ModVulnerabilityDiagnostics(ctx context.Context, snapshot source.Snapshot, if semver.IsValid(mod.FoundVersion) && semver.Compare(req.Mod.Version, mod.FoundVersion) != 0 { continue } - info = append(info, vuln.OSV.ID) - } - for _, mv := range affectingVulns { - mod, vuln := mv.mod, mv.vuln - // Only show the diagnostic if the vulnerability was calculated - // for the module at the current version. - if semver.IsValid(mod.FoundVersion) && semver.Compare(req.Mod.Version, mod.FoundVersion) != 0 { - continue + if !vuln.IsCalled() { + info = append(info, vuln.OSV.ID) + } else { + warning = append(warning, vuln.OSV.ID) } - warning = append(warning, vuln.OSV.ID) // Upgrade to the exact version we offer the user, not the most recent. - // TODO(hakim): Produce fixes only for affecting vulnerabilities (if len(v.Trace) > 0) if fixedVersion := mod.FixedVersion; semver.IsValid(fixedVersion) && semver.Compare(req.Mod.Version, fixedVersion) < 0 { cmd, err := getUpgradeCodeAction(fh, req, fixedVersion) if err != nil {