Skip to content

Commit

Permalink
Add workaround for nans in times during injection
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Jun 3, 2024
1 parent d423639 commit b9114fd
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def inject_light_curve(
) -> LightCurveObservation:
(
fluxes_with_injected_signal,
injected_light_curve_times,
_,
_,
) = inject_signal_into_light_curve_with_intermediates(
Expand All @@ -208,9 +209,10 @@ def inject_light_curve(
baseline_flux_estimation_method=BaselineFluxEstimationMethod.MEDIAN,
)
injected_light_curve = LightCurve.new(
times=injectee_observation.light_curve.times,
times=injected_light_curve_times,
fluxes=fluxes_with_injected_signal,
)
# TODO: Quickly hacked in times with nans removed. Should be handled elsewhere.
injected_observation = LightCurveObservation.new(
light_curve=injected_light_curve, label=injectable_observation.label
)
Expand Down Expand Up @@ -383,11 +385,19 @@ def inject_signal_into_light_curve_with_intermediates(
for scaling the signal magnifications.
:return: The fluxes with the injected signal, the offset signal times, and the signal flux.
"""
minimum_light_curve_time = np.min(light_curve_times)
light_curve_nan_flux_indexes = np.isnan(light_curve_times) | np.isnan(light_curve_fluxes)
light_curve_times = light_curve_times[~light_curve_nan_flux_indexes]
light_curve_fluxes = light_curve_fluxes[~light_curve_nan_flux_indexes]
signal_nan_flux_indexes = np.isnan(signal_times) | np.isnan(signal_magnifications)
signal_times = signal_times[~signal_nan_flux_indexes]
signal_magnifications = signal_magnifications[~signal_nan_flux_indexes]
# TODO: Remove quick hack of removing nans and add a more proper handling.

minimum_light_curve_time = np.nanmin(light_curve_times)
relative_light_curve_times = light_curve_times - minimum_light_curve_time
relative_signal_times = signal_times - np.min(signal_times)
signal_time_length = np.max(relative_signal_times)
light_curve_time_length = np.max(relative_light_curve_times)
relative_signal_times = signal_times - np.nanmin(signal_times)
signal_time_length = np.nanmax(relative_signal_times)
light_curve_time_length = np.nanmax(relative_light_curve_times)
time_length_difference = light_curve_time_length - signal_time_length
signal_start_offset = (
np.random.random() * time_length_difference
Expand Down Expand Up @@ -452,7 +462,7 @@ def inject_signal_into_light_curve_with_intermediates(
)
interpolated_signal_fluxes = signal_flux_interpolator(light_curve_times)
fluxes_with_injected_signal = light_curve_fluxes + interpolated_signal_fluxes
return fluxes_with_injected_signal, offset_signal_times, signal_fluxes
return fluxes_with_injected_signal, light_curve_times, offset_signal_times, signal_fluxes


def move_path_to_nvme(path: Path) -> Path:
Expand Down

0 comments on commit b9114fd

Please sign in to comment.