Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Oct 26, 2024
1 parent cf079e6 commit 55e7dfe
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 34 deletions.
68 changes: 36 additions & 32 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,39 +1054,43 @@ def calculate_feature_importances(self):
"""
# generate feature importance dict
feature_importance_list = []

for index, ensemble_row in self.ensemble_df[
self.ensemble_df["stixel_checklist_count"] >= self.stixel_training_size_threshold
].iterrows():
if ensemble_row["stixel_checklist_count"] < self.stixel_training_size_threshold:
continue

try:
stixel_index = ensemble_row["unique_stixel_id"]
the_model = self.model_dict[f"{stixel_index}_model"]
x_names = self.stixel_specific_x_names[{stixel_index}]

if isinstance(the_model, dummy_model1):
importance_dict = dict(zip(self.x_names, [1 / len(self.x_names)] * len(self.x_names)))
elif isinstance(the_model, Hurdle):
if "feature_importances_" in the_model.__dir__():
importance_dict = dict(zip(x_names, the_model.feature_importances_))
else:
if isinstance(the_model.classifier, dummy_model1):
importance_dict = dict(zip(self.x_names, [1 / len(self.x_names)] * len(self.x_names)))

for ensemble_id in self.ensemble_df['ensemble_index'].unique():
for index, ensemble_row in self.ensemble_df[self.ensemble_df['ensemble_index']==ensemble_id][
self.ensemble_df["stixel_checklist_count"] >= self.stixel_training_size_threshold
].iterrows():
if ensemble_row["stixel_checklist_count"] < self.stixel_training_size_threshold:
continue

Check warning on line 1063 in stemflow/model/AdaSTEM.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model/AdaSTEM.py#L1063

Added line #L1063 was not covered by tests

try:
stixel_index = ensemble_row["unique_stixel_id"]
the_model = self.model_dict[f"{stixel_index}_model"]
x_names = self.stixel_specific_x_names[stixel_index]

if isinstance(the_model, dummy_model1):
importance_dict = dict(zip(self.x_names, [1 / len(self.x_names)] * len(self.x_names)))
elif isinstance(the_model, Hurdle):
if "feature_importances_" in the_model.__dir__():
importance_dict = dict(zip(x_names, the_model.feature_importances_))
else:
importance_dict = dict(zip(x_names, the_model.classifier.feature_importances_))
else:
importance_dict = dict(zip(x_names, the_model.feature_importances_))

importance_dict["stixel_index"] = stixel_index
feature_importance_list.append(importance_dict)
if isinstance(the_model.classifier, dummy_model1):
importance_dict = dict(zip(self.x_names, [1 / len(self.x_names)] * len(self.x_names)))
else:
importance_dict = dict(zip(x_names, the_model.classifier.feature_importances_))
else:
importance_dict = dict(zip(x_names, the_model.feature_importances_))

except Exception as e:
warnings.warn(f"{e}")
# print(e)
continue
importance_dict["stixel_index"] = stixel_index
feature_importance_list.append(importance_dict)

except Exception as e:
warnings.warn(f"{e}")
# print(e)
continue

if self.lazy_loading:
self.model_dict.dump_ensemble(ensemble_id)

self.feature_importances_ = (
pd.DataFrame(feature_importance_list).set_index("stixel_index").reset_index(drop=False).fillna(0)
)
Expand Down Expand Up @@ -1161,7 +1165,7 @@ def assign_feature_importances_by_points(
for var_name in [self.Spatio1, self.Spatio2, self.Temporal1]:
if var_name not in Sample_ST_df.columns:
raise KeyError(f"{var_name} not found in Sample_ST_df.columns")

partial_assign_func = partial(
assign_function,
ensemble_df=self.ensemble_df,
Expand All @@ -1171,7 +1175,7 @@ def assign_feature_importances_by_points(
Spatio2=self.Spatio2,
feature_importances_=self.feature_importances_,
)

# assign input spatio-temporal points to stixels
if n_jobs > 1:
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator")

Check warning on line 1181 in stemflow/model/AdaSTEM.py

View check run for this annotation

Codecov / codecov/patch

stemflow/model/AdaSTEM.py#L1181

Added line #L1181 was not covered by tests
Expand Down
5 changes: 3 additions & 2 deletions stemflow/utils/lazyloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def load_ensemble(self, ensemble_id, force=False):
if ((not force) and (ensemble_id not in self.ensemble_models)) or force:
ensemble_path = os.path.join(self.directory, f"ensemble_{ensemble_id}_dict.pkl")
if not os.path.exists(ensemble_path):
raise FileNotFoundError(f"Ensemble file for ID {ensemble_id} not found.")
raise FileNotFoundError(f"Ensemble file for ID {ensemble_id} not found at {ensemble_path}.")

Check warning on line 173 in stemflow/utils/lazyloading.py

View check run for this annotation

Codecov / codecov/patch

stemflow/utils/lazyloading.py#L173

Added line #L173 was not covered by tests

loaded_ensemble = joblib.load(ensemble_path)
if ensemble_id in self.ensemble_models:
Expand All @@ -192,8 +192,9 @@ def delete_ensemble(self, ensemble_id):
for key in self.ensemble_models[ensemble_id]:
del self.key_to_ensemble[key]
del self.ensemble_models[ensemble_id]

Check warning on line 194 in stemflow/utils/lazyloading.py

View check run for this annotation

Codecov / codecov/patch

stemflow/utils/lazyloading.py#L192-L194

Added lines #L192 - L194 were not covered by tests

ensemble_path = os.path.join(self.directory, f"ensemble_{ensemble_id}_dict.pkl")
if os.path.exists(ensemble_path):
os.remove(ensemble_path)

Check warning on line 198 in stemflow/utils/lazyloading.py

View check run for this annotation

Codecov / codecov/patch

stemflow/utils/lazyloading.py#L196-L198

Added lines #L196 - L198 were not covered by tests
else:
raise ValueError(f'Ensemble {ensemble_id} does not exist.')
raise ValueError(f'Ensemble {ensemble_id} not found on disk.')

Check warning on line 200 in stemflow/utils/lazyloading.py

View check run for this annotation

Codecov / codecov/patch

stemflow/utils/lazyloading.py#L200

Added line #L200 was not covered by tests

0 comments on commit 55e7dfe

Please sign in to comment.