diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 2176f4c..6004c54 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -7,9 +7,9 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - - name: "Checkout repository" + - name: Checkout repository uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 907101e..53e2f44 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,10 +16,10 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ['3.9', '3.10', '3.11'] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - - name: "Checkout repository" + - name: Checkout repository uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/coverage-badge.svg b/coverage-badge.svg index 166ad97..cb11411 100644 --- a/coverage-badge.svg +++ b/coverage-badge.svg @@ -1 +1 @@ -coverage: 88.95%coverage88.95% \ No newline at end of file +coverage: 90.28%coverage90.28% \ No newline at end of file diff --git a/tensorhue/eastereggs.py b/tensorhue/eastereggs.py index e19548a..e83ae76 100644 --- a/tensorhue/eastereggs.py +++ b/tensorhue/eastereggs.py @@ -2,10 +2,19 @@ from matplotlib.colors import LinearSegmentedColormap from rich.color_triplet import ColorTriplet from tensorhue.colors import ColorScheme -from tensorhue.viz import viz +from tensorhue.viz import viz, get_terminal_width -def pride(): +def pride(width: int = None): + """ + Prints a pride flag in the terminal + + Args: + width (int, optional): The width of the pride flag. If none is specified, + the full width of the terminal is used. + """ + if width is None: + width = get_terminal_width(default_width=10) pride_colors = [ ColorTriplet(228, 3, 3), ColorTriplet(255, 140, 0), @@ -16,5 +25,5 @@ def pride(): ] pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride") pride_cs = ColorScheme(colormap=pride_cm) - arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), 10, axis=1) - viz(arr, colorscheme=pride_cs) + arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), width, axis=1) + viz(arr, colorscheme=pride_cs, legend=False) diff --git a/tensorhue/viz.py b/tensorhue/viz.py index 98bccf4..acd3cd1 100644 --- a/tensorhue/viz.py +++ b/tensorhue/viz.py @@ -1,3 +1,5 @@ +import os +import sys from rich.console import Console import numpy as np from tensorhue.colors import ColorScheme @@ -18,13 +20,14 @@ def viz(tensor, *args, **kwargs): ) from e -def _viz(self, colorscheme: ColorScheme = None): +def _viz(self, colorscheme: ColorScheme = None, legend: bool = True): """ Prints a tensor using colored Unicode art representation. Args: colorscheme (ColorScheme, optional): The color scheme to use. Defaults to None, which means the global default color scheme is used. + legend (bool, optional): Whether or not to include legend information (like the shape) """ if colorscheme is None: colorscheme = PRINT_OPTS.colorscheme @@ -32,24 +35,83 @@ def _viz(self, colorscheme: ColorScheme = None): self = self._tensorhue_to_numpy() shape = self.shape - if len(shape) > 2: + if len(shape) == 1: + self = self[np.newaxis, :] + elif len(shape) > 2: raise NotImplementedError( "Visualization for tensors with more than 2 dimensions is under development. Please slice them for now." ) - colors = colorscheme(self)[..., :3] + result_lines = _viz_2d(self, colorscheme) + if legend: + result_lines.append(f"[italic]shape = {shape}[/]") + + c = Console(log_path=False, record=False) + c.print("\n".join(result_lines)) + + +def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]: + """ + Constructs a list of rich-compatible strings out of a 2D numpy array. + + Args: + array_2d (np.ndarray): The 2-dimensional numpy array + colorscheme (ColorScheme): The color scheme to use + """ result_lines = [""] + terminal_width = get_terminal_width() + shape = array_2d.shape + + if shape[1] > terminal_width: + slice_left = (terminal_width - 5) // 2 + slice_right = slice_left + (terminal_width - 5) % 2 + colors_right = colorscheme(array_2d[:, -slice_right:])[..., :3] + else: + slice_left = shape[1] + slice_right = colors_right = False + + colors_left = colorscheme(array_2d[:, :slice_left])[..., :3] + for y in range(0, shape[0] - 1, 2): - for x in range(shape[-1]): + for x in range(slice_left): result_lines[ -1 - ] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]" + ] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]" + if slice_right: + result_lines[-1] += " ··· " + for x in range(slice_right): + result_lines[ + -1 + ] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]" result_lines.append("") if shape[0] % 2 == 1: - for x in range(shape[1]): - result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]" + for x in range(slice_left): + result_lines[-1] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]" + if slice_right: + result_lines[-1] += " ··· " + for x in range(slice_right): + result_lines[ + -1 + ] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]" + else: + result_lines = result_lines[:-1] - c = Console(log_path=False, record=False) - c.print("\n".join(result_lines)) + return result_lines + + +def get_terminal_width(default_width: int = 100) -> int: + """ + Returns the terminal width if the standard output is connected to a terminal. Otherwise, returns default_width. + + Args: + default_width (int, optional): The default width to use if there is no terminal. + """ + if sys.stdout.isatty(): + try: + return os.get_terminal_size().columns + except OSError: + return default_width + else: + return default_width diff --git a/tests/test_eastereggs.py b/tests/test_eastereggs.py index e6b9ac2..6aee5d9 100644 --- a/tests/test_eastereggs.py +++ b/tests/test_eastereggs.py @@ -4,5 +4,6 @@ def test_pride_output(capsys): pride() captured = capsys.readouterr() - assert len(captured.out.split("\n")) == 5 - assert captured.out.count("▀") == 30 + out = captured.out.rstrip("\n") + assert len(out.split("\n")) == 3 + assert out.count("▀") == 30 diff --git a/tests/test_viz.py b/tests/test_viz.py new file mode 100644 index 0000000..b3ca914 --- /dev/null +++ b/tests/test_viz.py @@ -0,0 +1,54 @@ +import pytest +import torch +import numpy as np +from tensorhue.viz import viz +from tensorhue._torch import _tensorhue_to_numpy_torch + + +@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))]) +def test_1d_tensor(tensor, capsys): + viz(tensor) + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + assert len(out.split("\n")) == 2 + assert out.count("▀") == 10 + assert out.split("\n")[-1] == f"shape = {tensor.shape}" + + +@pytest.mark.parametrize("tensor", [np.ones((10, 10)), _tensorhue_to_numpy_torch(torch.ones(10, 10))]) +def test_2d_tensor(tensor, capsys): + viz(tensor) + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + assert len(out.split("\n")) == 6 + assert out.count("▀") == 100 / 2 + assert out.split("\n")[-1] == f"shape = {tensor.shape}" + + +@pytest.mark.parametrize("tensor", [np.ones(200), _tensorhue_to_numpy_torch(torch.ones(200))]) +def test_1d_tensor_too_wide(tensor, capsys): + viz(tensor) + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + assert out.count(" ··· ") == 1 + assert out.count("▀") == 95 + assert out.split("\n")[-1] == f"shape = {tensor.shape}" + + +@pytest.mark.parametrize("tensor", [np.ones((10, 200)), _tensorhue_to_numpy_torch(torch.ones(10, 200))]) +def test_2d_tensor_too_wide(tensor, capsys): + viz(tensor) + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + assert out.count(" ··· ") == 5 + assert out.count("▀") == 950 / 2 + assert out.split("\n")[-1] == f"shape = {tensor.shape}" + + +@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))]) +def test_no_legend(tensor, capsys): + viz(tensor, legend=False) + captured = capsys.readouterr() + out = captured.out.rstrip("\n") + assert len(out.split("\n")) == 1 + assert out.count("▀") == 10