Skip to content

Commit

Permalink
have learn_loop read from MongoDB
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwest-uw committed Nov 22, 2024
1 parent adcfeb4 commit c90b50e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 32 deletions.
19 changes: 16 additions & 3 deletions src/resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,16 @@ def process_photometry_features(
print('\n Loaded ', self.test_metadata.shape[0],
' samples! \n')

def load_features(self, path_to_file: str, feature_extractor: str ='Bazin',
screen=False, survey='DES', sample=None ):
def load_features(
self,
path_to_file: str = None,
mongo_query: dict = None,
feature_extractor: str ='Bazin',
screen=False,
survey='DES',
sample=None,
location="filesystem",
):
"""Load features according to the chosen feature extraction method.
Populates properties: data, features, feature_list, header
Expand All @@ -392,7 +400,12 @@ def load_features(self, path_to_file: str, feature_extractor: str ='Bazin',
else, read independent files for 'train' and 'test'.
Default is None.
"""
features_data = load_external_features(path_to_file, location="filesystem")
features_data = load_external_features(
filename=path_to_file,
mongo_query=mongo_query,
feature_extractor=feature_extractor,
location=location,
)
if feature_extractor == "photometry":
self.process_photometry_features(features_data, screen=screen, survey=survey, sample=sample)
else:
Expand Down
22 changes: 19 additions & 3 deletions src/resspect/feature_handling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def save_features(
else:
raise ValueError("filename must be provided if saving to the filesystem.")
elif location == "mongodb":
with open("~/homework/mongodb_test.pass") as f:
# temporary local fix obviously
with open("/Users/maxwest/homework/mongodb_test.pass") as f:
MONGO_URI = f.readline().strip("\n")
client = MongoClient(MONGO_URI)
db = client[MONGODB_NAME]
Expand All @@ -42,8 +43,10 @@ def save_features(
def load_external_features(
filename: str = None,
location: str = "filesystem",
mongo_query: dict = None,
feature_extractor: str = "Malanchev",
):
"Load features from a .csv file."
"Load features from a .csv file or download them from a MongoDB instance."
data = None
if location == "filesystem":
if filename is not None:
Expand All @@ -59,6 +62,19 @@ def load_external_features(
data = pd.read_csv(filename, sep=' ', index_col=False)
else:
raise ValueError("filename must be provided if reading from the filesystem.")
elif location == "mongodb":
with open("/Users/maxwest/homework/mongodb_test.pass") as f:
MONGO_URI = f.readline().strip("\n")
client = MongoClient(MONGO_URI)
db = client[MONGODB_NAME]
collection = db[MONGO_COLLECTION_NAMES[feature_extractor]]

cursor = collection.find(mongo_query)
data_dicts = []
for element in cursor:
data_dicts.append(element)
# Potential TODO: drop the MongoDB `_id` column ?
data = pd.DataFrame(data_dicts)
else:
raise NotImplementedError("Alternative storage method implementation tbd.")
raise ValueError("location must either be 'filesystem' or 'location'")
return data
40 changes: 24 additions & 16 deletions src/resspect/learn_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,33 @@ def load_features(database_class: DataBase, config: LoopConfiguration) -> DataBa
config: `resspect.loop_configuration.LoopConfiguration`
The configuration elements of the learn loop.
"""
if isinstance(config.path_to_features, str):
if config.path_to_features is not None:
if isinstance(config.path_to_features, str):
database_class.load_features(
path_to_file=config.path_to_features,
feature_extractor=config.features_method,
survey=config.survey,
)
else:
features_set_names = ['train', 'test', 'validation', 'pool']
for sample_name in features_set_names:
if sample_name in config.path_to_features.keys():
database_class.load_features(
config.path_to_features[sample_name],
feature_extractor=config.features_method,
survey=config.survey,
sample=sample_name
)
else:
logging.warning(f'Path to {sample_name} not given.'
f' Proceeding without this sample')
else:
database_class.load_features(
path_to_file=config.path_to_features,
mongo_query=config.features_query,
feature_extractor=config.features_method,
survey=config.survey
survey=config.survey,
location="mongodb",
)
else:
features_set_names = ['train', 'test', 'validation', 'pool']
for sample_name in features_set_names:
if sample_name in config.path_to_features.keys():
database_class.load_features(
config.path_to_features[sample_name],
feature_extractor=config.features_method,
survey=config.survey,
sample=sample_name
)
else:
logging.warning(f'Path to {sample_name} not given.'
f' Proceeding without this sample')

database_class.build_samples(
initial_training=config.training,
Expand Down
24 changes: 14 additions & 10 deletions src/resspect/loop_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ class LoopConfiguration(BaseConfiguration):
"""
nloops: int
strategy: str
path_to_features: str
output_metrics_file: str
output_queried_file: str
path_to_features: str = None
features_query: dict = None
features_method: str = 'Bazin'
classifier: str = 'RandomForest'
training: str = 'original'
Expand All @@ -121,16 +122,19 @@ class LoopConfiguration(BaseConfiguration):

def __post_init__(self):
# file checking
if isinstance(self.path_to_features, str):
if not path.isfile(self.path_to_features):
raise ValueError("`path_to_features` must be an existing file.")
elif isinstance(self.path_to_features, dict):
for key in self.path_to_features.keys():
if not path.isfile(self.path_to_features[key]):
raise ValueError(f"path for '{key}' does not exist.")
if self.path_to_features is not None:
if isinstance(self.path_to_features, str):
if not path.isfile(self.path_to_features):
raise ValueError("`path_to_features` must be an existing file.")
elif isinstance(self.path_to_features, dict):
for key in self.path_to_features.keys():
if not path.isfile(self.path_to_features[key]):
raise ValueError(f"path for '{key}' does not exist.")
else:
raise ValueError("`path_to_features` must be a str or dict.")
else:
raise ValueError("`path_to_features` must be a str or dict.")

if self.features_query is None:
raise ValueError("Must provide either features file or MongoDB query.")
if isinstance(self.pretrained_model_path, str):
if not path.isfile(self.pretrained_model_path):
raise ValueError("`pretrained_model_path` must be an existing file.")
Expand Down

0 comments on commit c90b50e

Please sign in to comment.