From 2e9cb244a3fdd7a105e3293ec91b5162fdf7eac0 Mon Sep 17 00:00:00 2001 From: Laurens Weijs Date: Fri, 6 Feb 2026 14:07:48 +0100 Subject: [PATCH 1/2] first iteration on plot_impressions --- .../src/conversion/plots/data_formats.py | 65 +- .../src/conversion/plots/plot_impression.py | 640 ++++++++++++++++++ .../plot_impression/test_plot_impression.py | 366 ++++++++++ 3 files changed, 1070 insertions(+), 1 deletion(-) create mode 100644 packages/scratch-core/src/conversion/plots/plot_impression.py create mode 100644 packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py diff --git a/packages/scratch-core/src/conversion/plots/data_formats.py b/packages/scratch-core/src/conversion/plots/data_formats.py index da263074..0e322c91 100644 --- a/packages/scratch-core/src/conversion/plots/data_formats.py +++ b/packages/scratch-core/src/conversion/plots/data_formats.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from container_models.base import ImageRGB +from container_models.base import FloatArray2D, ImageRGB @dataclass @@ -51,3 +51,66 @@ class StriationComparisonPlots: mark2_filtered_preview_image: ImageRGB mark1_vs_moved_mark2: ImageRGB wavelength_plot: ImageRGB + + +@dataclass +class ImpressionComparisonMetrics: + """ + Metrics for impression comparison display. + + Equivalent to MATLAB results_table structure from GenerateAdditionalNISTFigures.m. + + :param area_correlation: Areal correlation coefficient (from area-based comparison). + :param cell_correlations: Grid of per-cell correlation values (shape: n_rows x n_cols). + :param cmc_score: Congruent Matching Cells score (percentage of cells above threshold). + :param sq_ref: Sq (RMS roughness) of reference surface in µm. + :param sq_comp: Sq (RMS roughness) of compared surface in µm. + :param sq_diff: Sq of difference (comp - ref) in µm. + :param has_area_results: Whether area-based results were computed. + :param has_cell_results: Whether cell/CMC-based results were computed. + """ + + area_correlation: float + cell_correlations: FloatArray2D + cmc_score: float + sq_ref: float + sq_comp: float + sq_diff: float + has_area_results: bool + has_cell_results: bool + + +@dataclass +class ImpressionComparisonPlots: + """ + Results from impression mark comparison visualization. + + Contains rendered images for both area-based and cell/CMC-based visualizations. + Fields are None when the corresponding analysis was not performed. + + :param leveled_reference: Leveled reference surface visualization. + :param leveled_compared: Leveled compared surface visualization. + :param filtered_reference: Filtered reference surface visualization. + :param filtered_compared: Filtered compared surface visualization. + :param difference_map: Difference map (compared - reference) visualization. + :param area_cross_correlation: Cross-correlation surface visualization. + :param cell_reference: Cell-preprocessed reference visualization. + :param cell_compared: Cell-preprocessed compared visualization. + :param cell_overlay: All cells overlay visualization. + :param cell_cross_correlation: Cell-based cross-correlation visualization. + :param cell_correlation_histogram: Histogram of per-cell correlations. + """ + + # Area-based plots + leveled_reference: ImageRGB | None + leveled_compared: ImageRGB | None + filtered_reference: ImageRGB | None + filtered_compared: ImageRGB | None + difference_map: ImageRGB | None + area_cross_correlation: ImageRGB | None + # Cell/CMC-based plots + cell_reference: ImageRGB | None + cell_compared: ImageRGB | None + cell_overlay: ImageRGB | None + cell_cross_correlation: ImageRGB | None + cell_correlation_histogram: ImageRGB | None diff --git a/packages/scratch-core/src/conversion/plots/plot_impression.py b/packages/scratch-core/src/conversion/plots/plot_impression.py new file mode 100644 index 00000000..9bbec16f --- /dev/null +++ b/packages/scratch-core/src/conversion/plots/plot_impression.py @@ -0,0 +1,640 @@ +""" +Impression mark comparison visualization. + +Translates MATLAB functions: +- GenerateAdditionalNISTFigures.m (orchestrator) +- PlotResultsAreaNIST.m (area-based correlation plots) +- PlotResultsCmcNIST.m (cell/CMC-based correlation plots) +""" + +import matplotlib.pyplot as plt +import numpy as np +from scipy.signal import correlate2d + +from container_models.base import FloatArray2D, ImageRGB +from conversion.data_formats import Mark +from conversion.plots.data_formats import ( + ImpressionComparisonMetrics, + ImpressionComparisonPlots, +) +from conversion.plots.utils import ( + DEFAULT_COLORMAP, + figure_to_array, + get_figure_dimensions, + plot_depth_map_on_axes, +) + + +def plot_impression_comparison_results( + mark_reference_leveled: Mark, + mark_compared_leveled: Mark, + mark_reference_filtered: Mark, + mark_compared_filtered: Mark, + metrics: ImpressionComparisonMetrics, + _metadata_reference: dict[str, str], + _metadata_compared: dict[str, str], +) -> ImpressionComparisonPlots: + """ + Generate visualization results for impression mark comparison. + + Main orchestrator function equivalent to MATLAB GenerateAdditionalNISTFigures.m. + Generates both area-based and cell/CMC-based visualizations based on which + results are available in the metrics. + + :param mark_reference_leveled: Reference mark after leveling. + :param mark_compared_leveled: Compared mark after leveling. + :param mark_reference_filtered: Reference mark after filtering. + :param mark_compared_filtered: Compared mark after filtering. + :param metrics: Comparison metrics including correlation values. + :param _metadata_reference: Metadata dict for reference mark display (reserved for future use). + :param _metadata_compared: Metadata dict for compared mark display (reserved for future use). + :returns: ImpressionComparisonPlots with all rendered images. + """ + # Initialize all plots as None + leveled_ref = None + leveled_comp = None + filtered_ref = None + filtered_comp = None + difference_map = None + area_xcorr = None + cell_ref = None + cell_comp = None + cell_overlay = None + cell_xcorr = None + cell_histogram = None + + # Generate area-based plots if available + if metrics.has_area_results: + ( + leveled_ref, + leveled_comp, + filtered_ref, + filtered_comp, + difference_map, + area_xcorr, + ) = plot_area_figures( + mark_ref_leveled=mark_reference_leveled, + mark_comp_leveled=mark_compared_leveled, + mark_ref_filtered=mark_reference_filtered, + mark_comp_filtered=mark_compared_filtered, + correlation_value=metrics.area_correlation, + ) + + # Generate cell/CMC-based plots if available + if metrics.has_cell_results: + ( + cell_ref, + cell_comp, + cell_overlay, + cell_xcorr, + cell_histogram, + ) = plot_cmc_figures( + mark_ref_filtered=mark_reference_filtered, + mark_comp_filtered=mark_compared_filtered, + cell_correlations=metrics.cell_correlations, + ) + + return ImpressionComparisonPlots( + leveled_reference=leveled_ref, + leveled_compared=leveled_comp, + filtered_reference=filtered_ref, + filtered_compared=filtered_comp, + difference_map=difference_map, + area_cross_correlation=area_xcorr, + cell_reference=cell_ref, + cell_compared=cell_comp, + cell_overlay=cell_overlay, + cell_cross_correlation=cell_xcorr, + cell_correlation_histogram=cell_histogram, + ) + + +def plot_area_figures( + mark_ref_leveled: Mark, + mark_comp_leveled: Mark, + mark_ref_filtered: Mark, + mark_comp_filtered: Mark, + correlation_value: float, +) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB, ImageRGB, ImageRGB]: + """ + Generate 6 area-based plots for impression comparison. + + Equivalent to MATLAB PlotResultsAreaNIST.m. + Generates: + 1. Leveled reference surface + 2. Leveled compared surface + 3. Filtered reference surface + 4. Filtered compared surface + 5. Difference map (compared - reference) + 6. Cross-correlation surface + + :param mark_ref_leveled: Reference mark after leveling. + :param mark_comp_leveled: Compared mark after leveling. + :param mark_ref_filtered: Reference mark after filtering. + :param mark_comp_filtered: Compared mark after filtering. + :param correlation_value: Areal correlation coefficient. + :returns: Tuple of 6 ImageRGB arrays. + """ + scale_ref = mark_ref_leveled.scan_image.scale_x + scale_comp = mark_comp_leveled.scan_image.scale_x + + # 1. Leveled reference surface + leveled_ref = plot_depth_map_with_axes( + data=mark_ref_leveled.scan_image.data, + scale=scale_ref, + title="Leveled Reference Surface", + ) + + # 2. Leveled compared surface + leveled_comp = plot_depth_map_with_axes( + data=mark_comp_leveled.scan_image.data, + scale=scale_comp, + title="Leveled Compared Surface", + ) + + # 3. Filtered reference surface + filtered_ref = plot_depth_map_with_axes( + data=mark_ref_filtered.scan_image.data, + scale=scale_ref, + title="Filtered Reference Surface", + ) + + # 4. Filtered compared surface + filtered_comp = plot_depth_map_with_axes( + data=mark_comp_filtered.scan_image.data, + scale=scale_comp, + title="Filtered Compared Surface", + ) + + # 5. Difference map + diff_map = plot_difference_map( + data_ref=mark_ref_filtered.scan_image.data, + data_comp=mark_comp_filtered.scan_image.data, + scale=scale_ref, + ) + + # 6. Cross-correlation surface + xcorr = plot_cross_correlation_surface( + data_ref=mark_ref_filtered.scan_image.data, + data_comp=mark_comp_filtered.scan_image.data, + scale=scale_ref, + correlation_value=correlation_value, + ) + + return leveled_ref, leveled_comp, filtered_ref, filtered_comp, diff_map, xcorr + + +def plot_cmc_figures( + mark_ref_filtered: Mark, + mark_comp_filtered: Mark, + cell_correlations: FloatArray2D, +) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB, ImageRGB]: + """ + Generate 5 CMC/cell-based plots for impression comparison. + + Equivalent to MATLAB PlotResultsCmcNIST.m. + Generates: + 1. Cell-preprocessed reference + 2. Cell-preprocessed compared + 3. All cells overlay visualization + 4. Cell cross-correlation heatmap + 5. Cell correlation histogram + + :param mark_ref_filtered: Reference mark after filtering. + :param mark_comp_filtered: Compared mark after filtering. + :param cell_correlations: Grid of per-cell correlation values. + :returns: Tuple of 5 ImageRGB arrays. + """ + scale = mark_ref_filtered.scan_image.scale_x + + # 1. Cell-preprocessed reference + cell_ref = plot_depth_map_with_axes( + data=mark_ref_filtered.scan_image.data, + scale=scale, + title="Cell-Preprocessed Reference", + ) + + # 2. Cell-preprocessed compared + cell_comp = plot_depth_map_with_axes( + data=mark_comp_filtered.scan_image.data, + scale=scale, + title="Cell-Preprocessed Compared", + ) + + # 3. Cell overlay visualization + cell_overlay = plot_cell_grid_overlay( + data=mark_ref_filtered.scan_image.data, + scale=scale, + cell_correlations=cell_correlations, + ) + + # 4. Cell cross-correlation heatmap + cell_xcorr = plot_cell_correlation_heatmap( + cell_correlations=cell_correlations, + ) + + # 5. Cell correlation histogram + cell_histogram = plot_correlation_histogram( + cell_correlations=cell_correlations, + ) + + return cell_ref, cell_comp, cell_overlay, cell_xcorr, cell_histogram + + +def plot_depth_map_with_axes( + data: FloatArray2D, + scale: float, + title: str, +) -> ImageRGB: + """ + Plot a depth map with axes and colorbar. + + :param data: Depth data in meters. + :param scale: Pixel scale in meters. + :param title: Title for the plot. + :returns: RGB image as uint8 array. + """ + height, width = data.shape + fig_height, fig_width = get_figure_dimensions(height, width) + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + plot_depth_map_on_axes(ax, fig, data, scale, title) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_difference_map( + data_ref: FloatArray2D, + data_comp: FloatArray2D, + scale: float, +) -> ImageRGB: + """ + Plot the difference map between two surfaces. + + :param data_ref: Reference surface data in meters. + :param data_comp: Compared surface data in meters. + :param scale: Pixel scale in meters. + :returns: RGB image as uint8 array. + """ + # Compute difference (handle NaN values) + diff = data_comp - data_ref + + # Compute Sq of difference (RMS of valid values) + valid_diff = diff[~np.isnan(diff)] + sq_diff = np.sqrt(np.mean(valid_diff**2)) * 1e6 if len(valid_diff) > 0 else 0.0 + + height, width = diff.shape + fig_height, fig_width = get_figure_dimensions(height, width) + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + extent = (0, width * scale * 1e6, 0, height * scale * 1e6) + im = ax.imshow( + diff * 1e6, + cmap="RdBu_r", # Diverging colormap centered at 0 + aspect="equal", + origin="lower", + extent=extent, + ) + + # Center colormap at 0 + vmax = np.nanmax(np.abs(diff * 1e6)) + im.set_clim(-vmax, vmax) + + ax.set_xlabel("X - Position [um]", fontsize=11) + ax.set_ylabel("Y - Position [um]", fontsize=11) + ax.set_title( + f"Difference Map (Sq = {sq_diff:.4f} um)", fontsize=12, fontweight="bold" + ) + ax.tick_params(labelsize=10) + + from mpl_toolkits.axes_grid1 import make_axes_locatable + + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = fig.colorbar(im, cax=cax, label="Difference [um]") + cbar.ax.tick_params(labelsize=10) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_cross_correlation_surface( + data_ref: FloatArray2D, + data_comp: FloatArray2D, + scale: float, + correlation_value: float, +) -> ImageRGB: + """ + Plot the 2D cross-correlation surface. + + :param data_ref: Reference surface data in meters. + :param data_comp: Compared surface data in meters. + :param scale: Pixel scale in meters. + :param correlation_value: Pre-computed correlation coefficient. + :returns: RGB image as uint8 array. + """ + # Replace NaN with 0 for correlation computation + ref_clean = np.nan_to_num(data_ref, nan=0.0) + comp_clean = np.nan_to_num(data_comp, nan=0.0) + + # Normalize for correlation + ref_norm = ref_clean - np.mean(ref_clean) + comp_norm = comp_clean - np.mean(comp_clean) + + # Compute 2D cross-correlation (use 'same' mode for same-size output) + xcorr = correlate2d(ref_norm, comp_norm, mode="same", boundary="fill", fillvalue=0) + + # Normalize to correlation coefficient scale + norm_factor = np.sqrt(np.sum(ref_norm**2) * np.sum(comp_norm**2)) + if norm_factor > 0: + xcorr = xcorr / norm_factor + + height, width = xcorr.shape + fig_height, fig_width = get_figure_dimensions(height, width) + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + # Create extent in lag coordinates (centered at 0) + half_h = height // 2 + half_w = width // 2 + extent_um = ( + -half_w * scale * 1e6, + half_w * scale * 1e6, + -half_h * scale * 1e6, + half_h * scale * 1e6, + ) + + im = ax.imshow( + xcorr, + cmap=DEFAULT_COLORMAP, + aspect="equal", + origin="lower", + extent=extent_um, + ) + + ax.set_xlabel("X - Lag [um]", fontsize=11) + ax.set_ylabel("Y - Lag [um]", fontsize=11) + ax.set_title( + f"Cross-Correlation (Max = {correlation_value:.4f})", + fontsize=12, + fontweight="bold", + ) + ax.tick_params(labelsize=10) + + # Mark the peak + peak_idx = np.unravel_index(np.argmax(xcorr), xcorr.shape) + peak_y = (peak_idx[0] - half_h) * scale * 1e6 + peak_x = (peak_idx[1] - half_w) * scale * 1e6 + ax.plot(peak_x, peak_y, "r+", markersize=15, markeredgewidth=2) + + from mpl_toolkits.axes_grid1 import make_axes_locatable + + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = fig.colorbar(im, cax=cax, label="Correlation") + cbar.ax.tick_params(labelsize=10) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_cell_grid_overlay( + data: FloatArray2D, + scale: float, + cell_correlations: FloatArray2D, +) -> ImageRGB: + """ + Plot surface with cell grid overlay showing correlation values. + + :param data: Surface data in meters. + :param scale: Pixel scale in meters. + :param cell_correlations: Grid of per-cell correlation values. + :returns: RGB image as uint8 array. + """ + height, width = data.shape + n_rows, n_cols = cell_correlations.shape + + fig_height, fig_width = get_figure_dimensions(height, width) + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + # Plot the surface + extent = (0, width * scale * 1e6, 0, height * scale * 1e6) + ax.imshow( + data * 1e6, + cmap=DEFAULT_COLORMAP, + aspect="equal", + origin="lower", + extent=extent, + ) + + # Calculate cell dimensions + cell_height = height / n_rows + cell_width = width / n_cols + + # Draw grid and correlation values + for i in range(n_rows): + for j in range(n_cols): + # Cell boundaries in um + x_left = j * cell_width * scale * 1e6 + x_right = (j + 1) * cell_width * scale * 1e6 + y_bottom = (n_rows - 1 - i) * cell_height * scale * 1e6 + y_top = (n_rows - i) * cell_height * scale * 1e6 + + # Draw cell border + ax.plot( + [x_left, x_right, x_right, x_left, x_left], + [y_bottom, y_bottom, y_top, y_top, y_bottom], + "w-", + linewidth=0.5, + alpha=0.7, + ) + + # Add correlation value text + corr_val = cell_correlations[i, j] + if not np.isnan(corr_val): + x_center = (x_left + x_right) / 2 + y_center = (y_bottom + y_top) / 2 + ax.text( + x_center, + y_center, + f"{corr_val:.2f}", + ha="center", + va="center", + fontsize=8, + color="white", + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.1", facecolor="black", alpha=0.5), + ) + + ax.set_xlabel("X - Position [um]", fontsize=11) + ax.set_ylabel("Y - Position [um]", fontsize=11) + ax.set_title("Cell Grid with Correlation Values", fontsize=12, fontweight="bold") + ax.tick_params(labelsize=10) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_cell_correlation_heatmap( + cell_correlations: FloatArray2D, +) -> ImageRGB: + """ + Plot heatmap of per-cell correlation values. + + :param cell_correlations: Grid of per-cell correlation values. + :returns: RGB image as uint8 array. + """ + n_rows, n_cols = cell_correlations.shape + + # Calculate figure size based on grid dimensions + base_size = 6 + aspect = n_cols / n_rows + if aspect > 1: + fig_width = base_size + fig_height = base_size / aspect + 1.5 + else: + fig_height = base_size + 1.5 + fig_width = base_size * aspect + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + im = ax.imshow( + cell_correlations, + cmap=DEFAULT_COLORMAP, + aspect="equal", + origin="upper", + vmin=0, + vmax=1, + ) + + # Add cell value annotations + for i in range(n_rows): + for j in range(n_cols): + val = cell_correlations[i, j] + if not np.isnan(val): + text_color = "white" if val < 0.5 else "black" + ax.text( + j, + i, + f"{val:.2f}", + ha="center", + va="center", + fontsize=9, + color=text_color, + fontweight="bold", + ) + + ax.set_xlabel("Column", fontsize=11) + ax.set_ylabel("Row", fontsize=11) + ax.set_title("Cell Correlation Heatmap", fontsize=12, fontweight="bold") + ax.tick_params(labelsize=10) + + # Set tick positions + ax.set_xticks(range(n_cols)) + ax.set_yticks(range(n_rows)) + + from mpl_toolkits.axes_grid1 import make_axes_locatable + + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.1) + cbar = fig.colorbar(im, cax=cax, label="Correlation") + cbar.ax.tick_params(labelsize=10) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_correlation_histogram( + cell_correlations: FloatArray2D, + threshold: float = 0.5, +) -> ImageRGB: + """ + Plot histogram of per-cell correlation values. + + :param cell_correlations: Grid of per-cell correlation values. + :param threshold: CMC threshold to mark on histogram. + :returns: RGB image as uint8 array. + """ + # Flatten and remove NaN values + valid_correlations = cell_correlations.flatten() + valid_correlations = valid_correlations[~np.isnan(valid_correlations)] + + # Count cells above threshold + n_above = np.sum(valid_correlations >= threshold) + n_total = len(valid_correlations) + cmc_score = (n_above / n_total * 100) if n_total > 0 else 0.0 + + fig, ax = plt.subplots(figsize=(8, 5)) + + # Create histogram + n_bins = 20 + _, _, patches = ax.hist( + valid_correlations, + bins=n_bins, + range=(0, 1), + color="steelblue", + edgecolor="white", + alpha=0.8, + ) + + # Color bars above threshold differently (patches is BarContainer for single input) + for patch in patches: # type: ignore[union-attr] + bin_center = patch.get_x() + patch.get_width() / 2 + if bin_center >= threshold: + patch.set_facecolor("forestgreen") + + # Add threshold line + ax.axvline( + threshold, + color="red", + linestyle="--", + linewidth=2, + label=f"CMC Threshold = {threshold:.2f}", + ) + + ax.set_xlabel("Correlation Coefficient", fontsize=11) + ax.set_ylabel("Number of Cells", fontsize=11) + ax.set_title( + f"Cell Correlation Distribution (CMC = {cmc_score:.1f}%)", + fontsize=12, + fontweight="bold", + ) + ax.tick_params(labelsize=10) + ax.set_xlim(0, 1) + ax.legend(loc="upper left", fontsize=10) + ax.grid(True, alpha=0.3, axis="y") + + # Add statistics annotation + stats_text = ( + f"N = {n_total}\n" + f"Mean = {np.mean(valid_correlations):.3f}\n" + f"Std = {np.std(valid_correlations):.3f}\n" + f"Above threshold: {n_above}/{n_total}" + ) + ax.text( + 0.98, + 0.95, + stats_text, + transform=ax.transAxes, + ha="right", + va="top", + fontsize=9, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), + ) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr diff --git a/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py b/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py new file mode 100644 index 00000000..2b77012f --- /dev/null +++ b/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py @@ -0,0 +1,366 @@ +"""Tests for impression mark comparison visualization.""" + +import numpy as np +import pytest + +from container_models.scan_image import ScanImage +from conversion.data_formats import Mark, MarkType +from conversion.plots.data_formats import ( + ImpressionComparisonMetrics, + ImpressionComparisonPlots, +) +from conversion.plots.plot_impression import ( + plot_area_figures, + plot_cell_correlation_heatmap, + plot_cell_grid_overlay, + plot_cmc_figures, + plot_correlation_histogram, + plot_cross_correlation_surface, + plot_depth_map_with_axes, + plot_difference_map, + plot_impression_comparison_results, +) + + +@pytest.fixture +def sample_depth_data() -> np.ndarray: + """Create synthetic depth data for testing.""" + np.random.seed(42) + return np.random.randn(100, 120) * 1e-6 # Random surface in meters + + +@pytest.fixture +def sample_mark(sample_depth_data: np.ndarray) -> Mark: + """Create a sample Mark for testing.""" + scan_image = ScanImage( + data=sample_depth_data, + scale_x=1.5e-6, # 1.5 µm pixel size + scale_y=1.5e-6, + ) + return Mark( + scan_image=scan_image, + mark_type=MarkType.FIRING_PIN_IMPRESSION, + ) + + +@pytest.fixture +def sample_cell_correlations() -> np.ndarray: + """Create synthetic cell correlation grid.""" + np.random.seed(42) + return np.random.rand(4, 5) # 4x5 grid of correlations + + +@pytest.fixture +def sample_metrics(sample_cell_correlations: np.ndarray) -> ImpressionComparisonMetrics: + """Create sample metrics for testing.""" + return ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=75.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=True, + has_cell_results=True, + ) + + +class TestPlotDepthMapWithAxes: + """Tests for plot_depth_map_with_axes function.""" + + def test_returns_rgb_image(self, sample_depth_data: np.ndarray): + """Output should be RGB uint8 array.""" + result = plot_depth_map_with_axes( + data=sample_depth_data, + scale=1.5e-6, + title="Test Surface", + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_handles_nan_values(self): + """Should handle NaN values in data.""" + data = np.random.randn(50, 60) * 1e-6 + data[10:20, 10:20] = np.nan + result = plot_depth_map_with_axes(data=data, scale=1.5e-6, title="With NaN") + assert result.shape[2] == 3 + + +class TestPlotDifferenceMap: + """Tests for plot_difference_map function.""" + + def test_returns_rgb_image(self, sample_depth_data: np.ndarray): + """Output should be RGB uint8 array.""" + data_comp = ( + sample_depth_data + np.random.randn(*sample_depth_data.shape) * 0.1e-6 + ) + result = plot_difference_map( + data_ref=sample_depth_data, + data_comp=data_comp, + scale=1.5e-6, + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_identical_surfaces_show_zero_difference( + self, sample_depth_data: np.ndarray + ): + """Identical surfaces should produce a valid difference map.""" + result = plot_difference_map( + data_ref=sample_depth_data, + data_comp=sample_depth_data.copy(), + scale=1.5e-6, + ) + assert result.shape[2] == 3 + + +class TestPlotCrossCorrelationSurface: + """Tests for plot_cross_correlation_surface function.""" + + def test_returns_rgb_image(self, sample_depth_data: np.ndarray): + """Output should be RGB uint8 array.""" + result = plot_cross_correlation_surface( + data_ref=sample_depth_data, + data_comp=sample_depth_data, + scale=1.5e-6, + correlation_value=0.95, + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + +class TestPlotCellGridOverlay: + """Tests for plot_cell_grid_overlay function.""" + + def test_returns_rgb_image( + self, sample_depth_data: np.ndarray, sample_cell_correlations: np.ndarray + ): + """Output should be RGB uint8 array.""" + result = plot_cell_grid_overlay( + data=sample_depth_data, + scale=1.5e-6, + cell_correlations=sample_cell_correlations, + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + +class TestPlotCellCorrelationHeatmap: + """Tests for plot_cell_correlation_heatmap function.""" + + def test_returns_rgb_image(self, sample_cell_correlations: np.ndarray): + """Output should be RGB uint8 array.""" + result = plot_cell_correlation_heatmap( + cell_correlations=sample_cell_correlations + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_handles_different_grid_sizes(self): + """Should handle various grid sizes.""" + for rows, cols in [(2, 3), (5, 5), (3, 8)]: + correlations = np.random.rand(rows, cols) + result = plot_cell_correlation_heatmap(cell_correlations=correlations) + assert result.shape[2] == 3 + + +class TestPlotCorrelationHistogram: + """Tests for plot_correlation_histogram function.""" + + def test_returns_rgb_image(self, sample_cell_correlations: np.ndarray): + """Output should be RGB uint8 array.""" + result = plot_correlation_histogram(cell_correlations=sample_cell_correlations) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_custom_threshold(self, sample_cell_correlations: np.ndarray): + """Should accept custom threshold.""" + result = plot_correlation_histogram( + cell_correlations=sample_cell_correlations, + threshold=0.7, + ) + assert result.shape[2] == 3 + + def test_handles_nan_values(self): + """Should handle NaN values in correlations.""" + correlations = np.array([[0.5, np.nan], [0.8, 0.3]]) + result = plot_correlation_histogram(cell_correlations=correlations) + assert result.shape[2] == 3 + + +class TestPlotAreaFigures: + """Tests for plot_area_figures function.""" + + def test_returns_six_images(self, sample_mark: Mark): + """Should return tuple of 6 RGB images.""" + result = plot_area_figures( + mark_ref_leveled=sample_mark, + mark_comp_leveled=sample_mark, + mark_ref_filtered=sample_mark, + mark_comp_filtered=sample_mark, + correlation_value=0.85, + ) + assert len(result) == 6 + for img in result: + assert img.ndim == 3 + assert img.shape[2] == 3 + assert img.dtype == np.uint8 + + +class TestPlotCmcFigures: + """Tests for plot_cmc_figures function.""" + + def test_returns_five_images( + self, sample_mark: Mark, sample_cell_correlations: np.ndarray + ): + """Should return tuple of 5 RGB images.""" + result = plot_cmc_figures( + mark_ref_filtered=sample_mark, + mark_comp_filtered=sample_mark, + cell_correlations=sample_cell_correlations, + ) + assert len(result) == 5 + for img in result: + assert img.ndim == 3 + assert img.shape[2] == 3 + assert img.dtype == np.uint8 + + +class TestPlotImpressionComparisonResults: + """Integration tests for the main orchestrator function.""" + + def test_generates_all_plots_when_both_flags_true( + self, sample_mark: Mark, sample_metrics: ImpressionComparisonMetrics + ): + """Should generate all plots when both area and cell results are available.""" + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=sample_metrics, + _metadata_reference={"Case": "Test"}, + _metadata_compared={"Case": "Test"}, + ) + + assert isinstance(result, ImpressionComparisonPlots) + + # Area-based plots should be present + assert result.leveled_reference is not None + assert result.leveled_compared is not None + assert result.filtered_reference is not None + assert result.filtered_compared is not None + assert result.difference_map is not None + assert result.area_cross_correlation is not None + + # Cell/CMC-based plots should be present + assert result.cell_reference is not None + assert result.cell_compared is not None + assert result.cell_overlay is not None + assert result.cell_cross_correlation is not None + assert result.cell_correlation_histogram is not None + + def test_only_area_plots_when_cell_flag_false( + self, sample_mark: Mark, sample_cell_correlations: np.ndarray + ): + """Should only generate area plots when has_cell_results is False.""" + metrics = ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=75.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=True, + has_cell_results=False, + ) + + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=metrics, + _metadata_reference={}, + _metadata_compared={}, + ) + + # Area-based plots should be present + assert result.leveled_reference is not None + assert result.area_cross_correlation is not None + + # Cell/CMC-based plots should be None + assert result.cell_reference is None + assert result.cell_correlation_histogram is None + + def test_only_cell_plots_when_area_flag_false( + self, sample_mark: Mark, sample_cell_correlations: np.ndarray + ): + """Should only generate cell plots when has_area_results is False.""" + metrics = ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=75.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=False, + has_cell_results=True, + ) + + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=metrics, + _metadata_reference={}, + _metadata_compared={}, + ) + + # Area-based plots should be None + assert result.leveled_reference is None + assert result.area_cross_correlation is None + + # Cell/CMC-based plots should be present + assert result.cell_reference is not None + assert result.cell_correlation_histogram is not None + + def test_all_outputs_are_valid_images( + self, sample_mark: Mark, sample_metrics: ImpressionComparisonMetrics + ): + """All non-None outputs should be valid RGB images.""" + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=sample_metrics, + _metadata_reference={}, + _metadata_compared={}, + ) + + for field_name in [ + "leveled_reference", + "leveled_compared", + "filtered_reference", + "filtered_compared", + "difference_map", + "area_cross_correlation", + "cell_reference", + "cell_compared", + "cell_overlay", + "cell_cross_correlation", + "cell_correlation_histogram", + ]: + img = getattr(result, field_name) + if img is not None: + assert img.ndim == 3, f"{field_name} should be 3D" + assert img.shape[2] == 3, f"{field_name} should have 3 channels" + assert img.dtype == np.uint8, f"{field_name} should be uint8" From c0f34a8bb22b0acdabe47b590c21a363d5f30ffd Mon Sep 17 00:00:00 2001 From: Laurens Weijs Date: Fri, 6 Feb 2026 16:45:31 +0100 Subject: [PATCH 2/2] Remove plots that are not wanted and generate also an overview picture --- .../src/conversion/plots/data_formats.py | 12 +- .../src/conversion/plots/plot_impression.py | 546 ++++++++---------- .../plot_impression/test_plot_impression.py | 223 +++---- 3 files changed, 377 insertions(+), 404 deletions(-) diff --git a/packages/scratch-core/src/conversion/plots/data_formats.py b/packages/scratch-core/src/conversion/plots/data_formats.py index 0e322c91..daf1d521 100644 --- a/packages/scratch-core/src/conversion/plots/data_formats.py +++ b/packages/scratch-core/src/conversion/plots/data_formats.py @@ -58,8 +58,6 @@ class ImpressionComparisonMetrics: """ Metrics for impression comparison display. - Equivalent to MATLAB results_table structure from GenerateAdditionalNISTFigures.m. - :param area_correlation: Areal correlation coefficient (from area-based comparison). :param cell_correlations: Grid of per-cell correlation values (shape: n_rows x n_cols). :param cmc_score: Congruent Matching Cells score (percentage of cells above threshold). @@ -88,29 +86,25 @@ class ImpressionComparisonPlots: Contains rendered images for both area-based and cell/CMC-based visualizations. Fields are None when the corresponding analysis was not performed. + :param comparison_overview: Combined overview figure with all results. :param leveled_reference: Leveled reference surface visualization. :param leveled_compared: Leveled compared surface visualization. :param filtered_reference: Filtered reference surface visualization. :param filtered_compared: Filtered compared surface visualization. - :param difference_map: Difference map (compared - reference) visualization. - :param area_cross_correlation: Cross-correlation surface visualization. :param cell_reference: Cell-preprocessed reference visualization. :param cell_compared: Cell-preprocessed compared visualization. :param cell_overlay: All cells overlay visualization. - :param cell_cross_correlation: Cell-based cross-correlation visualization. - :param cell_correlation_histogram: Histogram of per-cell correlations. + :param cell_cross_correlation: Cell-based cross-correlation heatmap. """ + comparison_overview: ImageRGB # Area-based plots leveled_reference: ImageRGB | None leveled_compared: ImageRGB | None filtered_reference: ImageRGB | None filtered_compared: ImageRGB | None - difference_map: ImageRGB | None - area_cross_correlation: ImageRGB | None # Cell/CMC-based plots cell_reference: ImageRGB | None cell_compared: ImageRGB | None cell_overlay: ImageRGB | None cell_cross_correlation: ImageRGB | None - cell_correlation_histogram: ImageRGB | None diff --git a/packages/scratch-core/src/conversion/plots/plot_impression.py b/packages/scratch-core/src/conversion/plots/plot_impression.py index 9bbec16f..68eaff22 100644 --- a/packages/scratch-core/src/conversion/plots/plot_impression.py +++ b/packages/scratch-core/src/conversion/plots/plot_impression.py @@ -1,15 +1,11 @@ -""" -Impression mark comparison visualization. +"""Impression mark comparison visualization.""" -Translates MATLAB functions: -- GenerateAdditionalNISTFigures.m (orchestrator) -- PlotResultsAreaNIST.m (area-based correlation plots) -- PlotResultsCmcNIST.m (cell/CMC-based correlation plots) -""" +from datetime import datetime import matplotlib.pyplot as plt import numpy as np -from scipy.signal import correlate2d +from matplotlib.axes import Axes +from matplotlib.figure import Figure from container_models.base import FloatArray2D, ImageRGB from conversion.data_formats import Mark @@ -20,7 +16,12 @@ from conversion.plots.utils import ( DEFAULT_COLORMAP, figure_to_array, + get_bounding_box, + get_col_widths, get_figure_dimensions, + get_height_ratios, + get_metadata_dimensions, + metadata_to_table_data, plot_depth_map_on_axes, ) @@ -31,23 +32,22 @@ def plot_impression_comparison_results( mark_reference_filtered: Mark, mark_compared_filtered: Mark, metrics: ImpressionComparisonMetrics, - _metadata_reference: dict[str, str], - _metadata_compared: dict[str, str], + metadata_reference: dict[str, str], + metadata_compared: dict[str, str], ) -> ImpressionComparisonPlots: """ Generate visualization results for impression mark comparison. - Main orchestrator function equivalent to MATLAB GenerateAdditionalNISTFigures.m. - Generates both area-based and cell/CMC-based visualizations based on which - results are available in the metrics. + Main orchestrator function that generates both area-based and cell/CMC-based + visualizations based on which results are available in the metrics. :param mark_reference_leveled: Reference mark after leveling. :param mark_compared_leveled: Compared mark after leveling. :param mark_reference_filtered: Reference mark after filtering. :param mark_compared_filtered: Compared mark after filtering. :param metrics: Comparison metrics including correlation values. - :param _metadata_reference: Metadata dict for reference mark display (reserved for future use). - :param _metadata_compared: Metadata dict for compared mark display (reserved for future use). + :param metadata_reference: Metadata dict for reference mark display. + :param metadata_compared: Metadata dict for compared mark display. :returns: ImpressionComparisonPlots with all rendered images. """ # Initialize all plots as None @@ -55,13 +55,10 @@ def plot_impression_comparison_results( leveled_comp = None filtered_ref = None filtered_comp = None - difference_map = None - area_xcorr = None cell_ref = None cell_comp = None cell_overlay = None cell_xcorr = None - cell_histogram = None # Generate area-based plots if available if metrics.has_area_results: @@ -70,14 +67,11 @@ def plot_impression_comparison_results( leveled_comp, filtered_ref, filtered_comp, - difference_map, - area_xcorr, ) = plot_area_figures( mark_ref_leveled=mark_reference_leveled, mark_comp_leveled=mark_compared_leveled, mark_ref_filtered=mark_reference_filtered, mark_comp_filtered=mark_compared_filtered, - correlation_value=metrics.area_correlation, ) # Generate cell/CMC-based plots if available @@ -87,25 +81,33 @@ def plot_impression_comparison_results( cell_comp, cell_overlay, cell_xcorr, - cell_histogram, ) = plot_cmc_figures( mark_ref_filtered=mark_reference_filtered, mark_comp_filtered=mark_compared_filtered, cell_correlations=metrics.cell_correlations, ) + # Generate comparison overview + comparison_overview = plot_comparison_overview( + mark_reference_leveled=mark_reference_leveled, + mark_compared_leveled=mark_compared_leveled, + mark_reference_filtered=mark_reference_filtered, + mark_compared_filtered=mark_compared_filtered, + metrics=metrics, + metadata_reference=metadata_reference, + metadata_compared=metadata_compared, + ) + return ImpressionComparisonPlots( + comparison_overview=comparison_overview, leveled_reference=leveled_ref, leveled_compared=leveled_comp, filtered_reference=filtered_ref, filtered_compared=filtered_comp, - difference_map=difference_map, - area_cross_correlation=area_xcorr, cell_reference=cell_ref, cell_compared=cell_comp, cell_overlay=cell_overlay, cell_cross_correlation=cell_xcorr, - cell_correlation_histogram=cell_histogram, ) @@ -114,26 +116,21 @@ def plot_area_figures( mark_comp_leveled: Mark, mark_ref_filtered: Mark, mark_comp_filtered: Mark, - correlation_value: float, -) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB, ImageRGB, ImageRGB]: +) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB]: """ - Generate 6 area-based plots for impression comparison. + Generate 4 area-based plots for impression comparison. - Equivalent to MATLAB PlotResultsAreaNIST.m. Generates: 1. Leveled reference surface 2. Leveled compared surface 3. Filtered reference surface 4. Filtered compared surface - 5. Difference map (compared - reference) - 6. Cross-correlation surface :param mark_ref_leveled: Reference mark after leveling. :param mark_comp_leveled: Compared mark after leveling. :param mark_ref_filtered: Reference mark after filtering. :param mark_comp_filtered: Compared mark after filtering. - :param correlation_value: Areal correlation coefficient. - :returns: Tuple of 6 ImageRGB arrays. + :returns: Tuple of 4 ImageRGB arrays. """ scale_ref = mark_ref_leveled.scan_image.scale_x scale_comp = mark_comp_leveled.scan_image.scale_x @@ -166,44 +163,27 @@ def plot_area_figures( title="Filtered Compared Surface", ) - # 5. Difference map - diff_map = plot_difference_map( - data_ref=mark_ref_filtered.scan_image.data, - data_comp=mark_comp_filtered.scan_image.data, - scale=scale_ref, - ) - - # 6. Cross-correlation surface - xcorr = plot_cross_correlation_surface( - data_ref=mark_ref_filtered.scan_image.data, - data_comp=mark_comp_filtered.scan_image.data, - scale=scale_ref, - correlation_value=correlation_value, - ) - - return leveled_ref, leveled_comp, filtered_ref, filtered_comp, diff_map, xcorr + return leveled_ref, leveled_comp, filtered_ref, filtered_comp def plot_cmc_figures( mark_ref_filtered: Mark, mark_comp_filtered: Mark, cell_correlations: FloatArray2D, -) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB, ImageRGB]: +) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB]: """ - Generate 5 CMC/cell-based plots for impression comparison. + Generate 4 CMC/cell-based plots for impression comparison. - Equivalent to MATLAB PlotResultsCmcNIST.m. Generates: 1. Cell-preprocessed reference 2. Cell-preprocessed compared 3. All cells overlay visualization 4. Cell cross-correlation heatmap - 5. Cell correlation histogram :param mark_ref_filtered: Reference mark after filtering. :param mark_comp_filtered: Compared mark after filtering. :param cell_correlations: Grid of per-cell correlation values. - :returns: Tuple of 5 ImageRGB arrays. + :returns: Tuple of 4 ImageRGB arrays. """ scale = mark_ref_filtered.scan_image.scale_x @@ -233,12 +213,146 @@ def plot_cmc_figures( cell_correlations=cell_correlations, ) - # 5. Cell correlation histogram - cell_histogram = plot_correlation_histogram( - cell_correlations=cell_correlations, + return cell_ref, cell_comp, cell_overlay, cell_xcorr + + +def plot_comparison_overview( + mark_reference_leveled: Mark, + mark_compared_leveled: Mark, + mark_reference_filtered: Mark, + mark_compared_filtered: Mark, + metrics: ImpressionComparisonMetrics, + metadata_reference: dict[str, str], + metadata_compared: dict[str, str], + wrap_width: int = 25, +) -> ImageRGB: + """ + Generate the main results overview figure with dynamic sizing. + + Combines metadata tables, surface visualizations, cell grid overlay, + and cell correlation heatmap into a single overview figure. + + :param mark_reference_leveled: Reference mark after leveling. + :param mark_compared_leveled: Compared mark after leveling. + :param mark_reference_filtered: Reference mark after filtering. + :param mark_compared_filtered: Compared mark after filtering. + :param metrics: Comparison metrics including correlation values. + :param metadata_reference: Metadata dict for reference mark display. + :param metadata_compared: Metadata dict for compared mark display. + :param wrap_width: Maximum characters per line before wrapping. + :returns: RGB image as uint8 array. + """ + # Build results metadata + results_items = { + "Date report": datetime.now().strftime("%Y-%m-%d"), + "Mark type": mark_reference_leveled.mark_type.value, + "Area Correlation": f"{metrics.area_correlation:.4f}", + "CMC Score": f"{metrics.cmc_score:.1f}%", + "Sq(Ref)": f"{metrics.sq_ref:.4f} µm", + "Sq(Comp)": f"{metrics.sq_comp:.4f} µm", + "Sq(Diff)": f"{metrics.sq_diff:.4f} µm", + } + + max_metadata_rows, metadata_height_ratio = get_metadata_dimensions( + metadata_compared, metadata_reference, wrap_width + ) + height_ratios = get_height_ratios(metadata_height_ratio) + + # Adjust figure height based on content + fig_height = 14 + (max_metadata_rows * 0.12) + fig_height = max(13, min(17, fig_height)) + + fig = plt.figure(figsize=(14, fig_height)) + + gs = fig.add_gridspec( + 4, + 3, + height_ratios=height_ratios, + width_ratios=[0.35, 0.35, 0.30], + hspace=0.35, + wspace=0.25, + ) + + # Row 0: Metadata tables + ax_meta_reference = fig.add_subplot(gs[0, 0]) + _draw_metadata_box( + ax_meta_reference, + metadata_reference, + "Reference Mark (A)", + wrap_width=wrap_width, ) - return cell_ref, cell_comp, cell_overlay, cell_xcorr, cell_histogram + ax_meta_compared = fig.add_subplot(gs[0, 1]) + _draw_metadata_box( + ax_meta_compared, + metadata_compared, + "Compared Mark (B)", + wrap_width=wrap_width, + ) + + # Row 1: Leveled surfaces + Results + ax_leveled_ref = fig.add_subplot(gs[1, 0]) + plot_depth_map_on_axes( + ax_leveled_ref, + fig, + mark_reference_leveled.scan_image.data, + mark_reference_leveled.scan_image.scale_x, + title="Leveled Reference Surface A", + ) + + ax_leveled_comp = fig.add_subplot(gs[1, 1]) + plot_depth_map_on_axes( + ax_leveled_comp, + fig, + mark_compared_leveled.scan_image.data, + mark_compared_leveled.scan_image.scale_x, + title="Leveled Compared Surface B", + ) + + ax_results = fig.add_subplot(gs[1, 2]) + _draw_metadata_box( + ax_results, results_items, draw_border=False, wrap_width=wrap_width + ) + + # Row 2: Filtered surfaces + ax_filtered_ref = fig.add_subplot(gs[2, 0]) + plot_depth_map_on_axes( + ax_filtered_ref, + fig, + mark_reference_filtered.scan_image.data, + mark_reference_filtered.scan_image.scale_x, + title="Filtered Reference Surface A", + ) + + ax_filtered_comp = fig.add_subplot(gs[2, 1]) + plot_depth_map_on_axes( + ax_filtered_comp, + fig, + mark_compared_filtered.scan_image.data, + mark_compared_filtered.scan_image.scale_x, + title="Filtered Compared Surface B", + ) + + # Row 2, Col 2: Cell correlation heatmap (if available) + if metrics.has_cell_results: + ax_heatmap = fig.add_subplot(gs[2, 2]) + _plot_cell_heatmap_on_axes(ax_heatmap, fig, metrics.cell_correlations) + + # Row 3: Cell grid overlay (spanning full width if cell results available) + if metrics.has_cell_results: + ax_overlay = fig.add_subplot(gs[3, :2]) + _plot_cell_overlay_on_axes( + ax_overlay, + mark_reference_filtered.scan_image.data, + mark_reference_filtered.scan_image.scale_x, + metrics.cell_correlations, + ) + + fig.tight_layout(pad=0.8, h_pad=1.2, w_pad=0.8) + fig.subplots_adjust(left=0.06, right=0.98, top=0.96, bottom=0.06) + arr = figure_to_array(fig) + plt.close(fig) + return arr def plot_depth_map_with_axes( @@ -266,57 +380,24 @@ def plot_depth_map_with_axes( return arr -def plot_difference_map( - data_ref: FloatArray2D, - data_comp: FloatArray2D, +def plot_cell_grid_overlay( + data: FloatArray2D, scale: float, + cell_correlations: FloatArray2D, ) -> ImageRGB: """ - Plot the difference map between two surfaces. + Plot surface with cell grid overlay showing correlation values. - :param data_ref: Reference surface data in meters. - :param data_comp: Compared surface data in meters. + :param data: Surface data in meters. :param scale: Pixel scale in meters. + :param cell_correlations: Grid of per-cell correlation values. :returns: RGB image as uint8 array. """ - # Compute difference (handle NaN values) - diff = data_comp - data_ref - - # Compute Sq of difference (RMS of valid values) - valid_diff = diff[~np.isnan(diff)] - sq_diff = np.sqrt(np.mean(valid_diff**2)) * 1e6 if len(valid_diff) > 0 else 0.0 - - height, width = diff.shape + height, width = data.shape fig_height, fig_width = get_figure_dimensions(height, width) - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - extent = (0, width * scale * 1e6, 0, height * scale * 1e6) - im = ax.imshow( - diff * 1e6, - cmap="RdBu_r", # Diverging colormap centered at 0 - aspect="equal", - origin="lower", - extent=extent, - ) - - # Center colormap at 0 - vmax = np.nanmax(np.abs(diff * 1e6)) - im.set_clim(-vmax, vmax) - - ax.set_xlabel("X - Position [um]", fontsize=11) - ax.set_ylabel("Y - Position [um]", fontsize=11) - ax.set_title( - f"Difference Map (Sq = {sq_diff:.4f} um)", fontsize=12, fontweight="bold" - ) - ax.tick_params(labelsize=10) - - from mpl_toolkits.axes_grid1 import make_axes_locatable - - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im, cax=cax, label="Difference [um]") - cbar.ax.tick_params(labelsize=10) + _plot_cell_overlay_on_axes(ax, data, scale, cell_correlations) fig.tight_layout() arr = figure_to_array(fig) @@ -324,107 +405,95 @@ def plot_difference_map( return arr -def plot_cross_correlation_surface( - data_ref: FloatArray2D, - data_comp: FloatArray2D, - scale: float, - correlation_value: float, +def plot_cell_correlation_heatmap( + cell_correlations: FloatArray2D, ) -> ImageRGB: """ - Plot the 2D cross-correlation surface. + Plot heatmap of per-cell correlation values. - :param data_ref: Reference surface data in meters. - :param data_comp: Compared surface data in meters. - :param scale: Pixel scale in meters. - :param correlation_value: Pre-computed correlation coefficient. + :param cell_correlations: Grid of per-cell correlation values. :returns: RGB image as uint8 array. """ - # Replace NaN with 0 for correlation computation - ref_clean = np.nan_to_num(data_ref, nan=0.0) - comp_clean = np.nan_to_num(data_comp, nan=0.0) + n_rows, n_cols = cell_correlations.shape - # Normalize for correlation - ref_norm = ref_clean - np.mean(ref_clean) - comp_norm = comp_clean - np.mean(comp_clean) + # Calculate figure size based on grid dimensions + base_size = 6 + aspect = n_cols / n_rows + if aspect > 1: + fig_width = base_size + fig_height = base_size / aspect + 1.5 + else: + fig_height = base_size + 1.5 + fig_width = base_size * aspect - # Compute 2D cross-correlation (use 'same' mode for same-size output) - xcorr = correlate2d(ref_norm, comp_norm, mode="same", boundary="fill", fillvalue=0) + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - # Normalize to correlation coefficient scale - norm_factor = np.sqrt(np.sum(ref_norm**2) * np.sum(comp_norm**2)) - if norm_factor > 0: - xcorr = xcorr / norm_factor + _plot_cell_heatmap_on_axes(ax, fig, cell_correlations) - height, width = xcorr.shape - fig_height, fig_width = get_figure_dimensions(height, width) + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - # Create extent in lag coordinates (centered at 0) - half_h = height // 2 - half_w = width // 2 - extent_um = ( - -half_w * scale * 1e6, - half_w * scale * 1e6, - -half_h * scale * 1e6, - half_h * scale * 1e6, - ) +# --- Helper functions for axes-level plotting --- - im = ax.imshow( - xcorr, - cmap=DEFAULT_COLORMAP, - aspect="equal", - origin="lower", - extent=extent_um, - ) - ax.set_xlabel("X - Lag [um]", fontsize=11) - ax.set_ylabel("Y - Lag [um]", fontsize=11) - ax.set_title( - f"Cross-Correlation (Max = {correlation_value:.4f})", - fontsize=12, - fontweight="bold", +def _draw_metadata_box( + ax: Axes, + metadata: dict[str, str], + title: str | None = None, + draw_border: bool = True, + wrap_width: int = 25, + side_margin: float = 0.06, +) -> None: + """Draw a metadata box with key-value pairs.""" + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xticks([]) + ax.set_yticks([]) + + for spine in ax.spines.values(): + spine.set_visible(draw_border) + spine.set_linewidth(1.5) + spine.set_edgecolor("black") + + if title: + ax.set_title(title, fontsize=14, fontweight="bold", pad=10) + + table_data = metadata_to_table_data(metadata, wrap_width=wrap_width) + col_widths = get_col_widths(side_margin, table_data) + bounding_box = get_bounding_box(side_margin, table_data) + + table = ax.table( + cellText=table_data, + cellLoc="left", + colWidths=col_widths, + loc="upper center", + edges="open", + bbox=bounding_box, ) - ax.tick_params(labelsize=10) - # Mark the peak - peak_idx = np.unravel_index(np.argmax(xcorr), xcorr.shape) - peak_y = (peak_idx[0] - half_h) * scale * 1e6 - peak_x = (peak_idx[1] - half_w) * scale * 1e6 - ax.plot(peak_x, peak_y, "r+", markersize=15, markeredgewidth=2) + table.auto_set_font_size(False) + table.set_fontsize(10) - from mpl_toolkits.axes_grid1 import make_axes_locatable - - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im, cax=cax, label="Correlation") - cbar.ax.tick_params(labelsize=10) + for i in range(len(table_data)): + table[i, 0].set_text_props(fontweight="bold", ha="right") + table[i, 0].PAD = 0.02 + table[i, 1].set_text_props(ha="left") + table[i, 1].PAD = 0.02 - fig.tight_layout() - arr = figure_to_array(fig) - plt.close(fig) - return arr - -def plot_cell_grid_overlay( +def _plot_cell_overlay_on_axes( + ax: Axes, data: FloatArray2D, scale: float, cell_correlations: FloatArray2D, -) -> ImageRGB: - """ - Plot surface with cell grid overlay showing correlation values. - - :param data: Surface data in meters. - :param scale: Pixel scale in meters. - :param cell_correlations: Grid of per-cell correlation values. - :returns: RGB image as uint8 array. - """ +) -> None: + """Plot surface with cell grid overlay on given axes.""" height, width = data.shape n_rows, n_cols = cell_correlations.shape - fig_height, fig_width = get_figure_dimensions(height, width) - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - # Plot the surface extent = (0, width * scale * 1e6, 0, height * scale * 1e6) ax.imshow( @@ -474,40 +543,22 @@ def plot_cell_grid_overlay( bbox=dict(boxstyle="round,pad=0.1", facecolor="black", alpha=0.5), ) - ax.set_xlabel("X - Position [um]", fontsize=11) - ax.set_ylabel("Y - Position [um]", fontsize=11) + ax.set_xlabel("X - Position [µm]", fontsize=11) + ax.set_ylabel("Y - Position [µm]", fontsize=11) ax.set_title("Cell Grid with Correlation Values", fontsize=12, fontweight="bold") ax.tick_params(labelsize=10) - fig.tight_layout() - arr = figure_to_array(fig) - plt.close(fig) - return arr - -def plot_cell_correlation_heatmap( +def _plot_cell_heatmap_on_axes( + ax: Axes, + fig: Figure, cell_correlations: FloatArray2D, -) -> ImageRGB: - """ - Plot heatmap of per-cell correlation values. +) -> None: + """Plot cell correlation heatmap on given axes.""" + from mpl_toolkits.axes_grid1 import make_axes_locatable - :param cell_correlations: Grid of per-cell correlation values. - :returns: RGB image as uint8 array. - """ n_rows, n_cols = cell_correlations.shape - # Calculate figure size based on grid dimensions - base_size = 6 - aspect = n_cols / n_rows - if aspect > 1: - fig_width = base_size - fig_height = base_size / aspect + 1.5 - else: - fig_height = base_size + 1.5 - fig_width = base_size * aspect - - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - im = ax.imshow( cell_correlations, cmap=DEFAULT_COLORMAP, @@ -543,98 +594,7 @@ def plot_cell_correlation_heatmap( ax.set_xticks(range(n_cols)) ax.set_yticks(range(n_rows)) - from mpl_toolkits.axes_grid1 import make_axes_locatable - divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.1) cbar = fig.colorbar(im, cax=cax, label="Correlation") cbar.ax.tick_params(labelsize=10) - - fig.tight_layout() - arr = figure_to_array(fig) - plt.close(fig) - return arr - - -def plot_correlation_histogram( - cell_correlations: FloatArray2D, - threshold: float = 0.5, -) -> ImageRGB: - """ - Plot histogram of per-cell correlation values. - - :param cell_correlations: Grid of per-cell correlation values. - :param threshold: CMC threshold to mark on histogram. - :returns: RGB image as uint8 array. - """ - # Flatten and remove NaN values - valid_correlations = cell_correlations.flatten() - valid_correlations = valid_correlations[~np.isnan(valid_correlations)] - - # Count cells above threshold - n_above = np.sum(valid_correlations >= threshold) - n_total = len(valid_correlations) - cmc_score = (n_above / n_total * 100) if n_total > 0 else 0.0 - - fig, ax = plt.subplots(figsize=(8, 5)) - - # Create histogram - n_bins = 20 - _, _, patches = ax.hist( - valid_correlations, - bins=n_bins, - range=(0, 1), - color="steelblue", - edgecolor="white", - alpha=0.8, - ) - - # Color bars above threshold differently (patches is BarContainer for single input) - for patch in patches: # type: ignore[union-attr] - bin_center = patch.get_x() + patch.get_width() / 2 - if bin_center >= threshold: - patch.set_facecolor("forestgreen") - - # Add threshold line - ax.axvline( - threshold, - color="red", - linestyle="--", - linewidth=2, - label=f"CMC Threshold = {threshold:.2f}", - ) - - ax.set_xlabel("Correlation Coefficient", fontsize=11) - ax.set_ylabel("Number of Cells", fontsize=11) - ax.set_title( - f"Cell Correlation Distribution (CMC = {cmc_score:.1f}%)", - fontsize=12, - fontweight="bold", - ) - ax.tick_params(labelsize=10) - ax.set_xlim(0, 1) - ax.legend(loc="upper left", fontsize=10) - ax.grid(True, alpha=0.3, axis="y") - - # Add statistics annotation - stats_text = ( - f"N = {n_total}\n" - f"Mean = {np.mean(valid_correlations):.3f}\n" - f"Std = {np.std(valid_correlations):.3f}\n" - f"Above threshold: {n_above}/{n_total}" - ) - ax.text( - 0.98, - 0.95, - stats_text, - transform=ax.transAxes, - ha="right", - va="top", - fontsize=9, - bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), - ) - - fig.tight_layout() - arr = figure_to_array(fig) - plt.close(fig) - return arr diff --git a/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py b/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py index 2b77012f..eae3cd8c 100644 --- a/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py +++ b/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py @@ -14,10 +14,8 @@ plot_cell_correlation_heatmap, plot_cell_grid_overlay, plot_cmc_figures, - plot_correlation_histogram, - plot_cross_correlation_surface, + plot_comparison_overview, plot_depth_map_with_axes, - plot_difference_map, plot_impression_comparison_results, ) @@ -65,6 +63,26 @@ def sample_metrics(sample_cell_correlations: np.ndarray) -> ImpressionComparison ) +@pytest.fixture +def sample_metadata_reference() -> dict[str, str]: + """Create sample metadata for reference mark.""" + return { + "Collection": "firearms", + "Firearm ID": "firearm_1", + "Specimen ID": "cartridge_1", + } + + +@pytest.fixture +def sample_metadata_compared() -> dict[str, str]: + """Create sample metadata for compared mark.""" + return { + "Collection": "firearms", + "Firearm ID": "firearm_1", + "Specimen ID": "cartridge_2", + } + + class TestPlotDepthMapWithAxes: """Tests for plot_depth_map_with_axes function.""" @@ -87,51 +105,6 @@ def test_handles_nan_values(self): assert result.shape[2] == 3 -class TestPlotDifferenceMap: - """Tests for plot_difference_map function.""" - - def test_returns_rgb_image(self, sample_depth_data: np.ndarray): - """Output should be RGB uint8 array.""" - data_comp = ( - sample_depth_data + np.random.randn(*sample_depth_data.shape) * 0.1e-6 - ) - result = plot_difference_map( - data_ref=sample_depth_data, - data_comp=data_comp, - scale=1.5e-6, - ) - assert result.ndim == 3 - assert result.shape[2] == 3 - assert result.dtype == np.uint8 - - def test_identical_surfaces_show_zero_difference( - self, sample_depth_data: np.ndarray - ): - """Identical surfaces should produce a valid difference map.""" - result = plot_difference_map( - data_ref=sample_depth_data, - data_comp=sample_depth_data.copy(), - scale=1.5e-6, - ) - assert result.shape[2] == 3 - - -class TestPlotCrossCorrelationSurface: - """Tests for plot_cross_correlation_surface function.""" - - def test_returns_rgb_image(self, sample_depth_data: np.ndarray): - """Output should be RGB uint8 array.""" - result = plot_cross_correlation_surface( - data_ref=sample_depth_data, - data_comp=sample_depth_data, - scale=1.5e-6, - correlation_value=0.95, - ) - assert result.ndim == 3 - assert result.shape[2] == 3 - assert result.dtype == np.uint8 - - class TestPlotCellGridOverlay: """Tests for plot_cell_grid_overlay function.""" @@ -169,44 +142,18 @@ def test_handles_different_grid_sizes(self): assert result.shape[2] == 3 -class TestPlotCorrelationHistogram: - """Tests for plot_correlation_histogram function.""" - - def test_returns_rgb_image(self, sample_cell_correlations: np.ndarray): - """Output should be RGB uint8 array.""" - result = plot_correlation_histogram(cell_correlations=sample_cell_correlations) - assert result.ndim == 3 - assert result.shape[2] == 3 - assert result.dtype == np.uint8 - - def test_custom_threshold(self, sample_cell_correlations: np.ndarray): - """Should accept custom threshold.""" - result = plot_correlation_histogram( - cell_correlations=sample_cell_correlations, - threshold=0.7, - ) - assert result.shape[2] == 3 - - def test_handles_nan_values(self): - """Should handle NaN values in correlations.""" - correlations = np.array([[0.5, np.nan], [0.8, 0.3]]) - result = plot_correlation_histogram(cell_correlations=correlations) - assert result.shape[2] == 3 - - class TestPlotAreaFigures: """Tests for plot_area_figures function.""" - def test_returns_six_images(self, sample_mark: Mark): - """Should return tuple of 6 RGB images.""" + def test_returns_four_images(self, sample_mark: Mark): + """Should return tuple of 4 RGB images.""" result = plot_area_figures( mark_ref_leveled=sample_mark, mark_comp_leveled=sample_mark, mark_ref_filtered=sample_mark, mark_comp_filtered=sample_mark, - correlation_value=0.85, ) - assert len(result) == 6 + assert len(result) == 4 for img in result: assert img.ndim == 3 assert img.shape[2] == 3 @@ -216,27 +163,85 @@ def test_returns_six_images(self, sample_mark: Mark): class TestPlotCmcFigures: """Tests for plot_cmc_figures function.""" - def test_returns_five_images( + def test_returns_four_images( self, sample_mark: Mark, sample_cell_correlations: np.ndarray ): - """Should return tuple of 5 RGB images.""" + """Should return tuple of 4 RGB images.""" result = plot_cmc_figures( mark_ref_filtered=sample_mark, mark_comp_filtered=sample_mark, cell_correlations=sample_cell_correlations, ) - assert len(result) == 5 + assert len(result) == 4 for img in result: assert img.ndim == 3 assert img.shape[2] == 3 assert img.dtype == np.uint8 +class TestPlotComparisonOverview: + """Tests for plot_comparison_overview function.""" + + def test_returns_rgb_image( + self, + sample_mark: Mark, + sample_metrics: ImpressionComparisonMetrics, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """Should return valid RGB image.""" + result = plot_comparison_overview( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=sample_metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_handles_area_only_metrics( + self, + sample_mark: Mark, + sample_cell_correlations: np.ndarray, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """Should work when only area results are available.""" + metrics = ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=0.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=True, + has_cell_results=False, + ) + result = plot_comparison_overview( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + assert result.shape[2] == 3 + + class TestPlotImpressionComparisonResults: """Integration tests for the main orchestrator function.""" def test_generates_all_plots_when_both_flags_true( - self, sample_mark: Mark, sample_metrics: ImpressionComparisonMetrics + self, + sample_mark: Mark, + sample_metrics: ImpressionComparisonMetrics, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], ): """Should generate all plots when both area and cell results are available.""" result = plot_impression_comparison_results( @@ -245,29 +250,33 @@ def test_generates_all_plots_when_both_flags_true( mark_reference_filtered=sample_mark, mark_compared_filtered=sample_mark, metrics=sample_metrics, - _metadata_reference={"Case": "Test"}, - _metadata_compared={"Case": "Test"}, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, ) assert isinstance(result, ImpressionComparisonPlots) + # Comparison overview should always be present + assert result.comparison_overview is not None + # Area-based plots should be present assert result.leveled_reference is not None assert result.leveled_compared is not None assert result.filtered_reference is not None assert result.filtered_compared is not None - assert result.difference_map is not None - assert result.area_cross_correlation is not None # Cell/CMC-based plots should be present assert result.cell_reference is not None assert result.cell_compared is not None assert result.cell_overlay is not None assert result.cell_cross_correlation is not None - assert result.cell_correlation_histogram is not None def test_only_area_plots_when_cell_flag_false( - self, sample_mark: Mark, sample_cell_correlations: np.ndarray + self, + sample_mark: Mark, + sample_cell_correlations: np.ndarray, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], ): """Should only generate area plots when has_cell_results is False.""" metrics = ImpressionComparisonMetrics( @@ -287,20 +296,26 @@ def test_only_area_plots_when_cell_flag_false( mark_reference_filtered=sample_mark, mark_compared_filtered=sample_mark, metrics=metrics, - _metadata_reference={}, - _metadata_compared={}, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, ) + # Comparison overview always present + assert result.comparison_overview is not None + # Area-based plots should be present assert result.leveled_reference is not None - assert result.area_cross_correlation is not None # Cell/CMC-based plots should be None assert result.cell_reference is None - assert result.cell_correlation_histogram is None + assert result.cell_cross_correlation is None def test_only_cell_plots_when_area_flag_false( - self, sample_mark: Mark, sample_cell_correlations: np.ndarray + self, + sample_mark: Mark, + sample_cell_correlations: np.ndarray, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], ): """Should only generate cell plots when has_area_results is False.""" metrics = ImpressionComparisonMetrics( @@ -320,20 +335,26 @@ def test_only_cell_plots_when_area_flag_false( mark_reference_filtered=sample_mark, mark_compared_filtered=sample_mark, metrics=metrics, - _metadata_reference={}, - _metadata_compared={}, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, ) + # Comparison overview always present + assert result.comparison_overview is not None + # Area-based plots should be None assert result.leveled_reference is None - assert result.area_cross_correlation is None # Cell/CMC-based plots should be present assert result.cell_reference is not None - assert result.cell_correlation_histogram is not None + assert result.cell_cross_correlation is not None def test_all_outputs_are_valid_images( - self, sample_mark: Mark, sample_metrics: ImpressionComparisonMetrics + self, + sample_mark: Mark, + sample_metrics: ImpressionComparisonMetrics, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], ): """All non-None outputs should be valid RGB images.""" result = plot_impression_comparison_results( @@ -342,22 +363,20 @@ def test_all_outputs_are_valid_images( mark_reference_filtered=sample_mark, mark_compared_filtered=sample_mark, metrics=sample_metrics, - _metadata_reference={}, - _metadata_compared={}, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, ) for field_name in [ + "comparison_overview", "leveled_reference", "leveled_compared", "filtered_reference", "filtered_compared", - "difference_map", - "area_cross_correlation", "cell_reference", "cell_compared", "cell_overlay", "cell_cross_correlation", - "cell_correlation_histogram", ]: img = getattr(result, field_name) if img is not None: