Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MaxVar split, Part 2] Added the visualisation improvements. #234

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
=========

0.x
---
- Improved the interactive plotting (customised for the MaxVar-based acquisition methods)
- Added a pair-wise plotting to plot_state() (a way to visualise n-dimensional parameters)

0.6.3 (2017-09-28)
------------------

Expand Down
2 changes: 2 additions & 0 deletions elfi/methods/bo/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def __init__(self, *args, delta=None, **kwargs):
kwargs['exploration_rate'] = 1 / delta

super(LCBSC, self).__init__(*args, **kwargs)
self.name = 'lcbsc'
self.label_fn = 'The Lower Confidence Bound Selection Criterion'

@property
def delta(self):
Expand Down
169 changes: 86 additions & 83 deletions elfi/methods/parameter_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
__all__ = ['Rejection', 'SMC', 'BayesianOptimization', 'BOLFI']

import logging
from collections import OrderedDict
from math import ceil

import matplotlib.pyplot as plt
import numpy as np

import elfi.client
Expand Down Expand Up @@ -89,7 +89,6 @@ def __init__(self,
model = model.model if isinstance(model, NodeReference) else model
if not model.parameter_names:
raise ValueError('Model {} defines no parameters'.format(model))

self.model = model.copy()
self.output_names = self._check_outputs(output_names)

Expand Down Expand Up @@ -161,7 +160,7 @@ def extract_result(self):
"""
raise NotImplementedError

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.

ELFI calls this method when a new batch has been computed and the state of
Expand All @@ -174,10 +173,8 @@ def update(self, batch, batch_index):
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
batch_index : int

Returns
-------
None
vis : bool, optional
Interactive visualisation of the iterations.

"""
self.state['n_batches'] += 1
Expand Down Expand Up @@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
"""
raise NotImplementedError

def infer(self, *args, vis=None, **kwargs):
def infer(self, *args, **options):
"""Set the objective and start the iterate loop until the inference is finished.

See the other arguments from the `set_objective` method.
Expand All @@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
result : Sample

"""
vis_opt = vis if isinstance(vis, dict) else {}

self.set_objective(*args, **kwargs)

vis = options.pop('vis', None)
self.set_objective(*args, **options)
while not self.finished:
self.iterate()
if vis:
self.plot_state(interactive=True, **vis_opt)

self.iterate(vis=vis)
self.batches.cancel_pending()
if vis:
self.plot_state(close=True, **vis_opt)

return self.extract_result()

def iterate(self):
"""Advance the inference by one iteration.
def iterate(self, vis=None):
"""Forward the inference one iteration.

This is a way to manually progress the inference. One iteration consists of
waiting and processing the result of the next batch in succession and possibly
Expand All @@ -272,6 +262,11 @@ def iterate(self):
will never be more batches submitted in parallel than the `max_parallel_batches`
setting allows.

Parameters
----------
vis : bool, optional
Interactive visualisation of the iterations.

Returns
-------
None
Expand All @@ -286,7 +281,7 @@ def iterate(self):
# Handle the next ready batch in succession
batch, batch_index = self.batches.wait_next()
logger.debug('Received batch %d' % batch_index)
self.update(batch, batch_index)
self.update(batch, batch_index, vis=vis)

@property
def finished(self):
Expand Down Expand Up @@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
# Reset the inference
self.batches.reset()

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.

Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
vis : bool, optional
Interactive visualisation of the iterations.
batch_index : int

"""
if vis and self.state['samples'] is not None:
self.plot_state(interactive=True, **vis)

super(Rejection, self).update(batch, batch_index)
if self.state['samples'] is None:
# Lazy initialization of the outputs dict
Expand Down Expand Up @@ -584,8 +583,8 @@ def plot_state(self, **options):
displays = []
if options.get('interactive'):
from IPython import display
displays.append(
display.HTML('<span>Threshold: {}</span>'.format(self.state['threshold'])))
html_display = '<span>Threshold: {}</span>'.format(self.state['threshold'])
displays.append(display.HTML(html_display))

visin.plot_sample(
self.state['samples'],
Expand Down Expand Up @@ -651,14 +650,15 @@ def extract_result(self):
threshold=pop.threshold,
**self._extract_result_kwargs())

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the inference state with a new batch.

Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
vis : bool, optional
Interactive visualisation of the iterations.
batch_index : int

"""
Expand Down Expand Up @@ -942,14 +942,16 @@ def extract_result(self):
return OptimizationResult(
x_min=batch_min, outputs=outputs, **self._extract_result_kwargs())

def update(self, batch, batch_index):
def update(self, batch, batch_index, vis=None):
"""Update the GP regression model of the target node with a new batch.

Parameters
----------
batch : dict
dict with `self.outputs` as keys and the corresponding outputs for the batch
as values
vis : bool, optional
Interactive visualisation of the iterations.
batch_index : int

"""
Expand All @@ -959,11 +961,22 @@ def update(self, batch, batch_index):
params = batch_to_arr2d(batch, self.parameter_names)
self._report_batch(batch_index, params, batch[self.target_name])

# Adding the acquisition plots.
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
options = {}
options['point_acq'] = {'x': params, 'd': batch[self.target_name]}
options['method_acq'] = self.acquisition_method.label_fn
arr_ax = self.plot_state(interactive=True, **options)

optimize = self._should_optimize()
self.target_model.update(params, batch[self.target_name], optimize)
if optimize:
self.state['last_GP_update'] = self.target_model.n_evidence

# Adding the updated gp plots.
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
self.plot_state(interactive=True, arr_ax=arr_ax, **options)

def prepare_new_batch(self, batch_index):
"""Prepare values for a new batch.

Expand Down Expand Up @@ -1040,60 +1053,51 @@ def _report_batch(self, batch_index, params, distances):
str += "{}{} at {}\n".format(fill, distances[i].item(), params[i])
logger.debug(str)

def plot_state(self, **options):
"""Plot the GP surface.
def plot_state(self, plot_acq_pairwise=False, arr_ax=None, **options):
"""Plot the GP surface and the acquisition space.

This feature is still experimental and currently supports only 2D cases.
"""
f = plt.gcf()
if len(f.axes) < 2:
f, _ = plt.subplots(1, 2, figsize=(13, 6), sharex='row', sharey='row')

gp = self.target_model

# Draw the GP surface
visin.draw_contour(
gp.predict_mean,
gp.bounds,
self.parameter_names,
title='GP target surface',
points=gp.X,
axes=f.axes[0],
**options)
Notes
-----
- The plots of the GP surface and the acquisition space work for the
cases when dim < 3;
- The method is experimental.

# Draw the latest acquisitions
if options.get('interactive'):
point = gp.X[-1, :]
if len(gp.X) > 1:
f.axes[1].scatter(*point, color='red')
Parameters
----------
plot_acq_pairwise : bool, optional
The option to plot the pair-wise acquisition point relationships.
arr_ax : array_like, optional
Handled implicitly upon interactive visualisation.

displays = [gp._gp]
Returns
-------
array_like
Axes for interactive visualisation.

if options.get('interactive'):
from IPython import display
displays.insert(
0,
display.HTML('<span><b>Iteration {}:</b> Acquired {} at {}</span>'.format(
len(gp.Y), gp.Y[-1][0], point)))

# Update
visin._update_interactive(displays, options)

def acq(x):
return self.acquisition_method.evaluate(x, len(gp.X))

# Draw the acquisition surface
visin.draw_contour(
acq,
gp.bounds,
self.parameter_names,
title='Acquisition surface',
points=None,
axes=f.axes[1],
**options)
Raises
------
ValueError
Unsupported dimension.

if options.get('close'):
plt.close()
"""
if plot_acq_pairwise:
if len(self.parameter_names) == 1:
raise ValueError('Can not plot the pair-wise comparison for 1 parameter.')

# Transform the acquisition points in the accepted format.
dict_pts_acq = OrderedDict()
for idx_param, name_param in enumerate(self.parameter_names):
dict_pts_acq[name_param] = self.target_model.X[:, idx_param]
vis.plot_pairs(dict_pts_acq, **options)
else:
if len(self.parameter_names) == 1:
arr_ax = vis.plot_state_1d(self, arr_ax, **options)
return arr_ax
elif len(self.parameter_names) == 2:
arr_ax = vis.plot_state_2d(self, arr_ax, **options)
return arr_ax
else:
raise ValueError('The method is supported only for 1- or 2-dimensions.')

def plot_discrepancy(self, axes=None, **kwargs):
"""Plot acquired parameters vs. resulting discrepancy.
Expand Down Expand Up @@ -1133,7 +1137,7 @@ class BOLFI(BayesianOptimization):

"""

def fit(self, n_evidence, threshold=None):
def fit(self, n_evidence, threshold=None, **options):
"""Fit the surrogate model.

Generates a regression model for the discrepancy given the parameters.
Expand All @@ -1150,9 +1154,8 @@ def fit(self, n_evidence, threshold=None):

if n_evidence is None:
raise ValueError(
'You must specify the number of evidence (n_evidence) for the fitting')

self.infer(n_evidence)
'You must specify the number of evidence( n_evidence) for the fitting')
self.infer(n_evidence, **options)
return self.extract_posterior(threshold)

def extract_posterior(self, threshold=None):
Expand Down
Loading