diff --git a/pybalance/propensity/matcher.py b/pybalance/propensity/matcher.py index 260ea56..a0f7bdf 100644 --- a/pybalance/propensity/matcher.py +++ b/pybalance/propensity/matcher.py @@ -108,6 +108,9 @@ def __init__( self.hyperparam_space = self.DEFAULT_HYPERPARAM_SPACE self.verbose = verbose + if caliper is not None and caliper <= 0: + raise ValueError("Caliper, if defined, must be greater than 0.") + self.matching_data = matching_data.copy() self.target, self.pool = split_target_pool(matching_data) if isinstance(objective, str): @@ -172,10 +175,6 @@ def match(self) -> MatchingData: X, y = self._preprocess_data_for_sklearn(self.matching_data) hyperparams = self._get_hyperparams(self.max_iter) for i, (model, params) in enumerate(hyperparams): - if (time.time() - t0) > self.time_limit: - logger.warning("Time limit exceeded. Stopping early.") - break - clf = model(**params) logger.info( f'Training model {str(clf).split("(")[0]} (iter {i + 1}/{self.max_iter}, {(time.time() - t0)/60:.3f} min) ...' @@ -194,7 +193,6 @@ def match(self) -> MatchingData: headers=self.matching_data.headers, population_col=self.matching_data.population_col, ) - score = self.balance_calculator.distance(pool) if score < self.best_score: @@ -203,6 +201,11 @@ def match(self) -> MatchingData: clf, params, match, score, ps_pool, ps_target, solution_time ) + # Put break at the end to ensure at least one iteration + if (time.time() - t0) > self.time_limit: + logger.warning("Time limit exceeded. Stopping early.") + break + return self.get_best_match() def _train_preprocessors(self, matching_data): @@ -228,6 +231,7 @@ def _preprocess_data_for_sklearn(self, matching_data): def _get_hyperparams(self, n_iter): hyperparams = [] + n_iter += 1 n_models = len(self.hyperparam_space) for model, params in self.hyperparam_space.items(): hyperparams.extend( @@ -239,7 +243,7 @@ def _get_hyperparams(self, n_iter): np.random.shuffle(hyperparams) - return hyperparams + return hyperparams[:n_iter] def get_propensity_score( self, clf: BaseEstimator = None, matching_data: MatchingData = None