Skip to content

Commit

Permalink
Move transforms to transforms file
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed May 15, 2024
1 parent 60bff1d commit 250f2a6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 53 deletions.
52 changes: 1 addition & 51 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from qusi.internal.light_curve_transforms import (
from_light_curve_observation_to_fluxes_array_and_label_array,
pair_array_to_tensor,
pair_array_to_tensor, normalize_tensor_by_modified_z_score, make_uniform_length,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -337,56 +337,6 @@ def default_light_curve_post_injection_transform(
return x


def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
"""
Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median
absolute deviation.
:param tensor: The tensor to normalize.
:return: The normalized tensor.
"""
median = torch.median(tensor)
deviation_from_median = tensor - median
absolute_deviation_from_median = torch.abs(deviation_from_median)
median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median)
if median_absolute_deviation_from_median != 0:
modified_z_score = (
0.6745 * deviation_from_median / median_absolute_deviation_from_median
)
else:
modified_z_score = torch.zeros_like(tensor)
return modified_z_score


def make_uniform_length(
example: np.ndarray, length: int, *, randomize: bool = True
) -> np.ndarray:
"""Makes the example a specific length, by clipping those too large and repeating those too small."""
if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases.
raise ValueError(
f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}"
)
if randomize:
example = randomly_roll_elements(example)
if example.shape[0] == length:
pass
elif example.shape[0] > length:
example = example[:length]
else:
elements_to_repeat = length - example.shape[0]
if len(example.shape) == 1:
example = np.pad(example, (0, elements_to_repeat), mode="wrap")
else:
example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap")
return example


def randomly_roll_elements(example: np.ndarray) -> np.ndarray:
"""Randomly rolls the elements."""
example = np.roll(example, np.random.randint(example.shape[0]), axis=0)
return example


class OutOfBoundsInjectionHandlingMethod(Enum):
"""
An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal.
Expand Down
46 changes: 46 additions & 0 deletions src/qusi/internal/light_curve_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
import numpy.typing as npt
import torch
Expand Down Expand Up @@ -38,3 +40,47 @@ def randomly_roll_elements(example: np.ndarray) -> np.ndarray:
"""Randomly rolls the elements."""
example = np.roll(example, np.random.randint(example.shape[0]), axis=0)
return example


def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
"""
Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median
absolute deviation.
:param tensor: The tensor to normalize.
:return: The normalized tensor.
"""
median = torch.median(tensor)
deviation_from_median = tensor - median
absolute_deviation_from_median = torch.abs(deviation_from_median)
median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median)
if median_absolute_deviation_from_median != 0:
modified_z_score = (
0.6745 * deviation_from_median / median_absolute_deviation_from_median
)
else:
modified_z_score = torch.zeros_like(tensor)
return modified_z_score


def make_uniform_length(
example: np.ndarray, length: int, *, randomize: bool = True
) -> np.ndarray:
"""Makes the example a specific length, by clipping those too large and repeating those too small."""
if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases.
raise ValueError(
f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}"
)
if randomize:
example = randomly_roll_elements(example)
if example.shape[0] == length:
pass
elif example.shape[0] > length:
example = example[:length]
else:
elements_to_repeat = length - example.shape[0]
if len(example.shape) == 1:
example = np.pad(example, (0, elements_to_repeat), mode="wrap")
else:
example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap")
return example
5 changes: 3 additions & 2 deletions src/qusi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
"""
from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve
from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \
default_light_curve_observation_post_injection_transform, make_uniform_length
default_light_curve_observation_post_injection_transform
from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \
randomly_roll_light_curve_observation
from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \
pair_array_to_tensor
pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score

__all__ = [
'default_light_curve_post_injection_transform',
'default_light_curve_observation_post_injection_transform',
'from_light_curve_observation_to_fluxes_array_and_label_array',
'make_uniform_length',
'normalize_tensor_by_modified_z_score',
'pair_array_to_tensor',
'randomly_roll_light_curve',
'randomly_roll_light_curve_observation',
Expand Down

0 comments on commit 250f2a6

Please sign in to comment.