Skip to content

Commit

Permalink
Merge pull request #206 from carterbox/online
Browse files Browse the repository at this point in the history
API: Enable online reconstruction
  • Loading branch information
carterbox authored May 20, 2022
2 parents 40c47b0 + 1576dda commit 443267c
Show file tree
Hide file tree
Showing 14 changed files with 1,742 additions and 1,119 deletions.
5 changes: 0 additions & 5 deletions docs/source/api/ptycho.solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,3 @@ solvers
.. autosummary::
:nosignatures:
:recursive:

adam_grad
cgrad
lstsq_grad
rpie
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# bibtex setting
bibtex_bibfiles = [
'zrefs.bib',
'bibtex/zrefs.bib',
]

# extlinks settings
Expand Down
1,520 changes: 952 additions & 568 deletions docs/source/examples/ptycho.ipynb

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions src/tike/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ def put_batch(y, x, b, n):
x[b[n]] = y


def momentum(g, v, m, vdecay=None, mdecay=0.9):
"""Add momentum to the gradient direction.
Parameters
----------
g : vector
The current gradient.
m : vector
The momentum.
eps : float
A tiny constant to prevent zero division.
"""
m = 0 if m is None else m
m = mdecay * m + (1 - mdecay) * g
return m, None, m


def adagrad(g, v=None, eps=1e-6):
"""Return the adaptive gradient algorithm direction.
Expand Down
51 changes: 50 additions & 1 deletion src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,34 @@ def __post_init__(self):
dtype='float32',
)

def append(self, new_scan):
self.initial_scan = np.append(
self.initial_scan,
values=new_scan,
axis=-2,
)
if self.use_adaptive_moment:
self._momentum = np.pad(
self._momentum,
pad_width=(
(0, len(new_scan)),
(0, 0),
),
mode='constant',
)

def empty(self):
new = PositionOptions(
np.empty((0, 2)),
use_adaptive_moment=self.use_adaptive_moment,
vdecay=self.vdecay,
mdecay=self.mdecay,
use_position_regularization=self.use_position_regularization,
)
if self.use_adaptive_moment:
new._momentum = np.empty((0, 4))
return new

def split(self, indices):
"""Split the PositionOption meta-data along indices."""
new = PositionOptions(
Expand All @@ -53,13 +81,34 @@ def split(self, indices):
new._momentum = self._momentum[..., indices, :]
return new

def join(self, other, indices):
def insert(self, other, indices):
"""Replace the PositionOption meta-data with other data."""
self.initial_scan[..., indices, :] = other.initial_scan
if self.use_adaptive_moment:
self._momentum[..., indices, :] = other._momentum
return self

def join(self, other, indices):
"""Replace the PositionOption meta-data with other data."""
len_scan = self.initial_scan.shape[-2]
max_index = max(indices.max() + 1, len_scan)
new_initial_scan = np.empty(
(*self.initial_scan.shape[:-2], max_index, 2),
dtype=self.initial_scan.dtype,
)
new_initial_scan[..., :len_scan, :] = self.initial_scan
new_initial_scan[..., indices, :] = other.initial_scan
self.initial_scan = new_initial_scan
if self.use_adaptive_moment:
new_momentum = np.empty(
(*self.initial_scan.shape[:-2], max_index, 4),
dtype=self.initial_scan.dtype,
)
new_momentum[..., :len_scan, :] = self._momentum
new_momentum[..., indices, :] = other._momentum
self._momentum = new_momentum
return self

def copy_to_device(self):
"""Copy to the current GPU memory."""
self.initial_scan = cp.asarray(self.initial_scan)
Expand Down
Loading

0 comments on commit 443267c

Please sign in to comment.