Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4cce5ae
Merge pull request #477 from ptycho/hotfixes
daurer Feb 10, 2023
6949de8
initial implementation of multislice ePIE
daurer Aug 23, 2023
97a51e4
code runs but probably still bug for slices > 1
daurer Aug 23, 2023
b35c3de
save slices infomation and fix the update loop
yiranlus Oct 24, 2023
67eb25a
object as product of all slices at each iteration
yiranlus Oct 26, 2023
6a7e8f9
iterating over pods to allow for modes
kahntm Oct 31, 2023
8da41aa
swapped loops
kahntm Nov 1, 2023
c9e71b1
added the 3PIE article to the engine
kahntm Nov 1, 2023
d4b178b
renamed the engine to match the algorithm name in the literature
kahntm Nov 1, 2023
aa97276
renamed the file to match the engine and algorithm name
kahntm Nov 1, 2023
a4b83ec
python convention for class names
kahntm Nov 2, 2023
55a7b26
file name as class name
kahntm Nov 2, 2023
7a65a77
added semi-functioning switching on of slices at arbitrary iterations
kahntm Nov 5, 2023
90058d1
allow non equal slice spacing
yiranlus Nov 6, 2023
391cc4f
add object regularisation
yiranlus Jan 26, 2024
e412764
Merge pull request #473 from ptycho/dev
bjoernenders Feb 5, 2024
381065c
Merge branch 'multi-slice-epie' of https://github.com/ptycho/ptypy in…
kahntm Mar 11, 2024
c601059
Merge pull request #542 from ptycho/dev
daurer Mar 11, 2024
36b9370
Wrap nccl.get_unique_id in try/except (#549)
daurer Mar 21, 2024
e76ea2d
Merge pull request #561 from ptycho/dev
daurer Aug 29, 2024
4dea7a9
Merge pull request #576 from ptycho/dev
daurer Sep 5, 2024
1681bec
Merge pull request #610 from ptycho/dev
daurer May 9, 2025
cdf33cc
Add option to provide defocus for nearfield data (#620)
daurer Jul 18, 2025
f3e160a
Merge remote-tracking branch 'origin/master' into multi-slice-epie
ltang320 Sep 10, 2025
a1180a1
add GPU version 3pie file
ltang320 Sep 15, 2025
e22d520
add threepin gpu version test
ltang320 Sep 15, 2025
5550e01
Remove test three_pie_cupy and start a new branch
ltang320 Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions ptypy/custom/ThreePIE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# -*- coding: utf-8 -*-
"""
A simple implementation of Multislice for the
ePIE algorithm.

authors: Benedikt J. Daurer and more...
"""
from ptypy.engines import stochastic
from ptypy.engines import register
from ptypy.core import geometry
from ptypy.utils import Param
from ptypy.utils.verbose import logger
from ptypy import io
import numpy as np

@register()
class ThreePIE(stochastic.EPIE):
"""
An extension of EPIE to include multislice

Defaults:

[name]
default = ThreePIE
type = str
help =
doc =

[number_of_slices]
default = 2
type = int
help = The number of slices
doc = Defines how many slices are used for the multi-slice object.

[slice_thickness]
default = 1e-6
type = float, list, tuple
help = Thickness of a single slice in meters
doc = A single float value or a list of float values. If a single value is used, all the slice will be assumed to be of the same thickness.

[slice_start_iteration]
default = 0
type = int, list, tuple
help = iteration number to start using a specific slice
doc =

[fslices]
default = slices.h5
type = str
help = File path for the slice data
doc =

"""
def __init__(self, ptycho_parent, pars=None):
super(ThreePIE, self).__init__(ptycho_parent, pars)
self.article = dict(
title='{Ptychographic transmission microscopy in three dimensions using a multi-slice approach',
author='A. M. Maiden et al.',
journal='J. Opt. Soc. Am. A',
volume=29,
year=2012,
page=1606,
doi='10.1364/JOSAA.29.001606',
comment='The 3PIE reconstruction algorithm',
)
self.ptycho.citations.add_article(**self.article)

def engine_initialize(self):
super().engine_initialize()

# Create a list of objects and exit waves (one for each slice)
self._object = [None] * self.p.number_of_slices
self._probe = [None] * self.p.number_of_slices
self._exits = [None] * self.p.number_of_slices
for i in range(self.p.number_of_slices):
self._object[i] = self.ob.copy(self.ob.ID + "_o_" + str(i))
self._probe[i] = self.pr.copy(self.pr.ID + "_p_" + str(i))
self._exits[i] = self.pr.copy(self.pr.ID + "_e_" + str(i))

# ToDo:
# - allow for non equal slice spacing
# - allow for start_slice_update at a freely chosen iteration
# for each slice separately - works, but not if the
# most downstream slice is switched off

if isinstance(self.p.slice_start_iteration, int):
self.p.slice_start_iteration = np.ones(self.p.number_of_slices) * self.p.slice_start_iteration
#if ĺen(self.p.slice_start_iteration) != self.p.number_of_slices:
# logger.info(f'dimension of given slice_start_iteration ({ĺen(self.p.slice_start_iteration)}) does not match number of slices ({self.p.number_of_slices})')

scan = list(self.ptycho.model.scans.values())[0]
geom = scan.geometries[0]
g = Param()
g.energy = geom.energy
g.distance = self.p.slice_thickness
g.psize = geom.resolution
g.shape = geom.shape
g.propagation = "nearfield"

self.fw = []
self.bw = []
if type(self.p.slice_thickness) in [list, tuple]:
assert(len(self.p.slice_thickness) == self.p.number_of_slices-1)
for thickness in self.p.slice_thickness:
g.distance = thickness
G = geometry.Geo(owner=None, pars=g)
self.fw.append(G.propagator.fw)
self.bw.append(G.propagator.bw)
else:
g.distance = self.p.slice_thickness
G = geometry.Geo(owner=None, pars=g)
self.fw = [G.propagator.fw for i in range(self.p.number_of_slices-1)]
self.bw = [G.propagator.bw for i in range(self.p.number_of_slices-1)]

def engine_iterate(self, num=1):
"""
Compute one iteration.
"""
vieworder = list(self.di.views.keys())
vieworder.sort()
rng = np.random.default_rng()

for it in range(num):

error_dct = {}
rng.shuffle(vieworder)

for name in vieworder:
view = self.di.views[name]
if not view.active:
continue

# Multislice update
error_dct[name] = self.multislice_update(view)

self.curiter += 1

return error_dct

def engine_finalize(self):
self.ob.fill(self._object[0])
for i in range(1, self.p.number_of_slices):
self.ob *= self._object[i]

# Save the slices
slices_info = Param()
slices_info.number_of_slices = self.p.number_of_slices
slices_info.slice_thickness = self.p.slice_thickness
slices_info.objects = {ob.ID: {ID: S._to_dict() for ID, S in ob.storages.items()}
for ob in self._object}
slices_info.slice_start_iteration = self.p.slice_start_iteration

header = {'description': 'multi-slices result details.'}

h5opt = io.h5options['UNSUPPORTED']
io.h5options['UNSUPPORTED'] = 'ignore'
logger.info(f'Saving to {self.p.fslices}')
io.h5write(self.p.fslices, header=header, content=slices_info)
io.h5options['UNSUPPORTED'] = h5opt

return super().engine_finalize()

def multislice_update(self, view):
"""
Performs one 'iteration' of 3PIE (multislice ePIE) for a single view.
Based on https://doi.org/10.1364/JOSAA.29.001606
"""

for i in range(self.p.number_of_slices-1):
for name, pod in view.pods.items():
# exit wave for this slice
if self.curiter >= self.p.slice_start_iteration[i]:
self._exits[i][pod.pr_view] = self._probe[i][pod.pr_view] * self._object[i][pod.ob_view]
else:
self._exits[i][pod.pr_view] = self._probe[i][pod.pr_view] * 1.
# incident wave for next slice
self._probe[i+1][pod.pr_view] = self.fw[i](self._exits[i][pod.pr_view])

for name, pod in view.pods.items():
# Exit wave for last slice
if self.curiter >= self.p.slice_start_iteration[-1]:
self._exits[-1][pod.pr_view] = self._probe[-1][pod.pr_view] * self._object[-1][pod.ob_view]
else:
self._exits[-1][pod.pr_view] = self._probe[-1][pod.pr_view] * 1.
# Save final state into pod (need for ptypy fourier update)
pod.probe = self._probe[-1][pod.pr_view]
pod.object = self._object[-1][pod.ob_view]
pod.exit = self._exits[-1][pod.pr_view]

# Fourier update
error = self.fourier_update(view)

# Object/probe update for the last slice
if self.curiter >= self.p.slice_start_iteration[-1]:
self.object_update(view, {pod.ID:self._exits[-1][pod.pr_view] for name, pod in view.pods.items()})
self.probe_update(view, {pod.ID:self._exits[-1][pod.pr_view] for name, pod in view.pods.items()})
for name, pod in view.pods.items():
self._object[-1][pod.ob_view] = pod.object
self._probe[-1][pod.pr_view] = pod.probe
else:
for name, pod in view.pods.items():
self._probe[-1][pod.pr_view] = pod.exit * 1.

# Object/probe update for other slices (backwards)
for i in range(self.p.number_of_slices-2, -1, -1):
if self.curiter >= self.p.slice_start_iteration[i]:

for name, pod in view.pods.items():
# Backwards propagation of the probe
pod.exit = self.bw[i](self._probe[i+1][pod.pr_view])
# Save state into pods
pod.probe = self._probe[i][pod.pr_view]
pod.object = self._object[i][pod.ob_view]

# Actual object/probe update
self.object_update(view, {pod.ID:self._exits[i][pod.pr_view] for name, pod in view.pods.items()})
self.probe_update(view, {pod.ID:self._exits[i][pod.pr_view] for name, pod in view.pods.items()})
for name, pod in view.pods.items():
self._object[i][pod.ob_view] = pod.object
self._probe[i][pod.pr_view] = pod.probe
else:
for name, pod in view.pods.items():
self._probe[i][pod.pr_view] = self.bw[i](self._probe[i+1][pod.pr_view])

# set the object as the product of all slices for better live plotting
self.ob.fill(self._object[0])
for i in range(1, self.p.number_of_slices):
self.ob *= self._object[i]

return error
34 changes: 33 additions & 1 deletion ptypy/custom/threepie.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class ThreePIE(stochastic.EPIE):
help = File path for the slice data
doc =

[object_regularization_rate]
default = 0.0
type = float
help = regularization rate for object slices
doc =

"""
def __init__(self, ptycho_parent, pars=None):
super(ThreePIE, self).__init__(ptycho_parent, pars)
Expand Down Expand Up @@ -227,4 +233,30 @@ def multislice_update(self, view):
for i in range(1, self.p.number_of_slices):
self.ob *= self._object[i]

return error
if self.p.object_regularization_rate > 0:
self.apply_object_regularization()
Comment on lines +236 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably make sense to move this a bit higher up to make sure that self.ob is calculated for plotting after the regulariser has been applied...


return error

def apply_object_regularization(self):
# single mode implementation
# only valide for slices with identical thickness
assert(self.p.number_of_slices > 1)
assert(isinstance(self.p.slice_thickness, float))

shape = self._object[0].S["Sscan_00G00"].data.shape[1:]
psize = self._object[0].S["Sscan_00G00"].psize[0]
kz = np.fft.fftfreq(self.p.number_of_slices, self.p.slice_thickness)[..., np.newaxis, np.newaxis]
ky = np.fft.fftfreq(shape[0], psize)[..., np.newaxis]
kx = np.fft.fftfreq(shape[1], psize)

# calculate the weight array
w = 1 - 2*np.arctan2(self.p.object_regularization_rate**2 * kz**2, kx**2+ky**2+np.spacing(1))/np.pi
Comment on lines +247 to +254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is basically just calculating some weights w and does not depend on any current update so can be moved into a separate function, e.g. initialize_regularizer and called once in the constructor. In this current implementation the weights are re-calculated for every iteration which seems unnecessary.


current_object = np.fft.ifftn(np.fft.fftn([self._object[i].S["Sscan_00G00"].data[0,...] for i in range(len(self._object))]) * w)

print("object shape", self._object[0].S["Sscan_00G00"].data.shape)
print("w shape", w.shape)
print("current shape", current_object.shape)
for i in range(len(self._object)):
self._object[i].S["Sscan_00G00"].data[0, ...] = current_object[i, ...]
24 changes: 24 additions & 0 deletions ptypy/experiment/hdf5_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ class Hdf5Loader(PtyScan):
type = list, ndarray
help = This is the array or list with the re-ordered indices.

[nearfield_defocus]
default = None
type = float
help = Distance from sample to focus (for nearfield only)
doc = If set, magnification will be calculated automatically and applied to detector distance and pixelsize

"""

def __init__(self, pars=None, swmr=False, **kwargs):
Expand Down Expand Up @@ -378,6 +384,10 @@ def __init__(self, pars=None, swmr=False, **kwargs):
if self.p.electron_data:
self.meta.energy = u.m2keV(u.electron_wavelength(self.meta.energy))

# For nearfield data, manipulate distance and psize
if self.p.nearfield_defocus:
self._prepare_nearfield()

# it's much better to have this logic here than in load!
if (self._ismapped and (self._scantype == 'arb')):
log(3, "This scan looks to be a mapped arbitrary trajectory scan.")
Expand Down Expand Up @@ -590,6 +600,20 @@ def _prepare_meta_info(self):
assert self.pad.size == 4, "self.p.padding needs to of size 4"
log(3, "Padding the detector frames by {}".format(self.p.padding))

def _prepare_nearfield(self):
"""
Calculate magnification and modify distance and psize
"""
defocus = self.p.nearfield_defocus
mag = self.meta.distance / defocus
dist_eff = (self.meta.distance - defocus) / mag
psize_eff = self.info.psize / mag
log(3, f"Nearfield: With defocus {defocus} m the magmification is {mag}")
log(3, f"Nearfield: The effective detector distance is {dist_eff}")
log(3, f"Nearfield: The effective pixel size is {psize_eff}")
self.meta.distance = dist_eff
self.info.psize = psize_eff

def _prepare_center(self):
"""
define how data should be loaded (center, cropping)
Expand Down
65 changes: 65 additions & 0 deletions templates/misc/moonflower_ePIE_multislice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
This script is a test for ptychographic reconstruction in the absence
of actual data. It uses the test Scan class
`ptypy.core.data.MoonFlowerScan` to provide "data".
"""
from ptypy.core import Ptycho
from ptypy import utils as u
from ptypy.custom import ePIE_multislice

import tempfile
tmpdir = tempfile.gettempdir()

p = u.Param()

# for verbose output
p.verbose_level = "info"

# set home path
p.io = u.Param()
p.io.home = "/".join([tmpdir, "ptypy"])

# saving intermediate results
p.io.autosave = u.Param(active=False)

# opens plotting GUI if interaction set to active)
p.io.autoplot = u.Param(active=True)
p.io.interaction = u.Param(active=True)

# max 200 frames (128x128px) of diffraction data
p.scans = u.Param()
p.scans.MF = u.Param()
# now you have to specify which ScanModel to use with scans.XX.name,
# just as you have to give 'name' for engines and PtyScan subclasses.
p.scans.MF.name = 'GradFull'
p.scans.MF.data= u.Param()
p.scans.MF.data.name = 'MoonFlowerScan'
p.scans.MF.data.shape = 128
p.scans.MF.data.num_frames = 200
p.scans.MF.data.save = None

# position distance in fraction of illumination frame
p.scans.MF.data.density = 0.2
# total number of photon in empty beam
p.scans.MF.data.photons = 1e8
# Gaussian FWHM of possible detector blurring
p.scans.MF.data.psf = 0.

# attach a reconstrucion engine
p.engines = u.Param()
p.engines.engine00 = u.Param()
p.engines.engine00.name = 'ePIE_multislice'
p.engines.engine00.numiter = 200
p.engines.engine00.probe_center_tol = None
p.engines.engine00.compute_log_likelihood = True
p.engines.engine00.object_norm_is_global = True
p.engines.engine00.alpha = 1
p.engines.engine00.beta = 1
p.engines.engine00.probe_update_start = 0
p.engines.engine00.number_of_slices = 2
p.engines.engine00.slice_thickness = 60e-9

# prepare and run
if __name__ == "__main__":
P = Ptycho(p,level=5)

Loading