-
Notifications
You must be signed in to change notification settings - Fork 1
/
nested_search_example.py
63 lines (57 loc) · 1.75 KB
/
nested_search_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from ax import RangeParameter, ParameterType
from ax.core import (
SearchSpace,
Experiment,
OptimizationConfig,
Objective,
ObservationFeatures,
)
from ax.runners.synthetic import SyntheticRunner
from ax.modelbridge.registry import Models
from ax.metrics import BraninMetric
branin_search_space = SearchSpace(
parameters=[
RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10
),
RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0, upper=15
),
]
)
exp = Experiment(
name="test_branin",
search_space=branin_search_space,
optimization_config=OptimizationConfig(
objective=Objective(
metric=BraninMetric(name="branin", param_names=["x1", "x2"]), minimize=True,
),
),
runner=SyntheticRunner(),
)
sobol = Models.SOBOL(exp.search_space)
for _ in range(5):
trial = exp.new_trial(generator_run=sobol.gen(1))
trial.run()
trial.mark_completed()
best_arm = None
for _ in range(15):
gpei = Models.GPEI(experiment=exp, data=exp.fetch_data())
generator_run = gpei.gen(1)
best_arm, _ = generator_run.best_arm_predictions
trial = exp.new_trial(generator_run=generator_run)
trial.run()
trial.mark_completed()
fixed_features = ObservationFeatures({"x2": best_arm.parameters["x2"]})
for _ in range(15):
gpei = Models.GPEI(experiment=exp, data=exp.fetch_data())
generator_run = gpei.gen(
1, search_space=branin_search_space, fixed_features=fixed_features,
)
best_arm2, _ = generator_run.best_arm_predictions
trial = exp.new_trial(generator_run=generator_run)
trial.run()
trial.mark_completed()
exp.fetch_data()
best_parameters = best_arm2.parameters
1 + 1