diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index f320cd5..b0f435c 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -32,7 +32,7 @@ ) from qusi.internal.light_curve_transforms import ( from_light_curve_observation_to_fluxes_array_and_label_array, - pair_array_to_tensor, + pair_array_to_tensor, normalize_tensor_by_modified_z_score, make_uniform_length, ) if TYPE_CHECKING: @@ -337,56 +337,6 @@ def default_light_curve_post_injection_transform( return x -def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: - """ - Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median - absolute deviation. - - :param tensor: The tensor to normalize. - :return: The normalized tensor. - """ - median = torch.median(tensor) - deviation_from_median = tensor - median - absolute_deviation_from_median = torch.abs(deviation_from_median) - median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) - if median_absolute_deviation_from_median != 0: - modified_z_score = ( - 0.6745 * deviation_from_median / median_absolute_deviation_from_median - ) - else: - modified_z_score = torch.zeros_like(tensor) - return modified_z_score - - -def make_uniform_length( - example: np.ndarray, length: int, *, randomize: bool = True -) -> np.ndarray: - """Makes the example a specific length, by clipping those too large and repeating those too small.""" - if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. - raise ValueError( - f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" - ) - if randomize: - example = randomly_roll_elements(example) - if example.shape[0] == length: - pass - elif example.shape[0] > length: - example = example[:length] - else: - elements_to_repeat = length - example.shape[0] - if len(example.shape) == 1: - example = np.pad(example, (0, elements_to_repeat), mode="wrap") - else: - example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap") - return example - - -def randomly_roll_elements(example: np.ndarray) -> np.ndarray: - """Randomly rolls the elements.""" - example = np.roll(example, np.random.randint(example.shape[0]), axis=0) - return example - - class OutOfBoundsInjectionHandlingMethod(Enum): """ An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal. diff --git a/src/qusi/internal/light_curve_transforms.py b/src/qusi/internal/light_curve_transforms.py index dae294d..f5afa28 100644 --- a/src/qusi/internal/light_curve_transforms.py +++ b/src/qusi/internal/light_curve_transforms.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import numpy.typing as npt import torch @@ -38,3 +40,47 @@ def randomly_roll_elements(example: np.ndarray) -> np.ndarray: """Randomly rolls the elements.""" example = np.roll(example, np.random.randint(example.shape[0]), axis=0) return example + + +def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: + """ + Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median + absolute deviation. + + :param tensor: The tensor to normalize. + :return: The normalized tensor. + """ + median = torch.median(tensor) + deviation_from_median = tensor - median + absolute_deviation_from_median = torch.abs(deviation_from_median) + median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) + if median_absolute_deviation_from_median != 0: + modified_z_score = ( + 0.6745 * deviation_from_median / median_absolute_deviation_from_median + ) + else: + modified_z_score = torch.zeros_like(tensor) + return modified_z_score + + +def make_uniform_length( + example: np.ndarray, length: int, *, randomize: bool = True +) -> np.ndarray: + """Makes the example a specific length, by clipping those too large and repeating those too small.""" + if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. + raise ValueError( + f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" + ) + if randomize: + example = randomly_roll_elements(example) + if example.shape[0] == length: + pass + elif example.shape[0] > length: + example = example[:length] + else: + elements_to_repeat = length - example.shape[0] + if len(example.shape) == 1: + example = np.pad(example, (0, elements_to_repeat), mode="wrap") + else: + example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap") + return example diff --git a/src/qusi/transform.py b/src/qusi/transform.py index 3f92951..ba3b863 100644 --- a/src/qusi/transform.py +++ b/src/qusi/transform.py @@ -3,17 +3,18 @@ """ from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \ - default_light_curve_observation_post_injection_transform, make_uniform_length + default_light_curve_observation_post_injection_transform from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \ randomly_roll_light_curve_observation from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ - pair_array_to_tensor + pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score __all__ = [ 'default_light_curve_post_injection_transform', 'default_light_curve_observation_post_injection_transform', 'from_light_curve_observation_to_fluxes_array_and_label_array', 'make_uniform_length', + 'normalize_tensor_by_modified_z_score', 'pair_array_to_tensor', 'randomly_roll_light_curve', 'randomly_roll_light_curve_observation',