Skip to content

Commit

Permalink
Correct many ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Feb 15, 2024
1 parent 0b90e70 commit bd3dc87
Show file tree
Hide file tree
Showing 16 changed files with 178 additions and 121 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,9 @@ extend-exclude = ["examples"]
ignore = [
"RET504", # Subjective but, naming the returned value often seems to help readability.
"SIM108", # Subjective but, ternary operators are often too confusing.
# We don't expect use cases of the frame to need to worry about security issues.
"S608",
"S301",
"S311",
]
isort.known-first-party = ["qusi", "ramjet"]
2 changes: 1 addition & 1 deletion src/ramjet/logging/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from __future__ import annotations

import math
import multiprocessing
import queue
from abc import ABC, abstractmethod

import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import multiprocessing

import wandb
from ramjet.photometric_database.light_curve import LightCurve
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
"""
Code for a light curve collection of the TESS two minute cadence data.
"""
from pathlib import Path
from typing import Union
from __future__ import annotations

import numpy as np
from peewee import Select
from pathlib import Path
from typing import TYPE_CHECKING

from ramjet.data_interface.metadatabase import MetadatabaseModel
from ramjet.data_interface.tess_data_interface import TessDataInterface, TessFluxType
from ramjet.data_interface.tess_data_interface import (
TessDataInterface,
TessFluxType,
download_two_minute_cadence_light_curves,
load_fluxes_and_times_from_fits_file,
)
from ramjet.data_interface.tess_target_metadata_manager import TessTargetMetadata
from ramjet.data_interface.tess_two_minute_cadence_light_curve_metadata_manager import (
TessTwoMinuteCadenceLightCurveMetadata,
TessTwoMinuteCadenceLightCurveMetadataManger,
)
from ramjet.photometric_database.sql_metadata_light_curve_collection import SqlMetadataLightCurveCollection

if TYPE_CHECKING:
import numpy as np
from peewee import Select

from ramjet.data_interface.metadatabase import MetadatabaseModel


class TessTwoMinuteCadenceLightCurveCollection(SqlMetadataLightCurveCollection):
"""
Expand All @@ -24,11 +33,11 @@ class TessTwoMinuteCadenceLightCurveCollection(SqlMetadataLightCurveCollection):
tess_data_interface = TessDataInterface()
tess_two_minute_cadence_light_curve_metadata_manger = TessTwoMinuteCadenceLightCurveMetadataManger()

def __init__(self, dataset_splits: Union[list[int], None] = None, flux_type: TessFluxType = TessFluxType.PDCSAP):
def __init__(self, dataset_splits: list[int] | None = None, flux_type: TessFluxType = TessFluxType.PDCSAP):
super().__init__()
self.data_directory: Path = Path('data/tess_two_minute_cadence_light_curves')
self.label = 0
self.dataset_splits: Union[list[int], None] = dataset_splits
self.dataset_splits: list[int] | None = dataset_splits
self.flux_type: TessFluxType = flux_type

def get_sql_query(self) -> Select:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Code representing the collection of TESS two minute cadence light curves containing transits.
"""
from typing import Union
from __future__ import annotations

from peewee import Select
from typing import TYPE_CHECKING

from ramjet.data_interface.tess_transit_metadata_manager import Disposition, TessTransitMetadata
from ramjet.data_interface.tess_two_minute_cadence_light_curve_metadata_manager import (
Expand All @@ -13,13 +13,16 @@
TessTwoMinuteCadenceTargetDatasetSplitLightCurveCollection,
)

if TYPE_CHECKING:
from peewee import Select


class TessTwoMinuteCadenceConfirmedTransitLightCurveCollection(
TessTwoMinuteCadenceTargetDatasetSplitLightCurveCollection):
"""
A class representing the collection of TESS two minute cadence light curves containing transits.
"""
def __init__(self, dataset_splits: Union[list[int], None] = None):
def __init__(self, dataset_splits: list[int] | None = None):
super().__init__(dataset_splits=dataset_splits)
self.label = 1

Expand All @@ -41,7 +44,7 @@ class TessTwoMinuteCadenceConfirmedAndCandidateTransitLightCurveCollection(
"""
A class representing the collection of TESS two minute cadence light curves containing transits.
"""
def __init__(self, dataset_splits: Union[list[int], None] = None):
def __init__(self, dataset_splits: list[int] | None = None):
super().__init__(dataset_splits=dataset_splits)
self.label = 1

Expand All @@ -63,7 +66,7 @@ class TessTwoMinuteCadenceNonTransitLightCurveCollection(TessTwoMinuteCadenceTar
"""
A class representing the collection of TESS two minute cadence light curves containing transits.
"""
def __init__(self, dataset_splits: Union[list[int], None] = None):
def __init__(self, dataset_splits: list[int] | None = None):
super().__init__(dataset_splits=dataset_splits)
self.label = 0

Expand Down
7 changes: 5 additions & 2 deletions src/ramjet/photometric_database/derived/toy_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self):
ToySineWaveLightCurveCollection(),
]


class ToyRamjetDatabaseWithAuxiliary(StandardAndInjectedLightCurveDatabase):
def __init__(self):
super().__init__()
Expand All @@ -36,8 +37,10 @@ def __init__(self):
self.number_of_auxiliary_values = 2
flat_collection = ToyFlatLightCurveCollection()
sine_wave_collection = ToySineWaveLightCurveCollection()
flat_collection.load_auxiliary_information_for_path = lambda path: np.array([0, 0], dtype=np.float32)
sine_wave_collection.load_auxiliary_information_for_path = lambda path: np.array([1, 1], dtype=np.float32)
flat_collection.load_auxiliary_information_for_path = (
lambda path: np.array([0, 0], dtype=np.float32)) # noqa ARG002
sine_wave_collection.load_auxiliary_information_for_path = (
lambda path: np.array([1, 1], dtype=np.float32)) # noqa ARG002
self.training_standard_light_curve_collections = [
flat_collection,
sine_wave_collection,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from collections.abc import Iterable
from __future__ import annotations

from pathlib import Path
from typing import Union
from typing import TYPE_CHECKING

import numpy as np

from ramjet.photometric_database.light_curve import LightCurve
from ramjet.photometric_database.light_curve_collection import LightCurveCollection

if TYPE_CHECKING:
from collections.abc import Iterable


class ToyLightCurveCollection(LightCurveCollection):
"""
Expand All @@ -26,7 +30,7 @@ def __init__(self):
super().__init__()
self.label = 0

def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray):
def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002
light_curve = ToyLightCurve.flat()
return light_curve.times, light_curve.fluxes

Expand All @@ -43,7 +47,7 @@ def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray
light_curve = ToyLightCurve.flat(float(path.name))
return light_curve.times, light_curve.fluxes

def load_label_from_path(self, path: Path) -> Union[float, np.ndarray]:
def load_label_from_path(self, path: Path) -> float | np.ndarray:
label = float(path.name)
return label

Expand All @@ -53,7 +57,7 @@ def __init__(self):
super().__init__()
self.label = 1

def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray):
def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002
light_curve = ToyLightCurve.sine_wave()
return light_curve.times, light_curve.fluxes

Expand Down
20 changes: 13 additions & 7 deletions src/ramjet/photometric_database/light_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
"""
from __future__ import annotations

from abc import ABC

import lightkurve.lightcurve
import numpy as np
import pandas as pd
from lightkurve.periodogram import LombScarglePeriodogram


class LightCurve(ABC):
class LightCurve:
"""
A class to represent a light curve. A light curve is a collection of data which may includes times, fluxes,
flux errors, and related values.
"""

def __init__(self):
self.data_frame: pd.DataFrame = pd.DataFrame()
self.flux_column_names: list[str] = []
Expand Down Expand Up @@ -59,7 +58,7 @@ def times(self, value: np.ndarray):
def folded_times(self):
if self.folded_times_column_name not in self.data_frame.columns:
error_message = 'Light curve has not been folded.'
raise MissingFoldedTimes(error_message)
raise MissingFoldedTimesError(error_message)
return self.data_frame[self.folded_times_column_name].values

@folded_times.setter
Expand Down Expand Up @@ -114,7 +113,11 @@ def to_lightkurve(self) -> lightkurve.lightcurve.LightCurve:
def get_variability_phase_folding_parameters(
self, minimum_period: float | None = None, maximum_period: float | None = None
) -> (float, float, float, float, float):
fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve, periodogram, folded_lightkurve_light_curve = self.get_variability_phase_folding_parameters_and_folding_lightkurve_light_curves(minimum_period=minimum_period, maximum_period=maximum_period)
(
fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve,
periodogram, folded_lightkurve_light_curve
) = self.get_variability_phase_folding_parameters_and_folding_lightkurve_light_curves(
minimum_period=minimum_period, maximum_period=maximum_period)
self._variability_period = fold_period
self._variability_period_epoch = fold_epoch
return fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase
Expand All @@ -138,11 +141,14 @@ def get_variability_phase_folding_parameters_and_folding_lightkurve_light_curves
minimum_bin_phase = binned_folded_lightkurve_light_curve.phase.value[minimum_bin_index]
maximum_bin_phase = binned_folded_lightkurve_light_curve.phase.value[maximum_bin_index]
fold_epoch = inlier_lightkurve_light_curve.time.value[0]
return fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve, periodogram, folded_lightkurve_light_curve
return (
fold_period, fold_epoch, time_bin_size, minimum_bin_phase, maximum_bin_phase, inlier_lightkurve_light_curve,
periodogram, folded_lightkurve_light_curve
)

def fold(self, period: float, epoch: float) -> None:
self.folded_times = (self.times - epoch) % period


class MissingFoldedTimes(Exception):
class MissingFoldedTimesError(Exception):
pass
24 changes: 14 additions & 10 deletions src/ramjet/photometric_database/light_curve_collection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""
Code for representing a collection of light curves.
"""
from collections.abc import Iterable
from __future__ import annotations

from pathlib import Path
from typing import Union
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
from collections.abc import Iterable


class LightCurveCollectionMethodNotImplementedError(RuntimeError):
"""
Expand All @@ -25,8 +29,8 @@ class LightCurveCollection:
:ivar paths: The default list of paths to be used if the `get_paths` method is not overridden.
"""
def __init__(self):
self.label: Union[float, list[float], np.ndarray, None] = None
self.paths: Union[list[Path], None] = None
self.label: float | list[float] | np.ndarray | None = None
self.paths: list[Path] | None = None

def get_paths(self) -> Iterable[Path]:
"""
Expand All @@ -36,7 +40,7 @@ def get_paths(self) -> Iterable[Path]:
"""
return self.paths

def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray):
def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002
"""
Loads the times and fluxes from a given light curve path.
Expand All @@ -45,7 +49,7 @@ def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray
"""
raise LightCurveCollectionMethodNotImplementedError

def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np.ndarray):
def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002
"""
Loads the times and magnifications from a given path as an injectable signal.
Expand All @@ -69,7 +73,7 @@ def generate_synthetic_signal_from_real_data(fluxes: np.ndarray, times: np.ndarr
relative_times = times - np.min(times)
return normalized_fluxes, relative_times

def load_label_from_path(self, path: Path) -> Union[float, np.ndarray]:
def load_label_from_path(self, path: Path) -> float | np.ndarray: # noqa ARG002
"""
Loads the label of an example from a corresponding path.
Expand Down Expand Up @@ -101,7 +105,7 @@ def shuffle_and_split_paths(paths: list[Path], dataset_splits: list[int], number
return dataset_split_paths

def load_times_fluxes_and_flux_errors_from_path(self, path: Path
) -> (np.ndarray, np.ndarray, Union[np.ndarray, None]):
) -> (np.ndarray, np.ndarray, np.ndarray | None):
"""
Loads the times, fluxes, and flux errors of a light curve from a path to the data.
Unless overridden, defaults to using the method to load only the times and fluxes, and returns None for errors.
Expand All @@ -114,7 +118,7 @@ def load_times_fluxes_and_flux_errors_from_path(self, path: Path
return times, fluxes, flux_errors

def load_times_magnifications_and_magnification_errors_from_path(
self, path: Path) -> (np.ndarray, np.ndarray, Union[np.ndarray, None]):
self, path: Path) -> (np.ndarray, np.ndarray, np.ndarray | None):
"""
Loads the times, magnifications, and magnification_errors of a light curve from a path to the data.
Unless overridden, defaults to using the method to load only the times and magnifications,
Expand All @@ -127,7 +131,7 @@ def load_times_magnifications_and_magnification_errors_from_path(
flux_errors = None
return times, fluxes, flux_errors

def load_auxiliary_information_for_path(self, path: Path) -> np.ndarray:
def load_auxiliary_information_for_path(self, path: Path) -> np.ndarray: # noqa ARG002
"""
Loads auxiliary information for the given path.
Expand Down
Loading

0 comments on commit bd3dc87

Please sign in to comment.