Skip to content

Commit

Permalink
support plotting of a circuit using graphviz
Browse files Browse the repository at this point in the history
  • Loading branch information
n28div committed Nov 28, 2024
1 parent 9efd59a commit bea8949
Showing 1 changed file with 127 additions and 0 deletions.
127 changes: 127 additions & 0 deletions cirkit/symbolic/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from enum import IntEnum, auto
from functools import cached_property
from typing import Any, Protocol, cast
from os import PathLike
from pathlib import Path

import graphviz

from cirkit.symbolic.layers import (
HadamardLayer,
Expand Down Expand Up @@ -851,6 +855,129 @@ def from_hmm(

return cls(num_channels, layers, in_layers, [layers[-1]])

def plot(
self,
out_path: str | PathLike[str] | None = None,
graph_direction: str = "vertical",
node_shape: str = "box",
label_font: str = "times italic bold",
label_size: str = "21pt",
label_color: str = "white",
sum_label: str | Callable[[SumLayer], str] = "+",
sum_color: str | Callable[[SumLayer], str] = "#607d8b",
product_label: str | Callable[[ProductLayer], str] = "⊙",
product_color: str | Callable[[ProductLayer], str] = "#24a5af",
input_label: str | Callable[[InputLayer], str] = lambda l: str(list(l.scope)),
input_color: str | Callable[[InputLayer], str] = "#ffbd2a",
) -> graphviz.Digraph:
"""Plot the current symbolic circuit using graphviz.
A graphviz object is returned, which can be visualized in jupyter notebooks.
Args:
out_path ( str | PathLike[str] | None, optional): The output path where the plot is save
If it is None, the plot is not saved to a file. Defaults to None.
The Output file format is deduce from the path. Possible formats are:
{'jp2', 'plain-ext', 'sgi', 'x11', 'pic', 'jpeg', 'imap', 'psd', 'pct',
'json', 'jpe', 'tif', 'tga', 'gif', 'tk', 'xlib', 'vmlz', 'json0', 'vrml',
'gd', 'xdot', 'plain', 'cmap', 'canon', 'cgimage', 'fig', 'svg', 'dot_json',
'bmp', 'png', 'cmapx', 'pdf', 'webp', 'ico', 'xdot_json', 'gtk', 'svgz',
'xdot1.4', 'cmapx_np', 'dot', 'tiff', 'ps2', 'gd2', 'gv', 'ps', 'jpg',
'imap_np', 'wbmp', 'vml', 'eps', 'xdot1.2', 'pov', 'pict', 'ismap', 'exr'}.
See https://graphviz.org/docs/outputs/ for more.
graph_direction (str, optional): Direction of the graph. "vertical" puts the root
node at the top, "horizontal" at left. Defaults to "vertical".
node_shape (str, optional): Default shape for a node in the graph. Defaults to "box".
See https://graphviz.org/doc/info/shapes.html for the supported shapes.
label_font (str, optional): Font used to render labels. Defaults to "times italic bold".
See https://graphviz.org/faq/font/ for the available fonts.
label_size (str, optional): Size of the font for labels in points. Defaults to 21pt.
label_color (str, optional): Color for the labels in the nodes. Defaults to "white".
See https://graphviz.org/docs/attr-types/color/ for supported color.
sum_label (str | Callable[[SumLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a sum layer and returns a string
that will be used as label. Defaults to "+".
sum_color (str | Callable[[SumLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a sum layer and returns a string
that will be used as color for the sum node. Defaults to "#607d8b".
product_label (str | Callable[[ProductLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a product layer and returns a string
that will be used as label. Defaults to "⊙".
product_color (str | Callable[[ProductLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input a product layer and returns a string
that will be used as color for the product node. Defaults to "#24a5af".
input_label (_type_, optional): Either a string or a function.
If a function is provided, then it must take as input an input layer and returns a string
that will be used as label. Defaults to using the scope of the layer.
input_color (str | Callable[[ProductLayer], str], optional): Either a string or a function.
If a function is provided, then it must take as input an input layer and returns a string
that will be used as color for the input layer node. Defaults to "#ffbd2a".
Raises:
ValueError: The format is not among the supported ones.
ValueError: The direction is not among the supported ones.
Returns:
graphviz.Digraph: _description_
"""
fmt: str = Path(out_path).suffix.replace(".", "")
if fmt not in graphviz.FORMATS:
raise ValueError(f"Supported formats are {graphviz.FORMATS}.")

if graph_direction not in ["vertical", "horizontal"]:
raise ValueError("Supported graph directions are only 'vertical' and 'horizontal'.")

dot: graphviz.Digraph = graphviz.Digraph(
format=fmt,
node_attr={
"shape": node_shape,
"style": "filled",
"fontcolor": label_color,
"fontsize": label_size,
"fontname": label_font,
},
engine="dot",
)

dot.graph_attr["rankdir"] = "BT" if graph_direction == "vertical" else "LR"

for layer in self.layers:
match layer:
case HadamardLayer():
dot.node(
str(layer),
product_label if isinstance(product_label, str) else product_label(layer),
color=product_color
if isinstance(product_color, str)
else product_color(layer),
)
case SumLayer():
dot.node(
str(layer),
sum_label if isinstance(sum_label, str) else sum_label(layer),
color=sum_color if isinstance(sum_color, str) else sum_color(layer),
)
case InputLayer():
dot.node(
str(layer),
input_label if isinstance(input_label, str) else input_label(layer),
color=input_color if isinstance(input_color, str) else input_color(layer),
)

for node, inputs in self.layers_inputs.items():
for i in inputs:
dot.edge(str(i), str(node))

if out_path is not None:
out_path: Path = Path(out_path).with_suffix("")

if fmt == "dot":
with open(out_path, "w", encoding="utf8") as f:
f.write(dot.source)
else:
dot.format = fmt
dot.render(out_path)

return dot

def are_compatible(sc1: Circuit, sc2: Circuit) -> bool:
"""Check if two symbolic circuits are compatible.
Expand Down

0 comments on commit bea8949

Please sign in to comment.