diff --git a/.gitignore b/.gitignore index ad4a1f1..0aa3ae3 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,13 @@ poetry.toml pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python + +# Devenv +.devenv* +devenv.local.nix +devenv.lock +# direnv +.direnv + +# pre-commit +.pre-commit-config.yaml diff --git a/requirements.txt b/requirements.txt index a70d53b..5287e84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ diffimg cartopy matplotlib numpy +ffmpeg-python pygmt pygmt_helper @ git+https://github.com/ucgmsim/pygmt_helper.git source_modelling @ git+https://github.com/ucgmsim/source_modelling.git diff --git a/visualisation/plot_ts.py b/visualisation/plot_ts.py index 447592a..d919845 100644 --- a/visualisation/plot_ts.py +++ b/visualisation/plot_ts.py @@ -1,11 +1,9 @@ """Create simulation video of surface ground motion levels.""" import functools +import io import multiprocessing as mp -import os import shutil -import subprocess -import tempfile from pathlib import Path from typing import Annotated @@ -16,6 +14,7 @@ matplotlib.use("Agg") +import ffmpeg import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np @@ -312,17 +311,43 @@ def waveform_coordinates(nztm_corners: np.ndarray, nx: int, ny: int) -> np.ndarr return coords_nztm[::-1, :, :] # Reverse order to (x, y) for NZTM +def tslice_get(xyts_file: XYTSFile, index: int, downsample: int = 1) -> np.ndarray: + """Retrieve a single timeslice from an xyts file with downsampling + + Parameters + ---------- + xyts_file : XYTSFile + The xyts file to retrieve from. + index : int + The timeslice index to read from. + downsample : int + If greater than 1, downsample the array in strides of `downsample` in + the x and y direction. + + Returns + ------- + array of float32 + An array of shape (ny, nx) containing the downsampled frame data for `index`. + """ + if downsample > 1: + frame_data = xyts_file.data[index, :, ::downsample, ::downsample] + else: + frame_data = xyts_file.data[index] # shape: (3, ny, nx) + return np.linalg.norm(frame_data, axis=0) + + def render_single_frame( frame_index: int, dt: float, - ground_motion_magnitude: np.ndarray, - max_motion: float, - cmap: str, + xyts_file_path: Path, source_config: SourceConfig, nztm_corners: np.ndarray, map_extent_nztm: tuple[float, float, float, float], xr: np.ndarray, yr: np.ndarray, + max_motion: float, + cmap: str, + shading: str, simple_map: bool, scale: str, map_quality: int, @@ -330,7 +355,8 @@ def render_single_frame( width: float, height: float, dpi: int, -) -> str: + downsample: int, +) -> bytes: """Render a single frame of the animation. Parameters @@ -339,12 +365,8 @@ def render_single_frame( The index of the frame to render. dt : float The time step of the simulation. - ground_motion_magnitude : np.ndarray - The ground motion magnitude data. - max_motion : float - The maximum ground motion value for color scaling. - cmap : str - The colormap to use for the animation. + xyts_file_path : Path + The path to the XYTS file. source_config : SourceConfig The source configuration object. nztm_corners : np.ndarray @@ -355,6 +377,12 @@ def render_single_frame( The x coordinates of the gridpoints in NZTM coordinates. yr : np.ndarray The y coordinates of the gridpoints in NZTM coordinates. + max_motion : float + The maximum ground motion value for color scaling. + cmap : str + The colormap to use for the animation. + shading : str + The shading to apply to the colourmap. simple_map : bool If True, disable OpenStreetMap background and use a simple map. scale : str @@ -369,12 +397,17 @@ def render_single_frame( The height of the figure in cm. dpi : int The DPI for the figure. + downsample : int, optional + If greater than 1, downsample the timeslice array in strides of + `downsample` in the x and y direction. Provides a speedup for large + domains. Returns ------- - str - The filename of the saved frame. + bytes + The raw frame output for the frame index """ + xyts_file = XYTSFile(xyts_file_path) # Create a new figure for this frame cm = 1 / 2.54 fig = plt.figure(figsize=(width * cm, height * cm)) @@ -422,17 +455,18 @@ def render_single_frame( ) # Add the actual data for this frame - current_data = ground_motion_magnitude[frame_index, :, :] + + current_data = tslice_get(xyts_file, frame_index, downsample=downsample) pcm = ax.pcolormesh( - xr, - yr, + yr[::downsample, ::downsample], + xr[::downsample, ::downsample], apply_cmap_with_alpha(current_data, 0, max_motion, cmap=cmap), cmap=cmap, vmin=0, vmax=max_motion, - shading="gouraud", + shading=shading, zorder=3, - rasterized=True, + transform=NZTM_CRS, ) # Add time text @@ -455,20 +489,24 @@ def render_single_frame( plt.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) cbar = fig.colorbar( - pcm, ax=ax, orientation="vertical", pad=0.02, aspect=30, shrink=0.8 + pcm, + ax=ax, + orientation="vertical", + pad=0.02, + aspect=30, + shrink=0.8, ) cbar.set_label("Ground Motion (cm/s)") # Save the frame to a file - frame_filename = f"frame_{frame_index:04d}.png" - plt.savefig(frame_filename, dpi=dpi) - plt.close(fig) - - return frame_filename + with io.BytesIO() as io_buf: + fig.savefig(io_buf, format="raw", dpi=dpi) + plt.close(fig) + return io_buf.getvalue() @cli.from_docstring(app, name="xyts") -def animate_low_frequency_mpl_nztm( +def animate_low_frequency( realisation_ffp: Annotated[Path, typer.Argument(exists=True, dir_okay=False)], xyts_ffp: Annotated[Path, typer.Argument(exists=True, dir_okay=False)], output_mp4: Annotated[ @@ -480,6 +518,7 @@ def animate_low_frequency_mpl_nztm( scale: Annotated[str, typer.Option()] = "10m", shading: Annotated[str, typer.Option()] = "gouraud", frame_count: Annotated[int | None, typer.Option()] = None, + frame_start: Annotated[int, typer.Option()] = 0, width: Annotated[float, typer.Option()] = 30.0, height: Annotated[float, typer.Option()] = 30.0, dpi: Annotated[int, typer.Option()] = 150, @@ -488,6 +527,7 @@ def animate_low_frequency_mpl_nztm( zoom: Annotated[float, typer.Option()] = 1, simple_map: Annotated[bool, typer.Option()] = False, map_quality: Annotated[int, typer.Option()] = 4, + downsample: Annotated[int, typer.Option()] = 1, ) -> None: """Render low-frequency output as a 2D video of ground motions. @@ -511,6 +551,8 @@ def animate_low_frequency_mpl_nztm( The shading method for `plt.pcolormesh`, by default "gouraud". frame_count : int | None, optional The number of frames to display in the animation, by default None (uses all frames). + frame_start : int, optional + The frame to start the animation on. Defaults to zero. width : float, optional The width of the figure in cm, by default 30. height : float, optional @@ -529,9 +571,13 @@ def animate_low_frequency_mpl_nztm( map_quality : int, optional The quality of the map, by default 4. Has no effect if using a simple map. Lower values have lower quality but render faster. + downsample : int, optional + If greater than 1, downsample the timeslice array in strides of + `downsample` in the x and y direction. Provides a speedup for large + domains. """ - ffmpeg = shutil.which("ffmpeg") - if not ffmpeg: + have_ffmpeg = shutil.which("ffmpeg") + if not have_ffmpeg: print( "You must have ffmpeg installed. See https://ffmpeg.org/download.html.", ) @@ -540,8 +586,6 @@ def animate_low_frequency_mpl_nztm( source_config = SourceConfig.read_from_realisation(realisation_ffp) xyts_file = XYTSFile(xyts_ffp) - ground_motion_magnitude = np.linalg.norm(xyts_file.data, axis=1) - nztm_corners = xyts_nztm_corners(xyts_file) map_extent_nztm = map_extents(nztm_corners, padding) @@ -560,65 +604,71 @@ def animate_low_frequency_mpl_nztm( frame_count = frame_count or xyts_file.nt xr, yr = waveform_coordinates(nztm_corners, xyts_file.nx, xyts_file.ny) - with tempfile.TemporaryDirectory() as temp_dir: - render_frame = functools.partial( - render_single_frame, - dt=xyts_file.dt, - ground_motion_magnitude=ground_motion_magnitude, - max_motion=max_motion, - cmap=cmap, - source_config=source_config, - nztm_corners=nztm_corners, - map_extent_nztm=map_extent_nztm, - xr=xr, - yr=yr, - simple_map=simple_map, - scale=scale, - map_quality=map_quality, - title=title, - width=width, - height=height, - dpi=dpi, - ) + render_frame = functools.partial( + render_single_frame, + dt=xyts_file.dt, + shading=shading, + xyts_file_path=xyts_ffp.resolve(), + max_motion=max_motion, + cmap=cmap, + source_config=source_config, + nztm_corners=nztm_corners, + map_extent_nztm=map_extent_nztm, + xr=xr, + yr=yr, + simple_map=simple_map, + scale=scale, + map_quality=map_quality, + title=title, + width=width, + height=height, + dpi=dpi, + downsample=downsample, + ) - # warm the OSM cache to speed up rendering by rendering the first frame - os.chdir(temp_dir) + # warm the OSM cache to speed up rendering by rendering the first frame - render_frame(0) + frames = [render_frame(0)] - with mp.Pool() as pool: - # Render all frames in parallel - _ = list( - tqdm.tqdm( - pool.imap(render_frame, range(1, frame_count)), - total=frame_count, - unit="frame", - desc="Rendering frames", - initial=1, - ) + with mp.Pool() as pool: + # Render all frames in parallel + frames.extend( + tqdm.tqdm( + pool.imap(render_frame, range(frame_start, frame_start + frame_count)), + total=frame_count, + unit="frame", + desc="Rendering frames", + initial=1, ) - - # Use ffmpeg to combine frames into video - - ffmpeg_cmd = [ - ffmpeg, - "-y", # Overwrite output file if it exists - "-framerate", - str(fps), - "-i", - "frame_%04d.png", - "-c:v", - "libx264", - "-vf", - "pad=ceil(iw/2)*2:ceil(ih/2)*2", - "-pix_fmt", - "yuv420p", - "-crf", - "23", # Quality setting (lower is better) + ) + cm = 1 / 2.54 + width_px = int(width * cm * dpi) + height_px = int(height * cm * dpi) + # Use ffmpeg to combine frames into video + process = ( + ffmpeg.input( + "pipe:0", format="rawvideo", pix_fmt="rgba", s=f"{width_px}x{height_px}" + ) + .output( str(output_mp4), - ] + pix_fmt="yuv420p", + r=fps, + vcodec="libx264", + crf=23, + vf="pad=ceil(iw/2)*2:ceil(ih/2)*2", + ) + .overwrite_output() + .run_async(pipe_stdin=True) + ) + + # Write the raw video data to FFmpeg's stdin + for frame in frames: + process.stdin.write(frame) + + process.stdin.close() - subprocess.run(ffmpeg_cmd, check=True) + # Wait for FFmpeg to finish + process.wait() def non_zero_data_points(