From 27a3465d6da3b1ff94a4d4640dbc94c81d4a22b1 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 3 Dec 2024 13:36:48 -0800 Subject: [PATCH 1/3] Implement "configure" methods on Client (#3080) Summary: Implements new Client methods `configure_experiment`, `configure_optimization`, and `configure_generation_strategy`. Creates new module api.utils.instantiation that holds functions for converting from Configs to core Ax objects. These functions do not do validation, which will live on the configs themselves and be implemented in a separate diff. Note that this diff also does not implement saving to DB, although this will happen after each of these three methods are called if a config is provided **Id especially like comment on our use of SymPy** here to parse through objective and constraint strings -- what we've wound up with is much less verbose and I suspect much less error prone than what exists now in InstantiationBase while also providing a more natural user experience (ex. not getting tripped up by spacing, automatically handling inequality simplification like `(x1 + x2) / 2 + 0.5 >= 0` --> `-0.5 * x1 - 0.5 * x2 <= 1`, etc.) without any manual string parsing on our end at all. Im curious what people think of this strategy overall. SymPy usage occurs in `_parse_objective`, `_parse_parameter_constraint`, and `_parse_outcome_constraint`. Specific RFCs: * We made the decision earlier to combine the concepts of "outcome constraint" and objective thresholds into a single concept to make things clearer for our users -- do we still stand by this decision? Seeing it in practice I think it will help our users a ton but I want to confirm this before we get too far into implementation * We discussed previously that if we were using strings to represent objectives we wanted users to be able to specify optimization direction via coefficients (ex objective="loss" vs objective="-loss") **but we did not decide which direction a positive coefficient would indicate**. In this diff Ive implemented things such that a positive coefficient indicates minimization but Im happy to change -- I dont think one is better than the other we just need to be consistent. * To express relative outcome constraints, rather than use "%" like we do in AxClient, we ask the user multiply their bound by the term "baseline" (ex. "qps >= 0.95 * baseline" will constrain such that the QPS is at least 95% of the baseline arm's qps). To be honest we do this to make things play nice with SymPy but I also find it clearer, though Im curious what you all think Reviewed By: saitcakmak, lena-kashtelyan Differential Revision: D65826204 --- ax/preview/api/client.py | 103 ++++++- ax/preview/api/configs.py | 12 +- ax/preview/api/tests/test_client.py | 156 ++++++++++ ax/preview/api/utils/__init__.py | 5 + .../api/utils/instantiation/__init__.py | 5 + .../api/utils/instantiation/from_config.py | 146 +++++++++ .../api/utils/instantiation/from_string.py | 263 +++++++++++++++++ .../instantiation/tests/test_from_config.py | 277 ++++++++++++++++++ .../instantiation/tests/test_from_string.py | 213 ++++++++++++++ setup.py | 1 + sphinx/source/preview.rst | 16 + 11 files changed, 1175 insertions(+), 22 deletions(-) create mode 100644 ax/preview/api/tests/test_client.py create mode 100644 ax/preview/api/utils/__init__.py create mode 100644 ax/preview/api/utils/instantiation/__init__.py create mode 100644 ax/preview/api/utils/instantiation/from_config.py create mode 100644 ax/preview/api/utils/instantiation/from_string.py create mode 100644 ax/preview/api/utils/instantiation/tests/test_from_config.py create mode 100644 ax/preview/api/utils/instantiation/tests/test_from_string.py diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 1cd62189b69..adc8246dc19 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import Sequence from ax.analysis.analysis import Analysis, AnalysisCard # Used as a return type @@ -16,6 +18,8 @@ from ax.core.runner import Runner from ax.core.search_space import SearchSpace from ax.early_stopping.strategies import BaseEarlyStoppingStrategy +from ax.exceptions.core import UnsupportedError +from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.preview.api.configs import ( DatabaseConfig, @@ -26,29 +30,35 @@ from ax.preview.api.protocols.metric import IMetric from ax.preview.api.protocols.runner import IRunner from ax.preview.api.types import TOutcome, TParameterization +from ax.preview.api.utils.instantiation.from_config import experiment_from_config +from ax.preview.api.utils.instantiation.from_string import ( + optimization_config_from_string, +) +from pyre_extensions import none_throws from typing_extensions import Self class Client: + _experiment: Experiment | None = None + _generation_strategy: GenerationStrategyInterface | None = None + def __init__( self, db_config: DatabaseConfig | None = None, random_seed: int | None = None, ) -> None: """ - Many parameter are intentionally omitted from __init__ that were present - in AxClient.__init__, including: + Initialize a Client, which manages state across the lifecycle of an experiment. - generation_strategy: Now set via configure_generation_strategy or - set_generation_strategy - enforce_sequential_optimization: Now set via GenerationStrategyConfig - torch_device: Now set via GenerationStrategyConfig - verbose_logging: Omitted, user can set the logger level on their root config - suppress_storage_errors: Omitted - early_stopping_strategy: Now set via set_early_stopping_strategy - global_stopping_strategy: Global stopping is not yet supported in API + Args: + db_config: Configuration for saving to and loading from a database. If + elided the experiment will not automatically be saved to a database. + random_seed: An optional integer to set the random seed for reproducibility + of the experiment's results. If not provided, the random seed will not + be set, leading to potentially different results on different runs. """ - ... + self.db_config = db_config + self.random_seed = random_seed # -------------------- Section 1: Configure -------------------------------------- def configure_experiment(self, experiment_config: ExperimentConfig) -> None: @@ -62,7 +72,17 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None: Saves to database on completion if db_config is present. """ - ... + if self._experiment is not None: + raise UnsupportedError( + "Experiment already configured. Please create a new Client if you " + "would like a new experiment." + ) + + self._experiment = experiment_from_config(config=experiment_config) + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def configure_optimization( self, @@ -72,6 +92,9 @@ def configure_optimization( objective: str, # Outcome constraints will also be parsed via SymPy # Ex: "num_layers1 <= num_layers2", "compound_a + compound_b <= 1" + # To indicate a relative constraint multiply your bound by "baseline" + # Ex: "qps >= 0.95 * baseline" will constrain such that the QPS is at least + # 95% of the baseline arm's QPS. outcome_constraints: Sequence[str] | None = None, ) -> None: """ @@ -80,9 +103,33 @@ def configure_optimization( tracking_metrics if they were were already present (i.e. they were attached via configure_metrics) or added as base Metrics. + Args: + objective: Objective is a string and allows us to express single, + scalarized, and multi-objective goals. Ex: "loss", "ne1 + ne1", + "-ne, qps" + outcome_constraints: Outcome constraints are also strings and allow us to + express a desire to have a metric clear a threshold but not be + further optimized. These constraints are expressed as inequalities. + Ex: "qps >= 100", "0.5 * ne1 + 0.5 * ne2 >= 0.95". + To indicate a relative constraint multiply your bound by "baseline" + Ex: "qps >= 0.95 * baseline" will constrain such that the QPS is at + least 95% of the baseline arm's QPS. + Note that scalarized outcome constraints cannot be relative. + + Saves to database on completion if db_config is present. """ - ... + + self._none_throws_experiment().optimization_config = ( + optimization_config_from_string( + objective_str=objective, + outcome_constraint_strs=outcome_constraints, + ) + ) + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def configure_generation_strategy( self, generation_strategy_config: GenerationStrategyConfig @@ -93,7 +140,26 @@ def configure_generation_strategy( Saves to database on completion if db_config is present. """ - ... + + generation_strategy = choose_generation_strategy( + search_space=self._none_throws_experiment().search_space, + optimization_config=self._none_throws_experiment().optimization_config, + num_trials=generation_strategy_config.num_trials, + num_initialization_trials=( + generation_strategy_config.num_initialization_trials + ), + max_parallelism_cap=generation_strategy_config.maximum_parallelism, + random_seed=self.random_seed, + ) + + # Necessary for storage implications, may be removed in the future + generation_strategy._experiment = self._none_throws_experiment() + + self._generation_strategy = generation_strategy + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... # -------------------- Section 1.1: Configure Automation ------------------------ def configure_runner(self, runner: IRunner) -> None: @@ -406,3 +472,12 @@ def load_from_database( The restored `AxClient`. """ ... + + def _none_throws_experiment(self) -> Experiment: + return none_throws( + self._experiment, + ( + "Experiment not set. Please call configure_experiment or load an " + "experiment before utilizing any other methods on the Client." + ), + ) diff --git a/ax/preview/api/configs.py b/ax/preview/api/configs.py index 2939a1064fd..6ff6388ce8c 100644 --- a/ax/preview/api/configs.py +++ b/ax/preview/api/configs.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional +from typing import List, Mapping, Optional, Sequence from ax.preview.api.types import TParameterValue @@ -63,18 +63,14 @@ class ChoiceParameterConfig: values: List[float] | List[int] | List[str] | List[bool] parameter_type: ParameterType is_ordered: bool | None = None - dependent_parameters: dict[TParameterValue, str] | None = None + dependent_parameters: Mapping[TParameterValue, Sequence[str]] | None = None @dataclass class ExperimentConfig: """ - ExperimentConfig allows users to specify the SearchSpace and OptimizationConfig of - an Experiment and validates their inputs jointly. - - This will also be the construct that handles transforming string-based inputs (the - objective, parameter constraints, and output constraints) into their corresponding - Ax class using SymPy. + ExperimentConfig allows users to specify the SearchSpace of an experiment along + with other metadata. """ name: str diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py new file mode 100644 index 00000000000..09f0e9e44c5 --- /dev/null +++ b/ax/preview/api/tests/test_client.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.core.experiment import Experiment +from ax.core.metric import Metric +from ax.core.objective import Objective +from ax.core.optimization_config import OptimizationConfig +from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint +from ax.core.parameter import ( + ChoiceParameter, + ParameterType as CoreParameterType, + RangeParameter, +) +from ax.core.parameter_constraint import ParameterConstraint +from ax.core.search_space import SearchSpace +from ax.exceptions.core import UnsupportedError +from ax.preview.api.client import Client +from ax.preview.api.configs import ( + ChoiceParameterConfig, + ExperimentConfig, + ParameterType, + RangeParameterConfig, +) +from ax.utils.common.testutils import TestCase +from pyre_extensions import none_throws + + +class TestClient(TestCase): + def test_configure_experiment(self) -> None: + client = Client() + + float_parameter = RangeParameterConfig( + name="float_param", + parameter_type=ParameterType.FLOAT, + bounds=(0, 1), + ) + int_parameter = RangeParameterConfig( + name="int_param", + parameter_type=ParameterType.INT, + bounds=(0, 1), + ) + choice_parameter = ChoiceParameterConfig( + name="choice_param", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + ) + + experiment_config = ExperimentConfig( + name="test_experiment", + parameters=[float_parameter, int_parameter, choice_parameter], + parameter_constraints=["int_param <= float_param"], + description="test description", + owner="miles", + ) + + client.configure_experiment(experiment_config=experiment_config) + self.assertEqual( + client._experiment, + Experiment( + search_space=SearchSpace( + parameters=[ + RangeParameter( + name="float_param", + parameter_type=CoreParameterType.FLOAT, + lower=0, + upper=1, + ), + RangeParameter( + name="int_param", + parameter_type=CoreParameterType.INT, + lower=0, + upper=1, + ), + ChoiceParameter( + name="choice_param", + parameter_type=CoreParameterType.STRING, + values=["a", "b", "c"], + is_ordered=False, + sort_values=False, + ), + ], + parameter_constraints=[ + ParameterConstraint( + constraint_dict={"int_param": 1, "float_param": -1}, bound=0 + ) + ], + ), + name="test_experiment", + description="test description", + properties={"owners": ["miles"]}, + ), + ) + + with self.assertRaisesRegex(UnsupportedError, "Experiment already configured"): + client.configure_experiment(experiment_config=experiment_config) + + def test_configure_optimization(self) -> None: + client = Client() + + float_parameter = RangeParameterConfig( + name="float_param", + parameter_type=ParameterType.FLOAT, + bounds=(0, 1), + ) + int_parameter = RangeParameterConfig( + name="int_param", + parameter_type=ParameterType.INT, + bounds=(0, 1), + ) + choice_parameter = ChoiceParameterConfig( + name="choice_param", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + ) + + experiment_config = ExperimentConfig( + name="test_experiment", + parameters=[float_parameter, int_parameter, choice_parameter], + parameter_constraints=["int_param <= float_param"], + description="test description", + owner="miles", + ) + + client.configure_experiment(experiment_config=experiment_config) + + client.configure_optimization( + objective="-ne", + outcome_constraints=["qps >= 0"], + ) + + self.assertEqual( + none_throws(client._experiment).optimization_config, + OptimizationConfig( + objective=Objective(metric=Metric(name="ne"), minimize=True), + outcome_constraints=[ + OutcomeConstraint( + metric=Metric(name="qps"), + op=ComparisonOp.GEQ, + bound=0.0, + relative=False, + ) + ], + ), + ) + + empty_client = Client() + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + empty_client.configure_optimization( + objective="ne", + outcome_constraints=["qps >= 0"], + ) diff --git a/ax/preview/api/utils/__init__.py b/ax/preview/api/utils/__init__.py new file mode 100644 index 00000000000..4b87eb9e4d0 --- /dev/null +++ b/ax/preview/api/utils/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/ax/preview/api/utils/instantiation/__init__.py b/ax/preview/api/utils/instantiation/__init__.py new file mode 100644 index 00000000000..4b87eb9e4d0 --- /dev/null +++ b/ax/preview/api/utils/instantiation/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/ax/preview/api/utils/instantiation/from_config.py b/ax/preview/api/utils/instantiation/from_config.py new file mode 100644 index 00000000000..38ee86a6626 --- /dev/null +++ b/ax/preview/api/utils/instantiation/from_config.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import numpy as np + +from ax.core.experiment import Experiment +from ax.core.parameter import ( + ChoiceParameter, + FixedParameter, + Parameter, + ParameterType as CoreParameterType, + RangeParameter, +) +from ax.core.parameter_constraint import validate_constraint_parameters +from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError +from ax.preview.api.configs import ( + ChoiceParameterConfig, + ExperimentConfig, + ParameterScaling, + ParameterType, + RangeParameterConfig, +) +from ax.preview.api.utils.instantiation.from_string import parse_parameter_constraint + + +def parameter_from_config( + config: RangeParameterConfig | ChoiceParameterConfig, +) -> Parameter: + """ + Create a RangeParameter, ChoiceParameter, or FixedParameter from a ParameterConfig. + """ + + if isinstance(config, RangeParameterConfig): + lower, upper = config.bounds + + # TODO[mpolson64] Add support for RangeParameterConfig.step_size native to + # RangeParameter instead of converting to ChoiceParameter + if (step_size := config.step_size) is not None: + if not ( + config.scaling == ParameterScaling.LINEAR or config.scaling is None + ): + raise UserInputError( + "Non-linear parameter scaling is not supported when using " + "step_size." + ) + + if (upper - lower) % step_size != 0: + raise UserInputError( + "The range of the parameter must be evenly divisible by the " + "step size." + ) + + return ChoiceParameter( + name=config.name, + parameter_type=_parameter_type_converter(config.parameter_type), + values=[*np.arange(lower, upper + step_size, step_size)], + is_ordered=True, + ) + + return RangeParameter( + name=config.name, + parameter_type=_parameter_type_converter(config.parameter_type), + lower=lower, + upper=upper, + log_scale=config.scaling == ParameterScaling.LOG, + ) + + else: + # If there is only one value, create a FixedParameter instead of a + # ChoiceParameter + if len(config.values) == 1: + return FixedParameter( + name=config.name, + parameter_type=_parameter_type_converter(config.parameter_type), + value=config.values[0], + # pyre-fixme[6] Variance issue caused by FixedParameter.dependents + # using List instead of immutable container type. + dependents=config.dependent_parameters, + ) + + return ChoiceParameter( + name=config.name, + parameter_type=_parameter_type_converter(config.parameter_type), + # pyre-fixme[6] Variance issue caused by ChoiceParameter.value using List + # instead of immutable container type. + values=config.values, + is_ordered=config.is_ordered, + # pyre-fixme[6] Variance issue caused by ChoiceParameter.dependents using + # List instead of immutable container type. + dependents=config.dependent_parameters, + ) + + +def experiment_from_config(config: ExperimentConfig) -> Experiment: + """Create an Experiment from an ExperimentConfig.""" + parameters = [ + parameter_from_config(config=parameter_config) + for parameter_config in config.parameters + ] + + constraints = [ + parse_parameter_constraint(constraint_str=constraint_str) + for constraint_str in config.parameter_constraints + ] + + # Ensure that all ParameterConstraints are valid and acting on existing parameters + for constraint in constraints: + validate_constraint_parameters( + parameters=[ + parameter + for parameter in parameters + if parameter.name in constraint.constraint_dict.keys() + ] + ) + + search_space = SearchSpace(parameters=parameters, parameter_constraints=constraints) + + return Experiment( + search_space=search_space, + name=config.name, + description=config.description, + properties={"owners": [config.owner]}, + ) + + +def _parameter_type_converter(parameter_type: ParameterType) -> CoreParameterType: + """ + Convert from an API ParameterType to a core Ax ParameterType. + """ + + if parameter_type == ParameterType.BOOL: + return CoreParameterType.BOOL + elif parameter_type == ParameterType.FLOAT: + return CoreParameterType.FLOAT + elif parameter_type == ParameterType.INT: + return CoreParameterType.INT + elif parameter_type == ParameterType.STRING: + return CoreParameterType.STRING + else: + raise UserInputError(f"Unsupported parameter type {parameter_type}.") diff --git a/ax/preview/api/utils/instantiation/from_string.py b/ax/preview/api/utils/instantiation/from_string.py new file mode 100644 index 00000000000..ac0dd827c3b --- /dev/null +++ b/ax/preview/api/utils/instantiation/from_string.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Sequence + +from ax.core.metric import Metric +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective +from ax.core.optimization_config import ( + MultiObjectiveOptimizationConfig, + OptimizationConfig, +) +from ax.core.outcome_constraint import ( + ComparisonOp, + ObjectiveThreshold, + OutcomeConstraint, + ScalarizedOutcomeConstraint, +) +from ax.core.parameter_constraint import ParameterConstraint +from ax.exceptions.core import UserInputError +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.relational import GreaterThan, LessThan +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify + + +def optimization_config_from_string( + objective_str: str, outcome_constraint_strs: Sequence[str] | None = None +) -> OptimizationConfig: + """ + Create an OptimizationConfig from objective and outcome constraint strings. + + Note that outcome constraints may not be placed on the objective metric except in + the multi-objective case where they will be converted to objective thresholds. + """ + + objective = parse_objective(objective_str=objective_str) + + if outcome_constraint_strs is not None: + outcome_constraints = [ + parse_outcome_constraint(constraint_str=constraint_str) + for constraint_str in outcome_constraint_strs + ] + else: + outcome_constraints = None + + if isinstance(objective, MultiObjective): + # Convert OutcomeConstraints to ObjectiveThresholds if relevant + objective_metric_names = {metric.name for metric in objective.metrics} + true_outcome_constraints = [] + objective_thresholds: list[ObjectiveThreshold] = [] + for outcome_constraint in outcome_constraints or []: + if ( + not isinstance(outcome_constraint, ScalarizedOutcomeConstraint) + and outcome_constraint.metric.name in objective_metric_names + ): + objective_thresholds.append( + ObjectiveThreshold( + metric=outcome_constraint.metric, + bound=outcome_constraint.bound, + relative=outcome_constraint.relative, + op=outcome_constraint.op, + ) + ) + else: + true_outcome_constraints.append(outcome_constraint) + + return MultiObjectiveOptimizationConfig( + objective=objective, + outcome_constraints=true_outcome_constraints, + objective_thresholds=objective_thresholds, + ) + + # Ensure that outcome constraints are not placed on the objective metric + objective_metric_names = {metric.name for metric in objective.metrics} + for outcome_constraint in outcome_constraints or []: + if outcome_constraint.metric.name in objective_metric_names: + raise UserInputError( + "Outcome constraints may not be placed on the objective metric " + f"except in the multi-objective case, found {objective_str} and " + f"{outcome_constraint_strs}" + ) + + return OptimizationConfig( + objective=objective, + outcome_constraints=outcome_constraints, + ) + + +def parse_parameter_constraint(constraint_str: str) -> ParameterConstraint: + """ + Parse a parameter constraint string into a ParameterConstraint object using SymPy. + Currently only supports linear constraints of the form "a * x + b * y >= k" or + "a * x + b * y <= k". + """ + coefficient_dict = _extract_coefficient_dict_from_inequality( + inequality_str=constraint_str + ) + + # Iterate through the coefficients to extract the parameter names and weights and + # the bound + constraint_dict = {} + bound = 0 + for term, coefficient in coefficient_dict.items(): + if term.is_symbol: + constraint_dict[term.name] = coefficient + elif term.is_number: + # Invert because we are "moving" the bound to the right hand side + bound = -1 * coefficient + else: + raise UserInputError( + "Only linear inequality parameter constraints are supported, found " + f"{constraint_str}" + ) + + return ParameterConstraint(constraint_dict=constraint_dict, bound=bound) + + +def parse_objective(objective_str: str) -> Objective: + """ + Parse an objective string into an Objective object using SymPy. + + Currently only supports linear objectives of the form "a * x + b * y" and tuples of + linear objectives. + """ + # Parse the objective string into a SymPy expression + expression = sympify(objective_str) + + if isinstance(expression, tuple): # Multi-objective + return MultiObjective( + objectives=[ + _create_single_objective(expression=term) for term in expression + ] + ) + + return _create_single_objective(expression=expression) + + +def parse_outcome_constraint(constraint_str: str) -> OutcomeConstraint: + """ + Parse an outcome constraint string into an OutcomeConstraint object using SymPy. + Currently only supports linear constraints of the form "a * x + b * y >= k" or + "a * x + b * y <= k". + + To indicate a relative constraint (i.e. performance relative to some baseline) + multiply your bound by "baseline". For example "qps >= 0.95 * baseline" will + constrain such that the QPS is at least 95% of the baseline arm's QPS. + """ + coefficient_dict = _extract_coefficient_dict_from_inequality( + inequality_str=constraint_str + ) + + # Iterate through the coefficients to extract the parameter names and weights and + # the bound + constraint_dict: dict[str, float] = {} + bound = 0 + is_relative = False + for term, coefficient in coefficient_dict.items(): + if term.is_symbol: + if term.name == "baseline": + # Invert because we are "moving" the bound to the right hand side + bound = -1 * coefficient + is_relative = True + else: + constraint_dict[term.name] = coefficient + elif term.is_number: + # Invert because we are "moving" the bound to the right hand side + bound = -1 * coefficient + else: + raise UserInputError( + "Only linear outcome constraints are supported, found " + f"{constraint_str}" + ) + + if len(constraint_dict) == 1: + term, coefficient = next(iter(constraint_dict.items())) + + return OutcomeConstraint( + metric=Metric(name=term), + op=ComparisonOp.LEQ if coefficient > 0 else ComparisonOp.GEQ, + bound=bound / coefficient, + relative=is_relative, + ) + + names, coefficients = zip(*constraint_dict.items()) + return ScalarizedOutcomeConstraint( + metrics=[Metric(name=name) for name in names], + op=ComparisonOp.LEQ, + weights=[*coefficients], + bound=bound, + relative=is_relative, + ) + + +def _create_single_objective(expression: Expr) -> Objective: + """ + Create an Objective or ScalarizedObjective from a linear SymPy expression. + + All expressions are assumed to represent maximization objectives. + """ + + # If the expression is a just a Symbol it represents a single metric objective + if isinstance(expression, Symbol): + return Objective(metric=Metric(name=str(expression.name)), minimize=False) + + # If the expression is a Mul it likely represents a single metric objective but + # some additional validation is required + if isinstance(expression, Mul): + symbol, *other_symbols = expression.free_symbols + if len(other_symbols) > 0: + raise UserInputError( + f"Only linear objectives are supported, found {expression}." + ) + + # Since the objectives 1 * loss and 2 * loss are equivalent, we can just use + # the sign from the coefficient rather than its value + minimize = bool(expression.as_coefficient(symbol) < 0) + + return Objective(metric=Metric(name=str(symbol)), minimize=minimize) + + # If the expression is an Add it represents a scalarized objective + elif isinstance(expression, Add): + names, coefficients = zip(*expression.as_coefficients_dict().items()) + return ScalarizedObjective( + metrics=[Metric(name=str(name)) for name in names], + weights=[float(coefficient) for coefficient in coefficients], + minimize=False, + ) + + raise UserInputError(f"Only linear objectives are supported, found {expression}.") + + +def _extract_coefficient_dict_from_inequality( + inequality_str: str, +) -> dict[Symbol, float]: + """ + Use SymPy to parse a string into an inequality, invert if necessary to enforce a + less-than relationship, move all terms to the left side, and return the + coefficients as a dictionary. This is useful for parsing parameter and outcome + constraints. + """ + # Parse the constraint string into a SymPy inequality + inequality = sympify(inequality_str) + + # Check the SymPy object is a valid inequality + if not isinstance(inequality, GreaterThan | LessThan): + raise UserInputError(f"Expected an inequality, found {inequality_str}") + + # Move all terms to the left side of the inequality and invert if necessary to + # enforce a less-than relationship + if isinstance(inequality, LessThan): + expression = inequality.lhs - inequality.rhs + else: + expression = inequality.rhs - inequality.lhs + + return { + key: float(value) for key, value in expression.as_coefficients_dict().items() + } diff --git a/ax/preview/api/utils/instantiation/tests/test_from_config.py b/ax/preview/api/utils/instantiation/tests/test_from_config.py new file mode 100644 index 00000000000..319fa0aee0d --- /dev/null +++ b/ax/preview/api/utils/instantiation/tests/test_from_config.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.core.experiment import Experiment +from ax.core.parameter import ( + ChoiceParameter, + FixedParameter, + ParameterType as CoreParameterType, + RangeParameter, +) +from ax.core.parameter_constraint import ParameterConstraint +from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError +from ax.preview.api.configs import ( + ChoiceParameterConfig, + ExperimentConfig, + ParameterScaling, + ParameterType, + RangeParameterConfig, +) +from ax.preview.api.utils.instantiation.from_config import ( + _parameter_type_converter, + experiment_from_config, + parameter_from_config, +) +from ax.utils.common.testutils import TestCase + + +class TestFromConfig(TestCase): + def test_create_range_parameter(self) -> None: + float_config = RangeParameterConfig( + name="float_param", + parameter_type=ParameterType.FLOAT, + bounds=(0, 1), + ) + + self.assertEqual( + parameter_from_config(config=float_config), + RangeParameter( + name="float_param", + parameter_type=CoreParameterType.FLOAT, + lower=0, + upper=1, + ), + ) + + float_config_with_log_scaling = RangeParameterConfig( + name="float_param_with_log_scaling", + parameter_type=ParameterType.FLOAT, + bounds=(1e-10, 1), + scaling=ParameterScaling.LOG, + ) + + self.assertEqual( + parameter_from_config(config=float_config_with_log_scaling), + RangeParameter( + name="float_param_with_log_scaling", + parameter_type=CoreParameterType.FLOAT, + lower=1e-10, + upper=1, + log_scale=True, + ), + ) + + int_config = RangeParameterConfig( + name="int_param", + parameter_type=ParameterType.INT, + bounds=(0, 1), + ) + + self.assertEqual( + parameter_from_config(config=int_config), + RangeParameter( + name="int_param", + parameter_type=CoreParameterType.INT, + lower=0, + upper=1, + ), + ) + + step_size_config = RangeParameterConfig( + name="step_size_param", + parameter_type=ParameterType.FLOAT, + bounds=(0, 100), + step_size=10, + ) + + self.assertEqual( + parameter_from_config(config=step_size_config), + ChoiceParameter( + name="step_size_param", + parameter_type=CoreParameterType.FLOAT, + values=[ + 0.0, + 10.0, + 20.0, + 30.0, + 40.0, + 50.0, + 60.0, + 70.0, + 80.0, + 90.0, + 100.0, + ], + is_ordered=True, + ), + ) + + with self.assertRaisesRegex( + UserInputError, + "Non-linear parameter scaling is not supported when using step_size", + ): + parameter_from_config( + config=RangeParameterConfig( + name="step_size_param_with_scaling", + parameter_type=ParameterType.FLOAT, + bounds=(0, 100), + step_size=10, + scaling=ParameterScaling.LOG, + ) + ) + + def test_create_choice_parameter(self) -> None: + choice_config = ChoiceParameterConfig( + name="choice_param", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + ) + + self.assertEqual( + parameter_from_config(config=choice_config), + ChoiceParameter( + name="choice_param", + parameter_type=CoreParameterType.STRING, + values=["a", "b", "c"], + ), + ) + + choice_config_with_order = ChoiceParameterConfig( + name="choice_param_with_order", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + is_ordered=True, + ) + self.assertEqual( + parameter_from_config(config=choice_config_with_order), + ChoiceParameter( + name="choice_param_with_order", + parameter_type=CoreParameterType.STRING, + values=["a", "b", "c"], + is_ordered=True, + ), + ) + + choice_config_with_dependents = ChoiceParameterConfig( + name="choice_param_with_dependents", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + dependent_parameters={ + "a": ["a1", "a2"], + "b": ["b1", "b2", "b3"], + }, + ) + self.assertEqual( + parameter_from_config(config=choice_config_with_dependents), + ChoiceParameter( + name="choice_param_with_dependents", + parameter_type=CoreParameterType.STRING, + values=["a", "b", "c"], + dependents={ + "a": ["a1", "a2"], + "b": ["b1", "b2", "b3"], + }, + ), + ) + + single_element_choice_config = ChoiceParameterConfig( + name="single_element_choice_param", + parameter_type=ParameterType.STRING, + values=["a"], + ) + self.assertEqual( + parameter_from_config(config=single_element_choice_config), + FixedParameter( + name="single_element_choice_param", + parameter_type=CoreParameterType.STRING, + value="a", + ), + ) + + def test_experiment_from_config(self) -> None: + float_parameter = RangeParameterConfig( + name="float_param", + parameter_type=ParameterType.FLOAT, + bounds=(0, 1), + ) + int_parameter = RangeParameterConfig( + name="int_param", + parameter_type=ParameterType.INT, + bounds=(0, 1), + ) + choice_parameter = ChoiceParameterConfig( + name="choice_param", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + ) + + experiment_config = ExperimentConfig( + name="test_experiment", + parameters=[float_parameter, int_parameter, choice_parameter], + parameter_constraints=["int_param <= float_param"], + description="test description", + owner="miles", + ) + + self.assertEqual( + experiment_from_config(config=experiment_config), + Experiment( + search_space=SearchSpace( + parameters=[ + RangeParameter( + name="float_param", + parameter_type=CoreParameterType.FLOAT, + lower=0, + upper=1, + ), + RangeParameter( + name="int_param", + parameter_type=CoreParameterType.INT, + lower=0, + upper=1, + ), + ChoiceParameter( + name="choice_param", + parameter_type=CoreParameterType.STRING, + values=["a", "b", "c"], + is_ordered=False, + sort_values=False, + ), + ], + parameter_constraints=[ + ParameterConstraint( + constraint_dict={"int_param": 1, "float_param": -1}, bound=0 + ) + ], + ), + name="test_experiment", + description="test description", + properties={"owners": ["miles"]}, + ), + ) + + def test_parameter_type_converter(self) -> None: + self.assertEqual( + _parameter_type_converter(parameter_type=ParameterType.BOOL), + CoreParameterType.BOOL, + ) + self.assertEqual( + _parameter_type_converter(parameter_type=ParameterType.INT), + CoreParameterType.INT, + ) + self.assertEqual( + _parameter_type_converter(parameter_type=ParameterType.FLOAT), + CoreParameterType.FLOAT, + ) + self.assertEqual( + _parameter_type_converter(parameter_type=ParameterType.STRING), + CoreParameterType.STRING, + ) + with self.assertRaisesRegex(UserInputError, "Unsupported parameter type"): + # pyre-ignore[6] Testing a bad input on purpose + _parameter_type_converter(parameter_type="bad") diff --git a/ax/preview/api/utils/instantiation/tests/test_from_string.py b/ax/preview/api/utils/instantiation/tests/test_from_string.py new file mode 100644 index 00000000000..fc897a1dc0d --- /dev/null +++ b/ax/preview/api/utils/instantiation/tests/test_from_string.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.core.metric import Metric +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective +from ax.core.optimization_config import ( + MultiObjectiveOptimizationConfig, + OptimizationConfig, +) +from ax.core.outcome_constraint import ( + ComparisonOp, + ObjectiveThreshold, + OutcomeConstraint, + ScalarizedOutcomeConstraint, +) +from ax.core.parameter_constraint import ParameterConstraint +from ax.exceptions.core import UserInputError +from ax.preview.api.utils.instantiation.from_string import ( + optimization_config_from_string, + parse_objective, + parse_outcome_constraint, + parse_parameter_constraint, +) +from ax.utils.common.testutils import TestCase + + +class TestFromString(TestCase): + def test_optimization_config_from_string(self) -> None: + only_objective = optimization_config_from_string(objective_str="ne") + self.assertEqual( + only_objective, + OptimizationConfig( + objective=Objective(metric=Metric(name="ne"), minimize=False), + ), + ) + + with_constraints = optimization_config_from_string( + objective_str="ne", outcome_constraint_strs=["qps >= 0"] + ) + self.assertEqual( + with_constraints, + OptimizationConfig( + objective=Objective(metric=Metric(name="ne"), minimize=False), + outcome_constraints=[ + OutcomeConstraint( + metric=Metric(name="qps"), + op=ComparisonOp.GEQ, + bound=0.0, + relative=False, + ) + ], + ), + ) + + with_constraints_and_objective_threshold = optimization_config_from_string( + objective_str="-ne, qps", + outcome_constraint_strs=["qps >= 1000", "flops <= 1000000"], + ) + self.assertEqual( + with_constraints_and_objective_threshold, + MultiObjectiveOptimizationConfig( + objective=MultiObjective( + objectives=[ + Objective(metric=Metric(name="ne"), minimize=True), + Objective(metric=Metric(name="qps"), minimize=False), + ] + ), + outcome_constraints=[ + OutcomeConstraint( + metric=Metric(name="flops"), + op=ComparisonOp.LEQ, + bound=1000000.0, + relative=False, + ) + ], + objective_thresholds=[ + ObjectiveThreshold( + metric=Metric(name="qps"), + op=ComparisonOp.GEQ, + bound=1000.0, + relative=False, + ) + ], + ), + ) + + def test_parse_paramter_constraint(self) -> None: + constraint = parse_parameter_constraint(constraint_str="x1 + x2 <= 1") + self.assertEqual( + constraint, + ParameterConstraint(constraint_dict={"x1": 1, "x2": 1}, bound=1.0), + ) + + with_coefficients = parse_parameter_constraint( + constraint_str="2 * x1 + 3 * x2 <= 1" + ) + self.assertEqual( + with_coefficients, + ParameterConstraint(constraint_dict={"x1": 2, "x2": 3}, bound=1.0), + ) + + flipped_sign = parse_parameter_constraint(constraint_str="x1 + x2 >= 1") + self.assertEqual( + flipped_sign, + ParameterConstraint(constraint_dict={"x1": -1, "x2": -1}, bound=-1.0), + ) + + weird = parse_parameter_constraint(constraint_str="x1 + x2 <= 1.5 * x3 + 2") + self.assertEqual( + weird, + ParameterConstraint( + constraint_dict={"x1": 1, "x2": 1, "x3": -1.5}, bound=2.0 + ), + ) + + with self.assertRaisesRegex(UserInputError, "Only linear"): + parse_parameter_constraint(constraint_str="x1 * x2 <= 1") + + def test_parse_objective(self) -> None: + single_objective = parse_objective(objective_str="ne") + self.assertEqual( + single_objective, Objective(metric=Metric(name="ne"), minimize=False) + ) + + maximize_single_objective = parse_objective(objective_str="-qps") + self.assertEqual( + maximize_single_objective, + Objective(metric=Metric(name="qps"), minimize=True), + ) + + scalarized_objective = parse_objective( + objective_str="0.5 * ne1 + 0.3 * ne2 + 0.2 * ne3" + ) + self.assertEqual( + scalarized_objective, + ScalarizedObjective( + metrics=[Metric(name="ne1"), Metric(name="ne2"), Metric(name="ne3")], + weights=[0.5, 0.3, 0.2], + minimize=False, + ), + ) + + multiobjective = parse_objective(objective_str="ne, -qps") + self.assertEqual( + multiobjective, + MultiObjective( + objectives=[ + Objective(metric=Metric(name="ne"), minimize=False), + Objective(metric=Metric(name="qps"), minimize=True), + ] + ), + ) + + with self.assertRaisesRegex(UserInputError, "Only linear"): + parse_objective(objective_str="ne * qps") + + def test_parse_outcome_constraint(self) -> None: + constraint = parse_outcome_constraint(constraint_str="flops <= 1000000") + self.assertEqual( + constraint, + OutcomeConstraint( + metric=Metric(name="flops"), + op=ComparisonOp.LEQ, + bound=1000000.0, + relative=False, + ), + ) + + flipped_sign = parse_outcome_constraint(constraint_str="flops >= 1000000.0") + self.assertEqual( + flipped_sign, + OutcomeConstraint( + metric=Metric(name="flops"), + op=ComparisonOp.GEQ, + bound=1000000.0, + relative=False, + ), + ) + + relative = parse_outcome_constraint(constraint_str="flops <= 105 * baseline") + self.assertEqual( + relative, + OutcomeConstraint( + metric=Metric(name="flops"), + op=ComparisonOp.LEQ, + bound=105.0, + relative=True, + ), + ) + + scalarized = parse_outcome_constraint( + constraint_str="0.5 * flops1 + 0.3 * flops2 <= 1000000" + ) + self.assertEqual( + scalarized, + ScalarizedOutcomeConstraint( + metrics=[Metric(name="flops1"), Metric(name="flops2")], + weights=[0.5, 0.3], + op=ComparisonOp.LEQ, + bound=1000000.0, + relative=False, + ), + ) + + with self.assertRaisesRegex(UserInputError, "Expected an inequality"): + parse_outcome_constraint(constraint_str="flops == 1000000") + + with self.assertRaisesRegex(UserInputError, "Only linear"): + parse_outcome_constraint(constraint_str="flops * flops <= 1000000") diff --git a/setup.py b/setup.py index d3251af1f7c..1be63417b4a 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ # Needed for compatibility with ipywidgets >= 8.0.0 "plotly>=5.12.0", "pyre-extensions", + "sympy", ] # pytest-cov requires pytest >= 3.6 diff --git a/sphinx/source/preview.rst b/sphinx/source/preview.rst index 65e2e4bacee..9b714d625da 100644 --- a/sphinx/source/preview.rst +++ b/sphinx/source/preview.rst @@ -61,3 +61,19 @@ Types :members: :undoc-members: :show-inheritance: + +From Config +~~~~~~~~~~~ + +.. automodule:: ax.preview.api.utils.instantiation.from_config + :members: + :undoc-members: + :show-inheritance: + +From String +~~~~~~~~~~~ + +.. automodule:: ax.preview.api.utils.instantiation.from_string + :members: + :undoc-members: + :show-inheritance: From 34004623ee21975179930d1e1a12691a48e2475d Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 3 Dec 2024 13:36:48 -0800 Subject: [PATCH 2/3] Implement set_ methods (#3101) Summary: These methods are not strictly speaking "part of the API" but may be useful for developers and trusted partners. Each is fairly self explanatory. Reviewed By: lena-kashtelyan Differential Revision: D66304352 --- ax/preview/api/client.py | 59 +++++++++++++++++++++++------ ax/preview/api/tests/test_client.py | 48 +++++++++++++++++++++++ 2 files changed, 95 insertions(+), 12 deletions(-) diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index adc8246dc19..65f9bbf0a60 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -16,7 +16,6 @@ from ax.core.metric import Metric from ax.core.optimization_config import OptimizationConfig from ax.core.runner import Runner -from ax.core.search_space import SearchSpace from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.exceptions.core import UnsupportedError from ax.modelbridge.dispatch_utils import choose_generation_strategy @@ -41,6 +40,7 @@ class Client: _experiment: Experiment | None = None _generation_strategy: GenerationStrategyInterface | None = None + _early_stopping_strategy: BaseEarlyStoppingStrategy | None = None def __init__( self, @@ -181,51 +181,82 @@ def configure_metrics(self, metrics: Sequence[IMetric]) -> None: # -------------------- Section 1.2: Set (not API) ------------------------------- def set_experiment(self, experiment: Experiment) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing Experiment with the provided Experiment. Saves to database on completion if db_config is present. """ - ... - - def set_search_space(self, search_space: SearchSpace) -> None: - """ - Overwrite the existing SearchSpace with the provided SearchSpace. + self._experiment = experiment - Saves to database on completion if db_config is present. - """ - ... + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_optimization_config(self, optimization_config: OptimizationConfig) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing OptimizationConfig with the provided OptimizationConfig. Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().optimization_config = optimization_config + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_generation_strategy( self, generation_strategy: GenerationStrategyInterface ) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing GenerationStrategy with the provided GenerationStrategy. Saves to database on completion if db_config is present. """ - ... + self._generation_strategy = generation_strategy + none_throws( + self._generation_strategy + )._experiment = self._none_throws_experiment() + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_early_stopping_strategy( self, early_stopping_strategy: BaseEarlyStoppingStrategy ) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Overwrite the existing EarlyStoppingStrategy with the provided EarlyStoppingStrategy. Saves to database on completion if db_config is present. """ - ... + self._early_stopping_strategy = early_stopping_strategy + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... def set_runner(self, runner: Runner) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers and power + users. + Attaches a Runner to the Experiment. Saves to database on completion if db_config is present. @@ -234,6 +265,10 @@ def set_runner(self, runner: Runner) -> None: def set_metrics(self, metrics: Sequence[Metric]) -> None: """ + This method is not part of the API and is provided (without guarantees of + method signature stability) for the convenience of some developers, power + users, and partners. + Finds equivallently named Metric that already exists on the Experiment and replaces it with the Metric provided, or adds the Metric provided to the Experiment as tracking metrics. diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 09f0e9e44c5..4f25db7d069 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -26,6 +26,12 @@ RangeParameterConfig, ) from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_branin_optimization_config, + get_percentile_early_stopping_strategy, +) +from ax.utils.testing.modeling_stubs import get_generation_strategy from pyre_extensions import none_throws @@ -154,3 +160,45 @@ def test_configure_optimization(self) -> None: objective="ne", outcome_constraints=["qps >= 0"], ) + + def test_set_experiment(self) -> None: + client = Client() + experiment = get_branin_experiment() + + client.set_experiment(experiment=experiment) + + self.assertEqual(client._experiment, experiment) + + def test_set_optimization_config(self) -> None: + client = Client() + optimization_config = get_branin_optimization_config() + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.set_optimization_config(optimization_config=optimization_config) + + client.set_experiment(experiment=get_branin_experiment()) + client.set_optimization_config( + optimization_config=optimization_config, + ) + + self.assertEqual( + none_throws(client._experiment).optimization_config, optimization_config + ) + + def test_set_generation_strategy(self) -> None: + client = Client() + client.set_experiment(experiment=get_branin_experiment()) + + generation_strategy = get_generation_strategy() + + client.set_generation_strategy(generation_strategy=generation_strategy) + self.assertEqual(client._generation_strategy, generation_strategy) + + def test_set_early_stopping_strategy(self) -> None: + client = Client() + early_stopping_strategy = get_percentile_early_stopping_strategy() + + client.set_early_stopping_strategy( + early_stopping_strategy=early_stopping_strategy + ) + self.assertEqual(client._early_stopping_strategy, early_stopping_strategy) From ba6578a25c9bffcabc8f9fee5aff9895e072779f Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Tue, 3 Dec 2024 13:36:48 -0800 Subject: [PATCH 3/3] Implement configure_runner, configure_metric (#3104) Summary: configure_runner and configure_metric allow users to attach custom Runners and Metrics to their experiment. configure_runner is fairly straightforward and just sets experiment.runner configure_metric is more complicated: given a list of IMetrics it iterates through and tries to find a metric with the same name somewhere on the experiment. In order it checks the Objective (single, MOO, or secularized), outcome constraints, then tracking metrics. If no metric with a matching name is found then the provided metric is added as a tracking metric. Reviewed By: lena-kashtelyan Differential Revision: D66305614 --- ax/preview/api/client.py | 111 ++++++++++++++++++--- ax/preview/api/tests/test_client.py | 148 +++++++++++++++++++++++++++- 2 files changed, 244 insertions(+), 15 deletions(-) diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 65f9bbf0a60..8ef45379b5c 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - +from logging import Logger from typing import Sequence from ax.analysis.analysis import Analysis, AnalysisCard # Used as a return type @@ -14,6 +14,7 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.metric import Metric +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import OptimizationConfig from ax.core.runner import Runner from ax.early_stopping.strategies import BaseEarlyStoppingStrategy @@ -33,9 +34,12 @@ from ax.preview.api.utils.instantiation.from_string import ( optimization_config_from_string, ) +from ax.utils.common.logger import get_logger from pyre_extensions import none_throws from typing_extensions import Self +logger: Logger = get_logger(__name__) + class Client: _experiment: Experiment | None = None @@ -168,15 +172,16 @@ def configure_runner(self, runner: IRunner) -> None: Saves to database on completion if db_config is present. """ - ... + self._set_runner(runner=runner) def configure_metrics(self, metrics: Sequence[IMetric]) -> None: """ - Finds equivallently named Metric that already exists on the Experiment and - replaces it with the Metric provided, or adds the Metric provided to the - Experiment as tracking metrics. + Attach a class with logic for autmating fetching of a given metric by + replacing its instance with the provided Metric from metrics sequence input, + or adds the Metric provided to the Experiment as a tracking metric if that + metric was not already present. """ - ... + self._set_metrics(metrics=metrics) # -------------------- Section 1.2: Set (not API) ------------------------------- def set_experiment(self, experiment: Experiment) -> None: @@ -228,6 +233,10 @@ def set_generation_strategy( self._generation_strategy )._experiment = self._none_throws_experiment() + none_throws( + self._generation_strategy + )._experiment = self._none_throws_experiment() + if self.db_config is not None: # TODO[mpolson64] Save to database ... @@ -251,7 +260,7 @@ def set_early_stopping_strategy( # TODO[mpolson64] Save to database ... - def set_runner(self, runner: Runner) -> None: + def _set_runner(self, runner: Runner) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers and power @@ -261,19 +270,33 @@ def set_runner(self, runner: Runner) -> None: Saves to database on completion if db_config is present. """ - ... + self._none_throws_experiment().runner = runner + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... - def set_metrics(self, metrics: Sequence[Metric]) -> None: + def _set_metrics(self, metrics: Sequence[Metric]) -> None: """ This method is not part of the API and is provided (without guarantees of method signature stability) for the convenience of some developers, power users, and partners. - Finds equivallently named Metric that already exists on the Experiment and - replaces it with the Metric provided, or adds the Metric provided to the - Experiment as tracking metrics. + Attach a class with logic for autmating fetching of a given metric by + replacing its instance with the provided Metric from metrics sequence input, + or adds the Metric provided to the Experiment as a tracking metric if that + metric was not already present. """ - ... + # If an equivalently named Metric already exists on the Experiment, replace it + # with the Metric provided. Otherwise, add the Metric to the Experiment as a + # tracking metric. + for metric in metrics: + # Check the optimization config first + self._overwrite_metric(metric=metric) + + if self.db_config is not None: + # TODO[mpolson64] Save to database + ... # -------------------- Section 2. Conduct Experiment ---------------------------- def get_next_trials( @@ -516,3 +539,65 @@ def _none_throws_experiment(self) -> Experiment: "experiment before utilizing any other methods on the Client." ), ) + + def _overwrite_metric(self, metric: Metric) -> None: + """ + Overwrite an existing Metric on the Experiment with the provided Metric if they + share the same name. If not Metric with the same name exists, add the Metric as + a tracking metric. + + Note that this method does not save the Experiment to the database (this is + handled in self._set_metrics). + """ + + # Check the OptimizationConfig first + if ( + optimization_config := self._none_throws_experiment().optimization_config + ) is not None: + # Check the objective + if isinstance( + multi_objective := optimization_config.objective, MultiObjective + ): + for i in range(len(multi_objective.objectives)): + if metric.name == multi_objective.objectives[i].metric.name: + multi_objective._objectives[i]._metric = metric + return + + if isinstance( + scalarized_objective := optimization_config.objective, + ScalarizedObjective, + ): + for i in range(len(scalarized_objective.metrics)): + if metric.name == scalarized_objective.metrics[i].name: + scalarized_objective._metrics[i] = metric + return + + if ( + isinstance(optimization_config.objective, Objective) + and metric.name == optimization_config.objective.metric.name + ): + optimization_config.objective._metric = metric + return + + # Check the outcome constraints + for i in range(len(optimization_config.outcome_constraints)): + if ( + metric.name + == optimization_config.outcome_constraints[i].metric.name + ): + optimization_config._outcome_constraints[i]._metric = metric + return + + # Check the tracking metrics + tracking_metric_names = self._none_throws_experiment()._tracking_metrics.keys() + if metric.name in tracking_metric_names: + self._none_throws_experiment()._tracking_metrics[metric.name] = metric + return + + # If an equivalently named Metric does not exist, add it as a tracking + # metric. + self._none_throws_experiment().add_tracking_metric(metric=metric) + logger.warning( + f"Metric {metric} not found in optimization config, added as tracking " + "metric." + ) diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 4f25db7d069..ddb5e42399a 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -5,9 +5,13 @@ # pyre-strict +from typing import Any, Mapping + +from ax.core.base_trial import TrialStatus + from ax.core.experiment import Experiment from ax.core.metric import Metric -from ax.core.objective import Objective +from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint from ax.core.parameter import ( @@ -25,6 +29,9 @@ ParameterType, RangeParameterConfig, ) +from ax.preview.api.protocols.metric import IMetric +from ax.preview.api.protocols.runner import IRunner +from ax.preview.api.types import TParameterization from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, @@ -32,7 +39,7 @@ get_percentile_early_stopping_strategy, ) from ax.utils.testing.modeling_stubs import get_generation_strategy -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws, override class TestClient(TestCase): @@ -161,6 +168,118 @@ def test_configure_optimization(self) -> None: outcome_constraints=["qps >= 0"], ) + def test_configure_runner(self) -> None: + client = Client() + runner = DummyRunner() + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.configure_runner(runner=runner) + + client.set_experiment(experiment=get_branin_experiment()) + client.configure_runner(runner=runner) + + self.assertEqual(none_throws(client._experiment).runner, runner) + + def test_configure_metric(self) -> None: + client = Client() + custom_metric = DummyMetric(name="custom") + + with self.assertRaisesRegex(AssertionError, "Experiment not set"): + client.configure_metrics(metrics=[custom_metric]) + + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1) + ) + ], + name="foo", + ) + ) + + # Test replacing a single objective + client.configure_optimization(objective="custom") + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws( + none_throws(client._experiment).optimization_config + ).objective.metric, + ) + + # Test replacing a multi-objective + client.configure_optimization(objective="custom, foo") + client.configure_metrics(metrics=[custom_metric]) + + self.assertIn( + custom_metric, + assert_is_instance( + none_throws( + none_throws(client._experiment).optimization_config + ).objective, + MultiObjective, + ).metrics, + ) + # Test replacing a scalarized objective + client.configure_optimization(objective="custom + foo") + client.configure_metrics(metrics=[custom_metric]) + + self.assertIn( + custom_metric, + assert_is_instance( + none_throws( + none_throws(client._experiment).optimization_config + ).objective, + ScalarizedObjective, + ).metrics, + ) + + # Test replacing an outcome constraint + client.configure_optimization( + objective="foo", outcome_constraints=["custom >= 0"] + ) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(none_throws(client._experiment).optimization_config) + .outcome_constraints[0] + .metric, + ) + + # Test replacing a tracking metric + client.configure_optimization( + objective="foo", + ) + none_throws(client._experiment).add_tracking_metric(metric=Metric("custom")) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(client._experiment).tracking_metrics[0], + ) + + # Test adding a tracking metric + client = Client() # Start a fresh Client + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(0, 1) + ) + ], + name="foo", + ) + ) + client.configure_metrics(metrics=[custom_metric]) + + self.assertEqual( + custom_metric, + none_throws(client._experiment).tracking_metrics[0], + ) + def test_set_experiment(self) -> None: client = Client() experiment = get_branin_experiment() @@ -202,3 +321,28 @@ def test_set_early_stopping_strategy(self) -> None: early_stopping_strategy=early_stopping_strategy ) self.assertEqual(client._early_stopping_strategy, early_stopping_strategy) + + +class DummyRunner(IRunner): + @override + def run_trial( + self, trial_index: int, parameterization: TParameterization + ) -> dict[str, Any]: ... + + @override + def poll_trial( + self, trial_index: int, trial_metadata: Mapping[str, Any] + ) -> TrialStatus: ... + + @override + def stop_trial( + self, trial_index: int, trial_metadata: Mapping[str, Any] + ) -> dict[str, Any]: ... + + +class DummyMetric(IMetric): + def fetch( + self, + trial_index: int, + trial_metadata: Mapping[str, Any], + ) -> tuple[int, float | tuple[float, float]]: ...