Skip to content

Commit

Permalink
Remove randomization from make_uniform_length and leave that to a sep…
Browse files Browse the repository at this point in the history
…arate randomly roll call
  • Loading branch information
golmschenk committed May 15, 2024
1 parent 250f2a6 commit 63fcfc5
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/tutorials/crafting_standard_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def default_light_curve_observation_post_injection_transform(x: LightCurveObserv
if randomize:
x = randomly_roll_light_curve_observation(x)
x = from_light_curve_observation_to_fluxes_array_and_label_array(x)
x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length.
x = (make_uniform_length(x[0], length=length), x[1])
x = pair_array_to_tensor(x)
x = (normalize_tensor_by_modified_z_score(x[0]), x[1])
return x
Expand Down
4 changes: 2 additions & 2 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def default_light_curve_observation_post_injection_transform(
if randomize:
x = randomly_roll_light_curve_observation(x)
x = from_light_curve_observation_to_fluxes_array_and_label_array(x)
x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length.
x = (make_uniform_length(x[0], length=length), x[1]) # Make the fluxes a uniform length.
x = pair_array_to_tensor(x)
x = (normalize_tensor_by_modified_z_score(x[0]), x[1])
return x
Expand All @@ -331,7 +331,7 @@ def default_light_curve_post_injection_transform(
if randomize:
x = randomly_roll_light_curve(x)
x = x.fluxes
x = make_uniform_length(x, length=length, randomize=randomize)
x = make_uniform_length(x, length=length)
x = torch.tensor(x, dtype=torch.float32)
x = normalize_tensor_by_modified_z_score(x)
return x
Expand Down
6 changes: 1 addition & 5 deletions src/qusi/internal/light_curve_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,12 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
return modified_z_score


def make_uniform_length(
example: np.ndarray, length: int, *, randomize: bool = True
) -> np.ndarray:
def make_uniform_length(example: np.ndarray, length: int) -> np.ndarray:
"""Makes the example a specific length, by clipping those too large and repeating those too small."""
if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases.
raise ValueError(
f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}"
)
if randomize:
example = randomly_roll_elements(example)
if example.shape[0] == length:
pass
elif example.shape[0] > length:
Expand Down

0 comments on commit 63fcfc5

Please sign in to comment.