Skip to content

Commit

Permalink
Merge pull request #175 from carterbox/full-adjoint
Browse files Browse the repository at this point in the history
NEW: Perform ptycho adjoints for probe and obj simultaneously
  • Loading branch information
carterbox authored Nov 19, 2021
2 parents 6c72766 + a257f93 commit 58b9a23
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 77 deletions.
5 changes: 3 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
displayName: Force remove previous environments

- script: >
conda create --quiet --yes
mamba create --quiet --yes
-n tike
--channel conda-forge
--file requirements.txt
Expand Down Expand Up @@ -93,7 +93,7 @@ jobs:
displayName: Force remove previous environments

- script: >
conda create --quiet --yes
mamba create --quiet --yes
-n tike
--channel conda-forge
--file requirements.txt
Expand Down Expand Up @@ -122,6 +122,7 @@ jobs:
- script: |
source activate tike
export OMPI_MCA_opal_cuda_support=true
mpirun -n 2 pytest tests/test_comm.py -vs
mpirun -n 2 pytest tests/test_ptycho.py -k TestPtychoRecon -vs
mpirun -n 2 pytest tests/test_lamino.py -k bucket -vs
Expand Down
40 changes: 40 additions & 0 deletions src/tike/operators/cupy/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,43 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False):
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]
return patches

def adj_all(self, nearplane, scan, probe, psi, overwrite=False):
"""Peform adj and adj_probe at the same time."""
assert probe.shape[:-4] == scan.shape[:-2]
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
assert nearplane.shape[:-3] == scan.shape[:-1], (nearplane.shape,
scan.shape)

patches = self.patch.fwd(
# Could be xp.empty if scan positions are all in bounds
patches=self.xp.zeros(
(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
self.probe_shape, self.probe_shape),
dtype='complex64',
),
images=psi,
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
patches = patches.reshape((*scan.shape[:-1], nearplane.shape[-3],
self.probe_shape, self.probe_shape))
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]

if not overwrite:
nearplane = nearplane.copy()
nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj()

return self.patch.adj(
patches=nearplane.reshape(
(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:])),
images=self.xp.zeros((*scan.shape[:-2], self.nz, self.n),
dtype='complex64'),
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
), patches
32 changes: 26 additions & 6 deletions src/tike/operators/cupy/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,18 @@ class Ptycho(Operator):
"""

def __init__(self, detector_shape, probe_shape, nz, n,
ntheta=1, model='gaussian',
propagation=Propagation,
diffraction=Convolution,
**kwargs): # noqa: D102 yapf: disable
def __init__(
self,
detector_shape,
probe_shape,
nz,
n,
ntheta=1,
model='gaussian',
propagation=Propagation,
diffraction=Convolution,
**kwargs,
):
"""Please see help(Ptycho) for more info."""
self.propagation = propagation(
detector_shape=detector_shape,
Expand All @@ -63,7 +70,6 @@ def __init__(self, detector_shape, probe_shape, nz, n,
detector_shape=detector_shape,
nz=nz,
n=n,
model=model,
**kwargs,
)
# TODO: Replace these with @property functions
Expand Down Expand Up @@ -177,3 +183,17 @@ def grad_probe(self, data, psi, scan, probe, mode=None):
axis=0,
keepdims=True,
)

def adj_all(self, farplane, probe, scan, psi, overwrite=False):
"""Please see help(Ptycho) for more info."""
apsi, aprobe = self.diffraction.adj_all(
nearplane=self.propagation.adj(
farplane,
overwrite=overwrite,
)[..., 0, :, :, :],
probe=probe[..., 0, :, :, :],
scan=scan,
overwrite=True,
psi=psi,
)
return apsi, aprobe[..., None, :, :, :]
21 changes: 2 additions & 19 deletions src/tike/ptycho/solvers/epie.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,7 @@ def epie(

def _update_wavefront(data, varying_probe, scan, psi, op=None):

# Compute the diffraction patterns for all of the probe modes at once.
# We need access to all of the modes of a position to solve the phase
# problem. The Ptycho operator doesn't do this natively, so it's messy.
patches = cp.zeros(data.shape, dtype='complex64')
patches = op.diffraction.patch.fwd(
patches=patches,
images=psi,
positions=scan,
patch_width=varying_probe.shape[-1],
)
patches = patches.reshape(*scan.shape[:-1], 1, 1, op.detector_shape,
op.detector_shape)

nearplane = cp.tile(patches, reps=(1, 1, varying_probe.shape[-3], 1, 1))
pad, end = op.diffraction.pad, op.diffraction.end
nearplane[..., pad:end, pad:end] *= varying_probe

# Solve the farplane phase problem ----------------------------------------
farplane = op.propagation.fwd(nearplane, overwrite=True)
farplane = op.fwd(probe=varying_probe, scan=scan, psi=psi)
intensity = cp.sum(
cp.square(cp.abs(farplane)),
axis=list(range(1, farplane.ndim - 2)),
Expand All @@ -183,6 +165,7 @@ def _update_wavefront(data, varying_probe, scan, psi, op=None):

farplane = op.propagation.adj(farplane, overwrite=True)

pad, end = op.diffraction.pad, op.diffraction.end
return farplane[..., pad:end, pad:end], cost


Expand Down
20 changes: 1 addition & 19 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,25 +529,7 @@ def _update_nearplane(op, comm, nearplane, psi, scan_, probe, unique_probe,

def _update_wavefront(data, varying_probe, scan, psi, op):

# Compute the diffraction patterns for all of the probe modes at once.
# We need access to all of the modes of a position to solve the phase
# problem. The Ptycho operator doesn't do this natively, so it's messy.
patches = cp.zeros(data.shape, dtype='complex64')
patches = op.diffraction.patch.fwd(
patches=patches,
images=psi,
positions=scan,
patch_width=varying_probe.shape[-1],
)
patches = patches.reshape(*scan.shape[:-1], 1, 1, op.detector_shape,
op.detector_shape)

nearplane = cp.tile(patches, reps=(1, 1, varying_probe.shape[-3], 1, 1))
pad, end = op.diffraction.pad, op.diffraction.end
nearplane[..., pad:end, pad:end] *= varying_probe

# Solve the farplane phase problem ----------------------------------------
farplane = op.propagation.fwd(nearplane, overwrite=True)
farplane = op.fwd(probe=varying_probe, scan=scan, psi=psi)
intensity = cp.sum(
cp.square(cp.abs(farplane)),
axis=list(range(1, farplane.ndim - 2)),
Expand Down
39 changes: 39 additions & 0 deletions src/tike/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,42 @@ def plot_trajectories(theta, v, h, t):
plt.xlabel('time [s]')
plt.ylim([0, 1.])
return ax1a, ax1b


def plot_cost_convergence(costs, times):
"""Plot a twined plot of cost vs iteration/cumulative-time
The plot is a semi-log line plot with two lines. One line shows cost as a
function of iteration (bottom horizontal); one line shows cost as a
function of cumulative wall-time (top horizontal).
Parameters
----------
costs : (NUM_ITER, ) array-like
The objective cost at each iteration.
times : (NUM_ITER, ) array-like
The wall-time for each iteration in seconds.
Returns
-------
fig : matplotlib.figure.Figure
ax1 : matplotlib.axes._subplots.AxesSubplot
ax2 : matplotlib.axes._subplots.AxesSubplot
"""
fig, ax1 = plt.subplots()

color = 'black'
ax1.semilogy()
ax1.set_xlabel('iteration', color=color)
ax1.set_ylabel('objective')
ax1.plot(costs, linestyle='--', color=color)
ax1.tick_params(axis='x', labelcolor=color)

ax2 = ax1.twiny()

color = 'red'
ax2.set_xlabel('wall-time [s]', color=color)
ax2.plot(np.cumsum(times), costs, color=color)
ax2.tick_params(axis='x', labelcolor=color)

return fig, ax1, ax2
42 changes: 40 additions & 2 deletions tests/operators/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def setUp(self):
'scan': self.xp.asarray(scan, dtype='float32'),
'psi': self.xp.asarray(original, dtype='complex64')
}
self.kwargs2 = {
'scan': self.xp.asarray(scan, dtype='float32'),
}

self.d = self.xp.asarray(nearplane, dtype='complex64')
self.d_name = 'nearplane'
Expand All @@ -76,8 +79,8 @@ def test_adjoint_probe(self):
print()
print('<Fm, m> = {:.6f}{:+.6f}j'.format(a.real.item(), a.imag.item()))
print('< d, F*d> = {:.6f}{:+.6f}j'.format(b.real.item(), b.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5)
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5, atol=0)

def test_adj_probe_time(self):
"""Time the adjoint operation."""
Expand All @@ -90,6 +93,41 @@ def test_adj_probe_time(self):
def test_scaled(self):
pass

def test_adjoint_all(self):
"""Check that the adjoint operator is correct."""
d = self.operator.fwd(
**{
self.m_name: self.m,
self.m1_name: self.m1
},
**self.kwargs2,
)
assert d.shape == self.d.shape
m, m1 = self.operator.adj_all(
**{
self.d_name: self.d,
self.m_name: self.m,
self.m1_name: self.m1
},
**self.kwargs2,
)
assert m.shape == self.m.shape
assert m1.shape == self.m1.shape
a = inner_complex(d, self.d)
b = inner_complex(self.m, m)
c = inner_complex(self.m1, m1)
print()
print('< Fm, m> = {:.6f}{:+.6f}j'.format(a.real.item(),
a.imag.item()))
print('< d0, F*d0> = {:.6f}{:+.6f}j'.format(b.real.item(),
b.imag.item()))
print('< d1, F*d1> = {:.6f}{:+.6f}j'.format(c.real.item(),
c.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-5, atol=0)


if __name__ == '__main__':
unittest.main()
46 changes: 42 additions & 4 deletions tests/operators/test_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def setUp(self, ntheta=3, pw=15, nscan=27):
'scan': self.xp.asarray(scan, dtype='float32'),
'psi': self.xp.asarray(original, dtype='complex64')
}
self.kwargs2 = {
'scan': self.xp.asarray(scan, dtype='float32'),
}

self.d = self.xp.asarray(farplane, dtype='complex64')
self.d_name = 'farplane'
Expand All @@ -76,10 +79,10 @@ def test_adjoint_probe(self):
a = inner_complex(d, self.d)
b = inner_complex(self.m1, m)
print()
print('<Fm, m> = {:.6f}{:+.6f}j'.format(a.real.item(), a.imag.item()))
print('< d, F*d> = {:.6f}{:+.6f}j'.format(b.real.item(), b.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5)
print('<Fm, m> = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item()))
print('< d, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5, atol=0)

def test_adj_probe_time(self):
"""Time the adjoint operation."""
Expand All @@ -92,6 +95,41 @@ def test_adj_probe_time(self):
def test_scaled(self):
pass

def test_adjoint_all(self):
"""Check that the adjoint operator is correct."""
d = self.operator.fwd(
**{
self.m_name: self.m,
self.m1_name: self.m1
},
**self.kwargs2,
)
assert d.shape == self.d.shape
m, m1 = self.operator.adj_all(
**{
self.d_name: self.d,
self.m_name: self.m,
self.m1_name: self.m1
},
**self.kwargs2,
)
assert m.shape == self.m.shape
assert m1.shape == self.m1.shape
a = inner_complex(d, self.d)
b = inner_complex(self.m, m)
c = inner_complex(self.m1, m1)
print()
print('< Fm, m> = {:.5g}{:+.5g}j'.format(a.real.item(),
a.imag.item()))
print('< d0, F*d0> = {:.5g}{:+.5g}j'.format(b.real.item(),
b.imag.item()))
print('< d1, F*d1> = {:.5g}{:+.5g}j'.format(c.real.item(),
c.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.real, c.real, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-5, atol=0)


if __name__ == '__main__':
unittest.main()
19 changes: 10 additions & 9 deletions tests/operators/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def test_adjoint(self):
a = inner_complex(d, self.d)
b = inner_complex(self.m, m)
print()
print('<Fm, m> = {:.6f}{:+.6f}j'.format(a.real.item(), a.imag.item()))
print('< d, F*d> = {:.6f}{:+.6f}j'.format(b.real.item(), b.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5)
print('<Fm, m> = {:.5g}{:+.5g}j'.format(a.real.item(), a.imag.item()))
print('< d, F*d> = {:.5g}{:+.5g}j'.format(b.real.item(), b.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5, atol=0)
self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5, atol=0)

def test_scaled(self):
"""Check that the adjoint operator is scaled."""
Expand All @@ -60,11 +60,12 @@ def test_scaled(self):
a = inner_complex(m, m)
b = inner_complex(self.m, self.m)
print()
print('<F*Fm, F*Fm> = {:.6f}{:+.6f}j'.format(a.real.item(),
a.imag.item()))
print('< m, m> = {:.6f}{:+.6f}j'.format(b.real.item(),
b.imag.item()))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5)
# NOTE: Inner product with self is real-only magnitude of self
print('<F*Fm, F*Fm> = {:.5g}{:+.5g}j'.format(a.real.item(),
0))
print('< m, m> = {:.5g}{:+.5g}j'.format(b.real.item(),
0))
self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5, atol=0)

def test_fwd_time(self):
"""Time the forward operation."""
Expand Down
Loading

0 comments on commit 58b9a23

Please sign in to comment.