Skip to content

Commit

Permalink
Implement set_ methods (facebook#3101)
Browse files Browse the repository at this point in the history
Summary:

These methods are not strictly speaking "part of the API" but may be useful for developers and trusted partners. Each is fairly self explanatory.

Differential Revision: D66304352
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Nov 26, 2024
1 parent 5474fb9 commit e81faca
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 12 deletions.
56 changes: 44 additions & 12 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ax.core.metric import Metric
from ax.core.optimization_config import OptimizationConfig
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.dispatch_utils import choose_generation_strategy
Expand All @@ -41,6 +40,7 @@
class Client:
_experiment: Experiment | None = None
_generation_strategy: GenerationStrategyInterface | None = None
_early_stopping_strategy: BaseEarlyStoppingStrategy | None = None

def __init__(
self,
Expand Down Expand Up @@ -180,51 +180,79 @@ def configure_metrics(self, metrics: Sequence[IMetric]) -> None:
# -------------------- Section 1.2: Set (not API) -------------------------------
def set_experiment(self, experiment: Experiment) -> None:
"""
This method is not part of the API and is provided (without guarantees of
method signature stability) for the convenience of some developers, power
users, and partners.
Overwrite the existing Experiment with the provided Experiment.
Saves to database on completion if db_config is present.
"""
...

def set_search_space(self, search_space: SearchSpace) -> None:
"""
Overwrite the existing SearchSpace with the provided SearchSpace.
self._experiment = experiment

Saves to database on completion if db_config is present.
"""
...
if self.db_config is not None:
# TODO[mpolson64] Save to database
...

def set_optimization_config(self, optimization_config: OptimizationConfig) -> None:
"""
This method is not part of the API and is provided (without guarantees of
method signature stability) for the convenience of some developers, power
users, and partners.
Overwrite the existing OptimizationConfig with the provided OptimizationConfig.
Saves to database on completion if db_config is present.
"""
...
self._none_throws_experiment().optimization_config = optimization_config

if self.db_config is not None:
# TODO[mpolson64] Save to database
...

def set_generation_strategy(
self, generation_strategy: GenerationStrategyInterface
) -> None:
"""
This method is not part of the API and is provided (without guarantees of
method signature stability) for the convenience of some developers, power
users, and partners.
Overwrite the existing GenerationStrategy with the provided GenerationStrategy.
Saves to database on completion if db_config is present.
"""
...
self._generation_strategy = generation_strategy

if self.db_config is not None:
# TODO[mpolson64] Save to database
...

def set_early_stopping_strategy(
self, early_stopping_strategy: BaseEarlyStoppingStrategy
) -> None:
"""
This method is not part of the API and is provided (without guarantees of
method signature stability) for the convenience of some developers, power
users, and partners.
Overwrite the existing EarlyStoppingStrategy with the provided
EarlyStoppingStrategy.
Saves to database on completion if db_config is present.
"""
...
self._early_stopping_strategy = early_stopping_strategy

if self.db_config is not None:
# TODO[mpolson64] Save to database
...

def set_runner(self, runner: Runner) -> None:
"""
This method is not part of the API and is provided (without guarantees of
method signature stability) for the convenience of some developers, power
users, and partners.
Attaches a Runner to the Experiment.
Saves to database on completion if db_config is present.
Expand All @@ -233,6 +261,10 @@ def set_runner(self, runner: Runner) -> None:

def set_metrics(self, metrics: Sequence[Metric]) -> None:
"""
This method is not part of the API and is provided (without guarantees of
method signature stability) for the convenience of some developers, power
users, and partners.
Finds equivallently named Metric that already exists on the Experiment and
replaces it with the Metric provided, or adds the Metric provided to the
Experiment as tracking metrics.
Expand Down
48 changes: 48 additions & 0 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
RangeParameterConfig,
)
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_optimization_config,
get_percentile_early_stopping_strategy,
)
from ax.utils.testing.modeling_stubs import get_generation_strategy
from pyre_extensions import none_throws


Expand Down Expand Up @@ -154,3 +160,45 @@ def test_configure_optimization(self) -> None:
objective="ne",
outcome_constraints=["qps >= 0"],
)

def test_set_experiment(self) -> None:
client = Client()
experiment = get_branin_experiment()

client.set_experiment(experiment=experiment)

self.assertEqual(client._experiment, experiment)

def test_set_optimization_config(self) -> None:
client = Client()
optimization_config = get_branin_optimization_config()

with self.assertRaisesRegex(AssertionError, "Experiment not set"):
client.set_optimization_config(optimization_config=optimization_config)

client.set_experiment(experiment=get_branin_experiment())
client.set_optimization_config(
optimization_config=optimization_config,
)

self.assertEqual(
none_throws(client._experiment).optimization_config, optimization_config
)

def test_set_generation_strategy(self) -> None:
client = Client()
client.set_experiment(experiment=get_branin_experiment())

generation_strategy = get_generation_strategy()

client.set_generation_strategy(generation_strategy=generation_strategy)
self.assertEqual(client._generation_strategy, generation_strategy)

def test_set_early_stopping_strategy(self) -> None:
client = Client()
early_stopping_strategy = get_percentile_early_stopping_strategy()

client.set_early_stopping_strategy(
early_stopping_strategy=early_stopping_strategy
)
self.assertEqual(client._early_stopping_strategy, early_stopping_strategy)

0 comments on commit e81faca

Please sign in to comment.