diff --git a/.gitignore b/.gitignore index e1d01a98..ff375543 100644 --- a/.gitignore +++ b/.gitignore @@ -118,5 +118,7 @@ core .DS_Store .nfs* +*.tiff + # MPI host files host_list diff --git a/src/broken/reg.py b/src/broken/reg.py deleted file mode 100644 index e8ed918f..00000000 --- a/src/broken/reg.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np - -def run(xp, u, mu, tau, alpha): - """Provide some kind of regularization.""" - z = fwd(xp, u) + mu / tau - # Soft-thresholding - # za = xp.sqrt(xp.sum(xp.abs(z), axis=0)) - za = xp.sqrt(xp.real(xp.sum(z*xp.conj(z), 0))) - zeros = (za <= alpha / tau) - z[:, zeros] = 0 - z[:, ~zeros] -= z[:, ~zeros] * alpha / (tau * za[~zeros]) - return z - -def fwd(xp, u): - """Forward operator for regularization (J).""" - res = xp.zeros((3, *u.shape), dtype=u.dtype, order='C') - res[0, :, :, :-1] = u[:, :, 1:] - u[:, :, :-1] - res[1, :, :-1, :] = u[:, 1:, :] - u[:, :-1, :] - res[2, :-1, :, :] = u[1:, :, :] - u[:-1, :, :] - res *= 2 / np.sqrt(3) # normalization - return res - -def adj(xp, gr): - """Adjoint operator for regularization (J^*).""" - res = xp.zeros(gr.shape[1:], gr.dtype, order='C') - res[:, :, 1:] = gr[0, :, :, 1:] - gr[0, :, :, :-1] - res[:, :, 0] = gr[0, :, :, 0] - res[:, 1:, :] += gr[1, :, 1:, :] - gr[1, :, :-1, :] - res[:, 0, :] += gr[1, :, 0, :] - res[1:, :, :] += gr[2, 1:, :, :] - gr[2, :-1, :, :] - res[0, :, :] += gr[2, 0, :, :] - res *= -2 / np.sqrt(3) # normalization - return res diff --git a/src/tike/admm/__init__.py b/src/tike/admm/__init__.py new file mode 100644 index 00000000..fa8577d8 --- /dev/null +++ b/src/tike/admm/__init__.py @@ -0,0 +1,4 @@ +from .admm import * +from .al import * +from .pal import * +from .pl import * diff --git a/src/tike/admm/admm.py b/src/tike/admm/admm.py new file mode 100644 index 00000000..1091fca8 --- /dev/null +++ b/src/tike/admm/admm.py @@ -0,0 +1,69 @@ +import logging + +import numpy as np + +import tike.align +import tike.lamino +import tike.ptycho + + +def simulate( + u, + scan, + probe, + flow, + angle, + tilt, + theta, + padded_shape, + detector_shape, +): + phi = np.exp(1j * tike.lamino.simulate( + obj=u, + tilt=tilt, + theta=theta, + )) + psi = tike.align.simulate( + original=phi, + flow=flow, + padded_shape=padded_shape, + angle=angle, + cval=1.0, + ) + data = tike.ptycho.simulate( + psi=psi, + probe=probe, + detector_shape=detector_shape, + scan=scan, + ) + return data, psi, phi + + +def print_log_line(**kwargs): + """Print keyword arguments and values on a single comma-separated line. + + The format of the line is as follows: + + ``` + foo: 003, bar: +1.234e+02, hello: world\n + ``` + + Parameters + ---------- + line: dictionary + The key value pairs to be printed. + + """ + line = [] + for k, v in kwargs.items(): + # Use special formatting for float and integers + if isinstance(v, (float, np.floating)): + line.append(f'"{k}": {v:6.3e}') + elif isinstance(v, (int, np.integer)): + line.append(f'"{k}": {v:3d}') + elif isinstance(v, str): + line.append(f'"{k}": "{v}"') + else: + line.append(f'"{k}": {v}') + # Combine all the strings and strip the last comma + print("{", ", ".join(line), "}", flush=True) diff --git a/src/tike/admm/al.py b/src/tike/admm/al.py new file mode 100644 index 00000000..053f546a --- /dev/null +++ b/src/tike/admm/al.py @@ -0,0 +1,156 @@ +import logging + +import numpy as np +import cupy as cp + +import tike.admm.subproblem +import tike.communicator + +from .admm import print_log_line + +logger = logging.getLogger(__name__) + + +def ptycho__align_lamino( + data, + psi, + scan, + probe, + theta, + tilt, + angle, + w, + flow=None, + shift=None, + niter=1, + interval=8, + folder=None, + cg_iter=4, + align_method=False, + skip_ptycho=False, +): + """Solve the joint ptycho-lamino problem using ADMM.""" + presult = { + 'psi': psi, + 'scan': scan, + 'probe': probe, + } + + u = np.zeros((w, w, w), dtype='complex64') + Hu = np.ones((len(theta), w, w), dtype='complex64') + phi = Hu + Aφ = np.ones(psi.shape, dtype='complex64') + + λ_l = np.zeros([len(theta), w, w], dtype='complex64') + ρ_l = 0.5 + + comm = tike.communicator.MPICommunicator() + + with cp.cuda.Device(comm.rank if comm.size > 1 else None): + + if not skip_ptycho: + presult, _ = tike.admm.subproblem.ptycho( + # constants + comm=comm, + data=data, + λ=None, + ρ=None, + Aφ=None, + # updated + presult=presult, + # parameters + num_iter=4 * niter, + cg_iter=cg_iter, + folder=folder, + save_result=niter + 1, + rescale=True, + rtol=1e-6, + ) + + for k in range(1, niter + 1): + logger.info(f"Start ADMM iteration {k}.") + save_result = k if k % interval == 0 else False + + ( + phi, + _, + _, + flow, + shift, + Aφ, + align_cost, + winsize, + ) = tike.admm.subproblem.align( + # constants + comm=comm, + psi=presult['psi'], + angle=angle, + Hu=Hu, + λ_l=λ_l, + ρ_l=ρ_l, + # updated + phi=phi, + λ_p=None, + ρ_p=1, + flow=flow, + shift=shift, + Aφ0=None, + # parameters + align_method=align_method, + cg_iter=cg_iter, + num_iter=4, + folder=folder, + save_result=save_result, + winsize=winsize if k > 1 else 129, + ) + + ( + u, + λ_l, + ρ_l, + Hu, + lamino_cost, + ) = tike.admm.subproblem.lamino( + # constants + comm=comm, + phi=phi, + theta=theta, + tilt=tilt, + # updated + u=u, + λ_l=λ_l, + ρ_l=ρ_l, + Hu0=Hu, + # parameters + num_iter=4, + cg_iter=cg_iter, + folder=folder, + save_result=save_result, + ) + + # Record metrics for each subproblem + ψAφ = presult['psi'] - Aφ + φHu = phi - Hu + lagrangian = ( + [np.mean(np.real(ψAφ.conj() * ψAφ))], + [ + 2 * np.mean(np.real(λ_l.conj() * φHu)) + + ρ_l * np.mean(np.real(φHu.conj() * φHu)) + ], + [align_cost], + ) + lagrangian = [comm.gather(x) for x in lagrangian] + + if comm.rank == 0: + lagrangian = [np.sum(x) for x in lagrangian] + print_log_line( + k=k, + ρ_l=ρ_l, + winsize=winsize, + align_method=align_method, + Lagrangian=np.sum(lagrangian[:2]), + ψAφ=lagrangian[0], + φHu=lagrangian[1], + align=lagrangian[2], + lamino=float(lamino_cost), + ) diff --git a/src/tike/admm/pal.py b/src/tike/admm/pal.py new file mode 100644 index 00000000..fee38495 --- /dev/null +++ b/src/tike/admm/pal.py @@ -0,0 +1,164 @@ +import logging + +import numpy as np +import cupy as cp + +import tike.admm.subproblem +import tike.communicator + +from .admm import print_log_line + +logger = logging.getLogger(__name__) + + +def ptycho_align_lamino( + data, + psi, + scan, + probe, + theta, + tilt, + angle, + w, + flow=None, + shift=None, + niter=1, + interval=8, + folder=None, + cg_iter=4, + align_method=False, +): + """Solve the joint ptycho-lamino problem using ADMM.""" + presult = { + 'psi': psi, + 'scan': scan, + 'probe': probe, + } + + u = np.zeros((w, w, w), dtype='complex64') + Hu = np.ones((len(theta), w, w), dtype='complex64') + phi = Hu + Aφ = np.ones(psi.shape, dtype='complex64') + + λ_p = np.zeros_like(psi) + ρ_p = 0.5 + + λ_l = np.zeros([len(theta), w, w], dtype='complex64') + ρ_l = 0.5 + + comm = tike.communicator.MPICommunicator() + + with cp.cuda.Device(comm.rank if comm.size > 1 else None): + for k in range(1, niter + 1): + logger.info(f"Start ADMM iteration {k}.") + save_result = k if k % interval == 0 else False + + presult, Gψ = tike.admm.subproblem.ptycho( + # constants + comm=comm, + data=data, + λ=λ_p, + ρ=ρ_p, + Aφ=Aφ, + # updated + presult=presult, + # parameters + num_iter=4, + cg_iter=cg_iter, + folder=folder, + save_result=save_result, + rescale=(k == 1), + ) + + ( + phi, + λ_p, + ρ_p, + flow, + shift, + Aφ, + align_cost, + winsize, + ) = tike.admm.subproblem.align( + # constants + comm=comm, + psi=presult['psi'], + angle=angle, + Hu=Hu, + λ_l=λ_l, + ρ_l=ρ_l, + # updated + phi=phi, + λ_p=λ_p, + ρ_p=ρ_p, + flow=flow, + shift=shift, + Aφ0=Aφ, + # parameters + align_method=align_method, + cg_iter=cg_iter, + num_iter=4, + folder=folder, + save_result=save_result, + winsize=winsize if k > 1 else 128, + ) + + ( + u, + λ_l, + ρ_l, + Hu, + lamino_cost, + ) = tike.admm.subproblem.lamino( + # constants + comm=comm, + phi=phi, + theta=theta, + tilt=tilt, + # updated + u=u, + λ_l=λ_l, + ρ_l=ρ_l, + Hu0=Hu, + # parameters + num_iter=4, + cg_iter=cg_iter, + folder=folder, + save_result=save_result, + ) + + # Record metrics for each subproblem + ψAφ = presult['psi'] - Aφ + φHu = phi - Hu + lagrangian = ( + [np.mean(np.square(data - Gψ))], + [ + 2 * np.mean(np.real(λ_p.conj() * ψAφ)) + + ρ_p * np.mean(np.real(ψAφ.conj() * ψAφ)) + ], + [ + 2 * np.mean(np.real(λ_l.conj() * φHu)) + + ρ_l * np.mean(np.real(φHu.conj() * φHu)) + ], + [presult['cost']], + [align_cost], + ) # yapf: disable + + lagrangian = [comm.gather(x) for x in lagrangian] + + if comm.rank == 0: + lagrangian = [np.mean(x) for x in lagrangian] + print_log_line( + k=k, + ρ_p=ρ_p, + ρ_l=ρ_l, + winsize=winsize, + align_method=align_method, + Lagrangian=np.sum(lagrangian[:3]), + dGψ=lagrangian[0], + ψAφ=lagrangian[1], + φHu=lagrangian[2], + ptycho=lagrangian[3], + align=lagrangian[4], + lamino=float(lamino_cost), + ) diff --git a/src/tike/admm/pl.py b/src/tike/admm/pl.py new file mode 100644 index 00000000..e3fbe810 --- /dev/null +++ b/src/tike/admm/pl.py @@ -0,0 +1,154 @@ +import logging + +import numpy as np +import cupy as cp + +import tike.admm.subproblem +import tike.align +import tike.communicator + +from .admm import print_log_line + +logger = logging.getLogger(__name__) + + +def ptycho_lamino( + data, + psi, + scan, + probe, + theta, + tilt, + angle, + w, + flow=False, + shift=None, + niter=1, + interval=8, + folder=None, + cg_iter=4, + align_method=False, +): + """Solve the joint ptycho-lamino problem using ADMM.""" + presult = { + 'psi': psi, + 'scan': scan, + 'probe': probe, + } + + u = np.zeros((w, w, w), dtype='complex64') + Hu = np.ones_like(psi) + λ_p = np.zeros_like(psi) + ρ_p = 0.5 + + comm = tike.communicator.MPICommunicator() + + with cp.cuda.Device(comm.rank if comm.size > 1 else None): + for k in range(1, niter + 1): + logger.info(f"Start ADMM iteration {k}.") + save_result = k if k % interval == 0 else False + + presult, Gψ = tike.admm.subproblem.ptycho( + comm, + # constants + data, + λ=λ_p, + ρ=ρ_p, + Aφ=Hu, + # updated + presult=presult, + # parameters + num_iter=4, + cg_iter=cg_iter, + folder=folder, + save_result=save_result, + rescale=(k == 1), + ) + + phi = tike.align.invert( + presult['psi'], + angle=angle, + flow=None, + shift=shift, + unpadded_shape=(len(theta), w, w), + cval=1.0, + ) + Hu = tike.align.invert( + Hu, + angle=angle, + flow=None, + shift=shift, + unpadded_shape=(len(theta), w, w), + cval=1.0, + ) + λ_p = tike.align.invert( + λ_p, + angle=angle, + flow=None, + shift=shift, + unpadded_shape=(len(theta), w, w), + cval=1.0, + ) + + ( + u, + λ_p, + ρ_p, + Hu, + lamino_cost, + ) = tike.admm.subproblem.lamino( + # constants + comm, + phi, + theta, + tilt, + # updated + u, + λ_p, + ρ_p, + Hu0=Hu, + # parameters + num_iter=4, + cg_iter=cg_iter, + folder=folder, + save_result=save_result, + ) + + Hu = tike.align.simulate( + Hu, + angle=angle, + flow=None, + shift=shift, + padded_shape=psi.shape, + cval=1.0, + ) + λ_p = tike.align.simulate( + λ_p, + angle=angle, + flow=None, + shift=shift, + padded_shape=psi.shape, + cval=1.0, + ) + + # Record metrics for each subproblem + ψHu = presult['psi'] - Hu + lagrangian = ( + [np.mean(np.square(data - Gψ))], + [ + 2 * np.real(λ_p.conj() * ψHu) + + ρ_p * np.linalg.norm(ψHu.ravel())**2 + ], + ) + lagrangian = [comm.gather(x) for x in lagrangian] + + if comm.rank == 0: + lagrangian = [np.sum(x) for x in lagrangian] + print_log_line( + k=k, + ρ_p=ρ_p, + Lagrangian=np.sum(lagrangian), + dGψ=lagrangian[0], + ψHu=lagrangian[1], + lamino=float(lamino_cost), + ) diff --git a/src/tike/admm/subproblem/__init__.py b/src/tike/admm/subproblem/__init__.py new file mode 100644 index 00000000..6ef448b9 --- /dev/null +++ b/src/tike/admm/subproblem/__init__.py @@ -0,0 +1,37 @@ +"""Implements subproblem formulations for ADMM. + +Each ADMM subproblem is implemented in a separate function such that the +problems are consistently implemented across different ADMM compositions. + +""" + + +def update_penalty(comm, g, h, h0, rho, diff=4): + """Increase rho when L2 error between g and h becomes too large. + + If rho is the penalty parameter associated with the constraint norm(g - h), + then rho is increased when + + norm(g - h) > diff * rho^2 * norm(h - h0) + + and decreased when + + norm(g - h) * diff < rho^2 * norm(h - h0) + + """ + r = np.linalg.norm(g - h)**2 + s = rho * rho * np.linalg.norm(h - h0)**2 + r, s = [np.sum(comm.gather(x)) for x in ([r], [s])] + if comm.rank == 0: + if (r > diff * s): + rho *= 2 + elif (r * diff < s): + rho *= 0.5 + rho = comm.broadcast(rho) + logging.info(f"Update penalty parameter ρ = {rho}.") + return rho + + +from .align import * +from .lamino import * +from .ptycho import * diff --git a/src/tike/admm/subproblem/align.py b/src/tike/admm/subproblem/align.py new file mode 100644 index 00000000..1c295a07 --- /dev/null +++ b/src/tike/admm/subproblem/align.py @@ -0,0 +1,238 @@ +import logging + +import dxchange +import numpy as np + +import tike.align +from . import update_penalty + +logger = logging.getLogger(__name__) + + +def _find_min_max(data): + mmin = np.zeros(data.shape[0], dtype='float32') + mmax = np.zeros(data.shape[0], dtype='float32') + + for k in range(data.shape[0]): + h, e = np.histogram(data[k][:], 1000) + stend = np.where(h > np.max(h) * 0.005) + st = stend[0][0] + end = stend[0][-1] + mmin[k] = e[st] + mmax[k] = e[end + 1] + + return mmin, mmax + + +def _optical_flow_tvl1(unaligned, original, num_iter=16): + """Wrap scikit-image optical_flow_tvl1 for complex values""" + from skimage.registration import optical_flow_tvl1 + pflow = [ + optical_flow_tvl1( + np.angle(original[i]), + np.angle(unaligned[i]), + num_iter=num_iter, + ) for i in range(len(original)) + ] + # rflow = [ + # optical_flow_tvl1( + # original[i].real, + # unaligned[i].real, + # num_iter=num_iter, + # ) for i in range(len(original)) + # ] + flow = np.array(pflow, dtype='float32') + #+ np.array(iflow, dtype='float32')) /2 + flow = np.moveaxis(flow, 1, -1) + return flow + + +def _center_of_mass(m, axis=None): + """Return the center of mass of m along the given axis. + + Parameters + ---------- + m : array + Values to find the center of mass from + axis : tuple(int) + The axes to find center of mass along. + + Returns + ------- + center : (..., len(axis)) array[int] + The shape of center is the shape of m with the dimensions corresponding + to axis removed plus a new dimension appended whose length is the + of length of axis in the order of axis. + + """ + centers = [] + for a in range(m.ndim) if axis is None else axis: + shape = np.ones_like(m.shape) + shape[a] = m.shape[a] + x = np.arange(1, m.shape[a] + 1).reshape(*shape).astype(m.dtype) + centers.append((m * x).sum(axis=axis) / m.sum(axis=axis) - 1) + return np.stack(centers, axis=-1) + + +def align( + # constants + comm, + psi, + angle, + Hu, + λ_l, + ρ_l, + # updated + phi, + λ_p, + ρ_p, + flow, + shift, + Aφ0=None, + # parameters + align_method=False, + cg_iter=1, + num_iter=1, + folder=None, + save_result=False, + winsize=0, +): + """ + Parameters + ---------- + psi + ptychography result. psi = A(phi) + angle + alignment rotation angle + Hu + Forward model of tomography phi = Hu + """ + + logging.info("Solve alignment subproblem.") + + save_result = False if folder is None else save_result + + aresult = tike.align.reconstruct( + unaligned=psi if λ_p is None else psi + λ_p / ρ_p, + original=phi, + flow=flow, + shift=shift, + angle=angle, + num_iter=cg_iter * num_iter, + algorithm='cgrad', + reg=Hu - λ_l / ρ_l, + rho_p=ρ_p, + rho_a=ρ_l, + cval=1.0, + ) + phi = aresult['original'] + cost = aresult['cost'] + + if align_method: + + # TODO: Try combining rotation and flow because they use the same + # interpolator + rotated = tike.align.invert( + psi if λ_p is None else psi + λ_p / ρ_p, + angle=angle, + flow=None, + shift=None, + unpadded_shape=None, + cval=1.0, + ) + padded = tike.align.simulate( + phi, + angle=None, + flow=None, + shift=None, + padded_shape=psi.shape, + cval=1.0, + ) + + if comm.rank == 0 and save_result: + dxchange.write_tiff( + np.angle(rotated), + f'{folder}/rotated-angle-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + dxchange.write_tiff( + np.angle(padded), + f'{folder}/padded-angle-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + + if align_method.lower() == 'flow': + if shift is not None: + flow = np.zeros((*rotated.shape, 2), dtype='float32') + flow[..., :] = shift[..., None, None, :] + shift = None + hi, lo = _find_min_max(np.angle(rotated)) + winsize = max(winsize - 2, 31) + logging.info("Estimate alignment using Farneback.") + fresult = tike.align.solvers.farneback( + op=None, + unaligned=np.angle(padded), + original=np.angle(rotated), + flow=flow if flow is None else -flow, + pyr_scale=0.5, + levels=4, + winsize=winsize, + num_iter=32, + hi=hi, + lo=lo, + ) + flow = -fresult['flow'] + elif align_method.lower() == 'tvl1': + logging.info("Estimate alignment using TV-L1.") + flow = -_optical_flow_tvl1( + unaligned=padded, + original=rotated, + num_iter=cg_iter, + ) + elif align_method.lower() == 'xcor': + logging.info("Estimate rigid alignment with cross correlation.") + sresult = tike.align.reconstruct( + algorithm='cross_correlation', + unaligned=padded, + original=rotated, + upsample_factor=100, + reg_weight=0.0, + ) + shift = -sresult['shift'] + else: + logging.info("Estimate rigid alignment with center of mass.") + centers = _center_of_mass(np.abs(np.angle(rotated)), axis=(-2, -1)) + # shift is defined from padded coords to rotated coords + shift = centers - np.array(rotated.shape[-2:]) / 2 + + Aφ = tike.align.simulate( + phi, + angle=angle, + flow=flow, + shift=shift, + padded_shape=psi.shape, + cval=1.0, + ) + + logger.info("Update alignment lambdas and rhos") + + if λ_p is not None: + λ_p += ρ_p * (psi - Aφ) + + if Aφ0 is not None: + ρ_p = update_penalty(comm, psi, Aφ, Aφ0, ρ_p) + + Aφ0 = Aφ + + return ( + phi, + λ_p, + ρ_p, + flow, + shift, + Aφ0, + cost, + winsize, + ) diff --git a/src/tike/admm/subproblem/lamino.py b/src/tike/admm/subproblem/lamino.py new file mode 100644 index 00000000..b20762d5 --- /dev/null +++ b/src/tike/admm/subproblem/lamino.py @@ -0,0 +1,137 @@ +import logging + +import dxchange +import numpy as np + +import tike.lamino +from . import update_penalty + +logger = logging.getLogger(__name__) + + +def lamino( + # constants + comm, + phi, + theta, + tilt, + # updated + u, + λ_l, + ρ_l, + Hu0=None, + # parameters + num_iter=1, + cg_iter=1, + folder=None, + save_result=False, +): + """Solver the laminography subproblem. + + Parameters + ---------- + phi + Exponentiated projections through the object, u. + u + Refractive indices of the object. + theta + Rotation angle of each projection, phi + tilt + The off-rotation axis angle of laminography + + """ + logger.info('Solve the laminography problem.') + + save_result = False if folder is None else save_result + + # Gather all to one process + λ_l, phi, theta = [comm.gather(x) for x in (λ_l, phi, theta)] + + cost, Hu = None, None + if comm.rank == 0: + if save_result: + # We cannot reorder phi, theta without ruining correspondence + # with data, psi, etc, but we can reorder the saved array + order = np.argsort(theta) + dxchange.write_tiff( + np.angle(phi[order]), + f'{folder}/phi-angle-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + dxchange.write_tiff( + np.abs(phi[order]), + f'{folder}/phi-abs-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + + lresult = tike.lamino.reconstruct( + data=-1j * np.log(phi + λ_l / ρ_l), + theta=theta, + tilt=tilt, + obj=u, + algorithm='cgrad', + num_iter=num_iter, + cg_iter=cg_iter, + # FIXME: Communications overhead makes 1 GPU faster than 8. + num_gpu=1, # comm.size, + upsample=2, + ) + u = lresult['obj'] + cost = lresult['cost'][-1] + + # FIXME: volume becomes too large to fit in MPI buffer. + # Used to broadcast u, now broadcast only Hu + # u = comm.broadcast(u) + Hu = np.exp(1j * tike.lamino.simulate( + obj=u, + tilt=tilt, + theta=theta, + )) + + # Separate again to multiple processes + λ_l, phi, theta, Hu = [comm.scatter(x) for x in (λ_l, phi, theta, Hu)] + + logger.info('Update laminography lambdas and rhos.') + + λ_l += ρ_l * (phi - Hu) + + if Hu0 is not None: + ρ_l = update_penalty(comm, phi, Hu, Hu0, ρ_l) + + Hu0 = Hu + + if comm.rank == 0 and save_result: + dxchange.write_tiff( + u.real, + f'{folder}/object-real-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + dxchange.write_tiff( + u.imag, + f'{folder}/object-imag-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + dxchange.write_tiff( + np.angle(Hu), + f'{folder}/Hu-angle-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + dxchange.write_tiff( + np.abs(Hu), + f'{folder}/Hu-abs-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + + return ( + u, + λ_l, + ρ_l, + Hu0, + cost, + ) diff --git a/src/tike/admm/subproblem/ptycho.py b/src/tike/admm/subproblem/ptycho.py new file mode 100644 index 00000000..0c0cb0b8 --- /dev/null +++ b/src/tike/admm/subproblem/ptycho.py @@ -0,0 +1,75 @@ +import logging + +import dxchange +import numpy as np + +import tike.ptycho + +logger = logging.getLogger(__name__) + + +def ptycho( + comm, + # constants + data, + λ, + ρ, + Aφ, + # updated + presult, + # parameters + num_iter=1, + cg_iter=1, + folder=None, + save_result=False, + rescale=False, + rtol=-1, +): + """Solve the ptychography subsproblem. + + Parameters + ---------- + + """ + logger.info("Solve the ptychography problem.") + + presult = tike.ptycho.reconstruct( + data=data, + reg=None if λ is None else λ / ρ - Aφ, + rho=ρ, + algorithm='cgrad', + num_iter=num_iter, + cg_iter=cg_iter, + recover_psi=True, + recover_probe=True, + recover_positions=False, + model='gaussian', + rescale=rescale, + rtol=rtol, + **presult, + ) + + logger.info("No update for ptychography lambdas and rhos") + + if save_result: + dxchange.write_tiff( + np.abs(presult['psi']), + f'{folder}/{comm.rank}-psi-abs-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + dxchange.write_tiff( + np.angle(presult['psi']), + f'{folder}/{comm.rank}-psi-angle-{save_result:03d}.tiff', + dtype='float32', + overwrite=True, + ) + + Gψ = tike.ptycho.simulate( + detector_shape=data.shape[-1], + probe=presult['probe'], + scan=presult['scan'], + psi=presult['psi'], + ) + + return presult, Gψ diff --git a/src/tike/admm/subproblem/reg.py b/src/tike/admm/subproblem/reg.py new file mode 100644 index 00000000..ab648f4e --- /dev/null +++ b/src/tike/admm/subproblem/reg.py @@ -0,0 +1,50 @@ +import logging + +import dxchange +import numpy as np + +import tike.regularization +from . import update_penalty + +logger = logging.getLogger(__name__) + + +def reg( + # constants + comm, + u, + # updated + omega, + dual, + penalty, + Ju0=None, + # parameters + folder=None, + save_result=False, +): + """Update omega, the regularized object.""" + op = tike.operators.Gradient() + + omega = tike.regularization.soft_threshold( + op, + x=u + dual=dual, + penalty=penalty, + alpha=alpha, + ) + + logger.info('Update regularization lambdas and rhos.') + + dual += penalty * (omega - Ju) + + if Ju0 is not None: + penalty = update_penalty(comm, omega, Ju, Ju0, penalty) + + Ju0 = Ju + + return ( + omega, + dual, + penalty, + Ju0, + ) diff --git a/src/tike/align/align.py b/src/tike/align/align.py index 2e271205..998ed2c1 100644 --- a/src/tike/align/align.py +++ b/src/tike/align/align.py @@ -32,6 +32,22 @@ def simulate( assert unaligned.dtype == 'complex64', unaligned.dtype return operator.asnumpy(unaligned) +def invert( + original, + **kwargs +): # yapf: disable + """Return original shifted by shift.""" + with Alignment() as operator: + for key, value in kwargs.items(): + if not isinstance(value, tuple) and np.ndim(value) > 0: + kwargs[key] = operator.asarray(value) + unaligned = operator.inv( + operator.asarray(original, dtype='complex64'), + **kwargs, + ) + assert unaligned.dtype == 'complex64', unaligned.dtype + return operator.asnumpy(unaligned) + def invert( original, diff --git a/src/tike/align/solvers/__init__.py b/src/tike/align/solvers/__init__.py index dcd40d4d..3ea3ced5 100644 --- a/src/tike/align/solvers/__init__.py +++ b/src/tike/align/solvers/__init__.py @@ -2,8 +2,10 @@ from .cross_correlation import cross_correlation from .farneback import farneback +from .cgrad import cgrad __all__ = [ "cross_correlation", "farneback", + "cgrad", ] diff --git a/src/tike/align/solvers/cgrad.py b/src/tike/align/solvers/cgrad.py new file mode 100644 index 00000000..d5db65ae --- /dev/null +++ b/src/tike/align/solvers/cgrad.py @@ -0,0 +1,54 @@ +import logging + +from tike.opt import conjugate_gradient + +logger = logging.getLogger(__name__) + + +def cgrad( + op, + original, + unaligned, + reg, rho_p, rho_a, + num_iter=4, + cost=None, + **kwargs +): # yapf: disable + """Recover an undistorted image from a given flow.""" + + def cost_function(original): + # yapf: disable + return ( + rho_p * op.xp.linalg.norm(op.xp.ravel( + unaligned - op.fwd( + original, + padded_shape=unaligned.shape, + **kwargs, + )))**2 + + rho_a * op.xp.linalg.norm(op.xp.ravel( + original - reg + ))**2 + ) + # yapf: enable + + def grad(original): + return (rho_p * op.adj( + op.fwd( + original, + padded_shape=unaligned.shape, + **kwargs, + ) - unaligned, + unpadded_shape=original.shape, + **kwargs, + ) + rho_a * (original - reg)) + + original, cost = conjugate_gradient( + op.xp, + x=original, + cost_function=cost_function, + grad=grad, + num_iter=num_iter, + ) + + logger.info('%10s cost is %+12.5e', 'original', cost) + return {'original': original, 'cost': cost} diff --git a/src/tike/align/solvers/cross_correlation.py b/src/tike/align/solvers/cross_correlation.py index b2bd28ed..c2cbd9d0 100644 --- a/src/tike/align/solvers/cross_correlation.py +++ b/src/tike/align/solvers/cross_correlation.py @@ -36,7 +36,7 @@ def cross_correlation( upsample_factor=1, space="real", num_iter=None, - reg_weight=1e-9, + reg_weight=0, ): """Efficient subpixel image translation alignment by cross-correlation. @@ -46,6 +46,13 @@ def cross_correlation( then refines the shift estimation by upsampling the DFT only in a small neighborhood of that estimate by means of a matrix-multiply DFT. + Parameters + ---------- + reg_weight: float [0, 1] + Determines how strongly the cross-correlation overlap matters. If C(x) is + the cross correlation function and A(x) is the overlap function, then + we choose the best alignment where (1 - reg)C + (reg)CA is a maximum. + References ---------- Stéfan van der Walt, Johannes L. Schönberger, Juan Nunez-Iglesias, @@ -81,11 +88,11 @@ def cross_correlation( # the cross_correlation is the same for multiple shifts. if reg_weight > 0: w = _area_overlap(op, cross_correlation) - w = op.xp.fft.fftshift(w) * reg_weight + w = reg_weight * op.xp.fft.fftshift(w) + (1 - reg_weight) else: - w = 0 + w = 1 - A = np.abs(cross_correlation) + w + A = np.abs(cross_correlation) * w maxima = A.reshape(A.shape[0], -1).argmax(1) maxima = np.column_stack(np.unravel_index(maxima, A[0, :, :].shape)) shifts = op.xp.array(maxima, dtype='float32') @@ -121,7 +128,7 @@ def cross_correlation( maxima = np.column_stack(np.unravel_index(maxima, A[0, :, :].shape)) maxima = maxima - dftshift shifts = shifts + maxima / upsample_factor - return {'shift': shifts.astype('float32'), 'cost': -1} + return {'shift': shifts.astype('float32', copy=False), 'cost': -1} def _upsampled_dft(op, data, ups, upsample_factor, axis_offsets): diff --git a/src/broken/communicator.py b/src/tike/communicator.py similarity index 82% rename from src/broken/communicator.py rename to src/tike/communicator.py index 7973cce4..d56df26b 100644 --- a/src/broken/communicator.py +++ b/src/tike/communicator.py @@ -29,32 +29,21 @@ def __init__(self): self.size = self.comm.Get_size() logger.info("Node {:,d} is running.".format(self.rank)) - def scatter(self, *args): + def scatter(self, arg, root=0): """Send and recieve constant data that must be divided.""" - if len(args) == 1: - arg = args[0] - if self.rank == 0: - chunks = np.array_split(arg, self.size) - else: - chunks = None - return self.comm.scatter(chunks, root=0) - out = list() - for arg in args: - if self.rank == 0: - chunks = np.array_split(arg, self.size) - else: - chunks = None - out.append(self.comm.scatter(chunks, root=0)) - return out + if self.rank == root: + chunks = np.array_split(arg, self.size) + else: + chunks = None + chunk = self.comm.scatter(chunks, root=root) + # logger.info(f"Scatter from node {root} to node {self.rank}.") + return chunk - def broadcast(self, *args): + def broadcast(self, arg, root=0): """Synchronize parameters that are the same for all processses.""" - if len(args) == 1: - return self.comm.bcast(args[0], root=0) - out = list() - for arg in args: - out.append(self.comm.bcast(arg, root=0)) - return out + copy = self.comm.bcast(arg, root=root) + # logger.info(f"Broadcast from node {root} to node {self.rank}.") + return copy def get_ptycho_slice(self, tomo_slice): """Switch to slicing for the pytchography problem.""" @@ -81,6 +70,8 @@ def get_tomo_slice(self, ptych_slice): def gather(self, arg, root=0, axis=0): """Gather arg to one node.""" arg = self.comm.gather(arg, root=root) + # logger.info( + # f"Gather from node {self.rank} to node {root} along axis {axis}.") if self.rank == root: return np.concatenate(arg, axis=axis) return None diff --git a/src/tike/lamino/lamino.py b/src/tike/lamino/lamino.py index 3d645908..4b80bc21 100644 --- a/src/tike/lamino/lamino.py +++ b/src/tike/lamino/lamino.py @@ -104,7 +104,17 @@ def reconstruct( rtol : float Terminate early if the relative decrease of the cost function is less than this amount. - + tilt : float32 [radians] + The tilt angle; the angle between the rotation axis of the object and + the light source. π / 2 for conventional tomography. 0 for a beam path + along the rotation axis. + obj : (nz, n, n) complex64 + The complex refractive index of the object. nz is the axis + corresponding to the rotation axis. + data : (ntheta, n, n) complex64 + The complex projection data of the object. + theta : array-like float32 [radians] + The projection angles; rotation around the vertical axis of the object. """ n = data.shape[2] obj = np.zeros([n, n, n], dtype='complex64') if obj is None else obj @@ -117,14 +127,14 @@ def reconstruct( **kwargs, ) as operator, Comm(num_gpu, mpi=None) as comm: # send any array-likes to device - data = np.array_split(data.astype('complex64'), + data = np.array_split(data.astype('complex64', copy=False), comm.pool.num_workers) data = comm.pool.scatter(data) - theta = np.array_split(theta.astype('float32'), + theta = np.array_split(theta.astype('float32', copy=False), comm.pool.num_workers) theta = comm.pool.scatter(theta) result = { - 'obj': comm.pool.bcast(obj.astype('complex64')), + 'obj': comm.pool.bcast(obj.astype('complex64', copy=False)), } for key, value in kwargs.items(): if np.ndim(value) > 0: diff --git a/src/tike/lamino/solvers/cgrad.py b/src/tike/lamino/solvers/cgrad.py index d5c9648f..c08deebf 100644 --- a/src/tike/lamino/solvers/cgrad.py +++ b/src/tike/lamino/solvers/cgrad.py @@ -6,22 +6,98 @@ logger = logging.getLogger(__name__) -def _estimate_step_length(obj, theta, op): +def _estimate_step_length(obj, theta, K, op): """Use norm of forward adjoint operations to estimate step length. Scaling the adjoint operation by |F*Fm| / |m| puts the step length in the proper order of magnitude. """ - logger.info('Estimate step length from forward adjoint operations.') - outnback = op.adj( - data=op.fwd(u=obj, theta=theta), - theta=theta, - overwrite=False, - ) - scaler = tike.linalg.norm(outnback) / tike.linalg.norm(obj) + logger.info('Estimate lamino step length from forward adjoint operations.') + if K is None: + outnback = op.adj( + data=op.fwd(u=obj, theta=theta), + theta=theta, + overwrite=False, + ) + else: + outnback = op.adj( + data=op.xp.conj(K) * K * op.fwd(u=obj, theta=theta), + theta=theta, + overwrite=False, + ) + scaler = tike.linalg.norm(obj) / tike.linalg.norm(outnback) # Multiply by 2 to because we prefer over-estimating the step - return 2 * scaler if op.xp.isfinite(scaler) else 1.0 + if op.xp.isfinite(scaler): + return 2 * scaler + else: + logger.warning('Lamino step length estimate is non-finite.') + return 1.0 + + +def _cost_tv( + data, + theta, + obj, + reg=0, + K=None, + penalty0=1, + penalty1=0, + op=None, + op1=None, +): + """Cost function for the regularized laminography problem. + + The cost function F(u) is a two term function as follows: + + F(u) = penalty0 * norm(K * R(u) − data)**2 + + penalty1 * norm( J(u) − reg)**2 + + where + + K = 1j / v * 2π * (ψ − dual0 / penalty0) + data = (ψ − dual0 / penalty0) * log(ψ − dual0/penalty0) + reg = ω − dual1 / penalty1 + + where ψ is the projection through the object and ω is J(obj) and v is the + wavenumber. + + """ + K = 1 if K is None else K + cost = penalty0 * tike.linalg.norm(K * op.fwd( + u=obj, + theta=theta, + ) - data)**2 + if penalty1 > 0: + cost += penalty1 * tike.linalg.norm(op1.fwd(obj) - reg)**2 + return cost + + +def _grad_tv( + data, + theta, + obj, + reg=0, + K=None, + penalty0=1, + penalty1=0, + op=None, + op1=None, +): + """Gradient for the regularized laminography problem. + + ∇F(u) = penalty0 * R_adj(conj(K) * (K * R(u) − data)) + + penalty1 * J_adj( J(u) − reg) + + """ + if K is None: + d = op.fwd(u=obj, theta=theta) - data + else: + d = op.xp.conj(K) * (K * op.fwd(u=obj, theta=theta) - data) + grad = penalty0 * op.adj(data=d, theta=theta) + if penalty1 > 0: + grad += penalty1 * op1.adj(op1.fwd(obj) - reg) + return grad def cgrad( @@ -30,15 +106,19 @@ def cgrad( data, theta, obj, cg_iter=4, step_length=1, + K=None, **kwargs ): # yapf: disable """Solve the Laminogarphy problem using the conjugate gradients method.""" + K = [None] * comm.pool.num_workers if K is None else K + step_length = comm.pool.reduce_cpu( comm.pool.map( _estimate_step_length, obj, theta, + K, op=op, )) / comm.pool.num_workers if step_length == 1 else step_length @@ -48,6 +128,7 @@ def cgrad( data, theta, obj, + K, num_iter=cg_iter, step_length=step_length, ) @@ -55,18 +136,18 @@ def cgrad( return {'obj': obj, 'cost': cost, 'step_length': step_length} -def update_obj(op, comm, data, theta, obj, num_iter=1, step_length=1): +def update_obj(op, comm, data, theta, obj, K, num_iter=1, step_length=1): """Solver the object recovery problem.""" def cost_function(obj): - cost_out = comm.pool.map(op.cost, data, theta, obj) + cost_out = comm.pool.map(_cost_tv, data, theta, obj, K, op=op) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad(obj): - grad_list = comm.pool.map(op.grad, data, theta, obj) + grad_list = comm.pool.map(_grad_tv, data, theta, obj, K, op=op) if comm.use_mpi: return comm.Allreduce_reduce(grad_list, 'gpu') else: diff --git a/src/tike/linalg.py b/src/tike/linalg.py index b7d5e9c2..69c82722 100644 --- a/src/tike/linalg.py +++ b/src/tike/linalg.py @@ -7,6 +7,11 @@ import numpy as np +def norm1(x, axis=None, keepdims=None): + """Return the vector 1-norm of x along given axis.""" + return np.sum(np.abs(x), axis=axis, keepdims=keepdims) + + def norm(x, axis=None, keepdims=None): """Return the vector 2-norm of x along given axis.""" return np.sqrt(np.sum((x * x.conj()).real, axis=axis, keepdims=keepdims)) diff --git a/src/tike/operators/cupy/__init__.py b/src/tike/operators/cupy/__init__.py index 572188b4..345ec022 100644 --- a/src/tike/operators/cupy/__init__.py +++ b/src/tike/operators/cupy/__init__.py @@ -8,6 +8,7 @@ from .alignment import * from .convolution import * from .flow import * +from .gradient import * from .lamino import * from .operator import * from .pad import * @@ -22,6 +23,7 @@ 'Alignment', 'Bucket', 'Convolution', + 'Gradient', 'Flow', 'Lamino', 'Operator', diff --git a/src/tike/operators/cupy/alignment.py b/src/tike/operators/cupy/alignment.py index f843c1a1..7542a445 100644 --- a/src/tike/operators/cupy/alignment.py +++ b/src/tike/operators/cupy/alignment.py @@ -51,23 +51,30 @@ def fwd( unpadded_shape=None, cval=0.0, ): - return self.rotate.fwd( - unrotated=self.flow.fwd( - f=self.shift.fwd( - a=self.pad.fwd( - unpadded=unpadded, - padded_shape=padded_shape, - cval=cval, - ), - shift=shift, - cval=cval, - ), - flow=flow, + unflowed = self.shift.fwd( + a=self.pad.fwd( + unpadded=unpadded, + padded_shape=padded_shape, cval=cval, ), - angle=angle, + shift=shift, cval=cval, ) + if flow is None: + return self.rotate.fwd( + unrotated=unflowed, + angle=angle, + cval=cval, + ) + else: + if angle is not None: + flow = flow + self.rotate._make_flow(unrotated=unflowed, + angle=angle) + return self.flow.fwd( + f=unflowed, + flow=flow, + cval=cval, + ) def adj( self, @@ -79,17 +86,24 @@ def adj( padded_shape=None, cval=0.0, ): + if flow is None: + unflowed = self.rotate.adj( + rotated=rotated, + angle=angle, + cval=cval, + ) + else: + if angle is not None: + flow = flow + self.rotate._make_flow(unrotated=rotated, + angle=angle) + unflowed = self.flow.adj( + g=rotated, + flow=flow, + cval=cval, + ) return self.pad.adj( padded=self.shift.adj( - a=self.flow.adj( - g=self.rotate.adj( - rotated=rotated, - angle=angle, - cval=cval, - ), - flow=flow, - cval=cval, - ), + a=unflowed, shift=shift, cval=cval, ), diff --git a/src/tike/operators/cupy/gradient.py b/src/tike/operators/cupy/gradient.py new file mode 100644 index 00000000..eeac24f6 --- /dev/null +++ b/src/tike/operators/cupy/gradient.py @@ -0,0 +1,33 @@ +__author__ = "Viktor Nikitin, Daniel Ching" +__copyright__ = "Copyright (c) 2021, UChicago Argonne, LLC." +__docformat__ = 'restructuredtext en' + +from .operator import Operator + + +class Gradient(Operator): + """Returns the Gradient approximation of a 3D array.""" + + def fwd(self, u): + """Forward operator for regularization.""" + res = self.xp.empty((3, *u.shape), dtype=u.dtype) + res[0, :, :, :-1] = u[:, :, 1:] - u[:, :, :-1] + res[0, :, :, -1] = u[:, :, 0] - u[:, :, -1] + res[1, :, :-1, :] = u[:, 1:, :] - u[:, :-1, :] + res[1, :, -1, :] = u[:, 0, :] - u[:, -1, :] + res[2, :-1, :, :] = u[1:, :, :] - u[:-1, :, :] + res[2, -1, :, :] = u[ 0, :, :] - u[ -1, :, :] + res *= 1 / self.xp.sqrt(3) # normalization + return res + + def adj(self, g): + """Adjoint operator for regularization.""" + res = self.xp.empty(g.shape[1:], g.dtype) + res[:, :, 1:] = g[0, :, :, 1:] - g[0, :, :, :-1] + res[:, :, 0] = g[0, :, :, 0] - g[0, :, :, -1] + res[:, 1:, :] += g[1, :, 1:, :] - g[1, :, :-1, :] + res[:, 0, :] += g[1, :, 0, :] - g[1, :, -1, :] + res[1:, :, :] += g[2, 1:, :, :] - g[2, :-1, :, :] + res[0, :, :] += g[2, 0, :, :] - g[2, -1, :, :] + res *= -1 / self.xp.sqrt(3) # normalization + return res diff --git a/src/tike/operators/cupy/lamino.py b/src/tike/operators/cupy/lamino.py index 31d23318..c2878517 100644 --- a/src/tike/operators/cupy/lamino.py +++ b/src/tike/operators/cupy/lamino.py @@ -26,7 +26,7 @@ class Lamino(CachedFFT, Operator): ---------- n : int The pixel width of the cubic reconstructed grid. - tilt : float32 + tilt : float32 [radians] The tilt angle; the angle between the rotation axis of the object and the light source. π / 2 for conventional tomography. 0 for a beam path along the rotation axis. @@ -38,7 +38,7 @@ class Lamino(CachedFFT, Operator): corresponding to the rotation axis. data : (ntheta, n, n) complex64 The complex projection data of the object. - theta : array-like float32 + theta : array-like float32 [radians] The projection angles; rotation around the vertical axis of the object. """ diff --git a/src/tike/operators/cupy/pad.py b/src/tike/operators/cupy/pad.py index defcaa7e..250e226a 100644 --- a/src/tike/operators/cupy/pad.py +++ b/src/tike/operators/cupy/pad.py @@ -31,6 +31,9 @@ def fwd(self, unpadded, corner=None, padded_shape=None, cval=0.0, **kwargs): """ if padded_shape is None: padded_shape = unpadded.shape + elif (padded_shape[-1] < unpadded.shape[-1] + or padded_shape[-2] < unpadded.shape[-2]): + raise ValueError("Padded shape must be larger than unpadded.") if corner is None: corner = self.xp.tile( (((padded_shape[-2] - unpadded.shape[-2]) // 2, @@ -64,6 +67,9 @@ def adj(self, padded, corner=None, unpadded_shape=None, **kwargs): """ if unpadded_shape is None: unpadded_shape = padded.shape + elif (padded.shape[-1] < unpadded_shape[-1] + or padded.shape[-2] < unpadded_shape[-2]): + raise ValueError("Padded shape must be larger than unpadded.") if corner is None: corner = self.xp.tile( (((padded.shape[-2] - unpadded_shape[-2]) // 2, diff --git a/src/tike/operators/cupy/rotate.py b/src/tike/operators/cupy/rotate.py index e214657c..2d469f8c 100644 --- a/src/tike/operators/cupy/rotate.py +++ b/src/tike/operators/cupy/rotate.py @@ -20,6 +20,20 @@ class Rotate(Operator): original image. """ + def _make_flow(self, unrotated, angle): + """Return a flow that performs the rotation.""" + cos, sin = np.cos(angle), np.sin(angle) + shifti = (unrotated.shape[-2] - 1) / 2.0 + shiftj = (unrotated.shape[-1] - 1) / 2.0 + + i, j = self.xp.mgrid[0:unrotated.shape[-2], + 0:unrotated.shape[-1]].astype('float32') + + di = i - ((+cos * (i - shifti) + sin * (j - shiftj)) + shifti) + dj = j - ((-sin * (i - shifti) + cos * (j - shiftj)) + shiftj) + + return self.xp.stack([di, dj], axis=-1) + def _make_grid(self, unrotated, angle): """Return the points on the rotated grid.""" cos, sin = np.cos(angle), np.sin(angle) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 61ee16b8..9c43528a 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -168,6 +168,7 @@ def reconstruct( psi=None, num_gpu=1, num_iter=1, rtol=-1, model='gaussian', use_mpi=False, cost=None, times=None, eigen_probe=None, eigen_weights=None, + rescale=True, batch_size=None, initial_scan=None, position_options=None, @@ -243,11 +244,11 @@ def reconstruct( ) result = { 'psi': - comm.pool.bcast(psi.astype('complex64')), + comm.pool.bcast(psi.astype('complex64', copy=False)), 'probe': - comm.pool.bcast(probe.astype('complex64')), + comm.pool.bcast(probe.astype('complex64', copy=False)), 'eigen_probe': - comm.pool.bcast(eigen_probe.astype('complex64')) + comm.pool.bcast(eigen_probe.astype('complex64', copy=False)) if eigen_probe is not None else None, 'scan': scan, @@ -269,15 +270,16 @@ def reconstruct( if initial_scan[0] is None: initial_scan = comm.pool.map(cp.copy, scan) - result['probe'] = _rescale_obj_probe( - operator, - comm, - data, - result['psi'], - scan, - result['probe'], - num_batch=num_batch, - ) + if rescale: + result['probe'] = _rescale_obj_probe( + operator, + comm, + data, + result['psi'], + scan, + result['probe'], + num_batch=num_batch, + ) costs = [] times = [] @@ -347,7 +349,7 @@ def reconstruct( if isinstance(v, list): result[k] = v[0] return { - k: operator.asnumpy(v) if isinstance(v, cp.ndarray) else v + k: v if np.ndim(v) < 1 else operator.asnumpy(v) if isinstance(v, cp.ndarray) else v for k, v in result.items() } else: @@ -388,9 +390,9 @@ def _get_rescale(data, psi, scan, probe, num_batch, operator): rescale = n1 / n2 - logger.info("object and probe rescaled by %f", rescale) - - probe[0] *= rescale + if abs(1 - rescale) > 0.01: + logger.info("object and probe rescaled by %f", rescale) + probe[0] *= rescale return comm.pool.bcast(probe[0]) diff --git a/src/tike/ptycho/solvers/combined.py b/src/tike/ptycho/solvers/combined.py index d2a605d4..38969742 100644 --- a/src/tike/ptycho/solvers/combined.py +++ b/src/tike/ptycho/solvers/combined.py @@ -10,9 +10,16 @@ def cgrad( - op, comm, - data, probe, scan, psi, - recover_probe=True, recover_positions=False, + op, + comm, + data, + probe, + scan, + psi, + rho=None, + reg=None, + recover_probe=True, + recover_positions=False, cg_iter=4, cost=None, eigen_probe=None, @@ -22,7 +29,7 @@ def cgrad( step_length=1, probe_is_orthogonal=False, object_options=None, -): # yapf: disable +): """Solve the ptychography problem using conjugate gradient. Parameters @@ -55,6 +62,8 @@ def cgrad( psi, bscan, probe, + rho, + reg, num_iter=cg_iter, step_length=step_length, ) @@ -141,22 +150,30 @@ def f(x, d): return probe, cost -def _update_object(op, comm, data, psi, scan, probe, num_iter, step_length): +def _update_object(op, comm, data, psi, scan, probe, rho, reg, num_iter, + step_length): """Solve the object recovery problem.""" def cost_function_multi(psi, **kwargs): cost_out = comm.pool.map(op.cost, data, psi, scan, probe) if comm.use_mpi: - return comm.Allreduce_reduce(cost_out, 'cpu') + result = comm.Allreduce_reduce(cost_out, 'cpu') else: - return comm.reduce(cost_out, 'cpu') + result = comm.reduce(cost_out, 'cpu') + if reg is not None: + result += op.asnumpy(rho * op.xp.linalg.norm( + (psi[0] + reg[0]).ravel())**2) + return result def grad_multi(psi): grad_list = comm.pool.map(op.grad_psi, data, psi, scan, probe) if comm.use_mpi: - return comm.Allreduce_reduce(grad_list, 'gpu') + result = comm.Allreduce_reduce(grad_list, 'gpu') else: - return comm.reduce(grad_list, 'gpu') + result = comm.reduce(grad_list, 'gpu') + if reg is not None: + result += rho * (psi[0] + reg[0]) + return result def dir_multi(dir): """Scatter dir to all GPUs""" diff --git a/src/tike/regularization.py b/src/tike/regularization.py new file mode 100644 index 00000000..909bf85f --- /dev/null +++ b/src/tike/regularization.py @@ -0,0 +1,54 @@ +__author__ = "Viktor Nikitin, Daniel Ching" +__copyright__ = "Copyright (c) 2021, UChicago Argonne, LLC." +__docformat__ = 'restructuredtext en' + +import tike.linalg + + +def cost(op, x, dual, penalty, alpha): + """Minimization functional for regularization problem. + + Parameters + ---------- + op : operators.Gradient + The gradient operator. + x : (L, M, N) array-like + The object being regularized + dual : float + ADMM dual variable. + penalty : float + ADMM penalty parameter. + alpha : float + Some tuning parameter. + """ + grad = op.fwd(x) + cost = alpha * tike.linalg.norm1(grad) + cost += penalty * tike.linalg.norm(grad - reg + dual / penalty)**2 + return cost + + +def soft_threshold(op, x, dual, penalty, alpha): + """Soft thresholding operator for solving something. + + Parameters + ---------- + op : operators.Gradient + The gradient operator. + x : (L, M, N) array-like + The object being regularized + dual : float + ADMM dual variable. + penalty : float + ADMM penalty parameter. + alpha : float + Some tuning parameter. + + Returns + ------- + x1 : (L, M, N) array-like + The updated x. + + """ + z = op.fwd(x) + dual / penalty + za = op.xp.abs(z) + return z / za * op.xp.maximum(0, za - alpha / penalty) diff --git a/tests/operators/test_gradient.py b/tests/operators/test_gradient.py new file mode 100644 index 00000000..b090a760 --- /dev/null +++ b/tests/operators/test_gradient.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import unittest + +import numpy as np +from tike.operators import Gradient +import tike.random + +from .util import random_complex, OperatorTests + +__author__ = "Daniel Ching" +__copyright__ = "Copyright (c) 2021, UChicago Argonne, LLC." +__docformat__ = 'restructuredtext en' + + +class TestGradient(unittest.TestCase, OperatorTests): + """Test the Gradient operator.""" + + def setUp(self, shape=(8, 19, 5)): + """Load a dataset for reconstruction.""" + + self.operator = Gradient() + self.operator.__enter__() + self.xp = self.operator.xp + + np.random.seed(0) + self.m = tike.random.cupy_complex(*shape) + self.m_name = 'u' + self.d = tike.random.cupy_complex(3, *shape) + self.d_name = 'g' + self.kwargs = { } + print(self.operator) + + @unittest.skip('FIXME: This operator is not scaled.') + def test_scaled(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/operators/test_rotate.py b/tests/operators/test_rotate.py index f922f4fb..64fc5bb9 100644 --- a/tests/operators/test_rotate.py +++ b/tests/operators/test_rotate.py @@ -4,7 +4,7 @@ import unittest import numpy as np -from tike.operators import Rotate +from tike.operators import Rotate, Flow from .util import random_complex, OperatorTests @@ -33,19 +33,24 @@ def setUp(self, shape=(7, 25, 53)): } print(self.operator) - def debug_show(self): + def debug_show(self, angle=19 * np.pi / 6): import libimage import matplotlib.pyplot as plt x = self.xp.asarray(libimage.load('coins', 256), dtype='complex64') - y = self.operator.fwd(x[None], 4 * np.pi) + y = self.operator.fwd(x[None], angle) + flow = self.operator._make_flow(unrotated=x, angle=angle) + y1 = Flow().fwd(x, flow) - print(x.shape, y.shape) + print(x.shape, y.shape, y1.shape) plt.figure() plt.imshow(x.real.get()) plt.figure() plt.imshow(y[0].real.get()) + + plt.figure() + plt.imshow(y1.real.get()) plt.show() @unittest.skip('FIXME: This operator is not scaled.')