Skip to content

Commit

Permalink
Merge pull request #285 from carterbox/optimize
Browse files Browse the repository at this point in the history
REF: Migrate RPIE and LSTSQ methods to new Stream manager
  • Loading branch information
carterbox authored Dec 19, 2023
2 parents 9d2465a + 5753f60 commit b2dc831
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 191 deletions.
63 changes: 45 additions & 18 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,18 @@ def constrain_variable_probe(comm, variable_probe, weights):


def _get_update(R, eigen_probe, weights, batches, *, batch_index, c, m):
# (..., POSI, 1, 1, 1, 1) to match other arrays
weights = weights[..., batches[batch_index], c:c + 1, m:m + 1, None, None]
eigen_probe = eigen_probe[..., c - 1:c, m:m + 1, :, :]
"""
Parameters
----------
R : (B, 1, 1, H, W)
eigen_probe (1, C, M, H, W)
weights : (B, C, M)
"""
lo = batches[batch_index][0]
hi = lo + len(batches[batch_index])
# (POSI, 1, 1, 1, 1) to match other arrays
weights = weights[lo:hi, c:c + 1, m:m + 1, None, None]
eigen_probe = eigen_probe[:, c - 1:c, m:m + 1, :, :]
norm_weights = tike.linalg.norm(weights, axis=-5, keepdims=True)**2

if np.all(norm_weights == 0):
Expand All @@ -336,22 +345,30 @@ def _get_update(R, eigen_probe, weights, batches, *, batch_index, c, m):


def _get_d(patches, diff, eigen_probe, update, *, β, c, m):
eigen_probe[..., c - 1:c, m:m + 1, :, :] += β * update / tike.linalg.mnorm(
"""
Parameters
----------
patches : (B, 1, 1, H, W)
diff : (B, 1, M, H, W)
eigen_probe (1, C, M, H, W)
update : (1, 1, 1, H, W)
"""
eigen_probe[:, c - 1:c, m:m + 1, :, :] += β * update / tike.linalg.mnorm(
update,
axis=(-2, -1),
keepdims=True,
)
eigen_probe[..., c - 1:c, m:m + 1, :, :] /= tike.linalg.mnorm(
eigen_probe[..., c - 1:c, m:m + 1, :, :],
eigen_probe[:, c - 1:c, m:m + 1, :, :] /= tike.linalg.mnorm(
eigen_probe[:, c - 1:c, m:m + 1, :, :],
axis=(-2, -1),
keepdims=True,
)
assert np.all(np.isfinite(eigen_probe))

# Determine new eigen_weights for the updated eigen probe
phi = patches * eigen_probe[..., c - 1:c, m:m + 1, :, :]
phi = patches * eigen_probe[:, c - 1:c, m:m + 1, :, :]
n = np.mean(
np.real(diff[..., m:m + 1, :, :] * phi.conj()),
np.real(diff[:, :, m:m + 1, :, :] * phi.conj()),
axis=(-1, -2),
keepdims=False,
)
Expand All @@ -361,15 +378,25 @@ def _get_d(patches, diff, eigen_probe, update, *, β, c, m):


def _get_weights_mean(n, d, d_mean, weights, batches, *, batch_index, c, m):
"""
Parameters
----------
n : (B, 1, 1)
d : (B, 1, 1)
d_mean : (1, 1, 1)
weights : (B, C, M)
"""
lo = batches[batch_index][0]
hi = lo + len(batches[batch_index])
# yapf: disable
weight_update = (
n / (d + 0.1 * d_mean)
).reshape(*weights[..., batches[batch_index], c:c + 1, m:m + 1].shape)
).reshape(*weights[lo:hi, c:c + 1, m:m + 1].shape)
# yapf: enable
assert np.all(np.isfinite(weight_update))

# (33) The sum of all previous steps constrained to zero-mean
weights[..., batches[batch_index], c:c + 1, m:m + 1] += weight_update
weights[lo:hi, c:c + 1, m:m + 1] += weight_update
return weights


Expand Down Expand Up @@ -398,17 +425,17 @@ def update_eigen_probe(
----------
comm : :py:class:`tike.communicators.Comm`
An object which manages communications between both GPUs and nodes.
R : (..., POSI, 1, 1, WIDE, HIGH) complex64
R : (POSI, 1, 1, WIDE, HIGH) complex64
Residual probe updates; what's left after subtracting the shared probe
update from the varying probe updates for each position
patches : (..., POSI, 1, 1, WIDE, HIGH) complex64
diff : (..., POSI, 1, SHARED, WIDE, HIGH) complex64
eigen_probe : (..., 1, EIGEN, SHARED, WIDE, HIGH) complex64
patches : (POSI, 1, 1, WIDE, HIGH) complex64
diff : (POSI, 1, SHARED, WIDE, HIGH) complex64
eigen_probe : (1, EIGEN, SHARED, WIDE, HIGH) complex64
The eigen probe being updated.
β : float
A relaxation constant that controls how quickly the eigen probe modes
are updated. Recommended to be < 1 for mini-batch updates.
weights : (..., POSI, EIGEN, SHARED) float32
weights : (POSI, EIGEN, SHARED) float32
A vector whose elements are sums of the previous optimal updates for
each posiiton.
Expand All @@ -425,7 +452,7 @@ def update_eigen_probe(
assert R[0].shape[-3] == R[0].shape[-4] == 1
assert 1 == eigen_probe[0].shape[-5]
assert R[0].shape[:-5] == eigen_probe[0].shape[:-5] == weights[0].shape[:-3]
assert weights[0][..., batches[0][batch_index], :, :].shape[-3] == R[0].shape[-5]
assert weights[0][batches[0][batch_index], :, :].shape[-3] == R[0].shape[-5]
assert R[0].shape[-2:] == eigen_probe[0].shape[-2:]

update = comm.pool.map(
Expand Down Expand Up @@ -919,8 +946,8 @@ def rescale_probe_using_fixed_intensity_photons(
if probe_power_fraction is None:
probe_power_fraction = probe_photons / cp.sum(probe_photons)

probe = probe * cp.sqrt(probe_power_fraction * Nphotons / probe_photons)[..., None,
None]
probe = probe * cp.sqrt(
probe_power_fraction * Nphotons / probe_photons)[..., None, None]

return probe

Expand Down
2 changes: 1 addition & 1 deletion src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def _get_rescale(
probe,
streams,
*,
operator,
operator: tike.operators.Ptycho,
):

sums = cp.zeros((2,), dtype=cp.double)
Expand Down
Loading

0 comments on commit b2dc831

Please sign in to comment.