From 6d8daefe60bc24fa82610d219af8d8d57ea51fe4 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 9 Dec 2025 15:16:23 +1300 Subject: [PATCH 01/41] feat: add alpine plot utilities --- visualisation/ims/response_rrup.py | 332 ++++++++++++++++++ visualisation/realisation.py | 19 +- visualisation/stations/plot_stations.py | 107 ++++++ visualisation/utils.py | 148 +++++++- visualisation/waveforms/plot_fas.py | 213 +++++++++++ .../waveforms/plot_response_spectra.py | 104 ++++++ visualisation/waveforms/plot_waveform.py | 192 ++++++++++ 7 files changed, 1103 insertions(+), 12 deletions(-) create mode 100644 visualisation/ims/response_rrup.py create mode 100644 visualisation/stations/plot_stations.py create mode 100644 visualisation/waveforms/plot_fas.py create mode 100644 visualisation/waveforms/plot_response_spectra.py create mode 100644 visualisation/waveforms/plot_waveform.py diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py new file mode 100644 index 0000000..7cd7f9a --- /dev/null +++ b/visualisation/ims/response_rrup.py @@ -0,0 +1,332 @@ +from pathlib import Path +from typing import NamedTuple + +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw +import pandas as pd +import typer +import xarray as xr +from matplotlib.axes import Axes +from rpy2.robjects import default_converter, globalenv, numpy2ri, r +from rpy2.robjects.conversion import localconverter + +from visualisation import utils +from visualisation.utils import RuptureContext, SiteProperties +from workflow.realisations import ( + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) + +app = typer.Typer() + + +def nshm2022_logic_tree_prediction( + rupture_context: RuptureContext, + site_properties: SiteProperties, + period: float, + rrup: npt.NDArray[np.floating], +) -> pd.DataFrame: + tect_type = oqw.constants.TectType.ACTIVE_SHALLOW + gmm_lt = oqw.constants.GMMLogicTree.NSHM2022 + rupture_df = pd.DataFrame( + {"rrup": rrup, "vs30measured": False, **rupture_context, **site_properties} + ) + for dist_metric in ["rjb", "rx", "ry"]: + rupture_df[dist_metric] = rupture_df["rrup"] + + psa_results = oqw.run_gmm_logic_tree( + gmm_lt, tect_type, rupture_df, "pSA", periods=[period] + ) + assert isinstance(psa_results, pd.DataFrame) + psa_results["rrup"] = rupture_df["rrup"] + return psa_results + + +class LowessFit(NamedTuple): + mean: npt.NDArray[np.floating] + std_low: npt.NDArray[np.floating] + std_high: npt.NDArray[np.floating] + + +def fit_loess_r( + y: npt.NDArray[np.floating], + x: npt.NDArray[np.floating], + x_out: npt.NDArray[np.floating], + **kwargs, +) -> LowessFit: + """ + Fit LOESS using R and return fitted values and prediction intervals. + """ + loess_args = ", ".join(f"{k}={v}" for k, v in kwargs.items()) + loess_call = f"loess(y ~ x, {loess_args})" if loess_args else "loess(y ~ x)" + + with localconverter(default_converter + numpy2ri.converter): + globalenv["x"] = x + globalenv["y"] = y + globalenv["x_out"] = x_out + + r(f"fit <- {loess_call}") + r("newdat <- data.frame(x=x_out)") + r("pred <- predict(fit, newdata=newdat, se=TRUE)") + r("residual_se <- fit$s") + + fit_vals = r("pred$fit") + if not isinstance(fit_vals, np.ndarray): + raise ValueError( + f"Residual stderr evaluation failed, expected float found: {fit_vals=}" + ) + + residual_se_eval = r("residual_se") + if isinstance(residual_se_eval, np.ndarray): + residual_se = residual_se_eval.item() + else: + raise ValueError( + f"Residual stderr evaluation failed, expected float found: {residual_se_eval=}" + ) + + std_low = fit_vals - residual_se + std_high = fit_vals + residual_se + + return LowessFit(fit_vals, std_low, std_high) + + +def plot_nshm_fit( + ax: Axes, + realisation_ffp: Path, + site_ds: xr.Dataset, + period: float, + rrup: npt.NDArray[np.floating], + color: str | None = None, + label: str | None = None, +) -> None: + source_config = SourceConfig.read_from_realisation(realisation_ffp) + magnitudes = Magnitudes.read_from_realisation(realisation_ffp) + rupture_prop = RupturePropagationConfig.read_from_realisation(realisation_ffp) + rakes = Rakes.read_from_realisation(realisation_ffp) + rupture_context = utils.compute_rupture_context( + source_config, magnitudes, rakes, rupture_prop + ) + site_properties = utils.compute_site_properties(site_ds.vs30.values) + logic_tree_results = nshm2022_logic_tree_prediction( + rupture_context, site_properties, period, rrup + ) + period_str = ( + f"{period:.2f}".rstrip("0") if not period.is_integer() else f"{int(period)}.0" + ) + mean = logic_tree_results[f"pSA_{period_str}_mean"] + std = logic_tree_results[f"pSA_{period_str}_std_Total"] + + ax.fill_between( + rrup, np.exp(mean - std), np.exp(mean + std), alpha=0.3, color=color + ) + ax.plot(rrup, np.exp(mean), c=color, label=label) + + +def plot_simulation_fit( + ax: Axes, + rrup: np.ndarray, + psa: np.ndarray, + label: str, + color: str, + span: float = 1 / 3, +) -> None: + """Plot LOESS fit for a subset of simulation data.""" + rrup_out = np.linspace(rrup.min(), rrup.max(), num=100) + fit, ci_low, ci_high = fit_loess_r( + np.log(psa), np.log(rrup), np.log(rrup_out), span=span + ) + ax.fill_between(rrup_out, np.exp(ci_low), np.exp(ci_high), alpha=0.3, color=color) + ax.plot(rrup_out, np.exp(fit), c=color, label=label) + + +# ------------------------ +# Compare per-basin subplots +# ------------------------ +def compare_sim_to_nshm_subplots( + realisation_ffp: Path, + simulation_ds: xr.Dataset, + period: float, + basins: list[str] | None = None, + component: str = "rotd50", + ymin: float | None = None, + ymax: float | None = None, + basin_vs_no_basin: bool = False, +): + """ + Create subplots: first shows all stations, then one subplot per basin. + """ + # Determine basins to plot + plot_basins = basins or [] + + fig, axes = utils.balanced_subplot_grid( + 1 + len(plot_basins), + 1.0, + subplot_size=(8, 6), + clear=False, + sharex=True, + sharey=True, + constrained_layout=True, + ) + max_rrup = min(500, simulation_ds.rrup.max().item()) + nshm_rrup = np.geomspace(1e-3, max_rrup, num=100) + # --- First subplot: all stations --- + ax = axes[0, 0] + ax.grid(True, which="both", axis="both", lw=0.3) + plot_nshm_fit( + ax, + realisation_ffp, + simulation_ds, + period, + nshm_rrup, + color="tab:blue", + label="NSHM logic tree prediction", + ) + if basin_vs_no_basin: + basin_ds = simulation_ds.where(simulation_ds.basin != "") + non_basin_ds = simulation_ds.where(simulation_ds.basin == "") + ax.scatter( + basin_ds.rrup, + basin_ds.pSA.sel(period=period, component=component).values, + c="tab:red", + alpha=0.3, + s=5, + ) + ax.scatter( + non_basin_ds.rrup, + non_basin_ds.pSA.sel(period=period, component=component).values, + c="tab:purple", + alpha=0.3, + s=5, + ) + plot_simulation_fit( + ax, + basin_ds.rrup.values, + basin_ds.pSA.sel(period=period, component=component).values, + label="Basin stations", + color="darkred", + ) + plot_simulation_fit( + ax, + non_basin_ds.rrup.values, + non_basin_ds.pSA.sel(period=period, component=component).values, + label="Non-basin stations", + color="purple", + ) + else: + ax.scatter( + simulation_ds.rrup, + simulation_ds.pSA.sel(period=period, component=component).values, + c="k", + alpha=0.3, + s=10, + ) + plot_simulation_fit( + ax, + simulation_ds.rrup.values, + simulation_ds.pSA.sel(period=period, component=component).values, + label="Simulated Stations", + color="tab:gray", + ) + # Plot NSHM + ax.legend() + + ax.set_yscale("log") + ax.set_xscale("log") + + # --- Per-basin subplots --- + if plot_basins: + for i, basin in enumerate(plot_basins): + row, col = np.unravel_index(i + 1, axes.shape) + ax = axes[row, col] + subds = simulation_ds.sel( + station=[ + s + for s, b in zip( + simulation_ds.station.values, simulation_ds.basin.values + ) + if b == basin + ] + ) + if len(subds.station) == 0: + continue + ax.grid(True, which="both", axis="both", lw=0.3) + plot_nshm_fit( + ax, + realisation_ffp, + subds, + period, + nshm_rrup, + color="tab:blue", + ) + ax.scatter( + subds.rrup, + subds.pSA.sel(period=period, component=component).values, + c="red", + alpha=0.7, + s=10, + ) + + ax.set_title(f"Basin: {basin}") + ax.set_yscale("log") + ax.set_xscale("log") + + # --- Axis labels --- + if plot_basins: + fig.supxlabel("Source to site distance, $R_{rup}$ [km]") + fig.supylabel(f"pSA({period:.2f} s) [g]") + else: + ax.set_xlabel("Source to site distance, $R_{rup}$ [km]") + ax.set_ylabel(f"pSA({period:.2f} s) [g]") + + if ymin is not None or ymax is not None: + for ax in axes.flatten(): + ax.set_ylim(bottom=ymin, top=ymax) + ax.set_xlim(left=1e-1, right=max_rrup) + return fig + + +# ------------------------ +# CLI +# ------------------------ +@app.command() +def compare_sim_per_basin( + realisation_ffp: Path, + simulation_dataset_path: Path, + period: float, + basins: list[str] | None = None, + save: Path | None = None, + dpi: int = 300, + ymin: float | None = 1e-5, + ymax: float | None = 10, + component: str = "rotd50", + compare_basin: bool = False, +) -> None: + """ + Compare simulation dataset results to NSHM with subplots per basin. + First subplot is all stations. + """ + simulation_ds = xr.open_dataset(simulation_dataset_path) + + fig = compare_sim_to_nshm_subplots( + realisation_ffp, + simulation_ds, + period, + basins=basins, + component=component, + ymin=ymin, + ymax=ymax, + basin_vs_no_basin=compare_basin, + ) + + if save: + fig.savefig(save, dpi=dpi) + else: + plt.show() + + +if __name__ == "__main__": + app() diff --git a/visualisation/realisation.py b/visualisation/realisation.py index a670407..eb3d6cb 100644 --- a/visualisation/realisation.py +++ b/visualisation/realisation.py @@ -211,7 +211,6 @@ def plot_realisation( title: str | None = None, subtitle: str | None = None, width: float = 10, - show_geonet_stations: bool = False, show_geometry: bool = True, show_pgv_targets: bool = False, pgv_targets: list[float] | None = None, @@ -233,8 +232,6 @@ def plot_realisation( Subtitle of the plot. width : float Width of the plot in cm. - show_geonet_stations : bool - Show GeoNet stations on the plot. show_geometry : bool Show source geometry on the plot. show_pgv_targets : bool @@ -263,20 +260,21 @@ def plot_realisation( >>> fig.show() """ show_pgv_targets = show_pgv_targets or bool(pgv_targets) - rupture_propagation = RupturePropagationConfig.read_from_realisation( - realisation_ffp - ) domain_parameters = DomainParameters.read_from_realisation(realisation_ffp) - velocity_model_parameters = VelocityModelParameters.read_from_realisation( - realisation_ffp - ) - source_config = SourceConfig.read_from_realisation(realisation_ffp) rrup_bounding_polygons: list[shapely.Polygon] = [] if show_pgv_targets: + rupture_propagation = RupturePropagationConfig.read_from_realisation( + realisation_ffp + ) + + velocity_model_parameters = VelocityModelParameters.read_from_realisation( + realisation_ffp + ) + fault_pgv_targets = pgv_targets or [ generate_velocity_model_parameters.pgv_target( rupture_propagation, velocity_model_parameters @@ -449,7 +447,6 @@ def plot_realisation_to_file( title=title, subtitle=subtitle, width=width, - show_geonet_stations=show_geonet_stations, show_geometry=show_geometry, show_pgv_targets=show_pgv_targets, pgv_targets=pgv_targets, diff --git a/visualisation/stations/plot_stations.py b/visualisation/stations/plot_stations.py new file mode 100644 index 0000000..e86badf --- /dev/null +++ b/visualisation/stations/plot_stations.py @@ -0,0 +1,107 @@ +from pathlib import Path + +import pandas as pd +import pygmt +import typer + +from pygmt_helper import plotting +from visualisation import realisation, utils +from workflow.realisations import DomainParameters, SourceConfig + +app = typer.Typer() + + +def plot_towns(fig: pygmt.Figure): + towns = { + "Blenheim": (173.9569444, -41.5138888), + "Christchurch": (172.6347222, -43.5313888), + "Dunedin": (170.3794444, -45.8644444), + "Greymouth": (171.2063889, -42.4502777), + "Haast": (169.0405556, -43.8808333), + "Kaikoura": (173.6802778, -42.4038888), + "Masterton": (175.658333, -40.952778), + "Napier": (176.916667, -39.483333), + "New Plymouth": (174.083333, -39.066667), + "Nelson": (173.2838889, -41.2761111), + "Palmerston North": (175.611667, -40.355000), + "Queenstown": (168.6680556, -45.0300000), + "Rakaia": (172.0230556, -43.75611111), + "Rotorua": (176.251389, -38.137778), + "Taupo": (176.069400, -38.6875), + "Tekapo": (170.4794444, -44.0069444), + "Timaru": (171.2430556, -44.3958333), + "Wellington": (174.777222, -41.288889), + "Westport": (171.5997222, -41.7575000), + } + for label, (lon, lat) in towns.items(): + fig.plot(x=lon, y=lat, style="c0.1c", fill="white", pen="0.3p,black") + fig.text( + x=lon, y=lat, text=label, justify="BC", offset="0.15c", font="6p,black" + ) + + +@app.command(help="Test") +def plot_stations( + realisation_ffp: Path, + stations_ll: Path, + stations_vs30: Path, + cmap: str, + title: str, + width: float, +): + domain_parameters = DomainParameters.read_from_realisation(realisation_ffp) + source_config = SourceConfig.read_from_realisation(realisation_ffp) + stations_vs30_df = pd.read_csv( + stations_vs30, + delimiter=r"\s+", + comment="#", + names=["name", "vs30"], + header=None, + ).set_index("name") + stations_ll_df = pd.read_csv( + stations_ll, + delimiter=r"\s+", + comment="#", + names=["lon", "lat", "name"], + header=None, + ).set_index("name") + stations_vs30_df = stations_vs30_df.loc[stations_ll_df.index] + region = utils.bounding_region_for([domain_parameters.domain.polygon], 0, 0) + fig = plotting.gen_region_fig( + title, + region, + projection=f"M{width}c", + plot_kwargs=dict(water_color="white", topo_cmap_min=-900, topo_cmap_max=3100), + plot_highways=False, + ) + + realisation.plot_domain(fig, domain_parameters, pen="1p,black,-") + realisation.plot_sources(fig, source_config, fill="blue") + + pygmt.makecpt( + cmap=cmap.removesuffix("_r"), + series=[0, 1500, 100], + reverse=cmap.endswith("_r"), + ) + realisation.plot_stations( + fig, + domain_parameters, + stations_ll, + fill=stations_vs30_df["vs30"], + cmap=True, + style="i0.3c", + ) + fig.text( + x=stations_ll_df["lon"], + y=stations_ll_df["lat"], + text=stations_ll_df.index, + justify="BC", + offset="0.25c", + font="6p,black", + ) + fig.colorbar(frame="xaf+lVs30 (m/s)") + fig.savefig("map.png") + + +if __name__ == "__main__": + app() diff --git a/visualisation/utils.py b/visualisation/utils.py index 7a2694a..602db55 100644 --- a/visualisation/utils.py +++ b/visualisation/utils.py @@ -1,13 +1,25 @@ """Utility functions common to many plotting scripts.""" -from typing import Optional +from typing import Any, Literal, Optional, TypedDict, Unpack import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw import pygmt import scipy as sp import shapely +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure from qcore import coordinates +from source_modelling import moment +from workflow.realisations import ( + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) def format_description( @@ -316,3 +328,137 @@ def grid_scale_for_region(region: tuple[float, float, float, float]) -> int: lon_km = (max_lon - min_lon) * 111 * np.cos(np.radians((min_lat + max_lat) / 2)) maximum_extent = max(lat_km, lon_km) return int(round(max(5, 2.5 * maximum_extent))) + + +class SubplotsKwargs(TypedDict, total=False): + sharex: bool | Literal["none", "all", "row", "col"] + sharey: bool | Literal["none", "all", "row", "col"] + subplot_kw: dict[str, Any] | None + gridspec_kw: dict[str, Any] | None + figsize: tuple[float, float] + constrained_layout: bool | dict[str, Any] + layout: Literal["constrained", "compressed", "tight"] | None + + +def balanced_subplot_grid( + n_subplots: int, + aspect: float, + subplot_size: tuple[float, float] | None = None, + squeeze: bool = False, + clear: bool = False, + **kwargs: Unpack[SubplotsKwargs], +) -> tuple[Figure, npt.NDArray[Axes]]: + # This has more columns than rows, i.e. wide + height = np.sqrt(n_subplots / aspect) + rows = int(np.ceil(height)) + columns = int(np.ceil(height * aspect)) + # Ensures there are no blank rows. + rows -= max(0, (rows * columns - n_subplots) // columns) + + if subplot_size: + width, height = subplot_size + kwargs["figsize"] = (width * columns, height * rows) + + fig, axes = plt.subplots(rows, columns, **kwargs) + if n_subplots == 1: + axes = np.atleast_2d([axes]) + if squeeze: + axes = axes.squeeze() + if clear: + for ax in axes.flatten()[n_subplots:]: + ax.remove() + + return fig, axes + + +class RuptureContext(TypedDict): + mag: float + rake: float + dip: float + hypo_depth: float + ztor: float + zbot: float + + +class SiteProperties(TypedDict): + vs30: float + z1pt0: float + z2pt5: float + + +def circmean( + samples: npt.NDArray[np.floating], weights: npt.NDArray[np.floating] +) -> float: + x = np.cos(samples) + y = np.sin(samples) + z = weights * np.array([x, y]) + mean_resultant_vector = np.mean(z, axis=1) + argument = np.arctan2(mean_resultant_vector[1], mean_resultant_vector[0]) + return float(argument) + + +def compute_rupture_context( + source_config: SourceConfig, + magnitudes_config: Magnitudes, + rakes_config: Rakes, + rupture_propagation: RupturePropagationConfig, +) -> RuptureContext: + moments = [] + dips = [] + rakes = [] + + for name, source in source_config.source_geometries.items(): + dips.append(source.dip) + moments.append(moment.magnitude_to_moment(magnitudes_config[name])) + rakes.append(rakes_config[name]) + + ztor = ( + min( + source_config.source_geometries.values(), key=lambda source: source.top_m + ).top_m + / 1000 + ) + zbot = ( + max( + source_config.source_geometries.values(), key=lambda source: source.bottom_m + ).bottom_m + / 1000 + ) + avg_rake = np.degrees(circmean(np.radians(rakes), np.array(moments))) + avg_dip = float(np.average(dips, weights=moments)) + avg_moment = float(np.mean(moments)) + total_moment = avg_moment * len(moments) + magnitude = moment.moment_to_magnitude(total_moment) + initial_source = source_config.source_geometries[rupture_propagation.initial_fault] + hypocentre = initial_source.fault_coordinates_to_wgs_depth_coordinates( + rupture_propagation.hypocentre + ) + hypo_depth = float(hypocentre[2]) + hypo_depth /= 1000.0 + return RuptureContext( + mag=magnitude, + rake=avg_rake, + dip=avg_dip, + hypo_depth=hypo_depth, + ztor=ztor, + zbot=zbot, + ) + + +def compute_site_properties( + site_vs30: npt.NDArray[np.floating] | np.floating, +) -> SiteProperties: + # Calculate geometric mean of site vs30 using the exponential-log form: + # exp(1/n sum vs30) + # This is as opposed to straight-forward calculation + # product(vs30) ^ (1/n) + # Which is numerically unstable for a large number of stations due to + # floating-point arithmetic overflow and inprecision at large values + # obtained by multiplication. + if isinstance(site_vs30, np.ndarray): + vs30 = np.exp(1 / len(site_vs30) * np.sum(np.log(site_vs30))) + else: + vs30 = site_vs30 + z1pt0 = oqw.estimations.chiou_young_14_calc_z1p0(vs30) + z2pt5 = oqw.estimations.campbell_bozorgina_14_calc_z2p5(vs30) + return SiteProperties(vs30=vs30, z1pt0=z1pt0, z2pt5=z2pt5) diff --git a/visualisation/waveforms/plot_fas.py b/visualisation/waveforms/plot_fas.py new file mode 100644 index 0000000..aed0210 --- /dev/null +++ b/visualisation/waveforms/plot_fas.py @@ -0,0 +1,213 @@ +from pathlib import Path +from typing import Annotated + +import matplotlib.pyplot as plt +import re +import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw +import pandas as pd +import typer +import xarray as xr +from matplotlib.axes import Axes + +from qcore import cli +from visualisation import utils +from workflow.realisations import ( + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) + +app = typer.Typer() + + +def plot_fas( + ax: Axes, + dataset: xr.Dataset, + station: str, + component: str, + ymax: float | None = None, + ymin: float | None = None, + **kwargs, +) -> None: + """Plot a Fourier Amplitude Spectrum (FAS) from a simulation. + + Parameters + ---------- + dataset : xr.Dataset + The dataset to read station FAS. + station : str + The station to plot. + component : str + The component to plot. + ymax : float or None + Max limit for y-axis. + ymin : float or None + Min limit for y-axis. + """ + fas = dataset.FAS.sel(station=station, component=component).values + freqs = dataset.frequency.values + ax.plot(freqs, fas, **kwargs) + if ymin is not None or ymax is not None: + ax.set_ylim(bottom=ymin, top=ymax) + ax.set_xlim(freqs.min(), 50) + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(visible=True, which="both", axis="both", lw=0.3) + + +def plot_fas_estimate( + ax: Axes, realisation_ffp: Path, dataset: xr.Dataset, station: str, **kwargs +) -> None: + freqs = dataset.frequency.values + vs30 = dataset.vs30.sel(station=station).item() + site_properties = utils.compute_site_properties(vs30) + source_config = SourceConfig.read_from_realisation(realisation_ffp) + magnitudes = Magnitudes.read_from_realisation(realisation_ffp) + rakes = Rakes.read_from_realisation(realisation_ffp) + rupture_propagation = RupturePropagationConfig.read_from_realisation( + realisation_ffp + ) + rupture_context = utils.compute_rupture_context( + source_config, magnitudes, rakes, rupture_propagation + ) + site = dataset.sel(station=station) + longitude = site.longitude.item() + latitude = site.latitude.item() + point = np.array([latitude, longitude, 0]) + rrup = ( + min( + source.rrup_distance(point) + for source in source_config.source_geometries.values() + ) + / 1000 + ) + rupture_df = pd.DataFrame( + { + "rrup": [rrup], + "vs30measured": False, + **site_properties, + **rupture_context, + } + ) + fas_results = oqw.run_gmm( + oqw.constants.GMM.BA_18, + oqw.constants.TectType.ACTIVE_SHALLOW, + rupture_df, + "EAS", + frequencies=freqs, + ) + emp_eas = [] + emp_eas_stddev = [] + for col in fas_results.columns: + if col.endswith("mean"): + emp_eas.append(fas_results[col].item()) + elif col.endswith("Total"): + emp_eas_stddev.append(fas_results[col].item()) + + emp_eas = np.array(emp_eas) + emp_eas_stddev = np.array(emp_eas_stddev) + ax.plot(freqs, np.exp(emp_eas), **kwargs) + colour = kwargs.get("c") or kwargs.get("color") + ax.fill_between( + freqs, + np.exp(emp_eas - emp_eas_stddev), + np.exp(emp_eas + emp_eas_stddev), + color=colour, + alpha=0.3, + ) + + +@cli.from_docstring(app) +def plot_fas_cli( + realisation_ffp: Annotated[Path, typer.Argument()], + dataset_path: Annotated[Path, typer.Argument()], + stations: Annotated[list[str], typer.Argument()], + title: str | None = None, + save: Path | None = None, + dpi: int = 300, + width: float = 20, + height: float = 15, + ymin: float | None = 1e-5, + ymax: float | None = 1, + component: str = "geom", +) -> None: + """Plot a station Fourier Amplitude Spectrum (FAS). + + Parameters + ---------- + dataset_path : Path + Path to HDF5 FAS dataset. + station : str + The station to plot. + """ + dset = xr.open_dataset(dataset_path, engine="h5netcdf") + cm = 1 / 2.54 + + fig, axes = utils.balanced_subplot_grid( + len(stations), + subplot_size=(width * cm, height * cm), + aspect=3 / 2, + sharex=True, + sharey=True, + clear=True, + constrained_layout=True, + ) + for station, ax in zip(stations, axes.flatten()): + plot_fas_estimate( + ax, + realisation_ffp, + dset, + station, + label="EAS (BA18; μ ± σ)", + color="blue", + ) + plot_fas( + ax, + dset, + station, + component, + ymin=ymin, + ymax=ymax, + label="EAS (Simulation)", + color="k", + ) + station_data = dset.sel(station=station) + vs30 = station_data.vs30.item() + basin = str(station_data.basin.item()) + if not basin: + basin = 'No Basin' + else: + basin = re.sub('[A-Z]', r' \g<0>', basin).lstrip() + lat = station_data.latitude.item() + lon = station_data.longitude.item() + pga = station_data.PGA.sel(component='rotd50').item() + pgv = station_data.PGV.sel(component='rotd50').item() + lon = station_data.longitude.item() + ax.set_title(f'{station}\n({lat:.3f}, {lon:.3f}) - PGA: {pga:.2g} g - PGV: {pgv:.0f} cm/s - Vs30: {vs30:.0f} m/s - Basin: {basin}') + + if axes.size > 1: + fig.supylabel(f"EAS [{component}]") + fig.supxlabel("Frequency [Hz]") + else: + ax = axes.flatten().item() + ax.set_ylabel(f"EAS [{component}]") + ax.set_xlabel("Frequency [Hz]") + + ax = axes.flatten()[0] + ax.legend() + + if title: + fig.suptitle(title) + + if save: + fig.savefig(save, dpi=dpi) + else: + fig.show() + plt.show() + + +if __name__ == "__main__": + app() diff --git a/visualisation/waveforms/plot_response_spectra.py b/visualisation/waveforms/plot_response_spectra.py new file mode 100644 index 0000000..e7391f9 --- /dev/null +++ b/visualisation/waveforms/plot_response_spectra.py @@ -0,0 +1,104 @@ +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import matplotlib.pyplot as plt +import numpy.typing as npt +import scipy as sp +import typer +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from qcore import cli + +app = typer.Typer() + + +def plot_spectra( + dataset: xr.Dataset, + station: str, + component: str, + ymax: float | None = None, + ymin: float | None = None, + **kwargs, +) -> tuple[Figure, list[Axes]]: + """Plot a spectra from a simulation. + + Parameters + ---------- + dataset : xr.Dataset + The dataset to read station spectras. It is assumed that spectra array is in cm/s^2. + station : str + The station to plot. + units : Units + The units to plot with. + ymax : float or None + Max limit for y-axis. + ymin : float or None + Min limit for y-axis. + """ + spectra = dataset.pSA.sel(station=station, component=component).values + periods = dataset.period.values + fig, ax = plt.subplots(**kwargs) + ax.plot(periods, spectra) + ax.grid() + ax.set_ylabel(f"pSA [{component}, g]") + if ymin is not None or ymax is not None: + ax.set_ylim(bottom=ymin, top=ymax) + ax.set_xlabel("Period [s]") + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(visible=True, which="both", axis="both", lw=0.3) + return fig, ax + + +@cli.from_docstring(app) +def plot_spectra_cli( + dataset_path: Annotated[Path, typer.Argument()], + station: Annotated[str, typer.Argument()], + title: str | None = None, + save: Path | None = None, + dpi: int = 300, + width: float = 20, + height: float = 15, + ymin: float | None = 1e-5, + ymax: float | None = 1, + component: str = "rotd50", +) -> None: + """Plot a station spectra. + + Parameters + ---------- + dataset_path : Path + Path to HDF5 spectra dataset. + station : str + The station to plot. + units : Units + The units to plot in. + """ + dset = xr.open_dataset(dataset_path, engine="h5netcdf") + cm = 1 / 2.54 + fig, _ = plot_spectra( + dset, + station, + component, + ymin=ymin, + ymax=ymax, + figsize=(width * cm, height * cm), + ) + + if title: + fig.suptitle(title) + + fig.tight_layout() + + if save: + fig.savefig(save, dpi=dpi) + else: + fig.show() + plt.show() + + +if __name__ == "__main__": + app() diff --git a/visualisation/waveforms/plot_waveform.py b/visualisation/waveforms/plot_waveform.py new file mode 100644 index 0000000..84629c9 --- /dev/null +++ b/visualisation/waveforms/plot_waveform.py @@ -0,0 +1,192 @@ +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import scipy as sp +import typer +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.gridspec import GridSpecFromSubplotSpec + +from qcore import cli + +app = typer.Typer() + + +class Units(StrEnum): + G = "g" + CMS = "cm/s" + CMS2 = "cm/s^2" + + +_UNIT_CONVERSION_TABLE = {Units.G: 981.0, Units.CMS: 1.0, Units.CMS2: 1.0} + + +def plot_waveform( + dataset: xr.Dataset, + station: str, + ax_x: Axes, + ax_y: Axes, + ax_z: Axes, + dataset_units: Units, + plot_units: Units | None = None, + ylim: float | None = None, + **kwargs, +) -> None: + """Plot a waveform from a simulation. + + Parameters + ---------- + dataset : xr.Dataset + The dataset to read station waveforms. It is assumed that waveform array is in cm/s^2. + station : str + The station to plot. + dataset_units : Units + The units of the dataset. + plot_units : Units, optional + The units to plot in. If not given, will assume it is the same + as the dataset units. + ylim : float or None + Limit for y-axis. + + Returns + ------- + Figure + The matplotlib figure containing the waveform plot. + Array of axes + The figure axes. + """ + if ylim is not None: + ylim = abs(ylim) + waveform = dataset.waveform.sel(station=station).values + time = dataset.time.values + dt = time[1] - time[0] + plot_units = plot_units or dataset_units + conversion = ( + _UNIT_CONVERSION_TABLE[dataset_units] / _UNIT_CONVERSION_TABLE[plot_units] + ) + waveform *= conversion + if plot_units == Units.CMS: + waveform = sp.integrate.cumulative_trapezoid(waveform, dx=dt, initial=0) + axes = [ax_x, ax_y, ax_z] + for i, component in enumerate(dataset.component): + axes[i].plot(time, waveform[i]) + axes[i].grid() + axes[i].set_ylabel(f"{str(component.item())} [{plot_units}]") + if ylim is not None: + axes[i].set_ylim(bottom=-ylim, top=ylim) + axes[-1].set_xlabel("time [s]") + + +@cli.from_docstring(app) +def plot_waveform_cli( + dataset_path: Annotated[Path, typer.Argument()], + dataset_units: Annotated[Units, typer.Argument()], + stations: Annotated[list[str], typer.Argument()], + plot_units: Annotated[Units | None, typer.Option()] = None, + title: str | None = None, + save: Path | None = None, + dpi: int = 300, + width: float = 20, + height: float = 15, + ylim: float | None = None, + rows: float | None = None, + columns: float | None = None, +) -> None: + """Plot a station waveform. + + Parameters + ---------- + dataset_path : Path + Path to HDF5 waveform dataset. + station : str + The station to plot. + dataset_units : Units + The units of the dataset. + plot_units : Units, optional + The units to plot in. If not given, will assume it is the same + as the dataset units. + title : str, optional + The title of the figure. + save : Path, optional + If given, save the figure to the supplied file. + dpi : int, optional + Figure DPI (higher is better quality). Only applies if saving + the figure to a file. + width : float, optional + The figure width, in centimetres. + height : float, optional + The figure height, in centimetres. + ylim : float, optional + The maximum value for the y-axis. + """ + dset = xr.open_dataset(dataset_path, engine="h5netcdf") + cm = 1 / 2.54 + if not (rows or columns): + n = len(stations) + rows = int(np.sqrt(n)) + columns = int(np.ceil(n / rows)) + + mosaic: list[list[str | None]] = [[None] * columns for _ in range(rows)] + for i, station in enumerate(stations): + row, column = np.unravel_index(i, (rows, columns)) + mosaic[row][column] = station + + fig, station_axes = plt.subplot_mosaic( + mosaic, figsize=(width * cm, height * cm), sharex=True, sharey=True + ) + for ax in station_axes.values(): + ax.remove() + # For rescaling axes at the end + all_station_axes = [] + for station, big_ax in station_axes.items(): + if not station: + continue # Skips None placeholders + gs = GridSpecFromSubplotSpec( + nrows=3, ncols=1, subplot_spec=big_ax.get_subplotspec(), hspace=0.1 + ) + axes = [] + for spec in gs: + axes.append(fig.add_subplot(spec)) + + for other in axes[:-1]: + other.set_xticklabels([]) + + all_station_axes.extend(axes) + plot_waveform( + dset, + station, + *axes, + dataset_units, + plot_units, + ylim=ylim, + figsize=(width * cm, height * cm), + ) + axes[0].set_title(station) + + maxes = [] + + for ax in all_station_axes: + maxes.extend(ax.get_ylim()) + abs_max = max(abs(lim) for lim in maxes) + global_ylim = (-abs_max, abs_max) + for ax in all_station_axes: + ax.set_ylim(global_ylim) + if title: + fig.suptitle(title) + + fig.tight_layout() + + if save: + fig.savefig(save, dpi=dpi) + else: + fig.show() + plt.show() + + +if __name__ == "__main__": + app() From 400afaa54900a23d9a3cf8cce5570d725738c8db Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 9 Dec 2025 15:23:54 +1300 Subject: [PATCH 02/41] feat(response-rrup): add script as executable in package --- pyproject.toml | 31 +++++++++++++++--------------- visualisation/ims/response_rrup.py | 4 ---- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 837a6f7..bce50be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ plot-1d-velocity-model = "visualisation.plot_1d_velocity_model:app" plot-rupture-path = "visualisation.plot_rupture_path:app" plot-stoch = "visualisation.sources.plot_stoch:app" plot-ts = "visualisation.plot_ts:app" +response-rrup = "visualisation.ims.response_rrup:app" [tool.setuptools.package-dir] visualisation = "visualisation" @@ -53,7 +54,7 @@ extend-select = [ # Missing function argument type-annotation "ANN001", # Using except without specifying an exception type to catch - "BLE001" + "BLE001", ] ignore = ["D104"] @@ -62,15 +63,15 @@ convention = "numpy" [tool.ruff.lint.isort] known-first-party = [ - "source_modelling", - "visualisation", - "workflow", - "pygmt_helper", - "qcore", - "empirical", - "nshmdb", - "IM_calculation", - "mera" + "source_modelling", + "visualisation", + "workflow", + "pygmt_helper", + "qcore", + "empirical", + "nshmdb", + "IM_calculation", + "mera", ] [tool.ruff.lint.per-file-ignores] @@ -80,9 +81,7 @@ known-first-party = [ "tests/**.py" = ["D"] [tool.coverage.run] -omit = [ - "visualisation/plot_ts.py" -] +omit = ["visualisation/plot_ts.py"] [tool.numpydoc_validation] checks = [ @@ -103,7 +102,7 @@ checks = [ "YD01", ] # remember to use single quotes for regex in TOML -exclude = [ # don't report on objects that match any of these regex - '\.undocumented_method$', - '\.__repr__$', +exclude = [ # don't report on objects that match any of these regex + '\.undocumented_method$', + '\.__repr__$', ] diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 7cd7f9a..785cdf8 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -326,7 +326,3 @@ def compare_sim_per_basin( fig.savefig(save, dpi=dpi) else: plt.show() - - -if __name__ == "__main__": - app() From c579ad1502e4c3672bb0765a6d1547ccc2d3d539 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 9 Dec 2025 15:29:53 +1300 Subject: [PATCH 03/41] deps: add missing rpy2 dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 5287e84..9baeafa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ pytest-cov pytest-xdist typer tqdm +rpy2 From 4e28aabebd44d4236bd14b1a5d21a75289d693f5 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 9 Dec 2025 15:33:04 +1300 Subject: [PATCH 04/41] deps: add fixed xarray dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9baeafa..a825f2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ pytest-xdist typer tqdm rpy2 +xarray[io] From 01c664209e24cffd1dd3ee70b27dfe92e972d9a7 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 9 Dec 2025 15:34:57 +1300 Subject: [PATCH 05/41] fix(response-rrup): explicitly use h5netcdf --- visualisation/ims/response_rrup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 785cdf8..66eeb93 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -309,7 +309,7 @@ def compare_sim_per_basin( Compare simulation dataset results to NSHM with subplots per basin. First subplot is all stations. """ - simulation_ds = xr.open_dataset(simulation_dataset_path) + simulation_ds = xr.open_dataset(simulation_dataset_path, engine="h5netcdf") fig = compare_sim_to_nshm_subplots( realisation_ffp, From 6591194c38b9df09e265f29a28c70f17c620965d Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 9 Dec 2025 15:51:57 +1300 Subject: [PATCH 06/41] fix(response-rrup): legend capitalisation --- visualisation/ims/response_rrup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 66eeb93..294f6e6 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -228,7 +228,7 @@ def compare_sim_to_nshm_subplots( ax, simulation_ds.rrup.values, simulation_ds.pSA.sel(period=period, component=component).values, - label="Simulated Stations", + label="Simulated stations", color="tab:gray", ) # Plot NSHM From 2b61b26e1c5852a0a9e9ec463f44051bf56e09cf Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Wed, 10 Dec 2025 11:27:18 +1300 Subject: [PATCH 07/41] fix(response-rrup): expose `--span` --- visualisation/ims/response_rrup.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 294f6e6..785eef6 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -155,6 +155,8 @@ def compare_sim_to_nshm_subplots( ymin: float | None = None, ymax: float | None = None, basin_vs_no_basin: bool = False, + all_in_one: bool = False, + span: float = 1 / 3, ): """ Create subplots: first shows all stations, then one subplot per basin. @@ -208,6 +210,7 @@ def compare_sim_to_nshm_subplots( basin_ds.pSA.sel(period=period, component=component).values, label="Basin stations", color="darkred", + span=span, ) plot_simulation_fit( ax, @@ -215,6 +218,7 @@ def compare_sim_to_nshm_subplots( non_basin_ds.pSA.sel(period=period, component=component).values, label="Non-basin stations", color="purple", + span=span, ) else: ax.scatter( @@ -230,6 +234,7 @@ def compare_sim_to_nshm_subplots( simulation_ds.pSA.sel(period=period, component=component).values, label="Simulated stations", color="tab:gray", + span=span, ) # Plot NSHM ax.legend() @@ -237,8 +242,19 @@ def compare_sim_to_nshm_subplots( ax.set_yscale("log") ax.set_xscale("log") - # --- Per-basin subplots --- - if plot_basins: + if plot_basins and all_in_one: + for basin in plot_basins: + subds = simulation_ds.sel( + station=[ + s + for s, b in zip( + simulation_ds.station.values, simulation_ds.basin.values + ) + if b == basin + ] + ) + + elif plot_basins: for i, basin in enumerate(plot_basins): row, col = np.unravel_index(i + 1, axes.shape) ax = axes[row, col] @@ -304,6 +320,8 @@ def compare_sim_per_basin( ymax: float | None = 10, component: str = "rotd50", compare_basin: bool = False, + all_in_one: bool = False, + span: float = 1 / 3, ) -> None: """ Compare simulation dataset results to NSHM with subplots per basin. @@ -320,6 +338,7 @@ def compare_sim_per_basin( ymin=ymin, ymax=ymax, basin_vs_no_basin=compare_basin, + span=span, ) if save: From d0adcfba0ad71aff3e3c9bb52d91f713954e8cb3 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:11:33 +1300 Subject: [PATCH 08/41] fix(response-rrup): add basin subplots in all-in-one --- visualisation/ims/response_rrup.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 785eef6..999359d 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -253,7 +253,23 @@ def compare_sim_to_nshm_subplots( if b == basin ] ) - + if len(subds.station) == 0: + continue + ax.scatter( + subds.rrup, + subds.pSA.sel(period=period, component=component).values, + alpha=0.7, + s=10, + label=f"{basin}", + ) + plot_nshm_fit( + ax, + realisation_ffp, + subds, + period, + nshm_rrup, + color="tab:blue", + ) elif plot_basins: for i, basin in enumerate(plot_basins): row, col = np.unravel_index(i + 1, axes.shape) From 3df84f36366cf68e3e53292cc6ca3c2791f41b71 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:28:43 +1300 Subject: [PATCH 09/41] fix(response-rrup): make basin names human readable --- visualisation/ims/response_rrup.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 999359d..eb2bbcf 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -1,3 +1,4 @@ +import re from pathlib import Path from typing import NamedTuple @@ -143,6 +144,13 @@ def plot_simulation_fit( ax.plot(rrup_out, np.exp(fit), c=color, label=label) +def human_readable_basin_name(basin_name: str) -> str: + # NE_Otago -> NE Otago + basin_name_no_underscore = basin_name.replace("_", " ") + # Greate(r)(W)ellington -> Greate(r) (W)ellington + return re.sub(r"([a-z])([A-Z])", r"\1 \2", basin_name_no_underscore) + + # ------------------------ # Compare per-basin subplots # ------------------------ @@ -260,7 +268,7 @@ def compare_sim_to_nshm_subplots( subds.pSA.sel(period=period, component=component).values, alpha=0.7, s=10, - label=f"{basin}", + label=f"{human_readable_basin(basin)}", ) plot_nshm_fit( ax, @@ -302,7 +310,7 @@ def compare_sim_to_nshm_subplots( s=10, ) - ax.set_title(f"Basin: {basin}") + ax.set_title(f"Basin: {human_readable_name(basin)}") ax.set_yscale("log") ax.set_xscale("log") From e76a6c53d82c04ae2b70e7e7d3cda7735eb7a360 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:28:54 +1300 Subject: [PATCH 10/41] deps: pin to python < 3.14 because of numba --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bce50be..a76e62f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "visualisation" authors = [{ name = "QuakeCoRE" }] description = "Visualisation repository for plotting scripts." readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.12,<3.14" dynamic = ["version", "dependencies"] From e5a98afa1029f56f1e1ff4b6e6bcdeb2d5e5bbf0 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:35:01 +1300 Subject: [PATCH 11/41] fix(response-rrup): incorrect basin name function --- visualisation/ims/response_rrup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index eb2bbcf..b23f4a8 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -268,7 +268,7 @@ def compare_sim_to_nshm_subplots( subds.pSA.sel(period=period, component=component).values, alpha=0.7, s=10, - label=f"{human_readable_basin(basin)}", + label=f"{human_readable_basin_name(basin)}", ) plot_nshm_fit( ax, @@ -310,7 +310,7 @@ def compare_sim_to_nshm_subplots( s=10, ) - ax.set_title(f"Basin: {human_readable_name(basin)}") + ax.set_title(f"Basin: {human_readable_basin_name(basin)}") ax.set_yscale("log") ax.set_xscale("log") From a60058c5b6e54a203566f9972af845c69394a285 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:36:17 +1300 Subject: [PATCH 12/41] fix(response-rrup): do not plot per-basin NSHM plots --- visualisation/ims/response_rrup.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index b23f4a8..7a98fc0 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -270,14 +270,7 @@ def compare_sim_to_nshm_subplots( s=10, label=f"{human_readable_basin_name(basin)}", ) - plot_nshm_fit( - ax, - realisation_ffp, - subds, - period, - nshm_rrup, - color="tab:blue", - ) + # No NSHM fit because of the overall fit plot elif plot_basins: for i, basin in enumerate(plot_basins): row, col = np.unravel_index(i + 1, axes.shape) From bf449e05746dabe26b5979d8bc51e487c73e0193 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:40:32 +1300 Subject: [PATCH 13/41] fix(response-rrup): do not create several plots in the all-in-one output --- visualisation/ims/response_rrup.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 7a98fc0..1286be6 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -172,19 +172,24 @@ def compare_sim_to_nshm_subplots( # Determine basins to plot plot_basins = basins or [] - fig, axes = utils.balanced_subplot_grid( - 1 + len(plot_basins), - 1.0, - subplot_size=(8, 6), - clear=False, - sharex=True, - sharey=True, - constrained_layout=True, - ) + if all_in_one: + fig, ax = plt.subplots(constrained_layout=True) + else: + fig, axes = utils.balanced_subplot_grid( + 1 + len(plot_basins), + 1.0, + subplot_size=(8, 6), + clear=False, + sharex=True, + sharey=True, + constrained_layout=True, + ) + ax = axes[0, 0] + max_rrup = min(500, simulation_ds.rrup.max().item()) nshm_rrup = np.geomspace(1e-3, max_rrup, num=100) # --- First subplot: all stations --- - ax = axes[0, 0] + ax.grid(True, which="both", axis="both", lw=0.3) plot_nshm_fit( ax, From df27db49c8ebb1368b77d83f337947bd7e33301a Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:49:40 +1300 Subject: [PATCH 14/41] refactor(response-rrup): remove the omni function --- visualisation/ims/response_rrup.py | 578 ++++++++++++++++++++--------- 1 file changed, 407 insertions(+), 171 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 1286be6..fc44760 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -1,6 +1,6 @@ import re from pathlib import Path -from typing import NamedTuple +from typing import Annotated, NamedTuple import matplotlib.pyplot as plt import numpy as np @@ -151,215 +151,451 @@ def human_readable_basin_name(basin_name: str) -> str: return re.sub(r"([a-z])([A-Z])", r"\1 \2", basin_name_no_underscore) -# ------------------------ -# Compare per-basin subplots -# ------------------------ -def compare_sim_to_nshm_subplots( +def _get_plotting_params(simulation_ds: xr.Dataset): + """Calculate common plotting parameters.""" + max_rrup = min(500, simulation_ds.rrup.max().item()) + nshm_rrup = np.geomspace(1e-3, max_rrup, num=100) + return max_rrup, nshm_rrup + + +def _apply_style_and_limits( + fig, + axes: np.ndarray, + period: float, + max_rrup: float, + ymin: float | None, + ymax: float | None, + is_multi_plot: bool, + ax_main: plt.Axes, +): + """Apply final styling, labels, and limits to all axes.""" + if is_multi_plot: + fig.supxlabel("Source to site distance, $R_{rup}$ [km]") + fig.supylabel(f"pSA({period:.2f} s) [g]") + else: + ax_main.set_xlabel("Source to site distance, $R_{rup}$ [km]") + ax_main.set_ylabel(f"pSA({period:.2f} s) [g]") + + if ymin is not None or ymax is not None: + for ax in axes.flatten(): + ax.set_ylim(bottom=ymin, top=ymax) + + ax_main.set_xlim(left=1e-1, right=max_rrup) + + +def _plot_nshm_fit_and_settings( + ax: plt.Axes, + realisation_ffp: Path, + data_ds: xr.Dataset, + period: float, + nshm_rrup: np.ndarray, + label: str | None = "NSHM logic tree prediction", +): + """Plots the common NSHM fit and initial axis settings (log scale, grid).""" + ax.grid(True, which="both", axis="both", lw=0.3) + plot_nshm_fit( + ax, + realisation_ffp, + data_ds, + period, + nshm_rrup, + color="tab:blue", + label=label, + ) + ax.set_yscale("log") + ax.set_xscale("log") + + +def _get_basin_stations(simulation_ds: xr.Dataset, basin: str): + """Filter the dataset for stations belonging to a specific basin.""" + # Uses xarray's .where() for cleaner filtering + return simulation_ds.where(simulation_ds.basin == basin, drop=True) + + +def plot_basin_vs_no_basin( realisation_ffp: Path, simulation_ds: xr.Dataset, period: float, - basins: list[str] | None = None, component: str = "rotd50", ymin: float | None = None, ymax: float | None = None, - basin_vs_no_basin: bool = False, - all_in_one: bool = False, span: float = 1 / 3, ): """ - Create subplots: first shows all stations, then one subplot per basin. + Creates a single plot comparing simulation data for basin stations + vs. non-basin stations against the NSHM prediction. """ - # Determine basins to plot - plot_basins = basins or [] + max_rrup, nshm_rrup = _get_plotting_params(simulation_ds) - if all_in_one: - fig, ax = plt.subplots(constrained_layout=True) - else: - fig, axes = utils.balanced_subplot_grid( - 1 + len(plot_basins), - 1.0, - subplot_size=(8, 6), - clear=False, - sharex=True, - sharey=True, - constrained_layout=True, - ) - ax = axes[0, 0] + fig, ax = plt.subplots(constrained_layout=True) + all_axes = np.array([ax]) # For style helper consistency - max_rrup = min(500, simulation_ds.rrup.max().item()) - nshm_rrup = np.geomspace(1e-3, max_rrup, num=100) - # --- First subplot: all stations --- + # 1. Plot NSHM fit and set axis scales/grid + _plot_nshm_fit_and_settings(ax, realisation_ffp, simulation_ds, period, nshm_rrup) - ax.grid(True, which="both", axis="both", lw=0.3) - plot_nshm_fit( + # 2. Split and plot simulation data + basin_ds = simulation_ds.where(simulation_ds.basin != "", drop=True) + non_basin_ds = simulation_ds.where(simulation_ds.basin == "", drop=True) + + # Basin stations + basin_pSA = basin_ds.pSA.sel(period=period, component=component).values + ax.scatter(basin_ds.rrup, basin_pSA, c="tab:red", alpha=0.3, s=5) + plot_simulation_fit( ax, - realisation_ffp, - simulation_ds, - period, - nshm_rrup, - color="tab:blue", - label="NSHM logic tree prediction", + basin_ds.rrup.values, + basin_pSA, + label="Basin stations", + color="darkred", + span=span, ) - if basin_vs_no_basin: - basin_ds = simulation_ds.where(simulation_ds.basin != "") - non_basin_ds = simulation_ds.where(simulation_ds.basin == "") - ax.scatter( - basin_ds.rrup, - basin_ds.pSA.sel(period=period, component=component).values, - c="tab:red", - alpha=0.3, - s=5, - ) - ax.scatter( - non_basin_ds.rrup, - non_basin_ds.pSA.sel(period=period, component=component).values, - c="tab:purple", - alpha=0.3, - s=5, - ) - plot_simulation_fit( - ax, - basin_ds.rrup.values, - basin_ds.pSA.sel(period=period, component=component).values, - label="Basin stations", - color="darkred", - span=span, - ) - plot_simulation_fit( - ax, - non_basin_ds.rrup.values, - non_basin_ds.pSA.sel(period=period, component=component).values, - label="Non-basin stations", - color="purple", - span=span, - ) - else: - ax.scatter( - simulation_ds.rrup, - simulation_ds.pSA.sel(period=period, component=component).values, - c="k", - alpha=0.3, - s=10, - ) - plot_simulation_fit( - ax, - simulation_ds.rrup.values, - simulation_ds.pSA.sel(period=period, component=component).values, - label="Simulated stations", - color="tab:gray", - span=span, - ) - # Plot NSHM + + # Non-Basin stations + non_basin_pSA = non_basin_ds.pSA.sel(period=period, component=component).values + ax.scatter(non_basin_ds.rrup, non_basin_pSA, c="tab:purple", alpha=0.3, s=5) + plot_simulation_fit( + ax, + non_basin_ds.rrup.values, + non_basin_pSA, + label="Non-basin stations", + color="purple", + span=span, + ) + ax.legend() - ax.set_yscale("log") - ax.set_xscale("log") + # 3. Apply final styling + _apply_style_and_limits(fig, all_axes, period, max_rrup, ymin, ymax, False, ax) - if plot_basins and all_in_one: - for basin in plot_basins: - subds = simulation_ds.sel( - station=[ - s - for s, b in zip( - simulation_ds.station.values, simulation_ds.basin.values - ) - if b == basin - ] - ) - if len(subds.station) == 0: - continue - ax.scatter( - subds.rrup, - subds.pSA.sel(period=period, component=component).values, - alpha=0.7, - s=10, - label=f"{human_readable_basin_name(basin)}", - ) - # No NSHM fit because of the overall fit plot - elif plot_basins: - for i, basin in enumerate(plot_basins): - row, col = np.unravel_index(i + 1, axes.shape) - ax = axes[row, col] - subds = simulation_ds.sel( - station=[ - s - for s, b in zip( - simulation_ds.station.values, simulation_ds.basin.values - ) - if b == basin - ] - ) - if len(subds.station) == 0: - continue - ax.grid(True, which="both", axis="both", lw=0.3) - plot_nshm_fit( - ax, - realisation_ffp, - subds, - period, - nshm_rrup, - color="tab:blue", - ) - ax.scatter( - subds.rrup, - subds.pSA.sel(period=period, component=component).values, - c="red", - alpha=0.7, - s=10, - ) + return fig - ax.set_title(f"Basin: {human_readable_basin_name(basin)}") - ax.set_yscale("log") - ax.set_xscale("log") - # --- Axis labels --- - if plot_basins: - fig.supxlabel("Source to site distance, $R_{rup}$ [km]") - fig.supylabel(f"pSA({period:.2f} s) [g]") - else: - ax.set_xlabel("Source to site distance, $R_{rup}$ [km]") - ax.set_ylabel(f"pSA({period:.2f} s) [g]") +def plot_separate_basin_subplots( + realisation_ffp: Path, + simulation_ds: xr.Dataset, + period: float, + basins: list[str], + component: str = "rotd50", + ymin: float | None = None, + ymax: float | None = None, + span: float = 1 / 3, +): + """ + Creates a grid of subplots: one for all stations, and one for each basin. + """ + if not basins: + raise ValueError("Basins list cannot be empty for separate basin plotting.") + + max_rrup, nshm_rrup = _get_plotting_params(simulation_ds) + + # Setup figure with 1 + N_basins plots + num_plots = 1 + len(basins) + # utils.balanced_subplot_grid is assumed to return a 2D array of axes + fig, axes_2d = utils.balanced_subplot_grid( + num_plots, + 1.0, + subplot_size=(8, 6), + clear=False, + sharex=True, + sharey=True, + constrained_layout=True, + ) + all_axes = axes_2d.flatten() + ax_all_stations = all_axes[0] + + # --- A. Plot All Stations (Primary Plot) --- + _plot_nshm_fit_and_settings( + ax_all_stations, realisation_ffp, simulation_ds, period, nshm_rrup + ) + + # Plot all simulation stations together + all_pSA = simulation_ds.pSA.sel(period=period, component=component).values + ax_all_stations.scatter(simulation_ds.rrup, all_pSA, c="k", alpha=0.3, s=10) + plot_simulation_fit( + ax_all_stations, + simulation_ds.rrup.values, + all_pSA, + label="Simulated stations", + color="tab:gray", + span=span, + ) + ax_all_stations.set_title("All Stations (Combined)") + ax_all_stations.legend() + + # --- B. Plot Individual Basins --- + for i, basin in enumerate(basins): + ax_basin = all_axes[i + 1] + subds = _get_basin_stations(simulation_ds, basin) + + if len(subds.station) == 0: + ax_basin.set_title(f"Basin: {human_readable_basin_name(basin)} (No data)") + continue + + _plot_nshm_fit_and_settings( + ax_basin, realisation_ffp, subds, period, nshm_rrup, label=None + ) # No legend for NSHM here + + basin_pSA = subds.pSA.sel(period=period, component=component).values + ax_basin.scatter(subds.rrup, basin_pSA, c="red", alpha=0.7, s=10) + ax_basin.set_title(f"Basin: {human_readable_basin_name(basin)}") + + # 3. Apply final styling (is_multi_plot=True) + _apply_style_and_limits( + fig, all_axes, period, max_rrup, ymin, ymax, True, ax_all_stations + ) - if ymin is not None or ymax is not None: - for ax in axes.flatten(): - ax.set_ylim(bottom=ymin, top=ymax) - ax.set_xlim(left=1e-1, right=max_rrup) return fig -# ------------------------ -# CLI -# ------------------------ -@app.command() -def compare_sim_per_basin( +def plot_combined_basin_plot( realisation_ffp: Path, - simulation_dataset_path: Path, + simulation_ds: xr.Dataset, period: float, - basins: list[str] | None = None, - save: Path | None = None, - dpi: int = 300, - ymin: float | None = 1e-5, - ymax: float | None = 10, + basins: list[str], component: str = "rotd50", - compare_basin: bool = False, - all_in_one: bool = False, + ymin: float | None = None, + ymax: float | None = None, span: float = 1 / 3, +): + """ + Creates a single plot showing all stations and then overlays each basin. + """ + if not basins: + # Fall back to plotting only all stations if no basins are specified + print("Warning: No basins specified. Plotting all stations only.") + + max_rrup, nshm_rrup = _get_plotting_params(simulation_ds) + + fig, ax = plt.subplots(constrained_layout=True) + all_axes = np.array([ax]) # For style helper consistency + + # 1. Plot NSHM fit and set axis scales/grid + _plot_nshm_fit_and_settings(ax, realisation_ffp, simulation_ds, period, nshm_rrup) + + # 2. Plot All Stations (Base Layer) + all_pSA = simulation_ds.pSA.sel(period=period, component=component).values + ax.scatter( + simulation_ds.rrup, + all_pSA, + c="k", + alpha=0.1, + s=10, + label="All Simulated Stations", + ) + plot_simulation_fit( + ax, + simulation_ds.rrup.values, + all_pSA, + label="Overall Fit", + color="tab:gray", + span=span, + ) + + # 3. Overlay Individual Basins + # We use a color cycle to differentiate the basin scatters + prop_cycle = plt.rcParams["axes.prop_cycle"] + colors = prop_cycle.by_key()["color"] + + for i, basin in enumerate(basins): + subds = _get_basin_stations(simulation_ds, basin) + + if len(subds.station) == 0: + continue + + color = colors[i % len(colors)] + + basin_pSA = subds.pSA.sel(period=period, component=component).values + ax.scatter( + subds.rrup, + basin_pSA, + alpha=0.7, + s=10, + color=color, + label=f"{human_readable_basin_name(basin)} Stations", + ) + + ax.legend() + ax.set_title("Combined View: All Stations & Individual Basins") + + # 4. Apply final styling (is_multi_plot=False) + _apply_style_and_limits(fig, all_axes, period, max_rrup, ymin, ymax, False, ax) + + return fig + + +@app.command() +def plot_basin_split( + # Arguments + realisation_ffp: Annotated[ + Path, typer.Argument(help="Path to the NSHM fit data file.") + ], + simulation_dataset_path: Annotated[ + Path, typer.Argument(help="Path to the simulation xarray dataset (H5NetCDF).") + ], + period: Annotated[ + float, typer.Argument(help="Spectral acceleration period (T) in seconds.") + ], + # Options + save: Annotated[ + Path | None, + typer.Option("--save", "-s", help="Output path to save the figure."), + ] = None, + dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, + ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, + ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + component: Annotated[ + str, typer.Option(help="PSA component to plot (e.g., rotd50).") + ] = "rotd50", + span: Annotated[ + float, typer.Option(help="Smoothing span for the simulation fit line.") + ] = 1 / 3, ) -> None: """ - Compare simulation dataset results to NSHM with subplots per basin. - First subplot is all stations. + Creates a single plot comparing simulated pSA for Basin stations against + Non-Basin stations, along with the NSHM prediction. """ simulation_ds = xr.open_dataset(simulation_dataset_path, engine="h5netcdf") - fig = compare_sim_to_nshm_subplots( + fig = plot_basin_vs_no_basin( + realisation_ffp, + simulation_ds, + period, + component=component, + ymin=ymin, + ymax=ymax, + span=span, + ) + + if save: + fig.savefig(save, dpi=dpi) + else: + plt.show() + + +# ---------------------------------------------------- +# 2. Command: All Stations + Separate Basin Subplots +# ---------------------------------------------------- +@app.command() +def plot_basins_separate( + # Arguments + realisation_ffp: Annotated[ + Path, typer.Argument(help="Path to the NSHM fit data file.") + ], + simulation_dataset_path: Annotated[ + Path, typer.Argument(help="Path to the simulation xarray dataset (H5NetCDF).") + ], + period: Annotated[ + float, typer.Argument(help="Spectral acceleration period (T) in seconds.") + ], + # Options + basins: Annotated[ + list[str] | None, + typer.Option( + "--basin", "-b", help="List of basins to plot in separate subplots." + ), + ] = None, + save: Annotated[ + Path | None, + typer.Option("--save", "-s", help="Output path to save the figure."), + ] = None, + dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, + ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, + ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + component: Annotated[ + str, typer.Option(help="PSA component to plot (e.g., rotd50).") + ] = "rotd50", + span: Annotated[ + float, typer.Option(help="Smoothing span for the overall simulation fit line.") + ] = 1 / 3, +) -> None: + """ + Creates a grid of plots: one showing all stations, and one separate subplot + for each specified basin, comparing to NSHM. + """ + # Note: basins will be list[str] or None. Check for empty list after parsing. + basins_list = basins or [] + if not basins_list: + typer.echo( + "Error: At least one basin must be specified using --basin for this command.", + err=True, + ) + raise typer.Exit(code=1) + + simulation_ds = xr.open_dataset(simulation_dataset_path, engine="h5netcdf") + + fig = plot_separate_basin_subplots( + realisation_ffp, + simulation_ds, + period, + basins=basins_list, + component=component, + ymin=ymin, + ymax=ymax, + span=span, + ) + + if save: + fig.savefig(save, dpi=dpi) + else: + plt.show() + + +# ---------------------------------------------------- +# 3. Command: All Stations + Basins in Combined Plot +# ---------------------------------------------------- +@app.command() +def plot_basins_combined( + # Arguments + realisation_ffp: Annotated[ + Path, typer.Argument(help="Path to the NSHM fit data file.") + ], + simulation_dataset_path: Annotated[ + Path, typer.Argument(help="Path to the simulation xarray dataset (H5NetCDF).") + ], + period: Annotated[ + float, typer.Argument(help="Spectral acceleration period (T) in seconds.") + ], + # Options + basins: Annotated[ + list[str] | None, + typer.Option( + "--basin", "-b", help="List of basins to overlay on the main plot." + ), + ] = None, + save: Annotated[ + Path | None, + typer.Option("--save", "-s", help="Output path to save the figure."), + ] = None, + dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, + ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, + ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + component: Annotated[ + str, typer.Option(help="PSA component to plot (e.g., rotd50).") + ] = "rotd50", + span: Annotated[ + float, typer.Option(help="Smoothing span for the overall simulation fit line.") + ] = 1 / 3, +) -> None: + """ + Creates a single plot showing all simulation stations (as background), + overlaid with scatters for each specified basin, and the NSHM prediction. + """ + basins_list = basins or [] + if not basins_list: + typer.echo( + "Warning: No basins specified. Plotting all stations (combined) only.", + err=True, + ) + + simulation_ds = xr.open_dataset(simulation_dataset_path, engine="h5netcdf") + + fig = plot_combined_basin_plot( realisation_ffp, simulation_ds, period, - basins=basins, + basins=basins_list, component=component, ymin=ymin, ymax=ymax, - basin_vs_no_basin=compare_basin, span=span, ) From 6f6f0c68a5542a91d126859caceef57832627f02 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:56:27 +1300 Subject: [PATCH 15/41] fix: remove titles, plot per-basin fit --- visualisation/ims/response_rrup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index fc44760..e90fed8 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -318,7 +318,6 @@ def plot_separate_basin_subplots( color="tab:gray", span=span, ) - ax_all_stations.set_title("All Stations (Combined)") ax_all_stations.legend() # --- B. Plot Individual Basins --- @@ -412,9 +411,11 @@ def plot_combined_basin_plot( color=color, label=f"{human_readable_basin_name(basin)} Stations", ) + plot_simulation_fit( + ax, subds.rrup, basin_pSA, label=None, color=color, span=span + ) ax.legend() - ax.set_title("Combined View: All Stations & Individual Basins") # 4. Apply final styling (is_multi_plot=False) _apply_style_and_limits(fig, all_axes, period, max_rrup, ymin, ymax, False, ax) From 89fdcd4d37e5deca766da7bb3d48312cc8ea2fb3 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 10:58:38 +1300 Subject: [PATCH 16/41] fix(response-rrup): rrup values for plot simulation fit --- visualisation/ims/response_rrup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index e90fed8..07b7659 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -131,7 +131,7 @@ def plot_simulation_fit( ax: Axes, rrup: np.ndarray, psa: np.ndarray, - label: str, + label: str | None, color: str, span: float = 1 / 3, ) -> None: @@ -412,7 +412,7 @@ def plot_combined_basin_plot( label=f"{human_readable_basin_name(basin)} Stations", ) plot_simulation_fit( - ax, subds.rrup, basin_pSA, label=None, color=color, span=span + ax, subds.rrup.values, basin_pSA, label=None, color=color, span=span ) ax.legend() From b44957d2355d4e0f405dab8e57b475b575789e86 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 11:05:43 +1300 Subject: [PATCH 17/41] fix(response-rrup): split LOESS fit from data scatter --- visualisation/ims/response_rrup.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 07b7659..f526b74 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -411,6 +411,17 @@ def plot_combined_basin_plot( color=color, label=f"{human_readable_basin_name(basin)} Stations", ) + + for i, basin in enumerate(basins): + subds = _get_basin_stations(simulation_ds, basin) + + if len(subds.station) == 0: + continue + + color = colors[i % len(colors)] + + basin_pSA = subds.pSA.sel(period=period, component=component).values + plot_simulation_fit( ax, subds.rrup.values, basin_pSA, label=None, color=color, span=span ) From ae4bd9489a0ddb82268ee9a5d080191ffa927d2f Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 11:11:56 +1300 Subject: [PATCH 18/41] fix(response-rrup): clear first --- visualisation/ims/response_rrup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index f526b74..ee4d275 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -294,7 +294,7 @@ def plot_separate_basin_subplots( num_plots, 1.0, subplot_size=(8, 6), - clear=False, + clear=True, sharex=True, sharey=True, constrained_layout=True, From 7ba98667cc272c46ad62ef0606221b398742497b Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 11:22:47 +1300 Subject: [PATCH 19/41] feat: add response spectra plot --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a76e62f..0f8964a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ plot-1d-velocity-model = "visualisation.plot_1d_velocity_model:app" plot-rupture-path = "visualisation.plot_rupture_path:app" plot-stoch = "visualisation.sources.plot_stoch:app" plot-ts = "visualisation.plot_ts:app" -response-rrup = "visualisation.ims.response_rrup:app" +plot-response-rrup = "visualisation.ims.response_rrup:app" +plot-response-spectra = "visualisation.waveforms.plot_response_spectra:app" [tool.setuptools.package-dir] visualisation = "visualisation" From 3f9a81b6a420dc0ef04fd6e3760070e80ee13cff Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 11:43:29 +1300 Subject: [PATCH 20/41] feat(plot-response-spectra): plot several stations at once --- .../waveforms/plot_response_spectra.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/visualisation/waveforms/plot_response_spectra.py b/visualisation/waveforms/plot_response_spectra.py index e7391f9..4a25bb4 100644 --- a/visualisation/waveforms/plot_response_spectra.py +++ b/visualisation/waveforms/plot_response_spectra.py @@ -16,13 +16,14 @@ def plot_spectra( + fig: Figure, + ax: Axes, dataset: xr.Dataset, - station: str, + stations: list[str], component: str, ymax: float | None = None, ymin: float | None = None, - **kwargs, -) -> tuple[Figure, list[Axes]]: +) -> None: """Plot a spectra from a simulation. Parameters @@ -38,25 +39,28 @@ def plot_spectra( ymin : float or None Min limit for y-axis. """ - spectra = dataset.pSA.sel(station=station, component=component).values - periods = dataset.period.values - fig, ax = plt.subplots(**kwargs) - ax.plot(periods, spectra) - ax.grid() ax.set_ylabel(f"pSA [{component}, g]") - if ymin is not None or ymax is not None: - ax.set_ylim(bottom=ymin, top=ymax) ax.set_xlabel("Period [s]") ax.set_xscale("log") ax.set_yscale("log") ax.grid(visible=True, which="both", axis="both", lw=0.3) - return fig, ax + if ymin is not None or ymax is not None: + ax.set_ylim(bottom=ymin, top=ymax) + + for station in stations: + spectra = dataset.pSA.sel(station=station, component=component).values + periods = dataset.period.values + + ax.plot(periods, spectra, label=station) + + if len(stations) > 1: + ax.legend() @cli.from_docstring(app) def plot_spectra_cli( dataset_path: Annotated[Path, typer.Argument()], - station: Annotated[str, typer.Argument()], + stations: Annotated[list[str], typer.Argument()], title: str | None = None, save: Path | None = None, dpi: int = 300, @@ -79,13 +83,15 @@ def plot_spectra_cli( """ dset = xr.open_dataset(dataset_path, engine="h5netcdf") cm = 1 / 2.54 - fig, _ = plot_spectra( + fig, ax = plt.subplots(figsize=(width * cm, height * cm)) + plot_spectra( + fig, + ax, dset, - station, + stations, component, ymin=ymin, ymax=ymax, - figsize=(width * cm, height * cm), ) if title: From 754bb15d29b46645f6a246b40f07eea51aed9b91 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 11:51:47 +1300 Subject: [PATCH 21/41] fix(response-rrup): use span = 1 for basin subplots --- visualisation/ims/response_rrup.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index ee4d275..5ad5b50 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -423,7 +423,12 @@ def plot_combined_basin_plot( basin_pSA = subds.pSA.sel(period=period, component=component).values plot_simulation_fit( - ax, subds.rrup.values, basin_pSA, label=None, color=color, span=span + ax, + subds.rrup.values, + basin_pSA, + label=None, + color=color, + span=1, # for each basin only show smooth line ) ax.legend() From 70b02352d626f43e6ffe015e585934bce73d58f4 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 12:00:49 +1300 Subject: [PATCH 22/41] feat(response-rrup): allow xmin/xmax settings --- visualisation/ims/response_rrup.py | 33 +++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 5ad5b50..81d4011 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -163,6 +163,8 @@ def _apply_style_and_limits( axes: np.ndarray, period: float, max_rrup: float, + xmin: float | None, + xmax: float | None, ymin: float | None, ymax: float | None, is_multi_plot: bool, @@ -179,6 +181,9 @@ def _apply_style_and_limits( if ymin is not None or ymax is not None: for ax in axes.flatten(): ax.set_ylim(bottom=ymin, top=ymax) + if xmin is not None or xmax is not None: + for ax in axes.flatten(): + ax.set_xlim(bottom=xmin, top=xmax) ax_main.set_xlim(left=1e-1, right=max_rrup) @@ -217,6 +222,8 @@ def plot_basin_vs_no_basin( simulation_ds: xr.Dataset, period: float, component: str = "rotd50", + xmin: float | None = None, + xmax: float | None = None, ymin: float | None = None, ymax: float | None = None, span: float = 1 / 3, @@ -264,7 +271,9 @@ def plot_basin_vs_no_basin( ax.legend() # 3. Apply final styling - _apply_style_and_limits(fig, all_axes, period, max_rrup, ymin, ymax, False, ax) + _apply_style_and_limits( + fig, all_axes, period, max_rrup, xmin, xmax, ymin, ymax, False, ax + ) return fig @@ -275,6 +284,8 @@ def plot_separate_basin_subplots( period: float, basins: list[str], component: str = "rotd50", + xmin: float | None = None, + xmax: float | None = None, ymin: float | None = None, ymax: float | None = None, span: float = 1 / 3, @@ -339,7 +350,7 @@ def plot_separate_basin_subplots( # 3. Apply final styling (is_multi_plot=True) _apply_style_and_limits( - fig, all_axes, period, max_rrup, ymin, ymax, True, ax_all_stations + fig, all_axes, period, max_rrup, xmin, xmax, ymin, ymax, True, ax_all_stations ) return fig @@ -351,6 +362,8 @@ def plot_combined_basin_plot( period: float, basins: list[str], component: str = "rotd50", + xmin: float | None = None, + xmax: float | None = None, ymin: float | None = None, ymax: float | None = None, span: float = 1 / 3, @@ -434,7 +447,9 @@ def plot_combined_basin_plot( ax.legend() # 4. Apply final styling (is_multi_plot=False) - _apply_style_and_limits(fig, all_axes, period, max_rrup, ymin, ymax, False, ax) + _apply_style_and_limits( + fig, all_axes, period, max_rrup, xmin, xmax, ymin, ymax, False, ax + ) return fig @@ -459,6 +474,8 @@ def plot_basin_split( dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + xmin: Annotated[float | None, typer.Option(help="Minimum x-axis limit.")] = None, + xmax: Annotated[float | None, typer.Option(help="Maximum x-axis limit.")] = None, component: Annotated[ str, typer.Option(help="PSA component to plot (e.g., rotd50).") ] = "rotd50", @@ -479,6 +496,8 @@ def plot_basin_split( component=component, ymin=ymin, ymax=ymax, + xmin=xmin, + xmax=xmax, span=span, ) @@ -517,6 +536,8 @@ def plot_basins_separate( dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + xmin: Annotated[float | None, typer.Option(help="Minimum x-axis limit.")] = None, + xmax: Annotated[float | None, typer.Option(help="Maximum x-axis limit.")] = None, component: Annotated[ str, typer.Option(help="PSA component to plot (e.g., rotd50).") ] = "rotd50", @@ -547,6 +568,8 @@ def plot_basins_separate( component=component, ymin=ymin, ymax=ymax, + xmin=xmin, + xmax=xmax, span=span, ) @@ -585,6 +608,8 @@ def plot_basins_combined( dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + xmin: Annotated[float | None, typer.Option(help="Minimum x-axis limit.")] = None, + xmax: Annotated[float | None, typer.Option(help="Maximum x-axis limit.")] = None, component: Annotated[ str, typer.Option(help="PSA component to plot (e.g., rotd50).") ] = "rotd50", @@ -611,6 +636,8 @@ def plot_basins_combined( period, basins=basins_list, component=component, + xmin=xmin, + xmax=xmax, ymin=ymin, ymax=ymax, span=span, From 0556e00a925838cab0f50e96029351e4c98db63e Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 12:04:26 +1300 Subject: [PATCH 23/41] fix(response-rrup): correctly set x-lim --- visualisation/ims/response_rrup.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 81d4011..5ea5076 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -163,12 +163,12 @@ def _apply_style_and_limits( axes: np.ndarray, period: float, max_rrup: float, - xmin: float | None, - xmax: float | None, ymin: float | None, ymax: float | None, is_multi_plot: bool, ax_main: plt.Axes, + xmin: float | None = 1e-1, + xmax: float | None = None, ): """Apply final styling, labels, and limits to all axes.""" if is_multi_plot: @@ -181,11 +181,9 @@ def _apply_style_and_limits( if ymin is not None or ymax is not None: for ax in axes.flatten(): ax.set_ylim(bottom=ymin, top=ymax) - if xmin is not None or xmax is not None: - for ax in axes.flatten(): - ax.set_xlim(bottom=xmin, top=xmax) - - ax_main.set_xlim(left=1e-1, right=max_rrup) + xmax = xmax or max_rrup + for ax in axes.flatten(): + ax.set_xlim(left=xmin, right=xmax) def _plot_nshm_fit_and_settings( From 9dc9a642480a63c79c3abb227de703f9a2573a7a Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 12:08:27 +1300 Subject: [PATCH 24/41] fix(response-rrup): correct xmin/xmax passing --- visualisation/ims/response_rrup.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 5ea5076..c0e8a54 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -270,7 +270,7 @@ def plot_basin_vs_no_basin( # 3. Apply final styling _apply_style_and_limits( - fig, all_axes, period, max_rrup, xmin, xmax, ymin, ymax, False, ax + fig, all_axes, period, max_rrup, ymin, ymax, False, ax, xmin=xmin, xmax=xmax ) return fig @@ -348,7 +348,16 @@ def plot_separate_basin_subplots( # 3. Apply final styling (is_multi_plot=True) _apply_style_and_limits( - fig, all_axes, period, max_rrup, xmin, xmax, ymin, ymax, True, ax_all_stations + fig, + all_axes, + period, + max_rrup, + ymin, + ymax, + True, + ax_all_stations, + xmin=xmin, + xmax=xmax, ) return fig @@ -446,7 +455,7 @@ def plot_combined_basin_plot( # 4. Apply final styling (is_multi_plot=False) _apply_style_and_limits( - fig, all_axes, period, max_rrup, xmin, xmax, ymin, ymax, False, ax + fig, all_axes, period, max_rrup, ymin, ymax, False, ax, xmin=xmin, xmax=xmax ) return fig From 21ccb4dc6a23cd33b2ae69ea355bc0f75dd4fb9f Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 12:17:03 +1300 Subject: [PATCH 25/41] fix(response-rrup): use viridis for more basin colours --- visualisation/ims/response_rrup.py | 36 +++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index c0e8a54..60802c3 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -369,18 +369,20 @@ def plot_combined_basin_plot( period: float, basins: list[str], component: str = "rotd50", - xmin: float | None = None, - xmax: float | None = None, - ymin: float | None = None, - ymax: float | None = None, + xmin: Union[float, None] = None, + xmax: Union[float, None] = None, + ymin: Union[float, None] = None, + ymax: Union[float, None] = None, span: float = 1 / 3, ): """ Creates a single plot showing all stations and then overlays each basin. + Uses a colormap to assign distinct colors to each basin plot. """ if not basins: # Fall back to plotting only all stations if no basins are specified print("Warning: No basins specified. Plotting all stations only.") + # Proceed with the rest of the function for the 'All Stations' plot max_rrup, nshm_rrup = _get_plotting_params(simulation_ds) @@ -409,18 +411,26 @@ def plot_combined_basin_plot( span=span, ) - # 3. Overlay Individual Basins - # We use a color cycle to differentiate the basin scatters - prop_cycle = plt.rcParams["axes.prop_cycle"] - colors = prop_cycle.by_key()["color"] + # --- FIX: Generate distinct colors using a Colormap --- + # Choose a colormap, e.g., 'viridis'. Other good choices: 'plasma', 'cividis'. + # N is the number of basins we need colors for. + N = len(basins) + cmap = cm.get_cmap("viridis") + + # Generate N colors by sampling the colormap evenly + # We sample from a range (0.1 to 0.9) to avoid the very darkest/lightest ends of the map + basin_colors = [cmap(i) for i in np.linspace(0.1, 0.9, N)] + # --- END FIX --- + # 3. Overlay Individual Basins (Scatter Plots) for i, basin in enumerate(basins): subds = _get_basin_stations(simulation_ds, basin) if len(subds.station) == 0: continue - color = colors[i % len(colors)] + # Use the distinct color generated from the colormap + color = basin_colors[i] basin_pSA = subds.pSA.sel(period=period, component=component).values ax.scatter( @@ -432,13 +442,16 @@ def plot_combined_basin_plot( label=f"{human_readable_basin_name(basin)} Stations", ) + # 4. Overlay Individual Basins (Fit Lines) + # This loop is separate to ensure the fit lines are drawn *on top* of all scatter points for i, basin in enumerate(basins): subds = _get_basin_stations(simulation_ds, basin) if len(subds.station) == 0: continue - color = colors[i % len(colors)] + # Use the distinct color generated from the colormap (same as the scatter plot) + color = basin_colors[i] basin_pSA = subds.pSA.sel(period=period, component=component).values @@ -452,8 +465,9 @@ def plot_combined_basin_plot( ) ax.legend() + # - # 4. Apply final styling (is_multi_plot=False) + # 5. Apply final styling (is_multi_plot=False) _apply_style_and_limits( fig, all_axes, period, max_rrup, ymin, ymax, False, ax, xmin=xmin, xmax=xmax ) From ee57aaf41f031ee04fb8d99d5cc6b677b286b342 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 12:18:49 +1300 Subject: [PATCH 26/41] fix(response-rrup): no Union --- visualisation/ims/response_rrup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 60802c3..87b4c04 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -369,10 +369,10 @@ def plot_combined_basin_plot( period: float, basins: list[str], component: str = "rotd50", - xmin: Union[float, None] = None, - xmax: Union[float, None] = None, - ymin: Union[float, None] = None, - ymax: Union[float, None] = None, + xmin: float | None = None, + xmax: float | None = None, + ymin: float | None = None, + ymax: float | None = None, span: float = 1 / 3, ): """ From c86884d08953af425e92d15d9ed2de7ed410be82 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 12:20:47 +1300 Subject: [PATCH 27/41] fix(response-rrup): missing colourmap import --- visualisation/ims/response_rrup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 87b4c04..f42eb01 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -9,6 +9,7 @@ import pandas as pd import typer import xarray as xr +from matplotlib import colormaps as cm from matplotlib.axes import Axes from rpy2.robjects import default_converter, globalenv, numpy2ri, r from rpy2.robjects.conversion import localconverter From a8c54c10ca39bcf5b84f134a3515e2f8f0bfc7b3 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 13:43:02 +1300 Subject: [PATCH 28/41] fix(response-rrup): use Dark2 set for colours --- visualisation/ims/response_rrup.py | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index f42eb01..3f2c910 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Annotated, NamedTuple +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt @@ -9,7 +10,6 @@ import pandas as pd import typer import xarray as xr -from matplotlib import colormaps as cm from matplotlib.axes import Axes from rpy2.robjects import default_converter, globalenv, numpy2ri, r from rpy2.robjects.conversion import localconverter @@ -415,23 +415,17 @@ def plot_combined_basin_plot( # --- FIX: Generate distinct colors using a Colormap --- # Choose a colormap, e.g., 'viridis'. Other good choices: 'plasma', 'cividis'. # N is the number of basins we need colors for. - N = len(basins) - cmap = cm.get_cmap("viridis") - # Generate N colors by sampling the colormap evenly - # We sample from a range (0.1 to 0.9) to avoid the very darkest/lightest ends of the map - basin_colors = [cmap(i) for i in np.linspace(0.1, 0.9, N)] # --- END FIX --- - + basin_colours = mpl.color_sequences["Dark2"] # 3. Overlay Individual Basins (Scatter Plots) - for i, basin in enumerate(basins): + for basin, colour in zip(basins, basin_colours): subds = _get_basin_stations(simulation_ds, basin) if len(subds.station) == 0: continue # Use the distinct color generated from the colormap - color = basin_colors[i] basin_pSA = subds.pSA.sel(period=period, component=component).values ax.scatter( @@ -439,29 +433,16 @@ def plot_combined_basin_plot( basin_pSA, alpha=0.7, s=10, - color=color, + color=colour, label=f"{human_readable_basin_name(basin)} Stations", ) - # 4. Overlay Individual Basins (Fit Lines) - # This loop is separate to ensure the fit lines are drawn *on top* of all scatter points - for i, basin in enumerate(basins): - subds = _get_basin_stations(simulation_ds, basin) - - if len(subds.station) == 0: - continue - - # Use the distinct color generated from the colormap (same as the scatter plot) - color = basin_colors[i] - - basin_pSA = subds.pSA.sel(period=period, component=component).values - plot_simulation_fit( ax, subds.rrup.values, basin_pSA, label=None, - color=color, + color=colour, span=1, # for each basin only show smooth line ) From 64ff9996b26e20ebd1a178f3a8b0b3f66c84dbe0 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 13:46:06 +1300 Subject: [PATCH 29/41] fix(response-rrup): make basins stand out more --- visualisation/ims/response_rrup.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 3f2c910..e9044ff 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -135,13 +135,17 @@ def plot_simulation_fit( label: str | None, color: str, span: float = 1 / 3, + show_bands: bool = True, ) -> None: """Plot LOESS fit for a subset of simulation data.""" rrup_out = np.linspace(rrup.min(), rrup.max(), num=100) fit, ci_low, ci_high = fit_loess_r( np.log(psa), np.log(rrup), np.log(rrup_out), span=span ) - ax.fill_between(rrup_out, np.exp(ci_low), np.exp(ci_high), alpha=0.3, color=color) + if show_bands: + ax.fill_between( + rrup_out, np.exp(ci_low), np.exp(ci_high), alpha=0.3, color=color + ) ax.plot(rrup_out, np.exp(fit), c=color, label=label) @@ -442,8 +446,9 @@ def plot_combined_basin_plot( subds.rrup.values, basin_pSA, label=None, - color=colour, + color='red', span=1, # for each basin only show smooth line + show_bands=False ) ax.legend() From 4efc91fd937e43891f3acfd070b7c7bb16be004d Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 13:49:42 +1300 Subject: [PATCH 30/41] fix(response-rrup): fix legend labelling --- visualisation/ims/response_rrup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index e9044ff..71da3dc 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -405,13 +405,13 @@ def plot_combined_basin_plot( c="k", alpha=0.1, s=10, - label="All Simulated Stations", + label="Simulation", ) plot_simulation_fit( ax, simulation_ds.rrup.values, all_pSA, - label="Overall Fit", + label="Simulation fit", color="tab:gray", span=span, ) @@ -438,7 +438,7 @@ def plot_combined_basin_plot( alpha=0.7, s=10, color=colour, - label=f"{human_readable_basin_name(basin)} Stations", + label=f"{human_readable_basin_name(basin)}", ) plot_simulation_fit( @@ -446,9 +446,9 @@ def plot_combined_basin_plot( subds.rrup.values, basin_pSA, label=None, - color='red', + color="red", span=1, # for each basin only show smooth line - show_bands=False + show_bands=False, ) ax.legend() From 26146138f0e4e0b6d265ede35b5933d7d29655ab Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Thu, 11 Dec 2025 13:51:54 +1300 Subject: [PATCH 31/41] fix(response-rrup): remove simulation fit --- visualisation/ims/response_rrup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 71da3dc..b48fd4b 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -411,7 +411,7 @@ def plot_combined_basin_plot( ax, simulation_ds.rrup.values, all_pSA, - label="Simulation fit", + label=None, color="tab:gray", span=span, ) From 65689f4950ece807f4ee4b72d49ea7f82fb20620 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Fri, 12 Dec 2025 10:16:14 +1300 Subject: [PATCH 32/41] refactor(response-rrup): use NSHM subtrend line --- visualisation/ims/response_rrup.py | 107 ++++++++++++++--------------- 1 file changed, 52 insertions(+), 55 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index b48fd4b..77cc3e0 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -11,6 +11,7 @@ import typer import xarray as xr from matplotlib.axes import Axes +from matplotlib.figure import Figure from rpy2.robjects import default_converter, globalenv, numpy2ri, r from rpy2.robjects.conversion import localconverter @@ -48,7 +49,7 @@ def nshm2022_logic_tree_prediction( return psa_results -class LowessFit(NamedTuple): +class ConfidenceInterval(NamedTuple): mean: npt.NDArray[np.floating] std_low: npt.NDArray[np.floating] std_high: npt.NDArray[np.floating] @@ -59,7 +60,7 @@ def fit_loess_r( x: npt.NDArray[np.floating], x_out: npt.NDArray[np.floating], **kwargs, -) -> LowessFit: +) -> ConfidenceInterval: """ Fit LOESS using R and return fitted values and prediction intervals. """ @@ -93,7 +94,7 @@ def fit_loess_r( std_low = fit_vals - residual_se std_high = fit_vals + residual_se - return LowessFit(fit_vals, std_low, std_high) + return ConfidenceInterval(fit_vals, std_low, std_high) def plot_nshm_fit( @@ -104,7 +105,7 @@ def plot_nshm_fit( rrup: npt.NDArray[np.floating], color: str | None = None, label: str | None = None, -) -> None: +) -> ConfidenceInterval: source_config = SourceConfig.read_from_realisation(realisation_ffp) magnitudes = Magnitudes.read_from_realisation(realisation_ffp) rupture_prop = RupturePropagationConfig.read_from_realisation(realisation_ffp) @@ -121,11 +122,12 @@ def plot_nshm_fit( ) mean = logic_tree_results[f"pSA_{period_str}_mean"] std = logic_tree_results[f"pSA_{period_str}_std_Total"] - - ax.fill_between( - rrup, np.exp(mean - std), np.exp(mean + std), alpha=0.3, color=color - ) - ax.plot(rrup, np.exp(mean), c=color, label=label) + fit = np.exp(mean) + ci_low = np.exp(mean - std) + ci_high = np.exp(mean + std) + ax.fill_between(rrup, ci_low, ci_high, alpha=0.3, color=color) + ax.plot(rrup, fit, c=color, label=label) + return ConfidenceInterval(mean, ci_low, ci_high) def plot_simulation_fit( @@ -136,7 +138,7 @@ def plot_simulation_fit( color: str, span: float = 1 / 3, show_bands: bool = True, -) -> None: +) -> ConfidenceInterval: """Plot LOESS fit for a subset of simulation data.""" rrup_out = np.linspace(rrup.min(), rrup.max(), num=100) fit, ci_low, ci_high = fit_loess_r( @@ -147,6 +149,7 @@ def plot_simulation_fit( rrup_out, np.exp(ci_low), np.exp(ci_high), alpha=0.3, color=color ) ax.plot(rrup_out, np.exp(fit), c=color, label=label) + return ConfidenceInterval(fit, ci_low, ci_high) def human_readable_basin_name(basin_name: str) -> str: @@ -164,7 +167,7 @@ def _get_plotting_params(simulation_ds: xr.Dataset): def _apply_style_and_limits( - fig, + fig: Figure, axes: np.ndarray, period: float, max_rrup: float, @@ -191,25 +194,8 @@ def _apply_style_and_limits( ax.set_xlim(left=xmin, right=xmax) -def _plot_nshm_fit_and_settings( - ax: plt.Axes, - realisation_ffp: Path, - data_ds: xr.Dataset, - period: float, - nshm_rrup: np.ndarray, - label: str | None = "NSHM logic tree prediction", -): - """Plots the common NSHM fit and initial axis settings (log scale, grid).""" +def _plot_settings(ax: Axes) -> None: ax.grid(True, which="both", axis="both", lw=0.3) - plot_nshm_fit( - ax, - realisation_ffp, - data_ds, - period, - nshm_rrup, - color="tab:blue", - label=label, - ) ax.set_yscale("log") ax.set_xscale("log") @@ -239,9 +225,16 @@ def plot_basin_vs_no_basin( fig, ax = plt.subplots(constrained_layout=True) all_axes = np.array([ax]) # For style helper consistency - + _plot_settings(ax) # 1. Plot NSHM fit and set axis scales/grid - _plot_nshm_fit_and_settings(ax, realisation_ffp, simulation_ds, period, nshm_rrup) + plot_nshm_fit( + ax, + realisation_ffp, + simulation_ds, + period, + nshm_rrup, + color="tab:blue", + ) # 2. Split and plot simulation data basin_ds = simulation_ds.where(simulation_ds.basin != "", drop=True) @@ -250,6 +243,7 @@ def plot_basin_vs_no_basin( # Basin stations basin_pSA = basin_ds.pSA.sel(period=period, component=component).values ax.scatter(basin_ds.rrup, basin_pSA, c="tab:red", alpha=0.3, s=5) + plot_simulation_fit( ax, basin_ds.rrup.values, @@ -315,11 +309,9 @@ def plot_separate_basin_subplots( ) all_axes = axes_2d.flatten() ax_all_stations = all_axes[0] - + _plot_settings(ax_all_stations) # --- A. Plot All Stations (Primary Plot) --- - _plot_nshm_fit_and_settings( - ax_all_stations, realisation_ffp, simulation_ds, period, nshm_rrup - ) + plot_nshm_fit(ax_all_stations, realisation_ffp, simulation_ds, period, nshm_rrup) # Plot all simulation stations together all_pSA = simulation_ds.pSA.sel(period=period, component=component).values @@ -335,15 +327,15 @@ def plot_separate_basin_subplots( ax_all_stations.legend() # --- B. Plot Individual Basins --- - for i, basin in enumerate(basins): - ax_basin = all_axes[i + 1] + for ax_basin, basin in zip(all_axes[1:], basins): + _plot_settings(ax_basin) subds = _get_basin_stations(simulation_ds, basin) if len(subds.station) == 0: ax_basin.set_title(f"Basin: {human_readable_basin_name(basin)} (No data)") continue - _plot_nshm_fit_and_settings( + plot_nshm_fit( ax_basin, realisation_ffp, subds, period, nshm_rrup, label=None ) # No legend for NSHM here @@ -395,7 +387,10 @@ def plot_combined_basin_plot( all_axes = np.array([ax]) # For style helper consistency # 1. Plot NSHM fit and set axis scales/grid - _plot_nshm_fit_and_settings(ax, realisation_ffp, simulation_ds, period, nshm_rrup) + _plot_settings(ax) + fit, ci_low, ci_high = plot_nshm_fit( + ax, realisation_ffp, simulation_ds, period, nshm_rrup + ) # 2. Plot All Stations (Base Layer) all_pSA = simulation_ds.pSA.sel(period=period, component=component).values @@ -416,13 +411,11 @@ def plot_combined_basin_plot( span=span, ) - # --- FIX: Generate distinct colors using a Colormap --- - # Choose a colormap, e.g., 'viridis'. Other good choices: 'plasma', 'cividis'. - # N is the number of basins we need colors for. - - # --- END FIX --- basin_colours = mpl.color_sequences["Dark2"] # 3. Overlay Individual Basins (Scatter Plots) + log_fit = np.log(fit) + log_nshm_rrup = np.log(nshm_rrup) + for basin, colour in zip(basins, basin_colours): subds = _get_basin_stations(simulation_ds, basin) @@ -432,6 +425,20 @@ def plot_combined_basin_plot( # Use the distinct color generated from the colormap basin_pSA = subds.pSA.sel(period=period, component=component).values + + basin_misfit = np.log(basin_pSA) - np.interp( + basin_pSA.rrup.values, log_nshm_rrup, log_fit + ) + mean_misfit = np.mean(basin_misfit) + + log_basin_rrup_min = np.log(subds.rrup.min().item()) + log_basin_rrup_max = np.log(subds.rrup.max().item()) + log_basin_misfit_rrup = np.linspace( + log_basin_rrup_min, log_basin_rrup_max, num=100 + ) + nshm_subline = ( + np.interp(log_basin_misfit_rrup, log_nshm_rrup, log_fit) + mean_misfit + ) ax.scatter( subds.rrup, basin_pSA, @@ -440,19 +447,9 @@ def plot_combined_basin_plot( color=colour, label=f"{human_readable_basin_name(basin)}", ) - - plot_simulation_fit( - ax, - subds.rrup.values, - basin_pSA, - label=None, - color="red", - span=1, # for each basin only show smooth line - show_bands=False, - ) + ax.plot(np.exp(log_basin_misfit_rrup), np.exp(nshm_subline), color=colour) ax.legend() - # # 5. Apply final styling (is_multi_plot=False) _apply_style_and_limits( From 744a842c73373b539bc8a62012054a548666e59b Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Fri, 12 Dec 2025 10:19:20 +1300 Subject: [PATCH 33/41] fix(response-rrup): correct rrup references --- visualisation/ims/response_rrup.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index 77cc3e0..e3fe74b 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -425,14 +425,12 @@ def plot_combined_basin_plot( # Use the distinct color generated from the colormap basin_pSA = subds.pSA.sel(period=period, component=component).values - - basin_misfit = np.log(basin_pSA) - np.interp( - basin_pSA.rrup.values, log_nshm_rrup, log_fit - ) + log_rrup = np.log(subds.rrup.values) + basin_misfit = np.log(basin_pSA) - np.interp(log_rrup, log_nshm_rrup, log_fit) mean_misfit = np.mean(basin_misfit) - log_basin_rrup_min = np.log(subds.rrup.min().item()) - log_basin_rrup_max = np.log(subds.rrup.max().item()) + log_basin_rrup_min = log_rrup.min() + log_basin_rrup_max = log_rrup.max() log_basin_misfit_rrup = np.linspace( log_basin_rrup_min, log_basin_rrup_max, num=100 ) From a60adf32592441fa5bb58c76a39d2e2a99ab47e7 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Mon, 15 Dec 2025 09:20:18 +1300 Subject: [PATCH 34/41] feat: add gmm comparison plots --- pyproject.toml | 1 + visualisation/ims/gmm_comparison.py | 194 ++++++++++++++++++++ visualisation/ims/response_rrup.py | 86 +-------- visualisation/realisation.py | 22 ++- visualisation/utils.py | 263 ++++++++++++++++++++++++++-- 5 files changed, 464 insertions(+), 102 deletions(-) create mode 100644 visualisation/ims/gmm_comparison.py diff --git a/pyproject.toml b/pyproject.toml index 0f8964a..dbb26ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ plot-stoch = "visualisation.sources.plot_stoch:app" plot-ts = "visualisation.plot_ts:app" plot-response-rrup = "visualisation.ims.response_rrup:app" plot-response-spectra = "visualisation.waveforms.plot_response_spectra:app" +plot-gmm-comparison = "visualisation.ims.gmm_comparison:app" [tool.setuptools.package-dir] visualisation = "visualisation" diff --git a/visualisation/ims/gmm_comparison.py b/visualisation/ims/gmm_comparison.py new file mode 100644 index 0000000..2fdb383 --- /dev/null +++ b/visualisation/ims/gmm_comparison.py @@ -0,0 +1,194 @@ +from pathlib import Path +from typing import Annotated + +import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw +import pandas as pd +import pygmt +import shapely +import typer +import xarray as xr + +from pygmt_helper import plotting +from qcore import cli +from visualisation import realisation, utils +from workflow.realisations import ( + DomainParameters, + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) + +app = typer.Typer() + + +def plot_diff( + fig: pygmt.Figure, + dset: xr.Dataset, + gmm_psa_value: pd.Series, + period: float, + cmap: str, + cmap_max: float | None, + cmap_min: float | None, + ticks: int, + reverse: bool, +) -> None: + intensity = dset["pSA"].sel(period=period, component="rotd50") + pgv_value = intensity.to_series() + breakpoint() + diff = np.log(pgv_value) - gmm_psa_value + cmap_min = cmap_min or diff.min() + cmap_max = cmap_max or diff.max() + + cmap_limits = utils.range_for(cmap_min, cmap_max, ticks) + + latitude = intensity.latitude.to_series() + longitude = intensity.longitude.to_series() + df = pd.DataFrame({"lat": latitude, "lon": longitude, "value": diff}) + + grid: xr.DataArray = plotting.create_grid( + df, + "value", + grid_spacing="1000e/1000e", + region=tuple(fig.region), + set_water_to_nan=True, + ) + + plotting.plot_grid( + fig, + grid, + cmap, + cmap_limits, + ("red", "blue"), + reverse_cmap=reverse, + transparency=40, + plot_contours=False, + ) + + +def find_region(domain: DomainParameters) -> tuple[float, float, float, float]: + """Find an appropriate domain,""" + nz_region = shapely.box(166.0, -48.0, 178.5, -34.0) + region = shapely.union( + utils.polygon_nztm_to_pygmt(domain.domain.polygon), nz_region + ) + (min_x, min_y, max_x, max_y) = shapely.bounds(region) + return (min_x, max_x, min_y, max_y) + + +def generate_basemap(region: tuple[float, float, float, float]) -> pygmt.Figure: + fig: pygmt.Figure = plotting.gen_region_fig( + title=None, + region=region, + plot_kwargs=dict( + plot_kwargs=["af", "xaf+Longitude", "yaf+Latitude"], + water_color="white", + topo_cmap_min=-900, + topo_cmap_max=3100, + ), + plot_highways=False, + config_options=dict( + MAP_FRAME_TYPE="plain", + FORMAT_GEO_MAP="ddd.xx", + MAP_FRAME_PEN="thinner,black", + ), + ) + assert isinstance(fig, pygmt.Figure) + return fig + + +@cli.from_docstring(app) +def main( + realisation_ffp: Annotated[Path, typer.Argument()], + dataset: Annotated[ + Path, + typer.Argument(), + ], + period: Annotated[ + float, + typer.Argument(), + ], + output: Annotated[ + Path, + typer.Argument(), + ], + cmap: Annotated[ + str, + typer.Option(), + ] = "polar", + reverse: Annotated[ + bool, + typer.Option(is_flag=True), + ] = False, + cmap_min: Annotated[ + float | None, + typer.Option(), + ] = None, + cmap_max: Annotated[ + float | None, + typer.Option(), + ] = None, + ticks: Annotated[ + int, + typer.Option(), + ] = 10, +) -> None: + """Compare simulation results to predictions from the NSHM2022 logic tree. + + Parameters + ---------- + realisation_ffp : Path + Path to realisation. + dataset : Path + Path to xarray intensity measure dataset. + period : float + pSA period to compare against. + output : Path + The path to write the figure out to. + cmap_min : float + Colourmap minimum + cmap_max : float + Colourmap maximum + ticks : int + Number of ticks in discrete colourmap. + output : Path + Output path. + cmap : str + Colourmap to plot log residuals. Should be divering. + reverse : bool + If true, reverse the colourmap. Defaults to false. + """ + + dset = xr.open_dataset(dataset, engine="h5netcdf") + domain = DomainParameters.read_from_realisation(realisation_ffp) + source_config = SourceConfig.read_from_realisation(realisation_ffp) + magnitudes = Magnitudes.read_from_realisation(realisation_ffp) + rakes = Rakes.read_from_realisation(realisation_ffp) + rupture_propagation_config = RupturePropagationConfig.read_from_realisation( + realisation_ffp + ) + + region = find_region(domain) + + fig = generate_basemap(region) + gmm_psa_value = utils.get_gmm_prediction( + dset, period, source_config, magnitudes, rakes, rupture_propagation_config + ) + + plot_diff( + fig, + dset, + gmm_psa_value, + period, + cmap, + cmap_max, + cmap_min, + ticks, + reverse, + ) + realisation.plot_domain(fig, domain) + realisation.plot_sources(fig, source_config) + + fig.savefig(output) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index e3fe74b..eca8ec7 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -12,11 +12,9 @@ import xarray as xr from matplotlib.axes import Axes from matplotlib.figure import Figure -from rpy2.robjects import default_converter, globalenv, numpy2ri, r -from rpy2.robjects.conversion import localconverter from visualisation import utils -from visualisation.utils import RuptureContext, SiteProperties +from visualisation.utils import ConfidenceInterval, RuptureContext, SiteProperties from workflow.realisations import ( Magnitudes, Rakes, @@ -27,76 +25,6 @@ app = typer.Typer() -def nshm2022_logic_tree_prediction( - rupture_context: RuptureContext, - site_properties: SiteProperties, - period: float, - rrup: npt.NDArray[np.floating], -) -> pd.DataFrame: - tect_type = oqw.constants.TectType.ACTIVE_SHALLOW - gmm_lt = oqw.constants.GMMLogicTree.NSHM2022 - rupture_df = pd.DataFrame( - {"rrup": rrup, "vs30measured": False, **rupture_context, **site_properties} - ) - for dist_metric in ["rjb", "rx", "ry"]: - rupture_df[dist_metric] = rupture_df["rrup"] - - psa_results = oqw.run_gmm_logic_tree( - gmm_lt, tect_type, rupture_df, "pSA", periods=[period] - ) - assert isinstance(psa_results, pd.DataFrame) - psa_results["rrup"] = rupture_df["rrup"] - return psa_results - - -class ConfidenceInterval(NamedTuple): - mean: npt.NDArray[np.floating] - std_low: npt.NDArray[np.floating] - std_high: npt.NDArray[np.floating] - - -def fit_loess_r( - y: npt.NDArray[np.floating], - x: npt.NDArray[np.floating], - x_out: npt.NDArray[np.floating], - **kwargs, -) -> ConfidenceInterval: - """ - Fit LOESS using R and return fitted values and prediction intervals. - """ - loess_args = ", ".join(f"{k}={v}" for k, v in kwargs.items()) - loess_call = f"loess(y ~ x, {loess_args})" if loess_args else "loess(y ~ x)" - - with localconverter(default_converter + numpy2ri.converter): - globalenv["x"] = x - globalenv["y"] = y - globalenv["x_out"] = x_out - - r(f"fit <- {loess_call}") - r("newdat <- data.frame(x=x_out)") - r("pred <- predict(fit, newdata=newdat, se=TRUE)") - r("residual_se <- fit$s") - - fit_vals = r("pred$fit") - if not isinstance(fit_vals, np.ndarray): - raise ValueError( - f"Residual stderr evaluation failed, expected float found: {fit_vals=}" - ) - - residual_se_eval = r("residual_se") - if isinstance(residual_se_eval, np.ndarray): - residual_se = residual_se_eval.item() - else: - raise ValueError( - f"Residual stderr evaluation failed, expected float found: {residual_se_eval=}" - ) - - std_low = fit_vals - residual_se - std_high = fit_vals + residual_se - - return ConfidenceInterval(fit_vals, std_low, std_high) - - def plot_nshm_fit( ax: Axes, realisation_ffp: Path, @@ -114,7 +42,7 @@ def plot_nshm_fit( source_config, magnitudes, rakes, rupture_prop ) site_properties = utils.compute_site_properties(site_ds.vs30.values) - logic_tree_results = nshm2022_logic_tree_prediction( + logic_tree_results = utils.nshm2022_logic_tree_prediction( rupture_context, site_properties, period, rrup ) period_str = ( @@ -127,7 +55,7 @@ def plot_nshm_fit( ci_high = np.exp(mean + std) ax.fill_between(rrup, ci_low, ci_high, alpha=0.3, color=color) ax.plot(rrup, fit, c=color, label=label) - return ConfidenceInterval(mean, ci_low, ci_high) + return ConfidenceInterval(fit, ci_low, ci_high) def plot_simulation_fit( @@ -141,7 +69,7 @@ def plot_simulation_fit( ) -> ConfidenceInterval: """Plot LOESS fit for a subset of simulation data.""" rrup_out = np.linspace(rrup.min(), rrup.max(), num=100) - fit, ci_low, ci_high = fit_loess_r( + fit, ci_low, ci_high = utils.fit_loess_r( np.log(psa), np.log(rrup), np.log(rrup_out), span=span ) if show_bands: @@ -445,7 +373,11 @@ def plot_combined_basin_plot( color=colour, label=f"{human_readable_basin_name(basin)}", ) - ax.plot(np.exp(log_basin_misfit_rrup), np.exp(nshm_subline), color=colour) + ax.plot( + np.exp(log_basin_misfit_rrup), + np.exp(nshm_subline), + color=utils.adjust_value(colour, 0.8), + ) ax.legend() diff --git a/visualisation/realisation.py b/visualisation/realisation.py index eb3d6cb..1d070fd 100644 --- a/visualisation/realisation.py +++ b/visualisation/realisation.py @@ -17,6 +17,7 @@ from pygmt_helper import plotting from qcore import cli +from source_modelling.sources import Fault, Plane from visualisation import utils from workflow.realisations import ( DomainParameters, @@ -78,9 +79,7 @@ def plot_stations( ) -def plot_sources( - fig: pygmt.Figure, source_config: SourceConfig, **kwargs: dict[str, Any] -) -> None: +def plot_sources(fig: pygmt.Figure, source_config: SourceConfig, **kwargs: Any) -> None: """Plot the sources on the figure. Parameters @@ -91,7 +90,8 @@ def plot_sources( The source configuration to plot. **kwargs : dict Additional keyword arguments to pass to the plotting function. If empty, the default is - - `pen="0.3p,black"` (polygon border colour) + - `pen="0.3p,black,--"` (polygon border colour) + - trace pen is found by taking the pen and stripping the "--" Examples -------- @@ -102,16 +102,24 @@ def plot_sources( >>> plot_sources(fig, source_config) >>> source_config.show() """ - kwargs = {"pen": "0.3p,black", **(kwargs or {})} + pen = kwargs.get("pen", "0.3p,black,--") + assert isinstance(pen, str) + trace_pen = pen.removesuffix(",--") + interior_kwargs = {"pen": "0.3p,black", **(kwargs or {})} for source in source_config.source_geometries.values(): - utils.plot_polygon(fig, utils.polygon_nztm_to_pygmt(source.geometry), **kwargs) + utils.plot_polygon( + fig, utils.polygon_nztm_to_pygmt(source.geometry), **interior_kwargs + ) + if isinstance(source, Plane | Fault): + trace = shapely.LineString(source.bounds[:2]) + utils.plot_polygon(fig, utils.polygon_nztm_to_pygmt(trace), pen=trace_pen) def plot_domain( fig: pygmt.Figure, domain_parameters: DomainParameters, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> None: """Plot the domain on a figure. diff --git a/visualisation/utils.py b/visualisation/utils.py index 602db55..e985d05 100644 --- a/visualisation/utils.py +++ b/visualisation/utils.py @@ -1,16 +1,22 @@ """Utility functions common to many plotting scripts.""" -from typing import Any, Literal, Optional, TypedDict, Unpack +from collections.abc import Sequence +from typing import Any, Literal, NamedTuple, Optional, TypedDict, Unpack import numpy as np import numpy.typing as npt import oq_wrapper as oqw +import pandas as pd import pygmt import scipy as sp import shapely +import xarray as xr +from matplotlib import colors from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.figure import Figure +from rpy2.robjects import default_converter, globalenv, numpy2ri, r +from rpy2.robjects.conversion import localconverter from qcore import coordinates from source_modelling import moment @@ -161,7 +167,7 @@ def _point_on_polygon(t: float, polygon: shapely.Polygon) -> shapely.Point: def _hausdorff_maximisation( polygon: shapely.Polygon, other_geom: shapely.Polygon -) -> shapely.Point: +) -> tuple[shapely.Point, float]: """Finds the point on polygon maximizing the distance to other_geom. Parameters @@ -175,6 +181,8 @@ def _hausdorff_maximisation( ------- shapely.Point Point on the polygon boundary maximizing the distance to other_geom. + float + The distance from the point to other_geom See Also -------- @@ -183,11 +191,11 @@ def _hausdorff_maximisation( def objective(t: float) -> float: # numpydoc ignore=GL08 point = _point_on_polygon(t, polygon) - return -point.distance(other_geom.exterior) # Negative because we maximize + return -point.distance(other_geom.exterior) # Negative because we maximise result = sp.optimize.minimize_scalar(objective, bounds=(0, 1), method="bounded") if result.success: - return _point_on_polygon(result.x, polygon), -result.fun + return _point_on_polygon(float(result.x), polygon), -result.fun else: raise RuntimeError("Optimisation failed") @@ -347,7 +355,12 @@ def balanced_subplot_grid( squeeze: bool = False, clear: bool = False, **kwargs: Unpack[SubplotsKwargs], -) -> tuple[Figure, npt.NDArray[Axes]]: +) -> tuple[ + Figure, + Sequence[Axes], # Although we are really returning a numpy array, numpy does + # not support generic Axes objects in NDArray. + # numpy/numpy#24738 +]: # This has more columns than rows, i.e. wide height = np.sqrt(n_subplots / aspect) rows = int(np.ceil(height)) @@ -381,9 +394,14 @@ class RuptureContext(TypedDict): class SiteProperties(TypedDict): - vs30: float - z1pt0: float - z2pt5: float + vs30measured: bool | npt.NDArray[np.bool_] + vs30: float | npt.NDArray[np.floating] + z1pt0: float | npt.NDArray[np.floating] + z2pt5: float | npt.NDArray[np.floating] + rrup: float | npt.NDArray[np.floating] + rjb: float | npt.NDArray[np.floating] + rx: float | npt.NDArray[np.floating] + ry: float | npt.NDArray[np.floating] def circmean( @@ -397,6 +415,45 @@ def circmean( return float(argument) +def compute_site_properties(sites: xr.Dataset) -> SiteProperties: + vs30 = sites.vs30.values + z1pt0 = oqw.estimations.chiou_young_08_calc_z1p0(vs30) + z2pt5 = oqw.estimations.chiou_young_08_calc_z2p5(vs30) + rrup = sites.rrup.values + rjb = sites.rjb.values + rx = rjb + ry = rjb + vs30measured = False + return SiteProperties( + vs30measured=vs30measured, + vs30=vs30, + z1pt0=z1pt0, + z2pt5=z2pt5, + rrup=rrup, + rjb=rjb, + rx=rx, + ry=ry, + ) + + +def get_gmm_prediction( + sites: xr.Dataset, + period: float, + source_config: SourceConfig, + magnitudes: Magnitudes, + rakes: Rakes, + rupture_propagation: RupturePropagationConfig, +) -> pd.Series: + site_properties = compute_site_properties(sites) + rupture_context = compute_rupture_context( + source_config, magnitudes, rakes, rupture_propagation + ) + breakpoint() + gmm_df = nshm2022_logic_tree_prediction(rupture_context, site_properties, period) + gmm_psa_value = gmm_df.loc[:, gmm_df.columns.str.endswith("_mean")].squeeze() + return gmm_psa_value + + def compute_rupture_context( source_config: SourceConfig, magnitudes_config: Magnitudes, @@ -445,20 +502,190 @@ def compute_rupture_context( ) -def compute_site_properties( - site_vs30: npt.NDArray[np.floating] | np.floating, -) -> SiteProperties: +def mean_vs30(site_vs30: npt.NDArray[np.floating]) -> float: # Calculate geometric mean of site vs30 using the exponential-log form: # exp(1/n sum vs30) # This is as opposed to straight-forward calculation # product(vs30) ^ (1/n) # Which is numerically unstable for a large number of stations due to - # floating-point arithmetic overflow and inprecision at large values + # floating-point arithmetic overflow and imprecision at large values # obtained by multiplication. - if isinstance(site_vs30, np.ndarray): - vs30 = np.exp(1 / len(site_vs30) * np.sum(np.log(site_vs30))) + + return np.exp(1 / len(site_vs30) * np.sum(np.log(site_vs30))) + + +def adjust_value(colour: npt.ArrayLike, gamma: float) -> npt.NDArray[np.float64]: + """Adjust the brightness of an RGB colour. + + Parameters + ---------- + colour : npt.ArrayLike + Colour to transform (in RGB format). + gamma : float + The brightness to adjust by. Adjustment is multiplicative, so + ``gamma=1`` is equivalent to no change. + + Returns + ------- + npt.NDArray[np.float64] + A brightness adjusted equivalent of `colour` with no change in + hue or saturation. + """ + colour = np.asarray(colour) + # Naive colour brightness adjustment would simply multiply colour + # by gamma. However, this also changes the hue of the colour, + # resulting in visually incorrect results. HSV is designed for + # this manipulation. + hsv = colors.rgb_to_hsv(colour) + # HSV colour scale represents every colour as a combination of three components: + # 1. (H)ue, the quality of the colour (blue, red, green, magenta, etc). + # 2. (S)aturation, how intense that colour is at a fixed + # brightness (e.g. black has zero saturation and looks the same + # regardless of brightness). + # 3. (V)alue, the brightness of the colour + # + # We just want to adjust the brightness. To do this we adjust the value component. + hsv[-1] *= gamma + return colors.hsv_to_rgb(hsv) + + +def nice_num(x: float, round: bool) -> float: + """Find an equivalent "nice number" for `x`. + + A nice number is a number that a power-of-ten multiple of 1, 2, or + 5. See: https://stackoverflow.com/a/16363437 + + Parameters + ---------- + x : float + The number to find a nice number for. + round : bool + If true, round toward the nearest nice number. Otherwise, + find the next largest. + + + Returns + ------- + float + The nearest or the next largest nice number. + """ + exponent = np.floor(np.log10(x)) + fraction = x / (10**exponent) + if round: + if fraction < 1.5: + nice_fraction = 1 + elif fraction < 3: + nice_fraction = 2 + elif fraction < 7: + nice_fraction = 5 + else: + nice_fraction = 10 else: - vs30 = site_vs30 - z1pt0 = oqw.estimations.chiou_young_14_calc_z1p0(vs30) - z2pt5 = oqw.estimations.campbell_bozorgina_14_calc_z2p5(vs30) - return SiteProperties(vs30=vs30, z1pt0=z1pt0, z2pt5=z2pt5) + if fraction <= 1: + nice_fraction = 1 + elif fraction <= 2: + nice_fraction = 2 + elif fraction <= 5: + nice_fraction = 5 + else: + nice_fraction = 10 + return nice_fraction * 10**exponent + + +class PlotDomain(NamedTuple): + """Named tuple representing plot domain with ticks.""" + + low: float + high: float + spacing: float + + +def range_for(low: float, high: float, max_ticks: int) -> tuple[float, float, float]: + """Given a plotting range and a fixed number of ticks, return its "nicest" representation. + + See: https://stackoverflow.com/a/16363437 + + Parameters + ---------- + low : float + The lower bound of the plotting range. + high : float + The upper bound of the plotting range. + max_ticks : int + An upper bound on the number of ticks desired. The number of + ticks returned is usually *less* than this value. + + Returns + ------- + PlotDomain + The lower bound, upper bound and tick spacing corresponding to + a "nice" representation of this domain. + """ + range = nice_num(high - low, False) + tick_spacing = nice_num(range / (max_ticks - 1), True) + nice_min = np.floor(low / tick_spacing) * tick_spacing + nice_max = np.ceil(high / tick_spacing) * tick_spacing + return PlotDomain(nice_min, nice_max, tick_spacing) + + +def nshm2022_logic_tree_prediction( + rupture_context: RuptureContext, + site_properties: SiteProperties, + period: float, +) -> pd.DataFrame: + tect_type = oqw.constants.TectType.ACTIVE_SHALLOW + gmm_lt = oqw.constants.GMMLogicTree.NSHM2022 + rupture_df = pd.DataFrame( + {"vs30measured": False, **rupture_context, **site_properties} + ) + psa_results = oqw.run_gmm_logic_tree( + gmm_lt, tect_type, rupture_df, "pSA", periods=[period] + ) + assert isinstance(psa_results, pd.DataFrame) + for site_property in site_properties: + psa_results[site_property] = rupture_df[site_property] + return psa_results + + +class ConfidenceInterval(NamedTuple): + mean: npt.NDArray[np.floating] + std_low: npt.NDArray[np.floating] + std_high: npt.NDArray[np.floating] + + +def fit_loess_r( + y: npt.NDArray[np.floating], + x: npt.NDArray[np.floating], + x_out: npt.NDArray[np.floating], + **kwargs, +) -> ConfidenceInterval: + """ + Fit LOESS using R and return fitted values and prediction intervals. + """ + loess_args = ", ".join(f"{k}={v}" for k, v in kwargs.items()) + loess_call = f"loess(y ~ x, {loess_args})" if loess_args else "loess(y ~ x)" + + with localconverter(default_converter + numpy2ri.converter): + globalenv["x"] = x + globalenv["y"] = y + globalenv["x_out"] = x_out + + r(f"fit <- {loess_call}") + r("newdat <- data.frame(x=x_out)") + r("pred <- predict(fit, newdata=newdat, se=TRUE)") + r("residual_se <- fit$s") + + fit_vals = np.asarray(r("pred$fit")) + + residual_se_eval = r("residual_se") + if isinstance(residual_se_eval, np.ndarray) and len(residual_se_eval) == 1: + residual_se = float(residual_se_eval.item()) # type: ignore[invalid-argument-type] + else: + raise ValueError( + f"Residual stderr evaluation failed, expected float found: {residual_se_eval=}" + ) + + std_low = fit_vals - residual_se + std_high = fit_vals + residual_se + + return ConfidenceInterval(fit_vals, std_low, std_high) From f1b835f84f298c8c80f99d7e98706f53f4ec2130 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 27 Jan 2026 09:26:10 +1300 Subject: [PATCH 35/41] feat(plot-response-spectra): add scenario station comparison plots --- .../waveforms/plot_response_spectra.py | 81 ++++++++++++------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/visualisation/waveforms/plot_response_spectra.py b/visualisation/waveforms/plot_response_spectra.py index 4a25bb4..9a8921d 100644 --- a/visualisation/waveforms/plot_response_spectra.py +++ b/visualisation/waveforms/plot_response_spectra.py @@ -15,6 +15,18 @@ app = typer.Typer() +def setup_plot_styling( + ax: Axes, component: str, ymax: float | None, ymin: float | None +) -> None: + ax.set_ylabel(f"pSA [{component}, g]") + ax.set_xlabel("Period [s]") + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(visible=True, which="both", axis="both", lw=0.3) + if ymin is not None or ymax is not None: + ax.set_ylim(bottom=ymin, top=ymax) + + def plot_spectra( fig: Figure, ax: Axes, @@ -23,6 +35,7 @@ def plot_spectra( component: str, ymax: float | None = None, ymin: float | None = None, + labels: list[str] | None = None, ) -> None: """Plot a spectra from a simulation. @@ -39,28 +52,19 @@ def plot_spectra( ymin : float or None Min limit for y-axis. """ - ax.set_ylabel(f"pSA [{component}, g]") - ax.set_xlabel("Period [s]") - ax.set_xscale("log") - ax.set_yscale("log") - ax.grid(visible=True, which="both", axis="both", lw=0.3) - if ymin is not None or ymax is not None: - ax.set_ylim(bottom=ymin, top=ymax) - - for station in stations: + labels = labels or stations + for label, station in zip(labels, stations): spectra = dataset.pSA.sel(station=station, component=component).values periods = dataset.period.values - ax.plot(periods, spectra, label=station) - - if len(stations) > 1: - ax.legend() + ax.plot(periods, spectra, label=label) @cli.from_docstring(app) def plot_spectra_cli( - dataset_path: Annotated[Path, typer.Argument()], - stations: Annotated[list[str], typer.Argument()], + dataset_paths: Annotated[list[Path], typer.Argument()], + scenarios: Annotated[list[str] | None, typer.Option("--scenario")] = None, + stations: Annotated[list[str] | None, typer.Option("--station")] = None, title: str | None = None, save: Path | None = None, dpi: int = 300, @@ -81,18 +85,41 @@ def plot_spectra_cli( units : Units The units to plot in. """ - dset = xr.open_dataset(dataset_path, engine="h5netcdf") + if not stations: + raise ValueError("Require at least one station to plot.") + elif len(dataset_paths) > 1 and ( + scenarios is None or len(scenarios) != len(dataset_paths) + ): + raise ValueError( + "Require a label for each dataset, if more than one is provided." + ) + cm = 1 / 2.54 fig, ax = plt.subplots(figsize=(width * cm, height * cm)) - plot_spectra( - fig, - ax, - dset, - stations, - component, - ymin=ymin, - ymax=ymax, - ) + setup_plot_styling(ax, component, ymin=ymin, ymax=ymax) + for i, dataset_path in enumerate(dataset_paths): + dset = xr.open_dataset(dataset_path, engine="h5netcdf") + for station in stations: + label = station + if scenarios is not None and len(scenarios) > 1 and len(stations) > 1: + scenario = scenarios[i] + label = f"{station} ({scenario})" + elif scenarios is not None and len(stations) == 1: + label = scenarios[i] + + plot_spectra( + fig, + ax, + dset, + [station], + component, + ymin=ymin, + ymax=ymax, + labels=[label], + ) + + if len(stations) * len(dataset_paths) > 1: + ax.legend() if title: fig.suptitle(title) @@ -104,7 +131,3 @@ def plot_spectra_cli( else: fig.show() plt.show() - - -if __name__ == "__main__": - app() From 0c6683d99a007c9f3dae75ca6b3ab1747fef3a11 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 27 Jan 2026 09:26:33 +1300 Subject: [PATCH 36/41] fix(gmm_comparison): dashed outline for gmm comparison --- visualisation/ims/gmm_comparison.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/visualisation/ims/gmm_comparison.py b/visualisation/ims/gmm_comparison.py index 2fdb383..80031ea 100644 --- a/visualisation/ims/gmm_comparison.py +++ b/visualisation/ims/gmm_comparison.py @@ -37,7 +37,6 @@ def plot_diff( ) -> None: intensity = dset["pSA"].sel(period=period, component="rotd50") pgv_value = intensity.to_series() - breakpoint() diff = np.log(pgv_value) - gmm_psa_value cmap_min = cmap_min or diff.min() cmap_max = cmap_max or diff.max() @@ -188,7 +187,7 @@ def main( ticks, reverse, ) - realisation.plot_domain(fig, domain) + realisation.plot_domain(fig, domain, pen="1p,black,-") realisation.plot_sources(fig, source_config) fig.savefig(output) From 17126e48a89fd8860a7c66f51345d92b0a99abcb Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 27 Jan 2026 09:27:23 +1300 Subject: [PATCH 37/41] fix(response_rrup): compute site properties --- visualisation/ims/response_rrup.py | 19 +++++++++++++++++-- visualisation/utils.py | 3 ++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py index eca8ec7..fed4e47 100644 --- a/visualisation/ims/response_rrup.py +++ b/visualisation/ims/response_rrup.py @@ -41,9 +41,24 @@ def plot_nshm_fit( rupture_context = utils.compute_rupture_context( source_config, magnitudes, rakes, rupture_prop ) - site_properties = utils.compute_site_properties(site_ds.vs30.values) + mean_vs30 = utils.mean_vs30(site_ds.vs30.values) + mean_z1pt0 = oqw.estimations.chiou_young_08_calc_z1p0(mean_vs30) + mean_z2pt5 = oqw.estimations.chiou_young_08_calc_z2p5(mean_vs30) + site_properties = SiteProperties( + vs30=mean_vs30, + vs30measured=False, + z1pt0=mean_z1pt0, + z2pt5=mean_z2pt5, + rrup=rrup, + rjb=rrup, + rx=rrup, + ry=rrup, + ) + logic_tree_results = utils.nshm2022_logic_tree_prediction( - rupture_context, site_properties, period, rrup + rupture_context, + site_properties, + period, ) period_str = ( f"{period:.2f}".rstrip("0") if not period.is_integer() else f"{int(period)}.0" diff --git a/visualisation/utils.py b/visualisation/utils.py index e985d05..7f7fe4e 100644 --- a/visualisation/utils.py +++ b/visualisation/utils.py @@ -448,8 +448,9 @@ def get_gmm_prediction( rupture_context = compute_rupture_context( source_config, magnitudes, rakes, rupture_propagation ) - breakpoint() gmm_df = nshm2022_logic_tree_prediction(rupture_context, site_properties, period) + gmm_df["station"] = sites.station.values + gmm_df = gmm_df.set_index("station") gmm_psa_value = gmm_df.loc[:, gmm_df.columns.str.endswith("_mean")].squeeze() return gmm_psa_value From 6e2c777ebb7cbba27a7f10cbf839f09f8102aa33 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 27 Jan 2026 09:27:44 +1300 Subject: [PATCH 38/41] fix(realisation): handle fault and plane plotting better --- visualisation/realisation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/visualisation/realisation.py b/visualisation/realisation.py index 1d070fd..f7aa30c 100644 --- a/visualisation/realisation.py +++ b/visualisation/realisation.py @@ -105,13 +105,18 @@ def plot_sources(fig: pygmt.Figure, source_config: SourceConfig, **kwargs: Any) pen = kwargs.get("pen", "0.3p,black,--") assert isinstance(pen, str) trace_pen = pen.removesuffix(",--") - interior_kwargs = {"pen": "0.3p,black", **(kwargs or {})} + interior_kwargs = {"pen": pen, **(kwargs or {})} for source in source_config.source_geometries.values(): utils.plot_polygon( fig, utils.polygon_nztm_to_pygmt(source.geometry), **interior_kwargs ) - if isinstance(source, Plane | Fault): + if isinstance(source, Fault): + trace = shapely.LineString( + np.concatenate([plane.bounds[:2] for plane in source.planes]) + ) + utils.plot_polygon(fig, utils.polygon_nztm_to_pygmt(trace), pen=trace_pen) + elif isinstance(source, Plane): trace = shapely.LineString(source.bounds[:2]) utils.plot_polygon(fig, utils.polygon_nztm_to_pygmt(trace), pen=trace_pen) From faa404eb58387b3e3245733fe56a23fecb8eaaa7 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 27 Jan 2026 09:29:13 +1300 Subject: [PATCH 39/41] refactor(plot-srf): use pygmt helper for grid generation --- visualisation/sources/plot_srf.py | 92 +++++++++++++++---------------- 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/visualisation/sources/plot_srf.py b/visualisation/sources/plot_srf.py index 54d75a8..fc28dbb 100644 --- a/visualisation/sources/plot_srf.py +++ b/visualisation/sources/plot_srf.py @@ -61,9 +61,8 @@ def show_slip( region: tuple[float, float, float, float], srf_data: srf.SrfFile, annotations: bool, - projection: str = "M?", + contours: bool, realisation_ffp: Optional[Path] = None, - title: Optional[str] = None, ): """Show a slip map with optional contours. @@ -94,31 +93,11 @@ def show_slip( >>> show_slip(fig, region, srf_data, annotations=True, projection="M6i", title="Slip Distribution") >>> fig.show() # Displays the slip map with optional annotations """ - subtitle = utils.format_description( - srf_data.points["slip"], units="cm", compact=True - ) - title_args = [f"+t{title}+s{subtitle}".replace(" ", r"\040")] if title else [] # Compute slip limits - fig.basemap( - region=region, - projection=projection, - frame=plotting.DEFAULT_PLT_KWARGS["frame_args"] + title_args, - ) - fig.coast( - shorelines=["1/0.1p,black", "2/0.1p,black"], - resolution="f", - land="#666666", - water="skyblue", - ) - - slip_quantile = srf_data.points["slip"].quantile(0.98) + slip_quantile = srf_data.points["slip"].max() slip_cb_max = max(int(np.round(slip_quantile, -1)), 10) cmap_limits = (0, slip_cb_max, slip_cb_max / 10) - slip_stats = utils.format_description( - srf_data.points["slip"], compact=True, units="cm" - ) dx = srf_data.header.iloc[0]["len"] / srf_data.header.iloc[0]["nstk"] - subtitle = f"Slip: {slip_stats}, dx = {dx:.2f} km, {len(srf_data.header)} planes" grid_scale = min(utils.grid_scale_for_region(region), dx * 1000) for (_, segment), segment_points in zip( srf_data.header.iterrows(), srf_data.segments @@ -150,23 +129,24 @@ def show_slip( reverse_cmap=True, plot_contours=False, cb_label="Slip (cm)", - continuous_cmap=True, ) - # Plot time contours - time_grid = plotting.create_grid( - segment_points, - "tinit", - grid_spacing=f"{grid_scale}e/{grid_scale}e", - region=( - segment_points["lon"].min(), - segment_points["lon"].max(), - segment_points["lat"].min(), - segment_points["lat"].max(), - ), - set_water_to_nan=False, - ) - fig.grdcontour(levels=1, grid=time_grid, pen="0.1p") + if contours: + # Plot time contours + time_grid = plotting.create_grid( + segment_points, + "tinit", + grid_spacing=f"{grid_scale}e/{grid_scale}e", + region=( + segment_points["lon"].min(), + segment_points["lon"].max(), + segment_points["lat"].min(), + segment_points["lat"].max(), + ), + set_water_to_nan=False, + ) + + fig.grdcontour(levels=1, grid=time_grid, pen="0.1p") # Plot bounds of the current segment. corners = segment_points.iloc[[0, nstk - 1, -1, (ndip - 1) * nstk]] @@ -290,8 +270,13 @@ def plot_srf( latitude_pad: Annotated[float, typer.Option()] = 0, longitude_pad: Annotated[float, typer.Option()] = 0, annotations: Annotated[bool, typer.Option()] = True, + contours: Annotated[bool, typer.Option()] = True, width: Annotated[float, typer.Option(min=0)] = 17, show_inset: bool = False, + min_lat: float | None = None, + max_lat: float | None = None, + min_lon: float | None = None, + max_lon: float | None = None, ) -> None: """Plot multi-segment rupture with slip. @@ -342,24 +327,35 @@ def plot_srf( """ srf_data = srf.read_srf(srf_ffp) region = ( - srf_data.points["lon"].min() - longitude_pad, - srf_data.points["lon"].max() + longitude_pad, - srf_data.points["lat"].min() - latitude_pad, - srf_data.points["lat"].max() + latitude_pad, + min_lon or srf_data.points["lon"].min() - longitude_pad, + max_lon or srf_data.points["lon"].max() + longitude_pad, + min_lat or srf_data.points["lat"].min() - latitude_pad, + max_lat or srf_data.points["lat"].max() + latitude_pad, + ) + use_high_res = (region[1] - region[0]) * (region[3] - region[2]) < 0.5 + fig = plotting.gen_region_fig( + title, + region, + high_res_topo=use_high_res, + projection=f"M{width}c", + subtitle=utils.format_description( + srf_data.points["slip"], units="cm", compact=True + ), + config_options=dict( + FONT_SUBTITLE="9p,Helvetica,black", + FORMAT_GEO_MAP="ddd.xx", + MAP_FRAME_TYPE="plain", + ), + plot_kwargs=dict(water_color="white", topo_cmap_min=-900, topo_cmap_max=3100), ) - - fig = pygmt.Figure() - - pygmt.config(FONT_SUBTITLE="9p,Helvetica,black", FORMAT_GEO_MAP="ddd.xx") show_slip( fig, region, srf_data, annotations, - projection=f"M{width}c", + contours, realisation_ffp=realisation_ffp, - title=title, ) if show_inset: with fig.inset(position=f"jTR+w{np.sqrt(width)}c", margin=0.2): From 037b0ddbce7181681123dc8a599a13904b06b259 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 27 Jan 2026 09:30:11 +1300 Subject: [PATCH 40/41] refactor(plot-ts): make plot-ts work for new xyts format --- visualisation/plot_ts.py | 465 +++++++++++++++------------------------ 1 file changed, 176 insertions(+), 289 deletions(-) diff --git a/visualisation/plot_ts.py b/visualisation/plot_ts.py index d919845..f171d75 100644 --- a/visualisation/plot_ts.py +++ b/visualisation/plot_ts.py @@ -1,8 +1,6 @@ """Create simulation video of surface ground motion levels.""" -import functools -import io -import multiprocessing as mp +import re import shutil from pathlib import Path from typing import Annotated @@ -11,20 +9,22 @@ import cartopy.feature as cfeature import cartopy.io.img_tiles as cimgt import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pyproj +import pyvista as pv +from matplotlib.colors import LinearSegmentedColormap matplotlib.use("Agg") -import ffmpeg import matplotlib.colors as mcolors -import matplotlib.pyplot as plt -import numpy as np import shapely import tqdm import typer +import xarray as xr from matplotlib.animation import FFMpegWriter, FuncAnimation from qcore import cli, coordinates -from qcore.xyts import XYTSFile from source_modelling import srf from workflow.realisations import DomainParameters, SourceConfig @@ -195,23 +195,6 @@ def plot_cartographic_features(ax: plt.Axes, scale: str) -> list: return features -def xyts_nztm_corners(xyts_file: XYTSFile) -> np.ndarray: - """Get the corners of the XYTS file in NZTM coordinates. - - Parameters - ---------- - xyts_file : XYTSFile - The XYTS file to get the corners from. - - Returns - ------- - np.ndarray - The corners of the XYTS file in NZTM coordinates. - """ - corners_geo = np.array(xyts_file.corners()) - return coordinates.wgs_depth_to_nztm(corners_geo[:, ::-1])[:, ::-1] - - def map_extents( nztm_corners: np.ndarray, padding: float ) -> tuple[float, float, float, float]: @@ -311,198 +294,31 @@ def waveform_coordinates(nztm_corners: np.ndarray, nx: int, ny: int) -> np.ndarr return coords_nztm[::-1, :, :] # Reverse order to (x, y) for NZTM -def tslice_get(xyts_file: XYTSFile, index: int, downsample: int = 1) -> np.ndarray: - """Retrieve a single timeslice from an xyts file with downsampling - - Parameters - ---------- - xyts_file : XYTSFile - The xyts file to retrieve from. - index : int - The timeslice index to read from. - downsample : int - If greater than 1, downsample the array in strides of `downsample` in - the x and y direction. - - Returns - ------- - array of float32 - An array of shape (ny, nx) containing the downsampled frame data for `index`. - """ - if downsample > 1: - frame_data = xyts_file.data[index, :, ::downsample, ::downsample] - else: - frame_data = xyts_file.data[index] # shape: (3, ny, nx) - return np.linalg.norm(frame_data, axis=0) - - -def render_single_frame( - frame_index: int, - dt: float, - xyts_file_path: Path, - source_config: SourceConfig, - nztm_corners: np.ndarray, - map_extent_nztm: tuple[float, float, float, float], - xr: np.ndarray, - yr: np.ndarray, - max_motion: float, - cmap: str, - shading: str, - simple_map: bool, - scale: str, - map_quality: int, - title: str | None, - width: float, - height: float, - dpi: int, - downsample: int, -) -> bytes: - """Render a single frame of the animation. - - Parameters - ---------- - frame_index : int - The index of the frame to render. - dt : float - The time step of the simulation. - xyts_file_path : Path - The path to the XYTS file. - source_config : SourceConfig - The source configuration object. - nztm_corners : np.ndarray - The corners of the XYTS domain in NZTM coordinates. - map_extent_nztm : tuple[float, float, float, float] - The map extents for the figure (x_min, x_max, y_min, y_max). - xr : np.ndarray - The x coordinates of the gridpoints in NZTM coordinates. - yr : np.ndarray - The y coordinates of the gridpoints in NZTM coordinates. - max_motion : float - The maximum ground motion value for color scaling. - cmap : str - The colormap to use for the animation. - shading : str - The shading to apply to the colourmap. - simple_map : bool - If True, disable OpenStreetMap background and use a simple map. - scale : str - The scale for cartographic features. - map_quality : int - The quality of the map (lower values are lower quality). - title : str | None - The title for the animation. - width : float - The width of the figure in cm. - height : float - The height of the figure in cm. - dpi : int - The DPI for the figure. - downsample : int, optional - If greater than 1, downsample the timeslice array in strides of - `downsample` in the x and y direction. Provides a speedup for large - domains. - - Returns - ------- - bytes - The raw frame output for the frame index - """ - xyts_file = XYTSFile(xyts_file_path) - # Create a new figure for this frame - cm = 1 / 2.54 - fig = plt.figure(figsize=(width * cm, height * cm)) - ax = fig.add_subplot(1, 1, 1, projection=NZTM_CRS) - ax.set_extent(map_extent_nztm, crs=NZTM_CRS) - - # Add all static elements - if simple_map: - plot_cartographic_features(ax, scale) - plot_towns(ax, map_extent_nztm) - else: - request = cimgt.OSM(cache=True) - request._MAX_THREADS = ( - 1 # Limit to one thread because it is in a multiprocess pool. - ) - ax.add_image( - request, - 10, - interpolation="spline36", - regrid_shape=map_quality * 1000, - zorder=0, - ) - - ax.add_geometries( - [shapely.Polygon(nztm_corners)], - facecolor="none", - edgecolor="black", - linestyle="--", - zorder=1, - crs=NZTM_CRS, - ) - - ax.add_geometries( - [ - shapely.transform(fault.geometry, lambda coords: coords[:, ::-1]) - for fault in sorted( - source_config.source_geometries.values(), - key=lambda fault: -fault.centroid[-1], - ) - ], - facecolor="red", - edgecolor="black", - zorder=2, - crs=NZTM_CRS, - ) - - # Add the actual data for this frame - - current_data = tslice_get(xyts_file, frame_index, downsample=downsample) - pcm = ax.pcolormesh( - yr[::downsample, ::downsample], - xr[::downsample, ::downsample], - apply_cmap_with_alpha(current_data, 0, max_motion, cmap=cmap), - cmap=cmap, - vmin=0, - vmax=max_motion, - shading=shading, - zorder=3, - transform=NZTM_CRS, - ) - - # Add time text - current_time = frame_index * dt - ax.text( - 0.98, - 0.02, - f"Time: {current_time:.2f} s", - transform=ax.transAxes, - fontsize=12, - color="black", - fontweight="bold", - ha="right", - va="bottom", - bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.8}, - ) - - if title: - fig.suptitle(title, fontsize=16) - - plt.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) - cbar = fig.colorbar( - pcm, - ax=ax, - orientation="vertical", - pad=0.02, - aspect=30, - shrink=0.8, - ) - cbar.set_label("Ground Motion (cm/s)") - - # Save the frame to a file - with io.BytesIO() as io_buf: - fig.savefig(io_buf, format="raw", dpi=dpi) - plt.close(fig) - return io_buf.getvalue() +def cmap_from_cpt(cpt_file: Path) -> tuple[float, float, LinearSegmentedColormap]: + line_re = r"^(?P[0 -9\.\-]+)\s+(?P\d+)/(?P\d+)/(?P\d+)" + + colours = [] + elevations = [] + with open(cpt_file, "r") as f: + for line in f: + if m := re.match(line_re, line): + elevation = float(m.group("el")) + r = int(m.group("r")) + g = int(m.group("g")) + b = int(m.group("b")) + elevations.append(elevation) + colours.append((r, g, b)) + colours_np = np.array(colours) / 255.0 + elevations_np = np.array(elevations) + el_min = elevations_np.min() + el_max = elevations_np.max() + normalised = (elevations_np - el_min) / (el_max - el_min) + segmentdata = { + "red": [(frac, c[0], c[0]) for frac, c in zip(normalised, colours_np)], + "green": [(frac, c[1], c[1]) for frac, c in zip(normalised, colours_np)], + "blue": [(frac, c[2], c[2]) for frac, c in zip(normalised, colours_np)], + } + return el_min, el_max, LinearSegmentedColormap(cpt_file.stem, segmentdata) @cli.from_docstring(app, name="xyts") @@ -583,92 +399,163 @@ def animate_low_frequency( ) raise typer.Exit(code=1) + dem_dataset = xr.open_dataarray("dem.h5") + x_coords = dem_dataset.x.values + y_coords = dem_dataset.y.values + x, y = np.meshgrid(x_coords, y_coords) + z = dem_dataset.to_numpy() + z = np.where(np.isnan(z), -250, z) + z_max = z.max() + dem_grid = pv.StructuredGrid(x, y, z) + dem_grid["elevation"] = z.T.ravel() + source_config = SourceConfig.read_from_realisation(realisation_ffp) - xyts_file = XYTSFile(xyts_ffp) + planes = [] + lines = [] + plane_max = 0 + for fault in source_config.source_geometries.values(): + for plane in fault.planes: + bounds = plane.bounds + plane_max = max(plane_max, bounds[-1, 2]) + bounds[:, [0, 1]] = bounds[:, [1, 0]] + bounds[:, 2] = z_max + 5 + planes.extend( + [ + pv.Triangle([bounds[0], bounds[1], bounds[-1]]), + pv.Triangle([bounds[1], bounds[2], bounds[-1]]), + ] + ) + point_a = bounds[0].copy() + point_a[2] = z_max + 10 + point_b = bounds[1].copy() + point_b[2] = z_max + 10 + lines.extend([point_a, point_b]) + + cmap_min, cmap_max, cmap = cmap_from_cpt( + Path("/home/jake/tmp/palm_springs_nz_topo.cpt") + ) + xyts_dataset = xr.open_dataset(xyts_ffp) + (nt, ny, nx) = xyts_dataset.waveform.shape + proj = coordinates.SphericalProjection( + xyts_dataset.mlon, xyts_dataset.mlat, xyts_dataset.mrot + ) + dx = xyts_dataset.dx - nztm_corners = xyts_nztm_corners(xyts_file) - map_extent_nztm = map_extents(nztm_corners, padding) + y_sim_bounds = np.linspace(-0.5, 0.5, num=ny) * ny * (dx * 5) + x_sim_bounds = np.linspace(-0.5, 0.5, num=nx) * nx * (dx * 5) - if zoom != 1: - centre = shapely.centroid( - shapely.union_all( - [fault.geometry for fault in source_config.source_geometries.values()] - ) - ) - map_extent_nztm = zoom_extents( - map_extent_nztm, - (centre.y, centre.x), - zoom, - ) + y_sim, x_sim = np.meshgrid(y_sim_bounds, x_sim_bounds, indexing="ij") + print(y_sim.shape) - frame_count = frame_count or xyts_file.nt - xr, yr = waveform_coordinates(nztm_corners, xyts_file.nx, xyts_file.ny) + x_flat = x_sim.ravel(order="F") + y_flat = y_sim.ravel(order="F") - render_frame = functools.partial( - render_single_frame, - dt=xyts_file.dt, - shading=shading, - xyts_file_path=xyts_ffp.resolve(), - max_motion=max_motion, - cmap=cmap, - source_config=source_config, - nztm_corners=nztm_corners, - map_extent_nztm=map_extent_nztm, - xr=xr, - yr=yr, - simple_map=simple_map, - scale=scale, - map_quality=map_quality, - title=title, - width=width, - height=height, - dpi=dpi, - downsample=downsample, + points = proj.inverse(x_flat, y_flat) + lon_sim = points[:, 1] + lat_sim = points[:, 0] + + proj = pyproj.Transformer.from_crs(4326, 2193, always_xy=True) + x_nztm_flat, y_nztm_flat = proj.transform(lon_sim, lat_sim) + + x_nztm = x_nztm_flat.reshape((ny, nx), order="F") + y_nztm = y_nztm_flat.reshape((ny, nx), order="F") + corners = np.array( + [ + [x_nztm[0, 0], y_nztm[0, 0], z_max], + [x_nztm[-1, 0], y_nztm[-1, 0], z_max], + [x_nztm[-1, -1], y_nztm[-1, -1], z_max], + [x_nztm[0, -1], y_nztm[0, -1], z_max], + [x_nztm[0, 0], y_nztm[0, 0], z_max], + ] ) + z_plane = np.full((ny, nx), z_max + ((1 << 16) - 1) * 0.1) - # warm the OSM cache to speed up rendering by rendering the first frame + grid = pv.StructuredGrid(x_nztm, y_nztm, z_plane) - frames = [render_frame(0)] + grid["Ground Motion (cm/s)"] = grid.points[:, -1] + plotter = pv.Plotter(notebook=False, off_screen=True) + plotter.remove_all_lights() + plotter.ren_win.SetSize([1920, 1088]) + # plotter.enable_anti_aliasing() - with mp.Pool() as pool: - # Render all frames in parallel - frames.extend( - tqdm.tqdm( - pool.imap(render_frame, range(frame_start, frame_start + frame_count)), - total=frame_count, - unit="frame", - desc="Rendering frames", - initial=1, - ) - ) - cm = 1 / 2.54 - width_px = int(width * cm * dpi) - height_px = int(height * cm * dpi) - # Use ffmpeg to combine frames into video - process = ( - ffmpeg.input( - "pipe:0", format="rawvideo", pix_fmt="rgba", s=f"{width_px}x{height_px}" - ) - .output( - str(output_mp4), - pix_fmt="yuv420p", - r=fps, - vcodec="libx264", - crf=23, - vf="pad=ceil(iw/2)*2:ceil(ih/2)*2", - ) - .overwrite_output() - .run_async(pipe_stdin=True) + plotter.open_movie(output_mp4, framerate=fps, quality=10) + for plane in planes: + plotter.add_mesh(plane, color="red", lighting=False) + + plotter.add_lines(corners, connected=True, color="black", width=3) + plotter.add_lines(np.array(lines), color="black", width=2) + + plotter.add_mesh( + dem_grid, + lighting=False, + smooth_shading=False, + cmap=cmap, + clim=(cmap_min, cmap_max), + show_scalar_bar=False, ) + plotter.add_mesh( + grid, + lighting=False, + smooth_shading=False, + scalars="Ground Motion (cm/s)", + clim=[0, 100], + cmap="hot", + show_edges=False, + nan_opacity=0.0, + show_scalar_bar=True, + scalar_bar_args=dict( + title="Ground Motion\n(cm/s)\n", + vertical=True, + position_x=0.85, + position_y=0.25, + bold=True, + color="white", + height=0.5, + width=0.05, + n_labels=11, + ), + ) + + plotter.camera.tight() + # plotter.camera.set + plotter.set_background((173 / 255, 216 / 255, 230 / 255)) + text = plotter.add_text("0.00s", name="time-label", position="lower_right") + + plotter.show(auto_close=False) + + frame_chunk = 100 + i_frame = frame_start + frames = None + scalars = np.empty((ny, nx), dtype=np.float32, order="F") + time = xyts_dataset.time.values + dt = xyts_dataset.attrs["dt"] + for i in tqdm.trange( + frame_start, + frame_start + min(frame_count or nt - frame_start, nt - frame_start), + ): + if i >= i_frame: + next = min(i_frame + frame_chunk, nt) + frames = ( + xyts_dataset.waveform.isel(time=range(i_frame, next)) + .astype(np.float32) + .values + ) + i_frame = next + assert frames is not None + + z_geometry = frames[i - i_frame] + np.copyto(scalars, frames[i - i_frame]) - # Write the raw video data to FFmpeg's stdin - for frame in frames: - process.stdin.write(frame) + scalars[scalars < 10] = np.nan - process.stdin.close() + grid["Ground Motion (cm/s)"] = scalars.ravel(order="F") - # Wait for FFmpeg to finish - process.wait() + np.add(z_geometry, z_max, out=z_geometry) + grid.points[:, -1] = z_geometry.ravel(order="F") + if i % round(1 / dt): + text.set_text("lower_right", f"{time[i]:.2f}s") + plotter.write_frame() + plotter.close() def non_zero_data_points( From cb6fb35c541997dc18f4fec04f1f29901fd0e3a8 Mon Sep 17 00:00:00 2001 From: Jake Faulkner Date: Tue, 27 Jan 2026 09:30:31 +1300 Subject: [PATCH 41/41] ci: add plot-waveform utility --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dbb26ac..33977ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,9 @@ name = "visualisation" authors = [{ name = "QuakeCoRE" }] description = "Visualisation repository for plotting scripts." readme = "README.md" -requires-python = ">=3.12,<3.14" +requires-python = ">=3.12" dynamic = ["version", "dependencies"] - [project.scripts] plot-srf-moment = "visualisation.sources.plot_srf_moment:app" plot-domain = "visualisation.realisation:app" @@ -27,6 +26,7 @@ plot-stoch = "visualisation.sources.plot_stoch:app" plot-ts = "visualisation.plot_ts:app" plot-response-rrup = "visualisation.ims.response_rrup:app" plot-response-spectra = "visualisation.waveforms.plot_response_spectra:app" +plot-waveform = "visualisation.waveforms.plot_waveform:app" plot-gmm-comparison = "visualisation.ims.gmm_comparison:app" [tool.setuptools.package-dir]