diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml
index 2176f4c..6004c54 100644
--- a/.github/workflows/pylint.yml
+++ b/.github/workflows/pylint.yml
@@ -7,9 +7,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.9", "3.10", "3.11"]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- - name: "Checkout repository"
+ - name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 907101e..53e2f44 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -16,10 +16,10 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
- python-version: ['3.9', '3.10', '3.11']
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- - name: "Checkout repository"
+ - name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
diff --git a/coverage-badge.svg b/coverage-badge.svg
index 166ad97..cb11411 100644
--- a/coverage-badge.svg
+++ b/coverage-badge.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/tensorhue/eastereggs.py b/tensorhue/eastereggs.py
index e19548a..e83ae76 100644
--- a/tensorhue/eastereggs.py
+++ b/tensorhue/eastereggs.py
@@ -2,10 +2,19 @@
from matplotlib.colors import LinearSegmentedColormap
from rich.color_triplet import ColorTriplet
from tensorhue.colors import ColorScheme
-from tensorhue.viz import viz
+from tensorhue.viz import viz, get_terminal_width
-def pride():
+def pride(width: int = None):
+ """
+ Prints a pride flag in the terminal
+
+ Args:
+ width (int, optional): The width of the pride flag. If none is specified,
+ the full width of the terminal is used.
+ """
+ if width is None:
+ width = get_terminal_width(default_width=10)
pride_colors = [
ColorTriplet(228, 3, 3),
ColorTriplet(255, 140, 0),
@@ -16,5 +25,5 @@ def pride():
]
pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride")
pride_cs = ColorScheme(colormap=pride_cm)
- arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), 10, axis=1)
- viz(arr, colorscheme=pride_cs)
+ arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), width, axis=1)
+ viz(arr, colorscheme=pride_cs, legend=False)
diff --git a/tensorhue/viz.py b/tensorhue/viz.py
index 98bccf4..acd3cd1 100644
--- a/tensorhue/viz.py
+++ b/tensorhue/viz.py
@@ -1,3 +1,5 @@
+import os
+import sys
from rich.console import Console
import numpy as np
from tensorhue.colors import ColorScheme
@@ -18,13 +20,14 @@ def viz(tensor, *args, **kwargs):
) from e
-def _viz(self, colorscheme: ColorScheme = None):
+def _viz(self, colorscheme: ColorScheme = None, legend: bool = True):
"""
Prints a tensor using colored Unicode art representation.
Args:
colorscheme (ColorScheme, optional): The color scheme to use.
Defaults to None, which means the global default color scheme is used.
+ legend (bool, optional): Whether or not to include legend information (like the shape)
"""
if colorscheme is None:
colorscheme = PRINT_OPTS.colorscheme
@@ -32,24 +35,83 @@ def _viz(self, colorscheme: ColorScheme = None):
self = self._tensorhue_to_numpy()
shape = self.shape
- if len(shape) > 2:
+ if len(shape) == 1:
+ self = self[np.newaxis, :]
+ elif len(shape) > 2:
raise NotImplementedError(
"Visualization for tensors with more than 2 dimensions is under development. Please slice them for now."
)
- colors = colorscheme(self)[..., :3]
+ result_lines = _viz_2d(self, colorscheme)
+ if legend:
+ result_lines.append(f"[italic]shape = {shape}[/]")
+
+ c = Console(log_path=False, record=False)
+ c.print("\n".join(result_lines))
+
+
+def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
+ """
+ Constructs a list of rich-compatible strings out of a 2D numpy array.
+
+ Args:
+ array_2d (np.ndarray): The 2-dimensional numpy array
+ colorscheme (ColorScheme): The color scheme to use
+ """
result_lines = [""]
+ terminal_width = get_terminal_width()
+ shape = array_2d.shape
+
+ if shape[1] > terminal_width:
+ slice_left = (terminal_width - 5) // 2
+ slice_right = slice_left + (terminal_width - 5) % 2
+ colors_right = colorscheme(array_2d[:, -slice_right:])[..., :3]
+ else:
+ slice_left = shape[1]
+ slice_right = colors_right = False
+
+ colors_left = colorscheme(array_2d[:, :slice_left])[..., :3]
+
for y in range(0, shape[0] - 1, 2):
- for x in range(shape[-1]):
+ for x in range(slice_left):
result_lines[
-1
- ] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
+ ] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]"
+ if slice_right:
+ result_lines[-1] += " ··· "
+ for x in range(slice_right):
+ result_lines[
+ -1
+ ] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]"
result_lines.append("")
if shape[0] % 2 == 1:
- for x in range(shape[1]):
- result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"
+ for x in range(slice_left):
+ result_lines[-1] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]"
+ if slice_right:
+ result_lines[-1] += " ··· "
+ for x in range(slice_right):
+ result_lines[
+ -1
+ ] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]"
+ else:
+ result_lines = result_lines[:-1]
- c = Console(log_path=False, record=False)
- c.print("\n".join(result_lines))
+ return result_lines
+
+
+def get_terminal_width(default_width: int = 100) -> int:
+ """
+ Returns the terminal width if the standard output is connected to a terminal. Otherwise, returns default_width.
+
+ Args:
+ default_width (int, optional): The default width to use if there is no terminal.
+ """
+ if sys.stdout.isatty():
+ try:
+ return os.get_terminal_size().columns
+ except OSError:
+ return default_width
+ else:
+ return default_width
diff --git a/tests/test_eastereggs.py b/tests/test_eastereggs.py
index e6b9ac2..6aee5d9 100644
--- a/tests/test_eastereggs.py
+++ b/tests/test_eastereggs.py
@@ -4,5 +4,6 @@
def test_pride_output(capsys):
pride()
captured = capsys.readouterr()
- assert len(captured.out.split("\n")) == 5
- assert captured.out.count("▀") == 30
+ out = captured.out.rstrip("\n")
+ assert len(out.split("\n")) == 3
+ assert out.count("▀") == 30
diff --git a/tests/test_viz.py b/tests/test_viz.py
new file mode 100644
index 0000000..b3ca914
--- /dev/null
+++ b/tests/test_viz.py
@@ -0,0 +1,54 @@
+import pytest
+import torch
+import numpy as np
+from tensorhue.viz import viz
+from tensorhue._torch import _tensorhue_to_numpy_torch
+
+
+@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
+def test_1d_tensor(tensor, capsys):
+ viz(tensor)
+ captured = capsys.readouterr()
+ out = captured.out.rstrip("\n")
+ assert len(out.split("\n")) == 2
+ assert out.count("▀") == 10
+ assert out.split("\n")[-1] == f"shape = {tensor.shape}"
+
+
+@pytest.mark.parametrize("tensor", [np.ones((10, 10)), _tensorhue_to_numpy_torch(torch.ones(10, 10))])
+def test_2d_tensor(tensor, capsys):
+ viz(tensor)
+ captured = capsys.readouterr()
+ out = captured.out.rstrip("\n")
+ assert len(out.split("\n")) == 6
+ assert out.count("▀") == 100 / 2
+ assert out.split("\n")[-1] == f"shape = {tensor.shape}"
+
+
+@pytest.mark.parametrize("tensor", [np.ones(200), _tensorhue_to_numpy_torch(torch.ones(200))])
+def test_1d_tensor_too_wide(tensor, capsys):
+ viz(tensor)
+ captured = capsys.readouterr()
+ out = captured.out.rstrip("\n")
+ assert out.count(" ··· ") == 1
+ assert out.count("▀") == 95
+ assert out.split("\n")[-1] == f"shape = {tensor.shape}"
+
+
+@pytest.mark.parametrize("tensor", [np.ones((10, 200)), _tensorhue_to_numpy_torch(torch.ones(10, 200))])
+def test_2d_tensor_too_wide(tensor, capsys):
+ viz(tensor)
+ captured = capsys.readouterr()
+ out = captured.out.rstrip("\n")
+ assert out.count(" ··· ") == 5
+ assert out.count("▀") == 950 / 2
+ assert out.split("\n")[-1] == f"shape = {tensor.shape}"
+
+
+@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
+def test_no_legend(tensor, capsys):
+ viz(tensor, legend=False)
+ captured = capsys.readouterr()
+ out = captured.out.rstrip("\n")
+ assert len(out.split("\n")) == 1
+ assert out.count("▀") == 10