Source code for genno.compat.graphviz

import re
from os import PathLike
from typing import Literal, Mapping, Optional, Set, Union

from genno.core.describe import is_list_of_keys, label


def key_label(key):
    return unwrap(str(key))


_UNWRAP_EXPR = re.compile("^<(.*)>$")


[docs] def unwrap(label: str) -> str: """Unwrap any number of paired '<' and '>' at the start/end of `label`. These characters cause errors in graphviz/dot. """ while True: result = _UNWRAP_EXPR.sub(r"\1", label) if result == label: return result else: label = result
[docs] class Visualizer: """Handle arguments for :func:`.visualize`.""" def __init__( self, data_attributes: Mapping, function_attributes: Mapping, graph_attr: Mapping, node_attr: Mapping, edge_attr: Mapping, kwargs: Mapping, ): from graphviz import Digraph # Handle arguments self.da = data_attributes self.fa = function_attributes # Store for reference below self.ga = dict(graph_attr) self.ga.setdefault("rankdir", "BT") self.ga.update(kwargs) na = dict(node_attr) na.setdefault("fontname", "helvetica") # Create the graph and tracking collections self.graph = Digraph(graph_attr=self.ga, node_attr=na, edge_attr=edge_attr) # Nodes or edges already seen self.seen: Set[str] = set() # Nodes already connected to the graph self.connected: Set[str] = set()
[docs] def get_attrs(self, kind: Literal["data", "func"], name: str, **defaults) -> dict: """Prepare attributes for a node of `kind`. If `name` is in self.da or self.fa, use those values, filling with `defaults`; otherwise, attributes are empty except for `defaults`.""" if kind == "data": result = self.da.get(name, {}).copy() result.setdefault("shape", "ellipse") else: result = self.fa.get(name, {}).copy() # Use a directional shape like [> in LR mode; otherwise a box result.setdefault("shape", "cds" if self.ga["rankdir"] == "LR" else "box") [result.setdefault(k, v) for k, v in defaults.items()] return result
[docs] def add_edge(self, a, b) -> None: """Add an edge to the graph.""" self.graph.edge(a, b) # Update the connected nodes self.connected.update((a, b))
[docs] def add_node(self, kind: Literal["data", "func"], name: str, k, v=None) -> None: """Add a data node to the graph.""" if name in self.seen: return self.seen.add(name) _label = key_label(k) if kind == "data" else unwrap(label(v[0], max_length=50)) self.graph.node(name, **self.get_attrs(kind, k, label=_label))
[docs] def process(self, dsk: Mapping, collapse_outputs: bool): """Process the dask graph `dsk`.""" from dask.core import get_dependencies, ishashable, istask from dask.dot import name # Iterate over keys, tasks in the graph for k, v in dsk.items(): # A unique "name" for the node within `g`; similar to hash(k). k_name = name(k) if istask(v): # A task # Node name for the operation, possibly distinct from its output func_name = name((k, "function")) if not collapse_outputs else k_name # Add a node for the operation self.add_node("func", func_name, k, v) # Add an edge between the operation-node and the key-node of its output if not collapse_outputs: self.add_edge(func_name, k_name) # Add edges between the operation-node and the key-nodes for each of its # inputs for dep in get_dependencies(dsk, k): dep_name = name(dep) self.add_node("data", dep_name, dep) self.add_edge(dep_name, func_name) elif ishashable(v) and v in dsk: # Simple alias of k → v self.add_edge(name(v), k_name) elif is_list_of_keys(v, dsk): # k = list of multiple keys (genno extension) for _v in v: self.add_edge(name(_v), k_name) if not collapse_outputs or k_name in self.connected: # Something else that hasn't been seen: add a node that may never be # connected self.add_node("data", k_name, k) return self.graph
[docs] def visualize( dsk: Mapping, filename: Optional[Union[str, PathLike]] = None, format: Optional[str] = None, data_attributes: Optional[Mapping] = None, function_attributes: Optional[Mapping] = None, graph_attr: Optional[Mapping] = None, node_attr: Optional[Mapping] = None, edge_attr: Optional[Mapping] = None, collapse_outputs: bool = False, **kwargs, ): """Generate a Graphviz visualization of `dsk`. This is merged and extended version of :func:`dask.base.visualize`, :func:`dask.dot.dot_graph`, and :func:`dask.dot.to_graphviz` that produces informative output for genno graphs. Parameters ---------- dsk : The graph to display. filename : Path or str, optional The name of the file to write to disk. If the file name does not have a suffix, ".png" is used by default. If `filename` is :data:`None`, no file is written, and dask communicates with :program:`dot` using only pipes. format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional Format in which to write output file, if not given by the suffix of `filename`. Default "png". data_attributes : Graphviz attributes to apply to single nodes representing keys, in addition to `node_attr`. function_attributes : Graphviz attributes to apply to single nodes representing operations or functions, in addition to `node_attr`. graph_attr : Mapping of (attribute, value) pairs for the graph. Passed directly to :class:`.graphviz.Digraph`. node_attr : Mapping of (attribute, value) pairs set for all nodes. Passed directly to :class:`.graphviz.Digraph`. edge_attr : Mapping of (attribute, value) pairs set for all edges. Passed directly to :class:`.graphviz.Digraph`. collapse_outputs : bool, optional Omit nodes for keys that are the output of intermediate calculations. kwargs : All other keyword arguments are added to `graph_attr`. Examples -------- .. _visualize-example: Prepare a computer: >>> from genno import Computer >>> from genno.testing import add_test_data >>> c = Computer() >>> add_test_data(c) >>> c.add_product("z", "x:t", "x:y") >>> c.add("y::0", itemgetter(0), "y") >>> c.add("y0", "y::0") >>> c.add("index_to", "z::indexed", "z:y", "y::0") >>> c.add_single("all", ["z::indexed", "t", "config", "x:t"]) Visualize its contents: >>> c.visualize("example.svg") This produces the output: .. image:: _static/visualize.svg :alt: Example output from graphviz.visualize. See also -------- .describe.label """ from dask.dot import graphviz_to_file # Handle arguments v = Visualizer( data_attributes or {}, function_attributes or {}, graph_attr or {}, node_attr or {}, edge_attr or {}, kwargs, ) # Process the graph graph = v.process(dsk, collapse_outputs) return graphviz_to_file(graph, None if filename is None else str(filename), format)