Skip to content
Merged
20 changes: 13 additions & 7 deletions src/ptychi/data_structures/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,12 @@ def __init__(self, name = "probe", options = None, *args, **kwargs):
self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv)
self.register_buffer("dictionary_matrix_H", dictionary_matrix_H)

probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 )
probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 )
self.register_buffer("probe_sparse_code_nnz", probe_sparse_code_nnz )

sparse_code_probe = self.get_initial_weights()
sparse_code_probe = self.get_sparse_code_weights()
self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe))

self.build_optimizer()

def get_dictionary(self):
Expand All @@ -481,8 +481,9 @@ def get_dictionary(self):
dictionary_matrix_H = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_conj_transpose, dtype=torch.complex64 )
return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H

def get_initial_weights(self):
probe_vec = torch.reshape( self.data, ( self.data.shape[1], self.data.shape[2] * self.data.shape[3] ))
def get_sparse_code_weights(self):
sz = self.data.shape
probe_vec = torch.reshape( self.data[0,...], (sz[1], sz[2] * sz[3]))
probe_vec = torch.swapaxes( probe_vec, 0, -1)
sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec
return sparse_code_probe
Expand All @@ -498,8 +499,13 @@ def generate(self):
"""
probe_vec = self.dictionary_matrix @ self.sparse_code_probe
probe_vec = torch.swapaxes( probe_vec, 0, -1)
probe = torch.reshape( probe_vec, ( self.data.shape[1], self.data.shape[2], self.data.shape[3] ))[ None, ... ]
self.tensor.data = torch.stack([probe.real, probe.imag], dim=-1)
probe = torch.reshape(probe_vec, *[self.data[0,...].shape])
probe = probe[None,...]

# we only use sparse codes for the shared modes, not the OPRs
probe = torch.cat((probe, self.data[1:,...]), 0)

self.set_data(probe)
return probe

def build_optimizer(self):
Expand Down
187 changes: 102 additions & 85 deletions src/ptychi/reconstructors/pie.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def build_loss_tracker(self):
return super().build_loss_tracker()

def check_inputs(self, *args, **kwargs):
if self.parameter_group.object.is_multislice:
raise NotImplementedError("EPIEReconstructor only supports 2D objects.")
for var in self.parameter_group.get_optimizable_parameters():
if "lr" not in var.optimizer_params.keys():
raise ValueError(
Expand Down Expand Up @@ -94,97 +92,116 @@ def compute_updates(
obj_patches = self.forward_model.intermediate_variables["obj_patches"]
psi = self.forward_model.intermediate_variables["psi"]
psi_far = self.forward_model.intermediate_variables["psi_far"]
unique_probes = self.forward_model.intermediate_variables.shifted_unique_probes[0]

unique_probes = self.forward_model.intermediate_variables.shifted_unique_probes
psi_prime = self.replace_propagated_exit_wave_magnitude(psi_far, y_true)
# Do not swap magnitude for bad pixels.
psi_prime = torch.where(
valid_pixel_mask.repeat(psi_prime.shape[0], probe.n_modes, 1, 1), psi_prime, psi_far
)
psi_prime = self.forward_model.free_space_propagator.propagate_backward(psi_prime)

delta_o = None
if object_.optimization_enabled(self.current_epoch):
step_weight = self.calculate_object_step_weight(unique_probes)
delta_o_patches = step_weight * (psi_prime - psi)
delta_o_patches = delta_o_patches.sum(1, keepdim=True)
delta_o = ip.place_patches_integer(
torch.zeros_like(object_.get_slice(0)),
positions.round().int() + object_.pos_origin_coords,
delta_o_patches[:, 0],
op="add",
)
# Add slice dimension.
delta_o = delta_o.unsqueeze(0)

delta_pos = None
if probe_positions.optimization_enabled(self.current_epoch) and object_.optimizable:
delta_pos = torch.zeros_like(probe_positions.data)
delta_pos[indices] = probe_positions.position_correction.get_update(
psi_prime - psi,
obj_patches,
delta_o_patches,
self.forward_model.intermediate_variables.shifted_unique_probes[0],
object_.optimizer_params["lr"],
)

delta_p_i = None
if probe.optimization_enabled(self.current_epoch) and self.parameter_group.probe.representation == "sparse_code":
rc = psi_prime.shape[-1] * psi_prime.shape[-2]
n_scpm = psi_prime.shape[-3]
n_pos = psi_prime.shape[-4]

psi_prime_vec = torch.reshape(psi_prime, (n_pos, n_scpm, rc))

probe_vec = torch.reshape(self.parameter_group.probe.data[0, ...], (n_scpm , rc))

obj_patches_vec = torch.reshape(obj_patches, (n_pos, 1, rc ))

conj_obj_patches = torch.conj(obj_patches_vec)
abs2_obj_patches = torch.abs(obj_patches_vec) ** 2

z = torch.sum(abs2_obj_patches, dim = 0)
z_max = torch.max(z)
w = 0.9 * (z_max - z)

sum_spos_conjT_s_psi = torch.sum(conj_obj_patches * psi_prime_vec, 0)
sum_spos_conjT_s_psi = torch.swapaxes(sum_spos_conjT_s_psi, 0, 1)

w_phi = torch.swapaxes(w * probe_vec, 0, 1)
z_plus_w = torch.swapaxes(z + w, 0, 1)

numer = self.parameter_group.probe.dictionary_matrix_H @ (sum_spos_conjT_s_psi + w_phi)
denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix))

sparse_code = torch.linalg.solve(denom, numer)

# Enforce sparsity constraint on sparse code
abs_sparse_code = torch.abs(sparse_code)
sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True)

sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :]
delta_exwv_i = psi_prime - psi
delta_o = torch.zeros_like(object_.data)

for i_slice in range(object_.n_slices - 1, -1, -1):

if object_.optimization_enabled(self.current_epoch):
step_weight = self.calculate_object_step_weight(unique_probes[i_slice])
delta_o_patches = step_weight * delta_exwv_i
delta_o_patches = delta_o_patches.sum(1, keepdim=True)
delta_o_i = ip.place_patches_integer(
torch.zeros_like(object_.get_slice(0)),
positions.round().int() + object_.pos_origin_coords,
delta_o_patches[:, 0],
op="add",
)

delta_o[i_slice, ...] = delta_o_i

delta_pos = None
if (probe_positions.optimization_enabled(self.current_epoch)
and object_.optimizable
and i_slice == self.parameter_group.probe_positions.get_slice_for_correction(object_.n_slices)
):
delta_pos = torch.zeros_like(probe_positions.data)
delta_pos[indices] = probe_positions.position_correction.get_update(
delta_exwv_i,
obj_patches[:, i_slice : i_slice + 1, ...],
delta_o_patches,
self.forward_model.intermediate_variables.shifted_unique_probes[i_slice],
object_.optimizer_params["lr"],
)

delta_p_i = None
if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)):
if (self.parameter_group.probe.representation == "sparse_code"):
# TODO: move this into SynthesisDictLearnProbe class
rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2]
n_scpm = delta_exwv_i.shape[-3]
n_spos = delta_exwv_i.shape[-4]

sparse_code = sparse_code * (abs_sparse_code >= sel)

# Update sparse code in probe object
self.parameter_group.probe.set_sparse_code(sparse_code)
else:
step_weight = self.calculate_probe_step_weight(obj_patches)
delta_p_i = step_weight * (psi_prime - psi) # get delta p at each position
delta_p_i = self.adjoint_shift_probe_update_direction(indices, delta_p_i, first_mode_only=True)

# Calculate and apply opr mode updates
if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch):
opr_mode_weights.update_variable_probe(
probe,
indices,
psi_prime - psi,
delta_p_i,
delta_p_i.mean(0),
obj_patches,
self.current_epoch,
probe_mode_index=0,
)
obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc ))
abs2_obj_patches = torch.abs(obj_patches_vec) ** 2

z = torch.sum(abs2_obj_patches, dim = 0)
z_max = torch.max(z)
w = self.parameter_group.probe.options.alpha * (z_max - z)
z_plus_w = torch.swapaxes(z + w, 0, 1)

delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True)
delta_exwv = torch.sum(delta_exwv, 0)
delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T

denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix))
numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv

delta_sparse_code = torch.linalg.solve(denom, numer)

delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code
delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2]))
delta_p_i = torch.tile(delta_p, (n_spos,1,1,1))

# sparse code update
sparse_code = self.parameter_group.probe.get_sparse_code_weights()
sparse_code = sparse_code + delta_sparse_code

# Enforce sparsity constraint on sparse code
abs_sparse_code = torch.abs(sparse_code)
sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True)

sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :]

# hard thresholding:
sparse_code = sparse_code * (abs_sparse_code >= sel)

#(TODO: soft thresholding option)

# Update the new sparse code in the probe class
self.parameter_group.probe.set_sparse_code(sparse_code)
else:
step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...]))
delta_p_i = step_weight * delta_exwv_i # get delta p at each position

# Undo subpixel shift in probe update directions.
delta_p_i = self.adjoint_shift_probe_update_direction(indices, delta_p_i, first_mode_only=True)

# Calculate and apply opr mode updates
if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch):
opr_mode_weights.update_variable_probe(
probe,
indices,
delta_exwv_i,
delta_p_i,
delta_p_i.mean(0),
obj_patches,
self.current_epoch,
probe_mode_index=0,
)

if i_slice > 0:
delta_exwv_i = delta_exwv_i * obj_patches[:, i_slice : i_slice + 1,...].conj()
delta_exwv_i = self.forward_model.propagate_to_previous_slice(delta_exwv_i, slice_index=i_slice)

return (delta_o, delta_p_i, delta_pos), y

Expand Down
2 changes: 1 addition & 1 deletion tests/test_2d_ptycho_rpie_synthesisdictlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_2d_ptycho_rpie_synthesisdictlearn(self):
options.probe_position_options.optimizable = False

options.reconstructor_options.batch_size = round(data.shape[0] * 0.1)
options.reconstructor_options.num_epochs = 50
options.reconstructor_options.num_epochs = 32
options.reconstructor_options.allow_nondeterministic_algorithms = False

task = PtychographyTask(options)
Expand Down
98 changes: 98 additions & 0 deletions tests/test_multislice_ptycho_rpie_synthesisdictlearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
import os

import torch
import numpy as np

import ptychi.api as api
from ptychi.api.task import PtychographyTask
from ptychi.utils import get_suggested_object_size, get_default_complex_dtype
from ptychi.utils import generate_initial_opr_mode_weights

import test_utils as tutils


class TestMultislicePtychoRPIESDL(tutils.TungstenDataTester):
@tutils.TungstenDataTester.wrap_recon_tester(name="test_multislice_ptycho_rpie_synthesisdictlearn")
def test_multislice_ptycho_rpie_synthesisdictlearn(self):
self.setup_ptychi(cpu_only=False)

data, probe, pixel_size_m, positions_px = self.load_tungsten_data(additional_opr_modes=3)

npz_dict_file = np.load(
os.path.join(
self.get_ci_input_data_dir(), "zernike2D_dictionaries", "testing_sdl_dictionary.npz"
)
)
D = npz_dict_file["D"]
D_pinv = npz_dict_file["D_pinv"]
npz_dict_file.close()

options = api.RPIEOptions()
options.data_options.data = data

Nslices = 2
options.object_options.initial_guess = torch.ones(
[Nslices, *get_suggested_object_size(positions_px, probe.shape[-2:], extra=100)],
dtype=get_default_complex_dtype(),
)
options.object_options.pixel_size_m = pixel_size_m
options.object_options.slice_spacings_m = (1e-5 / ( Nslices - 1)) * np.array( [1] * (Nslices - 1)).astype('float32')
options.object_options.optimizable = True
options.object_options.optimizer = api.Optimizers.SGD
options.object_options.step_size = 1e-2
options.object_options.alpha = 5e-1

options.object_options.multislice_regularization.enabled = True
options.object_options.multislice_regularization.weight = 0.01
options.object_options.multislice_regularization.unwrap_phase = True
options.object_options.multislice_regularization.unwrap_image_grad_method = api.enums.ImageGradientMethods.FOURIER_DIFFERENTIATION
options.object_options.multislice_regularization.unwrap_image_integration_method = api.enums.ImageIntegrationMethods.FOURIER

options.probe_options.initial_guess = probe
options.probe_options.optimizable = True
options.probe_options.optimizer = api.Optimizers.SGD
options.probe_options.orthogonalize_incoherent_modes.enabled = True
options.probe_options.step_size = 1e-0
options.probe_options.alpha = 9e-1

options.probe_options.experimental.sdl_probe_options.enabled = True
options.probe_options.experimental.sdl_probe_options.d_mat = np.asarray(
D, dtype=np.complex64
)
options.probe_options.experimental.sdl_probe_options.d_mat_conj_transpose = np.conj(
options.probe_options.experimental.sdl_probe_options.d_mat
).T
options.probe_options.experimental.sdl_probe_options.d_mat_pinv = D_pinv
options.probe_options.experimental.sdl_probe_options.probe_sparse_code_nnz = np.round(
0.50 * D.shape[-1]
)

options.probe_position_options.position_x_px = positions_px[:, 1]
options.probe_position_options.position_y_px = positions_px[:, 0]
options.probe_position_options.optimizable = False

options.opr_mode_weight_options.optimizable = True
options.opr_mode_weight_options.initial_weights = generate_initial_opr_mode_weights( len(positions_px), probe.shape[0] )
options.opr_mode_weight_options.optimization_plan.stride = 1
options.opr_mode_weight_options.update_relaxation = 1e-2

options.reconstructor_options.batch_size = round(data.shape[0] * 0.1)
options.reconstructor_options.num_epochs = 32
options.reconstructor_options.allow_nondeterministic_algorithms = False

task = PtychographyTask(options)
task.run()

recon = task.get_data_to_cpu("object", as_numpy=True)
return recon


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--generate-gold", action="store_true")
args = parser.parse_args()

tester = TestMultislicePtychoRPIESDL()
tester.setup_method(name="", generate_data=False, generate_gold=args.generate_gold, debug=True)
tester.test_multislice_ptycho_rpie_synthesisdictlearn()