Skip to content

Commit

Permalink
Add support for inference with uncorrelated and correlated gaussian n…
Browse files Browse the repository at this point in the history
…oise
  • Loading branch information
JBjoernskov committed Jan 9, 2024
1 parent 71fc2b7 commit 185f29f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
12 changes: 9 additions & 3 deletions twin4build/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def run_emcee_estimation(self,
assert n_cores>=1, "The argument \"n_cores\" must be larger than or equal to 1"
assert fac_walker>=2, "The argument \"fac_walker\" must be larger than or equal to 2"
allowed_priors = ["uniform", "gaussian"]
allowed_walker_initializations = ["uniform", "gaussian"]
allowed_walker_initializations = ["uniform", "gaussian", "ball"]
assert prior in allowed_priors, f"The \"prior\" argument must be one of the following: {', '.join(allowed_priors)} - \"{prior}\" was provided."
assert walker_initialization in allowed_walker_initializations, f"The \"walker_initialization\" argument must be one of the following: {', '.join(allowed_walker_initializations)} - \"{walker_initialization}\" was provided."
assert np.all(self.x0>=self.lb), "The provided x0 must be larger than the provided lower bound lb"
Expand All @@ -137,9 +137,9 @@ def run_emcee_estimation(self,
logprior = self.gaussian_logprior

ndim = len(self.flat_attr_list)
self.n_par = 0
self.n_par_map = {}
if assume_uncorrelated_noise==False:
self.n_par = 0
self.n_par_map = {}
# Get number of gaussian process parameters
for j, measuring_device in enumerate(self.targetMeasuringDevices):
source_component = [cp.connectsSystemThrough.connectsSystem for cp in measuring_device.connectsAt][0]
Expand Down Expand Up @@ -167,6 +167,12 @@ def run_emcee_estimation(self,
ub = np.resize(self.ub,(x0_start.shape))
x0_start[x0_start<self.lb] = lb[x0_start<self.lb]
x0_start[x0_start>self.ub] = ub[x0_start>self.ub]
elif walker_initialization=="ball":
x0_start = np.random.uniform(low=self.x0-1e-5, high=self.ub+1e-5, size=(n_temperature, n_walkers, ndim))
lb = np.resize(self.lb,(x0_start.shape))
ub = np.resize(self.ub,(x0_start.shape))
x0_start[x0_start<self.lb] = lb[x0_start<self.lb]
x0_start[x0_start>self.ub] = ub[x0_start>self.ub]

print(f"Number of cores: {n_cores}")
print(f"Number of estimated parameters: {ndim}")
Expand Down
4 changes: 2 additions & 2 deletions twin4build/estimator/tests/test_estimator_wbypass.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def test_estimator():
# Options for the PTEMCEE estimation algorithm. If the options argument is not supplied or None is supplied, default options are applied.
options = {"n_sample": 10000, #This is a test file, and we therefore only sample 2. Typically, we need at least 1000 samples before the chain converges.
"n_temperature": 1, #Number of parallel chains/temperatures.
"fac_walker": 4, #Scaling factor for the number of ensemble walkers per chain. Minimum is 2.
"fac_walker": 8, #Scaling factor for the number of ensemble walkers per chain. Minimum is 2.
"prior": "uniform", #Prior distribution - "gaussian" is also implemented
"walker_initialization": "gaussian",#Initialization of parameters - "gaussian" is also implemented
"walker_initialization": "ball",#Initialization of parameters - "gaussian" is also implemented
"n_cores": 8,
"assume_uncorrelated_noise": False
}
Expand Down
2 changes: 1 addition & 1 deletion twin4build/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _sim_func_gaussian_process(self, model, theta, stepSize, startTime, endTime,
y_model = np.zeros((len(self.dateTimeSteps), len(self.targetMeasuringDevices)))
y = np.zeros((len(self.dateTimeSteps), len(self.targetMeasuringDevices)))
standardDeviation = model.chain_log["standardDeviation"]
n_samples = 1000
n_samples = 100
n_prev = 0
for j, measuring_device in enumerate(self.targetMeasuringDevices):
source_component = [cp.connectsSystemThrough.connectsSystem for cp in measuring_device.connectsAt][0]
Expand Down

0 comments on commit 185f29f

Please sign in to comment.