Skip to content

Commit

Permalink
Implement configure_runner, configure_metric (facebook#3104)
Browse files Browse the repository at this point in the history
Summary:

configure_runner and configure_metric allow users to attach custom Runners and Metrics to their experiment.

configure_runner is fairly straightforward and just sets experiment.runner

configure_metric is more complicated: given a list of IMetrics it iterates through and tries to find a metric with the same name somewhere on the experiment. In order it checks the Objective (single, MOO, or secularized), outcome constraints, then tracking metrics. If no metric with a matching name is found then the provided metric is added as a tracking metric.

Reviewed By: lena-kashtelyan

Differential Revision: D66305614
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 2, 2024
1 parent 1f62a1d commit d3b27e5
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 13 deletions.
107 changes: 96 additions & 11 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict

from logging import Logger
from typing import Sequence

from ax.analysis.analysis import Analysis, AnalysisCard # Used as a return type
Expand All @@ -14,6 +14,7 @@
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.optimization_config import OptimizationConfig
from ax.core.runner import Runner
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
Expand All @@ -33,9 +34,12 @@
from ax.preview.api.utils.instantiation.from_string import (
optimization_config_from_string,
)
from ax.utils.common.logger import get_logger
from pyre_extensions import none_throws
from typing_extensions import Self

logger: Logger = get_logger(__name__)


class Client:
_experiment: Experiment | None = None
Expand Down Expand Up @@ -168,15 +172,16 @@ def configure_runner(self, runner: IRunner) -> None:
Saves to database on completion if db_config is present.
"""
...
self.set_runner(runner=runner)

def configure_metrics(self, metrics: Sequence[IMetric]) -> None:
"""
Finds equivallently named Metric that already exists on the Experiment and
replaces it with the Metric provided, or adds the Metric provided to the
Experiment as tracking metrics.
Attach a class with logic for autmating fetching of a given metric by
replacing its instance with the provided Metric from metrics sequence input,
or adds the Metric provided to the Experiment as a tracking metric if that
metric was not already present.
"""
...
self.set_metrics(metrics=metrics)

# -------------------- Section 1.2: Set (not API) -------------------------------
def set_experiment(self, experiment: Experiment) -> None:
Expand Down Expand Up @@ -228,6 +233,10 @@ def set_generation_strategy(
self._generation_strategy
)._experiment = self._none_throws_experiment()

none_throws(
self._generation_strategy
)._experiment = self._none_throws_experiment()

if self.db_config is not None:
# TODO[mpolson64] Save to database
...
Expand Down Expand Up @@ -261,19 +270,33 @@ def set_runner(self, runner: Runner) -> None:
Saves to database on completion if db_config is present.
"""
...
self._none_throws_experiment().runner = runner

if self.db_config is not None:
# TODO[mpolson64] Save to database
...

def set_metrics(self, metrics: Sequence[Metric]) -> None:
"""
This method is not part of the API and is provided (without guarantees of
method signature stability) for the convenience of some developers, power
users, and partners.
Finds equivallently named Metric that already exists on the Experiment and
replaces it with the Metric provided, or adds the Metric provided to the
Experiment as tracking metrics.
Attach a class with logic for autmating fetching of a given metric by
replacing its instance with the provided Metric from metrics sequence input,
or adds the Metric provided to the Experiment as a tracking metric if that
metric was not already present.
"""
...
# If an equivalently named Metric already exists on the Experiment, replace it
# with the Metric provided. Otherwise, add the Metric to the Experiment as a
# tracking metric.
for metric in metrics:
# Check the optimization config first
self._overwrite_metric(metric=metric)

if self.db_config is not None:
# TODO[mpolson64] Save to database
...

# -------------------- Section 2. Conduct Experiment ----------------------------
def get_next_trials(
Expand Down Expand Up @@ -516,3 +539,65 @@ def _none_throws_experiment(self) -> Experiment:
"experiment before utilizing any other methods on the Client."
),
)

def _overwrite_metric(self, metric: Metric) -> None:
"""
Overwrite an existing Metric on the Experiment with the provided Metric if they
share the same name. If not Metric with the same name exists, add the Metric as
a tracking metric.
Note that this method does not save the Experiment to the database (this is
handled in self.set_metrics).
"""

# Check the OptimizationConfig first
if (
optimization_config := self._none_throws_experiment().optimization_config
) is not None:
# Check the objective
if isinstance(
multi_objective := optimization_config.objective, MultiObjective
):
for i in range(len(multi_objective.objectives)):
if metric.name == multi_objective.objectives[i].metric.name:
multi_objective._objectives[i]._metric = metric
return

if isinstance(
scalarized_objective := optimization_config.objective,
ScalarizedObjective,
):
for i in range(len(scalarized_objective.metrics)):
if metric.name == scalarized_objective.metrics[i].name:
scalarized_objective._metrics[i] = metric
return

if (
isinstance(optimization_config.objective, Objective)
and metric.name == optimization_config.objective.metric.name
):
optimization_config.objective._metric = metric
return

# Check the outcome constraints
for i in range(len(optimization_config.outcome_constraints)):
if (
metric.name
== optimization_config.outcome_constraints[i].metric.name
):
optimization_config._outcome_constraints[i]._metric = metric
return

# Check the tracking metrics
tracking_metric_names = self._none_throws_experiment()._tracking_metrics.keys()
if metric.name in tracking_metric_names:
self._none_throws_experiment()._tracking_metrics[metric.name] = metric
return

# If an equivalently named Metric does not exist, add it as a tracking
# metric.
self._none_throws_experiment().add_tracking_metric(metric=metric)
logger.warning(
f"Metric {metric} not found in optimization config, added as tracking "
"metric."
)
148 changes: 146 additions & 2 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

# pyre-strict

from typing import Any, Mapping

from ax.core.base_trial import TrialStatus

from ax.core.experiment import Experiment
from ax.core.metric import Metric
from ax.core.objective import Objective
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
from ax.core.parameter import (
Expand All @@ -25,14 +29,17 @@
ParameterType,
RangeParameterConfig,
)
from ax.preview.api.protocols.metric import IMetric
from ax.preview.api.protocols.runner import IRunner
from ax.preview.api.types import TParameterization
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_optimization_config,
get_percentile_early_stopping_strategy,
)
from ax.utils.testing.modeling_stubs import get_generation_strategy
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws, override


class TestClient(TestCase):
Expand Down Expand Up @@ -161,6 +168,118 @@ def test_configure_optimization(self) -> None:
outcome_constraints=["qps >= 0"],
)

def test_configure_runner(self) -> None:
client = Client()
runner = DummyRunner()

with self.assertRaisesRegex(AssertionError, "Experiment not set"):
client.configure_runner(runner=runner)

client.set_experiment(experiment=get_branin_experiment())
client.configure_runner(runner=runner)

self.assertEqual(none_throws(client._experiment).runner, runner)

def test_configure_metric(self) -> None:
client = Client()
custom_metric = DummyMetric(name="custom")

with self.assertRaisesRegex(AssertionError, "Experiment not set"):
client.configure_metrics(metrics=[custom_metric])

client.configure_experiment(
experiment_config=ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1)
)
],
name="foo",
)
)

# Test replacing a single objective
client.configure_optimization(objective="custom")
client.configure_metrics(metrics=[custom_metric])

self.assertEqual(
custom_metric,
none_throws(
none_throws(client._experiment).optimization_config
).objective.metric,
)

# Test replacing a multi-objective
client.configure_optimization(objective="custom, foo")
client.configure_metrics(metrics=[custom_metric])

self.assertIn(
custom_metric,
assert_is_instance(
none_throws(
none_throws(client._experiment).optimization_config
).objective,
MultiObjective,
).metrics,
)
# Test replacing a scalarized objective
client.configure_optimization(objective="custom + foo")
client.configure_metrics(metrics=[custom_metric])

self.assertIn(
custom_metric,
assert_is_instance(
none_throws(
none_throws(client._experiment).optimization_config
).objective,
ScalarizedObjective,
).metrics,
)

# Test replacing an outcome constraint
client.configure_optimization(
objective="foo", outcome_constraints=["custom >= 0"]
)
client.configure_metrics(metrics=[custom_metric])

self.assertEqual(
custom_metric,
none_throws(none_throws(client._experiment).optimization_config)
.outcome_constraints[0]
.metric,
)

# Test replacing a tracking metric
client.configure_optimization(
objective="foo",
)
none_throws(client._experiment).add_tracking_metric(metric=Metric("custom"))
client.configure_metrics(metrics=[custom_metric])

self.assertEqual(
custom_metric,
none_throws(client._experiment).tracking_metrics[0],
)

# Test adding a tracking metric
client = Client() # Start a fresh Client
client.configure_experiment(
experiment_config=ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1)
)
],
name="foo",
)
)
client.configure_metrics(metrics=[custom_metric])

self.assertEqual(
custom_metric,
none_throws(client._experiment).tracking_metrics[0],
)

def test_set_experiment(self) -> None:
client = Client()
experiment = get_branin_experiment()
Expand Down Expand Up @@ -202,3 +321,28 @@ def test_set_early_stopping_strategy(self) -> None:
early_stopping_strategy=early_stopping_strategy
)
self.assertEqual(client._early_stopping_strategy, early_stopping_strategy)


class DummyRunner(IRunner):
@override
def run_trial(
self, trial_index: int, parameterization: TParameterization
) -> dict[str, Any]: ...

@override
def poll_trial(
self, trial_index: int, trial_metadata: Mapping[str, Any]
) -> TrialStatus: ...

@override
def stop_trial(
self, trial_index: int, trial_metadata: Mapping[str, Any]
) -> dict[str, Any]: ...


class DummyMetric(IMetric):
def fetch(
self,
trial_index: int,
trial_metadata: Mapping[str, Any],
) -> tuple[int, float | tuple[float, float]]: ...

0 comments on commit d3b27e5

Please sign in to comment.