Skip to content

Commit

Permalink
Merge pull request #2 from epistoteles/feature-1d-and-dynamic-sizing
Browse files Browse the repository at this point in the history
Support 1D tensors and dynamic sizing to terminal width
  • Loading branch information
epistoteles authored Jun 21, 2024
2 parents ffe95b5 + a565763 commit 97d043d
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- name: "Checkout repository"
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
python-version: ['3.9', '3.10', '3.11']
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- name: "Checkout repository"
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion coverage-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 13 additions & 4 deletions tensorhue/eastereggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
from matplotlib.colors import LinearSegmentedColormap
from rich.color_triplet import ColorTriplet
from tensorhue.colors import ColorScheme
from tensorhue.viz import viz
from tensorhue.viz import viz, get_terminal_width


def pride():
def pride(width: int = None):
"""
Prints a pride flag in the terminal
Args:
width (int, optional): The width of the pride flag. If none is specified,
the full width of the terminal is used.
"""
if width is None:
width = get_terminal_width(default_width=10)
pride_colors = [
ColorTriplet(228, 3, 3),
ColorTriplet(255, 140, 0),
Expand All @@ -16,5 +25,5 @@ def pride():
]
pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride")
pride_cs = ColorScheme(colormap=pride_cm)
arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), 10, axis=1)
viz(arr, colorscheme=pride_cs)
arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), width, axis=1)
viz(arr, colorscheme=pride_cs, legend=False)
80 changes: 71 additions & 9 deletions tensorhue/viz.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import sys
from rich.console import Console
import numpy as np
from tensorhue.colors import ColorScheme
Expand All @@ -18,38 +20,98 @@ def viz(tensor, *args, **kwargs):
) from e


def _viz(self, colorscheme: ColorScheme = None):
def _viz(self, colorscheme: ColorScheme = None, legend: bool = True):
"""
Prints a tensor using colored Unicode art representation.
Args:
colorscheme (ColorScheme, optional): The color scheme to use.
Defaults to None, which means the global default color scheme is used.
legend (bool, optional): Whether or not to include legend information (like the shape)
"""
if colorscheme is None:
colorscheme = PRINT_OPTS.colorscheme

self = self._tensorhue_to_numpy()
shape = self.shape

if len(shape) > 2:
if len(shape) == 1:
self = self[np.newaxis, :]
elif len(shape) > 2:
raise NotImplementedError(
"Visualization for tensors with more than 2 dimensions is under development. Please slice them for now."
)

colors = colorscheme(self)[..., :3]
result_lines = _viz_2d(self, colorscheme)

if legend:
result_lines.append(f"[italic]shape = {shape}[/]")

c = Console(log_path=False, record=False)
c.print("\n".join(result_lines))


def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
"""
Constructs a list of rich-compatible strings out of a 2D numpy array.
Args:
array_2d (np.ndarray): The 2-dimensional numpy array
colorscheme (ColorScheme): The color scheme to use
"""
result_lines = [""]
terminal_width = get_terminal_width()
shape = array_2d.shape

if shape[1] > terminal_width:
slice_left = (terminal_width - 5) // 2
slice_right = slice_left + (terminal_width - 5) % 2
colors_right = colorscheme(array_2d[:, -slice_right:])[..., :3]
else:
slice_left = shape[1]
slice_right = colors_right = False

colors_left = colorscheme(array_2d[:, :slice_left])[..., :3]

for y in range(0, shape[0] - 1, 2):
for x in range(shape[-1]):
for x in range(slice_left):
result_lines[
-1
] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]"
if slice_right:
result_lines[-1] += " ··· "
for x in range(slice_right):
result_lines[
-1
] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]"
result_lines.append("")

if shape[0] % 2 == 1:
for x in range(shape[1]):
result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"
for x in range(slice_left):
result_lines[-1] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]"
if slice_right:
result_lines[-1] += " ··· "
for x in range(slice_right):
result_lines[
-1
] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]"
else:
result_lines = result_lines[:-1]

c = Console(log_path=False, record=False)
c.print("\n".join(result_lines))
return result_lines


def get_terminal_width(default_width: int = 100) -> int:
"""
Returns the terminal width if the standard output is connected to a terminal. Otherwise, returns default_width.
Args:
default_width (int, optional): The default width to use if there is no terminal.
"""
if sys.stdout.isatty():
try:
return os.get_terminal_size().columns
except OSError:
return default_width
else:
return default_width
5 changes: 3 additions & 2 deletions tests/test_eastereggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
def test_pride_output(capsys):
pride()
captured = capsys.readouterr()
assert len(captured.out.split("\n")) == 5
assert captured.out.count("▀") == 30
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 3
assert out.count("▀") == 30
54 changes: 54 additions & 0 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch
import numpy as np
from tensorhue.viz import viz
from tensorhue._torch import _tensorhue_to_numpy_torch


@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
def test_1d_tensor(tensor, capsys):
viz(tensor)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 2
assert out.count("▀") == 10
assert out.split("\n")[-1] == f"shape = {tensor.shape}"


@pytest.mark.parametrize("tensor", [np.ones((10, 10)), _tensorhue_to_numpy_torch(torch.ones(10, 10))])
def test_2d_tensor(tensor, capsys):
viz(tensor)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 6
assert out.count("▀") == 100 / 2
assert out.split("\n")[-1] == f"shape = {tensor.shape}"


@pytest.mark.parametrize("tensor", [np.ones(200), _tensorhue_to_numpy_torch(torch.ones(200))])
def test_1d_tensor_too_wide(tensor, capsys):
viz(tensor)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert out.count(" ··· ") == 1
assert out.count("▀") == 95
assert out.split("\n")[-1] == f"shape = {tensor.shape}"


@pytest.mark.parametrize("tensor", [np.ones((10, 200)), _tensorhue_to_numpy_torch(torch.ones(10, 200))])
def test_2d_tensor_too_wide(tensor, capsys):
viz(tensor)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert out.count(" ··· ") == 5
assert out.count("▀") == 950 / 2
assert out.split("\n")[-1] == f"shape = {tensor.shape}"


@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
def test_no_legend(tensor, capsys):
viz(tensor, legend=False)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 1
assert out.count("▀") == 10

0 comments on commit 97d043d

Please sign in to comment.