From bea8949ade752716f22fd881cfd6d2f16bd47c27 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Thu, 28 Nov 2024 13:55:55 +0100 Subject: [PATCH] support plotting of a circuit using graphviz --- cirkit/symbolic/circuit.py | 127 +++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/cirkit/symbolic/circuit.py b/cirkit/symbolic/circuit.py index 3bd9c760..5475fe54 100644 --- a/cirkit/symbolic/circuit.py +++ b/cirkit/symbolic/circuit.py @@ -5,6 +5,10 @@ from enum import IntEnum, auto from functools import cached_property from typing import Any, Protocol, cast +from os import PathLike +from pathlib import Path + +import graphviz from cirkit.symbolic.layers import ( HadamardLayer, @@ -851,6 +855,129 @@ def from_hmm( return cls(num_channels, layers, in_layers, [layers[-1]]) + def plot( + self, + out_path: str | PathLike[str] | None = None, + graph_direction: str = "vertical", + node_shape: str = "box", + label_font: str = "times italic bold", + label_size: str = "21pt", + label_color: str = "white", + sum_label: str | Callable[[SumLayer], str] = "+", + sum_color: str | Callable[[SumLayer], str] = "#607d8b", + product_label: str | Callable[[ProductLayer], str] = "⊙", + product_color: str | Callable[[ProductLayer], str] = "#24a5af", + input_label: str | Callable[[InputLayer], str] = lambda l: str(list(l.scope)), + input_color: str | Callable[[InputLayer], str] = "#ffbd2a", + ) -> graphviz.Digraph: + """Plot the current symbolic circuit using graphviz. + A graphviz object is returned, which can be visualized in jupyter notebooks. + + Args: + out_path ( str | PathLike[str] | None, optional): The output path where the plot is save + If it is None, the plot is not saved to a file. Defaults to None. + The Output file format is deduce from the path. Possible formats are: + {'jp2', 'plain-ext', 'sgi', 'x11', 'pic', 'jpeg', 'imap', 'psd', 'pct', + 'json', 'jpe', 'tif', 'tga', 'gif', 'tk', 'xlib', 'vmlz', 'json0', 'vrml', + 'gd', 'xdot', 'plain', 'cmap', 'canon', 'cgimage', 'fig', 'svg', 'dot_json', + 'bmp', 'png', 'cmapx', 'pdf', 'webp', 'ico', 'xdot_json', 'gtk', 'svgz', + 'xdot1.4', 'cmapx_np', 'dot', 'tiff', 'ps2', 'gd2', 'gv', 'ps', 'jpg', + 'imap_np', 'wbmp', 'vml', 'eps', 'xdot1.2', 'pov', 'pict', 'ismap', 'exr'}. + See https://graphviz.org/docs/outputs/ for more. + graph_direction (str, optional): Direction of the graph. "vertical" puts the root + node at the top, "horizontal" at left. Defaults to "vertical". + node_shape (str, optional): Default shape for a node in the graph. Defaults to "box". + See https://graphviz.org/doc/info/shapes.html for the supported shapes. + label_font (str, optional): Font used to render labels. Defaults to "times italic bold". + See https://graphviz.org/faq/font/ for the available fonts. + label_size (str, optional): Size of the font for labels in points. Defaults to 21pt. + label_color (str, optional): Color for the labels in the nodes. Defaults to "white". + See https://graphviz.org/docs/attr-types/color/ for supported color. + sum_label (str | Callable[[SumLayer], str], optional): Either a string or a function. + If a function is provided, then it must take as input a sum layer and returns a string + that will be used as label. Defaults to "+". + sum_color (str | Callable[[SumLayer], str], optional): Either a string or a function. + If a function is provided, then it must take as input a sum layer and returns a string + that will be used as color for the sum node. Defaults to "#607d8b". + product_label (str | Callable[[ProductLayer], str], optional): Either a string or a function. + If a function is provided, then it must take as input a product layer and returns a string + that will be used as label. Defaults to "⊙". + product_color (str | Callable[[ProductLayer], str], optional): Either a string or a function. + If a function is provided, then it must take as input a product layer and returns a string + that will be used as color for the product node. Defaults to "#24a5af". + input_label (_type_, optional): Either a string or a function. + If a function is provided, then it must take as input an input layer and returns a string + that will be used as label. Defaults to using the scope of the layer. + input_color (str | Callable[[ProductLayer], str], optional): Either a string or a function. + If a function is provided, then it must take as input an input layer and returns a string + that will be used as color for the input layer node. Defaults to "#ffbd2a". + + Raises: + ValueError: The format is not among the supported ones. + ValueError: The direction is not among the supported ones. + + Returns: + graphviz.Digraph: _description_ + """ + fmt: str = Path(out_path).suffix.replace(".", "") + if fmt not in graphviz.FORMATS: + raise ValueError(f"Supported formats are {graphviz.FORMATS}.") + + if graph_direction not in ["vertical", "horizontal"]: + raise ValueError("Supported graph directions are only 'vertical' and 'horizontal'.") + + dot: graphviz.Digraph = graphviz.Digraph( + format=fmt, + node_attr={ + "shape": node_shape, + "style": "filled", + "fontcolor": label_color, + "fontsize": label_size, + "fontname": label_font, + }, + engine="dot", + ) + + dot.graph_attr["rankdir"] = "BT" if graph_direction == "vertical" else "LR" + + for layer in self.layers: + match layer: + case HadamardLayer(): + dot.node( + str(layer), + product_label if isinstance(product_label, str) else product_label(layer), + color=product_color + if isinstance(product_color, str) + else product_color(layer), + ) + case SumLayer(): + dot.node( + str(layer), + sum_label if isinstance(sum_label, str) else sum_label(layer), + color=sum_color if isinstance(sum_color, str) else sum_color(layer), + ) + case InputLayer(): + dot.node( + str(layer), + input_label if isinstance(input_label, str) else input_label(layer), + color=input_color if isinstance(input_color, str) else input_color(layer), + ) + + for node, inputs in self.layers_inputs.items(): + for i in inputs: + dot.edge(str(i), str(node)) + + if out_path is not None: + out_path: Path = Path(out_path).with_suffix("") + + if fmt == "dot": + with open(out_path, "w", encoding="utf8") as f: + f.write(dot.source) + else: + dot.format = fmt + dot.render(out_path) + + return dot def are_compatible(sc1: Circuit, sc2: Circuit) -> bool: """Check if two symbolic circuits are compatible.