From 2683f97d5d04ae0e445d570ee7d98c58c0be3709 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Tue, 1 Jul 2025 17:31:59 -0700 Subject: [PATCH 1/2] fix(global average): Analytics explorer global average fix --- src/pan3d/ui/analytics.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/pan3d/ui/analytics.py b/src/pan3d/ui/analytics.py index bc61800..da6f0cc 100644 --- a/src/pan3d/ui/analytics.py +++ b/src/pan3d/ui/analytics.py @@ -223,22 +223,20 @@ def apply_spatial_average_full_temporal(self, axis=None): """ Calculate spatial average for data for full temporal resoulution """ - if axis is None: - axis = ["X"] - ds = self.source.input active_var = self.state.color_by - select = self.get_selection_criteria(full_temporal=True) - average = ( - ds.isel(select).spatial.average(active_var, axis) - if axis is not None - else ds.isel(select).spatial.average(active_var) - ) - + # Apply spatial average + if axis is not None: + average = ds.isel(select).spatial.average(active_var, axis) + else: + average = ds.isel(select).spatial.average(active_var) + # Optionally apply temporal grouping group_by = self.state.group_by + if group_by == group_options.get(GroupBy.NONE): return average + return average.temporal.group_average( active_var, freq=group_by.lower(), weighted=True ) From 9b61a364cce7d8f266ce54124985579a16da338e Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Fri, 18 Jul 2025 09:21:48 -0700 Subject: [PATCH 2/2] fix: Expose analytics exploerer plot properties ui/analytics.py: - Removed unused cache attributes (spatial_cache, temporal_cache, zonal_cache) and CacheEntry class - Removed unused get_key() method and hashlib import - Simplified enum definitions (removed tuple syntax) - Refactored generate_plot() to use dictionary mapping for cleaner code - Removed redundant parameters from plot methods - they now get data from self.source explorers/analytics.py: - Added property relevance mapping to warn users about irrelevant property changes - Simplified method naming and code structure - Fixed view_update_force call to check existence before calling --- src/pan3d/explorers/analytics.py | 258 ++++++++++++++++++++++++++++++- src/pan3d/ui/analytics.py | 92 +++++------ 2 files changed, 293 insertions(+), 57 deletions(-) diff --git a/src/pan3d/explorers/analytics.py b/src/pan3d/explorers/analytics.py index 5e23fae..05cd502 100644 --- a/src/pan3d/explorers/analytics.py +++ b/src/pan3d/explorers/analytics.py @@ -1,3 +1,5 @@ +import warnings + import vtkmodules.vtkRenderingOpenGL2 # noqa: F401 from vtkmodules.vtkFiltersGeometry import vtkGeometryFilter @@ -13,7 +15,14 @@ vtkRenderWindowInteractor, ) -from pan3d.ui.analytics import Plotting +from pan3d.ui.analytics import ( + GroupBy, + Plotting, + PlotTypes, + group_options, + plot_options, + zonal_axes, +) from pan3d.ui.preview import RenderingSettings from pan3d.ui.vtk_view import Pan3DView from pan3d.utils.common import ControlPanel, Explorer, SummaryToolbar @@ -43,6 +52,14 @@ def __init__(self, render_window, **kwargs): class AnalyticsExplorer(Explorer): + # Define which properties are relevant for each plot type + PLOT_PROPERTY_RELEVANCE = { + PlotTypes.ZONAL: {"zonal_axis"}, + PlotTypes.ZONALTIME: {"zonal_axis", "group_by"}, + PlotTypes.GLOBAL: {"group_by"}, + PlotTypes.TEMPORAL: {"group_by", "temporal_slice"}, + } + def __init__( self, xarray=None, source=None, pipeline=None, server=None, local_rendering=None ): @@ -114,6 +131,233 @@ def ctrl(self): """Returns the Controller for the trame server.""" return self.server.controller + # ------------------------------------------------------------------------- + # Plot Control Properties + # ------------------------------------------------------------------------- + + @property + def plot_type(self): + """Get the current plot type.""" + current_value = self.state.active_plot + # Find the PlotType enum that matches the current value + for plot_enum, plot_str in plot_options.items(): + if plot_str == current_value: + return plot_enum + return PlotTypes.ZONAL # Default + + @plot_type.setter + def plot_type(self, value): + """Set the plot type. Accepts PlotTypes enum or string.""" + if isinstance(value, PlotTypes): + self.state.active_plot = plot_options.get(value) + elif isinstance(value, str): + # Validate the string is a valid plot option + if value in plot_options.values(): + self.state.active_plot = value + else: + valid_options = list(plot_options.values()) + msg = f"Invalid plot type: {value}. Valid options: {valid_options}" + raise ValueError(msg) + else: + type_name = type(value).__name__ + msg = f"plot_type must be PlotTypes enum or string, not {type_name}" + raise TypeError(msg) + # Update the plot after changing type + if self.plotting: + # First update the plot config + self.plotting.expose_plot_specific_config() + # If the plot type doesn't require manual update, generate the plot + if not self.state.show_update_button: + self.plotting.update_plot() + else: + # For plots that require update button, still generate the plot + # to provide immediate feedback when changed programmatically + figure = self.plotting.generate_plot() + self.ctrl.figure_update(figure) + # Update the view + self.ctrl.view_update() + + @property + def group_by(self): + """Get the current temporal grouping option.""" + current_value = self.state.group_by + # Find the GroupBy enum that matches the current value + for group_enum, group_str in group_options.items(): + if group_str == current_value: + return group_enum + return GroupBy.YEAR # Default + + @group_by.setter + def group_by(self, value): + """Set the temporal grouping. Accepts GroupBy enum or string.""" + if isinstance(value, GroupBy): + self.state.group_by = group_options.get(value) + elif isinstance(value, str): + # Validate the string is a valid group option + if value in group_options.values(): + self.state.group_by = value + else: + valid_options = list(group_options.values()) + msg = ( + f"Invalid group by option: {value}. Valid options: {valid_options}" + ) + raise ValueError(msg) + else: + type_name = type(value).__name__ + msg = f"group_by must be GroupBy enum or string, not {type_name}" + raise TypeError(msg) + # Check if this property is relevant for the current plot type + self._check_property_relevance("group_by") + # Trigger plot update for properties that affect the plot + self._trigger_plot_update() + # Update the view + self.ctrl.view_update() + + @property + def zonal_axis(self): + """Get the current zonal axis (Longitude or Latitude).""" + return self.state.zonal_axis + + @zonal_axis.setter + def zonal_axis(self, value): + """Set the zonal axis. Accepts 'Longitude', 'Latitude', 'X', or 'Y'.""" + if value in zonal_axes: + self.state.zonal_axis = value + elif value in zonal_axes.values(): + # Convert X/Y to Longitude/Latitude + for axis_name, axis_value in zonal_axes.items(): + if axis_value == value: + self.state.zonal_axis = axis_name + break + else: + valid_keys = list(zonal_axes.keys()) + valid_values = list(zonal_axes.values()) + msg = f"Invalid zonal axis: {value}. Valid options: {valid_keys} or {valid_values}" + raise ValueError(msg) + # Check if this property is relevant for the current plot type + self._check_property_relevance("zonal_axis") + # Trigger plot update for properties that affect the plot + self._trigger_plot_update() + # Update the view + self.ctrl.view_update() + + @property + def temporal_slice(self): + """Get the current temporal slice index.""" + return self.state.temporal_slice + + @temporal_slice.setter + def temporal_slice(self, value): + """Set the temporal slice index.""" + if not isinstance(value, int): + type_name = type(value).__name__ + msg = f"temporal_slice must be an integer, not {type_name}" + raise TypeError(msg) + if value < 0: + raise ValueError("temporal_slice must be non-negative") + max_slices = self.state.time_groups + if value > max_slices: + msg = f"temporal_slice {value} exceeds maximum {max_slices}" + raise ValueError(msg) + self.state.temporal_slice = value + # Check if this property is relevant for the current plot type + self._check_property_relevance("temporal_slice") + # Trigger plot update for properties that affect the plot + self._trigger_plot_update() + # Update the view + self.ctrl.view_update() + + @property + def analysis_variable(self): + """Get the currently selected variable for analysis.""" + return self.state.color_by + + @analysis_variable.setter + def analysis_variable(self, value): + """Set the variable to use for analysis.""" + # Validate that the variable exists in the dataset + if self.source and self.source.input: + available_vars = list(self.source.input.data_vars) + if value not in available_vars and value is not None: + msg = ( + f"Invalid variable: {value}. Available variables: {available_vars}" + ) + raise ValueError(msg) + self.state.color_by = value + # analysis_variable is used by all plot types, no need to check relevance + # Trigger plot update when variable changes + self._trigger_plot_update() + # Update the view + self.ctrl.view_update() + + @property + def figure_height(self): + """Get the figure height percentage.""" + return self.state.figure_height + + @figure_height.setter + def figure_height(self, value): + """Set the figure height percentage (0-100).""" + if not isinstance(value, (int, float)): + type_name = type(value).__name__ + msg = f"figure_height must be a number, not {type_name}" + raise TypeError(msg) + if not 0 <= value <= 100: + raise ValueError("figure_height must be between 0 and 100") + self.state.figure_height = value + # figure_height is a UI property, not plot-specific + # Update the view + self.ctrl.view_update() + + def get_current_plot(self): + """ + Get the current plot as a Plotly figure object. + + Returns: + plotly.graph_objects.Figure: The current plot figure, or None if no plot is available. + """ + if not self.plotting: + return None + + # Generate and return the current plot + return self.plotting.generate_plot() + + def _check_property_relevance(self, property_name): + """ + Check if a property is relevant for the current plot type and warn if not. + + Args: + property_name: Name of the property being set + """ + current_plot_type = self.plot_type + relevant_properties = self.PLOT_PROPERTY_RELEVANCE.get(current_plot_type, set()) + + if property_name not in relevant_properties: + plot_name = plot_options.get(current_plot_type, str(current_plot_type)) + warnings.warn( + f"Property '{property_name}' is not used by plot type '{plot_name}'. " + f"Relevant properties for this plot: {', '.join(sorted(relevant_properties)) if relevant_properties else 'none'}", + UserWarning, + stacklevel=3, + ) + + def _trigger_plot_update(self): + """ + Helper method to trigger plot updates based on the current plot type. + For plots that auto-update, calls update_plot(). + For plots that require manual update, generates and displays the plot immediately. + """ + if self.plotting: + # Check if this plot type auto-updates or requires button click + if not self.state.show_update_button: + # Auto-updating plot type (like ZONAL) + self.plotting.update_plot() + else: + # Plot types that normally require button click + # We still generate the plot for immediate feedback + figure = self.plotting.generate_plot() + self.ctrl.figure_update(figure) + # ------------------------------------------------------------------------- # UI # ------------------------------------------------------------------------- @@ -121,6 +365,14 @@ def ctrl(self): def _build_ui(self, **kwargs): self.state.trame__title = "Analytics Explorer" + # Initialize default state values for properties + self.state.setdefault("active_plot", plot_options.get(PlotTypes.ZONAL)) + self.state.setdefault("group_by", group_options.get(GroupBy.YEAR)) + self.state.setdefault("zonal_axis", next(iter(zonal_axes.keys()))) + self.state.setdefault("temporal_slice", 0) + self.state.setdefault("time_groups", 0) + self.state.setdefault("figure_height", 50) + with VAppLayout(self.server, fill_height=True) as layout: self.ui = layout # Save dialog @@ -221,7 +473,7 @@ def _build_ui(self, **kwargs): # ----------------------------------------------------- @change("color_by") - def _on_color_by_change_on(self, **kwargs): + def _on_color_by_change(self, **kwargs): super()._on_color_properties_change(**kwargs) self.plotting.update_plot() @@ -238,10 +490,8 @@ def _on_scale_change(self, scale_x, scale_y, scale_z, **_): if self.actor.visibility: self.renderer.ResetCamera() - if self.local_rendering: self.ctrl.view_update(push_camera=True) - self.ctrl.view_reset_camera() def update_rendering(self, reset_camera=False): diff --git a/src/pan3d/ui/analytics.py b/src/pan3d/ui/analytics.py index da6f0cc..a94743d 100644 --- a/src/pan3d/ui/analytics.py +++ b/src/pan3d/ui/analytics.py @@ -1,4 +1,3 @@ -import hashlib import sys from trame.decorators import TrameApp, change @@ -31,10 +30,10 @@ class PlotTypes(Enum): - ZONAL = (0,) - ZONALTIME = (1,) - GLOBAL = (2,) - TEMPORAL = (3,) + ZONAL = 0 + ZONALTIME = 1 + GLOBAL = 2 + TEMPORAL = 3 plot_options = { @@ -69,12 +68,6 @@ class GroupBy(Enum): } -class CacheEntry: - def __init__(self, data, figure): - self.data = data - self.figure = figure - - @TrameApp() class Plotting(v3.VCard): def __init__( @@ -88,9 +81,6 @@ def __init__( **kwargs, ) self.source = source - self.spatial_cache = {} - self.temporal_cache = {} - self.zonal_cache = {} # State variables controlling the UI for various types of plots self.state.figure_height = 50 @@ -195,9 +185,6 @@ def expose_plot_specific_config(self): # Update state in a single call state.update(config) - def get_key(self, selection): - return hashlib.sha256(repr(selection).encode()).hexdigest() - def get_selection_criteria(self, full_temporal=False): """ Get the xarray slicing criteria based on current selection of data @@ -255,31 +242,28 @@ def apply_temporal_average(self, active_var): ) def generate_plot(self): - state = self.state active_var = self.state.color_by - (x, y, _, t) = (self.source.x, self.source.y, self.source.z, self.source.t) if active_var is None: return None - plot_type = state.active_plot - zonal_axis = state.zonal_axis - axis = x if zonal_axes.get(zonal_axis) == "X" else y - if plot_type == plot_options.get(PlotTypes.ZONAL): - return self.zonal_average(active_var, axis) - if plot_type == plot_options.get(PlotTypes.ZONALTIME): - return self.zonal_with_time(active_var, axis, t) - if plot_type == plot_options.get(PlotTypes.GLOBAL): - return self.global_full_temporal(active_var, t) - if plot_type == plot_options.get(PlotTypes.TEMPORAL): - return self.temporal_average(active_var, x, y, t) - return None - - def zonal_average(self, active_var, axis): + plot_type = self.state.active_plot + plot_map = { + plot_options.get(PlotTypes.ZONAL): self.zonal_average, + plot_options.get(PlotTypes.ZONALTIME): self.zonal_with_time, + plot_options.get(PlotTypes.GLOBAL): self.global_full_temporal, + plot_options.get(PlotTypes.TEMPORAL): self.temporal_average, + } + + plot_func = plot_map.get(plot_type) + return plot_func(active_var) if plot_func else None + + def zonal_average(self, active_var): """ Get a plotly figure for the zonal average for current sptio-temporal selection. Average is calculated over a certain specified spatial dimension (Longitude or Latitude). """ - data = self.apply_spatial_average(axis=zonal_axes.get(self.state.zonal_axis)) + axis = zonal_axes.get(self.state.zonal_axis) + data = self.apply_spatial_average(axis=axis) to_plot = data[active_var] dim = to_plot.dims[0] plot = px.line(x=to_plot[dim].to_numpy(), y=to_plot.to_numpy()) @@ -287,23 +271,24 @@ def zonal_average(self, active_var, axis): var_long_name = data[active_var].attrs.get("long_name", active_var) var_units = data[active_var].attrs.get("units", "") x_axis_name = data[dim].attrs.get("long_name", dim) + axis_name = self.source.x if axis == "X" else self.source.y plot.update_layout( - title=f'Zonal Average for {active_var} "{var_long_name}" over {axis} (unit: {var_units})', + title=f'Zonal Average for {active_var} "{var_long_name}" over {axis_name} (unit: {var_units})', xaxis_title=x_axis_name, yaxis_title=var_long_name, ) return plot - def zonal_with_time(self, active_var, axis, t): + def zonal_with_time(self, active_var): """ Get a plotly figure for the zonal average along with current and full temporal selection. Average is calculated over a certain specified spatial dimension (Longitude or Latitude). """ - data_t = self.apply_spatial_average(axis=zonal_axes.get(self.state.zonal_axis)) - data = self.apply_spatial_average_full_temporal( - axis=zonal_axes.get(self.state.zonal_axis) - ) + axis = zonal_axes.get(self.state.zonal_axis) + data_t = self.apply_spatial_average(axis=axis) + data = self.apply_spatial_average_full_temporal(axis=axis) + t = self.source.t var_long_name = data[active_var].attrs.get("long_name", active_var) var_units = data[active_var].attrs.get("units", "") figure = make_subplots( @@ -333,12 +318,13 @@ def zonal_with_time(self, active_var, axis, t): return figure - def global_full_temporal(self, active_var, t): + def global_full_temporal(self, active_var): """ Get a plotly figure for the global average for all data with full temporal resolution. Data from spatial dimension is averaged yielding a single quantity with tempoal dimension. """ data = self.apply_spatial_average_full_temporal(axis=None) + t = self.source.t time = self.get_time_labels(data[t]) plot = px.line(x=time, y=data[active_var]) # plot.update_layout(title_text=f"Global Average for {active_var}") @@ -353,21 +339,22 @@ def global_full_temporal(self, active_var, t): ) return plot - def temporal_average(self, active_var, x, y, t): + def temporal_average(self, active_var): """ Get a time based average of data, data in temporal domain in averaged keeping spatial dimensions same """ + x, y, t = self.source.x, self.source.y, self.source.t data = self.apply_temporal_average(active_var) self.state.time_groups = len(data[t]) - slice = self.state.temporal_slice - plot = px.imshow(data[active_var][slice]) + slice_idx = self.state.temporal_slice + plot = px.imshow(data[active_var][slice_idx]) plot.update_yaxes(autorange="reversed") var_long_name = data[active_var].attrs.get("long_name", active_var) var_units = data[active_var].attrs.get("units", "") x_axis_name = data[x].attrs.get("long_name", x) y_axis_name = data[y].attrs.get("long_name", y) plot.update_layout( - title=f'Temporal average for {active_var} "{var_long_name}" (unit: {var_units}) at time {slice}/{len(data[t])}', + title=f'Temporal average for {active_var} "{var_long_name}" (unit: {var_units}) at time {slice_idx}/{len(data[t])}', xaxis_title=x_axis_name, yaxis_title=y_axis_name, coloraxis_colorbar={"orientation": "h"}, @@ -384,12 +371,11 @@ def get_time_labels(self, time_array): return np.vectorize(lambda dt: f"{dt.isoformat()}")(time_array) @change("temporal_slice") - def on_change_group(self, **kwargs): + def on_change_temporal_slice(self, **kwargs): active_var = self.state.color_by - (x, y, _, t) = (self.source.x, self.source.y, self.source.z, self.source.t) if active_var is None: return - plot = self.temporal_average(active_var, x, y, t) + plot = self.temporal_average(active_var) self.ctrl.figure_update(plot) @change( @@ -406,9 +392,9 @@ def update_plot(self, **kwargs): @change("active_plot") def on_change_active_plot(self, **kwargs): self.expose_plot_specific_config() - plot_type = self.state.active_plot - if plot_type != plot_options.get(PlotTypes.ZONAL): + # ZONAL plots update automatically, others show empty figure until update button clicked + if self.state.active_plot == plot_options.get(PlotTypes.ZONAL): + figure = self.generate_plot() + self.ctrl.figure_update(figure) + else: self.ctrl.figure_update(go.Figure()) - return - figure = self.generate_plot() - self.ctrl.figure_update(figure)