diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index b212e3494c0..fa03192d987 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -467,7 +467,11 @@ def mark_trial_failed(self, trial_index: int) -> None: Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().trials[trial_index].mark_failed() + + if self._db_config is not None: + # TODO[mpolson64] Save to database + ... def mark_trial_abandoned(self, trial_index: int) -> None: """ @@ -476,7 +480,11 @@ def mark_trial_abandoned(self, trial_index: int) -> None: Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().trials[trial_index].mark_abandoned() + + if self._db_config is not None: + # TODO[mpolson64] Save to database + ... def mark_trial_early_stopped(self, trial_index: int) -> None: """ @@ -486,7 +494,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None: Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().trials[trial_index].mark_early_stopped() + + if self._db_config is not None: + # TODO[mpolson64] Save to database + ... def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None: """ diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index b747c6d4916..116ee32be86 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -379,6 +379,72 @@ def test_get_next_trials(self) -> None: trials = client.get_next_trials(maximum_trials=2) self.assertEqual(len(trials), 1) + def test_mark_trial_failed(self) -> None: + client = Client() + + client.configure_experiment( + ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1) + ), + ], + name="foo", + ) + ) + client.configure_optimization(objective="foo") + + trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] + client.mark_trial_failed(trial_index=trial_index) + self.assertEqual( + none_throws(client._experiment).trials[trial_index].status, + TrialStatus.FAILED, + ) + + def test_mark_trial_abandoned(self) -> None: + client = Client() + + client.configure_experiment( + ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1) + ), + ], + name="foo", + ) + ) + client.configure_optimization(objective="foo") + + trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] + client.mark_trial_abandoned(trial_index=trial_index) + self.assertEqual( + none_throws(client._experiment).trials[trial_index].status, + TrialStatus.ABANDONED, + ) + + def test_mark_trial_early_stopped(self) -> None: + client = Client() + + client.configure_experiment( + ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1) + ), + ], + name="foo", + ) + ) + client.configure_optimization(objective="foo") + + trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] + client.mark_trial_early_stopped(trial_index=trial_index) + self.assertEqual( + none_throws(client._experiment).trials[trial_index].status, + TrialStatus.EARLY_STOPPED, + ) + class DummyRunner(IRunner): @override