Skip to content

Commit

Permalink
Implement get_next_trials (facebook#3107)
Browse files Browse the repository at this point in the history
Summary:

As titled. Get next trials takes in a maximum number of trials to generate and optionally a dict of parameter values to be fixed and generates up to that number of trials (as parallelism is available).

If Experiment is not set or Experiment is set but Optimization is not set raise an Exception. If GenerationStrategy is not set then choose one automatically.

Once generated trials are attached to the experiment and immediately marked RUNNING

Differential Revision: D66367189
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Nov 26, 2024
1 parent 91c29ff commit f63442b
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 14 deletions.
100 changes: 86 additions & 14 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# pyre-strict

from logging import Logger
from typing import Sequence

from ax.analysis.analysis import Analysis, AnalysisCard # Used as a return type
Expand All @@ -15,10 +16,14 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.core.utils import get_pending_observation_features_based_on_trial_status
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.modelbridge.dispatch_utils import choose_generation_strategy

from ax.preview.api.configs import (
Expand All @@ -34,10 +39,15 @@
create_experiment,
create_optimization_config,
)
from pyre_extensions import none_throws
from ax.utils.common.logger import get_logger
from ax.utils.common.random import with_rng_seed
from pyre_extensions import assert_is_instance, none_throws
from typing_extensions import Self


logger: Logger = get_logger(__name__)


class Client:
_experiment: Experiment | None = None
_generation_strategy: GenerationStrategyInterface | None = None
Expand All @@ -61,8 +71,8 @@ def __init__(
early_stopping_strategy: Now set via set_early_stopping_strategy
global_stopping_strategy: Global stopping is not yet supported in API
"""
self.db_config = db_config
self.random_seed = random_seed
self._db_config = db_config
self._random_seed = random_seed

# -------------------- Section 1: Configure --------------------------------------
def configure_experiment(self, experiment_config: ExperimentConfig) -> None:
Expand All @@ -84,7 +94,7 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None:

self._experiment = create_experiment(config=experiment_config)

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand Down Expand Up @@ -127,7 +137,7 @@ def configure_optimization(
outcome_constraint_strs=outcome_constraints,
)

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand All @@ -149,15 +159,15 @@ def configure_generation_strategy(
generation_strategy_config.num_initialization_trials
),
max_parallelism_cap=generation_strategy_config.maximum_parallelism,
random_seed=self.random_seed,
random_seed=self._random_seed,
)

# Necessary for storage implications, may be removed in the future
generation_strategy._experiment = self._none_throws_experiment()

self._generation_strategy = generation_strategy

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand Down Expand Up @@ -191,7 +201,7 @@ def set_experiment(self, experiment: Experiment) -> None:
"""
self._experiment = experiment

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand All @@ -207,7 +217,7 @@ def set_optimization_config(self, optimization_config: OptimizationConfig) -> No
"""
self._none_throws_experiment().optimization_config = optimization_config

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand All @@ -229,7 +239,7 @@ def set_generation_strategy(
self._generation_strategy
)._experiment = self._none_throws_experiment()

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand All @@ -248,7 +258,7 @@ def set_early_stopping_strategy(
"""
self._early_stopping_strategy = early_stopping_strategy

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand All @@ -264,7 +274,7 @@ def set_runner(self, runner: Runner) -> None:
"""
self._none_throws_experiment().runner = runner

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand All @@ -285,7 +295,7 @@ def set_metrics(self, metrics: Sequence[Metric]) -> None:
# Check the optimization config first
self._overwrite_metric(metric=metric)

if self.db_config is not None:
if self._db_config is not None:
# TODO[mpolson64] Save to database
...

Expand All @@ -307,7 +317,69 @@ def get_next_trials(
Returns:
A mapping of trial index to parameterization.
"""
...

if self._none_throws_experiment().optimization_config is None:
raise UnsupportedError(
"OptimizationConfig not set. Please call configure_optimization before "
"generating trials."
)

# If no GenerationStrategy is set, configure a default one
if self._generation_strategy is None:
self.configure_generation_strategy(
generation_strategy_config=GenerationStrategyConfig()
)

trials: list[Trial] = []
with with_rng_seed(seed=self._random_seed):
for i in range(maximum_trials):
try:
# Would prefer to use gen_for_multiple_trials_with_multiple_models
# directly but it currently lacks support for fixed_features
generator_run = none_throws(
self._generation_strategy
)._gen_multiple(
experiment=self._none_throws_experiment(),
num_generator_runs=1,
n=1,
pending_observations=(
get_pending_observation_features_based_on_trial_status(
experiment=self._none_throws_experiment()
)
),
fixed_features=(
# pyre-fixme[6]: Type narrowing broken because core Ax
# TParameterization is dict not Mapping
ObservationFeatures(parameters=fixed_parameters)
if fixed_parameters is not None
else None
),
)[0]
except MaxParallelismReachedException:
logger.info(
f"Maximum parallelism reached. Returning {i} trials instead of "
f"requested quantity {maximum_trials}."
)
break

trial = assert_is_instance(
self._none_throws_experiment()
.new_trial(
generator_run=generator_run,
)
.mark_running(no_runner_required=True),
Trial,
)

trials.append(trial)

if self._db_config is not None:
# TODO[mpolson64] Save trial and update generation strategy
...

# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
# None, but we do not allow this in the API.
return {trial.index: none_throws(trial.arm).parameters for trial in trials}

def complete_trial(
self,
Expand Down
57 changes: 57 additions & 0 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ax.preview.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
GenerationStrategyConfig,
ParameterType,
RangeParameterConfig,
)
Expand Down Expand Up @@ -322,6 +323,62 @@ def test_set_early_stopping_strategy(self) -> None:
)
self.assertEqual(client._early_stopping_strategy, early_stopping_strategy)

def test_get_next_trials(self) -> None:
client = Client()

with self.assertRaisesRegex(AssertionError, "Experiment not set"):
client.get_next_trials()

client.configure_experiment(
ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
RangeParameterConfig(
name="x2", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)

with self.assertRaisesRegex(UnsupportedError, "OptimizationConfig not set"):
client.get_next_trials()

client.configure_optimization(objective="foo")
client.configure_generation_strategy(
generation_strategy_config=GenerationStrategyConfig(
# Set this to a large number so test runs fast
num_initialization_trials=999,
maximum_parallelism=5,
)
)

# Test can generate one trial
trials = client.get_next_trials()
self.assertEqual(len(trials), 1)
self.assertEqual({*trials[0].keys()}, {"x1", "x2"})
for parameter in ["x1", "x2"]:
value = assert_is_instance(trials[0][parameter], float)
self.assertGreaterEqual(value, -1.0)
self.assertLessEqual(value, 1.0)

# Test can generate multiple trials
trials = client.get_next_trials(maximum_trials=2)
self.assertEqual(len(trials), 2)

# Test respects fixed features
trials = client.get_next_trials(maximum_trials=1, fixed_parameters={"x1": 0.5})
value = assert_is_instance(trials[3]["x1"], float)
self.assertEqual(value, 0.5)

# Test respects max parallelism
# Returns 1 even though we asked for 2 because maximum parallelism has been
# reached.
trials = client.get_next_trials(maximum_trials=2)
self.assertEqual(len(trials), 1)


class DummyRunner(IRunner):
@override
Expand Down

0 comments on commit f63442b

Please sign in to comment.