diff --git a/src/everystamp/everystamp.py b/src/everystamp/everystamp.py index ea7467d..587d793 100644 --- a/src/everystamp/everystamp.py +++ b/src/everystamp/everystamp.py @@ -21,6 +21,7 @@ from astroquery.skyview import SkyView # type: ignore from everystamp.cutters import make_cutout_2D, make_cutout_2D_fast, make_cutout_region from everystamp.tonemapping import lhdr, normalise +from everystamp.tonemapping.stretches import TimmermanStretch logging.basicConfig( format="[%(name)s] %(asctime)s - %(levelname)s: %(message)s", level=logging.INFO @@ -271,8 +272,8 @@ def _add_args_plot(parser): default=None, type=str, required=False, - choices=["log", "sqrt", "squared", "asinh", "sinh"], - help="Stretch an image with a certian function.", + choices=["log", "sqrt", "squared", "asinh", "sinh", "timmerman"], + help="Stretch an image with a certain function.", ) required_args.add_argument( "--cmap", @@ -830,6 +831,7 @@ def _process_args_download(args): if args.lotss_release == "dr3": print("LoTSS DR3 is not yet publicly available. Please authenticate.") import getpass + user = getpass.getpass("username:") password = getpass.getpass("password:") for ra, dec in zip(ras, decs): @@ -879,6 +881,7 @@ def _process_args_download(args): if args.mode == "both": raise ValueError("FIRST download does not support `both` (yet).") from everystamp.downloaders import FIRSTDownloader + vd = FIRSTDownloader() vd.download( ra=ra, @@ -1017,6 +1020,8 @@ def _process_args_plot(args): bp = BasicFITSPlot(args.image) elif args.style == "srtplot": bp = SRTPlot(args.image) + else: + raise ValueError("Unkonwn plot style specified.") else: # Probably an image format. bp = BasicImagePlot(args.image, wcsimage=args.wcs_image) @@ -1168,7 +1173,11 @@ def _process_args_plot(args): stretch = astropy.visualization.AsinhStretch() elif args.stretch == "sinh": stretch = astropy.visualization.SinhStretch() - bp.data = stretch(normalise(bp.data), min) + elif args.stretch == "timmerman": + stretch = TimmermanStretch() + else: + raise ValueError("Unknown stretch passed") + bp.data = stretch(normalise(bp.data)) if args.contour_image and (args.style == "normal"): bp.plot2D( @@ -1184,9 +1193,16 @@ def _process_args_plot(args): if args.image.lower().endswith("fits") and (args.style == "normal"): bp.savedata(args.image.replace(".fits", ".tonemapped.fits")) if args.contour_image: - bp.plot_noaxes(cmap=args.cmap, cmap_min=args.cmap_min, cmap_max=args.cmap_max, contour_image=args.contour_image) + bp.plot_noaxes( + cmap=args.cmap, + cmap_min=args.cmap_min, + cmap_max=args.cmap_max, + contour_image=args.contour_image, + ) else: - bp.plot_noaxes(cmap=args.cmap, cmap_min=args.cmap_min, cmap_max=args.cmap_max) + bp.plot_noaxes( + cmap=args.cmap, cmap_min=args.cmap_min, cmap_max=args.cmap_max + ) def _process_args_cutout(args): @@ -1211,7 +1227,9 @@ def _process_args_cutout(args): coords = SkyCoord(ras, decs, unit="deg") if args.region and (args.from_catalogue or args.size): - raise ValueError("Cannot specify a region, and a catalogue or size at the same time.") + raise ValueError( + "Cannot specify a region, and a catalogue or size at the same time." + ) if args.region and (args.cutout_mode == "fast"): raise NotImplementedError("Cutout mode `fast` is not supported with regions.") diff --git a/src/everystamp/tonemapping/stretches.py b/src/everystamp/tonemapping/stretches.py new file mode 100644 index 0000000..c18cf49 --- /dev/null +++ b/src/everystamp/tonemapping/stretches.py @@ -0,0 +1,93 @@ +from astropy.visualization import BaseStretch +from astropy.visualization.stretch import _prepare +from typing import Optional +import numpy as np + + +def findrms(mIn, maskSup: float = 1e-7): + """Find the rms of an array, from Cycil Tasse/kMS""" + m = mIn[np.abs(mIn) > maskSup] + rmsold = np.nanstd(m) + diff = 1e-1 + cut = 3.0 + med = np.nanmedian(m) + for i in range(10): + ind = np.where(np.abs(m - med) < rmsold * cut)[0] + rms = np.nanstd(m[ind]) + if np.abs((rms - rmsold) / rmsold) < diff: + break + rmsold = rms + return rms + + +class TimmermanStretch(BaseStretch): + """A stretch originally devised by R. Timmerman to (approximately) + conserve the visual dynamic range between background noise and source emission. + """ + + @property + def _supports_invalid_kw(self): + return True + + def __call__( + self, + values: np.ndarray, + rms_factor: float = 2.5, + peak_factor: float = 75.0, + clip: bool = True, + out: Optional[np.ndarray] = None, + invalid: Optional[float] = None, + ): + """ + Transform values using this stretch. + + Args: + values : array-like + The input values, which should already be normalized to the + [0:1] range. + clip : bool, optional + If `True` (default), values outside the [0:1] range are + clipped to the [0:1] range. + out : ndarray, optional + If specified, the output values will be placed in this array + (typically used for in-place calculations). + invalid : None or float, optional + Value to assign NaN values generated by this class. NaNs in + the input ``values`` array are not changed. This option is + generally used with matplotlib normalization classes, where + the ``invalid`` value should map to the matplotlib colormap + "under" value (i.e., any finite value < 0). If `None`, then + NaN values are not replaced. This keyword has no effect if + ``clip=True``. + rms_factor : float + Sets the minimum value of the stretch based on the rms. + peak_factor : float + Proxy for the maximum value of the stretch based on the rms. + + Returns: + result : ndarray + The transformed values. + """ + values = _prepare(values, clip=clip, out=out) + replace_invalid = not clip and invalid is not None + with np.errstate(invalid="ignore"): + if replace_invalid: + idx = values < 0 + rms = findrms(values) + rms_factor = 2.5 + vmin = -rms_factor * rms + vmax = np.nanmax((peak_factor * rms, np.nanmax(values))) + power_scaling = np.log(0.2) / np.log(2 * rms_factor * rms / (vmax - vmin)) + values = values**power_scaling + + if replace_invalid: + # Assign new NaN (i.e., NaN not in the original input + # values, but generated by this class) to the invalid value. + values[idx] = invalid + + return values + + @property + def inverse(self): + """A stretch object that performs the inverse operation.""" + raise NotImplementedError