-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #772 from starsimhub/calib-uplift
[DRAFT] Calib uplift
- Loading branch information
Showing
10 changed files
with
992 additions
and
358 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,28 @@ | ||
{ | ||
"summary": { | ||
"timevec": 2020.0, | ||
"births_new": 46.84158415841584, | ||
"births_cumulative": 2267.7425742574255, | ||
"births_cbr": 19.93833106141367, | ||
"deaths_new": 9.673267326732674, | ||
"deaths_cumulative": 468.5742574257426, | ||
"deaths_cmr": 4.118825151847284, | ||
"sir_n_susceptible": 2440.029702970297, | ||
"sir_n_infected": 3630.3960396039606, | ||
"sir_n_recovered": 5676.693069306931, | ||
"sir_prevalence": 0.32402568798505815, | ||
"sir_new_infections": 122.34653465346534, | ||
"sir_cum_infections": 12357.0, | ||
"sis_n_susceptible": 4784.306930693069, | ||
"sis_n_infected": 6973.504950495049, | ||
"sis_prevalence": 0.5720906510271361, | ||
"sis_new_infections": 193.5742574257426, | ||
"sis_cum_infections": 19551.0, | ||
"sis_rel_sus": 0.5019197711850157, | ||
"n_alive": 11747.118811881188, | ||
"new_deaths": 10.693069306930694, | ||
"cum_deaths": 1072.0 | ||
"births_new": 48.257425742574256, | ||
"births_cumulative": 2343.3069306930693, | ||
"births_cbr": 20.43178207644599, | ||
"deaths_new": 9.712871287128714, | ||
"deaths_cumulative": 470.58415841584156, | ||
"deaths_cmr": 4.112394571867341, | ||
"randomnet_n_edges": 58901.33663366337, | ||
"mfnet_n_edges": 4004.732673267327, | ||
"maternalnet_n_edges": 0.0, | ||
"sir_n_susceptible": 2464.970297029703, | ||
"sir_n_infected": 3658.227722772277, | ||
"sir_n_recovered": 5694.504950495049, | ||
"sir_prevalence": 0.32462009407521675, | ||
"sir_new_infections": 122.81188118811882, | ||
"sir_cum_infections": 12404.0, | ||
"sis_n_susceptible": 4828.3267326732675, | ||
"sis_n_infected": 7000.19801980198, | ||
"sis_prevalence": 0.5702209995778549, | ||
"sis_new_infections": 195.01980198019803, | ||
"sis_cum_infections": 19697.0, | ||
"sis_rel_sus": 0.5033450153204474, | ||
"n_alive": 11817.70297029703, | ||
"new_deaths": 10.821782178217822, | ||
"cum_deaths": 1084.0 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
{ | ||
"time": { | ||
"initialize": 0.055, | ||
"run": 1.013 | ||
"initialize": 0.053, | ||
"run": 0.914 | ||
}, | ||
"parameters": { | ||
"n_agents": 10000, | ||
"dur": 20, | ||
"dt": 0.2 | ||
}, | ||
"cpu_performance": 0.9665005580733697 | ||
"cpu_performance": 0.8734404449460742 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
""" | ||
Test calibration | ||
""" | ||
|
||
#%% Imports and settings | ||
import sciris as sc | ||
import starsim as ss | ||
import pandas as pd | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from ax.plot.contour import plot_contour | ||
from ax.plot.trace import optimization_trace_single_method | ||
from ax.service.managed_loop import optimize | ||
from ax.utils.notebook.plotting import init_notebook_plotting, render | ||
|
||
do_plot = 1 | ||
do_save = 0 | ||
n_agents = 2e3 | ||
|
||
#%% Helper functions | ||
|
||
def make_sim(): | ||
sir = ss.SIR( | ||
beta = ss.beta(0.9), | ||
dur_inf = ss.lognorm_ex(mean=ss.dur(6)), | ||
init_prev = ss.bernoulli(0.01), | ||
) | ||
|
||
#deaths = ss.Deaths(death_rate=15) | ||
#births = ss.Births(birth_rate=15) | ||
|
||
random = ss.RandomNet(n_contacts=ss.poisson(4)) | ||
|
||
sim = ss.Sim( | ||
dt = 1, | ||
unit = 'day', | ||
n_agents = n_agents, | ||
#total_pop = 9980999, | ||
start = sc.date('2024-01-01'), | ||
stop = sc.date('2024-01-31'), | ||
diseases = sir, | ||
networks = random, | ||
#demographics = [deaths, births], | ||
) | ||
|
||
return sim | ||
|
||
|
||
def build_sim(sim, calib_pars, **kwargs): | ||
""" Modify the base simulation by applying calib_pars """ | ||
|
||
for k, v in calib_pars.items(): | ||
if k == 'beta': | ||
sim.diseases.sir.pars['beta'] = ss.beta(v) | ||
elif k == 'dur_inf': | ||
sim.diseases.sir.pars['dur_inf'] = ss.lognorm_ex(mean=ss.dur(v)), #ss.dur(v) | ||
elif k == 'n_contacts': | ||
sim.networks.randomnet.pars.n_contacts = v # Typically a Poisson distribution, but this should set the distribution parameter value appropriately | ||
else: | ||
sim.pars[k] = v # Assume sim pars | ||
|
||
return sim | ||
|
||
def eval_sim(pars): | ||
sim = make_sim() | ||
sim.init() | ||
sim = build_sim(sim, pars) | ||
sim.run() | ||
#print('pars:', pars, ' --> Final prevalence:', sim.results.sir.prevalence[-1]) | ||
fig = sim.plot() | ||
fig.suptitle(pars) | ||
fig.subplots_adjust(top=0.9) | ||
plt.show() | ||
|
||
return dict( | ||
prevalence_error = ((sim.results.sir.prevalence[-1] - 0.10)**2, None), | ||
prevalence = (sim.results.sir.prevalence[-1], None), | ||
) | ||
|
||
|
||
#%% Define the tests | ||
def test_calibration(do_plot=False): | ||
sc.heading('Testing calibration') | ||
|
||
# Define the calibration parameters | ||
calib_pars = [ | ||
dict(name='beta', type='range', bounds=[0.01, 1.0], value_type='float', log_scale=True), | ||
dict(name='dur_inf', type='range', bounds=[1, 60], value_type='float', log_scale=False), | ||
#dict(name='init_prev', type='range', bounds=[0.01, 0.30], value_type='float', log_scale=False), | ||
dict(name='n_contacts', type='range', bounds=[2, 10], value_type='int', log_scale=False), | ||
] | ||
|
||
best_pars, values, exp, model = optimize( | ||
experiment_name = 'starsim', | ||
parameters = calib_pars, | ||
evaluation_function = eval_sim, | ||
objective_name = 'prevalence_error', | ||
minimize = True, | ||
parameter_constraints = None, | ||
outcome_constraints = None, | ||
total_trials = 10, | ||
arms_per_trial = 3, | ||
) | ||
|
||
return best_pars, values, exp, model | ||
|
||
|
||
#%% Run as a script | ||
if __name__ == '__main__': | ||
|
||
T = sc.timer() | ||
do_plot = True | ||
|
||
best_pars, values, exp, model = test_calibration(do_plot=do_plot) | ||
|
||
print('best_pars:', best_pars) | ||
print('values:', values) | ||
print('exp:', exp) | ||
print('model:', model) | ||
|
||
render(plot_contour(model=model, param_x='beta', param_y='init_prev', metric_name='prevalence')) | ||
|
||
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple | ||
# optimization runs, so we wrap out best objectives array in another array. | ||
|
||
for trial in exp.trials.values(): | ||
print(trial) | ||
print(dir(trial)) | ||
print(f"Trial {trial.index} with parameters {trial.arm.parameters} " | ||
f"has objective {trial.objective_mean}.") | ||
|
||
best_objectives = np.array( | ||
[[trial.objective_mean for trial in exp.trials.values()]] | ||
) | ||
best_objective_plot = optimization_trace_single_method( | ||
y = np.minimum.accumulate(best_objectives, axis=1), | ||
optimum = 0.10, #hartmann6.fmin, | ||
title = "Model performance vs. # of iterations", | ||
ylabel = "Prevalence", | ||
) | ||
render(best_objective_plot) | ||
|
||
plt.show() | ||
|
||
T.toc() |
Oops, something went wrong.