Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions src/everystamp/everystamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from astroquery.skyview import SkyView # type: ignore
from everystamp.cutters import make_cutout_2D, make_cutout_2D_fast, make_cutout_region
from everystamp.tonemapping import lhdr, normalise
from everystamp.tonemapping.stretches import TimmermanStretch

logging.basicConfig(
format="[%(name)s] %(asctime)s - %(levelname)s: %(message)s", level=logging.INFO
Expand Down Expand Up @@ -271,8 +272,8 @@ def _add_args_plot(parser):
default=None,
type=str,
required=False,
choices=["log", "sqrt", "squared", "asinh", "sinh"],
help="Stretch an image with a certian function.",
choices=["log", "sqrt", "squared", "asinh", "sinh", "timmerman"],
help="Stretch an image with a certain function.",
)
required_args.add_argument(
"--cmap",
Expand Down Expand Up @@ -830,6 +831,7 @@ def _process_args_download(args):
if args.lotss_release == "dr3":
print("LoTSS DR3 is not yet publicly available. Please authenticate.")
import getpass

user = getpass.getpass("username:")
password = getpass.getpass("password:")
for ra, dec in zip(ras, decs):
Expand Down Expand Up @@ -879,6 +881,7 @@ def _process_args_download(args):
if args.mode == "both":
raise ValueError("FIRST download does not support `both` (yet).")
from everystamp.downloaders import FIRSTDownloader

vd = FIRSTDownloader()
vd.download(
ra=ra,
Expand Down Expand Up @@ -1017,6 +1020,8 @@ def _process_args_plot(args):
bp = BasicFITSPlot(args.image)
elif args.style == "srtplot":
bp = SRTPlot(args.image)
else:
raise ValueError("Unkonwn plot style specified.")
else:
# Probably an image format.
bp = BasicImagePlot(args.image, wcsimage=args.wcs_image)
Expand Down Expand Up @@ -1168,7 +1173,11 @@ def _process_args_plot(args):
stretch = astropy.visualization.AsinhStretch()
elif args.stretch == "sinh":
stretch = astropy.visualization.SinhStretch()
bp.data = stretch(normalise(bp.data), min)
elif args.stretch == "timmerman":
stretch = TimmermanStretch()
else:
raise ValueError("Unknown stretch passed")
bp.data = stretch(normalise(bp.data))

if args.contour_image and (args.style == "normal"):
bp.plot2D(
Expand All @@ -1184,9 +1193,16 @@ def _process_args_plot(args):
if args.image.lower().endswith("fits") and (args.style == "normal"):
bp.savedata(args.image.replace(".fits", ".tonemapped.fits"))
if args.contour_image:
bp.plot_noaxes(cmap=args.cmap, cmap_min=args.cmap_min, cmap_max=args.cmap_max, contour_image=args.contour_image)
bp.plot_noaxes(
cmap=args.cmap,
cmap_min=args.cmap_min,
cmap_max=args.cmap_max,
contour_image=args.contour_image,
)
else:
bp.plot_noaxes(cmap=args.cmap, cmap_min=args.cmap_min, cmap_max=args.cmap_max)
bp.plot_noaxes(
cmap=args.cmap, cmap_min=args.cmap_min, cmap_max=args.cmap_max
)


def _process_args_cutout(args):
Expand All @@ -1211,7 +1227,9 @@ def _process_args_cutout(args):
coords = SkyCoord(ras, decs, unit="deg")

if args.region and (args.from_catalogue or args.size):
raise ValueError("Cannot specify a region, and a catalogue or size at the same time.")
raise ValueError(
"Cannot specify a region, and a catalogue or size at the same time."
)
if args.region and (args.cutout_mode == "fast"):
raise NotImplementedError("Cutout mode `fast` is not supported with regions.")

Expand Down
93 changes: 93 additions & 0 deletions src/everystamp/tonemapping/stretches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from astropy.visualization import BaseStretch
from astropy.visualization.stretch import _prepare
from typing import Optional
import numpy as np


def findrms(mIn, maskSup: float = 1e-7):
"""Find the rms of an array, from Cycil Tasse/kMS"""
m = mIn[np.abs(mIn) > maskSup]
rmsold = np.nanstd(m)
diff = 1e-1
cut = 3.0
med = np.nanmedian(m)
for i in range(10):
ind = np.where(np.abs(m - med) < rmsold * cut)[0]
rms = np.nanstd(m[ind])
if np.abs((rms - rmsold) / rmsold) < diff:
break
rmsold = rms
return rms


class TimmermanStretch(BaseStretch):
"""A stretch originally devised by R. Timmerman to (approximately)
conserve the visual dynamic range between background noise and source emission.
"""

@property
def _supports_invalid_kw(self):
return True

def __call__(
self,
values: np.ndarray,
rms_factor: float = 2.5,
peak_factor: float = 75.0,
clip: bool = True,
out: Optional[np.ndarray] = None,
invalid: Optional[float] = None,
):
"""
Transform values using this stretch.

Args:
values : array-like
The input values, which should already be normalized to the
[0:1] range.
clip : bool, optional
If `True` (default), values outside the [0:1] range are
clipped to the [0:1] range.
out : ndarray, optional
If specified, the output values will be placed in this array
(typically used for in-place calculations).
invalid : None or float, optional
Value to assign NaN values generated by this class. NaNs in
the input ``values`` array are not changed. This option is
generally used with matplotlib normalization classes, where
the ``invalid`` value should map to the matplotlib colormap
"under" value (i.e., any finite value < 0). If `None`, then
NaN values are not replaced. This keyword has no effect if
``clip=True``.
rms_factor : float
Sets the minimum value of the stretch based on the rms.
peak_factor : float
Proxy for the maximum value of the stretch based on the rms.

Returns:
result : ndarray
The transformed values.
"""
values = _prepare(values, clip=clip, out=out)
replace_invalid = not clip and invalid is not None
with np.errstate(invalid="ignore"):
if replace_invalid:
idx = values < 0
rms = findrms(values)
rms_factor = 2.5
vmin = -rms_factor * rms
vmax = np.nanmax((peak_factor * rms, np.nanmax(values)))
power_scaling = np.log(0.2) / np.log(2 * rms_factor * rms / (vmax - vmin))
values = values**power_scaling

if replace_invalid:
# Assign new NaN (i.e., NaN not in the original input
# values, but generated by this class) to the invalid value.
values[idx] = invalid

return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
raise NotImplementedError
Loading