Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions examples/python/LBFGSBPC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#
# SPDX-License-Identifier: Apache-2.0
#
# Class implementing the LBFGSB-PC algorithm in stir
#
# Authors: Kris Thielemans
#
# Based on Georg Schramm's
# https://github.com/SyneRBI/PETRIC-MaGeZ/blob/a690205b2e3ec874e621ed2a32a802cd0bed4c1d/simulation_src/sim_stochastic_grad.py
# but with using diag(H.1) as preconditioner at the moment, as per Tsai's paper (see ref in the class doc)
#
# Copyright 2025 University College London

import numpy as np
import numpy.typing as npt
import stir
from scipy.optimize import fmin_l_bfgs_b
from typing import Callable, Optional, List

# import matplotlib.pyplot as plt


class LBFGSBPC:
"""Implementation of the LBFGSB-PC Algorithm

See
Tsai et al,
Fast Quasi-Newton Algorithms for Penalized Reconstruction in Emission Tomography and Further Improvements via Preconditioning
IEEE TRANSACTIONS ON MEDICAL IMAGING, VOL. 37, NO. 4, APRIL 2018
DOI: 10.1109/TMI.2017.2786865

WARNING: it maximises the objective function (as required by sirf.STIR.ObjectiveFunction).
WARNING: the implementation uses asarray(), which means you need SIRF 3.9. You should be able to just replace it with as_array() otherwise.

This implementation is NOT a CIL.Algorithm, but it behaves somewhat as one.
"""

def __init__(
self,
objfun: stir.GeneralisedObjectiveFunction3DFloat,
initial: stir.FloatVoxelsOnCartesianGrid,
update_objective_interval: int = 0,
):
self.trunc_filter = stir.TruncateToCylindricalFOVImageProcessor3DFloat()
self.objfun = objfun
self.initial = initial.clone()
self.trunc_filter.apply(self.initial)
self.shape = initial.shape()
self.output = None
self.update_objective_interval = update_objective_interval

precon = initial.get_empty_copy()
objfun.accumulate_Hessian_times_input(precon, initial, initial * 0 + 1)
precon *= -1
# self.Dinv_STIR = precon.maximum(1).power(-0.5)
self.Dinv = np.power(np.maximum(precon.as_array(), 1), -0.5)
self.Dinv_STIR = precon
self.Dinv_STIR.fill(self.Dinv)
self.trunc_filter.apply(self.Dinv_STIR)
# plt.figure()
# plt.imshow(self.Dinv_STIR.as_array()[self.shape[0] // 2, :, :])
self.Dinv = self.Dinv_STIR.as_array().ravel()
self.tmp_for_value = initial.get_empty_copy()
self.tmp1_for_gradient = initial.get_empty_copy()
self.tmp2_for_gradient = initial.get_empty_copy()

def precond_objfun_value(self, z: npt.ArrayLike) -> float:
self.tmp_for_value.fill(
np.reshape(z.astype(np.float32) * self.Dinv, self.shape)
)
return -self.objfun.compute_value(self.tmp_for_value)

def precond_objfun_gradient(self, z: npt.ArrayLike) -> np.ndarray:
self.tmp1_for_gradient.fill(
np.reshape(z.astype(np.float32) * self.Dinv, self.shape)
)
self.objfun.compute_gradient(self.tmp2_for_gradient, self.tmp1_for_gradient)
return self.tmp2_for_gradient.as_array().ravel() * self.Dinv * -1

def callback(self, x):
if (
self.update_objective_interval > 0
and self.iter % self.update_objective_interval == 0
):
self.loss.append(-self.precond_objfun_value(x))
self.iterations.append(self.iter)
self.iter += 1

def process(
self, iterations=None, callbacks: Optional[List[Callable]] = None, verbose=0
) -> None:
r"""run upto :code:`iterations` with callbacks.

Parameters
-----------
iterations: int, default is None
Number of iterations to run.
callbacks: list of callables, default is Defaults to self.callback
List of callables which are passed the current Algorithm object each iteration. Defaults to :code:`[ProgressCallback(verbose)]`.
verbose: 0=quiet, 1=info, 2=debug
Passed to the default callback to determine the verbosity of the printed output.
"""
if iterations is None:
raise ValueError("`missing argument `iterations`")
precond_init = self.initial / self.Dinv_STIR
self.trunc_filter.apply(precond_init)
precond_init = precond_init.as_array().ravel()
bounds = precond_init.size * [(0, None)]
self.iter = 0
self.loss = []
self.iterations = []
# TODO not really required, but it differs from the first value reported by fmin_l_bfgs_b. Not sure why...
self.callback(precond_init)
self.iter = 0 # set back again
res = fmin_l_bfgs_b(
self.precond_objfun_value,
precond_init,
self.precond_objfun_gradient,
maxiter=iterations,
bounds=bounds,
m=20,
callback=self.callback,
factr=0,
pgtol=0,
)
# store result (use name "x" for CIL compatibility)
self.x = self.tmp_for_value.get_empty_copy()
self.x.fill(np.reshape(res[0].astype(np.float32) * self.Dinv, self.shape))

def run(
self, **kwargs
) -> None: # CIL alias, would need to callback and verbose keywords etc
self.process(**kwargs)

def get_output(self) -> stir.FloatVoxelsOnCartesianGrid:
return self.x
82 changes: 82 additions & 0 deletions examples/python/recon_demo-LBFGSBPC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Demo of how to use STIR from python to reconstruct some data
# To run in "normal" Python, you would type the following in the command line
# execfile('recon_demo.py')
# In ipython, you can use
# %run recon_demo.py

# Copyright 2012-06-05 - 2013 Kris Thielemans
# Copyright 2015 University College London

# This file is part of STIR.
#
# SPDX-License-Identifier: Apache-2.0
#
# See STIR/LICENSE.txt for details

import stir
import stirextra
import matplotlib.pyplot as plt
import os
from LBFGSBPC import LBFGSBPC

stir.Verbosity.set(0)

# Switch 'interactive' mode on for plt.
# Without it, the python shell will wait after every plt.show() for you
# to close the window.
try:
plt.ion()
except:
print("Enabling interactive-mode for plotting failed. Continuing.")

# go to directory with input files
os.chdir("../recon_demo")

# initialise reconstruction object
# we will do this here via a .par file
OSEM_recon = stir.OSMAPOSLReconstruction3DFloat("recon_demo_OSEM.par")
# set filenames to save subset sensitivities (for illustration purposes)
poissonobj = OSEM_recon.get_objective_function()

# %% run initial OSEM

# get initial image
OSEM_target = stir.FloatVoxelsOnCartesianGrid.read_from_file("init.hv")
# we will just fill the whole array with 1 here
OSEM_target.fill(1)

s = OSEM_recon.set_up(OSEM_target)
if not s.succeeded():
raise RuntimeError("set-up failed")

OSEM_recon.reconstruct(OSEM_target)

# %% add prior/penalty and remove subsets

poissonobj.set_num_subsets(1)
penalty = stir.GibbsRelativeDifferencePenalty3DFloat()
penalty.set_penalisation_factor(1)
poissonobj.set_prior_sptr(penalty)

s = poissonobj.set_up(OSEM_target)

# %% Run reconstruction
recon2 = LBFGSBPC(poissonobj, initial=OSEM_target, update_objective_interval=2)
recon2.process(iterations=15)


# %% make some plots
npimage = recon2.get_output().as_array()
plt.figure()
plt.plot(OSEM_target.as_array()[10, 30, :], label="OSEM")
plt.plot(npimage[10, 30, :], label="LBFGSBPC")
plt.legend()

plt.figure()
plt.imshow(npimage[10, :, :])

plt.figure()
plt.plot(recon2.iterations, recon2.loss)

# %% Keep figures open until user closes them
plt.show(block=True)
Loading