Skip to content

Commit

Permalink
Implement trial status marking (facebook#3119)
Browse files Browse the repository at this point in the history
Summary:

As titled. These methods are super thin wrappers around the existing methods in core Ax. A user needs these 3 methods to indicate something has gone wrong with a trial and will use them in conjunction with complete_trial (which includes  optional data attaching) to have full control over trial status.

Differential Revision: D66507188
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Nov 26, 2024
1 parent f63442b commit de7d041
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 3 deletions.
18 changes: 15 additions & 3 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,11 @@ def mark_trial_failed(self, trial_index: int) -> None:
Saves to database on completion if db_config is present.
"""
...
self._none_throws_experiment().trials[trial_index].mark_failed()

if self._db_config is not None:
# TODO[mpolson64] Save to database
...

def mark_trial_abandoned(self, trial_index: int) -> None:
"""
Expand All @@ -476,7 +480,11 @@ def mark_trial_abandoned(self, trial_index: int) -> None:
Saves to database on completion if db_config is present.
"""
...
self._none_throws_experiment().trials[trial_index].mark_abandoned()

if self._db_config is not None:
# TODO[mpolson64] Save to database
...

def mark_trial_early_stopped(self, trial_index: int) -> None:
"""
Expand All @@ -486,7 +494,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None:
Saves to database on completion if db_config is present.
"""
...
self._none_throws_experiment().trials[trial_index].mark_early_stopped()

if self._db_config is not None:
# TODO[mpolson64] Save to database
...

def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None:
"""
Expand Down
66 changes: 66 additions & 0 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,72 @@ def test_get_next_trials(self) -> None:
trials = client.get_next_trials(maximum_trials=2)
self.assertEqual(len(trials), 1)

def test_mark_trial_failed(self) -> None:
client = Client()

client.configure_experiment(
ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo")

trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0]
client.mark_trial_failed(trial_index=trial_index)
self.assertEqual(
none_throws(client._experiment).trials[trial_index].status,
TrialStatus.FAILED,
)

def test_mark_trial_abandoned(self) -> None:
client = Client()

client.configure_experiment(
ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo")

trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0]
client.mark_trial_abandoned(trial_index=trial_index)
self.assertEqual(
none_throws(client._experiment).trials[trial_index].status,
TrialStatus.ABANDONED,
)

def test_mark_trial_early_stopped(self) -> None:
client = Client()

client.configure_experiment(
ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo")

trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0]
client.mark_trial_early_stopped(trial_index=trial_index)
self.assertEqual(
none_throws(client._experiment).trials[trial_index].status,
TrialStatus.EARLY_STOPPED,
)


class DummyRunner(IRunner):
@override
Expand Down

0 comments on commit de7d041

Please sign in to comment.