From 5d5d2494e98dcce8169ac081322799ee82f3b255 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 26 Nov 2024 14:34:46 -0800 Subject: [PATCH] Implement configure_runner, configure_metric (#3104) Summary: configure_runner and configure_metric allow users to attach custom Runners and Metrics to their experiment. configure_runner is fairly straightforward and just sets experiment.runner configure_metric is more complicated: given a list of IMetrics it iterates through and tries to find a metric with the same name somewhere on the experiment. In order it checks the Objective (single, MOO, or secularized), outcome constraints, then tracking metrics. If no metric with a matching name is found then the provided metric is added as a tracking metric. Reviewed By: lena-kashtelyan Differential Revision: D66305614 --- ax/preview/api/client.py | 80 ++++++++++++++- ax/preview/api/tests/test_client.py | 148 +++++++++++++++++++++++++++- 2 files changed, 222 insertions(+), 6 deletions(-) diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index d513a62e73b..37fd2f55a7c 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -14,6 +14,7 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.metric import Metric +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import OptimizationConfig from ax.core.runner import Runner from ax.early_stopping.strategies import BaseEarlyStoppingStrategy @@ -167,7 +168,7 @@ def configure_runner(self, runner: IRunner) -> None: Saves to database on completion if db_config is present. """ - ... + self.set_runner(runner=runner) def configure_metrics(self, metrics: Sequence[IMetric]) -> None: """ @@ -175,7 +176,7 @@ def configure_metrics(self, metrics: Sequence[IMetric]) -> None: replaces it with the Metric provided, or adds the Metric provided to the Experiment as tracking metrics. """ - ... + self.set_metrics(metrics=metrics) # -------------------- Section 1.2: Set (not API) ------------------------------- def set_experiment(self, experiment: Experiment) -> None: @@ -224,6 +225,10 @@ def set_generation_strategy( """ self._generation_strategy = generation_strategy + none_throws( + self._generation_strategy + )._experiment = self._none_throws_experiment() + if self.db_config is not None: # TODO[mpolson64] Save to database ... @@ -257,7 +262,11 @@ def set_runner(self, runner: Runner) -> None: Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().runner = runner + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_metrics(self, metrics: Sequence[Metric]) -> None: """ @@ -269,7 +278,16 @@ def set_metrics(self, metrics: Sequence[Metric]) -> None: replaces it with the Metric provided, or adds the Metric provided to the Experiment as tracking metrics. """ - ... + # If an equivalently named Metric already exists on the Experiment, replace it + # with the Metric provided. Otherwise, add the Metric to the Experiment as a + # tracking metric. + for metric in metrics: + # Check the optimization config first + self._overwrite_metric(metric=metric) + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... # -------------------- Section 2. Conduct Experiment ---------------------------- def get_next_trials( @@ -512,3 +530,57 @@ def _none_throws_experiment(self) -> Experiment: "experiment before utilizing any other methods on the Client." ), ) + + def _overwrite_metric(self, metric: Metric) -> None: + """ + Overwrite an existing Metric on the Experiment with the provided Metric if they + share the same name. If not Metric with the same name exists, add the Metric as + a tracking metric. + """ + + # Check the OptimizationConfig first + if ( + optimization_config := self._none_throws_experiment().optimization_config + ) is not None: + # Check the objective + if isinstance( + multi_objective := optimization_config.objective, MultiObjective + ): + for i in range(len(multi_objective.objectives)): + if metric.name == multi_objective.objectives[i].metric.name: + multi_objective._objectives[i]._metric = metric + return + + if isinstance( + scalarized_objective := optimization_config.objective, + ScalarizedObjective, + ): + for i in range(len(scalarized_objective.metrics)): + if metric.name == scalarized_objective.metrics[i].name: + scalarized_objective._metrics[i] = metric + return + + if ( + isinstance(optimization_config.objective, Objective) + and metric.name == optimization_config.objective.metric.name + ): + optimization_config.objective._metric = metric + return + + # Check the outcome constraints + for i in range(len(optimization_config.outcome_constraints)): + if ( + metric.name + == optimization_config.outcome_constraints[i].metric.name + ): + optimization_config._outcome_constraints[i]._metric = metric + return + + # Check the tracking metrics + if metric.name in self._none_throws_experiment()._tracking_metrics: + self._none_throws_experiment()._tracking_metrics[metric.name] = metric + return + + # If an equivalently named Metric does not exist, add it as a tracking + # metric. + self._none_throws_experiment().add_tracking_metric(metric=metric) diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 11960a53365..09a16ab77ce 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -5,9 +5,13 @@ # pyre-strict +from typing import Any, Mapping + +from ax.core.base_trial import TrialStatus + from ax.core.experiment import Experiment from ax.core.metric import Metric -from ax.core.objective import Objective +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint from ax.core.parameter import ( @@ -25,6 +29,9 @@ ParameterType, RangeParameterConfig, ) +from ax.preview.api.protocols.metric import IMetric +from ax.preview.api.protocols.runner import IRunner +from ax.preview.api.types import TParameterization from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, @@ -32,7 +39,7 @@ get_percentile_early_stopping_strategy, ) from ax.utils.testing.modeling_stubs import get_generation_strategy -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws, override class TestClient(TestCase): @@ -161,6 +168,118 @@ def test_configure_optimization(self) -> None: outcome_constraints=["qps >= 0"], ) + def test_configure_runner(self) -> None: + client = Client() + runner = DummyRunner() + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.configure_runner(runner=runner) + + client.set_experiment(experiment=get_branin_experiment()) + client.configure_runner(runner=runner) + + self.assertEqual(none_throws(client._experiment).runner, runner) + + def test_configure_metric(self) -> None: + client = Client() + custom_metric = DummyMetric(name="custom") + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.configure_metrics(metrics=[custom_metric]) + + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1) + ) + ], + name="foo", + ) + ) + + # Test replacing a single objective + client.configure_optimization(objective="custom") + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws( + none_throws(client._experiment).optimization_config + ).objective.metric, + ) + + # Test replacing a multi-objective + client.configure_optimization(objective="custom, foo") + client.configure_metrics(metrics=[custom_metric]) + + self.assertIn( + custom_metric, + assert_is_instance( + none_throws( + none_throws(client._experiment).optimization_config + ).objective, + MultiObjective, + ).metrics, + ) + # Test replacing a scalarized objective + client.configure_optimization(objective="custom + foo") + client.configure_metrics(metrics=[custom_metric]) + + self.assertIn( + custom_metric, + assert_is_instance( + none_throws( + none_throws(client._experiment).optimization_config + ).objective, + ScalarizedObjective, + ).metrics, + ) + + # Test replacing an outcome constraint + client.configure_optimization( + objective="foo", outcome_constraints=["custom >= 0"] + ) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(none_throws(client._experiment).optimization_config) + .outcome_constraints[0] + .metric, + ) + + # Test replacing a tracking metric + client.configure_optimization( + objective="foo", + ) + none_throws(client._experiment).add_tracking_metric(metric=Metric("custom")) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(client._experiment).tracking_metrics[0], + ) + + # Test adding a tracking metric + client = Client() # Start a fresh Client + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1) + ) + ], + name="foo", + ) + ) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(client._experiment).tracking_metrics[0], + ) + def test_set_experiment(self) -> None: client = Client() experiment = get_branin_experiment() @@ -202,3 +321,28 @@ def test_set_early_stopping_strategy(self) -> None: early_stopping_strategy=early_stopping_strategy ) self.assertEqual(client._early_stopping_strategy, early_stopping_strategy) + + +class DummyRunner(IRunner): + @override + def run_trial( + self, trial_index: int, parameterization: TParameterization + ) -> dict[str, Any]: ... + + @override + def poll_trial( + self, trial_index: int, trial_metadata: Mapping[str, Any] + ) -> TrialStatus: ... + + @override + def stop_trial( + self, trial_index: int, trial_metadata: Mapping[str, Any] + ) -> dict[str, Any]: ... + + +class DummyMetric(IMetric): + def fetch( + self, + trial_index: int, + trial_metadata: Mapping[str, Any], + ) -> tuple[int, float | tuple[float, float]]: ...