Skip to content

Commit

Permalink
allow skipping attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffckerr committed Nov 18, 2024
1 parent a56875c commit d77cede
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
20 changes: 15 additions & 5 deletions starsim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,20 @@ def modules(self):
self.analyzers(),
)

def init(self, **kwargs):
""" Perform all initializations for the sim """
def init(self, force=False, **kwargs):
"""
Perform all initializations for the sim
Args:
force (bool): whether to overwrite sim attributes even if they already exist
kwargs (dict): passed to ss.People()
"""
# Validation and initialization -- this is "pre"
ss.set_seed(self.pars.rand_seed) # Reset the seed before the population is created -- shouldn't matter if only using Dist objects
self.pars.validate() # Validate parameters
self.init_time() # Initialize time
self.init_people(**kwargs) # Initialize the people
self.init_sim_attrs()
self.init_sim_attrs(force=force)
self.init_mods_pre()

# Final initializations -- this is "post"
Expand Down Expand Up @@ -183,11 +188,16 @@ def init_people(self, verbose=None, **kwargs):
self.people.link_sim(self)
return self.people

def init_sim_attrs(self):
def init_sim_attrs(self, force=False):
""" Move initialized modules to the sim """
keys = ['label', 'demographics', 'networks', 'diseases', 'interventions', 'analyzers', 'connectors']
for key in keys:
setattr(self, key, self.pars.pop(key))
orig = getattr(self, key, None)
if not force and orig is not None:
warnmsg = f'Skipping key "{key}" in parameters since already present in sim and force=False'
ss.warn(warnmsg)
else:
setattr(self, key, self.pars.pop(key))
return

def init_mods_pre(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_multi_timestep(do_plot=False):


def test_mixed_timesteps():
sc.heading('Test behavior of different commbinations of timesteps')
sc.heading('Test behavior of different combinations of timesteps')

siskw = dict(dur_inf=ss.dur(50, 'day'), beta=ss.beta(0.01, 'day'), waning=ss.rate(0.005, 'day'))
kw = dict(n_agents=1000, start='2001-01-01', stop='2001-07-01', networks='random', copy_inputs=False, verbose=0)
Expand All @@ -168,7 +168,7 @@ def test_mixed_timesteps():

msim = ss.parallel(sim1, sim2, sim3, sim4)

# Check that al results are close
# Check that all results are close
threshold = 0.01
summary = msim.summarize()
for key,res in summary.items():
Expand Down

0 comments on commit d77cede

Please sign in to comment.