Skip to content

Commit

Permalink
Merge pull request #319 from carterbox/probe-support
Browse files Browse the repository at this point in the history
REF: Reeenable probe support for lstsq_grad method
  • Loading branch information
carterbox authored Jun 17, 2024
2 parents 6c6e1ce + 6be8995 commit cd331bc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
46 changes: 26 additions & 20 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,28 +755,34 @@ def _precondition_nearplane_gradients(
A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1))

if recover_probe:
b0 = tike.ptycho.probe.finite_probe_support(
unique_probe[..., m : m + 1, :, :],
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)

# b0 = tike.ptycho.probe.finite_probe_support(
# unique_probe[..., m:m+1, :, :],
# p=probe_options.probe_support,
# radius=probe_options.probe_support_radius,
# degree=probe_options.probe_support_degree,
# )
b1 = (
probe_options.additional_probe_penalty
* cp.linspace(
0,
1,
probe[0].shape[-3],
dtype=tike.precision.floating,
)[..., m : m + 1, None, None]
)

# b1 = probe_options.additional_probe_penalty * cp.linspace(
# 0,
# 1,
# probe[0].shape[-3],
# dtype=tike.precision.floating,
# )[..., m:m+1, None, None]

# m_probe_update = (m_probe_update -
# (b0 + b1) * probe[..., m:m+1, :, :]) / (
# (1 - alpha) * probe_update_denominator +
# alpha * probe_update_denominator.max(
# axis=(-2, -1),
# keepdims=True,
# ) + b0 + b1)
m_probe_update = m_probe_update - (b0 + b1) * probe[..., m : m + 1, :, :]
# / (
# (1 - alpha) * probe_update_denominator
# + alpha
# * probe_update_denominator.max(
# axis=(-2, -1),
# keepdims=True,
# )
# + b0
# + b1
# )

dPO = m_probe_update[..., m:m + 1, :, :] * patches
A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1))
Expand Down
14 changes: 11 additions & 3 deletions tests/ptycho/test_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,11 @@ def test_consistent_lstsq_grad(self):
probe_options=ProbeOptions(
force_orthogonality=True,
use_adaptive_moment=True,
probe_support=0.1,
),
object_options=ObjectOptions(
use_adaptive_moment=True,
),
object_options=ObjectOptions(use_adaptive_moment=True,),
)

_save_ptycho_result(
Expand Down Expand Up @@ -583,8 +586,13 @@ def test_consistent_rpie(self):
num_batch=5,
num_iter=16,
),
probe_options=ProbeOptions(force_orthogonality=True,),
object_options=ObjectOptions(smoothness_constraint=0.01,),
probe_options=ProbeOptions(
force_orthogonality=True,
probe_support=0.1,
),
object_options=ObjectOptions(
smoothness_constraint=0.01,
),
)

_save_ptycho_result(
Expand Down

0 comments on commit cd331bc

Please sign in to comment.