From cd5b45b4f9c920d70ade8481e46d74ec820375a8 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 20 May 2022 13:52:12 -0500 Subject: [PATCH] STUB: Prototype regularize object scaling Want to remove ambiguity of object probe multiplication and smooth noise around edges where uncertain. --- src/tike/ptycho/object.py | 3 +++ src/tike/ptycho/solvers/lstsq.py | 16 +++++++++++----- src/tike/ptycho/solvers/rpie.py | 7 +++++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 4319cec1..c9128bcf 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -56,6 +56,9 @@ class ObjectOptions: clip_magnitude: bool = True """Whether to force the object magnitude to remain <= 1.""" + lasso_penalty: float = 0.0 + """Weight of the penalty to keep object coefficients near 1 + 0j.""" + def copy_to_device(self, comm): """Copy to the current GPU memory.""" if self.v is not None: diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 3929e194..9b6d3037 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -380,6 +380,7 @@ def _update_nearplane( psi_update_denominator, probe_update_denominator, patches, + psi, op=op, m=m, recover_psi=recover_psi, @@ -642,6 +643,7 @@ def _precondition_nearplane_gradients( psi_update_denominator, probe_update_denominator, patches, + psi, *, op, m, @@ -656,11 +658,15 @@ def _precondition_nearplane_gradients( eps = 1e-9 / (diff.shape[-2] * diff.shape[-1]) if recover_psi: - common_grad_psi /= ((1 - alpha) * psi_update_denominator + - alpha * psi_update_denominator.max( - axis=(-2, -1), - keepdims=True, - )) + + b = cp.complex64(1.0 + 0.0j) + + common_grad_psi = (common_grad_psi + b - psi) / ( + (1 - alpha) * psi_update_denominator + + alpha * psi_update_denominator.max( + axis=(-2, -1), + keepdims=True, + ) + b) dOP = op.diffraction.patch.fwd( patches=cp.zeros(patches.shape, dtype='complex64')[..., 0, 0, :, :], images=common_grad_psi, diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index b8f7e34f..7766aea5 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -196,6 +196,7 @@ def _update_nearplane( position_options=None, *, probe_options=None, + object_options=None, ): patches = comm.pool.map(_get_patches, nearplane_, psi, scan_, op=op) @@ -233,12 +234,14 @@ def _update_nearplane( psi_update_denominator = comm.reduce(psi_update_denominator, 'gpu')[0] - psi[0] += step_length * psi_update_numerator / ( + b = cp.complex64(1.0 + 0.0j) + + psi[0] += step_length * (psi_update_numerator + object_options.lasso_penalty * (b - psi[0])) / ( (1 - alpha) * psi_update_denominator + alpha * psi_update_denominator.max( axis=(-2, -1), keepdims=True, - )) + ) + object_options.lasso_penalty) psi = comm.pool.bcast([psi[0]])