From b07cec3c93e15a2691b8034211a61099b56cfc33 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Mon, 25 Apr 2022 11:08:35 -0500 Subject: [PATCH] NEW: Add momentum acceleration to object only for lstsq method --- src/tike/ptycho/solvers/lstsq.py | 20 ++++++++++++++++++-- tests/ptycho/test_ptycho.py | 8 ++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index ad179b0f..d783074e 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -167,6 +167,7 @@ def lstsq_grad( bposition_options, num_batch=algorithm_options.num_batch, psi_update_denominator=psi_update_denominator, + object_options=object_options, ) if position_options: @@ -286,6 +287,8 @@ def _update_nearplane( position_options, num_batch, psi_update_denominator, + *, + object_options, ): patches = comm.pool.map(_get_patches, nearplane, psi, scan_, op=op) @@ -458,8 +461,21 @@ def _update_nearplane( )[..., 0, 0, 0] # (27b) Object update - psi[0] += (weighted_step_psi[0] / - probe[0].shape[-3]) * common_grad_psi[0] + dpsi = (weighted_step_psi[0] / + probe[0].shape[-3]) * common_grad_psi[0] + if object_options.use_adaptive_moment: + ( + dpsi, + object_options.v, + object_options.m, + ) = tike.opt.adam( + g=dpsi, + v=object_options.v, + m=object_options.m, + vdecay=object_options.vdecay, + mdecay=object_options.mdecay, + ) + psi[0] = psi[0] + dpsi psi = comm.pool.bcast([psi[0]]) if recover_probe: diff --git a/tests/ptycho/test_ptycho.py b/tests/ptycho/test_ptycho.py index 64649614..783f2137 100644 --- a/tests/ptycho/test_ptycho.py +++ b/tests/ptycho/test_ptycho.py @@ -325,7 +325,9 @@ def test_consistent_lstsq_grad(self): 'probe_options': ProbeOptions(orthogonality_constraint=True,), 'object_options': - ObjectOptions(), + ObjectOptions( + use_adaptive_moment=True, + ), 'use_mpi': _mpi_size > 1, },), f"{'mpi-' if _mpi_size > 1 else ''}lstsq_grad") @@ -351,7 +353,9 @@ def test_consistent_lstsq_grad_variable_probe(self): 'probe_options': ProbeOptions(), 'object_options': - ObjectOptions(), + ObjectOptions( + use_adaptive_moment=True, + ), 'use_mpi': _mpi_size > 1, 'eigen_probe':