Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,13 @@ poetry.toml
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/python

# Devenv
.devenv*
devenv.local.nix
devenv.lock
# direnv
.direnv

# pre-commit
.pre-commit-config.yaml
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ diffimg
cartopy
matplotlib
numpy
ffmpeg-python
pygmt
pygmt_helper @ git+https://github.com/ucgmsim/pygmt_helper.git
source_modelling @ git+https://github.com/ucgmsim/source_modelling.git
Expand Down
218 changes: 134 additions & 84 deletions visualisation/plot_ts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Create simulation video of surface ground motion levels."""

import functools
import io
import multiprocessing as mp
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Annotated

Expand All @@ -16,6 +14,7 @@

matplotlib.use("Agg")

import ffmpeg
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -312,25 +311,52 @@ def waveform_coordinates(nztm_corners: np.ndarray, nx: int, ny: int) -> np.ndarr
return coords_nztm[::-1, :, :] # Reverse order to (x, y) for NZTM


def tslice_get(xyts_file: XYTSFile, index: int, downsample: int = 1) -> np.ndarray:
"""Retrieve a single timeslice from an xyts file with downsampling

Parameters
----------
xyts_file : XYTSFile
The xyts file to retrieve from.
index : int
The timeslice index to read from.
downsample : int
If greater than 1, downsample the array in strides of `downsample` in
the x and y direction.

Returns
-------
array of float32
An array of shape (ny, nx) containing the downsampled frame data for `index`.
"""
if downsample > 1:
frame_data = xyts_file.data[index, :, ::downsample, ::downsample]
else:
frame_data = xyts_file.data[index] # shape: (3, ny, nx)
return np.linalg.norm(frame_data, axis=0)


def render_single_frame(
frame_index: int,
dt: float,
ground_motion_magnitude: np.ndarray,
max_motion: float,
cmap: str,
xyts_file_path: Path,
source_config: SourceConfig,
nztm_corners: np.ndarray,
map_extent_nztm: tuple[float, float, float, float],
xr: np.ndarray,
yr: np.ndarray,
max_motion: float,
cmap: str,
shading: str,
simple_map: bool,
scale: str,
map_quality: int,
title: str | None,
width: float,
height: float,
dpi: int,
) -> str:
downsample: int,
) -> bytes:
"""Render a single frame of the animation.

Parameters
Expand All @@ -339,12 +365,8 @@ def render_single_frame(
The index of the frame to render.
dt : float
The time step of the simulation.
ground_motion_magnitude : np.ndarray
The ground motion magnitude data.
max_motion : float
The maximum ground motion value for color scaling.
cmap : str
The colormap to use for the animation.
xyts_file_path : Path
The path to the XYTS file.
source_config : SourceConfig
The source configuration object.
nztm_corners : np.ndarray
Expand All @@ -355,6 +377,12 @@ def render_single_frame(
The x coordinates of the gridpoints in NZTM coordinates.
yr : np.ndarray
The y coordinates of the gridpoints in NZTM coordinates.
max_motion : float
The maximum ground motion value for color scaling.
cmap : str
The colormap to use for the animation.
shading : str
The shading to apply to the colourmap.
simple_map : bool
If True, disable OpenStreetMap background and use a simple map.
scale : str
Expand All @@ -369,12 +397,17 @@ def render_single_frame(
The height of the figure in cm.
dpi : int
The DPI for the figure.
downsample : int, optional
If greater than 1, downsample the timeslice array in strides of
`downsample` in the x and y direction. Provides a speedup for large
domains.

Returns
-------
str
The filename of the saved frame.
bytes
The raw frame output for the frame index
"""
xyts_file = XYTSFile(xyts_file_path)
# Create a new figure for this frame
cm = 1 / 2.54
fig = plt.figure(figsize=(width * cm, height * cm))
Expand Down Expand Up @@ -422,17 +455,18 @@ def render_single_frame(
)

# Add the actual data for this frame
current_data = ground_motion_magnitude[frame_index, :, :]

current_data = tslice_get(xyts_file, frame_index, downsample=downsample)
pcm = ax.pcolormesh(
xr,
yr,
yr[::downsample, ::downsample],
xr[::downsample, ::downsample],
apply_cmap_with_alpha(current_data, 0, max_motion, cmap=cmap),
cmap=cmap,
vmin=0,
vmax=max_motion,
shading="gouraud",
shading=shading,
zorder=3,
rasterized=True,
transform=NZTM_CRS,
)

# Add time text
Expand All @@ -455,20 +489,24 @@ def render_single_frame(

plt.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])
cbar = fig.colorbar(
pcm, ax=ax, orientation="vertical", pad=0.02, aspect=30, shrink=0.8
pcm,
ax=ax,
orientation="vertical",
pad=0.02,
aspect=30,
shrink=0.8,
)
cbar.set_label("Ground Motion (cm/s)")

# Save the frame to a file
frame_filename = f"frame_{frame_index:04d}.png"
plt.savefig(frame_filename, dpi=dpi)
plt.close(fig)

return frame_filename
with io.BytesIO() as io_buf:
fig.savefig(io_buf, format="raw", dpi=dpi)
plt.close(fig)
return io_buf.getvalue()


@cli.from_docstring(app, name="xyts")
def animate_low_frequency_mpl_nztm(
def animate_low_frequency(
realisation_ffp: Annotated[Path, typer.Argument(exists=True, dir_okay=False)],
xyts_ffp: Annotated[Path, typer.Argument(exists=True, dir_okay=False)],
output_mp4: Annotated[
Expand All @@ -480,6 +518,7 @@ def animate_low_frequency_mpl_nztm(
scale: Annotated[str, typer.Option()] = "10m",
shading: Annotated[str, typer.Option()] = "gouraud",
frame_count: Annotated[int | None, typer.Option()] = None,
frame_start: Annotated[int, typer.Option()] = 0,
width: Annotated[float, typer.Option()] = 30.0,
height: Annotated[float, typer.Option()] = 30.0,
dpi: Annotated[int, typer.Option()] = 150,
Expand All @@ -488,6 +527,7 @@ def animate_low_frequency_mpl_nztm(
zoom: Annotated[float, typer.Option()] = 1,
simple_map: Annotated[bool, typer.Option()] = False,
map_quality: Annotated[int, typer.Option()] = 4,
downsample: Annotated[int, typer.Option()] = 1,
) -> None:
"""Render low-frequency output as a 2D video of ground motions.

Expand All @@ -511,6 +551,8 @@ def animate_low_frequency_mpl_nztm(
The shading method for `plt.pcolormesh`, by default "gouraud".
frame_count : int | None, optional
The number of frames to display in the animation, by default None (uses all frames).
frame_start : int, optional
The frame to start the animation on. Defaults to zero.
width : float, optional
The width of the figure in cm, by default 30.
height : float, optional
Expand All @@ -529,9 +571,13 @@ def animate_low_frequency_mpl_nztm(
map_quality : int, optional
The quality of the map, by default 4. Has no effect if using a
simple map. Lower values have lower quality but render faster.
downsample : int, optional
If greater than 1, downsample the timeslice array in strides of
`downsample` in the x and y direction. Provides a speedup for large
domains.
"""
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg:
have_ffmpeg = shutil.which("ffmpeg")
if not have_ffmpeg:
print(
"You must have ffmpeg installed. See https://ffmpeg.org/download.html.",
)
Expand All @@ -540,8 +586,6 @@ def animate_low_frequency_mpl_nztm(
source_config = SourceConfig.read_from_realisation(realisation_ffp)
xyts_file = XYTSFile(xyts_ffp)

ground_motion_magnitude = np.linalg.norm(xyts_file.data, axis=1)

nztm_corners = xyts_nztm_corners(xyts_file)
map_extent_nztm = map_extents(nztm_corners, padding)

Expand All @@ -560,65 +604,71 @@ def animate_low_frequency_mpl_nztm(
frame_count = frame_count or xyts_file.nt
xr, yr = waveform_coordinates(nztm_corners, xyts_file.nx, xyts_file.ny)

with tempfile.TemporaryDirectory() as temp_dir:
render_frame = functools.partial(
render_single_frame,
dt=xyts_file.dt,
ground_motion_magnitude=ground_motion_magnitude,
max_motion=max_motion,
cmap=cmap,
source_config=source_config,
nztm_corners=nztm_corners,
map_extent_nztm=map_extent_nztm,
xr=xr,
yr=yr,
simple_map=simple_map,
scale=scale,
map_quality=map_quality,
title=title,
width=width,
height=height,
dpi=dpi,
)
render_frame = functools.partial(
render_single_frame,
dt=xyts_file.dt,
shading=shading,
xyts_file_path=xyts_ffp.resolve(),
max_motion=max_motion,
cmap=cmap,
source_config=source_config,
nztm_corners=nztm_corners,
map_extent_nztm=map_extent_nztm,
xr=xr,
yr=yr,
simple_map=simple_map,
scale=scale,
map_quality=map_quality,
title=title,
width=width,
height=height,
dpi=dpi,
downsample=downsample,
)

# warm the OSM cache to speed up rendering by rendering the first frame
os.chdir(temp_dir)
# warm the OSM cache to speed up rendering by rendering the first frame

render_frame(0)
frames = [render_frame(0)]

with mp.Pool() as pool:
# Render all frames in parallel
_ = list(
tqdm.tqdm(
pool.imap(render_frame, range(1, frame_count)),
total=frame_count,
unit="frame",
desc="Rendering frames",
initial=1,
)
with mp.Pool() as pool:
# Render all frames in parallel
frames.extend(
tqdm.tqdm(
pool.imap(render_frame, range(frame_start, frame_start + frame_count)),
total=frame_count,
unit="frame",
desc="Rendering frames",
initial=1,
)

# Use ffmpeg to combine frames into video

ffmpeg_cmd = [
ffmpeg,
"-y", # Overwrite output file if it exists
"-framerate",
str(fps),
"-i",
"frame_%04d.png",
"-c:v",
"libx264",
"-vf",
"pad=ceil(iw/2)*2:ceil(ih/2)*2",
"-pix_fmt",
"yuv420p",
"-crf",
"23", # Quality setting (lower is better)
)
cm = 1 / 2.54
width_px = int(width * cm * dpi)
height_px = int(height * cm * dpi)
# Use ffmpeg to combine frames into video
process = (
ffmpeg.input(
"pipe:0", format="rawvideo", pix_fmt="rgba", s=f"{width_px}x{height_px}"
)
.output(
str(output_mp4),
]
pix_fmt="yuv420p",
r=fps,
vcodec="libx264",
crf=23,
vf="pad=ceil(iw/2)*2:ceil(ih/2)*2",
)
.overwrite_output()
.run_async(pipe_stdin=True)
)

# Write the raw video data to FFmpeg's stdin
for frame in frames:
process.stdin.write(frame)

process.stdin.close()

subprocess.run(ffmpeg_cmd, check=True)
# Wait for FFmpeg to finish
process.wait()


def non_zero_data_points(
Expand Down