diff --git a/cmd/digraph/digraph.go b/cmd/digraph/digraph.go index 9d4713ea9c..9d42b83c3d 100644 --- a/cmd/digraph/digraph.go +++ b/cmd/digraph/digraph.go @@ -35,6 +35,8 @@ The support commands are: all strongly connected components (one per line) scc the set of nodes nodes strongly connected to the specified one + focus + the subgraph containing all directed paths that pass through the specified node Input format: @@ -91,6 +93,7 @@ import ( "os" "sort" "strconv" + "strings" "unicode" "unicode/utf8" ) @@ -121,6 +124,8 @@ The support commands are: all strongly connected components (one per line) scc the set of nodes nodes strongly connected to the specified one + focus + the subgraph containing all directed paths that pass through the specified node `) os.Exit(2) } @@ -152,7 +157,7 @@ func (l nodelist) println(sep string) { fmt.Fprintln(stdout) } -type nodeset map[string]bool +type nodeset map[string]bool // TODO(deklerk): change bool to struct to reduce memory footprint func (s nodeset) sort() nodelist { nodes := make(nodelist, len(s)) @@ -498,6 +503,36 @@ func digraph(cmd string, args []string) error { } } + case "focus": + if len(args) != 1 { + return fmt.Errorf("usage: digraph focus ") + } + node := args[0] + if g[node] == nil { + return fmt.Errorf("no such node %q", node) + } + + edges := make(map[string]struct{}) + for from := range g.reachableFrom(nodeset{node: true}) { + for to := range g[from] { + edges[fmt.Sprintf("%s %s", from, to)] = struct{}{} + } + } + + gtrans := g.transpose() + for from := range gtrans.reachableFrom(nodeset{node: true}) { + for to := range gtrans[from] { + edges[fmt.Sprintf("%s %s", to, from)] = struct{}{} + } + } + + edgesSorted := make([]string, len(edges)) + for e := range edges { + edgesSorted = append(edgesSorted, e) + } + sort.Strings(edgesSorted) + fmt.Fprintln(stdout, strings.Join(edgesSorted, "\n")) + default: return fmt.Errorf("no such command %q", cmd) } diff --git a/cmd/digraph/digraph_test.go b/cmd/digraph/digraph_test.go index cf4d3b5e1a..1746fcaa69 100644 --- a/cmd/digraph/digraph_test.go +++ b/cmd/digraph/digraph_test.go @@ -283,3 +283,64 @@ func TestQuotedLength(t *testing.T) { } } } + +func TestFocus(t *testing.T) { + for _, test := range []struct { + name string + in string + focus string + want string + }{ + { + name: "Basic", + in: "A B", + focus: "B", + want: "A B\n", + }, + { + name: "Some Nodes Not Included", + // C does not have a path involving B, and should not be included + // in the output. + in: "A B\nA C", + focus: "B", + want: "A B\n", + }, + { + name: "Cycle In Path", + // A <-> B -> C + in: "A B\nB A\nB C", + focus: "C", + want: "A B\nB A\nB C\n", + }, + { + name: "Cycle Out Of Path", + // C <- A <->B + in: "A B\nB A\nB C", + focus: "C", + want: "A B\nB A\nB C\n", + }, + { + name: "Complex", + // Paths in and out from focus. + // /-> F + // /-> B -> D -- + // A -- \-> E + // \-> C + in: "A B\nA C\nB D\nD F\nD E", + focus: "D", + want: "A B\nB D\nD E\nD F\n", + }, + } { + t.Run(test.name, func(t *testing.T) { + stdin = strings.NewReader(test.in) + stdout = new(bytes.Buffer) + if err := digraph("focus", []string{test.focus}); err != nil { + t.Fatal(err) + } + got := stdout.(fmt.Stringer).String() + if got != test.want { + t.Errorf("digraph(focus, %s) = got %q, want %q", test.focus, got, test.want) + } + }) + } +}