Skip to content

Commit

Permalink
fix bug in PS matching
Browse files Browse the repository at this point in the history
  • Loading branch information
sprivite committed Jul 3, 2024
1 parent 0abb69b commit c9bdc2c
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pybalance/propensity/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def __init__(
self.hyperparam_space = self.DEFAULT_HYPERPARAM_SPACE
self.verbose = verbose

if caliper is not None and caliper <= 0:
raise ValueError("Caliper, if defined, must be greater than 0.")

self.matching_data = matching_data.copy()
self.target, self.pool = split_target_pool(matching_data)
if isinstance(objective, str):
Expand Down Expand Up @@ -172,10 +175,6 @@ def match(self) -> MatchingData:
X, y = self._preprocess_data_for_sklearn(self.matching_data)
hyperparams = self._get_hyperparams(self.max_iter)
for i, (model, params) in enumerate(hyperparams):
if (time.time() - t0) > self.time_limit:
logger.warning("Time limit exceeded. Stopping early.")
break

clf = model(**params)
logger.info(
f'Training model {str(clf).split("(")[0]} (iter {i + 1}/{self.max_iter}, {(time.time() - t0)/60:.3f} min) ...'
Expand All @@ -194,7 +193,6 @@ def match(self) -> MatchingData:
headers=self.matching_data.headers,
population_col=self.matching_data.population_col,
)

score = self.balance_calculator.distance(pool)

if score < self.best_score:
Expand All @@ -203,6 +201,11 @@ def match(self) -> MatchingData:
clf, params, match, score, ps_pool, ps_target, solution_time
)

# Put break at the end to ensure at least one iteration
if (time.time() - t0) > self.time_limit:
logger.warning("Time limit exceeded. Stopping early.")
break

return self.get_best_match()

def _train_preprocessors(self, matching_data):
Expand All @@ -228,6 +231,7 @@ def _preprocess_data_for_sklearn(self, matching_data):

def _get_hyperparams(self, n_iter):
hyperparams = []
n_iter += 1
n_models = len(self.hyperparam_space)
for model, params in self.hyperparam_space.items():
hyperparams.extend(
Expand All @@ -239,7 +243,7 @@ def _get_hyperparams(self, n_iter):

np.random.shuffle(hyperparams)

return hyperparams
return hyperparams[:n_iter]

def get_propensity_score(
self, clf: BaseEstimator = None, matching_data: MatchingData = None
Expand Down

0 comments on commit c9bdc2c

Please sign in to comment.