From 9b06f6c9dcf948585f106a3f071d3f8392cf5b00 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 09:39:58 +0000 Subject: [PATCH 1/9] new parameter interface in stan code --- .../stan/data/estimate_infections_params.stan | 4 + inst/stan/data/estimate_secondary_params.stan | 2 + inst/stan/data/gaussian_process.stan | 2 - inst/stan/data/observation_model.stan | 4 - inst/stan/data/params.stan | 13 ++ inst/stan/data/rt.stan | 2 - .../data/simulation_observation_model.stan | 2 - inst/stan/estimate_infections.stan | 62 +++++---- inst/stan/estimate_secondary.stan | 46 +++++-- inst/stan/functions/gaussian_process.stan | 7 +- inst/stan/functions/observation_model.stan | 23 ++-- inst/stan/functions/params.stan | 53 ++++++++ inst/stan/functions/rt.stan | 21 ++- inst/stan/simulate_infections.stan | 121 ++++++++++-------- inst/stan/simulate_secondary.stan | 102 ++++++++------- 15 files changed, 285 insertions(+), 179 deletions(-) create mode 100644 inst/stan/data/estimate_infections_params.stan create mode 100644 inst/stan/data/estimate_secondary_params.stan create mode 100644 inst/stan/data/params.stan create mode 100644 inst/stan/functions/params.stan diff --git a/inst/stan/data/estimate_infections_params.stan b/inst/stan/data/estimate_infections_params.stan new file mode 100644 index 000000000..3351f5ea3 --- /dev/null +++ b/inst/stan/data/estimate_infections_params.stan @@ -0,0 +1,4 @@ +int alpha_id; // parameter id of alpha (GP magnitude) +int R0_id; // parameter id of R0 +int frac_obs_id; // parameter id of frac_obs +int rep_phi_id; // parameter id of rep_phi_id diff --git a/inst/stan/data/estimate_secondary_params.stan b/inst/stan/data/estimate_secondary_params.stan new file mode 100644 index 000000000..736ce31df --- /dev/null +++ b/inst/stan/data/estimate_secondary_params.stan @@ -0,0 +1,2 @@ +int frac_obs_id; // parameter id of frac_obs +int rep_phi_id; // parameter id of rep_phi_id diff --git a/inst/stan/data/gaussian_process.stan b/inst/stan/data/gaussian_process.stan index 8154ffdfe..7990dba8a 100644 --- a/inst/stan/data/gaussian_process.stan +++ b/inst/stan/data/gaussian_process.stan @@ -4,8 +4,6 @@ real ls_sdlog; // sdlog for gp lengthscale prior real ls_min; // Lower bound for the lengthscale real ls_max; // Upper bound for the lengthscale - real alpha_mean; // mean of the alpha gp kernal parameter - real alpha_sd; // standard deviation of the alpha gp kernal parameter int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern real nu; // smoothness parameter for Matern kernel (used if gp_type = 2) real w0; // fundamental frequency for periodic kernel (used if gp_type = 1) diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index 671004ef4..ea0780998 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -1,11 +1,7 @@ array[t - seeding_time] int day_of_week; // day of the week indicator (1 - 7) int model_type; // type of model: 0 = poisson otherwise negative binomial - real phi_mean; // Mean and sd of the normal prior for the - real phi_sd; // reporting process int week_effect; // length of week effect int obs_scale; // logical controlling scaling of observations - real obs_scale_mean; // mean scaling factor for observations - real obs_scale_sd; // standard deviation of observation scaling real obs_weight; // weight given to observation in log density int likelihood; // Should the likelihood be included in the model int return_likelihood; // Should the likehood be returned by the model diff --git a/inst/stan/data/params.stan b/inst/stan/data/params.stan new file mode 100644 index 000000000..5ac81a1c4 --- /dev/null +++ b/inst/stan/data/params.stan @@ -0,0 +1,13 @@ +int n_params_variable; // number of parameters +int n_params_fixed; // number of parameters +vector[n_params_variable] params_lower; // lower bounds of the priors +vector[n_params_variable] params_upper; // upper bounds of the priors + +array[n_params_fixed + n_params_variable] int params_fixed_lookup; // fixed parameter lookup +array[n_params_fixed + n_params_variable] int params_variable_lookup; // variable parameter lookup + +vector[n_params_fixed] params_value; // fixed parameter values + +array[n_params_variable] int prior_dist; // 0 = lognormal; 1 = gamma; 2 = normal +int prior_dist_params_length; // number of parameters across all parametric delay distributions +vector[prior_dist_params_length] prior_dist_params; diff --git a/inst/stan/data/rt.stan b/inst/stan/data/rt.stan index 11b1989ae..b736f1ade 100644 --- a/inst/stan/data/rt.stan +++ b/inst/stan/data/rt.stan @@ -1,8 +1,6 @@ int estimate_r; // should the reproduction no be estimated (1 = yes) real prior_infections; // prior for initial infections real prior_growth; // prior on initial growth rate - real r_mean; // prior mean of reproduction number - real r_sd; // prior standard deviation of reproduction number int bp_n; // no of breakpoints (0 = no breakpoints) array[t - seeding_time] int breakpoints; // when do breakpoints occur int future_fixed; // is underlying future Rt assumed to be fixed diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index c8cab6b35..2b83cab66 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -2,7 +2,5 @@ int week_effect; // should a day of the week effect be estimated array[n, week_effect] real day_of_week_simplex; int obs_scale; - array[n, obs_scale] real frac_obs; int model_type; - array[n, model_type] real rep_phi; // overdispersion of the reporting process int trunc_id; // id of truncation diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 303b6b0e4..1914703f5 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -7,6 +7,7 @@ functions { #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan +#include functions/params.stan } data { @@ -16,6 +17,8 @@ data { #include data/rt.stan #include data/backcalc.stan #include data/observation_model.stan +#include data/params.stan +#include data/estimate_infections_params.stan } transformed data { @@ -27,9 +30,6 @@ transformed data { ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from ); matrix[noise_terms, gp_type == 1 ? 2*M : M] PHI = setup_gp(M, L, noise_terms, gp_type == 1, w0); // basis function - // Rt - real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2)); - real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2))); array[delay_types] int delay_type_max; profile("assign max") { @@ -41,12 +41,11 @@ transformed data { } parameters { + vector[n_params_variable] params; // gaussian process array[fixed ? 0 : 1] real rescaled_rho; // length scale of noise GP - array[fixed ? 0 : 1] real alpha; // scale of noise GP vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise // Rt - vector[estimate_r] log_R; // baseline reproduction number estimate (log) array[estimate_r] real initial_infections; // seed infections array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate array[bp_n > 0 ? 1 : 0] real bp_sd; // standard deviation of breakpoint effect @@ -54,8 +53,6 @@ parameters { // observation model vector[delay_params_length] delay_params; // delay parameters simplex[week_effect] day_of_week_simplex; // day of week reporting effect - array[obs_scale_sd > 0 ? 1 : 0] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process } transformed parameters { @@ -69,8 +66,12 @@ transformed parameters { // GP in noise - spectral densities profile("update gp") { if (!fixed) { + real alpha = get_param( + alpha_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); noise = update_gp( - PHI, M, L, alpha[1], rescaled_rho, eta, gp_type, nu + PHI, M, L, alpha, rescaled_rho, eta, gp_type, nu ); } } @@ -85,9 +86,12 @@ transformed parameters { 1, 1, 0 ); } - profile("R") { + profile("R0") { + real R0 = get_param( + R0_id, params_fixed_lookup, params_variable_lookup, params_value, params + ); R = update_Rt( - ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary + ot_h, R0, noise, breakpoints, bp_effects, stationary ); } profile("infections") { @@ -133,9 +137,11 @@ transformed parameters { // scaling of reported cases by fraction observed if (obs_scale) { profile("scale") { - reports = scale_obs( - reports, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean + real frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value, + params ); + reports = scale_obs(reports, frac_obs); } } @@ -162,7 +168,7 @@ model { // priors for noise GP if (!fixed) { profile("gp lp") { - gaussian_process_lp(alpha[1], eta, alpha_mean, alpha_sd); + gaussian_process_lp(eta); if (gp_type != 3) { lengthscale_lp(rescaled_rho[1], ls_meanlog, ls_sdlog, ls_min, ls_max); } @@ -177,29 +183,33 @@ model { ); } + // parameter priors + profile("param lp") { + params_lp( + params, prior_dist, prior_dist_params, params_lower, params_upper + ); + } + if (estimate_r) { // priors on Rt profile("rt lp") { rt_lp( - log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n, - seeding_time, r_logmean, r_logsd, prior_infections, prior_growth + initial_infections, initial_growth, bp_effects, bp_sd, bp_n, + seeding_time, prior_infections, prior_growth ); } } - // prior observation scaling - if (obs_scale_sd > 0) { - profile("scale lp") { - frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; - } - } - // observed reports from mean of reports (update likelihood) if (likelihood) { profile("report lp") { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); report_lp( - cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type, - obs_weight, accumulate + cases, cases_time, obs_reports, rep_phi, model_type, obs_weight, + accumulate ); } } @@ -213,6 +223,10 @@ generated quantities { vector[fixed ? 0 : 1] rho; profile("generated quantities") { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); if (!fixed && gp_type != 3) { vector[noise_terms] x = linspaced_vector(noise_terms, 1, noise_terms); rho[1] = rescaled_rho[1] * sd(x); diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 8f5081fb0..057617f77 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -4,6 +4,7 @@ functions { #include functions/delays.stan #include functions/observation_model.stan #include functions/secondary.stan +#include functions/params.stan } data { @@ -16,6 +17,8 @@ data { #include data/secondary.stan #include data/delays.stan #include data/observation_model.stan +#include data/params.stan +#include data/estimate_secondary_params.stan } transformed data{ @@ -29,8 +32,7 @@ parameters{ // observation model vector[delay_params_length] delay_params; simplex[week_effect] day_of_week_simplex; // day of week reporting effect - array[obs_scale] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process + vector[n_params_variable] params; } transformed parameters { @@ -43,7 +45,11 @@ transformed parameters { // scaling of primary reports by fraction observed if (obs_scale) { - scaled = scale_obs(primary, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean); + real frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); + scaled = scale_obs(primary, frac_obs); } else { scaled = primary; } @@ -89,15 +95,21 @@ model { delay_dist, delay_weight ); - // prior primary report scaling - if (obs_scale) { - frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; - } + // parameter priors + profile("param lp") { + params_lp( + params, prior_dist, prior_dist_params, params_lower, params_upper + ); + } // observed secondary reports from mean of secondary reports (update likelihood) if (likelihood) { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); report_lp( obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t], - rep_phi, phi_mean, phi_sd, model_type, 1, accumulate + rep_phi, model_type, 1, accumulate ); } } @@ -105,11 +117,17 @@ model { generated quantities { array[t - burn_in] int sim_secondary; vector[return_likelihood > 1 ? t - burn_in : 0] log_lik; - // simulate secondary reports - sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type); - // log likelihood of model - if (return_likelihood) { - log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, model_type, obs_weight); + { + real rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, params_value, + params + ); + // simulate secondary reports + sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type); + // log likelihood of model + if (return_likelihood) { + log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], + rep_phi, model_type, obs_weight); + } } } diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index ab5ee5eb7..e35906f02 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -208,13 +208,8 @@ void lengthscale_lp(real rho, real ls_meanlog, real ls_sdlog, /** * Priors for Gaussian process (excluding length scale) * - * @param alpha Scaling parameter * @param eta Vector of noise terms - * @param alpha_mean Mean of alpha - * @param alpha_sd Standard deviation of alpha */ -void gaussian_process_lp(real alpha, vector eta, real alpha_mean, - real alpha_sd) { - alpha ~ normal(alpha_mean, alpha_sd) T[0,]; +void gaussian_process_lp(vector eta) { eta ~ std_normal(); } diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index aa10ccfda..b34983771 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -97,16 +97,14 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, * @param cases Array of integer observed cases. * @param cases_time Array of integer time indices for observed cases. * @param reports Vector of expected reports. - * @param rep_phi Array of real values for reporting overdispersion. - * @param phi_mean Real value for mean of reporting overdispersion prior. - * @param phi_sd Real value for standard deviation of reporting overdispersion prior. + * @param rep_phi Real values for reporting overdispersion. * @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial). * @param weight Real value for weighting the log density contribution. * @param accumulate Integer flag indicating whether to accumulate reports (1) or not (0). */ void report_lp(array[] int cases, array[] int cases_time, vector reports, - array[] real rep_phi, real phi_mean, real phi_sd, - int model_type, real weight, int accumulate) { + real rep_phi, int model_type, real weight, + int accumulate) { int n = num_elements(cases_time) - accumulate; // number of observations vector[n] obs_reports; // reports at observation time array[n] int obs_cases; // observed cases at observation time @@ -130,10 +128,7 @@ void report_lp(array[] int cases, array[] int cases_time, vector reports, obs_cases = cases; } if (model_type) { - real dispersion = inv_square(phi_sd > 0 ? rep_phi[model_type] : phi_mean); - if (phi_sd > 0) { - rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,]; - } + real dispersion = inv_square(rep_phi); if (weight == 1) { obs_cases ~ neg_binomial_2(obs_reports, dispersion); } else { @@ -164,7 +159,7 @@ void report_lp(array[] int cases, array[] int cases_time, vector reports, * @return A vector of log likelihoods for each time point. */ vector report_log_lik(array[] int cases, vector reports, - array[] real rep_phi, int model_type, real weight) { + real rep_phi, int model_type, real weight) { int t = num_elements(reports); vector[t] log_lik; @@ -174,7 +169,7 @@ vector report_log_lik(array[] int cases, vector reports, log_lik[i] = poisson_lpmf(cases[i] | reports[i]) * weight; } } else { - real dispersion = inv_square(rep_phi[model_type]); + real dispersion = inv_square(rep_phi); for (i in 1:t) { log_lik[i] = neg_binomial_2_lpmf(cases[i] | reports[i], dispersion) * weight; } @@ -188,17 +183,17 @@ vector report_log_lik(array[] int cases, vector reports, * This function generates random samples of reported cases based on the specified model type. * * @param reports Vector of expected reports. - * @param rep_phi Array of real values for reporting overdispersion. + * @param rep_phi Real value for reporting overdispersion. * @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial). * * @return An array of integer sampled reports. */ -array[] int report_rng(vector reports, array[] real rep_phi, int model_type) { +array[] int report_rng(vector reports, real rep_phi, int model_type) { int t = num_elements(reports); array[t] int sampled_reports; real dispersion = 1e5; if (model_type) { - dispersion = inv_square(rep_phi[model_type]); + dispersion = inv_square(rep_phi); } for (s in 1:t) { diff --git a/inst/stan/functions/params.stan b/inst/stan/functions/params.stan new file mode 100644 index 000000000..3861106c2 --- /dev/null +++ b/inst/stan/functions/params.stan @@ -0,0 +1,53 @@ +real get_param(int id, + array[] int params_fixed_lookup, + array[] int params_variable_lookup, + vector params_value, vector params) { + if (id == 0) { + return 0; // parameter not used + } else if (params_fixed_lookup[id]) { + return params_value[params_fixed_lookup[id]]; + } else { + return params[params_variable_lookup[id]]; + } +} + +vector get_param(int id, + array[] int params_fixed_lookup, + array[] int params_variable_lookup, + vector params_value, matrix params) { + int n_samples = rows(params); + if (id == 0) { + return rep_vector(0, n_samples) ; // parameter not used + } else if (params_fixed_lookup[id]) { + return rep_vector(params_value[params_fixed_lookup[id]], n_samples); + } else { + return params[, params_variable_lookup[id]]; + } +} + +void params_lp(vector params, array[] int prior_dist, + vector prior_dist_params, vector params_lower, + vector params_upper) { + int params_id = 1; + int num_params = num_elements(params); + for (id in 1:num_params) { + if (prior_dist[id] == 0) { // lognormal + params[id] ~ + lognormal(prior_dist_params[params_id], prior_dist_params[params_id + 1]) + T[params_lower[id], params_upper[id]]; + params_id += 2; + } else if (prior_dist[id] == 1) { + params[id] ~ + gamma(prior_dist_params[params_id], prior_dist_params[params_id + 1]) + T[params_lower[id], params_upper[id]]; + params_id += 2; + } else if (prior_dist[id] == 2) { + params[id] ~ + normal(prior_dist_params[params_id], prior_dist_params[params_id + 1]) + T[params_lower[id], params_upper[id]]; + params_id += 2; + } else { + reject("dist must be <= 2"); + } + } +} diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index ad2d877b1..e5ebb30f4 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -4,7 +4,7 @@ * process. * * @param t Length of the time series - * @param log_R Logarithm of the base reproduction number + * @param R0 Initial reproduction number * @param noise Vector of Gaussian process noise values * @param bps Array of breakpoint indices * @param bp_effects Vector of breakpoint effects @@ -12,19 +12,19 @@ * (1) or non-stationary (0) * @return A vector of length t containing the updated Rt values */ -vector update_Rt(int t, real log_R, vector noise, array[] int bps, +vector update_Rt(int t, real R0, vector noise, array[] int bps, vector bp_effects, int stationary) { // define control parameters int bp_n = num_elements(bp_effects); int gp_n = num_elements(noise); // initialise intercept - vector[t] R = rep_vector(log_R, t); + vector[t] logR = rep_vector(log(R0), t); //initialise breakpoints + rw if (bp_n) { vector[bp_n + 1] bp0; bp0[1] = 0; bp0[2:(bp_n + 1)] = cumulative_sum(bp_effects); - R = R + bp0[bps]; + logR = logR + bp0[bps]; } //initialise gaussian process if (gp_n) { @@ -39,32 +39,27 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, gp[2:(gp_n + 1)] = noise; gp = cumulative_sum(gp); } - R = R + gp; + logR = logR + gp; } - return exp(R); + return exp(logR); } /** * Calculate the log-probability of the reproduction number (Rt) priors * - * @param log_R Logarithm of the base reproduction number * @param initial_infections Array of initial infection values * @param initial_growth Array of initial growth rates * @param bp_effects Vector of breakpoint effects * @param bp_sd Array of breakpoint standard deviations * @param bp_n Number of breakpoints * @param seeding_time Time point at which seeding occurs - * @param r_logmean Log-mean of the prior distribution for the base reproduction number - * @param r_logsd Log-standard deviation of the prior distribution for the base reproduction number * @param prior_infections Prior mean for initial infections * @param prior_growth Prior mean for initial growth rates */ -void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth, +void rt_lp(array[] real initial_infections, array[] real initial_growth, vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time, - real r_logmean, real r_logsd, real prior_infections, - real prior_growth) { - log_R ~ normal(r_logmean, r_logsd); + real prior_infections, real prior_growth) { //breakpoint effects on Rt if (bp_n > 0) { bp_sd[1] ~ normal(0, 0.1) T[0,]; diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 245f80c49..3e8131994 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -7,6 +7,7 @@ functions { #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan +#include functions/params.stan } data { @@ -21,6 +22,10 @@ data { #include data/simulation_delays.stan // observation model #include data/simulation_observation_model.stan + // parameters +#include data/params.stan +#include data/estimate_infections_params.stan + matrix[n, n_params_variable] params; // parameters } transformed data { @@ -36,66 +41,76 @@ generated quantities { matrix[n, t - seeding_time] reports; // observed cases array[n, t - seeding_time] int imputed_reports; matrix[n, t - seeding_time - 1] r; - for (i in 1:n) { - // generate infections from Rt trace - vector[delay_type_max[gt_id] + 1] gt_rev_pmf; - gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 1, 1, 0 + { + vector[n] rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, + params_value, params ); - - infections[i] = to_row_vector(generate_infections( - to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], - initial_growth[i], pop, future_time - )); - - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + vector[n] frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ); + for (i in 1:n) { + // generate infections from Rt trace + vector[delay_type_max[gt_id] + 1] gt_rev_pmf; + gt_rev_pmf = get_delay_rev_pmf( + gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, delay_types_groups, delay_max, delay_np_pmf, delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 - ); - // convolve from latent infections to mean of observations - reports[i] = to_row_vector(convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time) + 1, 1, 0 ); - } else { - reports[i] = to_row_vector( - infections[i, (seeding_time + 1):t] - ); - } - // weekly reporting effect - if (week_effect > 1) { - reports[i] = to_row_vector( - day_of_week_effect(to_vector(reports[i]), day_of_week, - to_vector(day_of_week_simplex[i]))); - } - // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 + infections[i] = to_row_vector(generate_infections( + to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], + initial_growth[i], pop, future_time + )); + + if (delay_id) { + vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 0 + ); + // convolve from latent infections to mean of observations + reports[i] = to_row_vector(convolve_to_report( + to_vector(infections[i]), delay_rev_pmf, seeding_time) + ); + } else { + reports[i] = to_row_vector( + infections[i, (seeding_time + 1):t] + ); + } + + // weekly reporting effect + if (week_effect > 1) { + reports[i] = to_row_vector( + day_of_week_effect(to_vector(reports[i]), day_of_week, + to_vector(day_of_week_simplex[i]))); + } + // truncate near time cases to observed reports + if (trunc_id) { + vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( + trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 1 + ); + reports[i] = to_row_vector(truncate_obs( + to_vector(reports[i]), trunc_rev_cmf, 0) + ); + } + // scale observations + if (obs_scale) { + reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i])); + } + // simulate reported cases + imputed_reports[i] = report_rng( + to_vector(reports[i]), rep_phi[i], model_type ); - reports[i] = to_row_vector(truncate_obs( - to_vector(reports[i]), trunc_rev_cmf, 0) + r[i] = to_row_vector( + calculate_growth(to_vector(infections[i]), seeding_time + 1) ); } - // scale observations - if (obs_scale) { - reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1])); - } - // simulate reported cases - imputed_reports[i] = report_rng( - to_vector(reports[i]), rep_phi[i], model_type - ); - r[i] = to_row_vector( - calculate_growth(to_vector(infections[i]), seeding_time + 1) - ); } } diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index 8bd4386f1..ab75ba040 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -4,6 +4,7 @@ functions { #include functions/delays.stan #include functions/observation_model.stan #include functions/secondary.stan +#include functions/params.stan } data { @@ -16,10 +17,11 @@ data { array[t - h] int obs; // observed secondary data matrix[n, t] primary; // observed primary data #include data/secondary.stan - // delay from infection to report #include data/simulation_delays.stan - // observation model #include data/simulation_observation_model.stan +#include data/params.stan +#include data/estimate_secondary_params.stan + matrix[n, n_params_variable] params; // parameters } transformed data { @@ -31,56 +33,66 @@ transformed data { generated quantities { array[n, all_dates ? t : h] int sim_secondary; - for (i in 1:n) { - vector[t] secondary; - vector[t] scaled; - vector[t] convolved = rep_vector(1e-5, t); + { + vector[n] rep_phi = get_param( + rep_phi_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ); + vector[n] frac_obs = get_param( + frac_obs_id, params_fixed_lookup, params_variable_lookup, + params_value, params + ); + for (i in 1:n) { + vector[t] secondary; + vector[t] scaled; + vector[t] convolved = rep_vector(1e-5, t); - if (obs_scale) { - scaled = scale_obs(to_vector(primary[i]), frac_obs[i, 1]); - } else { - scaled = to_vector(primary[i]); - } + if (obs_scale) { + scaled = scale_obs(to_vector(primary[i]), frac_obs[i]); + } else { + scaled = to_vector(primary[i]); + } - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 0 + if (delay_id) { + vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 0 + ); + convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); + } else { + convolved = convolved + scaled; + } + + // calculate secondary reports from primary + secondary = calculate_secondary( + scaled, convolved, obs, cumulative, historic, primary_hist_additive, + current, primary_current_additive, t - h + 1 ); - convolved = convolved + convolve_to_report(scaled, delay_rev_pmf, 0); - } else { - convolved = convolved + scaled; - } - // calculate secondary reports from primary - secondary = calculate_secondary( - scaled, convolved, obs, cumulative, historic, primary_hist_additive, - current, primary_current_additive, t - h + 1 - ); + // weekly reporting effect + if (week_effect > 1) { + secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i])); + } - // weekly reporting effect - if (week_effect > 1) { - secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i])); - } + // truncate near time cases to observed reports + if (trunc_id) { + vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( + trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, + 0, 1, 1 + ); + secondary = truncate_obs( + secondary, trunc_rev_cmf, 0 + ); + } - // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, - 0, 1, 1 - ); - secondary = truncate_obs( - secondary, trunc_rev_cmf, 0 + // simulate secondary reports + sim_secondary[i] = report_rng( + tail(secondary, all_dates ? t : h), rep_phi[i], model_type ); } - - // simulate secondary reports - sim_secondary[i] = report_rng( - tail(secondary, all_dates ? t : h), rep_phi[i], model_type - ); } } From d09f0cb391a6dfeda0e74c9d27f3f4b854a368d4 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 09:40:24 +0000 Subject: [PATCH 2/9] adapt R code to new param interface --- R/create.R | 164 +++++++++++++++++++++++++++++----------- R/dist_spec.R | 51 +++++++++++++ R/estimate_secondary.R | 11 ++- R/extract.R | 13 +--- R/opts.R | 136 ++++++++++++++++++++++----------- R/simulate_infections.R | 27 ++++--- R/simulate_secondary.R | 24 +++--- 7 files changed, 302 insertions(+), 124 deletions(-) diff --git a/R/create.R b/R/create.R index 1309fd508..42231785f 100644 --- a/R/create.R +++ b/R/create.R @@ -315,8 +315,6 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, # map settings to underlying gp stan requirements rt_data <- list( - r_mean = rt$prior$mean, - r_sd = rt$prior$sd, estimate_r = as.numeric(rt$use_rt), bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0), breakpoints = breakpoints, @@ -429,8 +427,6 @@ create_gp_data <- function(gp = gp_opts(), data) { ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd), ls_min = gp$ls_min, ls_max = gp$ls_max, - alpha_mean = gp$alpha_mean, - alpha_sd = gp$alpha_sd, gp_type = data.table::fcase( gp$kernel == "se", 0, gp$kernel == "periodic", 1, @@ -477,13 +473,9 @@ create_gp_data <- function(gp = gp_opts(), data) { create_obs_model <- function(obs = obs_opts(), dates) { data <- list( model_type = as.numeric(obs$family == "negbin"), - phi_mean = obs$phi$mean, - phi_sd = obs$phi$sd, week_effect = ifelse(obs$week_effect, obs$week_length, 1), obs_weight = obs$weight, - obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1), - obs_scale_mean = obs$scale$mean, - obs_scale_sd = obs$scale$sd, + obs_scale = as.integer(obs$scale != Fixed(1)), accumulate = obs$accumulate, likelihood = as.numeric(obs$likelihood), return_likelihood = as.numeric(obs$return_likelihood) @@ -584,15 +576,30 @@ create_stan_data <- function(data, seeding_time, ) ) + # parameters + stan_data <- c( + stan_data, + create_stan_params( + alpha = gp$alpha, + R0 = rt$prior, + frac_obs = obs$scale, + rep_phi = obs$phi, + lower_bounds = c( + alpha = 0, + R0 = 0, + frac_obs = 0, + rep_phi = 0 + ) + ) + ) + # rescale mean shifted prior for back calculation if observation scaling is # used - if (stan_data$obs_scale == 1) { - stan_data$shifted_cases <- - stan_data$shifted_cases / stan_data$obs_scale_mean - stan_data$prior_infections <- log( - exp(stan_data$prior_infections) / stan_data$obs_scale_mean - ) - } + stan_data$shifted_cases <- + stan_data$shifted_cases / mean(obs$scale) + stan_data$prior_infections <- log( + exp(stan_data$prior_infections) / mean(obs$scale) + ) return(stan_data) } @@ -642,34 +649,15 @@ create_initial_conditions <- function(data) { out$rescaled_rho < data$ls_min, data$ls_min + 0.001, default = out$rescaled_rho )) - - out$alpha <- array( - truncnorm::rtruncnorm( - 1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd - ) - ) } else { out$eta <- array(numeric(0)) out$rescaled_rho <- array(numeric(0)) - out$alpha <- array(numeric(0)) - } - if (data$model_type == 1) { - out$rep_phi <- array( - truncnorm::rtruncnorm( - 1, - a = 0, mean = data$phi_mean, sd = data$phi_sd - ) - ) } if (data$estimate_r == 1) { out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2)) if (data$seeding_time > 1) { out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02)) } - out$log_R <- array(rnorm( - n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd), - sd = convert_to_logsd(data$r_mean, data$r_sd) - )) } if (data$bp_n > 0) { @@ -679,20 +667,17 @@ create_initial_conditions <- function(data) { out$bp_sd <- array(numeric(0)) out$bp_effects <- array(numeric(0)) } - if (data$obs_scale_sd > 0) { - out$frac_obs <- array(truncnorm::rtruncnorm(1, - a = 0, b = 1, - mean = data$obs_scale_mean, - sd = data$obs_scale_sd - )) - } else { - out$frac_obs <- array(numeric(0)) - } if (data$week_effect > 0) { out$day_of_week_simplex <- array( rep(1 / data$week_effect, data$week_effect) ) } + out$params <- array(truncnorm::rtruncnorm( + data$n_params_variable, + a = data$params_lower, + b = data$params_upper, + mean = 0, sd = 1 + )) return(out) } return(init_fun) @@ -872,3 +857,94 @@ create_stan_delays <- function(..., time_points = 1L) { return(ret) } + +##' Create parameters for stan +##' +##' @param ... Named delay distributions. The names are assigned to IDs +##' @param lower_bounds Named vector of lower bounds for any delay(s). The names +##' have to correspond to the names given to the delay distributions passed. +##' If `NULL` (default) no parameters are given a lower bound. +##' @return A list of variables as expected by the stan model +##' @importFrom data.table fcase +##' @keywords internal +create_stan_params <- function(..., lower_bounds = NULL) { + params <- list(...) + + ## set IDs of any parameters that is NULL to 0 and remove + null_params <- vapply(params, is.null, logical(1)) + null_ids <- rep(0, sum(null_params)) + if (length(null_ids) > 0) { + names(null_ids) <- paste(names(null_params)[null_params], "id", sep = "_") + params <- params[!null_params] + } + + ## initialise variables + params_fixed_lookup <- rep(0L, length(params)) + params_variable_lookup <- rep(0L, length(params)) + + ## identify fixed/variable parameters + fixed <- vapply(params, get_distribution, character(1)) == "fixed" + params_fixed_lookup[fixed] <- seq_along(which(fixed)) + params_variable_lookup[!fixed] <- seq_along(which(!fixed)) + + ## lower bounds + params_lower <- rep(-Inf, length(params[!fixed])) + names(params_lower) <- names(params[!fixed]) + lower_bounds <- lower_bounds[names(params_lower)] + params_lower[names(lower_bounds)] <- lower_bounds + + ## upper bounds + params_upper <- vapply(params[!fixed], max, numeric(1)) + + ## prior distributions + prior_dist_name <- vapply(params[!fixed], get_distribution, character(1)) + prior_dist <- fcase( + prior_dist_name == "lognormal", 0L, + prior_dist_name == "gamma", 1L, + prior_dist_name == "normal", 2L + ) + ## parameters + prior_dist_params <- lapply(params[!fixed], get_parameters) + prior_dist_params_lengths <- lengths(prior_dist_params) + + ## check none of the parameters are uncertain + prior_uncertain <- vapply(prior_dist_params, function(x) { + !all(vapply(x, is.numeric, logical(1))) + }, logical(1)) + if (any(prior_uncertain)) { + uncertain_priors <- names(params[!fixed])[prior_uncertain] # nolint: object_usage_linter + cli_abort( + c( + "!" = "Parameter prior distribution{?s} for {.var {uncertain_priors}} + cannot have uncertain parameters." + ) + ) + } + + prior_dist_params <- unlist(prior_dist_params) + if (is.null(prior_dist_params)) { + prior_dist_params <- numeric(0) + } + + ## extract distributions and parameters + ret <- list( + n_params_variable = length(params) - sum(fixed), + n_params_fixed = sum(fixed), + params_lower = array(params_lower), + params_upper = array(params_upper), + params_fixed_lookup = array(params_fixed_lookup), + params_variable_lookup = array(params_variable_lookup), + params_value = array(vapply( + params[fixed], \(x) get_parameters(x)$value, numeric(1) + )), + prior_dist = array(prior_dist), + prior_dist_params_length = sum(prior_dist_params_lengths), + prior_dist_params = array(prior_dist_params) + ) + ids <- seq_along(params) + if (length(ids) > 0) { + names(ids) <- paste(names(params), "id", sep = "_") + } + ret <- c(ret, as.list(ids), as.list(null_ids)) + return(ret) +} diff --git a/R/dist_spec.R b/R/dist_spec.R index 1186a6438..37d368362 100644 --- a/R/dist_spec.R +++ b/R/dist_spec.R @@ -125,6 +125,57 @@ discrete_pmf <- function(distribution = c(e1, e2) } +##' Compares two delay distributions +##' +##' @param e1 The first delay distribution (of type ) to +##' combine. +##' +##' @param e2 The second delay distribution (of type ) to +##' combine. +##' @method == dist_spec +##' @return TRUE or FALSE +##' @export +##' @examples +##' Fixed(1) == Normal(1, 0.5) +## nolint start: cyclocomp_linter +`==.dist_spec` <- function(e1, e2) { + ## both must have same number of distributions + if (ndist(e1) != ndist(e2)) return(FALSE) + ## loop over constituent distributions + for (i in seq_len(ndist(e1))) { + ## distributions need to be the same + if (get_distribution(e1, i) != get_distribution(e2, i)) return(FALSE) + if (get_distribution(e1, i) == "nonparametric") { + ## if nonparametric then PMFs need to be the same + if (!identical(get_pmf(e1, i), get_pmf(e2, i))) return(FALSE) + } else { + ## if parametric then all parameters need to be the same + params1 <- get_parameters(e1, i) + params2 <- get_parameters(e2, i) + for (param in names(params1)) { + ## all parameters must be the same type + if ((is(params1[[param]], "dist_spec") && + is(params2[[param]], "dist_spec")) || + (is.numeric(params1[[param]]) && is.numeric(params2[[param]]))) { + ## if parameters are the same type they need to be same value + if (!(params1[[param]] == params2[[param]])) return(FALSE) + } else { + return(FALSE) + } + } + } + } + return(TRUE) +} +## nolint end: cyclocomp_linter + +##' @rdname equals-.dist_spec +##' @method != dist_spec +##' @export +`!=.dist_spec` <- function(e1, e2) { + !(e1 == e2) +} + #' Combines multiple delay distributions for further processing #' #' @description `r lifecycle::badge("experimental")` diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 1af53fbb5..a316c0fb2 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -239,6 +239,15 @@ estimate_secondary <- function(data, # observation model data stan_data <- c(stan_data, create_obs_model(obs, dates = reports$date)) + stan_data <- c(stan_data, create_stan_params( + frac_obs = obs$scale, + rep_phi = obs$phi, + lower_bounds = c( + frac_obs = 0, + rep_phi = 0 + ) + )) + # update data to use specified priors rather than defaults stan_data <- update_secondary_args(stan_data, priors = priors, verbose = verbose @@ -663,7 +672,7 @@ forecast_secondary <- function(estimate, # allocate empty parameters data <- allocate_empty( - data, c("frac_obs", "delay_params", "rep_phi"), + data, c("params", "delay_params"), n = data$n ) data$all_dates <- as.integer(all_dates) diff --git a/R/extract.R b/R/extract.R index 3c6d04489..cf8f74a9c 100644 --- a/R/extract.R +++ b/R/extract.R @@ -46,10 +46,12 @@ extract_parameter <- function(param, samples, dates) { #' value #' @keywords internal extract_static_parameter <- function(param, samples) { + id <- samples[[paste(param, "id", sep = "_")]] + lookup <- samples[["params_variable_lookup"]][id] data.table::data.table( parameter = param, - sample = seq_along(samples[[param]]), - value = samples[[param]] + sample = seq_along(samples[["params"]][, lookup]), + value = samples[["params"]][, lookup] ) } @@ -239,16 +241,9 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, } if (data$model_type == 1) { out$reporting_overdispersion <- extract_static_parameter("rep_phi", samples) - out$reporting_overdispersion <- out$reporting_overdispersion[, - value := value.V1][, - value.V1 := NULL - ] } if ("obs_scale_sd" %in% names(data) && data$obs_scale_sd > 0) { out$fraction_observed <- extract_static_parameter("frac_obs", samples) - out$fraction_observed <- out$fraction_observed[, value := value.V1][, - value.V1 := NULL - ] } return(out) } diff --git a/R/opts.R b/R/opts.R index 2627b14aa..e8533e00d 100644 --- a/R/opts.R +++ b/R/opts.R @@ -297,9 +297,10 @@ trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001, #' reproduction number. Custom settings can be supplied which override the #' defaults. #' -#' @param prior List containing named numeric elements "mean" and "sd". The -#' mean and standard deviation of the log normal Rt prior. Defaults to mean of -#' 1 and standard deviation of 1. +#' @param prior A `` giving the prior of the initial reproduciton +#' number. Ignored if `use_rt` is `FALSE`. Defaults to a LogNormal distributin +#' with mean of 1 and standard deviation of 1: `LogNormal(mean = 1, sd = 1)`. +#' A lower limit of 0 will be enforced automatically. #' #' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate #' infections and hence reported cases. @@ -343,7 +344,7 @@ trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001, #' #' # add a weekly random walk #' rt_opts(rw = 7) -rt_opts <- function(prior = list(mean = 1, sd = 1), +rt_opts <- function(prior = LogNormal(mean = 1, sd = 1), use_rt = TRUE, rw = 0, use_breakpoints = TRUE, @@ -351,7 +352,6 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), gp_on = c("R_t-1", "R0"), pop = 0) { rt <- list( - prior = prior, use_rt = use_rt, rw = rw, use_breakpoints = use_breakpoints, @@ -365,15 +365,37 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), rt$use_breakpoints <- TRUE } - if (!("mean" %in% names(rt$prior) && "sd" %in% names(rt$prior))) { - cli_abort( + if (is.list(prior) && !is(prior, "dist_spec")) { + cli_warn( c( - "!" = "{.var prior} must have both {.var mean} and {.var sd} - specified.", - "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + "!" = "Specifying {.var prior} as a list is deprecated.", + "i" = "Use a {.cls dist_spec} instead." ) ) + if (!("mean" %in% names(prior) && "sd" %in% names(prior))) { + cli_abort( + c( + "!" = "{.var prior} must have both {.var mean} and {.var sd} + specified.", + "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + ) + ) + } + prior <- LogNormal(mean = prior$mean, sd = prior$sd) } + + if (rt$use_rt) { + rt$prior <- prior + } else { + if (!missing(prior)) { + cli_warn( + c( + "!" = "Rt {.var prior} is ignored if {.var use_rt} is FALSE." + ) + ) + } + } + attr(rt, "class") <- c("rt_opts", class(rt)) return(rt) } @@ -453,14 +475,17 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"), #' scale. Updated in [create_gp_data()] to be the length of the input data if #' this is smaller. #' -#' @param alpha_mean Numeric, defaults to 0. The mean of the magnitude parameter -#' of the Gaussian process kernel. Should be approximately the expected standard -#' deviation of the Gaussian process (logged Rt in case of the renewal model, -#' logged infections in case of the nonmechanistic model). +#' @param alpha A `` giving the prior distribution of the magnitude +#' parameter of the Gaussian process kernel. Should be approximately the +#' expected standard deviation of the Gaussian process (logged Rt in case of +#' the renewal model, logged infections in case of the nonmechanistic model). +#' Defaults to a half-normal distribution with mean 0 and sd 0.01: +#' `Normal(mean = 0, sd = 0.01)` (a lower limit of 0 will be enforced +#' automatically to ensure positivity) #' -#' @param alpha_sd Numeric, defaults to 0.01. The standard deviation of the -#' magnitude parameter of the Gaussian process kernel. Can be tuned to adjust -#' how far alpha is allowed to deviate form its prior mean (`alpha_mean`). +#' @param alpha_mean Deprecated; use `alpha` instead. +#' +#' @param alpha_sd Deprecated; use `alpha` instead. #' #' @param kernel Character string, the type of kernel required. Currently #' supporting the Matern kernel ("matern"), squared exponential kernel ("se"), @@ -508,18 +533,28 @@ gp_opts <- function(basis_prop = 0.2, ls_sd = 7, ls_min = 0, ls_max = 60, - alpha_mean = 0, - alpha_sd = 0.01, + alpha = Normal(mean = 0, sd = 0.01), kernel = c("matern", "se", "ou", "periodic"), matern_order = 3 / 2, matern_type, - w0 = 1.0) { + w0 = 1.0, + alpha_mean, alpha_sd) { if (!missing(matern_type)) { lifecycle::deprecate_warn( "1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)" ) } + if (!missing(alpha_mean)) { + lifecycle::deprecate_warn( + "1.7.0", "gp_opts(alpha_mean)", "gp_opts(alpha)" + ) + } + if (!missing(alpha_sd)) { + lifecycle::deprecate_warn( + "1.7.0", "gp_opts(alpha_sd)", "gp_opts(alpha)" + ) + } if (!missing(matern_type)) { if (!missing(matern_order) && matern_type != matern_order) { @@ -557,8 +592,7 @@ gp_opts <- function(basis_prop = 0.2, ls_sd = ls_sd, ls_min = ls_min, ls_max = ls_max, - alpha_mean = alpha_mean, - alpha_sd = alpha_sd, + alpha = alpha, kernel = kernel, matern_order = matern_order, w0 = w0 @@ -575,13 +609,12 @@ gp_opts <- function(basis_prop = 0.2, #' model. Custom settings can be supplied which override the defaults. #' @param family Character string defining the observation model. Options are #' Negative binomial ("negbin"), the default, and Poisson. -#' @param phi Overdispersion parameter of the reporting process, used only if -#' `familiy` is "negbin". Can be supplied either as a single numeric value -#' (fixed overdispersion) or a list with numeric elements mean (`mean`) and -#' standard deviation (`sd`) defining a normally distributed prior. -#' Internally parameterised such that the overdispersion is one over the -#' square of this prior overdispersion. Defaults to a list with elements -#' `mean = 0` and `sd = 0.25`. +#' @param phi A `` specifying a prior on the overdispersion parameter +#' of the reporting process, used only if `familiy` is "negbin". Internally +#' parameterised such that the overdispersion is one over the square of this +#' prior overdispersion phi. Defaults to a half-normal distribution with mean +#' of 0 and standard deviation of 0.25: `Normal(mean = 0, sd = 0.25)`. A lower +#' limit of zero will be enforced automatically. #' @param weight Numeric, defaults to 1. Weight to give the observed data in the #' log density. #' @param week_effect Logical defaulting to `TRUE`. Should a day of the week @@ -589,11 +622,12 @@ gp_opts <- function(basis_prop = 0.2, #' @param week_length Numeric assumed length of the week in days, defaulting to #' 7 days. This can be modified if data aggregated over a period other than a #' week or if data has a non-weekly periodicity. -#' @param scale Scaling factor to be applied to map latent infections (convolved -#' to date of report). Can be supplied either as a single numeric value (fixed -#' scale) or a list with numeric elements mean (`mean`) and standard deviation -#' (`sd`) defining a normally distributed scaling factor. Defaults to 1, i.e. -#' no scaling. +#' @param scale A `` specifying a prior on the scaling factor to be +#' applied to map latent infections (convolved to date of report). Defaults +#' to a fixed value of 1, i.e. no scaling: `Fixed(1)`. A lower limit of zero +#' will be enforced automatically. If setting to a prior distribution and no +#' overreporting is expected, it might be sensible to set a maximum of 1 via +#' the `max` option when declaring the distribution. #' @param na Character. Options are "missing" (the default) and "accumulate". #' This determines how NA values in the data are interpreted. If set to #' "missing", any NA values in the observation data set will be interpreted as @@ -621,11 +655,11 @@ gp_opts <- function(basis_prop = 0.2, #' # Scale reported data #' obs_opts(scale = list(mean = 0.2, sd = 0.02)) obs_opts <- function(family = c("negbin", "poisson"), - phi = list(mean = 0, sd = 0.25), + phi = Normal(mean = 0, sd = 0.25), weight = 1, week_effect = TRUE, week_length = 7, - scale = 1, + scale = Fixed(1), na = c("missing", "accumulate"), likelihood = TRUE, return_likelihood = FALSE) { @@ -672,16 +706,32 @@ obs_opts <- function(family = c("negbin", "poisson"), for (param in c("phi", "scale")) { if (is.numeric(obs[[param]])) { - obs[[param]] <- list(mean = obs[[param]], sd = 0) - } - if (!(all(c("mean", "sd") %in% names(obs[[param]])))) { - cli_abort( + cli_warn( c( - "!" = "Both a {.var mean} and {.var sd} are needed if specifying - {.strong {param}} as list.", - "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + "!" = "Specifying {.var {param}} as a numeric value is deprecated.", + "i" = "Use a {.cls dist_spec} instead using {.fn Fixed()}." + ) + ) + obs[[param]] <- Fixed(obs[[param]]) + } else if (is.list(obs[[param]]) && !is(obs[[param]], "dist_spec")) { + cli_warn( + c( + "!" = "Specifying {.var {param}} as a list is deprecated.", + "i" = "Use a {.cls dist_spec} instead." ) ) + if (!(all(c("mean", "sd") %in% names(obs[[param]])))) { + cli_abort( + c( + "!" = "Both a {.var mean} and {.var sd} are needed if specifying + {.var {param}} as list.", + "i" = "Did you forget to specify {.var mean} and/or {.var sd}?" + ) + ) + } + obs[[param]] <- Normal(mean = obs[[param]]$mean, sd = obs[[param]]$sd) + } else { + assert_class(obs[[param]], "dist_spec") } } diff --git a/R/simulate_infections.R b/R/simulate_infections.R index 3170220e7..1d9175cc5 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -152,7 +152,7 @@ simulate_infections <- function(estimates, R, initial_infections, obs, dates = R$date )) - if (data$obs_scale_sd > 0) { + if (get_distribution(obs$scale) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain observation scaling.", @@ -160,16 +160,9 @@ simulate_infections <- function(estimates, R, initial_infections, ) ) } - if (data$obs_scale) { - data$frac_obs <- array(data$obs_scale_mean, dim = c(1, 1)) - } else { - data$frac_obs <- array(dim = c(1, 0)) - } - data$obs_scale_mean <- NULL - data$obs_scale_sd <- NULL if (obs$family == "negbin") { - if (data$phi_sd > 0) { + if (get_distribution(obs$phi) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain overdispersion.", @@ -177,12 +170,18 @@ simulate_infections <- function(estimates, R, initial_infections, ) ) } - data$rep_phi <- array(data$phi_mean, dim = c(1, 1)) } else { - data$rep_phi <- array(dim = c(1, 0)) + obs$phi <- NULL } - data$phi_mean <- NULL - data$phi_sd <- NULL + + data <- c(data, create_stan_params( + alpha = NULL, + R0 = NULL, + frac_obs = obs$scale, + rep_phi = obs$phi + )) + ## set empty params matrix - variable parameters not supported here + data$params <- array(dim = c(1, 0)) ## day of week effect if (is.null(day_of_week_effect)) { @@ -436,7 +435,7 @@ forecast_infections <- function(estimates, ## allocate empty parameters data <- allocate_empty( - data, c("frac_obs", "delay_params", "rep_phi"), + data, c("delay_params", "params"), n = data$n ) diff --git a/R/simulate_secondary.R b/R/simulate_secondary.R index 0bcb82314..df112fbcd 100644 --- a/R/simulate_secondary.R +++ b/R/simulate_secondary.R @@ -94,7 +94,7 @@ simulate_secondary <- function(primary, obs, dates = primary$date )) - if (data$obs_scale_sd > 0) { + if (get_distribution(obs$scale) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain observation scaling.", @@ -102,16 +102,9 @@ simulate_secondary <- function(primary, ) ) } - if (data$obs_scale) { - data$frac_obs <- array(data$obs_scale_mean, dim = c(1, 1)) - } else { - data$frac_obs <- array(dim = c(1, 0)) - } - data$obs_scale_mean <- NULL - data$obs_scale_sd <- NULL if (obs$family == "negbin") { - if (data$phi_sd > 0) { + if (get_distribution(obs$phi) != "fixed") { cli_abort( c( "!" = "Cannot simulate from uncertain overdispersion.", @@ -119,12 +112,17 @@ simulate_secondary <- function(primary, ) ) } - data$rep_phi <- array(data$phi_mean, dim = c(1, 1)) } else { - data$rep_phi <- array(dim = c(1, 0)) + obs$phi <- NULL } - data$phi_mean <- NULL - data$phi_sd <- NULL + + data <- c(data, create_stan_params( + frac_obs = obs$scale, + rep_phi = obs$phi + )) + + ## set empty params matrix - variable parameters not supported here + data$params <- array(dim = c(1, 0)) ## day of week effect if (is.null(day_of_week_effect)) { From 7f491477bb21b4e0890992cfd625cc061daf9428 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 09:40:37 +0000 Subject: [PATCH 3/9] render docs --- NAMESPACE | 2 ++ man/create_stan_params.Rd | 22 ++++++++++++++++++++++ man/equals-.dist_spec.Rd | 27 +++++++++++++++++++++++++++ man/gp_opts.Rd | 26 +++++++++++++++----------- man/obs_opts.Rd | 28 ++++++++++++++-------------- man/rt_opts.Rd | 9 +++++---- 6 files changed, 85 insertions(+), 29 deletions(-) create mode 100644 man/create_stan_params.Rd create mode 100644 man/equals-.dist_spec.Rd diff --git a/NAMESPACE b/NAMESPACE index 8b17878e3..1a3a8294a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,8 @@ # Generated by roxygen2: do not edit by hand +S3method("!=",dist_spec) S3method("+",dist_spec) +S3method("==",dist_spec) S3method(c,dist_spec) S3method(collapse,dist_spec) S3method(collapse,multi_dist_spec) diff --git a/man/create_stan_params.Rd b/man/create_stan_params.Rd new file mode 100644 index 000000000..6a2e11bdc --- /dev/null +++ b/man/create_stan_params.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/create.R +\name{create_stan_params} +\alias{create_stan_params} +\title{Create parameters for stan} +\usage{ +create_stan_params(..., lower_bounds = NULL) +} +\arguments{ +\item{...}{Named delay distributions. The names are assigned to IDs} + +\item{lower_bounds}{Named vector of lower bounds for any delay(s). The names +have to correspond to the names given to the delay distributions passed. +If \code{NULL} (default) no parameters are given a lower bound.} +} +\value{ +A list of variables as expected by the stan model +} +\description{ +Create parameters for stan +} +\keyword{internal} diff --git a/man/equals-.dist_spec.Rd b/man/equals-.dist_spec.Rd new file mode 100644 index 000000000..879c0331d --- /dev/null +++ b/man/equals-.dist_spec.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dist_spec.R +\name{==.dist_spec} +\alias{==.dist_spec} +\alias{!=.dist_spec} +\title{Compares two delay distributions} +\usage{ +\method{==}{dist_spec}(e1, e2) + +\method{!=}{dist_spec}(e1, e2) +} +\arguments{ +\item{e1}{The first delay distribution (of type ) to +combine.} + +\item{e2}{The second delay distribution (of type ) to +combine.} +} +\value{ +TRUE or FALSE +} +\description{ +Compares two delay distributions +} +\examples{ +Fixed(1) == Normal(1, 0.5) +} diff --git a/man/gp_opts.Rd b/man/gp_opts.Rd index 4b21c2494..3bbe91930 100644 --- a/man/gp_opts.Rd +++ b/man/gp_opts.Rd @@ -11,12 +11,13 @@ gp_opts( ls_sd = 7, ls_min = 0, ls_max = 60, - alpha_mean = 0, - alpha_sd = 0.01, + alpha = Normal(mean = 0, sd = 0.01), kernel = c("matern", "se", "ou", "periodic"), matern_order = 3/2, matern_type, - w0 = 1 + w0 = 1, + alpha_mean, + alpha_sd ) } \arguments{ @@ -45,14 +46,13 @@ process length scale will be used with recommended parameters scale. Updated in \code{\link[=create_gp_data]{create_gp_data()}} to be the length of the input data if this is smaller.} -\item{alpha_mean}{Numeric, defaults to 0. The mean of the magnitude parameter -of the Gaussian process kernel. Should be approximately the expected standard -deviation of the Gaussian process (logged Rt in case of the renewal model, -logged infections in case of the nonmechanistic model).} - -\item{alpha_sd}{Numeric, defaults to 0.01. The standard deviation of the -magnitude parameter of the Gaussian process kernel. Can be tuned to adjust -how far alpha is allowed to deviate form its prior mean (\code{alpha_mean}).} +\item{alpha}{A \verb{} giving the prior distribution of the magnitude +parameter of the Gaussian process kernel. Should be approximately the +expected standard deviation of the Gaussian process (logged Rt in case of +the renewal model, logged infections in case of the nonmechanistic model). +Defaults to a half-normal distribution with mean 0 and sd 0.01: +\code{Normal(mean = 0, sd = 0.01)} (a lower limit of 0 will be enforced +automatically to ensure positivity)} \item{kernel}{Character string, the type of kernel required. Currently supporting the Matern kernel ("matern"), squared exponential kernel ("se"), @@ -69,6 +69,10 @@ Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported.} \item{w0}{Numeric, defaults to 1.0. Fundamental frequency for periodic kernel. They are only used if \code{kernel} is set to "periodic".} + +\item{alpha_mean}{Deprecated; use \code{alpha} instead.} + +\item{alpha_sd}{Deprecated; use \code{alpha} instead.} } \value{ A \verb{} object of settings defining the Gaussian process diff --git a/man/obs_opts.Rd b/man/obs_opts.Rd index 36f1d9ed2..0bf4579a0 100644 --- a/man/obs_opts.Rd +++ b/man/obs_opts.Rd @@ -6,11 +6,11 @@ \usage{ obs_opts( family = c("negbin", "poisson"), - phi = list(mean = 0, sd = 0.25), + phi = Normal(mean = 0, sd = 0.25), weight = 1, week_effect = TRUE, week_length = 7, - scale = 1, + scale = Fixed(1), na = c("missing", "accumulate"), likelihood = TRUE, return_likelihood = FALSE @@ -20,13 +20,12 @@ obs_opts( \item{family}{Character string defining the observation model. Options are Negative binomial ("negbin"), the default, and Poisson.} -\item{phi}{Overdispersion parameter of the reporting process, used only if -\code{familiy} is "negbin". Can be supplied either as a single numeric value -(fixed overdispersion) or a list with numeric elements mean (\code{mean}) and -standard deviation (\code{sd}) defining a normally distributed prior. -Internally parameterised such that the overdispersion is one over the -square of this prior overdispersion. Defaults to a list with elements -\code{mean = 0} and \code{sd = 0.25}.} +\item{phi}{A \verb{} specifying a prior on the overdispersion parameter +of the reporting process, used only if \code{familiy} is "negbin". Internally +parameterised such that the overdispersion is one over the square of this +prior overdispersion phi. Defaults to a half-normal distribution with mean +of 0 and standard deviation of 0.25: \code{Normal(mean = 0, sd = 0.25)}. A lower +limit of zero will be enforced automatically.} \item{weight}{Numeric, defaults to 1. Weight to give the observed data in the log density.} @@ -38,11 +37,12 @@ effect be used in the observation model.} 7 days. This can be modified if data aggregated over a period other than a week or if data has a non-weekly periodicity.} -\item{scale}{Scaling factor to be applied to map latent infections (convolved -to date of report). Can be supplied either as a single numeric value (fixed -scale) or a list with numeric elements mean (\code{mean}) and standard deviation -(\code{sd}) defining a normally distributed scaling factor. Defaults to 1, i.e. -no scaling.} +\item{scale}{A \verb{} specifying a prior on the scaling factor to be +applied to map latent infections (convolved to date of report). Defaults +to a fixed value of 1, i.e. no scaling: \code{Fixed(1)}. A lower limit of zero +will be enforced automatically. If setting to a prior distribution and no +overreporting is expected, it might be sensible to set a maximum of 1 via +the \code{max} option when declaring the distribution.} \item{na}{Character. Options are "missing" (the default) and "accumulate". This determines how NA values in the data are interpreted. If set to diff --git a/man/rt_opts.Rd b/man/rt_opts.Rd index 24774d891..8d796d583 100644 --- a/man/rt_opts.Rd +++ b/man/rt_opts.Rd @@ -5,7 +5,7 @@ \title{Time-Varying Reproduction Number Options} \usage{ rt_opts( - prior = list(mean = 1, sd = 1), + prior = LogNormal(mean = 1, sd = 1), use_rt = TRUE, rw = 0, use_breakpoints = TRUE, @@ -15,9 +15,10 @@ rt_opts( ) } \arguments{ -\item{prior}{List containing named numeric elements "mean" and "sd". The -mean and standard deviation of the log normal Rt prior. Defaults to mean of -1 and standard deviation of 1.} +\item{prior}{A \verb{} giving the prior of the initial reproduciton +number. Ignored if \code{use_rt} is \code{FALSE}. Defaults to a LogNormal distributin +with mean of 1 and standard deviation of 1: \code{LogNormal(mean = 1, sd = 1)}. +A lower limit of 0 will be enforced automatically.} \item{use_rt}{Logical, defaults to \code{TRUE}. Should Rt be used to generate infections and hence reported cases.} From a24dcaad4f06fa5b320d0db7a4ed737b252f8b6c Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 09:40:46 +0000 Subject: [PATCH 4/9] update tests --- tests/testthat/test-create_gp_data.R | 2 +- tests/testthat/test-create_obs_model.R | 40 +---------------- tests/testthat/test-create_rt_date.R | 5 --- tests/testthat/test-create_stan_params.R | 53 +++++++++++++++++++++++ tests/testthat/test-estimate_secondary.R | 25 ++++++----- tests/testthat/test-gp_opts.R | 2 +- tests/testthat/test-obs_opts.R | 4 +- tests/testthat/test-rt_opts.R | 23 ++++++---- tests/testthat/test-simulate-infections.R | 4 +- tests/testthat/test-simulate-secondary.R | 4 +- tests/testthat/test-stan-rt.R | 22 +++++----- 11 files changed, 102 insertions(+), 82 deletions(-) create mode 100644 tests/testthat/test-create_stan_params.R diff --git a/tests/testthat/test-create_gp_data.R b/tests/testthat/test-create_gp_data.R index b1f7cc765..d33dccd76 100644 --- a/tests/testthat/test-create_gp_data.R +++ b/tests/testthat/test-create_gp_data.R @@ -11,7 +11,7 @@ test_that("create_gp_data returns correct default values when GP is disabled", { expect_equal(gp_data$ls_sdlog, convert_to_logsd(21, 7)) expect_equal(gp_data$ls_min, 0) expect_equal(gp_data$ls_max, 3.54, tolerance = 0.01) - expect_equal(gp_data$alpha_sd, 0.01) + expect_equal(gp_data$alpha, NULL) expect_equal(gp_data$gp_type, 2) # Default to Matern expect_equal(gp_data$nu, 3 / 2) expect_equal(gp_data$w0, 1.0) diff --git a/tests/testthat/test-create_obs_model.R b/tests/testthat/test-create_obs_model.R index b704ca54d..66db06992 100644 --- a/tests/testthat/test-create_obs_model.R +++ b/tests/testthat/test-create_obs_model.R @@ -3,10 +3,9 @@ dates <- seq(as.Date("2020-03-15"), by = "days", length.out = 15) test_that("create_obs_model works with default settings", { obs <- create_obs_model(dates = dates) - expect_equal(length(obs), 12) + expect_equal(length(obs), 8) expect_equal(names(obs), c( - "model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight", - "obs_scale", "obs_scale_mean", "obs_scale_sd", "accumulate", + "model_type", "week_effect", "obs_weight", "obs_scale", "accumulate", "likelihood", "return_likelihood", "day_of_week" )) expect_equal(obs$model_type, 1) @@ -15,8 +14,6 @@ test_that("create_obs_model works with default settings", { expect_equal(obs$likelihood, 1) expect_equal(obs$return_likelihood, 0) expect_equal(obs$day_of_week, c(7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7)) - expect_equal(obs$obs_scale_mean, 1) - expect_equal(obs$obs_scale_sd, 0) }) test_that("create_obs_model can be used with a Poisson model", { @@ -24,24 +21,6 @@ test_that("create_obs_model can be used with a Poisson model", { expect_equal(obs$model_type, 0) }) -test_that("create_obs_model can be used with a scaling", { - obs <- create_obs_model( - dates = dates, - obs = obs_opts(scale = list(mean = 0.4, sd = 0.01)) - ) - expect_equal(obs$obs_scale_mean, 0.4) - expect_equal(obs$obs_scale_sd, 0.01) -}) - -test_that("create_obs_model can be used with fixed scaling", { - obs <- create_obs_model( - dates = dates, - obs = obs_opts(scale = 0.4) - ) - expect_equal(obs$obs_scale_mean, 0.4) - expect_equal(obs$obs_scale_sd, 0) -}) - test_that("create_obs_model can be used with no week effect", { obs <- create_obs_model(dates = dates, obs = obs_opts(week_effect = FALSE)) expect_equal(obs$week_effect, 1) @@ -52,18 +31,3 @@ test_that("create_obs_model can be used with a custom week length", { obs <- create_obs_model(dates = dates, obs = obs_opts(week_length = 3)) expect_equal(obs$day_of_week, c(3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2)) }) - -test_that("create_obs_model can be used with a user set phi", { - obs <- create_obs_model( - dates = dates, obs = obs_opts(phi = list(mean = 10, sd = 0.1)) - ) - expect_equal(obs$phi_mean, 10) - expect_equal(obs$phi_sd, 0.1) - obs <- create_obs_model( - dates = dates, - obs = obs_opts(phi = 0.5) - ) - expect_equal(obs$phi_mean, 0.5) - expect_equal(obs$phi_sd, 0) - expect_error(obs_opts(phi = c("Hi", "World"))) -}) diff --git a/tests/testthat/test-create_rt_date.R b/tests/testthat/test-create_rt_date.R index 748ae80d4..3fc47b1cc 100644 --- a/tests/testthat/test-create_rt_date.R +++ b/tests/testthat/test-create_rt_date.R @@ -2,8 +2,6 @@ test_that("create_rt_data returns expected default values", { result <- create_rt_data() expect_type(result, "list") - expect_equal(result$r_mean, 1) - expect_equal(result$r_sd, 1) expect_equal(result$estimate_r, 1) expect_equal(result$bp_n, 0) expect_equal(result$breakpoints, numeric(0)) @@ -24,7 +22,6 @@ test_that("create_rt_data handles NULL rt input correctly", { test_that("create_rt_data handles custom rt_opts correctly", { custom_rt <- rt_opts( - prior = list(mean = 2, sd = 0.5), use_rt = FALSE, rw = 0, use_breakpoints = FALSE, @@ -35,8 +32,6 @@ test_that("create_rt_data handles custom rt_opts correctly", { result <- create_rt_data(rt = custom_rt, horizon = 7) - expect_equal(result$r_mean, 2) - expect_equal(result$r_sd, 0.5) expect_equal(result$estimate_r, 0) expect_equal(result$pop, 1000000) expect_equal(result$stationary, 1) diff --git a/tests/testthat/test-create_stan_params.R b/tests/testthat/test-create_stan_params.R new file mode 100644 index 000000000..f0b8a7b4b --- /dev/null +++ b/tests/testthat/test-create_stan_params.R @@ -0,0 +1,53 @@ +test_that("create_stan_params can be used with a scaling", { + obs <- obs_opts(scale = Normal(mean = 0.4, sd = 0.01)) + params <- create_stan_params( + frac_obs = obs$scale, lower_bounds = c(frac_obs = 0) + ) + expect_equal(params$prior_dist, array(2L)) + expect_equal(params$prior_dist_params, array(c(0.4, 0.01))) + expect_equal(params$params_lower, array(0)) + expect_equal(params$frac_obs_id, 1L) +}) + +test_that("create_stan_params can be used with fixed scaling", { + obs <- obs_opts(scale = Fixed(0.4)) + params <- create_stan_params( + frac_obs = obs$scale + ) + expect_equal(params$params_value, array(0.4)) + expect_equal(length(params$prior_dist_params), 0L) +}) + +test_that("create_stan_params can be used with a user set phi", { + obs <- obs_opts( + phi = Normal(mean = 10, sd = 0.1) + ) + params <- create_stan_params( + phi = obs$phi + ) + expect_equal(params$prior_dist, array(2L)) + expect_equal(params$prior_dist_params, array(c(10, 0.1))) + expect_equal(params$phi_id, 1L) +}) + +test_that("create_stan_params can be used with fixed phi", { + obs <- obs_opts(phi = Fixed(0.5)) + params <- create_stan_params( + phi = obs$phi + ) + expect_equal(params$params_value, array(0.5)) + expect_equal(length(params$prior_dist_params), 0L) +}) + +test_that("create_stan_params can be used with NULL parameters", { + params <- create_stan_params( + test = NULL + ) + expect_equal(params$test_id, 0) +}) + +test_that("create_stan_params warns about uncertain parameters", { + expect_error(create_stan_params( + test = Normal(mean = 0, sd = Normal(1, 1)) + ), "cannot have uncertain parameters") +}) diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index 9b46e7641..f0445053b 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -23,14 +23,16 @@ inc_cases[ # fit model to example data specifying a weak prior for fraction reported # with a secondary case inc <- estimate_secondary(inc_cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts( + scale = Normal(mean = 0.2, sd = 0.2, max = 1), week_effect = FALSE + ), verbose = FALSE ) # extract posterior variables of interest params <- c( "meanlog" = "delay_params[1]", "sdlog" = "delay_params[2]", - "scaling" = "frac_obs[1]" + "scaling" = "params[1]" ) inc_posterior <- inc$posterior[variable %in% params] @@ -58,7 +60,7 @@ prev <- estimate_secondary(prev_cases[1:100], secondary = secondary_opts(type = "prevalence"), obs = obs_opts( week_effect = FALSE, - scale = list(mean = 0.4, sd = 0.1) + scale = Normal(mean = 0.4, sd = 0.1) ), verbose = FALSE ) @@ -90,7 +92,7 @@ test_that("estimate_secondary successfully returns estimates when passed NA valu delays = delay_opts( LogNormal(meanlog = 1.8, sdlog = 0.5, max = 30) ), - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE ) prev_cases_na <- data.table::copy(prev_cases) @@ -100,7 +102,7 @@ test_that("estimate_secondary successfully returns estimates when passed NA valu delays = delay_opts( LogNormal(mean = 1.8, sd = 0.5, max = 30) ), - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE ) expect_true(is.list(inc_na$data)) @@ -122,7 +124,8 @@ test_that("estimate_secondary successfully returns estimates when accumulating t ) ), obs = obs_opts( - scale = list(mean = 0.4, sd = 0.05), week_effect = FALSE, na = "accumulate" + scale = Normal(mean = 0.4, sd = 0.05), week_effect = FALSE, + na = "accumulate" ), verbose = FALSE ) expect_true(is.list(inc_weekly$data)) @@ -130,7 +133,7 @@ test_that("estimate_secondary successfully returns estimates when accumulating t test_that("estimate_secondary works when only estimating scaling", { inc <- estimate_secondary(inc_cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), delay = delay_opts(), verbose = FALSE ) @@ -159,7 +162,7 @@ test_that("estimate_secondary can recover simulated parameters with the skip_on_os("windows") output <- capture.output(suppressMessages(suppressWarnings( inc_cmdstanr <- estimate_secondary(inc_cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), verbose = FALSE, stan = stan_opts(backend = "cmdstanr") ) ))) @@ -212,7 +215,7 @@ test_that("estimate_secondary works with weigh_delay_priors = TRUE", { ) inc_weigh <- estimate_secondary( inc_cases[1:60], delays = delay_opts(delays), - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), weigh_delay_priors = TRUE, verbose = FALSE ) expect_s3_class(inc_weigh, "estimate_secondary") @@ -222,7 +225,7 @@ test_that("estimate_secondary works with filter_leading_zeros set", { modified_data <- inc_cases[1:10, secondary := 0] out <- estimate_secondary( modified_data, - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), filter_leading_zeros = TRUE, verbose = FALSE @@ -236,7 +239,7 @@ test_that("estimate_secondary works with zero_threshold set", { modified_data <- inc_cases[sample(1:30, 10), primary := 0] out <- estimate_secondary( modified_data, - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE), zero_threshold = 10, verbose = FALSE diff --git a/tests/testthat/test-gp_opts.R b/tests/testthat/test-gp_opts.R index cd848b75c..47e4f6186 100644 --- a/tests/testthat/test-gp_opts.R +++ b/tests/testthat/test-gp_opts.R @@ -6,7 +6,7 @@ test_that("gp_opts returns correct default values", { expect_equal(gp$ls_sd, 7) expect_equal(gp$ls_min, 0) expect_equal(gp$ls_max, 60) - expect_equal(gp$alpha_sd, 0.01) + expect_equal(gp$alpha, Normal(0, 0.01)) expect_equal(gp$kernel, "matern") expect_equal(gp$matern_order, 3 / 2) expect_equal(gp$w0, 1.0) diff --git a/tests/testthat/test-obs_opts.R b/tests/testthat/test-obs_opts.R index a9cb17db0..dd9ea7b93 100644 --- a/tests/testthat/test-obs_opts.R +++ b/tests/testthat/test-obs_opts.R @@ -6,7 +6,7 @@ test_that("obs_opts returns expected default values", { expect_equal(result$weight, 1) expect_true(result$week_effect) expect_equal(result$week_length, 7L) - expect_equal(result$scale, list(mean = 1, sd = 0)) + expect_equal(result$scale, Normal(mean = 1, sd = 0)) expect_equal(result$accumulate, 0) expect_true(result$likelihood) expect_false(result$return_likelihood) @@ -28,4 +28,4 @@ test_that("obs_opts returns expected messages", { test_that("obs_opts behaves as expected for user specified na treatment", { # If user explicitly specifies NA as missing, then don't throw message expect_false(obs_opts(na = "missing")$na_as_missing_default_used) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-rt_opts.R b/tests/testthat/test-rt_opts.R index 0a39be027..726a1e3ac 100644 --- a/tests/testthat/test-rt_opts.R +++ b/tests/testthat/test-rt_opts.R @@ -2,7 +2,7 @@ test_that("rt_opts returns expected default values", { result <- rt_opts() expect_s3_class(result, "rt_opts") - expect_equal(result$prior, list(mean = 1, sd = 1)) + expect_equal(result$prior, LogNormal(mean = 1, sd = 1)) expect_true(result$use_rt) expect_equal(result$rw, 0) expect_true(result$use_breakpoints) @@ -12,17 +12,17 @@ test_that("rt_opts returns expected default values", { }) test_that("rt_opts handles custom inputs correctly", { - result <- rt_opts( - prior = list(mean = 2, sd = 0.5), + result <- suppressWarnings(rt_opts( + prior = Normal(mean = 2, sd = 0.5), use_rt = FALSE, rw = 7, use_breakpoints = FALSE, future = "project", gp_on = "R0", pop = 1000000 - ) + )) - expect_equal(result$prior, list(mean = 2, sd = 0.5)) + expect_null(result$prior) expect_false(result$use_rt) expect_equal(result$rw, 7) expect_true(result$use_breakpoints) # Should be TRUE when rw > 0 @@ -37,10 +37,15 @@ test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", { }) test_that("rt_opts throws error for invalid prior", { - expect_error(rt_opts(prior = list(mean = 1)), - "must have both") - expect_error(rt_opts(prior = list(sd = 1)), - "must have both") + ## deprecated + expect_error( + suppressWarnings(rt_opts(prior = list(mean = 1))), + "must have both" + ) + expect_error( + suppressWarnings(rt_opts(prior = list(sd = 1))), + "must have both" + ) }) test_that("rt_opts validates gp_on argument", { diff --git a/tests/testthat/test-simulate-infections.R b/tests/testthat/test-simulate-infections.R index 0314806de..3c977d689 100644 --- a/tests/testthat/test-simulate-infections.R +++ b/tests/testthat/test-simulate-infections.R @@ -30,7 +30,7 @@ test_that("simulate_infections works as expected with additional parameters", { sim <- test_simulate_infections( generation_time = gt_opts(fix_parameters(example_generation_time)), delays = delay_opts(fix_parameters(example_reporting_delay)), - obs = obs_opts(family = "negbin", phi = list(mean = 0.5, sd = 0)), + obs = obs_opts(family = "negbin", phi = Normal(mean = 0.5, sd = 0)), seeding_time = 10 ) expect_equal(nrow(sim), 2 * nrow(R)) @@ -49,7 +49,7 @@ test_that("simulate_infections fails with uncertain parameters", { expect_error( test_simulate_infections( generation_time = gt_opts(Fixed(1)), - obs = obs_opts(scale = list(mean = 1, sd = 1)) + obs = obs_opts(scale = Normal(mean = 1, sd = 1)) ), "uncertain" ) diff --git a/tests/testthat/test-simulate-secondary.R b/tests/testthat/test-simulate-secondary.R index f78c91de7..d30e7bcf3 100644 --- a/tests/testthat/test-simulate-secondary.R +++ b/tests/testthat/test-simulate-secondary.R @@ -22,7 +22,7 @@ test_that("simulate_secondary works as expected with additional parameters", { set.seed(123) sim <- test_simulate_secondary( delays = delay_opts(fix_parameters(example_reporting_delay)), - obs = obs_opts(family = "negbin", phi = list(mean = 0.5, sd = 0)) + obs = obs_opts(family = "negbin", phi = Fixed(0.5)) ) expect_equal(nrow(sim), nrow(cases)) expect_snapshot_output(sim) @@ -36,7 +36,7 @@ test_that("simulate_secondary fails with uncertain parameters", { ) expect_error( test_simulate_secondary( - obs = obs_opts(scale = list(mean = 1, sd = 1)) + obs = obs_opts(scale = Normal(mean = 1, sd = 1)) ), "uncertain" ) diff --git a/tests/testthat/test-stan-rt.R b/tests/testthat/test-stan-rt.R index 1b4c40153..7d5fcb17b 100644 --- a/tests/testthat/test-stan-rt.R +++ b/tests/testthat/test-stan-rt.R @@ -4,57 +4,57 @@ skip_on_os("windows") # Test update_Rt test_that("update_Rt works to produce multiple Rt estimates with a static gaussian process", { expect_equal( - update_Rt(10, log(1.2), rep(0, 9), rep(10, 0), numeric(0), 0), + update_Rt(10, 1.2, rep(0, 9), rep(10, 0), numeric(0), 0), rep(1.2, 10) ) }) test_that("update_Rt works to produce multiple Rt estimates with a non-static gaussian process", { expect_equal( - round(update_Rt(10, log(1.2), rep(0.1, 9), rep(10, 0), numeric(0), 0), 2), + round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 2), c(1.20, 1.33, 1.47, 1.62, 1.79, 1.98, 2.19, 2.42, 2.67, 2.95) ) }) test_that("update_Rt works to produce multiple Rt estimates with a non-static stationary gaussian process", { expect_equal( - round(update_Rt(10, log(1.2), rep(0.1, 10), rep(10, 0), numeric(0), 1), 3), + round(update_Rt(10, 1.2, rep(0.1, 10), rep(10, 0), numeric(0), 1), 3), c(1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326) ) }) test_that("update_Rt works when Rt is fixed", { expect_equal( - round(update_Rt(10, log(1.2), numeric(0), rep(10, 0), numeric(0), 0), 2), + round(update_Rt(10, 1.2, numeric(0), rep(10, 0), numeric(0), 0), 2), rep(1.2, 10) ) expect_equal( - round(update_Rt(10, log(1.2), numeric(0), rep(10, 0), numeric(0), 1), 2), + round(update_Rt(10, 1.2, numeric(0), rep(10, 0), numeric(0), 1), 2), rep(1.2, 10) ) }) test_that("update_Rt works when Rt is fixed but a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2), + round(update_Rt(5, 1.2, numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2), + round(update_Rt(5, 1.2, numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2), + round(update_Rt(5, 1.2, numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2), c(1.2, 1.33, rep(1.47, 3)) ) }) test_that("update_Rt works when Rt is variable and a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), + round(update_Rt(5, 1.2, rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2), + round(update_Rt(5, 1.2, rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), + round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.20, 1.33, 1.62, 1.79, 1.98) ) }) From b148f30c8e07eaf57151bd8db5bcdf686a9ee682 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 13:48:45 +0000 Subject: [PATCH 5/9] update examples and other code snippets --- R/create.R | 2 +- R/epinow.R | 2 +- R/estimate_infections.R | 2 +- R/estimate_secondary.R | 4 ++-- R/opts.R | 4 ++-- R/regional_epinow.R | 2 +- R/simulate_infections.R | 4 ++-- data-raw/estimate-infections.R | 4 ++-- inst/dev/benchmark-functions.R | 2 +- inst/dev/recover-synthetic/rt.R | 12 ++++++------ man/create_obs_model.Rd | 2 +- man/epinow.Rd | 2 +- man/estimate_infections.Rd | 2 +- man/estimate_secondary.Rd | 4 ++-- man/forecast_infections.Rd | 4 ++-- man/obs_opts.Rd | 2 +- man/regional_epinow.Rd | 2 +- man/rt_opts.Rd | 2 +- touchstone/script.R | 10 +++++----- vignettes/EpiNow2.Rmd.orig | 4 ++-- vignettes/epinow.Rmd.orig | 2 +- vignettes/estimate_infections_options.Rmd.orig | 2 +- vignettes/estimate_infections_workflow.Rmd.orig | 4 ++-- 23 files changed, 40 insertions(+), 40 deletions(-) diff --git a/R/create.R b/R/create.R index 42231785f..a3b142dbf 100644 --- a/R/create.R +++ b/R/create.R @@ -464,7 +464,7 @@ create_gp_data <- function(gp = gp_opts(), data) { #' #' # Applying a observation scaling to the data #' create_obs_model( -#' obs_opts(scale = list(mean = 0.4, sd = 0.01)), dates = dates +#' obs_opts(scale = Normal(mean = 0.4, sd = 0.01)), dates = dates #' ) #' #' # Apply a custom week week length diff --git a/R/epinow.R b/R/epinow.R index 5baa06093..6afb211f6 100644 --- a/R/epinow.R +++ b/R/epinow.R @@ -65,7 +65,7 @@ #' out <- epinow( #' data = reported_cases, #' generation_time = gt_opts(generation_time), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)), +#' rt = rt_opts(prior = Normal(mean = 2, sd = 0.1)), #' delays = delay_opts(incubation_period + reporting_delay) #' ) #' # summary of the latest estimates diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 7af5b2f1a..b23a1f23c 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -101,7 +101,7 @@ #' def <- estimate_infections(reported_cases, #' generation_time = gt_opts(generation_time), #' delays = delay_opts(incubation_period + reporting_delay), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)) +#' rt = rt_opts(prior = Normal(mean = 2, sd = 0.1)) #' ) #' # real time estimates #' summary(def) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index a316c0fb2..1be744195 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -95,7 +95,7 @@ #' # fit model to example data specifying a weak prior for fraction reported #' # with a secondary case #' inc <- estimate_secondary(cases[1:60], -#' obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE) +#' obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE) #' ) #' plot(inc, primary = TRUE) #' @@ -123,7 +123,7 @@ #' secondary = secondary_opts(type = "prevalence"), #' obs = obs_opts( #' week_effect = FALSE, -#' scale = list(mean = 0.4, sd = 0.1) +#' scale = Normal(mean = 0.4, sd = 0.1) #' ) #' ) #' plot(prev, primary = TRUE) diff --git a/R/opts.R b/R/opts.R index e8533e00d..5f5a8db8c 100644 --- a/R/opts.R +++ b/R/opts.R @@ -340,7 +340,7 @@ trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001, #' rt_opts() #' #' # add a custom length scale -#' rt_opts(prior = list(mean = 2, sd = 1)) +#' rt_opts(prior = Normal(mean = 2, sd = 1)) #' #' # add a weekly random walk #' rt_opts(rw = 7) @@ -653,7 +653,7 @@ gp_opts <- function(basis_prop = 0.2, #' obs_opts(week_effect = TRUE) #' #' # Scale reported data -#' obs_opts(scale = list(mean = 0.2, sd = 0.02)) +#' obs_opts(scale = Normal(mean = 0.2, sd = 0.02)) obs_opts <- function(family = c("negbin", "poisson"), phi = Normal(mean = 0, sd = 0.25), weight = 1, diff --git a/R/regional_epinow.R b/R/regional_epinow.R index 09d305031..a5f08d062 100644 --- a/R/regional_epinow.R +++ b/R/regional_epinow.R @@ -79,7 +79,7 @@ #' data = cases, #' generation_time = gt_opts(example_generation_time), #' delays = delay_opts(example_incubation_period + example_reporting_delay), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.2)), +#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), #' stan = stan_opts( #' samples = 100, warmup = 200 #' ), diff --git a/R/simulate_infections.R b/R/simulate_infections.R index 1d9175cc5..f3b66cc3f 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -277,8 +277,8 @@ simulate_infections <- function(estimates, R, initial_infections, #' est <- estimate_infections(reported_cases, #' generation_time = generation_time_opts(example_generation_time), #' delays = delay_opts(example_incubation_period + example_reporting_delay), -#' rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7), -#' obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)), +#' rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1), rw = 7), +#' obs = obs_opts(scale = Normal(mean = 0.1, sd = 0.01)), #' gp = NULL, horizon = 0 #' ) #' diff --git a/data-raw/estimate-infections.R b/data-raw/estimate-infections.R index 24f1215c1..165a767d5 100644 --- a/data-raw/estimate-infections.R +++ b/data-raw/estimate-infections.R @@ -14,7 +14,7 @@ reporting_delay <- LogNormal(mean = 2, sd = 1, max = 10L) example_estimate_infections <- estimate_infections(reported_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1)), stan = stan_opts(samples = 200, control = list(adapt_delta = 0.95)) ) @@ -28,7 +28,7 @@ example_regional_epinow <- regional_epinow( generation_time = gt_opts(example_generation_time), data = cases, delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), stan = stan_opts(samples = 200, control = list(adapt_delta = 0.95)) ) diff --git a/inst/dev/benchmark-functions.R b/inst/dev/benchmark-functions.R index 3e884e72b..a0409ef94 100644 --- a/inst/dev/benchmark-functions.R +++ b/inst/dev/benchmark-functions.R @@ -17,7 +17,7 @@ create_profiles <- function(dir = file.path("inst", "stan"), data = reported_cases, generation_time = gt_opts(fixed_generation_time), delays = delay_opts(delays), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), stan = stan_opts( samples = 1000, chains = 2, object = compiled_model, cores = 2 diff --git a/inst/dev/recover-synthetic/rt.R b/inst/dev/recover-synthetic/rt.R index a20428701..fd223ae3e 100644 --- a/inst/dev/recover-synthetic/rt.R +++ b/inst/dev/recover-synthetic/rt.R @@ -7,14 +7,14 @@ old_opts <- options() options(mc.cores = 4) #' get example delays -obs <- obs_opts(scale = list(mean = 0.1, sd = 0.025), return_likelihood = TRUE) +obs <- obs_opts(scale = Normal(mean = 0.1, sd = 0.025), return_likelihood = TRUE) # fit model to data to recover realistic parameter estimates and define settings # shared simulation settings init <- estimate_infections(example_confirmed[1:100], generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 14), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1), rw = 14), gp = NULL, horizon = 0, obs = obs ) @@ -59,7 +59,7 @@ for (method in c("nuts")) { estimate_infections(sim_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.25)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.25)), stan = stanopts, obs = obs, horizon = 0 @@ -90,7 +90,7 @@ for (method in c("nuts")) { generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), rt = rt_opts( - prior = list(mean = 2, sd = 0.25), + prior = LogNormal(mean = 2, sd = 0.25), rw = 7 ), gp = NULL, @@ -109,7 +109,7 @@ for (method in c("nuts")) { generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), rt = rt_opts( - prior = list(mean = 2, sd = 0.25), rw = 14, gp_on = "R0" + prior = LogNormal(mean = 2, sd = 0.25), rw = 14, gp_on = "R0" ), stan = stanopts, obs = obs, @@ -130,7 +130,7 @@ for (method in c("nuts")) { example_incubation_period + example_reporting_delay ), rt = rt_opts( - prior = list(mean = 2, sd = 0.25), + prior = LogNormal(mean = 2, sd = 0.25), rw = 1 ), gp = NULL, diff --git a/man/create_obs_model.Rd b/man/create_obs_model.Rd index 736385fcf..b2c743fde 100644 --- a/man/create_obs_model.Rd +++ b/man/create_obs_model.Rd @@ -32,7 +32,7 @@ create_obs_model(obs_opts(family = "poisson"), dates = dates) # Applying a observation scaling to the data create_obs_model( - obs_opts(scale = list(mean = 0.4, sd = 0.01)), dates = dates + obs_opts(scale = Normal(mean = 0.4, sd = 0.01)), dates = dates ) # Apply a custom week week length diff --git a/man/epinow.Rd b/man/epinow.Rd index 6046dea56..04b4a58c0 100644 --- a/man/epinow.Rd +++ b/man/epinow.Rd @@ -160,7 +160,7 @@ reported_cases <- example_confirmed[1:40] out <- epinow( data = reported_cases, generation_time = gt_opts(generation_time), - rt = rt_opts(prior = list(mean = 2, sd = 0.1)), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.1)), delays = delay_opts(incubation_period + reporting_delay) ) # summary of the latest estimates diff --git a/man/estimate_infections.Rd b/man/estimate_infections.Rd index cab919f85..e3fa91689 100644 --- a/man/estimate_infections.Rd +++ b/man/estimate_infections.Rd @@ -146,7 +146,7 @@ reporting_delay <- LogNormal(mean = 2, sd = 1, max = 10) def <- estimate_infections(reported_cases, generation_time = gt_opts(generation_time), delays = delay_opts(incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1)) + rt = rt_opts(prior = Normal(mean = 2, sd = 0.1)) ) # real time estimates summary(def) diff --git a/man/estimate_secondary.Rd b/man/estimate_secondary.Rd index 06ff2e848..359c79653 100644 --- a/man/estimate_secondary.Rd +++ b/man/estimate_secondary.Rd @@ -136,7 +136,7 @@ cases <- convolve_and_scale(cases, type = "incidence") # fit model to example data specifying a weak prior for fraction reported # with a secondary case inc <- estimate_secondary(cases[1:60], - obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE) + obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE) ) plot(inc, primary = TRUE) @@ -164,7 +164,7 @@ prev <- estimate_secondary(cases[1:100], secondary = secondary_opts(type = "prevalence"), obs = obs_opts( week_effect = FALSE, - scale = list(mean = 0.4, sd = 0.1) + scale = Normal(mean = 0.4, sd = 0.1) ) ) plot(prev, primary = TRUE) diff --git a/man/forecast_infections.Rd b/man/forecast_infections.Rd index e9c4fbde4..24c5b2c5d 100644 --- a/man/forecast_infections.Rd +++ b/man/forecast_infections.Rd @@ -65,8 +65,8 @@ reported_cases <- example_confirmed[1:50] est <- estimate_infections(reported_cases, generation_time = generation_time_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7), - obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.1), rw = 7), + obs = obs_opts(scale = Normal(mean = 0.1, sd = 0.01)), gp = NULL, horizon = 0 ) diff --git a/man/obs_opts.Rd b/man/obs_opts.Rd index 0bf4579a0..27b37b2e9 100644 --- a/man/obs_opts.Rd +++ b/man/obs_opts.Rd @@ -76,5 +76,5 @@ obs_opts() obs_opts(week_effect = TRUE) # Scale reported data -obs_opts(scale = list(mean = 0.2, sd = 0.02)) +obs_opts(scale = Normal(mean = 0.2, sd = 0.02)) } diff --git a/man/regional_epinow.Rd b/man/regional_epinow.Rd index 4a1c754f7..6eca0c83a 100644 --- a/man/regional_epinow.Rd +++ b/man/regional_epinow.Rd @@ -156,7 +156,7 @@ def <- regional_epinow( data = cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + example_reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = LogNormal(mean = 2, sd = 0.2)), stan = stan_opts( samples = 100, warmup = 200 ), diff --git a/man/rt_opts.Rd b/man/rt_opts.Rd index 8d796d583..99e3d1e05 100644 --- a/man/rt_opts.Rd +++ b/man/rt_opts.Rd @@ -70,7 +70,7 @@ defaults. rt_opts() # add a custom length scale -rt_opts(prior = list(mean = 2, sd = 1)) +rt_opts(prior = Normal(mean = 2, sd = 1)) # add a weekly random walk rt_opts(rw = 7) diff --git a/touchstone/script.R b/touchstone/script.R index a0c37b300..daf426cdb 100644 --- a/touchstone/script.R +++ b/touchstone/script.R @@ -11,7 +11,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), delays = delay_opts(fixed_delays), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -27,7 +27,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(example_generation_time), delays = delays, - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -42,7 +42,7 @@ touchstone::benchmark_run( no_delays = { epinow( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -58,7 +58,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), delays = delay_opts(fixed_delays), - rt = rt_opts(prior = list(mean = 2, sd = 0.2), gp_on = "R0"), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2), gp_on = "R0"), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -74,7 +74,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), delays = delay_opts(fixed_delays), - rt = rt_opts(prior = list(mean = 2, sd = 0.2), rw = 7), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2), rw = 7), gp = NULL, stan = stan_opts( cores = 2, samples = 500, chains = 2, diff --git a/vignettes/EpiNow2.Rmd.orig b/vignettes/EpiNow2.Rmd.orig index 0b27fa285..f71463cde 100644 --- a/vignettes/EpiNow2.Rmd.orig +++ b/vignettes/EpiNow2.Rmd.orig @@ -94,7 +94,7 @@ estimates <- epinow( data = reported_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), stan = stan_opts(cores = 4, control = list(adapt_delta = 0.99)), verbose = interactive() ) @@ -148,7 +148,7 @@ estimates <- regional_epinow( data = reported_cases, generation_time = gt_opts(example_generation_time), delays = delay_opts(example_incubation_period + reporting_delay), - rt = rt_opts(prior = list(mean = 2, sd = 0.2), rw = 7), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2), rw = 7), gp = NULL, stan = stan_opts(cores = 4, warmup = 250, samples = 1000) ) diff --git a/vignettes/epinow.Rmd.orig b/vignettes/epinow.Rmd.orig index e8da3777f..bdcc72621 100644 --- a/vignettes/epinow.Rmd.orig +++ b/vignettes/epinow.Rmd.orig @@ -40,7 +40,7 @@ options(mc.cores = 4) reported_cases <- example_confirmed[1:60] reporting_delay <- LogNormal(mean = 2, sd = 1, max = 10) delay <- example_incubation_period + reporting_delay -rt_prior <- list(mean = 2, sd = 0.1) +rt_prior <- Normal(mean = 2, sd = 0.1) ``` We can then run the `epinow()` function with the same arguments as `estimate_infections()`. diff --git a/vignettes/estimate_infections_options.Rmd.orig b/vignettes/estimate_infections_options.Rmd.orig index bdfaaa138..939434814 100644 --- a/vignettes/estimate_infections_options.Rmd.orig +++ b/vignettes/estimate_infections_options.Rmd.orig @@ -97,7 +97,7 @@ example_generation_time Lastly we need to choose a prior for the initial value of the reproduction number. This is assumed by the model to be normally distributed and we can set the mean and the standard deviation. We decide to set the mean to 2 and the standard deviation to 1. ```{r initial_r} -rt_prior <- list(mean = 2, sd = 0.1) +rt_prior <- Normal(mean = 2, sd = 0.1) ``` # Running the model diff --git a/vignettes/estimate_infections_workflow.Rmd.orig b/vignettes/estimate_infections_workflow.Rmd.orig index ca1297472..2c8f48a19 100644 --- a/vignettes/estimate_infections_workflow.Rmd.orig +++ b/vignettes/estimate_infections_workflow.Rmd.orig @@ -196,7 +196,7 @@ In _EpiNow2_ we can specify the proportion of infections that we expect to be ob For example, if we think that 40% (with standard deviation 1%) of infections end up in the data as observations we could specify. ```{r results = 'hide'} -obs_scale <- list(mean = 0.4, sd = 0.01) +obs_scale <- Normal(mean = 0.4, sd = 0.01) obs_opts(scale = obs_scale) ``` @@ -209,7 +209,7 @@ It can be changed using the `rt_opts()` function. For example, if the user believes that at the very start of the data the reproduction number was 2, with uncertainty in this belief represented by a standard deviation of 1, they would use ```{r results = 'hide'} -rt_prior <- list(mean = 2, sd = 1) +rt_prior <- LogNormal(mean = 2, sd = 1) rt_opts(prior = rt_prior) ``` From 77d96970ab0cf5805b334c946dd8a86c3b5aad55 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 13:59:21 +0000 Subject: [PATCH 6/9] add news item --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index ed92b8f02..be3b3c92f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,7 @@ - A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs. - A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk. +- All parameters have been changed to the new parameter interface. By @sbfnk in # and reviewed by @. # EpiNow2 1.6.1 From 2a2d403d6fa4f0eee8457ce65c21b18de4c92ca9 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 14:59:52 +0000 Subject: [PATCH 7/9] add progressr to lintr workflow --- .github/workflows/lint-only-changed-files.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/lint-only-changed-files.yaml b/.github/workflows/lint-only-changed-files.yaml index 3d67b098d..a77c2d01d 100644 --- a/.github/workflows/lint-only-changed-files.yaml +++ b/.github/workflows/lint-only-changed-files.yaml @@ -29,6 +29,7 @@ jobs: any::gh any::lintr any::purrr + progressr - name: Add lintr options run: | From deedb9e2c276860daef0d54b890b8c47fa267ec2 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 15:35:17 +0000 Subject: [PATCH 8/9] switch benchmarks back to previous syntax otherwise they won't work on main --- inst/dev/benchmark-functions.R | 2 +- touchstone/script.R | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/inst/dev/benchmark-functions.R b/inst/dev/benchmark-functions.R index a0409ef94..3e884e72b 100644 --- a/inst/dev/benchmark-functions.R +++ b/inst/dev/benchmark-functions.R @@ -17,7 +17,7 @@ create_profiles <- function(dir = file.path("inst", "stan"), data = reported_cases, generation_time = gt_opts(fixed_generation_time), delays = delay_opts(delays), - rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), + rt = rt_opts(prior = list(mean = 2, sd = 0.2)), stan = stan_opts( samples = 1000, chains = 2, object = compiled_model, cores = 2 diff --git a/touchstone/script.R b/touchstone/script.R index daf426cdb..a0c37b300 100644 --- a/touchstone/script.R +++ b/touchstone/script.R @@ -11,7 +11,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), delays = delay_opts(fixed_delays), - rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), + rt = rt_opts(prior = list(mean = 2, sd = 0.2)), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -27,7 +27,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(example_generation_time), delays = delays, - rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), + rt = rt_opts(prior = list(mean = 2, sd = 0.2)), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -42,7 +42,7 @@ touchstone::benchmark_run( no_delays = { epinow( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), - rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), + rt = rt_opts(prior = list(mean = 2, sd = 0.2)), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -58,7 +58,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), delays = delay_opts(fixed_delays), - rt = rt_opts(prior = Normal(mean = 2, sd = 0.2), gp_on = "R0"), + rt = rt_opts(prior = list(mean = 2, sd = 0.2), gp_on = "R0"), stan = stan_opts( cores = 2, samples = 500, chains = 2, control = list(adapt_delta = 0.95)), @@ -74,7 +74,7 @@ touchstone::benchmark_run( data = reported_cases, generation_time = generation_time_opts(fixed_generation_time), delays = delay_opts(fixed_delays), - rt = rt_opts(prior = Normal(mean = 2, sd = 0.2), rw = 7), + rt = rt_opts(prior = list(mean = 2, sd = 0.2), rw = 7), gp = NULL, stan = stan_opts( cores = 2, samples = 500, chains = 2, From 333de4628ef95a45e721c11658d1f9d1ad64364f Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 26 Nov 2024 16:44:35 +0000 Subject: [PATCH 9/9] change benchmark back it won't work anyway as the R code has changed too much --- inst/dev/benchmark-functions.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inst/dev/benchmark-functions.R b/inst/dev/benchmark-functions.R index 3e884e72b..a0409ef94 100644 --- a/inst/dev/benchmark-functions.R +++ b/inst/dev/benchmark-functions.R @@ -17,7 +17,7 @@ create_profiles <- function(dir = file.path("inst", "stan"), data = reported_cases, generation_time = gt_opts(fixed_generation_time), delays = delay_opts(delays), - rt = rt_opts(prior = list(mean = 2, sd = 0.2)), + rt = rt_opts(prior = Normal(mean = 2, sd = 0.2)), stan = stan_opts( samples = 1000, chains = 2, object = compiled_model, cores = 2