diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 65f9bbf0a60..fc451be2c7f 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - +from logging import Logger from typing import Sequence from ax.analysis.analysis import Analysis, AnalysisCard # Used as a return type @@ -14,6 +14,7 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.metric import Metric +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import OptimizationConfig from ax.core.runner import Runner from ax.early_stopping.strategies import BaseEarlyStoppingStrategy @@ -33,9 +34,12 @@ from ax.preview.api.utils.instantiation.from_string import ( optimization_config_from_string, ) +from ax.utils.common.logger import get_logger from pyre_extensions import none_throws from typing_extensions import Self +logger: Logger = get_logger(__name__) + class Client: _experiment: Experiment | None = None @@ -168,15 +172,16 @@ def configure_runner(self, runner: IRunner) -> None: Saves to database on completion if db_config is present. """ - ... + self.set_runner(runner=runner) def configure_metrics(self, metrics: Sequence[IMetric]) -> None: """ - Finds equivallently named Metric that already exists on the Experiment and - replaces it with the Metric provided, or adds the Metric provided to the - Experiment as tracking metrics. + Attach a class with logic for autmating fetching of a given metric by + replacing its instance with the provided Metric from metrics sequence input, + or adds the Metric provided to the Experiment as a tracking metric if that + metric was not already present. """ - ... + self.set_metrics(metrics=metrics) # -------------------- Section 1.2: Set (not API) ------------------------------- def set_experiment(self, experiment: Experiment) -> None: @@ -228,6 +233,10 @@ def set_generation_strategy( self._generation_strategy )._experiment = self._none_throws_experiment() + none_throws( + self._generation_strategy + )._experiment = self._none_throws_experiment() + if self.db_config is not None: # TODO[mpolson64] Save to database ... @@ -261,7 +270,11 @@ def set_runner(self, runner: Runner) -> None: Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().runner = runner + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_metrics(self, metrics: Sequence[Metric]) -> None: """ @@ -269,11 +282,21 @@ def set_metrics(self, metrics: Sequence[Metric]) -> None: method signature stability) for the convenience of some developers, power users, and partners. - Finds equivallently named Metric that already exists on the Experiment and - replaces it with the Metric provided, or adds the Metric provided to the - Experiment as tracking metrics. + Attach a class with logic for autmating fetching of a given metric by + replacing its instance with the provided Metric from metrics sequence input, + or adds the Metric provided to the Experiment as a tracking metric if that + metric was not already present. """ - ... + # If an equivalently named Metric already exists on the Experiment, replace it + # with the Metric provided. Otherwise, add the Metric to the Experiment as a + # tracking metric. + for metric in metrics: + # Check the optimization config first + self._overwrite_metric(metric=metric) + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... # -------------------- Section 2. Conduct Experiment ---------------------------- def get_next_trials( @@ -516,3 +539,65 @@ def _none_throws_experiment(self) -> Experiment: "experiment before utilizing any other methods on the Client." ), ) + + def _overwrite_metric(self, metric: Metric) -> None: + """ + Overwrite an existing Metric on the Experiment with the provided Metric if they + share the same name. If not Metric with the same name exists, add the Metric as + a tracking metric. + + Note that this method does not save the Experiment to the database (this is + handled in self.set_metrics). + """ + + # Check the OptimizationConfig first + if ( + optimization_config := self._none_throws_experiment().optimization_config + ) is not None: + # Check the objective + if isinstance( + multi_objective := optimization_config.objective, MultiObjective + ): + for i in range(len(multi_objective.objectives)): + if metric.name == multi_objective.objectives[i].metric.name: + multi_objective._objectives[i]._metric = metric + return + + if isinstance( + scalarized_objective := optimization_config.objective, + ScalarizedObjective, + ): + for i in range(len(scalarized_objective.metrics)): + if metric.name == scalarized_objective.metrics[i].name: + scalarized_objective._metrics[i] = metric + return + + if ( + isinstance(optimization_config.objective, Objective) + and metric.name == optimization_config.objective.metric.name + ): + optimization_config.objective._metric = metric + return + + # Check the outcome constraints + for i in range(len(optimization_config.outcome_constraints)): + if ( + metric.name + == optimization_config.outcome_constraints[i].metric.name + ): + optimization_config._outcome_constraints[i]._metric = metric + return + + # Check the tracking metrics + tracking_metric_names = self._none_throws_experiment()._tracking_metrics.keys() + if metric.name in tracking_metric_names: + self._none_throws_experiment()._tracking_metrics[metric.name] = metric + return + + # If an equivalently named Metric does not exist, add it as a tracking + # metric. + self._none_throws_experiment().add_tracking_metric(metric=metric) + logger.warning( + f"Metric {metric} not found in optimization config, added as tracking " + "metric." + ) diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 4f25db7d069..ddb5e42399a 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -5,9 +5,13 @@ # pyre-strict +from typing import Any, Mapping + +from ax.core.base_trial import TrialStatus + from ax.core.experiment import Experiment from ax.core.metric import Metric -from ax.core.objective import Objective +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint from ax.core.parameter import ( @@ -25,6 +29,9 @@ ParameterType, RangeParameterConfig, ) +from ax.preview.api.protocols.metric import IMetric +from ax.preview.api.protocols.runner import IRunner +from ax.preview.api.types import TParameterization from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, @@ -32,7 +39,7 @@ get_percentile_early_stopping_strategy, ) from ax.utils.testing.modeling_stubs import get_generation_strategy -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws, override class TestClient(TestCase): @@ -161,6 +168,118 @@ def test_configure_optimization(self) -> None: outcome_constraints=["qps >= 0"], ) + def test_configure_runner(self) -> None: + client = Client() + runner = DummyRunner() + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.configure_runner(runner=runner) + + client.set_experiment(experiment=get_branin_experiment()) + client.configure_runner(runner=runner) + + self.assertEqual(none_throws(client._experiment).runner, runner) + + def test_configure_metric(self) -> None: + client = Client() + custom_metric = DummyMetric(name="custom") + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.configure_metrics(metrics=[custom_metric]) + + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1) + ) + ], + name="foo", + ) + ) + + # Test replacing a single objective + client.configure_optimization(objective="custom") + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws( + none_throws(client._experiment).optimization_config + ).objective.metric, + ) + + # Test replacing a multi-objective + client.configure_optimization(objective="custom, foo") + client.configure_metrics(metrics=[custom_metric]) + + self.assertIn( + custom_metric, + assert_is_instance( + none_throws( + none_throws(client._experiment).optimization_config + ).objective, + MultiObjective, + ).metrics, + ) + # Test replacing a scalarized objective + client.configure_optimization(objective="custom + foo") + client.configure_metrics(metrics=[custom_metric]) + + self.assertIn( + custom_metric, + assert_is_instance( + none_throws( + none_throws(client._experiment).optimization_config + ).objective, + ScalarizedObjective, + ).metrics, + ) + + # Test replacing an outcome constraint + client.configure_optimization( + objective="foo", outcome_constraints=["custom >= 0"] + ) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(none_throws(client._experiment).optimization_config) + .outcome_constraints[0] + .metric, + ) + + # Test replacing a tracking metric + client.configure_optimization( + objective="foo", + ) + none_throws(client._experiment).add_tracking_metric(metric=Metric("custom")) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(client._experiment).tracking_metrics[0], + ) + + # Test adding a tracking metric + client = Client() # Start a fresh Client + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1) + ) + ], + name="foo", + ) + ) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(client._experiment).tracking_metrics[0], + ) + def test_set_experiment(self) -> None: client = Client() experiment = get_branin_experiment() @@ -202,3 +321,28 @@ def test_set_early_stopping_strategy(self) -> None: early_stopping_strategy=early_stopping_strategy ) self.assertEqual(client._early_stopping_strategy, early_stopping_strategy) + + +class DummyRunner(IRunner): + @override + def run_trial( + self, trial_index: int, parameterization: TParameterization + ) -> dict[str, Any]: ... + + @override + def poll_trial( + self, trial_index: int, trial_metadata: Mapping[str, Any] + ) -> TrialStatus: ... + + @override + def stop_trial( + self, trial_index: int, trial_metadata: Mapping[str, Any] + ) -> dict[str, Any]: ... + + +class DummyMetric(IMetric): + def fetch( + self, + trial_index: int, + trial_metadata: Mapping[str, Any], + ) -> tuple[int, float | tuple[float, float]]: ...