diff --git a/pyproject.toml b/pyproject.toml index 993ea461..9545a90e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,5 +110,9 @@ extend-exclude = ["examples"] ignore = [ "RET504", # Subjective but, naming the returned value often seems to help readability. "SIM108", # Subjective but, ternary operators are often too confusing. + # We don't expect use cases of the frame to need to worry about security issues. + "S608", + "S301", + "S311", ] isort.known-first-party = ["qusi", "ramjet"] \ No newline at end of file diff --git a/src/ramjet/logging/wandb_logger.py b/src/ramjet/logging/wandb_logger.py index eefb4a4d..bc249587 100644 --- a/src/ramjet/logging/wandb_logger.py +++ b/src/ramjet/logging/wandb_logger.py @@ -4,13 +4,13 @@ from __future__ import annotations import math +import multiprocessing import queue from abc import ABC, abstractmethod import plotly import plotly.graph_objects as go from plotly.subplots import make_subplots -import multiprocessing import wandb from ramjet.photometric_database.light_curve import LightCurve diff --git a/src/ramjet/photometric_database/derived/tess_two_minute_cadence_light_curve_collection.py b/src/ramjet/photometric_database/derived/tess_two_minute_cadence_light_curve_collection.py index 75fa1797..60c31449 100644 --- a/src/ramjet/photometric_database/derived/tess_two_minute_cadence_light_curve_collection.py +++ b/src/ramjet/photometric_database/derived/tess_two_minute_cadence_light_curve_collection.py @@ -1,14 +1,17 @@ """ Code for a light curve collection of the TESS two minute cadence data. """ -from pathlib import Path -from typing import Union +from __future__ import annotations -import numpy as np -from peewee import Select +from pathlib import Path +from typing import TYPE_CHECKING -from ramjet.data_interface.metadatabase import MetadatabaseModel -from ramjet.data_interface.tess_data_interface import TessDataInterface, TessFluxType +from ramjet.data_interface.tess_data_interface import ( + TessDataInterface, + TessFluxType, + download_two_minute_cadence_light_curves, + load_fluxes_and_times_from_fits_file, +) from ramjet.data_interface.tess_target_metadata_manager import TessTargetMetadata from ramjet.data_interface.tess_two_minute_cadence_light_curve_metadata_manager import ( TessTwoMinuteCadenceLightCurveMetadata, @@ -16,6 +19,12 @@ ) from ramjet.photometric_database.sql_metadata_light_curve_collection import SqlMetadataLightCurveCollection +if TYPE_CHECKING: + import numpy as np + from peewee import Select + + from ramjet.data_interface.metadatabase import MetadatabaseModel + class TessTwoMinuteCadenceLightCurveCollection(SqlMetadataLightCurveCollection): """ @@ -24,11 +33,11 @@ class TessTwoMinuteCadenceLightCurveCollection(SqlMetadataLightCurveCollection): tess_data_interface = TessDataInterface() tess_two_minute_cadence_light_curve_metadata_manger = TessTwoMinuteCadenceLightCurveMetadataManger() - def __init__(self, dataset_splits: Union[list[int], None] = None, flux_type: TessFluxType = TessFluxType.PDCSAP): + def __init__(self, dataset_splits: list[int] | None = None, flux_type: TessFluxType = TessFluxType.PDCSAP): super().__init__() self.data_directory: Path = Path('data/tess_two_minute_cadence_light_curves') self.label = 0 - self.dataset_splits: Union[list[int], None] = dataset_splits + self.dataset_splits: list[int] | None = dataset_splits self.flux_type: TessFluxType = flux_type def get_sql_query(self) -> Select: diff --git a/src/ramjet/photometric_database/derived/tess_two_minute_cadence_transit_light_curve_collections.py b/src/ramjet/photometric_database/derived/tess_two_minute_cadence_transit_light_curve_collections.py index 7e184ef9..ad93e402 100644 --- a/src/ramjet/photometric_database/derived/tess_two_minute_cadence_transit_light_curve_collections.py +++ b/src/ramjet/photometric_database/derived/tess_two_minute_cadence_transit_light_curve_collections.py @@ -1,9 +1,9 @@ """ Code representing the collection of TESS two minute cadence light curves containing transits. """ -from typing import Union +from __future__ import annotations -from peewee import Select +from typing import TYPE_CHECKING from ramjet.data_interface.tess_transit_metadata_manager import Disposition, TessTransitMetadata from ramjet.data_interface.tess_two_minute_cadence_light_curve_metadata_manager import ( @@ -13,13 +13,16 @@ TessTwoMinuteCadenceTargetDatasetSplitLightCurveCollection, ) +if TYPE_CHECKING: + from peewee import Select + class TessTwoMinuteCadenceConfirmedTransitLightCurveCollection( TessTwoMinuteCadenceTargetDatasetSplitLightCurveCollection): """ A class representing the collection of TESS two minute cadence light curves containing transits. """ - def __init__(self, dataset_splits: Union[list[int], None] = None): + def __init__(self, dataset_splits: list[int] | None = None): super().__init__(dataset_splits=dataset_splits) self.label = 1 @@ -41,7 +44,7 @@ class TessTwoMinuteCadenceConfirmedAndCandidateTransitLightCurveCollection( """ A class representing the collection of TESS two minute cadence light curves containing transits. """ - def __init__(self, dataset_splits: Union[list[int], None] = None): + def __init__(self, dataset_splits: list[int] | None = None): super().__init__(dataset_splits=dataset_splits) self.label = 1 @@ -63,7 +66,7 @@ class TessTwoMinuteCadenceNonTransitLightCurveCollection(TessTwoMinuteCadenceTar """ A class representing the collection of TESS two minute cadence light curves containing transits. """ - def __init__(self, dataset_splits: Union[list[int], None] = None): + def __init__(self, dataset_splits: list[int] | None = None): super().__init__(dataset_splits=dataset_splits) self.label = 0 diff --git a/src/ramjet/photometric_database/derived/toy_database.py b/src/ramjet/photometric_database/derived/toy_database.py index 96cc4efc..be7389bb 100644 --- a/src/ramjet/photometric_database/derived/toy_database.py +++ b/src/ramjet/photometric_database/derived/toy_database.py @@ -27,6 +27,7 @@ def __init__(self): ToySineWaveLightCurveCollection(), ] + class ToyRamjetDatabaseWithAuxiliary(StandardAndInjectedLightCurveDatabase): def __init__(self): super().__init__() @@ -36,8 +37,10 @@ def __init__(self): self.number_of_auxiliary_values = 2 flat_collection = ToyFlatLightCurveCollection() sine_wave_collection = ToySineWaveLightCurveCollection() - flat_collection.load_auxiliary_information_for_path = lambda path: np.array([0, 0], dtype=np.float32) - sine_wave_collection.load_auxiliary_information_for_path = lambda path: np.array([1, 1], dtype=np.float32) + flat_collection.load_auxiliary_information_for_path = ( + lambda path: np.array([0, 0], dtype=np.float32)) # noqa ARG002 + sine_wave_collection.load_auxiliary_information_for_path = ( + lambda path: np.array([1, 1], dtype=np.float32)) # noqa ARG002 self.training_standard_light_curve_collections = [ flat_collection, sine_wave_collection, diff --git a/src/ramjet/photometric_database/derived/toy_light_curve_collection.py b/src/ramjet/photometric_database/derived/toy_light_curve_collection.py index 46ab4440..d4667406 100644 --- a/src/ramjet/photometric_database/derived/toy_light_curve_collection.py +++ b/src/ramjet/photometric_database/derived/toy_light_curve_collection.py @@ -1,12 +1,16 @@ -from collections.abc import Iterable +from __future__ import annotations + from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING import numpy as np from ramjet.photometric_database.light_curve import LightCurve from ramjet.photometric_database.light_curve_collection import LightCurveCollection +if TYPE_CHECKING: + from collections.abc import Iterable + class ToyLightCurveCollection(LightCurveCollection): """ @@ -26,7 +30,7 @@ def __init__(self): super().__init__() self.label = 0 - def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): + def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002 light_curve = ToyLightCurve.flat() return light_curve.times, light_curve.fluxes @@ -43,7 +47,7 @@ def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray light_curve = ToyLightCurve.flat(float(path.name)) return light_curve.times, light_curve.fluxes - def load_label_from_path(self, path: Path) -> Union[float, np.ndarray]: + def load_label_from_path(self, path: Path) -> float | np.ndarray: label = float(path.name) return label @@ -53,7 +57,7 @@ def __init__(self): super().__init__() self.label = 1 - def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): + def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002 light_curve = ToyLightCurve.sine_wave() return light_curve.times, light_curve.fluxes diff --git a/src/ramjet/photometric_database/light_curve.py b/src/ramjet/photometric_database/light_curve.py index eb34a092..21e33e34 100644 --- a/src/ramjet/photometric_database/light_curve.py +++ b/src/ramjet/photometric_database/light_curve.py @@ -3,19 +3,18 @@ """ from __future__ import annotations -from abc import ABC - import lightkurve.lightcurve import numpy as np import pandas as pd from lightkurve.periodogram import LombScarglePeriodogram -class LightCurve(ABC): +class LightCurve: """ A class to represent a light curve. A light curve is a collection of data which may includes times, fluxes, flux errors, and related values. """ + def __init__(self): self.data_frame: pd.DataFrame = pd.DataFrame() self.flux_column_names: list[str] = [] @@ -59,7 +58,7 @@ def times(self, value: np.ndarray): def folded_times(self): if self.folded_times_column_name not in self.data_frame.columns: error_message = 'Light curve has not been folded.' - raise MissingFoldedTimes(error_message) + raise MissingFoldedTimesError(error_message) return self.data_frame[self.folded_times_column_name].values @folded_times.setter @@ -114,7 +113,11 @@ def to_lightkurve(self) -> lightkurve.lightcurve.LightCurve: def get_variability_phase_folding_parameters( self, minimum_period: float | None = None, maximum_period: float | None = None ) -> (float, float, float, float, float): - fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve, periodogram, folded_lightkurve_light_curve = self.get_variability_phase_folding_parameters_and_folding_lightkurve_light_curves(minimum_period=minimum_period, maximum_period=maximum_period) + ( + fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve, + periodogram, folded_lightkurve_light_curve + ) = self.get_variability_phase_folding_parameters_and_folding_lightkurve_light_curves( + minimum_period=minimum_period, maximum_period=maximum_period) self._variability_period = fold_period self._variability_period_epoch = fold_epoch return fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase @@ -138,11 +141,14 @@ def get_variability_phase_folding_parameters_and_folding_lightkurve_light_curves minimum_bin_phase = binned_folded_lightkurve_light_curve.phase.value[minimum_bin_index] maximum_bin_phase = binned_folded_lightkurve_light_curve.phase.value[maximum_bin_index] fold_epoch = inlier_lightkurve_light_curve.time.value[0] - return fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve, periodogram, folded_lightkurve_light_curve + return ( + fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve, + periodogram, folded_lightkurve_light_curve + ) def fold(self, period: float, epoch: float) -> None: self.folded_times = (self.times - epoch) % period -class MissingFoldedTimes(Exception): +class MissingFoldedTimesError(Exception): pass diff --git a/src/ramjet/photometric_database/light_curve_collection.py b/src/ramjet/photometric_database/light_curve_collection.py index d22955d4..5051f4cf 100644 --- a/src/ramjet/photometric_database/light_curve_collection.py +++ b/src/ramjet/photometric_database/light_curve_collection.py @@ -1,12 +1,16 @@ """ Code for representing a collection of light curves. """ -from collections.abc import Iterable +from __future__ import annotations + from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING import numpy as np +if TYPE_CHECKING: + from collections.abc import Iterable + class LightCurveCollectionMethodNotImplementedError(RuntimeError): """ @@ -25,8 +29,8 @@ class LightCurveCollection: :ivar paths: The default list of paths to be used if the `get_paths` method is not overridden. """ def __init__(self): - self.label: Union[float, list[float], np.ndarray, None] = None - self.paths: Union[list[Path], None] = None + self.label: float | list[float] | np.ndarray | None = None + self.paths: list[Path] | None = None def get_paths(self) -> Iterable[Path]: """ @@ -36,7 +40,7 @@ def get_paths(self) -> Iterable[Path]: """ return self.paths - def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): + def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002 """ Loads the times and fluxes from a given light curve path. @@ -45,7 +49,7 @@ def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray """ raise LightCurveCollectionMethodNotImplementedError - def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np.ndarray): + def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002 """ Loads the times and magnifications from a given path as an injectable signal. @@ -69,7 +73,7 @@ def generate_synthetic_signal_from_real_data(fluxes: np.ndarray, times: np.ndarr relative_times = times - np.min(times) return normalized_fluxes, relative_times - def load_label_from_path(self, path: Path) -> Union[float, np.ndarray]: + def load_label_from_path(self, path: Path) -> float | np.ndarray: # noqa ARG002 """ Loads the label of an example from a corresponding path. @@ -101,7 +105,7 @@ def shuffle_and_split_paths(paths: list[Path], dataset_splits: list[int], number return dataset_split_paths def load_times_fluxes_and_flux_errors_from_path(self, path: Path - ) -> (np.ndarray, np.ndarray, Union[np.ndarray, None]): + ) -> (np.ndarray, np.ndarray, np.ndarray | None): """ Loads the times, fluxes, and flux errors of a light curve from a path to the data. Unless overridden, defaults to using the method to load only the times and fluxes, and returns None for errors. @@ -114,7 +118,7 @@ def load_times_fluxes_and_flux_errors_from_path(self, path: Path return times, fluxes, flux_errors def load_times_magnifications_and_magnification_errors_from_path( - self, path: Path) -> (np.ndarray, np.ndarray, Union[np.ndarray, None]): + self, path: Path) -> (np.ndarray, np.ndarray, np.ndarray | None): """ Loads the times, magnifications, and magnification_errors of a light curve from a path to the data. Unless overridden, defaults to using the method to load only the times and magnifications, @@ -127,7 +131,7 @@ def load_times_magnifications_and_magnification_errors_from_path( flux_errors = None return times, fluxes, flux_errors - def load_auxiliary_information_for_path(self, path: Path) -> np.ndarray: + def load_auxiliary_information_for_path(self, path: Path) -> np.ndarray: # noqa ARG002 """ Loads auxiliary information for the given path. diff --git a/src/ramjet/photometric_database/light_curve_database.py b/src/ramjet/photometric_database/light_curve_database.py index 1007eb05..58d4732a 100644 --- a/src/ramjet/photometric_database/light_curve_database.py +++ b/src/ramjet/photometric_database/light_curve_database.py @@ -1,8 +1,8 @@ """Code for a base generalized database for photometric data to be subclassed.""" +from __future__ import annotations + import shutil -from abc import ABC from pathlib import Path -from typing import Union import numpy as np import numpy.typing as npt @@ -19,22 +19,27 @@ def preprocess_times(light_curve_array: np.ndarray) -> None: light_curve_array[:, 0] = calculate_time_differences(times) -def make_times_and_fluxes_array_uniform_length(arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], length: int, randomize: bool = True) -> (np.ndarray, np.ndarray): +def make_times_and_fluxes_array_uniform_length(arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], + length: int, *, randomize: bool = True) -> (np.ndarray, np.ndarray): times, fluxes = arrays light_curve_array = np.stack([times, fluxes], axis=-1) uniform_length_light_curve_array = make_uniform_length(light_curve_array, length=length, randomize=randomize) return uniform_length_light_curve_array[:, 0], uniform_length_light_curve_array[:, 1] -def make_fluxes_and_label_array_uniform_length(arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], length: int, randomize: bool = True) -> (np.ndarray, np.ndarray): +def make_fluxes_and_label_array_uniform_length(arrays: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]], + length: int, *, randomize: bool = True) -> (np.ndarray, np.ndarray): times, label = arrays uniform_length_times = make_uniform_length(times, length=length, randomize=randomize) return uniform_length_times, label -def make_uniform_length(example: np.ndarray, length: int, randomize: bool = True) -> np.ndarray: +def make_uniform_length(example: np.ndarray, length: int, *, randomize: bool = True) -> np.ndarray: """Makes the example a specific length, by clipping those too large and repeating those too small.""" - assert len(example.shape) in [1, 2] # Only tested for 1D and 2D cases. + allowed_channels_dimension = [1, 2] + if len(example.shape) not in allowed_channels_dimension: # Only tested for 1D and 2D cases. + msg = f'Expected one of {allowed_channels_dimension}, but got {len(example.shape)}.' + raise ValueError(msg) if randomize: example = randomly_roll_elements(example) if example.shape[0] == length: @@ -50,7 +55,7 @@ def make_uniform_length(example: np.ndarray, length: int, randomize: bool = True return example -class LightCurveDatabase(ABC): +class LightCurveDatabase: """A base generalized database for photometric data to be subclassed.""" def __init__(self, data_directory='data'): @@ -95,18 +100,27 @@ def normalize_fluxes(self, light_curve: np.ndarray) -> None: """ if self.include_time_as_channel: if self.include_flux_errors_as_channel: - assert light_curve.shape[1] == 3 + expected_channels = 3 + if light_curve.shape[1] != expected_channels: + msg = f'Expected light curve channels shape of 3, found {light_curve.shape[1]}.' + raise ValueError(msg) light_curve[:, 1], light_curve[:, 2] = normalize_on_percentiles_with_errors( light_curve[:, 1], light_curve[:, 2]) else: - assert light_curve.shape[1] == 2 + expected_channels = 2 + if light_curve.shape[1] != expected_channels: + msg = f'Expected light curve channels shape of 2, found {light_curve.shape[1]}.' + raise ValueError(msg) light_curve[:, 1] = normalize_on_percentiles(light_curve[:, 1]) else: - assert light_curve.shape[1] == 1 + expected_channels = 1 + if light_curve.shape[1] != expected_channels: + msg = f'Expected light curve channels shape of 1, found {light_curve.shape[1]}.' + raise ValueError(msg) light_curve[:, 0] = normalize_on_percentiles(light_curve[:, 0]) - def build_light_curve_array(self, fluxes: np.ndarray, times: Union[np.ndarray, None] = None, - flux_errors: Union[np.ndarray, None] = None): + def build_light_curve_array(self, fluxes: np.ndarray, times: np.ndarray | None = None, + flux_errors: np.ndarray | None = None): """ Builds the light curve array based on the components required for the specific database setup. @@ -125,7 +139,7 @@ def build_light_curve_array(self, fluxes: np.ndarray, times: Union[np.ndarray, N light_curve = np.expand_dims(fluxes, axis=-1) return light_curve - def preprocess_light_curve(self, light_curve: np.ndarray, evaluation_mode: bool = False) -> np.ndarray: + def preprocess_light_curve(self, light_curve: np.ndarray, *, evaluation_mode: bool = False) -> np.ndarray: """ Preprocessing for the light curve. @@ -213,7 +227,7 @@ def randomly_roll_elements(example: np.ndarray) -> np.ndarray: return example -def extract_shuffled_chunk_and_remainder(array_to_extract_from: Union[list, np.ndarray], chunk_ratio: float, +def extract_shuffled_chunk_and_remainder(array_to_extract_from: list | np.ndarray, chunk_ratio: float, chunk_to_extract_index: int = 0) -> (np.ndarray, np.ndarray): """ Shuffles an array, extracts a chunk of the data, and returns the chunk and remainder of the array. diff --git a/src/ramjet/photometric_database/microlensing_signal_generator.py b/src/ramjet/photometric_database/microlensing_signal_generator.py index bd49b7c5..c5871847 100644 --- a/src/ramjet/photometric_database/microlensing_signal_generator.py +++ b/src/ramjet/photometric_database/microlensing_signal_generator.py @@ -14,7 +14,8 @@ try: from muLAn.models.vbb.vbb import vbbmagU except ModuleNotFoundError: - vbbmagU = None + def vbbmagU(_s, _q, _rho, _xi, _yi, _accuracy): # noqa + raise ModuleNotFoundError class MagnificationSignal: @@ -29,7 +30,7 @@ class MagnificationSignal: > The distribution for tE and rho are based on the MOA observations > No parallax effect is considered """ - tE_list: np.ndarray = None + einstein_crossing_time_list: np.ndarray = None rho_list: np.ndarray = None def __init__(self): @@ -49,21 +50,22 @@ def load_moa_meta_data_to_class_attributes(self): """ Loads the MOA meta data defining microlensing to class attributes. If already loaded, does nothing. """ - if self.tE_list is None: + if self.einstein_crossing_time_list is None: microlensing_meta_data_path = Path(__file__).parent.joinpath( 'microlensing_signal_meta_data/candlist_RADec.dat.txt') microlensing_meta_data_path.parent.mkdir(parents=True, exist_ok=True) if not microlensing_meta_data_path.exists(): candidate_list_csv_url = 'https://exoplanetarchive.ipac.caltech.edu/data/ExoData/MOA/candlist_RADec.dat' - response = requests.get(candidate_list_csv_url) + response = requests.get(candidate_list_csv_url, timeout=600) with open(microlensing_meta_data_path, 'wb') as csv_file: csv_file.write(response.content) data = pd.read_csv(microlensing_meta_data_path, header=None, delim_whitespace=True, comment='#', usecols=[19, 36], names=['tE', 'rho']) - self.tE_list: np.ndarray = data['tE'].values + self.einstein_crossing_time_list: np.ndarray = data['tE'].values self.rho_list: np.ndarray = data['rho'].values - bad_indexes = np.argwhere(self.tE_list > 6000) - self.tE_list = np.delete(self.tE_list, bad_indexes) + bad_einstein_crossing_time = 6000 + bad_indexes = np.argwhere(self.einstein_crossing_time_list > bad_einstein_crossing_time) + self.einstein_crossing_time_list = np.delete(self.einstein_crossing_time_list, bad_indexes) self.rho_list = np.delete(self.rho_list, bad_indexes) def getting_random_values(self): @@ -76,8 +78,8 @@ def getting_random_values(self): u0_list = np.linspace(-0.1, 0, 1000) self.u0 = np.random.choice(u0_list) - index = np.random.choice(np.arange(self.tE_list.shape[0])) - self.tE = float(self.tE_list[index]) + index = np.random.choice(np.arange(self.einstein_crossing_time_list.shape[0])) + self.tE = float(self.einstein_crossing_time_list[index]) self.rho = float(self.rho_list[index]) s_list = np.linspace(0.7, 1.3, 100) @@ -149,13 +151,13 @@ def calculating_magnification_from_vbb(timeseries, lens_params): # Get parameters t0 = lens_params['t0'] u0 = lens_params['u0'] - tE = lens_params['tE'] + einstein_crossing_time = lens_params['tE'] rho = lens_params['rho'] q = lens_params['q'] alpha = lens_params['alpha'] s = lens_params['s'] - tau = (timeseries - t0) / tE + tau = (timeseries - t0) / einstein_crossing_time cos_alpha = np.cos(alpha) sin_alpha = np.sin(alpha) @@ -169,13 +171,3 @@ def calculating_magnification_from_vbb(timeseries, lens_params): accuracy = 1.e-3 # Absolute mag accuracy (mag+/-accuracy) magnification = np.array([vbbmagU(s, q, rho, x[i], y[i], accuracy) for i in range(len(x))]) return magnification - - -if __name__ == '__main__': - import time - - start_time = time.time() - random_signal = MagnificationSignal.generate_randomly_based_on_moa_observations() - print("--- %s seconds ---" % (time.time() - start_time)) - random_signal.plot_magnification() - print("Done") diff --git a/src/ramjet/photometric_database/sql_metadata_light_curve_collection.py b/src/ramjet/photometric_database/sql_metadata_light_curve_collection.py index 23994a1d..696df291 100644 --- a/src/ramjet/photometric_database/sql_metadata_light_curve_collection.py +++ b/src/ramjet/photometric_database/sql_metadata_light_curve_collection.py @@ -1,20 +1,26 @@ """ Code for a light curve collection that stores its metadata in the SQL database. """ +from __future__ import annotations + import random -from collections.abc import Iterable +from abc import abstractmethod from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING from uuid import uuid4 from peewee import Case, Field, Select -from ramjet.data_interface.metadatabase import MetadatabaseModel from ramjet.photometric_database.light_curve_collection import ( LightCurveCollection, LightCurveCollectionMethodNotImplementedError, ) +if TYPE_CHECKING: + from collections.abc import Iterable + + from ramjet.data_interface.metadatabase import MetadatabaseModel + class SqlMetadataLightCurveCollection(LightCurveCollection): """ @@ -39,6 +45,7 @@ def sql_count(self) -> int: """ return self.get_sql_query().count() + @abstractmethod def get_path_from_model(self, model: MetadatabaseModel) -> Path: """ Gets the light curve path from the SQL database model. @@ -74,7 +81,7 @@ def order_by_uuid_with_random_start(select_query: Select, uuid_field: Field) -> @staticmethod def order_by_dataset_split_with_random_start(select_query: Select, dataset_split_field: Field, - available_dataset_splits: Union[list[int], None]) -> Select: + available_dataset_splits: list[int] | None) -> Select: """ Applies an "order by" on a query using a passed dataset_split field. The "order by" starts at a random dataset_split out of the passed available options, then loops back to the minimum dataset_split to include all diff --git a/src/ramjet/photometric_database/standard_and_injected_light_curve_database.py b/src/ramjet/photometric_database/standard_and_injected_light_curve_database.py index 6097ecef..e700b82b 100644 --- a/src/ramjet/photometric_database/standard_and_injected_light_curve_database.py +++ b/src/ramjet/photometric_database/standard_and_injected_light_curve_database.py @@ -1,13 +1,14 @@ """ An abstract class allowing for any number and combination of standard and injectable/injectee light curve collections. """ +from __future__ import annotations + from functools import partial -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable import numpy as np import numpy.typing as npt -from ramjet.logging.wandb_logger import WandbLoggableInjection, WandbLogger from ramjet.photometric_database.light_curve import LightCurve from ramjet.photometric_database.light_curve_database import LightCurveDatabase from ramjet.photometric_database.light_curve_dataset_manipulations import ( @@ -17,6 +18,7 @@ ) if TYPE_CHECKING: + from ramjet.logging.wandb_logger import WandbLoggableInjection, WandbLogger from ramjet.photometric_database.light_curve_collection import LightCurveCollection @@ -54,17 +56,17 @@ def inject_signal_into_light_curve( class StandardAndInjectedLightCurveDatabase(LightCurveDatabase): """ - An abstract class allowing for any number and combination of standard and injectable/injectee light curve collections - to be used for training. + An abstract class allowing for any number and combination of standard and injectable/injectee light curve + collections to be used for training. """ def __init__(self): super().__init__() self.training_standard_light_curve_collections: list[LightCurveCollection] = [] - self.training_injectee_light_curve_collection: Union[LightCurveCollection, None] = None + self.training_injectee_light_curve_collection: LightCurveCollection | None = None self.training_injectable_light_curve_collections: list[LightCurveCollection] = [] self.validation_standard_light_curve_collections: list[LightCurveCollection] = [] - self.validation_injectee_light_curve_collection: Union[LightCurveCollection, None] = None + self.validation_injectee_light_curve_collection: LightCurveCollection | None = None self.validation_injectable_light_curve_collections: list[LightCurveCollection] = [] self.inference_light_curve_collections: list[LightCurveCollection] = [] self.shuffle_buffer_size = 10000 @@ -73,7 +75,7 @@ def __init__(self): self.out_of_bounds_injection_handling: OutOfBoundsInjectionHandlingMethod = \ OutOfBoundsInjectionHandlingMethod.ERROR self.baseline_flux_estimation_method = BaselineFluxEstimationMethod.MEDIAN - self.logger: Optional[WandbLogger] = None + self.logger: WandbLogger | None = None @property def number_of_input_channels(self) -> int: @@ -89,7 +91,7 @@ def number_of_input_channels(self) -> int: channels += 1 return channels - def add_logging_queues_to_map_function(self, preprocess_map_function: Callable, name: Optional[str]) -> Callable: + def add_logging_queues_to_map_function(self, preprocess_map_function: Callable, name: str | None) -> Callable: """ Adds logging queues to the map functions. @@ -105,7 +107,7 @@ def add_logging_queues_to_map_function(self, preprocess_map_function: Callable, def inject_signal_into_light_curve(self, light_curve_fluxes: np.ndarray, light_curve_times: np.ndarray, signal_magnifications: np.ndarray, signal_times: np.ndarray, - wandb_loggable_injection: Optional[WandbLoggableInjection] = None) -> np.ndarray: + wandb_loggable_injection: WandbLoggableInjection | None = None) -> np.ndarray: """ Injects a synthetic magnification signal into real light curve fluxes. @@ -118,9 +120,11 @@ def inject_signal_into_light_curve(self, light_curve_fluxes: np.ndarray, light_c """ out_of_bounds_injection_handling_method = self.out_of_bounds_injection_handling baseline_flux_estimation_method = self.baseline_flux_estimation_method - fluxes_with_injected_signal, offset_signal_times, signal_fluxes = inject_signal_into_light_curve_with_intermediates( - light_curve_times, light_curve_fluxes, signal_times, signal_magnifications, - out_of_bounds_injection_handling_method, baseline_flux_estimation_method) + fluxes_with_injected_signal, offset_signal_times, signal_fluxes = ( + inject_signal_into_light_curve_with_intermediates( + light_curve_times, light_curve_fluxes, signal_times, signal_magnifications, + out_of_bounds_injection_handling_method, baseline_flux_estimation_method) + ) if wandb_loggable_injection is not None: wandb_loggable_injection.aligned_injectee_light_curve = LightCurve.from_times_and_fluxes( light_curve_times, light_curve_fluxes) @@ -131,7 +135,7 @@ def inject_signal_into_light_curve(self, light_curve_fluxes: np.ndarray, light_c return fluxes_with_injected_signal -def expand_label_to_training_dimensions(label: Union[int, list[int], tuple[int], np.ndarray]) -> np.ndarray: +def expand_label_to_training_dimensions(label: int | list[int] | tuple[int] | np.ndarray) -> np.ndarray: """ Expand the label to the appropriate dimensions for training. diff --git a/src/ramjet/photometric_database/tess_ffi_light_curve.py b/src/ramjet/photometric_database/tess_ffi_light_curve.py index 00f8d7a4..78de275b 100644 --- a/src/ramjet/photometric_database/tess_ffi_light_curve.py +++ b/src/ramjet/photometric_database/tess_ffi_light_curve.py @@ -56,7 +56,7 @@ class TessFfiPickleIndex(Enum): QUALITY_FLAG = 11 -class AdaptIntermittentException(Exception): +class AdaptIntermittentError(Exception): pass @@ -75,10 +75,10 @@ def __init__(self): TessFfiColumnName.RAW_FLUX.value] @classmethod - @retry(retry=retry_if_exception_type(AdaptIntermittentException), + @retry(retry=retry_if_exception_type(AdaptIntermittentError), wait=wait_random_exponential(multiplier=0.1, max=20), stop=stop_after_attempt(20), reraise=True) def from_path(cls, path: Path, column_names_to_load: list[TessFfiColumnName] | None = None, - remove_bad_quality_data: bool = True) -> TessFfiLightCurve: + *, remove_bad_quality_data: bool = True) -> TessFfiLightCurve: """ Creates an FFI TESS light curve from a path to one of Brian Powell's pickle files. @@ -89,26 +89,26 @@ def from_path(cls, path: Path, column_names_to_load: list[TessFfiColumnName] | N :param remove_bad_quality_data: Removes data with quality problem flags (e.g., non-zero quality flags). :return: The light curve. """ + light_curve = cls() + light_curve.time_column_name = TessFfiColumnName.TIME__BTJD.value + if column_names_to_load is None: + column_names_to_load = list(TessFfiColumnName) try: - light_curve = cls() - light_curve.time_column_name = TessFfiColumnName.TIME__BTJD.value - if column_names_to_load is None: - column_names_to_load = list(TessFfiColumnName) with path.open('rb') as pickle_file: light_curve_data_dictionary = pickle.load(pickle_file) - if remove_bad_quality_data: - quality_flag_values = light_curve_data_dictionary[TessFfiPickleIndex.QUALITY_FLAG.value] - for column_name in column_names_to_load: - pickle_index = TessFfiPickleIndex[column_name.name] - column_values = light_curve_data_dictionary[pickle_index.value] - if remove_bad_quality_data: - column_values = column_values[quality_flag_values == 0] - light_curve.data_frame[column_name.value] = column_values - light_curve.tic_id, light_curve.sector = light_curve.get_tic_id_and_sector_from_file_path(path) - return light_curve except (pickle.UnpicklingError, OSError, IsADirectoryError) as error: error_message = f'Errored on path {path}.' - raise AdaptIntermittentException(error_message) from error + raise AdaptIntermittentError(error_message) from error + if remove_bad_quality_data: + quality_flag_values = light_curve_data_dictionary[TessFfiPickleIndex.QUALITY_FLAG.value] + for column_name in column_names_to_load: + pickle_index = TessFfiPickleIndex[column_name.name] + column_values = light_curve_data_dictionary[pickle_index.value] + if remove_bad_quality_data: + column_values = column_values[quality_flag_values == 0] + light_curve.data_frame[column_name.value] = column_values + light_curve.tic_id, light_curve.sector = light_curve.get_tic_id_and_sector_from_file_path(path) + return light_curve @staticmethod def get_tic_id_and_sector_from_file_path(path: Path | str) -> (int, int | None): @@ -179,7 +179,7 @@ def get_magnitude_from_file(file_path: Path | str) -> float: @classmethod def load_fluxes_and_times_from_pickle_file( cls, file_path: Path | str, flux_column_name: TessFfiColumnName = TessFfiColumnName.CORRECTED_FLUX, - remove_bad_quality_data: bool = True + *, remove_bad_quality_data: bool = True ) -> (np.ndarray, np.ndarray): """ Loads the fluxes and times from one of Brian Powell's FFI pickle files. @@ -197,7 +197,10 @@ def load_fluxes_and_times_from_pickle_file( remove_bad_quality_data=remove_bad_quality_data) fluxes = light_curve.data_frame[flux_column_name.value] times = light_curve.data_frame[TessFfiColumnName.TIME__BTJD.value] - assert times.shape == fluxes.shape + if times.shape != fluxes.shape: + error_message = f'Times and fluxes arrays must have the same shape, but have shapes ' \ + f'{times.shape} and {fluxes.shape}.' + raise ValueError(error_message) return fluxes, times @classmethod @@ -219,7 +222,10 @@ def load_fluxes_flux_errors_and_times_from_pickle_file( fluxes = light_curve.data_frame[flux_column_name.value] flux_errors = light_curve.data_frame[TessFfiColumnName.FLUX_ERROR.value] times = light_curve.data_frame[TessFfiColumnName.TIME__BTJD.value] - assert times.shape == fluxes.shape + if times.shape != fluxes.shape: + error_message = f'Times and fluxes arrays must have the same shape, but have shapes ' \ + f'{times.shape} and {fluxes.shape}.' + raise ValueError(error_message) return fluxes, flux_errors, times @@ -250,8 +256,10 @@ def filter_function(var_type_string): def separation_to_nearest_gcvs_rr_lyrae_within_separation( sky_coord: SkyCoord, - maximum_separation: Angle = Angle(21, unit=units.arcsecond) + maximum_separation: Angle | None = None ) -> Angle | None: + if maximum_separation is None: + maximum_separation = Angle(21, unit=units.arcsecond) gcvs_region_table_list = Vizier(columns=['**'], catalog='B/gcvs/gcvs_cat', row_limit=-1 ).query_region(sky_coord, radius=maximum_separation) if len(gcvs_region_table_list) == 0: diff --git a/src/ramjet/photometric_database/tess_light_curve.py b/src/ramjet/photometric_database/tess_light_curve.py index b68e5e43..b5c4173b 100644 --- a/src/ramjet/photometric_database/tess_light_curve.py +++ b/src/ramjet/photometric_database/tess_light_curve.py @@ -123,9 +123,9 @@ def load_tic_rows_from_mast_for_list(cls, light_curves: list[TessLightCurve]) -> for light_curve in light_curves: light_curve_tic_row_data_frame = tic_row_data_frame[tic_row_data_frame['ID'] == str(light_curve.tic_id)] if light_curve_tic_row_data_frame.shape[0] == 0: - light_curve._tic_row = MissingTicRow + light_curve._tic_row = MissingTicRow # noqa SLF001 else: - light_curve._tic_row = light_curve_tic_row_data_frame.iloc[0] + light_curve._tic_row = light_curve_tic_row_data_frame.iloc[0] # noqa SLF001 class MissingTicRow: diff --git a/src/ramjet/photometric_database/tess_target.py b/src/ramjet/photometric_database/tess_target.py index 0aadde2a..587a7c2d 100644 --- a/src/ramjet/photometric_database/tess_target.py +++ b/src/ramjet/photometric_database/tess_target.py @@ -62,7 +62,7 @@ def get_radius_from_gaia(gaia_source_id: int) -> float: radius = query_results_data_frame['radius_val'].iloc[0] return radius - def calculate_transiting_body_radius(self, transit_depth: float, allow_unknown_contamination_ratio: bool = False + def calculate_transiting_body_radius(self, transit_depth: float, *, allow_unknown_contamination_ratio: bool = False ) -> float: """ Calculates the radius of a transiting body based on the target parameters and the transit depth. @@ -87,7 +87,7 @@ def retrieve_nearby_tic_targets(self) -> pd.DataFrame: :return: The data frame of nearby targets. """ csv_url = f'https://exofop.ipac.caltech.edu/tess/download_nearbytarget.php?id={self.tic_id}&output=csv' - csv_string = requests.get(csv_url).content.decode('utf-8') + csv_string = requests.get(csv_url, timeout=600).content.decode('utf-8') if 'Distance Err' not in csv_string: # Correct ExoFOP bug where distance error column header is missing. csv_string = csv_string.replace('Distance(pc)', 'Distance (pc),Distance Err (pc)') data_frame = pd.read_csv(io.StringIO(csv_string), index_col=False) diff --git a/src/ramjet/photometric_database/tess_two_minute_cadence_light_curve.py b/src/ramjet/photometric_database/tess_two_minute_cadence_light_curve.py index 8b9e3b4f..47206a84 100644 --- a/src/ramjet/photometric_database/tess_two_minute_cadence_light_curve.py +++ b/src/ramjet/photometric_database/tess_two_minute_cadence_light_curve.py @@ -101,17 +101,16 @@ def from_identifier(cls, identifier: Any) -> TessMissionLightCurve: integer_types = (int, np.integer) if isinstance(identifier, Path): return cls.from_path(path=identifier) - elif isinstance(identifier, tuple) and (isinstance(identifier[0], integer_types) and + if isinstance(identifier, tuple) and (isinstance(identifier[0], integer_types) and isinstance(identifier[1], integer_types)): tic_id = identifier[0] sector = identifier[1] return cls.from_mast(tic_id=tic_id, sector=sector) - elif isinstance(identifier, str): + if isinstance(identifier, str): tic_id, sector = cls.get_tic_id_and_sector_from_identifier_string(identifier) return cls.from_mast(tic_id=tic_id, sector=sector) - else: - error_message = f'{identifier} does not match a known type to infer the light curve identifier from.' - raise ValueError(error_message) + error_message = f'{identifier} does not match a known type to infer the light curve identifier from.' + raise TypeError(error_message) @staticmethod def get_tic_id_and_sector_from_file_path(file_path: Path) -> (int, int | None):