Skip to content

Commit

Permalink
Merge pull request #204 from carterbox/adam-lstsq
Browse files Browse the repository at this point in the history
NEW: Add momentum acceleration to object only for lstsq method
  • Loading branch information
carterbox authored Apr 25, 2022
2 parents 993da6c + 311ebe6 commit 962ffbb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
20 changes: 18 additions & 2 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def lstsq_grad(
bposition_options,
num_batch=algorithm_options.num_batch,
psi_update_denominator=psi_update_denominator,
object_options=object_options,
)

if position_options:
Expand Down Expand Up @@ -286,6 +287,8 @@ def _update_nearplane(
position_options,
num_batch,
psi_update_denominator,
*,
object_options,
):

patches = comm.pool.map(_get_patches, nearplane, psi, scan_, op=op)
Expand Down Expand Up @@ -458,8 +461,21 @@ def _update_nearplane(
)[..., 0, 0, 0]

# (27b) Object update
psi[0] += (weighted_step_psi[0] /
probe[0].shape[-3]) * common_grad_psi[0]
dpsi = (weighted_step_psi[0] /
probe[0].shape[-3]) * common_grad_psi[0]
if object_options.use_adaptive_moment:
(
dpsi,
object_options.v,
object_options.m,
) = tike.opt.adam(
g=dpsi,
v=object_options.v,
m=object_options.m,
vdecay=object_options.vdecay,
mdecay=object_options.mdecay,
)
psi[0] = psi[0] + dpsi
psi = comm.pool.bcast([psi[0]])

if recover_probe:
Expand Down
8 changes: 6 additions & 2 deletions tests/ptycho/test_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def test_consistent_lstsq_grad(self):
'probe_options':
ProbeOptions(orthogonality_constraint=True,),
'object_options':
ObjectOptions(),
ObjectOptions(
use_adaptive_moment=True,
),
'use_mpi':
_mpi_size > 1,
},), f"{'mpi-' if _mpi_size > 1 else ''}lstsq_grad")
Expand All @@ -351,7 +353,9 @@ def test_consistent_lstsq_grad_variable_probe(self):
'probe_options':
ProbeOptions(),
'object_options':
ObjectOptions(),
ObjectOptions(
use_adaptive_moment=True,
),
'use_mpi':
_mpi_size > 1,
'eigen_probe':
Expand Down

0 comments on commit 962ffbb

Please sign in to comment.