diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 2fbbcd2c3a1..d513a62e73b 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -16,7 +16,6 @@ from ax.core.metric import Metric from ax.core.optimization_config import OptimizationConfig from ax.core.runner import Runner -from ax.core.search_space import SearchSpace from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.exceptions.core import UnsupportedError from ax.modelbridge.dispatch_utils import choose_generation_strategy @@ -41,6 +40,7 @@ class Client: _experiment: Experiment | None = None _generation_strategy: GenerationStrategyInterface | None = None + _early_stopping_strategy: BaseEarlyStoppingStrategy | None = None def __init__( self, @@ -180,51 +180,79 @@ def configure_metrics(self, metrics: Sequence[IMetric]) -> None: # -------------------- Section 1.2: Set (not API) ------------------------------- def set_experiment(self, experiment: Experiment) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing Experiment with the provided Experiment. Saves to database on completion if db_config is present. """ - ... - - def set_search_space(self, search_space: SearchSpace) -> None: - """ - Overwrite the existing SearchSpace with the provided SearchSpace. + self._experiment = experiment - Saves to database on completion if db_config is present. - """ - ... + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_optimization_config(self, optimization_config: OptimizationConfig) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing OptimizationConfig with the provided OptimizationConfig. Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().optimization_config = optimization_config + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_generation_strategy( self, generation_strategy: GenerationStrategyInterface ) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing GenerationStrategy with the provided GenerationStrategy. Saves to database on completion if db_config is present. """ - ... + self._generation_strategy = generation_strategy + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_early_stopping_strategy( self, early_stopping_strategy: BaseEarlyStoppingStrategy ) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing EarlyStoppingStrategy with the provided EarlyStoppingStrategy. Saves to database on completion if db_config is present. """ - ... + self._early_stopping_strategy = early_stopping_strategy + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_runner(self, runner: Runner) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Attaches a Runner to the Experiment. Saves to database on completion if db_config is present. @@ -233,6 +261,10 @@ def set_runner(self, runner: Runner) -> None: def set_metrics(self, metrics: Sequence[Metric]) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Finds equivallently named Metric that already exists on the Experiment and replaces it with the Metric provided, or adds the Metric provided to the Experiment as tracking metrics. diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 4767a57cb3c..11960a53365 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -26,6 +26,12 @@ RangeParameterConfig, ) from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_branin_optimization_config, + get_percentile_early_stopping_strategy, +) +from ax.utils.testing.modeling_stubs import get_generation_strategy from pyre_extensions import none_throws @@ -154,3 +160,45 @@ def test_configure_optimization(self) -> None: objective="ne", outcome_constraints=["qps >= 0"], ) + + def test_set_experiment(self) -> None: + client = Client() + experiment = get_branin_experiment() + + client.set_experiment(experiment=experiment) + + self.assertEqual(client._experiment, experiment) + + def test_set_optimization_config(self) -> None: + client = Client() + optimization_config = get_branin_optimization_config() + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.set_optimization_config(optimization_config=optimization_config) + + client.set_experiment(experiment=get_branin_experiment()) + client.set_optimization_config( + optimization_config=optimization_config, + ) + + self.assertEqual( + none_throws(client._experiment).optimization_config, optimization_config + ) + + def test_set_generation_strategy(self) -> None: + client = Client() + client.set_experiment(experiment=get_branin_experiment()) + + generation_strategy = get_generation_strategy() + + client.set_generation_strategy(generation_strategy=generation_strategy) + self.assertEqual(client._generation_strategy, generation_strategy) + + def test_set_early_stopping_strategy(self) -> None: + client = Client() + early_stopping_strategy = get_percentile_early_stopping_strategy() + + client.set_early_stopping_strategy( + early_stopping_strategy=early_stopping_strategy + ) + self.assertEqual(client._early_stopping_strategy, early_stopping_strategy)