diff --git a/cmd/digraph/digraph.go b/cmd/digraph/digraph.go index 6393f44a93..73835d8713 100644 --- a/cmd/digraph/digraph.go +++ b/cmd/digraph/digraph.go @@ -303,6 +303,36 @@ func (g graph) allpaths(from, to string) error { return nil } +func (g graph) somepath(from, to string) error { + type edge struct{ from, to string } + seen := make(nodeset) + var dfs func(path []edge, from string) bool + dfs = func(path []edge, from string) bool { + if !seen[from] { + seen[from] = true + if from == to { + // fmt.Println(path, len(path), cap(path)) + // Print and unwind. + for _, e := range path { + fmt.Fprintln(stdout, e.from+" "+e.to) + } + return true + } + for e := range g[from] { + if dfs(append(path, edge{from: from, to: e}), e) { + return true + } + } + } + return false + } + maxEdgesInGraph := len(g) * (len(g) - 1) + if !dfs(make([]edge, 0, maxEdgesInGraph), from) { + return fmt.Errorf("no path from %q to %q", from, to) + } + return nil +} + func parse(rd io.Reader) (graph, error) { g := make(graph) @@ -407,26 +437,8 @@ func digraph(cmd string, args []string) error { if g[to] == nil { return fmt.Errorf("no such 'to' node %q", to) } - - seen := make(nodeset) - var visit func(path nodelist, label string) bool - visit = func(path nodelist, label string) bool { - if !seen[label] { - seen[label] = true - if label == to { - append(path, label).println("\n") - return true // unwind - } - for e := range g[label] { - if visit(append(path, label), e) { - return true - } - } - } - return false - } - if !visit(make(nodelist, 0, 100), from) { - return fmt.Errorf("no path from %q to %q", args[0], args[1]) + if err := g.somepath(from, to); err != nil { + return err } case "allpaths": @@ -440,7 +452,9 @@ func digraph(cmd string, args []string) error { if g[to] == nil { return fmt.Errorf("no such 'to' node %q", to) } - g.allpaths(from, to) + if err := g.allpaths(from, to); err != nil { + return err + } case "sccs": if len(args) != 0 { diff --git a/cmd/digraph/digraph_test.go b/cmd/digraph/digraph_test.go index 3376065fe5..0f99ac34dc 100644 --- a/cmd/digraph/digraph_test.go +++ b/cmd/digraph/digraph_test.go @@ -7,6 +7,7 @@ import ( "bytes" "fmt" "reflect" + "sort" "strings" "testing" ) @@ -168,6 +169,62 @@ func TestAllpaths(t *testing.T) { } } +func TestSomepath(t *testing.T) { + for _, test := range []struct { + name string + in string + to string + // somepath is non-deterministic, so we have to provide all the + // possible options. Each option is separated with |. + wantAnyOf string + }{ + { + name: "Basic", + in: "A B\n", + to: "B", + wantAnyOf: "A B", + }, + { + name: "Basic With Cycle", + in: "A B\nB A", + to: "B", + wantAnyOf: "A B", + }, + { + name: "Two Paths", + // /-> B --\ + // A -- -> D + // \-> C --/ + in: "A B\nA C\nB D\nC D", + to: "D", + wantAnyOf: "A B\nB D|A C\nC D", + }, + } { + t.Run(test.name, func(t *testing.T) { + stdin = strings.NewReader(test.in) + stdout = new(bytes.Buffer) + if err := digraph("somepath", []string{"A", test.to}); err != nil { + t.Fatal(err) + } + + got := stdout.(fmt.Stringer).String() + lines := strings.Split(got, "\n") + sort.Strings(lines) + got = strings.Join(lines[1:], "\n") + + var oneMatch bool + for _, want := range strings.Split(test.wantAnyOf, "|") { + if got == want { + oneMatch = true + } + } + if !oneMatch { + t.Errorf("digraph(somepath, A, %s) = got %q, want any of\n%s", test.to, got, test.wantAnyOf) + } + }) + } +} + func TestSplit(t *testing.T) { for _, test := range []struct { line string