diff --git a/src/e3sm_quickview/app.py b/src/e3sm_quickview/app.py index 36fdf2c..45165a2 100644 --- a/src/e3sm_quickview/app.py +++ b/src/e3sm_quickview/app.py @@ -14,7 +14,7 @@ from e3sm_quickview.assets import ASSETS from e3sm_quickview.components import doc, file_browser, css, toolbars, dialogs, drawers from e3sm_quickview.pipeline import EAMVisSource -from e3sm_quickview.utils import compute, js, constants, cli +from e3sm_quickview.utils import compute, cli from e3sm_quickview.view_manager import ViewManager @@ -45,13 +45,11 @@ def __init__(self, server=None): "variables_selected": [], # Control 'Load Variables' button availability "variables_loaded": False, - # Level controls - "midpoint_idx": 0, + # Dynamic type-color mapping (populated when data loads) + "variable_types": [], + # Dimension arrays (will be populated dynamically) "midpoints": [], - "interface_idx": 0, "interfaces": [], - # Time controls - "time_idx": 0, "timestamps": [], # Fields summaries "fields_avgs": {}, @@ -208,18 +206,19 @@ def _build_ui(self, **_): @property def selected_variables(self): - vars_per_type = {n: [] for n in "smi"} + from collections import defaultdict + + vars_per_type = defaultdict(list) for var in self.state.variables_selected: - type = var[0] - name = var[1:] - vars_per_type[type].append(name) + type = self.source.varmeta[var].dimensions + vars_per_type[type].append(var) - return vars_per_type + return dict(vars_per_type) @property def selected_variable_names(self): # Remove var type (first char) - return [var[1:] for var in self.state.variables_selected] + return [var for var in self.state.variables_selected] # ------------------------------------------------------------------------- # Methods connected to UI @@ -369,30 +368,42 @@ async def data_loading_open(self, simulation, connectivity): self.state.variables_filter = "" self.state.variables_listing = [ *( - {"name": name, "type": "surface", "id": f"s{name}"} - for name in self.source.surface_vars - ), - *( - {"name": name, "type": "interface", "id": f"i{name}"} - for name in self.source.interface_vars - ), - *( - {"name": name, "type": "midpoint", "id": f"m{name}"} - for name in self.source.midpoint_vars + { + "name": var.name, + "type": str(var.dimensions), + "id": f"{var.name}", + } + for _, var in self.source.varmeta.items() ), ] + # Build dynamic type-color mapping + from e3sm_quickview.utils.colors import get_type_color + + dim_types = sorted( + set(str(var.dimensions) for var in self.source.varmeta.values()) + ) + self.state.variable_types = [ + {"name": t, "color": get_type_color(i)} + for i, t in enumerate(dim_types) + ] + # Update Layer/Time values and ui layout n_cols = 0 available_tracks = [] - for name in ["midpoints", "interfaces", "timestamps"]: - values = getattr(self.source, name) - self.state[name] = values - - if len(values) > 1: + for name, dim in self.source.dimmeta.items(): + values = dim.data + # Convert to list for JSON serialization + self.state[name] = ( + values.tolist() + if hasattr(values, "tolist") + else list(values) + if values is not None + else [] + ) + if values is not None and len(values) > 1: n_cols += 1 - available_tracks.append(constants.TRACK_ENTRIES[name]) - + available_tracks.append({"title": name, "value": name}) self.state.toolbar_slider_cols = 12 / n_cols if n_cols else 12 self.state.animation_tracks = available_tracks self.state.animation_track = ( @@ -401,6 +412,20 @@ async def data_loading_open(self, simulation, connectivity): else None ) + from functools import partial + + # Initialize dynamic index variables for each dimension + for track in available_tracks: + dim_name = track["value"] + index_var = f"{dim_name}_idx" + if "time" in index_var: + self.state[index_var] = 50 + else: + self.state[index_var] = 0 + self.state.change(index_var)( + partial(self._on_slicing_change, dim_name, index_var) + ) + @controller.set("file_selection_cancel") def data_loading_hide(self): self.state.active_tools = [ @@ -414,11 +439,10 @@ async def _data_load_variables(self): """Called at 'Load Variables' button click""" vars_to_show = self.selected_variables - self.source.LoadVariables( - vars_to_show["s"], # surfaces - vars_to_show["m"], # midpoints - vars_to_show["i"], # interfaces - ) + # Flatten the list of lists + flattened_vars = [var for var_list in vars_to_show.values() for var in var_list] + + self.source.LoadVariables(flattened_vars) # Trigger source update + compute avg with self.state: @@ -455,30 +479,42 @@ async def _on_projection(self, projection, **_): await asyncio.sleep(0.1) self.view_manager.reset_camera() - @change("active_tools") + @change("active_tools", "animation_tracks") def _on_toolbar_change(self, active_tools, **_): top_padding = 0 for name in active_tools: - top_padding += toolbars.SIZES.get(name, 0) + if name == "select-slice-time": + track_count = len(self.state.animation_tracks or []) + rows_needed = max(1, (track_count + 2) // 3) # 3 sliders per row + top_padding += 70 * rows_needed + else: + top_padding += toolbars.SIZES.get(name, 0) self.state.top_padding = top_padding + def _on_slicing_change(self, var, ind_var, **_): + self.source.UpdateSlicing(var, self.state[ind_var]) + self.source.UpdatePipeline() + + self.view_manager.update_color_range() + self.view_manager.render() + + # Update avg computation + # Get area variable to calculate weighted average + data = self.source.views["atmosphere_data"] + self.state.fields_avgs = compute.extract_avgs( + data, self.selected_variable_names + ) + @change( "variables_loaded", - "time_idx", - "midpoint_idx", - "interface_idx", "crop_longitude", "crop_latitude", "projection", ) - def _on_time_change( + def _on_downstream_change( self, variables_loaded, - time_idx, - timestamps, - midpoint_idx, - interface_idx, crop_longitude, crop_latitude, projection, @@ -487,12 +523,9 @@ def _on_time_change( if not variables_loaded: return - time_value = timestamps[time_idx] if len(timestamps) else 0.0 - self.source.UpdateLev(midpoint_idx, interface_idx) self.source.ApplyClipping(crop_longitude, crop_latitude) self.source.UpdateProjection(projection[0]) - self.source.UpdateTimeStep(time_idx) - self.source.UpdatePipeline(time_value) + self.source.UpdatePipeline() self.view_manager.update_color_range() self.view_manager.render() diff --git a/src/e3sm_quickview/components/drawers.py b/src/e3sm_quickview/components/drawers.py index 1325e64..19fa1a5 100644 --- a/src/e3sm_quickview/components/drawers.py +++ b/src/e3sm_quickview/components/drawers.py @@ -68,20 +68,26 @@ def __init__(self, load_variables=None): with self: with html.Div(style="position:fixed;top:0;width: 500px;"): - with v3.VCardActions(key="variables_selected.length"): - for name, color in [ - ("surfaces", "success"), - ("interfaces", "info"), - ("midpoints", "warning"), - ]: - v3.VChip( - js.var_title(name), - color=color, - v_show=js.var_count(name), - size="small", - closable=True, - click_close=js.var_remove(name), - ) + with v3.VCardActions( + key="variables_selected.length", + classes="flex-wrap", + style="overflow-y: auto; max-height: 100px;", + ): + v3.VChip( + "{{ variables_selected.filter(id => variables_listing.find(v => v.id === id)?.type === vtype.name).length }} {{ vtype.name }}", + v_for="(vtype, idx) in variable_types", + key="idx", + color=("vtype.color",), + v_show=( + "variables_selected.filter(id => variables_listing.find(v => v.id === id)?.type === vtype.name).length", + ), + size="small", + closable=True, + click_close=( + "variables_selected = variables_selected.filter(id => variables_listing.find(v => v.id === id)?.type !== vtype.name)", + ), + classes="ma-1", + ) v3.VSpacer() v3.VBtn( diff --git a/src/e3sm_quickview/components/toolbars.py b/src/e3sm_quickview/components/toolbars.py index 119e4a8..2a3362e 100644 --- a/src/e3sm_quickview/components/toolbars.py +++ b/src/e3sm_quickview/components/toolbars.py @@ -2,9 +2,10 @@ from trame.app import asynchronous from trame.decorators import change -from trame.widgets import html, vuetify3 as v3 +from trame.widgets import html, vuetify3 as v3, client -from e3sm_quickview.utils import js, constants + +from e3sm_quickview.utils import js DENSITY = { "adjust-layout": "compact", @@ -16,7 +17,7 @@ SIZES = { "adjust-layout": 49, "adjust-databounds": 65, - "select-slice-time": 65, + "select-slice-time": 70, "animation-controls": 49, } @@ -232,82 +233,60 @@ def __init__(self): ) -class DataSelection(v3.VToolbar): +class DataSelection(html.Div): def __init__(self): - super().__init__(**to_kwargs("select-slice-time")) + style = to_kwargs("select-slice-time") + # Use style instead of d-flex class to avoid !important override of v-show + # Add background color to match VToolbar appearance + style["style"] = ( + "display: flex; align-items: center; background: rgb(var(--v-theme-surface));" + ) + super().__init__(**style) with self: - v3.VIcon("mdi-tune-variant", classes="ml-3 opacity-50") - with v3.VRow(classes="ma-0 pr-2 align-center", dense=True): - # midpoint layer - with v3.VCol( - cols=("toolbar_slider_cols", 4), - v_show="midpoints.length > 1", - ): - with v3.VRow(classes="mx-2 my-0"): - v3.VLabel( - "Layer Midpoints", - classes="text-subtitle-2", - ) - v3.VSpacer() - v3.VLabel( - "{{ parseFloat(midpoints[midpoint_idx] || 0).toFixed(2) }} hPa (k={{ midpoint_idx }})", - classes="text-body-2", - ) - v3.VSlider( - v_model=("midpoint_idx", 0), - min=0, - max=("Math.max(0, midpoints.length - 1)",), - step=1, - density="compact", - hide_details=True, - ) + v3.VIcon("mdi-tune-variant", classes="ml-3 mr-2 opacity-50") - # interface layer + with v3.VRow(classes="ma-0 pr-2 flex-wrap flex-grow-1", dense=True): + # Debug: Show animation_tracks array + # html.Div("Animation Tracks: {{ JSON.stringify(animation_tracks) }}", classes="col-12") + # Each track gets a column (3 per row) with v3.VCol( - cols=("toolbar_slider_cols", 4), - v_show="interfaces.length > 1", + cols=4, + v_for="(track, idx) in animation_tracks", + key="idx", + classes="pa-2", ): - with v3.VRow(classes="mx-2 my-0"): - v3.VLabel( - "Layer Interfaces", - classes="text-subtitle-2", - ) - v3.VSpacer() - v3.VLabel( - "{{ parseFloat(interfaces[interface_idx] || 0).toFixed(2) }} hPa (k={{interface_idx}})", - classes="text-body-2", - ) - v3.VSlider( - v_model=("interface_idx", 0), - min=0, - max=("Math.max(0, interfaces.length - 1)",), - step=1, - density="compact", - hide_details=True, - ) + with client.Getter(name=("track.value",), value_name="t_values"): + with client.Getter( + name=("track.value + '_idx'",), value_name="t_idx" + ): + with v3.VRow(classes="ma-0 align-center", dense=True): + v3.VLabel( + "{{track.title}}", + classes="text-subtitle-2", + ) + v3.VSpacer() + v3.VLabel( + "{{ parseFloat(t_values[t_idx]).toFixed(2) }} hPa (k={{ t_idx }})", + classes="text-body-2", + ) + v3.VSlider( + model_value=("t_idx",), + update_modelValue=( + self.on_update_slider, + "[track.value, $event]", + ), + min=0, + # max=100,#("get(track.value).length - 1",), + max=("t_values.length - 1",), + step=1, + density="compact", + hide_details=True, + ) - # time - with v3.VCol( - cols=("toolbar_slider_cols", 4), - v_show="timestamps.length > 1", - ): - self.state.setdefault("time_value", 80.50) - with v3.VRow(classes="mx-2 my-0"): - v3.VLabel("Time", classes="text-subtitle-2") - v3.VSpacer() - v3.VLabel( - "{{ parseFloat(timestamps[time_idx]).toFixed(2) }} (t={{time_idx}})", - classes="text-body-2", - ) - v3.VSlider( - v_model=("time_idx", 0), - min=0, - max=("Math.max(0, timestamps.length - 1)",), - step=1, - density="compact", - hide_details=True, - ) + def on_update_slider(self, dimension, index, *_, **__): + with self.state: + self.state[f"{dimension}_idx"] = index class Animation(v3.VToolbar): @@ -382,7 +361,7 @@ def _on_animation_track_change(self, animation_track, **_): @change("animation_step") def _on_animation_step(self, animation_track, animation_step, **_): if animation_track: - self.state[constants.TRACK_STEPS[animation_track]] = animation_step + self.state[f"{animation_track}_idx"] = animation_step @change("animation_play") def _on_animation_play(self, animation_play, **_): diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py index d4cbbd8..f7cb479 100644 --- a/src/e3sm_quickview/pipeline.py +++ b/src/e3sm_quickview/pipeline.py @@ -1,5 +1,5 @@ import fnmatch -import numpy as np +import json import os @@ -11,10 +11,10 @@ LegacyVTKReader, ) -from paraview import servermanager as sm -from paraview.vtk.numpy_interface import dataset_adapter as dsa from vtkmodules.vtkCommonCore import vtkLogger +from collections import defaultdict + # Define a VTK error observer class ErrorObserver: @@ -39,17 +39,10 @@ def __init__(self): self.data_file = None self.conn_file = None - self.midpoints = [] - self.interfaces = [] - # List of all available variables - self.surface_vars = [] - self.midpoint_vars = [] - self.interface_vars = [] - # List of selected variables - self.surface_vars_sel = [] - self.interface_vars_sel = [] - self.midpoint_vars_sel = [] + self.varmeta = None + self.dimmeta = None + self.slicing = defaultdict(int) self.data = None self.globe = None @@ -77,34 +70,6 @@ def __init__(self): except Exception as e: print("Error loading plugin :", e) - def UpdateLev(self, lev, ilev): - if not self.valid: - return - - if self.data is None: - return - - # Handle NaN, None, or invalid values - try: - if lev is None or (isinstance(lev, float) and np.isnan(lev)): - lev_idx = 0 - else: - lev_idx = int(lev) - except (ValueError, TypeError): - lev_idx = 0 - - try: - if ilev is None or (isinstance(ilev, float) and np.isnan(ilev)): - ilev_idx = 0 - else: - ilev_idx = int(ilev) - except (ValueError, TypeError): - ilev_idx = 0 - - if self.data.MiddleLayer != lev_idx or self.data.InterfaceLayer != ilev_idx: - self.data.MiddleLayer = lev_idx - self.data.InterfaceLayer = ilev_idx - def ApplyClipping(self, cliplong, cliplat): if not self.valid: return @@ -175,7 +140,16 @@ def UpdatePipeline(self, time=0.0): self.views["continents"] = OutputPort(cont_proj, 0) self.views["grid_lines"] = OutputPort(grid_proj, 0) - def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=False): + def UpdateSlicing(self, dimension, slice): + if self.slicing.get(dimension) == slice: + return + else: + self.slicing[dimension] = slice + if self.data is not None: + x = json.dumps(self.slicing) + self.data.Slicing = x + + def Update(self, data_file, conn_file, force_reload=False): # Check if we need to reload if ( not force_reload @@ -188,17 +162,21 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal self.conn_file = conn_file if self.data is None: - data = EAMSliceDataReader( + data = EAMSliceDataReader( # noqa: F821 registrationName="AtmosReader", ConnectivityFile=conn_file, DataFile=data_file, ) - data.MiddleLayer = midpoint - data.InterfaceLayer = interface self.data = data vtk_obj = data.GetClientSideObject() vtk_obj.AddObserver("ErrorEvent", self.observer) vtk_obj.GetExecutive().AddObserver("ErrorEvent", self.observer) + self.varmeta = vtk_obj.GetVariables() + self.dimmeta = vtk_obj.GetDimensions() + + for dim in self.dimmeta.keys(): + self.slicing[dim] = 0 + self.observer.clear() else: self.data.DataFile = data_file @@ -215,19 +193,6 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal "and are compatible" ) - data_wrapped = dsa.WrapDataObject(sm.Fetch(self.data)) - self.midpoints = data_wrapped.FieldData["lev"].tolist() - self.interfaces = data_wrapped.FieldData["ilev"].tolist() - - self.surface_vars = list( - np.asarray(self.data.GetProperty("SurfaceVariablesInfo"))[::2] - ) - self.midpoint_vars = list( - np.asarray(self.data.GetProperty("MidpointVariablesInfo"))[::2] - ) - self.interface_vars = list( - np.asarray(self.data.GetProperty("InterfaceVariablesInfo"))[::2] - ) # Ensure TimestepValues is always a list timestep_values = self.data.TimestepValues if isinstance(timestep_values, (list, tuple)): @@ -244,7 +209,7 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal ) # Step 1: Extract and transform atmospheric data - atmos_extract = EAMTransformAndExtract( + atmos_extract = EAMTransformAndExtract( # noqa: F821 registrationName="AtmosExtract", Input=self.data ) atmos_extract.LongitudeRange = [-180.0, 180.0] @@ -253,7 +218,7 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal self.extents = atmos_extract.GetDataInformation().GetBounds() # Step 2: Apply map projection to atmospheric data - atmos_proj = EAMProject( + atmos_proj = EAMProject( # noqa: F821 registrationName="AtmosProj", Input=OutputPort(atmos_extract, 0) ) atmos_proj.Projection = self.projection @@ -278,14 +243,14 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal self.globe = cont_contour # Step 4: Extract and transform continent data - cont_extract = EAMTransformAndExtract( + cont_extract = EAMTransformAndExtract( # noqa: F821 registrationName="ContExtract", Input=self.globe ) cont_extract.LongitudeRange = [-180.0, 180.0] cont_extract.LatitudeRange = [-90.0, 90.0] # Step 5: Apply map projection to continents - cont_proj = EAMProject( + cont_proj = EAMProject( # noqa: F821 registrationName="ContProj", Input=OutputPort(cont_extract, 0) ) cont_proj.Projection = self.projection @@ -293,11 +258,11 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal cont_proj.UpdatePipeline() # Step 6: Generate lat/lon grid lines - grid_gen = EAMGridLines(registrationName="GridGen") + grid_gen = EAMGridLines(registrationName="GridGen") # noqa: F821 grid_gen.UpdatePipeline() # Step 7: Apply map projection to grid lines - grid_proj = EAMProject( + grid_proj = EAMProject( # noqa: F821 registrationName="GridProj", Input=OutputPort(grid_gen, 0) ) grid_proj.Projection = self.projection @@ -319,15 +284,10 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal return self.valid - def LoadVariables(self, surf, mid, intf): + def LoadVariables(self, vars): if not self.valid: return - self.data.SurfaceVariables = surf - self.data.MidpointVariables = mid - self.data.InterfaceVariables = intf - self.vars["surface"] = surf - self.vars["midpoint"] = mid - self.vars["interface"] = intf + self.data.Variables = vars if __name__ == "__main__": diff --git a/src/e3sm_quickview/plugins/eam_reader.py b/src/e3sm_quickview/plugins/eam_reader.py index 2e1173e..ab7fa60 100644 --- a/src/e3sm_quickview/plugins/eam_reader.py +++ b/src/e3sm_quickview/plugins/eam_reader.py @@ -1,10 +1,10 @@ -from paraview.util.vtkAlgorithm import * +from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase +from paraview.util.vtkAlgorithm import smproxy, smproperty from vtkmodules.numpy_interface import dataset_adapter as dsa from vtkmodules.vtkCommonCore import vtkPoints, vtkDataArraySelection from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkCellArray from vtkmodules.util import vtkConstants, numpy_support -from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase -from paraview import print_error +from paraview import print_error, print_warning try: import netCDF4 @@ -32,38 +32,83 @@ class EAMConstants: PS0 = float(1e5) -from enum import Enum # noqa: E402 +class DimMeta: + """Simple class to store dimension metadata.""" + + def __init__(self, name, size, data=None): + self.name = name + self.size = size + self.long_name = None + self.units = None + self.data = data # Store the actual dimension coordinate values + + def __getitem__(self, key): + """Dict-like access to attributes.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Dict-like setting of attributes.""" + setattr(self, key, value) + def update_from_variable(self, var_info): + """Update metadata from netCDF variable info - only long_name and units.""" + try: + self.long_name = var_info.getncattr("long_name") + except AttributeError: + pass -class VarType(Enum): - _1D = 1 - _2D = 2 - _3Dm = 3 - _3Di = 4 + try: + self.units = var_info.getncattr("units") + except AttributeError: + pass + + def __repr__(self): + return f"DimMeta(name='{self.name}', size={self.size}, long_name='{self.long_name}')" class VarMeta: + """Simple class to store variable metadata.""" + def __init__(self, name, info, horizontal_dim=None): self.name = name - self.type = None - self.transpose = False + self.dimensions = info.dimensions # Store dimensions for slicing self.fillval = np.nan + self.long_name = None + + # Extract metadata from info + self._extract_metadata(info) + + def _extract_metadata(self, info): + """Helper to extract metadata attributes from netCDF variable.""" + # Try to get fill value from either _FillValue or missing_value + for fillattr in ["_FillValue", "missing_value"]: + value = self._get_attr(info, fillattr) + if value is not None: + self.fillval = value + break + + # Get long_name if available + long_name = self._get_attr(info, "long_name") + if long_name is not None: + self.long_name = long_name + + def _get_attr(self, info, attr_name): + """Safely get an attribute from netCDF variable info.""" + try: + return info.getncattr(attr_name) + except (AttributeError, KeyError): + return None - dims = info.dimensions + def __getitem__(self, key): + """Dict-like access to attributes.""" + return getattr(self, key, None) - if len(dims) == 1: - self.type = VarType._1D - elif len(dims) == 2: - self.type = VarType._2D - elif len(dims) == 3: - if "lev" in dims: - self.type = VarType._3Dm - elif "ilev" in dims: - self.type = VarType._3Di + def __setitem__(self, key, value): + """Dict-like setting of attributes.""" + setattr(self, key, value) - # Use dynamic horizontal dimension - if horizontal_dim and len(dims) > 1 and horizontal_dim in dims[1]: - self.transpose = True + def __repr__(self): + return f"VarMeta(name='{self.name}', dimensions={self.dimensions})" def compare(data, arrays, dim): @@ -134,9 +179,6 @@ def _markmodified(*args, **kwars): return _markmodified -import traceback # noqa: E402 - - @smproxy.reader( name="EAMSliceSource", label="EAM Slice Data Reader", @@ -168,20 +210,14 @@ def _markmodified(*args, **kwars): ) @smproperty.xml( """ - - - """ -) -@smproperty.xml( - """ - - + + JSON representing dimension slices (e.g. {"lev": 0, "ilev": 1}) + """ ) class EAMSliceSource(VTKPythonAlgorithmBase): @@ -200,33 +236,18 @@ def __init__(self): # Variables for dimension sliders self._time = 0 - self._lev = 0 - self._ilev = 0 - # Arrays to store field names in netCDF file - self._info_vars = [] # 1D info variables - self._surface_vars = [] # 2D surface variables - self._interface_vars = [] # 3D interface layer variables - self._midpoint_vars = [] # 3D midpoint layer variables + # Dictionaries to store metadata objects + self._variables = {} # Will store VarMeta objects by name + self._dimensions = {} # Will store DimMeta objects by name self._timeSteps = [] + # Dictionary to store dimension slices + self._slices = {} # vtkDataArraySelection to allow users choice for fields # to fetch from the netCDF data set - self._info_selection = vtkDataArraySelection() - self._surface_selection = vtkDataArraySelection() - self._interface_selection = vtkDataArraySelection() - self._midpoint_selection = vtkDataArraySelection() - # Cache for non temporal variables - # Store { names : data } - self._info_vars_cache = {} + self._variable_selection = vtkDataArraySelection() # Add observers for the selection arrays - self._info_selection.AddObserver("ModifiedEvent", createModifiedCallback(self)) - self._surface_selection.AddObserver( - "ModifiedEvent", createModifiedCallback(self) - ) - self._interface_selection.AddObserver( - "ModifiedEvent", createModifiedCallback(self) - ) - self._midpoint_selection.AddObserver( + self._variable_selection.AddObserver( "ModifiedEvent", createModifiedCallback(self) ) # Flag for area var to calculate averages @@ -246,12 +267,12 @@ def __init__(self): self._cached_ncells2D = None # Special variable caching - self._cached_lev = None - self._cached_ilev = None + # self._cached_lev = None + # self._cached_ilev = None self._cached_area = None - + # Dynamic dimension detection - self._horizontal_dim = None # From connectivity file + self._horizontal_dim = None self._data_horizontal_dim = None # Matched in data file def __del__(self): @@ -302,39 +323,37 @@ def _get_var_dataset(self): # Method to clear all the variable names def _clear(self): - self._info_vars.clear() - self._surface_vars.clear() - self._interface_vars.clear() - self._midpoint_vars.clear() + self._variables.clear() + # Clear special variable cache when metadata changes - self._cached_lev = None - self._cached_ilev = None self._cached_area = None + # Clear dimension detection - self._horizontal_dim = None self._data_horizontal_dim = None def _identify_horizontal_dimension(self, meshdata, vardata): """Identify horizontal dimension from connectivity and match with data file.""" if self._horizontal_dim and self._data_horizontal_dim: return # Already identified - + # Get first dimension from connectivity file conn_dims = list(meshdata.dimensions.keys()) if not conn_dims: print_error("No dimensions found in connectivity file") return - + self._horizontal_dim = conn_dims[0] conn_size = meshdata.dimensions[self._horizontal_dim].size - + # Match dimension in data file by size for dim_name, dim_obj in vardata.dimensions.items(): if dim_obj.size == conn_size: self._data_horizontal_dim = dim_name return - - print_error(f"Could not match horizontal dimension size {conn_size} in data file") + + print_error( + f"Could not match horizontal dimension size {conn_size} in data file" + ) def _clear_geometry_cache(self): """Clear cached geometry data.""" @@ -344,6 +363,11 @@ def _clear_geometry_cache(self): self._cached_offsets = None self._cached_ncells2D = None + ''' + Disable the derivation of lev/ilev for the new approach -- the new approach + relies on the identified dimensions from the data file and connectivity files. + We could reintroduce this later if required. + def _get_cached_lev(self, vardata): """Get cached lev array or compute and cache it.""" if self._cached_lev is None: @@ -359,6 +383,7 @@ def _get_cached_ilev(self, vardata): vardata, EAMConstants.ILEV, EAMConstants.HYAI, EAMConstants.HYBI ) return self._cached_ilev + ''' def _get_cached_area(self, vardata): """Get cached area array or load and cache it.""" @@ -371,27 +396,31 @@ def _get_cached_area(self, vardata): self._cached_area[mask] = np.nan return self._cached_area - def _load_2d_variable(self, vardata, varmeta, timeInd): - """Load 2D variable data with optimized operations.""" - # Get data without unnecessary copy - data = vardata[varmeta.name][:].data[timeInd].flatten() - data = np.where(data == varmeta.fillval, np.nan, data) - return data - - def _load_3d_slice(self, vardata, varmeta, timeInd, start_idx, end_idx): - """Load a slice of 3D variable data with optimized operations.""" - # Load full 3D data for time step - if not varmeta.transpose: - data = vardata[varmeta.name][:].data[timeInd].flatten()[start_idx:end_idx] - else: - data = ( - vardata[varmeta.name][:] - .data[timeInd] - .transpose() - .flatten()[start_idx:end_idx] - ) - data = np.where(data == varmeta.fillval, np.nan, data) - return data + def _load_variable(self, vardata, varmeta, timeInd): + """Load variable data with dimension-based slicing.""" + try: + # Build slice tuple based on variable's dimensions and user-selected slices + slice_tuple = [] + for dim in varmeta.dimensions: + if dim == self._data_horizontal_dim: + continue + # elif dim == "time": + # # Use timeInd for time dimension + # slice_tuple.append(timeInd) + elif hasattr(self, "_slices") and dim in self._slices: + # Use user-specified slice for this dimension + slice_tuple.append(self._slices[dim]) + else: + # Use all data for unspecified dimensions + slice_tuple.append(slice(None)) + # Get data with proper slicing + data = vardata[varmeta.name][tuple(slice_tuple)].data.flatten() + data = np.where(data == varmeta.fillval, np.nan, data) + return data + except Exception as e: + print_error(f"Error loading variable {varmeta.name}: {e}") + # Return empty array on error + return np.array([]) def _get_enabled_arrays(self, var_list, selection_obj): """Get list of enabled variable names from selection object.""" @@ -409,12 +438,12 @@ def _build_geometry(self, meshdata): dims = meshdata.dimensions mvars = np.array(list(meshdata.variables.keys())) - + # Use the identified horizontal dimension if not self._horizontal_dim: print_error("Horizontal dimension not identified in connectivity file") return - + ncells2D = dims[self._horizontal_dim].size self._cached_ncells2D = ncells2D @@ -461,72 +490,73 @@ def _build_geometry(self, meshdata): def _populate_variable_metadata(self): if self._DataFileName is None or self._ConnFileName is None: return - + meshdata = self._get_mesh_dataset() vardata = self._get_var_dataset() - + # Identify horizontal dimensions first self._identify_horizontal_dimension(meshdata, vardata) - + if not self._data_horizontal_dim: print_error("Could not detect horizontal dimension in data file") return - + # Clear existing selection arrays BEFORE adding new ones - self._surface_selection.RemoveAllArrays() - self._midpoint_selection.RemoveAllArrays() - self._interface_selection.RemoveAllArrays() - - # Define dimension sets dynamically based on detected dimension - dims1 = set([self._data_horizontal_dim]) - dims2 = set(['time', self._data_horizontal_dim]) - dims3m = set(['time', 'lev', self._data_horizontal_dim]) - dims3i = set(['time', 'ilev', self._data_horizontal_dim]) + self._variable_selection.RemoveAllArrays() + # First pass: collect dimensions used by valid variables + all_dimensions = set() for name, info in vardata.variables.items(): dims = set(info.dimensions) - if not (dims == dims1 or dims == dims2 or dims == dims3m or dims == dims3i): + if self._data_horizontal_dim not in dims: continue varmeta = VarMeta(name, info, self._data_horizontal_dim) - if varmeta.type == VarType._1D: - self._info_vars.append(varmeta) - if "area" in name: - self._areavar = varmeta - elif varmeta.type == VarType._2D: - self._surface_vars.append(varmeta) - self._surface_selection.AddArray(name) - elif varmeta.type == VarType._3Dm: - self._midpoint_vars.append(varmeta) - self._midpoint_selection.AddArray(name) - elif varmeta.type == VarType._3Di: - self._interface_vars.append(varmeta) - self._interface_selection.AddArray(name) - try: - fillval = info.getncattr("_FillValue") - varmeta.fillval = fillval - except Exception: - try: - fillval = info.getncattr("missing_value") - varmeta.fillval = fillval - except Exception: - pass - self._surface_selection.DisableAllArrays() - self._interface_selection.DisableAllArrays() - self._midpoint_selection.DisableAllArrays() + if len(dims) == 1 and "area" in name: + self._areavar = varmeta + if len(dims) > 1: + all_dimensions.update(dims) + self._variables[name] = varmeta + self._variable_selection.AddArray(name) + + # Remove the horizontal dimension from sliceable dimensions + all_dimensions.discard(self._data_horizontal_dim) + + # Second pass: only populate _dimensions for dimensions that are: + # 1. Used by at least one valid variable + # 2. Have arity > 1 + self._dimensions.clear() + for dim_name in all_dimensions: + if dim_name in vardata.dimensions: + dim_obj = vardata.dimensions[dim_name] + if dim_obj.size > 1: + dim_meta = DimMeta(dim_name, dim_obj.size) + if dim_name in vardata.variables: + dim_var = vardata.variables[dim_name] + try: + dim_meta.data = vardata[dim_name][:].data + except Exception: + pass + dim_meta.update_from_variable(dim_var) + self._dimensions[dim_name] = dim_meta + + # Initialize slices for relevant dimensions + for dim in self._dimensions: + if dim not in self._slices: + self._slices[dim] = 0 + + self._variable_selection.DisableAllArrays() # Clear old timestamps before adding new ones self._timeSteps.clear() - timesteps = vardata["time"][:].data.flatten() - self._timeSteps.extend(timesteps) + if "time" in vardata.variables: + timesteps = vardata["time"][:].data.flatten() + self._timeSteps.extend(timesteps) def SetDataFileName(self, fname): if fname is not None and fname != "None": if fname != self._DataFileName: self._DataFileName = fname self._dirty = True - self._surface_update = True - self._midpoint_update = True - self._interface_update = True self._clear() # Close old dataset if filename changed if self._cached_var_filename != fname and self._var_dataset is not None: @@ -542,9 +572,6 @@ def SetConnFileName(self, fname): if fname != self._ConnFileName: self._ConnFileName = fname self._dirty = True - self._surface_update = True - self._midpoint_update = True - self._interface_update = True self._clear() # Clear dimension cache # Close old dataset if filename changed if self._cached_mesh_filename != fname and self._mesh_dataset is not None: @@ -559,23 +586,73 @@ def SetConnFileName(self, fname): self._populate_variable_metadata() self.Modified() - def SetMiddleLayer(self, lev): - if self._lev != lev: - self._lev = lev - self._midpoint_update = True - self.Modified() + def SetSlicing(self, slice_str): + # Parse JSON string containing dimension slices and update self._slices + # Initialize _slices if not already done + if not hasattr(self, "_slices"): + self._slices = {} - def SetInterfaceLayer(self, ilev): - if self._ilev != ilev: - self._ilev = ilev - self._interface_update = True - self.Modified() + # Initialize dimensions if not already done + if not hasattr(self, "_dimensions"): + self._dimensions = {} + + if slice_str and slice_str.strip(): # Check for non-empty string + try: + import json + + slice_dict = json.loads(slice_str) + + # Validate and update slices for provided dimensions + invalid_slices = [] + for dim, slice_val in slice_dict.items(): + # Check if dimension exists + if dim in self._dimensions: + dim_meta = self._dimensions[dim] + dim_size = dim_meta.size + # Validate slice index + if isinstance(slice_val, int): + if slice_val < 0 or slice_val >= dim_size: + # Include dimension long name if available + dim_display = f"{dim}" + if dim_meta.long_name: + dim_display += f" ({dim_meta.long_name})" + invalid_slices.append( + f"{dim_display}={slice_val} (valid range: 0-{dim_size - 1})" + ) + else: + self._slices[dim] = slice_val + else: + print_error( + f"Slice value for '{dim}' must be an integer, got {type(slice_val).__name__}" + ) + else: + # Store the slice anyway for dimensions we haven't seen yet + # (might be populated later) + self._slices[dim] = slice_val + if self._dimensions: # Only warn if we have dimension info + print_warning(f"Dimension '{dim}' not found in data file") + + if invalid_slices: + print_error(f"Invalid slice indices: {', '.join(invalid_slices)}") + else: + self.Modified() + + except (json.JSONDecodeError, ValueError) as e: + print_error(f"Invalid JSON for slicing: {e}") + except Exception as e: + print_error(f"Error setting slices: {e}") def SetCalculateAverages(self, calcavg): if self._avg != calcavg: self._avg = calcavg self.Modified() + def GetVariables(self): + return self._variables + + def GetDimensions(self): + return self._dimensions + @smproperty.doublevector( name="TimestepValues", information_only="1", si_class="vtkSITimeStepsProperty" ) @@ -587,17 +664,9 @@ def GetTimestepValues(self): # load. To expose that in ParaView, simply use the # smproperty.dataarrayselection(). # This method **must** return a `vtkDataArraySelection` instance. - @smproperty.dataarrayselection(name="Surface Variables") + @smproperty.dataarrayselection(name="Variables") def GetSurfaceVariables(self): - return self._surface_selection - - @smproperty.dataarrayselection(name="Midpoint Variables") - def GetMidpointVariables(self): - return self._midpoint_selection - - @smproperty.dataarrayselection(name="Interface Variables") - def GetInterfaceVariables(self): - return self._interface_selection + return self._variable_selection def RequestInformation(self, request, inInfo, outInfo): executive = self.GetExecutive() @@ -659,14 +728,14 @@ def RequestData(self, request, inInfo, outInfo): # Ensure dimensions are identified self._identify_horizontal_dimension(meshdata, vardata) - + if not self._horizontal_dim or not self._data_horizontal_dim: print_error("Could not identify required dimensions from files") return 0 - + # Build geometry if not cached self._build_geometry(meshdata) - + if self._cached_points is None: print_error("Could not build geometry from connectivity file") return 0 @@ -687,89 +756,19 @@ def RequestData(self, request, inInfo, outInfo): self._dirty = False - # Use cached ncells2D - ncells2D = self._cached_ncells2D - # Needed to drop arrays from cached VTK Object to_remove = set() last_num_arrays = output_mesh.CellData.GetNumberOfArrays() for i in range(last_num_arrays): to_remove.add(output_mesh.CellData.GetArrayName(i)) - for varmeta in self._surface_vars: - if self._surface_selection.ArrayIsEnabled(varmeta.name): - if output_mesh.CellData.HasArray(varmeta.name): - to_remove.remove(varmeta.name) - if ( - not output_mesh.CellData.HasArray(varmeta.name) - or self._surface_update - ): - data = self._load_2d_variable(vardata, varmeta, timeInd) - output_mesh.CellData.append(data, varmeta.name) - self._surface_update = False - - try: - lev_field_name = "lev" - has_lev_field = output_mesh.FieldData.HasArray(lev_field_name) - lev = self._get_cached_lev(vardata) - if lev is not None: - if not has_lev_field: - output_mesh.FieldData.append(lev, lev_field_name) - if self._lev >= vardata.dimensions[lev_field_name].size: - print_error( - f"User provided input for middle layer {self._lev} larger than actual data {len(lev) - 1}" - ) - lstart = self._lev * ncells2D - lend = lstart + ncells2D - - for varmeta in self._midpoint_vars: - if self._midpoint_selection.ArrayIsEnabled(varmeta.name): - if output_mesh.CellData.HasArray(varmeta.name): - to_remove.remove(varmeta.name) - if ( - not output_mesh.CellData.HasArray(varmeta.name) - or self._midpoint_update - ): - data = self._load_3d_slice( - vardata, varmeta, timeInd, lstart, lend - ) - output_mesh.CellData.append(data, varmeta.name) - self._midpoint_update = False - except Exception as e: - print_error("Error occurred while processing middle layer variables :", e) - traceback.print_exc() - - try: - ilev_field_name = "ilev" - has_ilev_field = output_mesh.FieldData.HasArray(ilev_field_name) - ilev = self._get_cached_ilev(vardata) - if ilev is not None: - if not has_ilev_field: - output_mesh.FieldData.append(ilev, ilev_field_name) - if self._ilev >= vardata.dimensions[ilev_field_name].size: - print_error( - f"User provided input for middle layer {self._ilev} larger than actual data {len(ilev) - 1}" - ) - ilstart = self._ilev * ncells2D - ilend = ilstart + ncells2D - for varmeta in self._interface_vars: - if self._interface_selection.ArrayIsEnabled(varmeta.name): - if output_mesh.CellData.HasArray(varmeta.name): - to_remove.remove(varmeta.name) - if ( - not output_mesh.CellData.HasArray(varmeta.name) - or self._interface_update - ): - data = self._load_3d_slice( - vardata, varmeta, timeInd, ilstart, ilend - ) - output_mesh.CellData.append(data, varmeta.name) - self._interface_update = False - except Exception as e: - print_error( - "Error occurred while processing interface layer variables :", e - ) - traceback.print_exc() + for name, varmeta in self._variables.items(): + if self._variable_selection.ArrayIsEnabled(name): + if output_mesh.CellData.HasArray(name): + to_remove.remove(name) + if not output_mesh.CellData.HasArray(name) or self._surface_update: + data = self._load_variable(vardata, varmeta, timeInd) + output_mesh.CellData.append(data, name) area_var_name = "area" if self._areavar and not output_mesh.CellData.HasArray(area_var_name): diff --git a/src/e3sm_quickview/utils/colors.py b/src/e3sm_quickview/utils/colors.py new file mode 100644 index 0000000..b302ed7 --- /dev/null +++ b/src/e3sm_quickview/utils/colors.py @@ -0,0 +1,18 @@ +TYPE_COLORS = [ + "success", + "info", + "warning", + "error", + "purple", + "cyan", + "teal", + "indigo", + "pink", + "amber", + "lime", + "deep-purple", +] + + +def get_type_color(index: int) -> str: + return TYPE_COLORS[index % len(TYPE_COLORS)] diff --git a/src/e3sm_quickview/utils/constants.py b/src/e3sm_quickview/utils/constants.py index cf43c8f..278ed3a 100644 --- a/src/e3sm_quickview/utils/constants.py +++ b/src/e3sm_quickview/utils/constants.py @@ -1,16 +1,4 @@ VAR_HEADERS = [ {"title": "Name", "align": "start", "key": "name", "sortable": True}, {"title": "Type", "align": "start", "key": "type", "sortable": True}, -] - -TRACK_STEPS = { - "timestamps": "time_idx", - "interfaces": "interface_idx", - "midpoints": "midpoint_idx", -} - -TRACK_ENTRIES = { - "timestamps": {"title": "Time", "value": "timestamps"}, - "midpoints": {"title": "Layer Midpoints", "value": "midpoints"}, - "interfaces": {"title": "Layer Interfaces", "value": "interfaces"}, -} +] \ No newline at end of file diff --git a/src/e3sm_quickview/utils/js.py b/src/e3sm_quickview/utils/js.py index 441936e..7c2e51e 100644 --- a/src/e3sm_quickview/utils/js.py +++ b/src/e3sm_quickview/utils/js.py @@ -1,16 +1,2 @@ -def var_count(name): - return f"variables_selected.filter((v) => v[0] === '{name[0]}').length" - - -def var_remove(name): - return ( - f"variables_selected = variables_selected.filter((v) => v[0] !== '{name[0]}')" - ) - - -def var_title(name): - return " ".join(["{{", var_count(name), "}}", name.capitalize()]) - - def is_active(name): return f"active_tools.includes('{name}')" diff --git a/src/e3sm_quickview/view_manager.py b/src/e3sm_quickview/view_manager.py index 7e26461..6a3bd3a 100644 --- a/src/e3sm_quickview/view_manager.py +++ b/src/e3sm_quickview/view_manager.py @@ -9,7 +9,7 @@ from paraview import simple -from e3sm_quickview.components import view +from e3sm_quickview.components import view as tview from e3sm_quickview.utils.color import get_cached_colorbar_image, COLORBAR_CACHE from e3sm_quickview.presets import COLOR_BLIND_SAFE @@ -41,12 +41,6 @@ def auto_size_to_col(size): "flow": None, } -TYPE_COLOR = { - "s": "success", - "i": "info", - "m": "warning", -} - def lut_name(element): return element.get("name").lower() @@ -235,7 +229,7 @@ def _build_ui(self): dense=True, classes="ma-0 pa-0 bg-black opacity-90 d-flex align-center", ): - view.create_size_menu(self.name, self.config) + tview.create_size_menu(self.name, self.config) with html.Div( self.variable_name, classes="text-subtitle-2 pr-2", @@ -302,7 +296,7 @@ def _build_ui(self): self.view, interactive_ratio=1, ctx_name=self.name ) - view.create_bottom_bar(self.config, self.update_color_preset) + tview.create_bottom_bar(self.config, self.update_color_preset) class ViewManager(TrameComponent): @@ -422,21 +416,25 @@ def build_auto_layout(self, variables=None): # Create UI based on variables self.state.swap_groups = {} + # Build a lookup from type name to color from state.variable_types + type_to_color = {vt["name"]: vt["color"] for vt in self.state.variable_types} with DivLayout(self.server, template_name="auto_layout") as self.ui: if self.state.layout_grouped: with v3.VCol(classes="pa-1"): - for var_type in "smi": + for var_type in variables.keys(): var_names = variables[var_type] total_size = len(var_names) if total_size == 0: continue + # Look up color from variable_types to match chip colors + border_color = type_to_color.get(str(var_type), "primary") with v3.VAlert( border="start", classes="pr-1 py-1 pl-3 mb-1", variant="flat", - border_color=TYPE_COLOR[var_type], + border_color=border_color, ): with v3.VRow(dense=True): for name in var_names: @@ -467,7 +465,7 @@ def build_auto_layout(self, variables=None): else: all_names = [name for names in variables.values() for name in names] with v3.VRow(dense=True, classes="pa-2"): - for var_type in "smi": + for var_type in variables.keys(): var_names = variables[var_type] for name in var_names: view = self.get_view(name, var_type) @@ -503,7 +501,7 @@ def build_auto_layout(self, variables=None): existed_order = set() order_max = 0 orders_to_update = [] - for var_type in "smi": + for var_type in variables.keys(): var_names = variables[var_type] for name in var_names: config = self.get_view(name, var_type).config