diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 710f93e..06ab65a 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -545,7 +545,9 @@ class ProbeOrthogonalizeIncoherentModesOptions(FeatureOptions): method: enums.OrthogonalizationMethods = enums.OrthogonalizationMethods.SVD """The method to use for incoherent_mode orthogonalization.""" - + + sort_by_occupancy: bool = False + """Keep the probes sorted so that mode with highest occupancy is the 0th shared mode""" @dataclasses.dataclass class ProbeOrthogonalizeOPRModesOptions(FeatureOptions): @@ -651,25 +653,26 @@ def get_non_data_fields(self) -> dict: @dataclasses.dataclass class SynthesisDictLearnProbeOptions(Options): - - d_mat: Union[ndarray, Tensor] = None + + enabled: bool = False + enabled_shared: bool = False + + thresholding_type_shared: str = 'hard' + """Choose between 'hard' or 'soft' thresholding.""" + + dictionary_matrix: Union[ndarray, Tensor] = None """The synthesis sparse dictionary matrix; contains the basis functions that will be used to represent the probe via the sparse code weights.""" - d_mat_conj_transpose: Union[ndarray, Tensor] = None - """Conjugate transpose of the synthesis sparse dictionary matrix.""" - - d_mat_pinv: Union[ndarray, Tensor] = None + dictionary_matrix_pinv: Union[ndarray, Tensor] = None """Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix.""" - probe_sparse_code: Union[ndarray, Tensor] = None - """Sparse code weights vector.""" + sparse_code_probe_shared: Union[ndarray, Tensor] = None + """Sparse code weights vector for the shared modes.""" - probe_sparse_code_nnz: float = None + sparse_code_probe_shared_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector probe_sparse_code.""" - - enabled: bool = False + the SHARED sparse code weights vector sparse_code_probe_shared.""" @dataclasses.dataclass class PositionCorrectionOptions(Options): diff --git a/src/ptychi/api/options/lsqml.py b/src/ptychi/api/options/lsqml.py index 2a2f063..31c6364 100644 --- a/src/ptychi/api/options/lsqml.py +++ b/src/ptychi/api/options/lsqml.py @@ -100,6 +100,10 @@ class LSQMLObjectOptions(base.ObjectOptions): propagation always uses all probe modes regardless of this option. """ +@dataclasses.dataclass +class LSQMLProbeExperimentalOptions(base.Options): + sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions) + @dataclasses.dataclass class LSQMLProbeOptions(base.ProbeOptions): @@ -107,8 +111,9 @@ class LSQMLProbeOptions(base.ProbeOptions): """ A scaler for the solved optimal step size (beta_LSQ in PtychoShelves). """ - - + experimental: LSQMLProbeExperimentalOptions = dataclasses.field(default_factory=LSQMLProbeExperimentalOptions) + + @dataclasses.dataclass class LSQMLProbePositionOptions(base.ProbePositionOptions): pass diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index b730b64..7c627f5 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -222,6 +222,13 @@ def constrain_incoherent_modes_orthogonality(self): probe = self.data + if self.options.orthogonalize_incoherent_modes.sort_by_occupancy: + shared_occupancy = torch.sum(torch.abs(probe[0, ...]) ** 2, (-2, -1)) / torch.sum( + torch.abs(probe[0, ...]) ** 2 + ) + shared_occupancy = torch.sort(shared_occupancy, dim=0, descending=True) + probe[0, ...] = probe[0, shared_occupancy[1], ...] + norm_first_mode_orig = pmath.norm(probe[0, 0], dim=(-2, -1)) if self.orthogonalize_incoherent_modes_method == "gs": @@ -470,31 +477,60 @@ def __init__(self, name = "probe", options = None, *args, **kwargs): super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs) - dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H = self.get_dictionary() + dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary() self.register_buffer("dictionary_matrix", dictionary_matrix) self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv) - self.register_buffer("dictionary_matrix_H", dictionary_matrix_H) - - probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 ) - self.register_buffer("probe_sparse_code_nnz", probe_sparse_code_nnz ) - sparse_code_probe = self.get_sparse_code_weights() - self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe)) - + sparse_code_probe_shared_nnz = torch.tensor( + self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, + dtype=torch.uint32, + ) + sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() + self.register_buffer("sparse_code_probe_shared_nnz", sparse_code_probe_shared_nnz) + self.register_parameter( + "sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared) + ) + self.build_optimizer() def get_dictionary(self): - dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.d_mat, dtype=torch.complex64 ) - dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_pinv, dtype=torch.complex64 ) - dictionary_matrix_H = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_conj_transpose, dtype=torch.complex64 ) - return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H - - def get_sparse_code_weights(self): - sz = self.data.shape - probe_vec = torch.reshape( self.data[0,...], (sz[1], sz[2] * sz[3])) - probe_vec = torch.swapaxes( probe_vec, 0, -1) - sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec - return sparse_code_probe + dictionary_matrix = torch.tensor( + self.options.experimental.sdl_probe_options.dictionary_matrix, dtype=torch.complex64 + ) + dictionary_matrix_pinv = torch.tensor( + self.options.experimental.sdl_probe_options.dictionary_matrix_pinv, + dtype=torch.complex64, + ) + return dictionary_matrix, dictionary_matrix_pinv + + def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions): + """Get the sparse code weights for a given probe vs scan positions. + + Parameters + ---------- + probe_vs_scanpositions : Tensor + A (n_pos, 1, h, w) tensor giving the probe vs scan positions. + + Returns + ------- + Tensor + A tensor giving the sparse code weights for the given probe vs scan positions. + """ + sz = probe_vs_scanpositions.shape + probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2] * sz[3])) + sparse_code_vs_scanpositions = torch.einsum( + "ij,klj->ikl", self.dictionary_matrix_pinv, probe_vec + ) + + return sparse_code_vs_scanpositions + + def get_sparse_code_probe_shared_weights(self): + probe_shared = self.data[0, ...] + sz = probe_shared.shape + probe_vec = torch.reshape(probe_shared, (sz[0], sz[1] * sz[2])) + sparse_code_probe_shared = self.dictionary_matrix_pinv @ probe_vec.T + + return sparse_code_probe_shared.T def generate(self): """Generate the probe using the sparse code, and set the @@ -505,27 +541,159 @@ def generate(self): Tensor A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - probe_vec = self.dictionary_matrix @ self.sparse_code_probe - probe_vec = torch.swapaxes( probe_vec, 0, -1) - probe = torch.reshape(probe_vec, *[self.data[0,...].shape]) - probe = probe[None,...] - - # we only use sparse codes for the shared modes, not the OPRs - probe = torch.cat((probe, self.data[1:,...]), 0) - - self.set_data(probe) - return probe - + + if self.options.experimental.sdl_probe_options.enabled_shared: + sz = self.data.shape + probe = torch.zeros(*[sz], dtype=torch.complex64) + + probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T + + probe[0, ...] = torch.reshape(probe_shared.T, *[sz[1:]]) + probe[1:, 0, ...] = self.data[1:, 0, ...] + + self.set_data(probe) + + else: + probe = self.data + def build_optimizer(self): if self.optimizable and self.optimizer_class is None: raise ValueError( "Parameter {} is optimizable but no optimizer is specified.".format(self.name) ) if self.optimizable: - self.optimizer = self.optimizer_class([self.sparse_code_probe], **self.optimizer_params) + self.optimizer = self.optimizer_class( + [self.sparse_code_probe_shared], **self.optimizer_params + ) + + def set_sparse_code_probe_shared(self, data): + """ + Set the sparse code weights for the shared probe. + + Parameters + ---------- + data : Tensor + A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe. + """ + self.sparse_code_probe_shared.data = data - def set_sparse_code(self, data): - self.sparse_code_probe.data = data + def initialize_grad_sparse_code_probe_shared(self): + """ + Initialize the gradient of the sparse code weights update for the shared probe. + + Parameters + ---------- + data : Tensor + A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe. + """ + self.sparse_code_probe_shared.grad = torch.zeros_like(self.sparse_code_probe_shared.data) + + def set_gradient_sparse_code_probe_shared(self, grad): + """ + Set the gradient of the sparse code weights update for the shared probe. + + Parameters + ---------- + data : Tensor + A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe. + """ + self.sparse_code_probe_shared.grad = grad + + def set_sparse_code_weights_vs_scanpositions( + self, sparse_code_vs_scanpositions: Tensor, indices: tuple | Tensor = None + ): + """ + Set the sparse code weights for a given probe vs scan positions. + + Parameters + ---------- + sparse_code_vs_scanpositions : Tensor + A (n_pos, n_opr_modes, n_scpm) tensor giving the sparse code weights for the + given probe vs scan positions. + indices : tuple | Tensor + The indices to apply to the sparse code weights. + """ + raise NotImplementedError("This method is not implemented yet.") + if indices is None: + indices = slice(None) + self.sparse_code_weights_vs_scanpositions[indices] = sparse_code_vs_scanpositions + + def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, obj_patches): + nr = chi.shape[-2] + nc = chi.shape[-1] + nrnc = nr * nc + n_scpm = chi.shape[-3] + n_spos = chi.shape[-4] + + obj_patches = torch.reshape(obj_patches, (n_spos, nrnc)) + chi = torch.reshape(chi, (n_spos, n_scpm, nrnc)).permute(2, 0, 1) + + # get sparse code update direction + delta_sparse_code = torch.einsum( + "ijk,kl->lij", + torch.reshape(delta_p_i, (n_spos, n_scpm, nrnc)), + self.dictionary_matrix.conj(), + ) + + # compute optimal step length for sparse code update + dict_delta_sparse_code = torch.einsum( + "ij,jkl->ikl", self.dictionary_matrix, delta_sparse_code + ) + + denom = (torch.abs(dict_delta_sparse_code) ** 2) * obj_patches.swapaxes(0, -1)[..., None] + denom = torch.einsum("ij,jik->ik", torch.conj(obj_patches), denom) + + numer = torch.conj(dict_delta_sparse_code) * chi + numer = torch.einsum("ij,jik->ik", torch.conj(obj_patches), numer) + + # real is used to throw away small imag part due to numerical precision errors + optimal_step_sparse_code = (numer / denom).real + + optimal_delta_sparse_code = optimal_step_sparse_code[None, ...] * delta_sparse_code + + # enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(optimal_delta_sparse_code) + abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + + sel = abs_sparse_code_sorted[0][self.sparse_code_probe_shared_nnz, ...] + sparse_code_mask = abs_sparse_code >= sel[None, ...] + + # hard or soft thresholding + if self.options.experimental.sdl_probe_options.thresholding_type_shared == "hard": + optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask + elif self.options.experimental.sdl_probe_options.thresholding_type_shared == "soft": + optimal_delta_sparse_code = ( + (abs_sparse_code - sel[None, ...]) + * sparse_code_mask + * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) + ) + + delta_p_i = torch.einsum( + "ij,jlk->ilk", self.dictionary_matrix, optimal_delta_sparse_code + ).permute(1, 2, 0) + + delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, nr, nc)) + + return delta_p_i, optimal_delta_sparse_code + + def get_grad(self) -> torch.Tensor: + """Get the gradient of the sparse code weights for the shared probe. + This method overrides the method in the base class, which returns + the `.grad` attribute of the tensor. + + Returns + ------- + Tensor + The gradient of the sparse code weights for the shared probe. + """ + return self.sparse_code_probe_shared.grad + + def set_grad(self, grad: torch.Tensor): + """Set the gradient of the sparse code weights for the shared probe. + This method overrides the method in the base class, which sets the `.grad` + attribute of the tensor. + """ + self.set_gradient_sparse_code_probe_shared(grad) class DIPProbe(Probe): diff --git a/src/ptychi/maths.py b/src/ptychi/maths.py index d19efc0..292aa86 100644 --- a/src/ptychi/maths.py +++ b/src/ptychi/maths.py @@ -213,6 +213,10 @@ def orthogonalize_svd( def project(a, b, dim=None): """Return complex vector projection of a onto b for along given axis.""" projected_length = inner(a, b, dim=dim, keepdims=True) / inner(b, b, dim=dim, keepdims=True) + + # if the inner product of b with itself has any zeros: + projected_length = torch.nan_to_num(projected_length, nan=0.0) + return projected_length * b def inner(x, y, dim=None, keepdims=False): diff --git a/src/ptychi/reconstructors/base.py b/src/ptychi/reconstructors/base.py index 4a4fc31..27b22de 100644 --- a/src/ptychi/reconstructors/base.py +++ b/src/ptychi/reconstructors/base.py @@ -642,6 +642,13 @@ def __init__( ) self.forward_model = None self.build_forward_model() + + @property + def use_sparse_probe_shared_update(self): + return ( + self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared + ) def build_forward_model(self): self.forward_model = fm.PlanarPtychographyForwardModel( diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 5a793c8..9120159 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -254,20 +254,34 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): self._record_object_slice_gradient(i_slice, delta_o_precond, add_to_existing=False) else: self._record_object_slice_gradient(i_slice, delta_o_comb, add_to_existing=False) + + if self.use_sparse_probe_shared_update and self.parameter_group.probe.optimization_enabled(self.current_epoch): + ( + delta_p_i_before_adj_shift, delta_p_i, _ + ) = self.calculate_probe_update_direction_sparse_code_probe_shared( + indices, chi, obj_patches, i_slice + ) + else: + # Calculate probe update direction (dense representation) + delta_p_i_before_adj_shift = self._calculate_probe_update_direction( + chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None + ) # Eq. 24a + delta_p_i = self.adjoint_shift_probe_update_direction( + indices, delta_p_i_before_adj_shift, first_mode_only=True + ) - # Calculate probe update direction. - delta_p_i_before_adj_shift = self._calculate_probe_update_direction( - chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None - ) # Eq. 24a - delta_p_i = self.adjoint_shift_probe_update_direction( - indices, delta_p_i_before_adj_shift, first_mode_only=True - ) delta_p_hat = self._precondition_probe_update_direction(delta_p_i) # Eq. 25a self._record_probe_gradient(delta_p_hat) # Calculate update vectors for OPR modes and weights. if i_slice == 0: if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): + + if self.use_sparse_probe_shared_update: + apply_updates = True + else: + apply_updates = False + self.parameter_group.opr_mode_weights.update_variable_probe( self.parameter_group.probe, indices, @@ -277,7 +291,7 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): obj_patches, self.current_epoch, probe_mode_index=0, - apply_updates=False, + apply_updates=apply_updates, ) # Update buffered data for momentum acceleration. @@ -315,6 +329,27 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): # Set chi to conjugate-modulated wavefield. chi = delta_p_i_before_adj_shift + + def calculate_probe_update_direction_sparse_code_probe_shared( + self, indices, chi, obj_patches, i_slice=None + ): + """Calculate probe update direction using the sparse code representation. + """ + delta_p_i_before_adj_shift = self._calculate_probe_update_direction( + chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None + ) + delta_p_i = self.adjoint_shift_probe_update_direction( + indices, delta_p_i_before_adj_shift, first_mode_only=True + ) + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( + indices, chi, first_mode_only=True + ) + delta_p_i, optimal_delta_sparse_code_vs_spos = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( + delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] + ) + self.parameter_group.probe.set_grad(optimal_delta_sparse_code_vs_spos.mean(1).T) + delta_p_i_before_adj_shift = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + return delta_p_i_before_adj_shift, delta_p_i, optimal_delta_sparse_code_vs_spos @timer() def apply_reconstruction_parameter_updates(self, indices: torch.Tensor): @@ -343,7 +378,10 @@ def apply_reconstruction_parameter_updates(self, indices: torch.Tensor): alpha_p_i = self.reconstructor_buffers.alpha_probe_all_pos[indices] if self.parameter_group.probe.optimization_enabled(self.current_epoch): self._apply_probe_update(alpha_p_i, -self.parameter_group.probe.get_grad()[0]) - + # update the shared probe sparse code if enabled + if self.use_sparse_probe_shared_update: + self._apply_probe_sparse_code_shared_updates() + # Update probe positions. if self.parameter_group.probe_positions.optimization_enabled(self.current_epoch): self.parameter_group.probe_positions.step_optimizer() @@ -749,6 +787,11 @@ def _apply_probe_update(self, alpha_p_i, delta_p_hat, probe_mode_index=None): alpha_p_mean = torch.mean(alpha_p_i) self.parameter_group.probe.set_grad(-delta_p_hat * alpha_p_mean, slicer=(0, mode_slicer)) self.parameter_group.probe.optimizer.step() + + def _apply_probe_sparse_code_shared_updates(self): + sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() + sparse_code_probe_shared = sparse_code_probe_shared + self.parameter_group.probe.sparse_code_probe_shared.grad + self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) @timer() def _apply_probe_momentum(self, alpha_p_mean, delta_p_hat): diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index ac032ac..048fcd7 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -135,50 +135,24 @@ def compute_updates( delta_p_i = None if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)): - if (self.parameter_group.probe.representation == "sparse_code"): - # TODO: move this into SynthesisDictLearnProbe class - rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2] - n_scpm = delta_exwv_i.shape[-3] - n_spos = delta_exwv_i.shape[-4] - - obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) - abs2_obj_patches = torch.abs(obj_patches_vec) ** 2 - - z = torch.sum(abs2_obj_patches, dim = 0) - z_max = torch.max(z) - w = self.parameter_group.probe.options.alpha * (z_max - z) - z_plus_w = torch.swapaxes(z + w, 0, 1) - - delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True) - delta_exwv = torch.sum(delta_exwv, 0) - delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T - - denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix)) - numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv - - delta_sparse_code = torch.linalg.solve(denom, numer) - - delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code - delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2])) - delta_p_i = torch.tile(delta_p, (n_spos,1,1,1)) - - # sparse code update - sparse_code = self.parameter_group.probe.get_sparse_code_weights() - sparse_code = sparse_code + delta_sparse_code - - # Enforce sparsity constraint on sparse code - abs_sparse_code = torch.abs(sparse_code) - sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + if self.use_sparse_probe_shared_update: + + # Calculate probe update direction using the sparse code representation + + step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...])) + delta_p_i = step_weight * delta_exwv_i # get delta p at each position - sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :] + # Undo subpixel shift in probe update directions. + delta_p_i = self.adjoint_shift_probe_update_direction(indices, delta_p_i, first_mode_only=True) - # hard thresholding: - sparse_code = sparse_code * (abs_sparse_code >= sel) + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( + indices, delta_exwv_i, first_mode_only=True + ) - #(TODO: soft thresholding option) + delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( + delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] + ) - # Update the new sparse code in the probe class - self.parameter_group.probe.set_sparse_code(sparse_code) else: step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...])) delta_p_i = step_weight * delta_exwv_i # get delta p at each position