diff --git a/pyproject.toml b/pyproject.toml index 837a6f7..33977ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ readme = "README.md" requires-python = ">=3.12" dynamic = ["version", "dependencies"] - [project.scripts] plot-srf-moment = "visualisation.sources.plot_srf_moment:app" plot-domain = "visualisation.realisation:app" @@ -25,6 +24,10 @@ plot-1d-velocity-model = "visualisation.plot_1d_velocity_model:app" plot-rupture-path = "visualisation.plot_rupture_path:app" plot-stoch = "visualisation.sources.plot_stoch:app" plot-ts = "visualisation.plot_ts:app" +plot-response-rrup = "visualisation.ims.response_rrup:app" +plot-response-spectra = "visualisation.waveforms.plot_response_spectra:app" +plot-waveform = "visualisation.waveforms.plot_waveform:app" +plot-gmm-comparison = "visualisation.ims.gmm_comparison:app" [tool.setuptools.package-dir] visualisation = "visualisation" @@ -53,7 +56,7 @@ extend-select = [ # Missing function argument type-annotation "ANN001", # Using except without specifying an exception type to catch - "BLE001" + "BLE001", ] ignore = ["D104"] @@ -62,15 +65,15 @@ convention = "numpy" [tool.ruff.lint.isort] known-first-party = [ - "source_modelling", - "visualisation", - "workflow", - "pygmt_helper", - "qcore", - "empirical", - "nshmdb", - "IM_calculation", - "mera" + "source_modelling", + "visualisation", + "workflow", + "pygmt_helper", + "qcore", + "empirical", + "nshmdb", + "IM_calculation", + "mera", ] [tool.ruff.lint.per-file-ignores] @@ -80,9 +83,7 @@ known-first-party = [ "tests/**.py" = ["D"] [tool.coverage.run] -omit = [ - "visualisation/plot_ts.py" -] +omit = ["visualisation/plot_ts.py"] [tool.numpydoc_validation] checks = [ @@ -103,7 +104,7 @@ checks = [ "YD01", ] # remember to use single quotes for regex in TOML -exclude = [ # don't report on objects that match any of these regex - '\.undocumented_method$', - '\.__repr__$', +exclude = [ # don't report on objects that match any of these regex + '\.undocumented_method$', + '\.__repr__$', ] diff --git a/requirements.txt b/requirements.txt index 5287e84..a825f2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,5 @@ pytest-cov pytest-xdist typer tqdm +rpy2 +xarray[io] diff --git a/visualisation/ims/gmm_comparison.py b/visualisation/ims/gmm_comparison.py new file mode 100644 index 0000000..80031ea --- /dev/null +++ b/visualisation/ims/gmm_comparison.py @@ -0,0 +1,193 @@ +from pathlib import Path +from typing import Annotated + +import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw +import pandas as pd +import pygmt +import shapely +import typer +import xarray as xr + +from pygmt_helper import plotting +from qcore import cli +from visualisation import realisation, utils +from workflow.realisations import ( + DomainParameters, + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) + +app = typer.Typer() + + +def plot_diff( + fig: pygmt.Figure, + dset: xr.Dataset, + gmm_psa_value: pd.Series, + period: float, + cmap: str, + cmap_max: float | None, + cmap_min: float | None, + ticks: int, + reverse: bool, +) -> None: + intensity = dset["pSA"].sel(period=period, component="rotd50") + pgv_value = intensity.to_series() + diff = np.log(pgv_value) - gmm_psa_value + cmap_min = cmap_min or diff.min() + cmap_max = cmap_max or diff.max() + + cmap_limits = utils.range_for(cmap_min, cmap_max, ticks) + + latitude = intensity.latitude.to_series() + longitude = intensity.longitude.to_series() + df = pd.DataFrame({"lat": latitude, "lon": longitude, "value": diff}) + + grid: xr.DataArray = plotting.create_grid( + df, + "value", + grid_spacing="1000e/1000e", + region=tuple(fig.region), + set_water_to_nan=True, + ) + + plotting.plot_grid( + fig, + grid, + cmap, + cmap_limits, + ("red", "blue"), + reverse_cmap=reverse, + transparency=40, + plot_contours=False, + ) + + +def find_region(domain: DomainParameters) -> tuple[float, float, float, float]: + """Find an appropriate domain,""" + nz_region = shapely.box(166.0, -48.0, 178.5, -34.0) + region = shapely.union( + utils.polygon_nztm_to_pygmt(domain.domain.polygon), nz_region + ) + (min_x, min_y, max_x, max_y) = shapely.bounds(region) + return (min_x, max_x, min_y, max_y) + + +def generate_basemap(region: tuple[float, float, float, float]) -> pygmt.Figure: + fig: pygmt.Figure = plotting.gen_region_fig( + title=None, + region=region, + plot_kwargs=dict( + plot_kwargs=["af", "xaf+Longitude", "yaf+Latitude"], + water_color="white", + topo_cmap_min=-900, + topo_cmap_max=3100, + ), + plot_highways=False, + config_options=dict( + MAP_FRAME_TYPE="plain", + FORMAT_GEO_MAP="ddd.xx", + MAP_FRAME_PEN="thinner,black", + ), + ) + assert isinstance(fig, pygmt.Figure) + return fig + + +@cli.from_docstring(app) +def main( + realisation_ffp: Annotated[Path, typer.Argument()], + dataset: Annotated[ + Path, + typer.Argument(), + ], + period: Annotated[ + float, + typer.Argument(), + ], + output: Annotated[ + Path, + typer.Argument(), + ], + cmap: Annotated[ + str, + typer.Option(), + ] = "polar", + reverse: Annotated[ + bool, + typer.Option(is_flag=True), + ] = False, + cmap_min: Annotated[ + float | None, + typer.Option(), + ] = None, + cmap_max: Annotated[ + float | None, + typer.Option(), + ] = None, + ticks: Annotated[ + int, + typer.Option(), + ] = 10, +) -> None: + """Compare simulation results to predictions from the NSHM2022 logic tree. + + Parameters + ---------- + realisation_ffp : Path + Path to realisation. + dataset : Path + Path to xarray intensity measure dataset. + period : float + pSA period to compare against. + output : Path + The path to write the figure out to. + cmap_min : float + Colourmap minimum + cmap_max : float + Colourmap maximum + ticks : int + Number of ticks in discrete colourmap. + output : Path + Output path. + cmap : str + Colourmap to plot log residuals. Should be divering. + reverse : bool + If true, reverse the colourmap. Defaults to false. + """ + + dset = xr.open_dataset(dataset, engine="h5netcdf") + domain = DomainParameters.read_from_realisation(realisation_ffp) + source_config = SourceConfig.read_from_realisation(realisation_ffp) + magnitudes = Magnitudes.read_from_realisation(realisation_ffp) + rakes = Rakes.read_from_realisation(realisation_ffp) + rupture_propagation_config = RupturePropagationConfig.read_from_realisation( + realisation_ffp + ) + + region = find_region(domain) + + fig = generate_basemap(region) + gmm_psa_value = utils.get_gmm_prediction( + dset, period, source_config, magnitudes, rakes, rupture_propagation_config + ) + + plot_diff( + fig, + dset, + gmm_psa_value, + period, + cmap, + cmap_max, + cmap_min, + ticks, + reverse, + ) + realisation.plot_domain(fig, domain, pen="1p,black,-") + realisation.plot_sources(fig, source_config) + + fig.savefig(output) diff --git a/visualisation/ims/response_rrup.py b/visualisation/ims/response_rrup.py new file mode 100644 index 0000000..fed4e47 --- /dev/null +++ b/visualisation/ims/response_rrup.py @@ -0,0 +1,599 @@ +import re +from pathlib import Path +from typing import Annotated, NamedTuple + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw +import pandas as pd +import typer +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from visualisation import utils +from visualisation.utils import ConfidenceInterval, RuptureContext, SiteProperties +from workflow.realisations import ( + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) + +app = typer.Typer() + + +def plot_nshm_fit( + ax: Axes, + realisation_ffp: Path, + site_ds: xr.Dataset, + period: float, + rrup: npt.NDArray[np.floating], + color: str | None = None, + label: str | None = None, +) -> ConfidenceInterval: + source_config = SourceConfig.read_from_realisation(realisation_ffp) + magnitudes = Magnitudes.read_from_realisation(realisation_ffp) + rupture_prop = RupturePropagationConfig.read_from_realisation(realisation_ffp) + rakes = Rakes.read_from_realisation(realisation_ffp) + rupture_context = utils.compute_rupture_context( + source_config, magnitudes, rakes, rupture_prop + ) + mean_vs30 = utils.mean_vs30(site_ds.vs30.values) + mean_z1pt0 = oqw.estimations.chiou_young_08_calc_z1p0(mean_vs30) + mean_z2pt5 = oqw.estimations.chiou_young_08_calc_z2p5(mean_vs30) + site_properties = SiteProperties( + vs30=mean_vs30, + vs30measured=False, + z1pt0=mean_z1pt0, + z2pt5=mean_z2pt5, + rrup=rrup, + rjb=rrup, + rx=rrup, + ry=rrup, + ) + + logic_tree_results = utils.nshm2022_logic_tree_prediction( + rupture_context, + site_properties, + period, + ) + period_str = ( + f"{period:.2f}".rstrip("0") if not period.is_integer() else f"{int(period)}.0" + ) + mean = logic_tree_results[f"pSA_{period_str}_mean"] + std = logic_tree_results[f"pSA_{period_str}_std_Total"] + fit = np.exp(mean) + ci_low = np.exp(mean - std) + ci_high = np.exp(mean + std) + ax.fill_between(rrup, ci_low, ci_high, alpha=0.3, color=color) + ax.plot(rrup, fit, c=color, label=label) + return ConfidenceInterval(fit, ci_low, ci_high) + + +def plot_simulation_fit( + ax: Axes, + rrup: np.ndarray, + psa: np.ndarray, + label: str | None, + color: str, + span: float = 1 / 3, + show_bands: bool = True, +) -> ConfidenceInterval: + """Plot LOESS fit for a subset of simulation data.""" + rrup_out = np.linspace(rrup.min(), rrup.max(), num=100) + fit, ci_low, ci_high = utils.fit_loess_r( + np.log(psa), np.log(rrup), np.log(rrup_out), span=span + ) + if show_bands: + ax.fill_between( + rrup_out, np.exp(ci_low), np.exp(ci_high), alpha=0.3, color=color + ) + ax.plot(rrup_out, np.exp(fit), c=color, label=label) + return ConfidenceInterval(fit, ci_low, ci_high) + + +def human_readable_basin_name(basin_name: str) -> str: + # NE_Otago -> NE Otago + basin_name_no_underscore = basin_name.replace("_", " ") + # Greate(r)(W)ellington -> Greate(r) (W)ellington + return re.sub(r"([a-z])([A-Z])", r"\1 \2", basin_name_no_underscore) + + +def _get_plotting_params(simulation_ds: xr.Dataset): + """Calculate common plotting parameters.""" + max_rrup = min(500, simulation_ds.rrup.max().item()) + nshm_rrup = np.geomspace(1e-3, max_rrup, num=100) + return max_rrup, nshm_rrup + + +def _apply_style_and_limits( + fig: Figure, + axes: np.ndarray, + period: float, + max_rrup: float, + ymin: float | None, + ymax: float | None, + is_multi_plot: bool, + ax_main: plt.Axes, + xmin: float | None = 1e-1, + xmax: float | None = None, +): + """Apply final styling, labels, and limits to all axes.""" + if is_multi_plot: + fig.supxlabel("Source to site distance, $R_{rup}$ [km]") + fig.supylabel(f"pSA({period:.2f} s) [g]") + else: + ax_main.set_xlabel("Source to site distance, $R_{rup}$ [km]") + ax_main.set_ylabel(f"pSA({period:.2f} s) [g]") + + if ymin is not None or ymax is not None: + for ax in axes.flatten(): + ax.set_ylim(bottom=ymin, top=ymax) + xmax = xmax or max_rrup + for ax in axes.flatten(): + ax.set_xlim(left=xmin, right=xmax) + + +def _plot_settings(ax: Axes) -> None: + ax.grid(True, which="both", axis="both", lw=0.3) + ax.set_yscale("log") + ax.set_xscale("log") + + +def _get_basin_stations(simulation_ds: xr.Dataset, basin: str): + """Filter the dataset for stations belonging to a specific basin.""" + # Uses xarray's .where() for cleaner filtering + return simulation_ds.where(simulation_ds.basin == basin, drop=True) + + +def plot_basin_vs_no_basin( + realisation_ffp: Path, + simulation_ds: xr.Dataset, + period: float, + component: str = "rotd50", + xmin: float | None = None, + xmax: float | None = None, + ymin: float | None = None, + ymax: float | None = None, + span: float = 1 / 3, +): + """ + Creates a single plot comparing simulation data for basin stations + vs. non-basin stations against the NSHM prediction. + """ + max_rrup, nshm_rrup = _get_plotting_params(simulation_ds) + + fig, ax = plt.subplots(constrained_layout=True) + all_axes = np.array([ax]) # For style helper consistency + _plot_settings(ax) + # 1. Plot NSHM fit and set axis scales/grid + plot_nshm_fit( + ax, + realisation_ffp, + simulation_ds, + period, + nshm_rrup, + color="tab:blue", + ) + + # 2. Split and plot simulation data + basin_ds = simulation_ds.where(simulation_ds.basin != "", drop=True) + non_basin_ds = simulation_ds.where(simulation_ds.basin == "", drop=True) + + # Basin stations + basin_pSA = basin_ds.pSA.sel(period=period, component=component).values + ax.scatter(basin_ds.rrup, basin_pSA, c="tab:red", alpha=0.3, s=5) + + plot_simulation_fit( + ax, + basin_ds.rrup.values, + basin_pSA, + label="Basin stations", + color="darkred", + span=span, + ) + + # Non-Basin stations + non_basin_pSA = non_basin_ds.pSA.sel(period=period, component=component).values + ax.scatter(non_basin_ds.rrup, non_basin_pSA, c="tab:purple", alpha=0.3, s=5) + plot_simulation_fit( + ax, + non_basin_ds.rrup.values, + non_basin_pSA, + label="Non-basin stations", + color="purple", + span=span, + ) + + ax.legend() + + # 3. Apply final styling + _apply_style_and_limits( + fig, all_axes, period, max_rrup, ymin, ymax, False, ax, xmin=xmin, xmax=xmax + ) + + return fig + + +def plot_separate_basin_subplots( + realisation_ffp: Path, + simulation_ds: xr.Dataset, + period: float, + basins: list[str], + component: str = "rotd50", + xmin: float | None = None, + xmax: float | None = None, + ymin: float | None = None, + ymax: float | None = None, + span: float = 1 / 3, +): + """ + Creates a grid of subplots: one for all stations, and one for each basin. + """ + if not basins: + raise ValueError("Basins list cannot be empty for separate basin plotting.") + + max_rrup, nshm_rrup = _get_plotting_params(simulation_ds) + + # Setup figure with 1 + N_basins plots + num_plots = 1 + len(basins) + # utils.balanced_subplot_grid is assumed to return a 2D array of axes + fig, axes_2d = utils.balanced_subplot_grid( + num_plots, + 1.0, + subplot_size=(8, 6), + clear=True, + sharex=True, + sharey=True, + constrained_layout=True, + ) + all_axes = axes_2d.flatten() + ax_all_stations = all_axes[0] + _plot_settings(ax_all_stations) + # --- A. Plot All Stations (Primary Plot) --- + plot_nshm_fit(ax_all_stations, realisation_ffp, simulation_ds, period, nshm_rrup) + + # Plot all simulation stations together + all_pSA = simulation_ds.pSA.sel(period=period, component=component).values + ax_all_stations.scatter(simulation_ds.rrup, all_pSA, c="k", alpha=0.3, s=10) + plot_simulation_fit( + ax_all_stations, + simulation_ds.rrup.values, + all_pSA, + label="Simulated stations", + color="tab:gray", + span=span, + ) + ax_all_stations.legend() + + # --- B. Plot Individual Basins --- + for ax_basin, basin in zip(all_axes[1:], basins): + _plot_settings(ax_basin) + subds = _get_basin_stations(simulation_ds, basin) + + if len(subds.station) == 0: + ax_basin.set_title(f"Basin: {human_readable_basin_name(basin)} (No data)") + continue + + plot_nshm_fit( + ax_basin, realisation_ffp, subds, period, nshm_rrup, label=None + ) # No legend for NSHM here + + basin_pSA = subds.pSA.sel(period=period, component=component).values + ax_basin.scatter(subds.rrup, basin_pSA, c="red", alpha=0.7, s=10) + ax_basin.set_title(f"Basin: {human_readable_basin_name(basin)}") + + # 3. Apply final styling (is_multi_plot=True) + _apply_style_and_limits( + fig, + all_axes, + period, + max_rrup, + ymin, + ymax, + True, + ax_all_stations, + xmin=xmin, + xmax=xmax, + ) + + return fig + + +def plot_combined_basin_plot( + realisation_ffp: Path, + simulation_ds: xr.Dataset, + period: float, + basins: list[str], + component: str = "rotd50", + xmin: float | None = None, + xmax: float | None = None, + ymin: float | None = None, + ymax: float | None = None, + span: float = 1 / 3, +): + """ + Creates a single plot showing all stations and then overlays each basin. + Uses a colormap to assign distinct colors to each basin plot. + """ + if not basins: + # Fall back to plotting only all stations if no basins are specified + print("Warning: No basins specified. Plotting all stations only.") + # Proceed with the rest of the function for the 'All Stations' plot + + max_rrup, nshm_rrup = _get_plotting_params(simulation_ds) + + fig, ax = plt.subplots(constrained_layout=True) + all_axes = np.array([ax]) # For style helper consistency + + # 1. Plot NSHM fit and set axis scales/grid + _plot_settings(ax) + fit, ci_low, ci_high = plot_nshm_fit( + ax, realisation_ffp, simulation_ds, period, nshm_rrup + ) + + # 2. Plot All Stations (Base Layer) + all_pSA = simulation_ds.pSA.sel(period=period, component=component).values + ax.scatter( + simulation_ds.rrup, + all_pSA, + c="k", + alpha=0.1, + s=10, + label="Simulation", + ) + plot_simulation_fit( + ax, + simulation_ds.rrup.values, + all_pSA, + label=None, + color="tab:gray", + span=span, + ) + + basin_colours = mpl.color_sequences["Dark2"] + # 3. Overlay Individual Basins (Scatter Plots) + log_fit = np.log(fit) + log_nshm_rrup = np.log(nshm_rrup) + + for basin, colour in zip(basins, basin_colours): + subds = _get_basin_stations(simulation_ds, basin) + + if len(subds.station) == 0: + continue + + # Use the distinct color generated from the colormap + + basin_pSA = subds.pSA.sel(period=period, component=component).values + log_rrup = np.log(subds.rrup.values) + basin_misfit = np.log(basin_pSA) - np.interp(log_rrup, log_nshm_rrup, log_fit) + mean_misfit = np.mean(basin_misfit) + + log_basin_rrup_min = log_rrup.min() + log_basin_rrup_max = log_rrup.max() + log_basin_misfit_rrup = np.linspace( + log_basin_rrup_min, log_basin_rrup_max, num=100 + ) + nshm_subline = ( + np.interp(log_basin_misfit_rrup, log_nshm_rrup, log_fit) + mean_misfit + ) + ax.scatter( + subds.rrup, + basin_pSA, + alpha=0.7, + s=10, + color=colour, + label=f"{human_readable_basin_name(basin)}", + ) + ax.plot( + np.exp(log_basin_misfit_rrup), + np.exp(nshm_subline), + color=utils.adjust_value(colour, 0.8), + ) + + ax.legend() + + # 5. Apply final styling (is_multi_plot=False) + _apply_style_and_limits( + fig, all_axes, period, max_rrup, ymin, ymax, False, ax, xmin=xmin, xmax=xmax + ) + + return fig + + +@app.command() +def plot_basin_split( + # Arguments + realisation_ffp: Annotated[ + Path, typer.Argument(help="Path to the NSHM fit data file.") + ], + simulation_dataset_path: Annotated[ + Path, typer.Argument(help="Path to the simulation xarray dataset (H5NetCDF).") + ], + period: Annotated[ + float, typer.Argument(help="Spectral acceleration period (T) in seconds.") + ], + # Options + save: Annotated[ + Path | None, + typer.Option("--save", "-s", help="Output path to save the figure."), + ] = None, + dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, + ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, + ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + xmin: Annotated[float | None, typer.Option(help="Minimum x-axis limit.")] = None, + xmax: Annotated[float | None, typer.Option(help="Maximum x-axis limit.")] = None, + component: Annotated[ + str, typer.Option(help="PSA component to plot (e.g., rotd50).") + ] = "rotd50", + span: Annotated[ + float, typer.Option(help="Smoothing span for the simulation fit line.") + ] = 1 / 3, +) -> None: + """ + Creates a single plot comparing simulated pSA for Basin stations against + Non-Basin stations, along with the NSHM prediction. + """ + simulation_ds = xr.open_dataset(simulation_dataset_path, engine="h5netcdf") + + fig = plot_basin_vs_no_basin( + realisation_ffp, + simulation_ds, + period, + component=component, + ymin=ymin, + ymax=ymax, + xmin=xmin, + xmax=xmax, + span=span, + ) + + if save: + fig.savefig(save, dpi=dpi) + else: + plt.show() + + +# ---------------------------------------------------- +# 2. Command: All Stations + Separate Basin Subplots +# ---------------------------------------------------- +@app.command() +def plot_basins_separate( + # Arguments + realisation_ffp: Annotated[ + Path, typer.Argument(help="Path to the NSHM fit data file.") + ], + simulation_dataset_path: Annotated[ + Path, typer.Argument(help="Path to the simulation xarray dataset (H5NetCDF).") + ], + period: Annotated[ + float, typer.Argument(help="Spectral acceleration period (T) in seconds.") + ], + # Options + basins: Annotated[ + list[str] | None, + typer.Option( + "--basin", "-b", help="List of basins to plot in separate subplots." + ), + ] = None, + save: Annotated[ + Path | None, + typer.Option("--save", "-s", help="Output path to save the figure."), + ] = None, + dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, + ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, + ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + xmin: Annotated[float | None, typer.Option(help="Minimum x-axis limit.")] = None, + xmax: Annotated[float | None, typer.Option(help="Maximum x-axis limit.")] = None, + component: Annotated[ + str, typer.Option(help="PSA component to plot (e.g., rotd50).") + ] = "rotd50", + span: Annotated[ + float, typer.Option(help="Smoothing span for the overall simulation fit line.") + ] = 1 / 3, +) -> None: + """ + Creates a grid of plots: one showing all stations, and one separate subplot + for each specified basin, comparing to NSHM. + """ + # Note: basins will be list[str] or None. Check for empty list after parsing. + basins_list = basins or [] + if not basins_list: + typer.echo( + "Error: At least one basin must be specified using --basin for this command.", + err=True, + ) + raise typer.Exit(code=1) + + simulation_ds = xr.open_dataset(simulation_dataset_path, engine="h5netcdf") + + fig = plot_separate_basin_subplots( + realisation_ffp, + simulation_ds, + period, + basins=basins_list, + component=component, + ymin=ymin, + ymax=ymax, + xmin=xmin, + xmax=xmax, + span=span, + ) + + if save: + fig.savefig(save, dpi=dpi) + else: + plt.show() + + +# ---------------------------------------------------- +# 3. Command: All Stations + Basins in Combined Plot +# ---------------------------------------------------- +@app.command() +def plot_basins_combined( + # Arguments + realisation_ffp: Annotated[ + Path, typer.Argument(help="Path to the NSHM fit data file.") + ], + simulation_dataset_path: Annotated[ + Path, typer.Argument(help="Path to the simulation xarray dataset (H5NetCDF).") + ], + period: Annotated[ + float, typer.Argument(help="Spectral acceleration period (T) in seconds.") + ], + # Options + basins: Annotated[ + list[str] | None, + typer.Option( + "--basin", "-b", help="List of basins to overlay on the main plot." + ), + ] = None, + save: Annotated[ + Path | None, + typer.Option("--save", "-s", help="Output path to save the figure."), + ] = None, + dpi: Annotated[int, typer.Option(help="DPI for saving the figure.")] = 300, + ymin: Annotated[float | None, typer.Option(help="Minimum y-axis limit.")] = 1e-5, + ymax: Annotated[float | None, typer.Option(help="Maximum y-axis limit.")] = 10, + xmin: Annotated[float | None, typer.Option(help="Minimum x-axis limit.")] = None, + xmax: Annotated[float | None, typer.Option(help="Maximum x-axis limit.")] = None, + component: Annotated[ + str, typer.Option(help="PSA component to plot (e.g., rotd50).") + ] = "rotd50", + span: Annotated[ + float, typer.Option(help="Smoothing span for the overall simulation fit line.") + ] = 1 / 3, +) -> None: + """ + Creates a single plot showing all simulation stations (as background), + overlaid with scatters for each specified basin, and the NSHM prediction. + """ + basins_list = basins or [] + if not basins_list: + typer.echo( + "Warning: No basins specified. Plotting all stations (combined) only.", + err=True, + ) + + simulation_ds = xr.open_dataset(simulation_dataset_path, engine="h5netcdf") + + fig = plot_combined_basin_plot( + realisation_ffp, + simulation_ds, + period, + basins=basins_list, + component=component, + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax, + span=span, + ) + + if save: + fig.savefig(save, dpi=dpi) + else: + plt.show() diff --git a/visualisation/plot_ts.py b/visualisation/plot_ts.py index d919845..f171d75 100644 --- a/visualisation/plot_ts.py +++ b/visualisation/plot_ts.py @@ -1,8 +1,6 @@ """Create simulation video of surface ground motion levels.""" -import functools -import io -import multiprocessing as mp +import re import shutil from pathlib import Path from typing import Annotated @@ -11,20 +9,22 @@ import cartopy.feature as cfeature import cartopy.io.img_tiles as cimgt import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pyproj +import pyvista as pv +from matplotlib.colors import LinearSegmentedColormap matplotlib.use("Agg") -import ffmpeg import matplotlib.colors as mcolors -import matplotlib.pyplot as plt -import numpy as np import shapely import tqdm import typer +import xarray as xr from matplotlib.animation import FFMpegWriter, FuncAnimation from qcore import cli, coordinates -from qcore.xyts import XYTSFile from source_modelling import srf from workflow.realisations import DomainParameters, SourceConfig @@ -195,23 +195,6 @@ def plot_cartographic_features(ax: plt.Axes, scale: str) -> list: return features -def xyts_nztm_corners(xyts_file: XYTSFile) -> np.ndarray: - """Get the corners of the XYTS file in NZTM coordinates. - - Parameters - ---------- - xyts_file : XYTSFile - The XYTS file to get the corners from. - - Returns - ------- - np.ndarray - The corners of the XYTS file in NZTM coordinates. - """ - corners_geo = np.array(xyts_file.corners()) - return coordinates.wgs_depth_to_nztm(corners_geo[:, ::-1])[:, ::-1] - - def map_extents( nztm_corners: np.ndarray, padding: float ) -> tuple[float, float, float, float]: @@ -311,198 +294,31 @@ def waveform_coordinates(nztm_corners: np.ndarray, nx: int, ny: int) -> np.ndarr return coords_nztm[::-1, :, :] # Reverse order to (x, y) for NZTM -def tslice_get(xyts_file: XYTSFile, index: int, downsample: int = 1) -> np.ndarray: - """Retrieve a single timeslice from an xyts file with downsampling - - Parameters - ---------- - xyts_file : XYTSFile - The xyts file to retrieve from. - index : int - The timeslice index to read from. - downsample : int - If greater than 1, downsample the array in strides of `downsample` in - the x and y direction. - - Returns - ------- - array of float32 - An array of shape (ny, nx) containing the downsampled frame data for `index`. - """ - if downsample > 1: - frame_data = xyts_file.data[index, :, ::downsample, ::downsample] - else: - frame_data = xyts_file.data[index] # shape: (3, ny, nx) - return np.linalg.norm(frame_data, axis=0) - - -def render_single_frame( - frame_index: int, - dt: float, - xyts_file_path: Path, - source_config: SourceConfig, - nztm_corners: np.ndarray, - map_extent_nztm: tuple[float, float, float, float], - xr: np.ndarray, - yr: np.ndarray, - max_motion: float, - cmap: str, - shading: str, - simple_map: bool, - scale: str, - map_quality: int, - title: str | None, - width: float, - height: float, - dpi: int, - downsample: int, -) -> bytes: - """Render a single frame of the animation. - - Parameters - ---------- - frame_index : int - The index of the frame to render. - dt : float - The time step of the simulation. - xyts_file_path : Path - The path to the XYTS file. - source_config : SourceConfig - The source configuration object. - nztm_corners : np.ndarray - The corners of the XYTS domain in NZTM coordinates. - map_extent_nztm : tuple[float, float, float, float] - The map extents for the figure (x_min, x_max, y_min, y_max). - xr : np.ndarray - The x coordinates of the gridpoints in NZTM coordinates. - yr : np.ndarray - The y coordinates of the gridpoints in NZTM coordinates. - max_motion : float - The maximum ground motion value for color scaling. - cmap : str - The colormap to use for the animation. - shading : str - The shading to apply to the colourmap. - simple_map : bool - If True, disable OpenStreetMap background and use a simple map. - scale : str - The scale for cartographic features. - map_quality : int - The quality of the map (lower values are lower quality). - title : str | None - The title for the animation. - width : float - The width of the figure in cm. - height : float - The height of the figure in cm. - dpi : int - The DPI for the figure. - downsample : int, optional - If greater than 1, downsample the timeslice array in strides of - `downsample` in the x and y direction. Provides a speedup for large - domains. - - Returns - ------- - bytes - The raw frame output for the frame index - """ - xyts_file = XYTSFile(xyts_file_path) - # Create a new figure for this frame - cm = 1 / 2.54 - fig = plt.figure(figsize=(width * cm, height * cm)) - ax = fig.add_subplot(1, 1, 1, projection=NZTM_CRS) - ax.set_extent(map_extent_nztm, crs=NZTM_CRS) - - # Add all static elements - if simple_map: - plot_cartographic_features(ax, scale) - plot_towns(ax, map_extent_nztm) - else: - request = cimgt.OSM(cache=True) - request._MAX_THREADS = ( - 1 # Limit to one thread because it is in a multiprocess pool. - ) - ax.add_image( - request, - 10, - interpolation="spline36", - regrid_shape=map_quality * 1000, - zorder=0, - ) - - ax.add_geometries( - [shapely.Polygon(nztm_corners)], - facecolor="none", - edgecolor="black", - linestyle="--", - zorder=1, - crs=NZTM_CRS, - ) - - ax.add_geometries( - [ - shapely.transform(fault.geometry, lambda coords: coords[:, ::-1]) - for fault in sorted( - source_config.source_geometries.values(), - key=lambda fault: -fault.centroid[-1], - ) - ], - facecolor="red", - edgecolor="black", - zorder=2, - crs=NZTM_CRS, - ) - - # Add the actual data for this frame - - current_data = tslice_get(xyts_file, frame_index, downsample=downsample) - pcm = ax.pcolormesh( - yr[::downsample, ::downsample], - xr[::downsample, ::downsample], - apply_cmap_with_alpha(current_data, 0, max_motion, cmap=cmap), - cmap=cmap, - vmin=0, - vmax=max_motion, - shading=shading, - zorder=3, - transform=NZTM_CRS, - ) - - # Add time text - current_time = frame_index * dt - ax.text( - 0.98, - 0.02, - f"Time: {current_time:.2f} s", - transform=ax.transAxes, - fontsize=12, - color="black", - fontweight="bold", - ha="right", - va="bottom", - bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.8}, - ) - - if title: - fig.suptitle(title, fontsize=16) - - plt.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) - cbar = fig.colorbar( - pcm, - ax=ax, - orientation="vertical", - pad=0.02, - aspect=30, - shrink=0.8, - ) - cbar.set_label("Ground Motion (cm/s)") - - # Save the frame to a file - with io.BytesIO() as io_buf: - fig.savefig(io_buf, format="raw", dpi=dpi) - plt.close(fig) - return io_buf.getvalue() +def cmap_from_cpt(cpt_file: Path) -> tuple[float, float, LinearSegmentedColormap]: + line_re = r"^(?P[0 -9\.\-]+)\s+(?P\d+)/(?P\d+)/(?P\d+)" + + colours = [] + elevations = [] + with open(cpt_file, "r") as f: + for line in f: + if m := re.match(line_re, line): + elevation = float(m.group("el")) + r = int(m.group("r")) + g = int(m.group("g")) + b = int(m.group("b")) + elevations.append(elevation) + colours.append((r, g, b)) + colours_np = np.array(colours) / 255.0 + elevations_np = np.array(elevations) + el_min = elevations_np.min() + el_max = elevations_np.max() + normalised = (elevations_np - el_min) / (el_max - el_min) + segmentdata = { + "red": [(frac, c[0], c[0]) for frac, c in zip(normalised, colours_np)], + "green": [(frac, c[1], c[1]) for frac, c in zip(normalised, colours_np)], + "blue": [(frac, c[2], c[2]) for frac, c in zip(normalised, colours_np)], + } + return el_min, el_max, LinearSegmentedColormap(cpt_file.stem, segmentdata) @cli.from_docstring(app, name="xyts") @@ -583,92 +399,163 @@ def animate_low_frequency( ) raise typer.Exit(code=1) + dem_dataset = xr.open_dataarray("dem.h5") + x_coords = dem_dataset.x.values + y_coords = dem_dataset.y.values + x, y = np.meshgrid(x_coords, y_coords) + z = dem_dataset.to_numpy() + z = np.where(np.isnan(z), -250, z) + z_max = z.max() + dem_grid = pv.StructuredGrid(x, y, z) + dem_grid["elevation"] = z.T.ravel() + source_config = SourceConfig.read_from_realisation(realisation_ffp) - xyts_file = XYTSFile(xyts_ffp) + planes = [] + lines = [] + plane_max = 0 + for fault in source_config.source_geometries.values(): + for plane in fault.planes: + bounds = plane.bounds + plane_max = max(plane_max, bounds[-1, 2]) + bounds[:, [0, 1]] = bounds[:, [1, 0]] + bounds[:, 2] = z_max + 5 + planes.extend( + [ + pv.Triangle([bounds[0], bounds[1], bounds[-1]]), + pv.Triangle([bounds[1], bounds[2], bounds[-1]]), + ] + ) + point_a = bounds[0].copy() + point_a[2] = z_max + 10 + point_b = bounds[1].copy() + point_b[2] = z_max + 10 + lines.extend([point_a, point_b]) + + cmap_min, cmap_max, cmap = cmap_from_cpt( + Path("/home/jake/tmp/palm_springs_nz_topo.cpt") + ) + xyts_dataset = xr.open_dataset(xyts_ffp) + (nt, ny, nx) = xyts_dataset.waveform.shape + proj = coordinates.SphericalProjection( + xyts_dataset.mlon, xyts_dataset.mlat, xyts_dataset.mrot + ) + dx = xyts_dataset.dx - nztm_corners = xyts_nztm_corners(xyts_file) - map_extent_nztm = map_extents(nztm_corners, padding) + y_sim_bounds = np.linspace(-0.5, 0.5, num=ny) * ny * (dx * 5) + x_sim_bounds = np.linspace(-0.5, 0.5, num=nx) * nx * (dx * 5) - if zoom != 1: - centre = shapely.centroid( - shapely.union_all( - [fault.geometry for fault in source_config.source_geometries.values()] - ) - ) - map_extent_nztm = zoom_extents( - map_extent_nztm, - (centre.y, centre.x), - zoom, - ) + y_sim, x_sim = np.meshgrid(y_sim_bounds, x_sim_bounds, indexing="ij") + print(y_sim.shape) - frame_count = frame_count or xyts_file.nt - xr, yr = waveform_coordinates(nztm_corners, xyts_file.nx, xyts_file.ny) + x_flat = x_sim.ravel(order="F") + y_flat = y_sim.ravel(order="F") - render_frame = functools.partial( - render_single_frame, - dt=xyts_file.dt, - shading=shading, - xyts_file_path=xyts_ffp.resolve(), - max_motion=max_motion, - cmap=cmap, - source_config=source_config, - nztm_corners=nztm_corners, - map_extent_nztm=map_extent_nztm, - xr=xr, - yr=yr, - simple_map=simple_map, - scale=scale, - map_quality=map_quality, - title=title, - width=width, - height=height, - dpi=dpi, - downsample=downsample, + points = proj.inverse(x_flat, y_flat) + lon_sim = points[:, 1] + lat_sim = points[:, 0] + + proj = pyproj.Transformer.from_crs(4326, 2193, always_xy=True) + x_nztm_flat, y_nztm_flat = proj.transform(lon_sim, lat_sim) + + x_nztm = x_nztm_flat.reshape((ny, nx), order="F") + y_nztm = y_nztm_flat.reshape((ny, nx), order="F") + corners = np.array( + [ + [x_nztm[0, 0], y_nztm[0, 0], z_max], + [x_nztm[-1, 0], y_nztm[-1, 0], z_max], + [x_nztm[-1, -1], y_nztm[-1, -1], z_max], + [x_nztm[0, -1], y_nztm[0, -1], z_max], + [x_nztm[0, 0], y_nztm[0, 0], z_max], + ] ) + z_plane = np.full((ny, nx), z_max + ((1 << 16) - 1) * 0.1) - # warm the OSM cache to speed up rendering by rendering the first frame + grid = pv.StructuredGrid(x_nztm, y_nztm, z_plane) - frames = [render_frame(0)] + grid["Ground Motion (cm/s)"] = grid.points[:, -1] + plotter = pv.Plotter(notebook=False, off_screen=True) + plotter.remove_all_lights() + plotter.ren_win.SetSize([1920, 1088]) + # plotter.enable_anti_aliasing() - with mp.Pool() as pool: - # Render all frames in parallel - frames.extend( - tqdm.tqdm( - pool.imap(render_frame, range(frame_start, frame_start + frame_count)), - total=frame_count, - unit="frame", - desc="Rendering frames", - initial=1, - ) - ) - cm = 1 / 2.54 - width_px = int(width * cm * dpi) - height_px = int(height * cm * dpi) - # Use ffmpeg to combine frames into video - process = ( - ffmpeg.input( - "pipe:0", format="rawvideo", pix_fmt="rgba", s=f"{width_px}x{height_px}" - ) - .output( - str(output_mp4), - pix_fmt="yuv420p", - r=fps, - vcodec="libx264", - crf=23, - vf="pad=ceil(iw/2)*2:ceil(ih/2)*2", - ) - .overwrite_output() - .run_async(pipe_stdin=True) + plotter.open_movie(output_mp4, framerate=fps, quality=10) + for plane in planes: + plotter.add_mesh(plane, color="red", lighting=False) + + plotter.add_lines(corners, connected=True, color="black", width=3) + plotter.add_lines(np.array(lines), color="black", width=2) + + plotter.add_mesh( + dem_grid, + lighting=False, + smooth_shading=False, + cmap=cmap, + clim=(cmap_min, cmap_max), + show_scalar_bar=False, ) + plotter.add_mesh( + grid, + lighting=False, + smooth_shading=False, + scalars="Ground Motion (cm/s)", + clim=[0, 100], + cmap="hot", + show_edges=False, + nan_opacity=0.0, + show_scalar_bar=True, + scalar_bar_args=dict( + title="Ground Motion\n(cm/s)\n", + vertical=True, + position_x=0.85, + position_y=0.25, + bold=True, + color="white", + height=0.5, + width=0.05, + n_labels=11, + ), + ) + + plotter.camera.tight() + # plotter.camera.set + plotter.set_background((173 / 255, 216 / 255, 230 / 255)) + text = plotter.add_text("0.00s", name="time-label", position="lower_right") + + plotter.show(auto_close=False) + + frame_chunk = 100 + i_frame = frame_start + frames = None + scalars = np.empty((ny, nx), dtype=np.float32, order="F") + time = xyts_dataset.time.values + dt = xyts_dataset.attrs["dt"] + for i in tqdm.trange( + frame_start, + frame_start + min(frame_count or nt - frame_start, nt - frame_start), + ): + if i >= i_frame: + next = min(i_frame + frame_chunk, nt) + frames = ( + xyts_dataset.waveform.isel(time=range(i_frame, next)) + .astype(np.float32) + .values + ) + i_frame = next + assert frames is not None + + z_geometry = frames[i - i_frame] + np.copyto(scalars, frames[i - i_frame]) - # Write the raw video data to FFmpeg's stdin - for frame in frames: - process.stdin.write(frame) + scalars[scalars < 10] = np.nan - process.stdin.close() + grid["Ground Motion (cm/s)"] = scalars.ravel(order="F") - # Wait for FFmpeg to finish - process.wait() + np.add(z_geometry, z_max, out=z_geometry) + grid.points[:, -1] = z_geometry.ravel(order="F") + if i % round(1 / dt): + text.set_text("lower_right", f"{time[i]:.2f}s") + plotter.write_frame() + plotter.close() def non_zero_data_points( diff --git a/visualisation/realisation.py b/visualisation/realisation.py index a670407..f7aa30c 100644 --- a/visualisation/realisation.py +++ b/visualisation/realisation.py @@ -17,6 +17,7 @@ from pygmt_helper import plotting from qcore import cli +from source_modelling.sources import Fault, Plane from visualisation import utils from workflow.realisations import ( DomainParameters, @@ -78,9 +79,7 @@ def plot_stations( ) -def plot_sources( - fig: pygmt.Figure, source_config: SourceConfig, **kwargs: dict[str, Any] -) -> None: +def plot_sources(fig: pygmt.Figure, source_config: SourceConfig, **kwargs: Any) -> None: """Plot the sources on the figure. Parameters @@ -91,7 +90,8 @@ def plot_sources( The source configuration to plot. **kwargs : dict Additional keyword arguments to pass to the plotting function. If empty, the default is - - `pen="0.3p,black"` (polygon border colour) + - `pen="0.3p,black,--"` (polygon border colour) + - trace pen is found by taking the pen and stripping the "--" Examples -------- @@ -102,16 +102,29 @@ def plot_sources( >>> plot_sources(fig, source_config) >>> source_config.show() """ - kwargs = {"pen": "0.3p,black", **(kwargs or {})} + pen = kwargs.get("pen", "0.3p,black,--") + assert isinstance(pen, str) + trace_pen = pen.removesuffix(",--") + interior_kwargs = {"pen": pen, **(kwargs or {})} for source in source_config.source_geometries.values(): - utils.plot_polygon(fig, utils.polygon_nztm_to_pygmt(source.geometry), **kwargs) + utils.plot_polygon( + fig, utils.polygon_nztm_to_pygmt(source.geometry), **interior_kwargs + ) + if isinstance(source, Fault): + trace = shapely.LineString( + np.concatenate([plane.bounds[:2] for plane in source.planes]) + ) + utils.plot_polygon(fig, utils.polygon_nztm_to_pygmt(trace), pen=trace_pen) + elif isinstance(source, Plane): + trace = shapely.LineString(source.bounds[:2]) + utils.plot_polygon(fig, utils.polygon_nztm_to_pygmt(trace), pen=trace_pen) def plot_domain( fig: pygmt.Figure, domain_parameters: DomainParameters, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> None: """Plot the domain on a figure. @@ -211,7 +224,6 @@ def plot_realisation( title: str | None = None, subtitle: str | None = None, width: float = 10, - show_geonet_stations: bool = False, show_geometry: bool = True, show_pgv_targets: bool = False, pgv_targets: list[float] | None = None, @@ -233,8 +245,6 @@ def plot_realisation( Subtitle of the plot. width : float Width of the plot in cm. - show_geonet_stations : bool - Show GeoNet stations on the plot. show_geometry : bool Show source geometry on the plot. show_pgv_targets : bool @@ -263,20 +273,21 @@ def plot_realisation( >>> fig.show() """ show_pgv_targets = show_pgv_targets or bool(pgv_targets) - rupture_propagation = RupturePropagationConfig.read_from_realisation( - realisation_ffp - ) domain_parameters = DomainParameters.read_from_realisation(realisation_ffp) - velocity_model_parameters = VelocityModelParameters.read_from_realisation( - realisation_ffp - ) - source_config = SourceConfig.read_from_realisation(realisation_ffp) rrup_bounding_polygons: list[shapely.Polygon] = [] if show_pgv_targets: + rupture_propagation = RupturePropagationConfig.read_from_realisation( + realisation_ffp + ) + + velocity_model_parameters = VelocityModelParameters.read_from_realisation( + realisation_ffp + ) + fault_pgv_targets = pgv_targets or [ generate_velocity_model_parameters.pgv_target( rupture_propagation, velocity_model_parameters @@ -449,7 +460,6 @@ def plot_realisation_to_file( title=title, subtitle=subtitle, width=width, - show_geonet_stations=show_geonet_stations, show_geometry=show_geometry, show_pgv_targets=show_pgv_targets, pgv_targets=pgv_targets, diff --git a/visualisation/sources/plot_srf.py b/visualisation/sources/plot_srf.py index 54d75a8..fc28dbb 100644 --- a/visualisation/sources/plot_srf.py +++ b/visualisation/sources/plot_srf.py @@ -61,9 +61,8 @@ def show_slip( region: tuple[float, float, float, float], srf_data: srf.SrfFile, annotations: bool, - projection: str = "M?", + contours: bool, realisation_ffp: Optional[Path] = None, - title: Optional[str] = None, ): """Show a slip map with optional contours. @@ -94,31 +93,11 @@ def show_slip( >>> show_slip(fig, region, srf_data, annotations=True, projection="M6i", title="Slip Distribution") >>> fig.show() # Displays the slip map with optional annotations """ - subtitle = utils.format_description( - srf_data.points["slip"], units="cm", compact=True - ) - title_args = [f"+t{title}+s{subtitle}".replace(" ", r"\040")] if title else [] # Compute slip limits - fig.basemap( - region=region, - projection=projection, - frame=plotting.DEFAULT_PLT_KWARGS["frame_args"] + title_args, - ) - fig.coast( - shorelines=["1/0.1p,black", "2/0.1p,black"], - resolution="f", - land="#666666", - water="skyblue", - ) - - slip_quantile = srf_data.points["slip"].quantile(0.98) + slip_quantile = srf_data.points["slip"].max() slip_cb_max = max(int(np.round(slip_quantile, -1)), 10) cmap_limits = (0, slip_cb_max, slip_cb_max / 10) - slip_stats = utils.format_description( - srf_data.points["slip"], compact=True, units="cm" - ) dx = srf_data.header.iloc[0]["len"] / srf_data.header.iloc[0]["nstk"] - subtitle = f"Slip: {slip_stats}, dx = {dx:.2f} km, {len(srf_data.header)} planes" grid_scale = min(utils.grid_scale_for_region(region), dx * 1000) for (_, segment), segment_points in zip( srf_data.header.iterrows(), srf_data.segments @@ -150,23 +129,24 @@ def show_slip( reverse_cmap=True, plot_contours=False, cb_label="Slip (cm)", - continuous_cmap=True, ) - # Plot time contours - time_grid = plotting.create_grid( - segment_points, - "tinit", - grid_spacing=f"{grid_scale}e/{grid_scale}e", - region=( - segment_points["lon"].min(), - segment_points["lon"].max(), - segment_points["lat"].min(), - segment_points["lat"].max(), - ), - set_water_to_nan=False, - ) - fig.grdcontour(levels=1, grid=time_grid, pen="0.1p") + if contours: + # Plot time contours + time_grid = plotting.create_grid( + segment_points, + "tinit", + grid_spacing=f"{grid_scale}e/{grid_scale}e", + region=( + segment_points["lon"].min(), + segment_points["lon"].max(), + segment_points["lat"].min(), + segment_points["lat"].max(), + ), + set_water_to_nan=False, + ) + + fig.grdcontour(levels=1, grid=time_grid, pen="0.1p") # Plot bounds of the current segment. corners = segment_points.iloc[[0, nstk - 1, -1, (ndip - 1) * nstk]] @@ -290,8 +270,13 @@ def plot_srf( latitude_pad: Annotated[float, typer.Option()] = 0, longitude_pad: Annotated[float, typer.Option()] = 0, annotations: Annotated[bool, typer.Option()] = True, + contours: Annotated[bool, typer.Option()] = True, width: Annotated[float, typer.Option(min=0)] = 17, show_inset: bool = False, + min_lat: float | None = None, + max_lat: float | None = None, + min_lon: float | None = None, + max_lon: float | None = None, ) -> None: """Plot multi-segment rupture with slip. @@ -342,24 +327,35 @@ def plot_srf( """ srf_data = srf.read_srf(srf_ffp) region = ( - srf_data.points["lon"].min() - longitude_pad, - srf_data.points["lon"].max() + longitude_pad, - srf_data.points["lat"].min() - latitude_pad, - srf_data.points["lat"].max() + latitude_pad, + min_lon or srf_data.points["lon"].min() - longitude_pad, + max_lon or srf_data.points["lon"].max() + longitude_pad, + min_lat or srf_data.points["lat"].min() - latitude_pad, + max_lat or srf_data.points["lat"].max() + latitude_pad, + ) + use_high_res = (region[1] - region[0]) * (region[3] - region[2]) < 0.5 + fig = plotting.gen_region_fig( + title, + region, + high_res_topo=use_high_res, + projection=f"M{width}c", + subtitle=utils.format_description( + srf_data.points["slip"], units="cm", compact=True + ), + config_options=dict( + FONT_SUBTITLE="9p,Helvetica,black", + FORMAT_GEO_MAP="ddd.xx", + MAP_FRAME_TYPE="plain", + ), + plot_kwargs=dict(water_color="white", topo_cmap_min=-900, topo_cmap_max=3100), ) - - fig = pygmt.Figure() - - pygmt.config(FONT_SUBTITLE="9p,Helvetica,black", FORMAT_GEO_MAP="ddd.xx") show_slip( fig, region, srf_data, annotations, - projection=f"M{width}c", + contours, realisation_ffp=realisation_ffp, - title=title, ) if show_inset: with fig.inset(position=f"jTR+w{np.sqrt(width)}c", margin=0.2): diff --git a/visualisation/stations/plot_stations.py b/visualisation/stations/plot_stations.py new file mode 100644 index 0000000..e86badf --- /dev/null +++ b/visualisation/stations/plot_stations.py @@ -0,0 +1,107 @@ +from pathlib import Path + +import pandas as pd +import pygmt +import typer + +from pygmt_helper import plotting +from visualisation import realisation, utils +from workflow.realisations import DomainParameters, SourceConfig + +app = typer.Typer() + + +def plot_towns(fig: pygmt.Figure): + towns = { + "Blenheim": (173.9569444, -41.5138888), + "Christchurch": (172.6347222, -43.5313888), + "Dunedin": (170.3794444, -45.8644444), + "Greymouth": (171.2063889, -42.4502777), + "Haast": (169.0405556, -43.8808333), + "Kaikoura": (173.6802778, -42.4038888), + "Masterton": (175.658333, -40.952778), + "Napier": (176.916667, -39.483333), + "New Plymouth": (174.083333, -39.066667), + "Nelson": (173.2838889, -41.2761111), + "Palmerston North": (175.611667, -40.355000), + "Queenstown": (168.6680556, -45.0300000), + "Rakaia": (172.0230556, -43.75611111), + "Rotorua": (176.251389, -38.137778), + "Taupo": (176.069400, -38.6875), + "Tekapo": (170.4794444, -44.0069444), + "Timaru": (171.2430556, -44.3958333), + "Wellington": (174.777222, -41.288889), + "Westport": (171.5997222, -41.7575000), + } + for label, (lon, lat) in towns.items(): + fig.plot(x=lon, y=lat, style="c0.1c", fill="white", pen="0.3p,black") + fig.text( + x=lon, y=lat, text=label, justify="BC", offset="0.15c", font="6p,black" + ) + + +@app.command(help="Test") +def plot_stations( + realisation_ffp: Path, + stations_ll: Path, + stations_vs30: Path, + cmap: str, + title: str, + width: float, +): + domain_parameters = DomainParameters.read_from_realisation(realisation_ffp) + source_config = SourceConfig.read_from_realisation(realisation_ffp) + stations_vs30_df = pd.read_csv( + stations_vs30, + delimiter=r"\s+", + comment="#", + names=["name", "vs30"], + header=None, + ).set_index("name") + stations_ll_df = pd.read_csv( + stations_ll, + delimiter=r"\s+", + comment="#", + names=["lon", "lat", "name"], + header=None, + ).set_index("name") + stations_vs30_df = stations_vs30_df.loc[stations_ll_df.index] + region = utils.bounding_region_for([domain_parameters.domain.polygon], 0, 0) + fig = plotting.gen_region_fig( + title, + region, + projection=f"M{width}c", + plot_kwargs=dict(water_color="white", topo_cmap_min=-900, topo_cmap_max=3100), + plot_highways=False, + ) + + realisation.plot_domain(fig, domain_parameters, pen="1p,black,-") + realisation.plot_sources(fig, source_config, fill="blue") + + pygmt.makecpt( + cmap=cmap.removesuffix("_r"), + series=[0, 1500, 100], + reverse=cmap.endswith("_r"), + ) + realisation.plot_stations( + fig, + domain_parameters, + stations_ll, + fill=stations_vs30_df["vs30"], + cmap=True, + style="i0.3c", + ) + fig.text( + x=stations_ll_df["lon"], + y=stations_ll_df["lat"], + text=stations_ll_df.index, + justify="BC", + offset="0.25c", + font="6p,black", + ) + fig.colorbar(frame="xaf+lVs30 (m/s)") + fig.savefig("map.png") + + +if __name__ == "__main__": + app() diff --git a/visualisation/utils.py b/visualisation/utils.py index 7a2694a..7f7fe4e 100644 --- a/visualisation/utils.py +++ b/visualisation/utils.py @@ -1,13 +1,31 @@ """Utility functions common to many plotting scripts.""" -from typing import Optional +from collections.abc import Sequence +from typing import Any, Literal, NamedTuple, Optional, TypedDict, Unpack import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw +import pandas as pd import pygmt import scipy as sp import shapely +import xarray as xr +from matplotlib import colors +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from rpy2.robjects import default_converter, globalenv, numpy2ri, r +from rpy2.robjects.conversion import localconverter from qcore import coordinates +from source_modelling import moment +from workflow.realisations import ( + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) def format_description( @@ -149,7 +167,7 @@ def _point_on_polygon(t: float, polygon: shapely.Polygon) -> shapely.Point: def _hausdorff_maximisation( polygon: shapely.Polygon, other_geom: shapely.Polygon -) -> shapely.Point: +) -> tuple[shapely.Point, float]: """Finds the point on polygon maximizing the distance to other_geom. Parameters @@ -163,6 +181,8 @@ def _hausdorff_maximisation( ------- shapely.Point Point on the polygon boundary maximizing the distance to other_geom. + float + The distance from the point to other_geom See Also -------- @@ -171,11 +191,11 @@ def _hausdorff_maximisation( def objective(t: float) -> float: # numpydoc ignore=GL08 point = _point_on_polygon(t, polygon) - return -point.distance(other_geom.exterior) # Negative because we maximize + return -point.distance(other_geom.exterior) # Negative because we maximise result = sp.optimize.minimize_scalar(objective, bounds=(0, 1), method="bounded") if result.success: - return _point_on_polygon(result.x, polygon), -result.fun + return _point_on_polygon(float(result.x), polygon), -result.fun else: raise RuntimeError("Optimisation failed") @@ -316,3 +336,357 @@ def grid_scale_for_region(region: tuple[float, float, float, float]) -> int: lon_km = (max_lon - min_lon) * 111 * np.cos(np.radians((min_lat + max_lat) / 2)) maximum_extent = max(lat_km, lon_km) return int(round(max(5, 2.5 * maximum_extent))) + + +class SubplotsKwargs(TypedDict, total=False): + sharex: bool | Literal["none", "all", "row", "col"] + sharey: bool | Literal["none", "all", "row", "col"] + subplot_kw: dict[str, Any] | None + gridspec_kw: dict[str, Any] | None + figsize: tuple[float, float] + constrained_layout: bool | dict[str, Any] + layout: Literal["constrained", "compressed", "tight"] | None + + +def balanced_subplot_grid( + n_subplots: int, + aspect: float, + subplot_size: tuple[float, float] | None = None, + squeeze: bool = False, + clear: bool = False, + **kwargs: Unpack[SubplotsKwargs], +) -> tuple[ + Figure, + Sequence[Axes], # Although we are really returning a numpy array, numpy does + # not support generic Axes objects in NDArray. + # numpy/numpy#24738 +]: + # This has more columns than rows, i.e. wide + height = np.sqrt(n_subplots / aspect) + rows = int(np.ceil(height)) + columns = int(np.ceil(height * aspect)) + # Ensures there are no blank rows. + rows -= max(0, (rows * columns - n_subplots) // columns) + + if subplot_size: + width, height = subplot_size + kwargs["figsize"] = (width * columns, height * rows) + + fig, axes = plt.subplots(rows, columns, **kwargs) + if n_subplots == 1: + axes = np.atleast_2d([axes]) + if squeeze: + axes = axes.squeeze() + if clear: + for ax in axes.flatten()[n_subplots:]: + ax.remove() + + return fig, axes + + +class RuptureContext(TypedDict): + mag: float + rake: float + dip: float + hypo_depth: float + ztor: float + zbot: float + + +class SiteProperties(TypedDict): + vs30measured: bool | npt.NDArray[np.bool_] + vs30: float | npt.NDArray[np.floating] + z1pt0: float | npt.NDArray[np.floating] + z2pt5: float | npt.NDArray[np.floating] + rrup: float | npt.NDArray[np.floating] + rjb: float | npt.NDArray[np.floating] + rx: float | npt.NDArray[np.floating] + ry: float | npt.NDArray[np.floating] + + +def circmean( + samples: npt.NDArray[np.floating], weights: npt.NDArray[np.floating] +) -> float: + x = np.cos(samples) + y = np.sin(samples) + z = weights * np.array([x, y]) + mean_resultant_vector = np.mean(z, axis=1) + argument = np.arctan2(mean_resultant_vector[1], mean_resultant_vector[0]) + return float(argument) + + +def compute_site_properties(sites: xr.Dataset) -> SiteProperties: + vs30 = sites.vs30.values + z1pt0 = oqw.estimations.chiou_young_08_calc_z1p0(vs30) + z2pt5 = oqw.estimations.chiou_young_08_calc_z2p5(vs30) + rrup = sites.rrup.values + rjb = sites.rjb.values + rx = rjb + ry = rjb + vs30measured = False + return SiteProperties( + vs30measured=vs30measured, + vs30=vs30, + z1pt0=z1pt0, + z2pt5=z2pt5, + rrup=rrup, + rjb=rjb, + rx=rx, + ry=ry, + ) + + +def get_gmm_prediction( + sites: xr.Dataset, + period: float, + source_config: SourceConfig, + magnitudes: Magnitudes, + rakes: Rakes, + rupture_propagation: RupturePropagationConfig, +) -> pd.Series: + site_properties = compute_site_properties(sites) + rupture_context = compute_rupture_context( + source_config, magnitudes, rakes, rupture_propagation + ) + gmm_df = nshm2022_logic_tree_prediction(rupture_context, site_properties, period) + gmm_df["station"] = sites.station.values + gmm_df = gmm_df.set_index("station") + gmm_psa_value = gmm_df.loc[:, gmm_df.columns.str.endswith("_mean")].squeeze() + return gmm_psa_value + + +def compute_rupture_context( + source_config: SourceConfig, + magnitudes_config: Magnitudes, + rakes_config: Rakes, + rupture_propagation: RupturePropagationConfig, +) -> RuptureContext: + moments = [] + dips = [] + rakes = [] + + for name, source in source_config.source_geometries.items(): + dips.append(source.dip) + moments.append(moment.magnitude_to_moment(magnitudes_config[name])) + rakes.append(rakes_config[name]) + + ztor = ( + min( + source_config.source_geometries.values(), key=lambda source: source.top_m + ).top_m + / 1000 + ) + zbot = ( + max( + source_config.source_geometries.values(), key=lambda source: source.bottom_m + ).bottom_m + / 1000 + ) + avg_rake = np.degrees(circmean(np.radians(rakes), np.array(moments))) + avg_dip = float(np.average(dips, weights=moments)) + avg_moment = float(np.mean(moments)) + total_moment = avg_moment * len(moments) + magnitude = moment.moment_to_magnitude(total_moment) + initial_source = source_config.source_geometries[rupture_propagation.initial_fault] + hypocentre = initial_source.fault_coordinates_to_wgs_depth_coordinates( + rupture_propagation.hypocentre + ) + hypo_depth = float(hypocentre[2]) + hypo_depth /= 1000.0 + return RuptureContext( + mag=magnitude, + rake=avg_rake, + dip=avg_dip, + hypo_depth=hypo_depth, + ztor=ztor, + zbot=zbot, + ) + + +def mean_vs30(site_vs30: npt.NDArray[np.floating]) -> float: + # Calculate geometric mean of site vs30 using the exponential-log form: + # exp(1/n sum vs30) + # This is as opposed to straight-forward calculation + # product(vs30) ^ (1/n) + # Which is numerically unstable for a large number of stations due to + # floating-point arithmetic overflow and imprecision at large values + # obtained by multiplication. + + return np.exp(1 / len(site_vs30) * np.sum(np.log(site_vs30))) + + +def adjust_value(colour: npt.ArrayLike, gamma: float) -> npt.NDArray[np.float64]: + """Adjust the brightness of an RGB colour. + + Parameters + ---------- + colour : npt.ArrayLike + Colour to transform (in RGB format). + gamma : float + The brightness to adjust by. Adjustment is multiplicative, so + ``gamma=1`` is equivalent to no change. + + Returns + ------- + npt.NDArray[np.float64] + A brightness adjusted equivalent of `colour` with no change in + hue or saturation. + """ + colour = np.asarray(colour) + # Naive colour brightness adjustment would simply multiply colour + # by gamma. However, this also changes the hue of the colour, + # resulting in visually incorrect results. HSV is designed for + # this manipulation. + hsv = colors.rgb_to_hsv(colour) + # HSV colour scale represents every colour as a combination of three components: + # 1. (H)ue, the quality of the colour (blue, red, green, magenta, etc). + # 2. (S)aturation, how intense that colour is at a fixed + # brightness (e.g. black has zero saturation and looks the same + # regardless of brightness). + # 3. (V)alue, the brightness of the colour + # + # We just want to adjust the brightness. To do this we adjust the value component. + hsv[-1] *= gamma + return colors.hsv_to_rgb(hsv) + + +def nice_num(x: float, round: bool) -> float: + """Find an equivalent "nice number" for `x`. + + A nice number is a number that a power-of-ten multiple of 1, 2, or + 5. See: https://stackoverflow.com/a/16363437 + + Parameters + ---------- + x : float + The number to find a nice number for. + round : bool + If true, round toward the nearest nice number. Otherwise, + find the next largest. + + + Returns + ------- + float + The nearest or the next largest nice number. + """ + exponent = np.floor(np.log10(x)) + fraction = x / (10**exponent) + if round: + if fraction < 1.5: + nice_fraction = 1 + elif fraction < 3: + nice_fraction = 2 + elif fraction < 7: + nice_fraction = 5 + else: + nice_fraction = 10 + else: + if fraction <= 1: + nice_fraction = 1 + elif fraction <= 2: + nice_fraction = 2 + elif fraction <= 5: + nice_fraction = 5 + else: + nice_fraction = 10 + return nice_fraction * 10**exponent + + +class PlotDomain(NamedTuple): + """Named tuple representing plot domain with ticks.""" + + low: float + high: float + spacing: float + + +def range_for(low: float, high: float, max_ticks: int) -> tuple[float, float, float]: + """Given a plotting range and a fixed number of ticks, return its "nicest" representation. + + See: https://stackoverflow.com/a/16363437 + + Parameters + ---------- + low : float + The lower bound of the plotting range. + high : float + The upper bound of the plotting range. + max_ticks : int + An upper bound on the number of ticks desired. The number of + ticks returned is usually *less* than this value. + + Returns + ------- + PlotDomain + The lower bound, upper bound and tick spacing corresponding to + a "nice" representation of this domain. + """ + range = nice_num(high - low, False) + tick_spacing = nice_num(range / (max_ticks - 1), True) + nice_min = np.floor(low / tick_spacing) * tick_spacing + nice_max = np.ceil(high / tick_spacing) * tick_spacing + return PlotDomain(nice_min, nice_max, tick_spacing) + + +def nshm2022_logic_tree_prediction( + rupture_context: RuptureContext, + site_properties: SiteProperties, + period: float, +) -> pd.DataFrame: + tect_type = oqw.constants.TectType.ACTIVE_SHALLOW + gmm_lt = oqw.constants.GMMLogicTree.NSHM2022 + rupture_df = pd.DataFrame( + {"vs30measured": False, **rupture_context, **site_properties} + ) + psa_results = oqw.run_gmm_logic_tree( + gmm_lt, tect_type, rupture_df, "pSA", periods=[period] + ) + assert isinstance(psa_results, pd.DataFrame) + for site_property in site_properties: + psa_results[site_property] = rupture_df[site_property] + return psa_results + + +class ConfidenceInterval(NamedTuple): + mean: npt.NDArray[np.floating] + std_low: npt.NDArray[np.floating] + std_high: npt.NDArray[np.floating] + + +def fit_loess_r( + y: npt.NDArray[np.floating], + x: npt.NDArray[np.floating], + x_out: npt.NDArray[np.floating], + **kwargs, +) -> ConfidenceInterval: + """ + Fit LOESS using R and return fitted values and prediction intervals. + """ + loess_args = ", ".join(f"{k}={v}" for k, v in kwargs.items()) + loess_call = f"loess(y ~ x, {loess_args})" if loess_args else "loess(y ~ x)" + + with localconverter(default_converter + numpy2ri.converter): + globalenv["x"] = x + globalenv["y"] = y + globalenv["x_out"] = x_out + + r(f"fit <- {loess_call}") + r("newdat <- data.frame(x=x_out)") + r("pred <- predict(fit, newdata=newdat, se=TRUE)") + r("residual_se <- fit$s") + + fit_vals = np.asarray(r("pred$fit")) + + residual_se_eval = r("residual_se") + if isinstance(residual_se_eval, np.ndarray) and len(residual_se_eval) == 1: + residual_se = float(residual_se_eval.item()) # type: ignore[invalid-argument-type] + else: + raise ValueError( + f"Residual stderr evaluation failed, expected float found: {residual_se_eval=}" + ) + + std_low = fit_vals - residual_se + std_high = fit_vals + residual_se + + return ConfidenceInterval(fit_vals, std_low, std_high) diff --git a/visualisation/waveforms/plot_fas.py b/visualisation/waveforms/plot_fas.py new file mode 100644 index 0000000..aed0210 --- /dev/null +++ b/visualisation/waveforms/plot_fas.py @@ -0,0 +1,213 @@ +from pathlib import Path +from typing import Annotated + +import matplotlib.pyplot as plt +import re +import numpy as np +import numpy.typing as npt +import oq_wrapper as oqw +import pandas as pd +import typer +import xarray as xr +from matplotlib.axes import Axes + +from qcore import cli +from visualisation import utils +from workflow.realisations import ( + Magnitudes, + Rakes, + RupturePropagationConfig, + SourceConfig, +) + +app = typer.Typer() + + +def plot_fas( + ax: Axes, + dataset: xr.Dataset, + station: str, + component: str, + ymax: float | None = None, + ymin: float | None = None, + **kwargs, +) -> None: + """Plot a Fourier Amplitude Spectrum (FAS) from a simulation. + + Parameters + ---------- + dataset : xr.Dataset + The dataset to read station FAS. + station : str + The station to plot. + component : str + The component to plot. + ymax : float or None + Max limit for y-axis. + ymin : float or None + Min limit for y-axis. + """ + fas = dataset.FAS.sel(station=station, component=component).values + freqs = dataset.frequency.values + ax.plot(freqs, fas, **kwargs) + if ymin is not None or ymax is not None: + ax.set_ylim(bottom=ymin, top=ymax) + ax.set_xlim(freqs.min(), 50) + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(visible=True, which="both", axis="both", lw=0.3) + + +def plot_fas_estimate( + ax: Axes, realisation_ffp: Path, dataset: xr.Dataset, station: str, **kwargs +) -> None: + freqs = dataset.frequency.values + vs30 = dataset.vs30.sel(station=station).item() + site_properties = utils.compute_site_properties(vs30) + source_config = SourceConfig.read_from_realisation(realisation_ffp) + magnitudes = Magnitudes.read_from_realisation(realisation_ffp) + rakes = Rakes.read_from_realisation(realisation_ffp) + rupture_propagation = RupturePropagationConfig.read_from_realisation( + realisation_ffp + ) + rupture_context = utils.compute_rupture_context( + source_config, magnitudes, rakes, rupture_propagation + ) + site = dataset.sel(station=station) + longitude = site.longitude.item() + latitude = site.latitude.item() + point = np.array([latitude, longitude, 0]) + rrup = ( + min( + source.rrup_distance(point) + for source in source_config.source_geometries.values() + ) + / 1000 + ) + rupture_df = pd.DataFrame( + { + "rrup": [rrup], + "vs30measured": False, + **site_properties, + **rupture_context, + } + ) + fas_results = oqw.run_gmm( + oqw.constants.GMM.BA_18, + oqw.constants.TectType.ACTIVE_SHALLOW, + rupture_df, + "EAS", + frequencies=freqs, + ) + emp_eas = [] + emp_eas_stddev = [] + for col in fas_results.columns: + if col.endswith("mean"): + emp_eas.append(fas_results[col].item()) + elif col.endswith("Total"): + emp_eas_stddev.append(fas_results[col].item()) + + emp_eas = np.array(emp_eas) + emp_eas_stddev = np.array(emp_eas_stddev) + ax.plot(freqs, np.exp(emp_eas), **kwargs) + colour = kwargs.get("c") or kwargs.get("color") + ax.fill_between( + freqs, + np.exp(emp_eas - emp_eas_stddev), + np.exp(emp_eas + emp_eas_stddev), + color=colour, + alpha=0.3, + ) + + +@cli.from_docstring(app) +def plot_fas_cli( + realisation_ffp: Annotated[Path, typer.Argument()], + dataset_path: Annotated[Path, typer.Argument()], + stations: Annotated[list[str], typer.Argument()], + title: str | None = None, + save: Path | None = None, + dpi: int = 300, + width: float = 20, + height: float = 15, + ymin: float | None = 1e-5, + ymax: float | None = 1, + component: str = "geom", +) -> None: + """Plot a station Fourier Amplitude Spectrum (FAS). + + Parameters + ---------- + dataset_path : Path + Path to HDF5 FAS dataset. + station : str + The station to plot. + """ + dset = xr.open_dataset(dataset_path, engine="h5netcdf") + cm = 1 / 2.54 + + fig, axes = utils.balanced_subplot_grid( + len(stations), + subplot_size=(width * cm, height * cm), + aspect=3 / 2, + sharex=True, + sharey=True, + clear=True, + constrained_layout=True, + ) + for station, ax in zip(stations, axes.flatten()): + plot_fas_estimate( + ax, + realisation_ffp, + dset, + station, + label="EAS (BA18; μ ± σ)", + color="blue", + ) + plot_fas( + ax, + dset, + station, + component, + ymin=ymin, + ymax=ymax, + label="EAS (Simulation)", + color="k", + ) + station_data = dset.sel(station=station) + vs30 = station_data.vs30.item() + basin = str(station_data.basin.item()) + if not basin: + basin = 'No Basin' + else: + basin = re.sub('[A-Z]', r' \g<0>', basin).lstrip() + lat = station_data.latitude.item() + lon = station_data.longitude.item() + pga = station_data.PGA.sel(component='rotd50').item() + pgv = station_data.PGV.sel(component='rotd50').item() + lon = station_data.longitude.item() + ax.set_title(f'{station}\n({lat:.3f}, {lon:.3f}) - PGA: {pga:.2g} g - PGV: {pgv:.0f} cm/s - Vs30: {vs30:.0f} m/s - Basin: {basin}') + + if axes.size > 1: + fig.supylabel(f"EAS [{component}]") + fig.supxlabel("Frequency [Hz]") + else: + ax = axes.flatten().item() + ax.set_ylabel(f"EAS [{component}]") + ax.set_xlabel("Frequency [Hz]") + + ax = axes.flatten()[0] + ax.legend() + + if title: + fig.suptitle(title) + + if save: + fig.savefig(save, dpi=dpi) + else: + fig.show() + plt.show() + + +if __name__ == "__main__": + app() diff --git a/visualisation/waveforms/plot_response_spectra.py b/visualisation/waveforms/plot_response_spectra.py new file mode 100644 index 0000000..9a8921d --- /dev/null +++ b/visualisation/waveforms/plot_response_spectra.py @@ -0,0 +1,133 @@ +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import matplotlib.pyplot as plt +import numpy.typing as npt +import scipy as sp +import typer +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from qcore import cli + +app = typer.Typer() + + +def setup_plot_styling( + ax: Axes, component: str, ymax: float | None, ymin: float | None +) -> None: + ax.set_ylabel(f"pSA [{component}, g]") + ax.set_xlabel("Period [s]") + ax.set_xscale("log") + ax.set_yscale("log") + ax.grid(visible=True, which="both", axis="both", lw=0.3) + if ymin is not None or ymax is not None: + ax.set_ylim(bottom=ymin, top=ymax) + + +def plot_spectra( + fig: Figure, + ax: Axes, + dataset: xr.Dataset, + stations: list[str], + component: str, + ymax: float | None = None, + ymin: float | None = None, + labels: list[str] | None = None, +) -> None: + """Plot a spectra from a simulation. + + Parameters + ---------- + dataset : xr.Dataset + The dataset to read station spectras. It is assumed that spectra array is in cm/s^2. + station : str + The station to plot. + units : Units + The units to plot with. + ymax : float or None + Max limit for y-axis. + ymin : float or None + Min limit for y-axis. + """ + labels = labels or stations + for label, station in zip(labels, stations): + spectra = dataset.pSA.sel(station=station, component=component).values + periods = dataset.period.values + + ax.plot(periods, spectra, label=label) + + +@cli.from_docstring(app) +def plot_spectra_cli( + dataset_paths: Annotated[list[Path], typer.Argument()], + scenarios: Annotated[list[str] | None, typer.Option("--scenario")] = None, + stations: Annotated[list[str] | None, typer.Option("--station")] = None, + title: str | None = None, + save: Path | None = None, + dpi: int = 300, + width: float = 20, + height: float = 15, + ymin: float | None = 1e-5, + ymax: float | None = 1, + component: str = "rotd50", +) -> None: + """Plot a station spectra. + + Parameters + ---------- + dataset_path : Path + Path to HDF5 spectra dataset. + station : str + The station to plot. + units : Units + The units to plot in. + """ + if not stations: + raise ValueError("Require at least one station to plot.") + elif len(dataset_paths) > 1 and ( + scenarios is None or len(scenarios) != len(dataset_paths) + ): + raise ValueError( + "Require a label for each dataset, if more than one is provided." + ) + + cm = 1 / 2.54 + fig, ax = plt.subplots(figsize=(width * cm, height * cm)) + setup_plot_styling(ax, component, ymin=ymin, ymax=ymax) + for i, dataset_path in enumerate(dataset_paths): + dset = xr.open_dataset(dataset_path, engine="h5netcdf") + for station in stations: + label = station + if scenarios is not None and len(scenarios) > 1 and len(stations) > 1: + scenario = scenarios[i] + label = f"{station} ({scenario})" + elif scenarios is not None and len(stations) == 1: + label = scenarios[i] + + plot_spectra( + fig, + ax, + dset, + [station], + component, + ymin=ymin, + ymax=ymax, + labels=[label], + ) + + if len(stations) * len(dataset_paths) > 1: + ax.legend() + + if title: + fig.suptitle(title) + + fig.tight_layout() + + if save: + fig.savefig(save, dpi=dpi) + else: + fig.show() + plt.show() diff --git a/visualisation/waveforms/plot_waveform.py b/visualisation/waveforms/plot_waveform.py new file mode 100644 index 0000000..84629c9 --- /dev/null +++ b/visualisation/waveforms/plot_waveform.py @@ -0,0 +1,192 @@ +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import scipy as sp +import typer +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.gridspec import GridSpecFromSubplotSpec + +from qcore import cli + +app = typer.Typer() + + +class Units(StrEnum): + G = "g" + CMS = "cm/s" + CMS2 = "cm/s^2" + + +_UNIT_CONVERSION_TABLE = {Units.G: 981.0, Units.CMS: 1.0, Units.CMS2: 1.0} + + +def plot_waveform( + dataset: xr.Dataset, + station: str, + ax_x: Axes, + ax_y: Axes, + ax_z: Axes, + dataset_units: Units, + plot_units: Units | None = None, + ylim: float | None = None, + **kwargs, +) -> None: + """Plot a waveform from a simulation. + + Parameters + ---------- + dataset : xr.Dataset + The dataset to read station waveforms. It is assumed that waveform array is in cm/s^2. + station : str + The station to plot. + dataset_units : Units + The units of the dataset. + plot_units : Units, optional + The units to plot in. If not given, will assume it is the same + as the dataset units. + ylim : float or None + Limit for y-axis. + + Returns + ------- + Figure + The matplotlib figure containing the waveform plot. + Array of axes + The figure axes. + """ + if ylim is not None: + ylim = abs(ylim) + waveform = dataset.waveform.sel(station=station).values + time = dataset.time.values + dt = time[1] - time[0] + plot_units = plot_units or dataset_units + conversion = ( + _UNIT_CONVERSION_TABLE[dataset_units] / _UNIT_CONVERSION_TABLE[plot_units] + ) + waveform *= conversion + if plot_units == Units.CMS: + waveform = sp.integrate.cumulative_trapezoid(waveform, dx=dt, initial=0) + axes = [ax_x, ax_y, ax_z] + for i, component in enumerate(dataset.component): + axes[i].plot(time, waveform[i]) + axes[i].grid() + axes[i].set_ylabel(f"{str(component.item())} [{plot_units}]") + if ylim is not None: + axes[i].set_ylim(bottom=-ylim, top=ylim) + axes[-1].set_xlabel("time [s]") + + +@cli.from_docstring(app) +def plot_waveform_cli( + dataset_path: Annotated[Path, typer.Argument()], + dataset_units: Annotated[Units, typer.Argument()], + stations: Annotated[list[str], typer.Argument()], + plot_units: Annotated[Units | None, typer.Option()] = None, + title: str | None = None, + save: Path | None = None, + dpi: int = 300, + width: float = 20, + height: float = 15, + ylim: float | None = None, + rows: float | None = None, + columns: float | None = None, +) -> None: + """Plot a station waveform. + + Parameters + ---------- + dataset_path : Path + Path to HDF5 waveform dataset. + station : str + The station to plot. + dataset_units : Units + The units of the dataset. + plot_units : Units, optional + The units to plot in. If not given, will assume it is the same + as the dataset units. + title : str, optional + The title of the figure. + save : Path, optional + If given, save the figure to the supplied file. + dpi : int, optional + Figure DPI (higher is better quality). Only applies if saving + the figure to a file. + width : float, optional + The figure width, in centimetres. + height : float, optional + The figure height, in centimetres. + ylim : float, optional + The maximum value for the y-axis. + """ + dset = xr.open_dataset(dataset_path, engine="h5netcdf") + cm = 1 / 2.54 + if not (rows or columns): + n = len(stations) + rows = int(np.sqrt(n)) + columns = int(np.ceil(n / rows)) + + mosaic: list[list[str | None]] = [[None] * columns for _ in range(rows)] + for i, station in enumerate(stations): + row, column = np.unravel_index(i, (rows, columns)) + mosaic[row][column] = station + + fig, station_axes = plt.subplot_mosaic( + mosaic, figsize=(width * cm, height * cm), sharex=True, sharey=True + ) + for ax in station_axes.values(): + ax.remove() + # For rescaling axes at the end + all_station_axes = [] + for station, big_ax in station_axes.items(): + if not station: + continue # Skips None placeholders + gs = GridSpecFromSubplotSpec( + nrows=3, ncols=1, subplot_spec=big_ax.get_subplotspec(), hspace=0.1 + ) + axes = [] + for spec in gs: + axes.append(fig.add_subplot(spec)) + + for other in axes[:-1]: + other.set_xticklabels([]) + + all_station_axes.extend(axes) + plot_waveform( + dset, + station, + *axes, + dataset_units, + plot_units, + ylim=ylim, + figsize=(width * cm, height * cm), + ) + axes[0].set_title(station) + + maxes = [] + + for ax in all_station_axes: + maxes.extend(ax.get_ylim()) + abs_max = max(abs(lim) for lim in maxes) + global_ylim = (-abs_max, abs_max) + for ax in all_station_axes: + ax.set_ylim(global_ylim) + if title: + fig.suptitle(title) + + fig.tight_layout() + + if save: + fig.savefig(save, dpi=dpi) + else: + fig.show() + plt.show() + + +if __name__ == "__main__": + app()