Skip to content

Commit

Permalink
use id in graphviz instead of module string; better parameter naming
Browse files Browse the repository at this point in the history
  • Loading branch information
n28div committed Nov 28, 2024
1 parent bea8949 commit c3519c7
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions cirkit/symbolic/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def from_hmm(
def plot(
self,
out_path: str | PathLike[str] | None = None,
graph_direction: str = "vertical",
orientation: str = "vertical",
node_shape: str = "box",
label_font: str = "times italic bold",
label_size: str = "21pt",
Expand All @@ -867,7 +867,7 @@ def plot(
sum_color: str | Callable[[SumLayer], str] = "#607d8b",
product_label: str | Callable[[ProductLayer], str] = "⊙",
product_color: str | Callable[[ProductLayer], str] = "#24a5af",
input_label: str | Callable[[InputLayer], str] = lambda l: str(list(l.scope)),
input_label: str | Callable[[InputLayer], str] = lambda l: " ".join(map(str, l.scope)),
input_color: str | Callable[[InputLayer], str] = "#ffbd2a",
) -> graphviz.Digraph:
"""Plot the current symbolic circuit using graphviz.
Expand All @@ -884,7 +884,7 @@ def plot(
'xdot1.4', 'cmapx_np', 'dot', 'tiff', 'ps2', 'gd2', 'gv', 'ps', 'jpg',
'imap_np', 'wbmp', 'vml', 'eps', 'xdot1.2', 'pov', 'pict', 'ismap', 'exr'}.
See https://graphviz.org/docs/outputs/ for more.
graph_direction (str, optional): Direction of the graph. "vertical" puts the root
orientation (str, optional): Orientation of the graph. "vertical" puts the root
node at the top, "horizontal" at left. Defaults to "vertical".
node_shape (str, optional): Default shape for a node in the graph. Defaults to "box".
See https://graphviz.org/doc/info/shapes.html for the supported shapes.
Expand Down Expand Up @@ -923,7 +923,7 @@ def plot(
if fmt not in graphviz.FORMATS:
raise ValueError(f"Supported formats are {graphviz.FORMATS}.")

if graph_direction not in ["vertical", "horizontal"]:
if orientation not in ["vertical", "horizontal"]:
raise ValueError("Supported graph directions are only 'vertical' and 'horizontal'.")

dot: graphviz.Digraph = graphviz.Digraph(
Expand All @@ -938,34 +938,34 @@ def plot(
engine="dot",
)

dot.graph_attr["rankdir"] = "BT" if graph_direction == "vertical" else "LR"
dot.graph_attr["rankdir"] = "BT" if orientation == "vertical" else "LR"

for layer in self.layers:
match layer:
case HadamardLayer():
dot.node(
str(layer),
str(id(layer)),
product_label if isinstance(product_label, str) else product_label(layer),
color=product_color
if isinstance(product_color, str)
else product_color(layer),
)
case SumLayer():
dot.node(
str(layer),
str(id(layer)),
sum_label if isinstance(sum_label, str) else sum_label(layer),
color=sum_color if isinstance(sum_color, str) else sum_color(layer),
)
case InputLayer():
dot.node(
str(layer),
str(id(layer)),
input_label if isinstance(input_label, str) else input_label(layer),
color=input_color if isinstance(input_color, str) else input_color(layer),
)

for node, inputs in self.layers_inputs.items():
for i in inputs:
dot.edge(str(i), str(node))
dot.edge(str(id(i)), str(id(node)))

if out_path is not None:
out_path: Path = Path(out_path).with_suffix("")
Expand Down

0 comments on commit c3519c7

Please sign in to comment.