diff --git a/packages/scratch-core/src/conversion/plots/data_formats.py b/packages/scratch-core/src/conversion/plots/data_formats.py index da263074..daf1d521 100644 --- a/packages/scratch-core/src/conversion/plots/data_formats.py +++ b/packages/scratch-core/src/conversion/plots/data_formats.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from container_models.base import ImageRGB +from container_models.base import FloatArray2D, ImageRGB @dataclass @@ -51,3 +51,60 @@ class StriationComparisonPlots: mark2_filtered_preview_image: ImageRGB mark1_vs_moved_mark2: ImageRGB wavelength_plot: ImageRGB + + +@dataclass +class ImpressionComparisonMetrics: + """ + Metrics for impression comparison display. + + :param area_correlation: Areal correlation coefficient (from area-based comparison). + :param cell_correlations: Grid of per-cell correlation values (shape: n_rows x n_cols). + :param cmc_score: Congruent Matching Cells score (percentage of cells above threshold). + :param sq_ref: Sq (RMS roughness) of reference surface in µm. + :param sq_comp: Sq (RMS roughness) of compared surface in µm. + :param sq_diff: Sq of difference (comp - ref) in µm. + :param has_area_results: Whether area-based results were computed. + :param has_cell_results: Whether cell/CMC-based results were computed. + """ + + area_correlation: float + cell_correlations: FloatArray2D + cmc_score: float + sq_ref: float + sq_comp: float + sq_diff: float + has_area_results: bool + has_cell_results: bool + + +@dataclass +class ImpressionComparisonPlots: + """ + Results from impression mark comparison visualization. + + Contains rendered images for both area-based and cell/CMC-based visualizations. + Fields are None when the corresponding analysis was not performed. + + :param comparison_overview: Combined overview figure with all results. + :param leveled_reference: Leveled reference surface visualization. + :param leveled_compared: Leveled compared surface visualization. + :param filtered_reference: Filtered reference surface visualization. + :param filtered_compared: Filtered compared surface visualization. + :param cell_reference: Cell-preprocessed reference visualization. + :param cell_compared: Cell-preprocessed compared visualization. + :param cell_overlay: All cells overlay visualization. + :param cell_cross_correlation: Cell-based cross-correlation heatmap. + """ + + comparison_overview: ImageRGB + # Area-based plots + leveled_reference: ImageRGB | None + leveled_compared: ImageRGB | None + filtered_reference: ImageRGB | None + filtered_compared: ImageRGB | None + # Cell/CMC-based plots + cell_reference: ImageRGB | None + cell_compared: ImageRGB | None + cell_overlay: ImageRGB | None + cell_cross_correlation: ImageRGB | None diff --git a/packages/scratch-core/src/conversion/plots/plot_impression.py b/packages/scratch-core/src/conversion/plots/plot_impression.py new file mode 100644 index 00000000..68eaff22 --- /dev/null +++ b/packages/scratch-core/src/conversion/plots/plot_impression.py @@ -0,0 +1,600 @@ +"""Impression mark comparison visualization.""" + +from datetime import datetime + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from container_models.base import FloatArray2D, ImageRGB +from conversion.data_formats import Mark +from conversion.plots.data_formats import ( + ImpressionComparisonMetrics, + ImpressionComparisonPlots, +) +from conversion.plots.utils import ( + DEFAULT_COLORMAP, + figure_to_array, + get_bounding_box, + get_col_widths, + get_figure_dimensions, + get_height_ratios, + get_metadata_dimensions, + metadata_to_table_data, + plot_depth_map_on_axes, +) + + +def plot_impression_comparison_results( + mark_reference_leveled: Mark, + mark_compared_leveled: Mark, + mark_reference_filtered: Mark, + mark_compared_filtered: Mark, + metrics: ImpressionComparisonMetrics, + metadata_reference: dict[str, str], + metadata_compared: dict[str, str], +) -> ImpressionComparisonPlots: + """ + Generate visualization results for impression mark comparison. + + Main orchestrator function that generates both area-based and cell/CMC-based + visualizations based on which results are available in the metrics. + + :param mark_reference_leveled: Reference mark after leveling. + :param mark_compared_leveled: Compared mark after leveling. + :param mark_reference_filtered: Reference mark after filtering. + :param mark_compared_filtered: Compared mark after filtering. + :param metrics: Comparison metrics including correlation values. + :param metadata_reference: Metadata dict for reference mark display. + :param metadata_compared: Metadata dict for compared mark display. + :returns: ImpressionComparisonPlots with all rendered images. + """ + # Initialize all plots as None + leveled_ref = None + leveled_comp = None + filtered_ref = None + filtered_comp = None + cell_ref = None + cell_comp = None + cell_overlay = None + cell_xcorr = None + + # Generate area-based plots if available + if metrics.has_area_results: + ( + leveled_ref, + leveled_comp, + filtered_ref, + filtered_comp, + ) = plot_area_figures( + mark_ref_leveled=mark_reference_leveled, + mark_comp_leveled=mark_compared_leveled, + mark_ref_filtered=mark_reference_filtered, + mark_comp_filtered=mark_compared_filtered, + ) + + # Generate cell/CMC-based plots if available + if metrics.has_cell_results: + ( + cell_ref, + cell_comp, + cell_overlay, + cell_xcorr, + ) = plot_cmc_figures( + mark_ref_filtered=mark_reference_filtered, + mark_comp_filtered=mark_compared_filtered, + cell_correlations=metrics.cell_correlations, + ) + + # Generate comparison overview + comparison_overview = plot_comparison_overview( + mark_reference_leveled=mark_reference_leveled, + mark_compared_leveled=mark_compared_leveled, + mark_reference_filtered=mark_reference_filtered, + mark_compared_filtered=mark_compared_filtered, + metrics=metrics, + metadata_reference=metadata_reference, + metadata_compared=metadata_compared, + ) + + return ImpressionComparisonPlots( + comparison_overview=comparison_overview, + leveled_reference=leveled_ref, + leveled_compared=leveled_comp, + filtered_reference=filtered_ref, + filtered_compared=filtered_comp, + cell_reference=cell_ref, + cell_compared=cell_comp, + cell_overlay=cell_overlay, + cell_cross_correlation=cell_xcorr, + ) + + +def plot_area_figures( + mark_ref_leveled: Mark, + mark_comp_leveled: Mark, + mark_ref_filtered: Mark, + mark_comp_filtered: Mark, +) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB]: + """ + Generate 4 area-based plots for impression comparison. + + Generates: + 1. Leveled reference surface + 2. Leveled compared surface + 3. Filtered reference surface + 4. Filtered compared surface + + :param mark_ref_leveled: Reference mark after leveling. + :param mark_comp_leveled: Compared mark after leveling. + :param mark_ref_filtered: Reference mark after filtering. + :param mark_comp_filtered: Compared mark after filtering. + :returns: Tuple of 4 ImageRGB arrays. + """ + scale_ref = mark_ref_leveled.scan_image.scale_x + scale_comp = mark_comp_leveled.scan_image.scale_x + + # 1. Leveled reference surface + leveled_ref = plot_depth_map_with_axes( + data=mark_ref_leveled.scan_image.data, + scale=scale_ref, + title="Leveled Reference Surface", + ) + + # 2. Leveled compared surface + leveled_comp = plot_depth_map_with_axes( + data=mark_comp_leveled.scan_image.data, + scale=scale_comp, + title="Leveled Compared Surface", + ) + + # 3. Filtered reference surface + filtered_ref = plot_depth_map_with_axes( + data=mark_ref_filtered.scan_image.data, + scale=scale_ref, + title="Filtered Reference Surface", + ) + + # 4. Filtered compared surface + filtered_comp = plot_depth_map_with_axes( + data=mark_comp_filtered.scan_image.data, + scale=scale_comp, + title="Filtered Compared Surface", + ) + + return leveled_ref, leveled_comp, filtered_ref, filtered_comp + + +def plot_cmc_figures( + mark_ref_filtered: Mark, + mark_comp_filtered: Mark, + cell_correlations: FloatArray2D, +) -> tuple[ImageRGB, ImageRGB, ImageRGB, ImageRGB]: + """ + Generate 4 CMC/cell-based plots for impression comparison. + + Generates: + 1. Cell-preprocessed reference + 2. Cell-preprocessed compared + 3. All cells overlay visualization + 4. Cell cross-correlation heatmap + + :param mark_ref_filtered: Reference mark after filtering. + :param mark_comp_filtered: Compared mark after filtering. + :param cell_correlations: Grid of per-cell correlation values. + :returns: Tuple of 4 ImageRGB arrays. + """ + scale = mark_ref_filtered.scan_image.scale_x + + # 1. Cell-preprocessed reference + cell_ref = plot_depth_map_with_axes( + data=mark_ref_filtered.scan_image.data, + scale=scale, + title="Cell-Preprocessed Reference", + ) + + # 2. Cell-preprocessed compared + cell_comp = plot_depth_map_with_axes( + data=mark_comp_filtered.scan_image.data, + scale=scale, + title="Cell-Preprocessed Compared", + ) + + # 3. Cell overlay visualization + cell_overlay = plot_cell_grid_overlay( + data=mark_ref_filtered.scan_image.data, + scale=scale, + cell_correlations=cell_correlations, + ) + + # 4. Cell cross-correlation heatmap + cell_xcorr = plot_cell_correlation_heatmap( + cell_correlations=cell_correlations, + ) + + return cell_ref, cell_comp, cell_overlay, cell_xcorr + + +def plot_comparison_overview( + mark_reference_leveled: Mark, + mark_compared_leveled: Mark, + mark_reference_filtered: Mark, + mark_compared_filtered: Mark, + metrics: ImpressionComparisonMetrics, + metadata_reference: dict[str, str], + metadata_compared: dict[str, str], + wrap_width: int = 25, +) -> ImageRGB: + """ + Generate the main results overview figure with dynamic sizing. + + Combines metadata tables, surface visualizations, cell grid overlay, + and cell correlation heatmap into a single overview figure. + + :param mark_reference_leveled: Reference mark after leveling. + :param mark_compared_leveled: Compared mark after leveling. + :param mark_reference_filtered: Reference mark after filtering. + :param mark_compared_filtered: Compared mark after filtering. + :param metrics: Comparison metrics including correlation values. + :param metadata_reference: Metadata dict for reference mark display. + :param metadata_compared: Metadata dict for compared mark display. + :param wrap_width: Maximum characters per line before wrapping. + :returns: RGB image as uint8 array. + """ + # Build results metadata + results_items = { + "Date report": datetime.now().strftime("%Y-%m-%d"), + "Mark type": mark_reference_leveled.mark_type.value, + "Area Correlation": f"{metrics.area_correlation:.4f}", + "CMC Score": f"{metrics.cmc_score:.1f}%", + "Sq(Ref)": f"{metrics.sq_ref:.4f} µm", + "Sq(Comp)": f"{metrics.sq_comp:.4f} µm", + "Sq(Diff)": f"{metrics.sq_diff:.4f} µm", + } + + max_metadata_rows, metadata_height_ratio = get_metadata_dimensions( + metadata_compared, metadata_reference, wrap_width + ) + height_ratios = get_height_ratios(metadata_height_ratio) + + # Adjust figure height based on content + fig_height = 14 + (max_metadata_rows * 0.12) + fig_height = max(13, min(17, fig_height)) + + fig = plt.figure(figsize=(14, fig_height)) + + gs = fig.add_gridspec( + 4, + 3, + height_ratios=height_ratios, + width_ratios=[0.35, 0.35, 0.30], + hspace=0.35, + wspace=0.25, + ) + + # Row 0: Metadata tables + ax_meta_reference = fig.add_subplot(gs[0, 0]) + _draw_metadata_box( + ax_meta_reference, + metadata_reference, + "Reference Mark (A)", + wrap_width=wrap_width, + ) + + ax_meta_compared = fig.add_subplot(gs[0, 1]) + _draw_metadata_box( + ax_meta_compared, + metadata_compared, + "Compared Mark (B)", + wrap_width=wrap_width, + ) + + # Row 1: Leveled surfaces + Results + ax_leveled_ref = fig.add_subplot(gs[1, 0]) + plot_depth_map_on_axes( + ax_leveled_ref, + fig, + mark_reference_leveled.scan_image.data, + mark_reference_leveled.scan_image.scale_x, + title="Leveled Reference Surface A", + ) + + ax_leveled_comp = fig.add_subplot(gs[1, 1]) + plot_depth_map_on_axes( + ax_leveled_comp, + fig, + mark_compared_leveled.scan_image.data, + mark_compared_leveled.scan_image.scale_x, + title="Leveled Compared Surface B", + ) + + ax_results = fig.add_subplot(gs[1, 2]) + _draw_metadata_box( + ax_results, results_items, draw_border=False, wrap_width=wrap_width + ) + + # Row 2: Filtered surfaces + ax_filtered_ref = fig.add_subplot(gs[2, 0]) + plot_depth_map_on_axes( + ax_filtered_ref, + fig, + mark_reference_filtered.scan_image.data, + mark_reference_filtered.scan_image.scale_x, + title="Filtered Reference Surface A", + ) + + ax_filtered_comp = fig.add_subplot(gs[2, 1]) + plot_depth_map_on_axes( + ax_filtered_comp, + fig, + mark_compared_filtered.scan_image.data, + mark_compared_filtered.scan_image.scale_x, + title="Filtered Compared Surface B", + ) + + # Row 2, Col 2: Cell correlation heatmap (if available) + if metrics.has_cell_results: + ax_heatmap = fig.add_subplot(gs[2, 2]) + _plot_cell_heatmap_on_axes(ax_heatmap, fig, metrics.cell_correlations) + + # Row 3: Cell grid overlay (spanning full width if cell results available) + if metrics.has_cell_results: + ax_overlay = fig.add_subplot(gs[3, :2]) + _plot_cell_overlay_on_axes( + ax_overlay, + mark_reference_filtered.scan_image.data, + mark_reference_filtered.scan_image.scale_x, + metrics.cell_correlations, + ) + + fig.tight_layout(pad=0.8, h_pad=1.2, w_pad=0.8) + fig.subplots_adjust(left=0.06, right=0.98, top=0.96, bottom=0.06) + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_depth_map_with_axes( + data: FloatArray2D, + scale: float, + title: str, +) -> ImageRGB: + """ + Plot a depth map with axes and colorbar. + + :param data: Depth data in meters. + :param scale: Pixel scale in meters. + :param title: Title for the plot. + :returns: RGB image as uint8 array. + """ + height, width = data.shape + fig_height, fig_width = get_figure_dimensions(height, width) + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + plot_depth_map_on_axes(ax, fig, data, scale, title) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_cell_grid_overlay( + data: FloatArray2D, + scale: float, + cell_correlations: FloatArray2D, +) -> ImageRGB: + """ + Plot surface with cell grid overlay showing correlation values. + + :param data: Surface data in meters. + :param scale: Pixel scale in meters. + :param cell_correlations: Grid of per-cell correlation values. + :returns: RGB image as uint8 array. + """ + height, width = data.shape + fig_height, fig_width = get_figure_dimensions(height, width) + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + _plot_cell_overlay_on_axes(ax, data, scale, cell_correlations) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +def plot_cell_correlation_heatmap( + cell_correlations: FloatArray2D, +) -> ImageRGB: + """ + Plot heatmap of per-cell correlation values. + + :param cell_correlations: Grid of per-cell correlation values. + :returns: RGB image as uint8 array. + """ + n_rows, n_cols = cell_correlations.shape + + # Calculate figure size based on grid dimensions + base_size = 6 + aspect = n_cols / n_rows + if aspect > 1: + fig_width = base_size + fig_height = base_size / aspect + 1.5 + else: + fig_height = base_size + 1.5 + fig_width = base_size * aspect + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + _plot_cell_heatmap_on_axes(ax, fig, cell_correlations) + + fig.tight_layout() + arr = figure_to_array(fig) + plt.close(fig) + return arr + + +# --- Helper functions for axes-level plotting --- + + +def _draw_metadata_box( + ax: Axes, + metadata: dict[str, str], + title: str | None = None, + draw_border: bool = True, + wrap_width: int = 25, + side_margin: float = 0.06, +) -> None: + """Draw a metadata box with key-value pairs.""" + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xticks([]) + ax.set_yticks([]) + + for spine in ax.spines.values(): + spine.set_visible(draw_border) + spine.set_linewidth(1.5) + spine.set_edgecolor("black") + + if title: + ax.set_title(title, fontsize=14, fontweight="bold", pad=10) + + table_data = metadata_to_table_data(metadata, wrap_width=wrap_width) + col_widths = get_col_widths(side_margin, table_data) + bounding_box = get_bounding_box(side_margin, table_data) + + table = ax.table( + cellText=table_data, + cellLoc="left", + colWidths=col_widths, + loc="upper center", + edges="open", + bbox=bounding_box, + ) + + table.auto_set_font_size(False) + table.set_fontsize(10) + + for i in range(len(table_data)): + table[i, 0].set_text_props(fontweight="bold", ha="right") + table[i, 0].PAD = 0.02 + table[i, 1].set_text_props(ha="left") + table[i, 1].PAD = 0.02 + + +def _plot_cell_overlay_on_axes( + ax: Axes, + data: FloatArray2D, + scale: float, + cell_correlations: FloatArray2D, +) -> None: + """Plot surface with cell grid overlay on given axes.""" + height, width = data.shape + n_rows, n_cols = cell_correlations.shape + + # Plot the surface + extent = (0, width * scale * 1e6, 0, height * scale * 1e6) + ax.imshow( + data * 1e6, + cmap=DEFAULT_COLORMAP, + aspect="equal", + origin="lower", + extent=extent, + ) + + # Calculate cell dimensions + cell_height = height / n_rows + cell_width = width / n_cols + + # Draw grid and correlation values + for i in range(n_rows): + for j in range(n_cols): + # Cell boundaries in um + x_left = j * cell_width * scale * 1e6 + x_right = (j + 1) * cell_width * scale * 1e6 + y_bottom = (n_rows - 1 - i) * cell_height * scale * 1e6 + y_top = (n_rows - i) * cell_height * scale * 1e6 + + # Draw cell border + ax.plot( + [x_left, x_right, x_right, x_left, x_left], + [y_bottom, y_bottom, y_top, y_top, y_bottom], + "w-", + linewidth=0.5, + alpha=0.7, + ) + + # Add correlation value text + corr_val = cell_correlations[i, j] + if not np.isnan(corr_val): + x_center = (x_left + x_right) / 2 + y_center = (y_bottom + y_top) / 2 + ax.text( + x_center, + y_center, + f"{corr_val:.2f}", + ha="center", + va="center", + fontsize=8, + color="white", + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.1", facecolor="black", alpha=0.5), + ) + + ax.set_xlabel("X - Position [µm]", fontsize=11) + ax.set_ylabel("Y - Position [µm]", fontsize=11) + ax.set_title("Cell Grid with Correlation Values", fontsize=12, fontweight="bold") + ax.tick_params(labelsize=10) + + +def _plot_cell_heatmap_on_axes( + ax: Axes, + fig: Figure, + cell_correlations: FloatArray2D, +) -> None: + """Plot cell correlation heatmap on given axes.""" + from mpl_toolkits.axes_grid1 import make_axes_locatable + + n_rows, n_cols = cell_correlations.shape + + im = ax.imshow( + cell_correlations, + cmap=DEFAULT_COLORMAP, + aspect="equal", + origin="upper", + vmin=0, + vmax=1, + ) + + # Add cell value annotations + for i in range(n_rows): + for j in range(n_cols): + val = cell_correlations[i, j] + if not np.isnan(val): + text_color = "white" if val < 0.5 else "black" + ax.text( + j, + i, + f"{val:.2f}", + ha="center", + va="center", + fontsize=9, + color=text_color, + fontweight="bold", + ) + + ax.set_xlabel("Column", fontsize=11) + ax.set_ylabel("Row", fontsize=11) + ax.set_title("Cell Correlation Heatmap", fontsize=12, fontweight="bold") + ax.tick_params(labelsize=10) + + # Set tick positions + ax.set_xticks(range(n_cols)) + ax.set_yticks(range(n_rows)) + + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.1) + cbar = fig.colorbar(im, cax=cax, label="Correlation") + cbar.ax.tick_params(labelsize=10) diff --git a/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py b/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py new file mode 100644 index 00000000..eae3cd8c --- /dev/null +++ b/packages/scratch-core/tests/conversion/plot_impression/test_plot_impression.py @@ -0,0 +1,385 @@ +"""Tests for impression mark comparison visualization.""" + +import numpy as np +import pytest + +from container_models.scan_image import ScanImage +from conversion.data_formats import Mark, MarkType +from conversion.plots.data_formats import ( + ImpressionComparisonMetrics, + ImpressionComparisonPlots, +) +from conversion.plots.plot_impression import ( + plot_area_figures, + plot_cell_correlation_heatmap, + plot_cell_grid_overlay, + plot_cmc_figures, + plot_comparison_overview, + plot_depth_map_with_axes, + plot_impression_comparison_results, +) + + +@pytest.fixture +def sample_depth_data() -> np.ndarray: + """Create synthetic depth data for testing.""" + np.random.seed(42) + return np.random.randn(100, 120) * 1e-6 # Random surface in meters + + +@pytest.fixture +def sample_mark(sample_depth_data: np.ndarray) -> Mark: + """Create a sample Mark for testing.""" + scan_image = ScanImage( + data=sample_depth_data, + scale_x=1.5e-6, # 1.5 µm pixel size + scale_y=1.5e-6, + ) + return Mark( + scan_image=scan_image, + mark_type=MarkType.FIRING_PIN_IMPRESSION, + ) + + +@pytest.fixture +def sample_cell_correlations() -> np.ndarray: + """Create synthetic cell correlation grid.""" + np.random.seed(42) + return np.random.rand(4, 5) # 4x5 grid of correlations + + +@pytest.fixture +def sample_metrics(sample_cell_correlations: np.ndarray) -> ImpressionComparisonMetrics: + """Create sample metrics for testing.""" + return ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=75.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=True, + has_cell_results=True, + ) + + +@pytest.fixture +def sample_metadata_reference() -> dict[str, str]: + """Create sample metadata for reference mark.""" + return { + "Collection": "firearms", + "Firearm ID": "firearm_1", + "Specimen ID": "cartridge_1", + } + + +@pytest.fixture +def sample_metadata_compared() -> dict[str, str]: + """Create sample metadata for compared mark.""" + return { + "Collection": "firearms", + "Firearm ID": "firearm_1", + "Specimen ID": "cartridge_2", + } + + +class TestPlotDepthMapWithAxes: + """Tests for plot_depth_map_with_axes function.""" + + def test_returns_rgb_image(self, sample_depth_data: np.ndarray): + """Output should be RGB uint8 array.""" + result = plot_depth_map_with_axes( + data=sample_depth_data, + scale=1.5e-6, + title="Test Surface", + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_handles_nan_values(self): + """Should handle NaN values in data.""" + data = np.random.randn(50, 60) * 1e-6 + data[10:20, 10:20] = np.nan + result = plot_depth_map_with_axes(data=data, scale=1.5e-6, title="With NaN") + assert result.shape[2] == 3 + + +class TestPlotCellGridOverlay: + """Tests for plot_cell_grid_overlay function.""" + + def test_returns_rgb_image( + self, sample_depth_data: np.ndarray, sample_cell_correlations: np.ndarray + ): + """Output should be RGB uint8 array.""" + result = plot_cell_grid_overlay( + data=sample_depth_data, + scale=1.5e-6, + cell_correlations=sample_cell_correlations, + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + +class TestPlotCellCorrelationHeatmap: + """Tests for plot_cell_correlation_heatmap function.""" + + def test_returns_rgb_image(self, sample_cell_correlations: np.ndarray): + """Output should be RGB uint8 array.""" + result = plot_cell_correlation_heatmap( + cell_correlations=sample_cell_correlations + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_handles_different_grid_sizes(self): + """Should handle various grid sizes.""" + for rows, cols in [(2, 3), (5, 5), (3, 8)]: + correlations = np.random.rand(rows, cols) + result = plot_cell_correlation_heatmap(cell_correlations=correlations) + assert result.shape[2] == 3 + + +class TestPlotAreaFigures: + """Tests for plot_area_figures function.""" + + def test_returns_four_images(self, sample_mark: Mark): + """Should return tuple of 4 RGB images.""" + result = plot_area_figures( + mark_ref_leveled=sample_mark, + mark_comp_leveled=sample_mark, + mark_ref_filtered=sample_mark, + mark_comp_filtered=sample_mark, + ) + assert len(result) == 4 + for img in result: + assert img.ndim == 3 + assert img.shape[2] == 3 + assert img.dtype == np.uint8 + + +class TestPlotCmcFigures: + """Tests for plot_cmc_figures function.""" + + def test_returns_four_images( + self, sample_mark: Mark, sample_cell_correlations: np.ndarray + ): + """Should return tuple of 4 RGB images.""" + result = plot_cmc_figures( + mark_ref_filtered=sample_mark, + mark_comp_filtered=sample_mark, + cell_correlations=sample_cell_correlations, + ) + assert len(result) == 4 + for img in result: + assert img.ndim == 3 + assert img.shape[2] == 3 + assert img.dtype == np.uint8 + + +class TestPlotComparisonOverview: + """Tests for plot_comparison_overview function.""" + + def test_returns_rgb_image( + self, + sample_mark: Mark, + sample_metrics: ImpressionComparisonMetrics, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """Should return valid RGB image.""" + result = plot_comparison_overview( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=sample_metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + assert result.ndim == 3 + assert result.shape[2] == 3 + assert result.dtype == np.uint8 + + def test_handles_area_only_metrics( + self, + sample_mark: Mark, + sample_cell_correlations: np.ndarray, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """Should work when only area results are available.""" + metrics = ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=0.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=True, + has_cell_results=False, + ) + result = plot_comparison_overview( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + assert result.shape[2] == 3 + + +class TestPlotImpressionComparisonResults: + """Integration tests for the main orchestrator function.""" + + def test_generates_all_plots_when_both_flags_true( + self, + sample_mark: Mark, + sample_metrics: ImpressionComparisonMetrics, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """Should generate all plots when both area and cell results are available.""" + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=sample_metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + + assert isinstance(result, ImpressionComparisonPlots) + + # Comparison overview should always be present + assert result.comparison_overview is not None + + # Area-based plots should be present + assert result.leveled_reference is not None + assert result.leveled_compared is not None + assert result.filtered_reference is not None + assert result.filtered_compared is not None + + # Cell/CMC-based plots should be present + assert result.cell_reference is not None + assert result.cell_compared is not None + assert result.cell_overlay is not None + assert result.cell_cross_correlation is not None + + def test_only_area_plots_when_cell_flag_false( + self, + sample_mark: Mark, + sample_cell_correlations: np.ndarray, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """Should only generate area plots when has_cell_results is False.""" + metrics = ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=75.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=True, + has_cell_results=False, + ) + + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + + # Comparison overview always present + assert result.comparison_overview is not None + + # Area-based plots should be present + assert result.leveled_reference is not None + + # Cell/CMC-based plots should be None + assert result.cell_reference is None + assert result.cell_cross_correlation is None + + def test_only_cell_plots_when_area_flag_false( + self, + sample_mark: Mark, + sample_cell_correlations: np.ndarray, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """Should only generate cell plots when has_area_results is False.""" + metrics = ImpressionComparisonMetrics( + area_correlation=0.85, + cell_correlations=sample_cell_correlations, + cmc_score=75.0, + sq_ref=1.5, + sq_comp=1.6, + sq_diff=0.4, + has_area_results=False, + has_cell_results=True, + ) + + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + + # Comparison overview always present + assert result.comparison_overview is not None + + # Area-based plots should be None + assert result.leveled_reference is None + + # Cell/CMC-based plots should be present + assert result.cell_reference is not None + assert result.cell_cross_correlation is not None + + def test_all_outputs_are_valid_images( + self, + sample_mark: Mark, + sample_metrics: ImpressionComparisonMetrics, + sample_metadata_reference: dict[str, str], + sample_metadata_compared: dict[str, str], + ): + """All non-None outputs should be valid RGB images.""" + result = plot_impression_comparison_results( + mark_reference_leveled=sample_mark, + mark_compared_leveled=sample_mark, + mark_reference_filtered=sample_mark, + mark_compared_filtered=sample_mark, + metrics=sample_metrics, + metadata_reference=sample_metadata_reference, + metadata_compared=sample_metadata_compared, + ) + + for field_name in [ + "comparison_overview", + "leveled_reference", + "leveled_compared", + "filtered_reference", + "filtered_compared", + "cell_reference", + "cell_compared", + "cell_overlay", + "cell_cross_correlation", + ]: + img = getattr(result, field_name) + if img is not None: + assert img.ndim == 3, f"{field_name} should be 3D" + assert img.shape[2] == 3, f"{field_name} should have 3 channels" + assert img.dtype == np.uint8, f"{field_name} should be uint8"