From 552fbd940d0cc586666300a3ca9a2d2d1cf8e368 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Mon, 17 Nov 2025 10:09:44 -0800 Subject: [PATCH 1/6] feat: adding partial changes for supporing general ESM reader --- src/e3sm_quickview/app.py | 1 + src/e3sm_quickview/pipeline.py | 20 +- src/e3sm_quickview/plugins/eam_reader.py | 458 ++++++++++++----------- 3 files changed, 234 insertions(+), 245 deletions(-) diff --git a/src/e3sm_quickview/app.py b/src/e3sm_quickview/app.py index 36fdf2c..f02d0ad 100644 --- a/src/e3sm_quickview/app.py +++ b/src/e3sm_quickview/app.py @@ -381,6 +381,7 @@ async def data_loading_open(self, simulation, connectivity): for name in self.source.midpoint_vars ), ] + print(self.state.variables_listing) # Update Layer/Time values and ui layout n_cols = 0 diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py index d4cbbd8..c7e5664 100644 --- a/src/e3sm_quickview/pipeline.py +++ b/src/e3sm_quickview/pipeline.py @@ -39,18 +39,10 @@ def __init__(self): self.data_file = None self.conn_file = None - self.midpoints = [] - self.interfaces = [] - # List of all available variables - self.surface_vars = [] - self.midpoint_vars = [] - self.interface_vars = [] - # List of selected variables - self.surface_vars_sel = [] - self.interface_vars_sel = [] - self.midpoint_vars_sel = [] - + self.variables = None + self.dimensions = None + self.data = None self.globe = None self.projection = "Cyl. Equidistant" @@ -193,8 +185,6 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal ConnectivityFile=conn_file, DataFile=data_file, ) - data.MiddleLayer = midpoint - data.InterfaceLayer = interface self.data = data vtk_obj = data.GetClientSideObject() vtk_obj.AddObserver("ErrorEvent", self.observer) @@ -325,10 +315,6 @@ def LoadVariables(self, surf, mid, intf): self.data.SurfaceVariables = surf self.data.MidpointVariables = mid self.data.InterfaceVariables = intf - self.vars["surface"] = surf - self.vars["midpoint"] = mid - self.vars["interface"] = intf - if __name__ == "__main__": e = EAMVisSource() diff --git a/src/e3sm_quickview/plugins/eam_reader.py b/src/e3sm_quickview/plugins/eam_reader.py index 2e1173e..d526b93 100644 --- a/src/e3sm_quickview/plugins/eam_reader.py +++ b/src/e3sm_quickview/plugins/eam_reader.py @@ -4,7 +4,7 @@ from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkCellArray from vtkmodules.util import vtkConstants, numpy_support from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase -from paraview import print_error +from paraview import print_error, print_warning try: import netCDF4 @@ -34,36 +34,80 @@ class EAMConstants: from enum import Enum # noqa: E402 - -class VarType(Enum): - _1D = 1 - _2D = 2 - _3Dm = 3 - _3Di = 4 - +class DimMeta: + """Simple class to store dimension metadata.""" + def __init__(self, name, size, data=None): + self.name = name + self.size = size + self.long_name = None + self.units = None + self.data = data # Store the actual dimension coordinate values + + def __getitem__(self, key): + """Dict-like access to attributes.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Dict-like setting of attributes.""" + setattr(self, key, value) + + def update_from_variable(self, var_info): + """Update metadata from netCDF variable info - only long_name and units.""" + try: + self.long_name = var_info.getncattr('long_name') + except AttributeError: + pass + + try: + self.units = var_info.getncattr('units') + except AttributeError: + pass + + def __repr__(self): + return f"DimMeta(name='{self.name}', size={self.size}, long_name='{self.long_name}')" class VarMeta: + """Simple class to store variable metadata.""" def __init__(self, name, info, horizontal_dim=None): self.name = name - self.type = None - self.transpose = False + self.dimensions = info.dimensions # Store dimensions for slicing self.fillval = np.nan - - dims = info.dimensions - - if len(dims) == 1: - self.type = VarType._1D - elif len(dims) == 2: - self.type = VarType._2D - elif len(dims) == 3: - if "lev" in dims: - self.type = VarType._3Dm - elif "ilev" in dims: - self.type = VarType._3Di - - # Use dynamic horizontal dimension - if horizontal_dim and len(dims) > 1 and horizontal_dim in dims[1]: - self.transpose = True + self.long_name = None + + # Extract metadata from info + self._extract_metadata(info) + + def _extract_metadata(self, info): + """Helper to extract metadata attributes from netCDF variable.""" + # Try to get fill value from either _FillValue or missing_value + for fillattr in ['_FillValue', 'missing_value']: + value = self._get_attr(info, fillattr) + if value is not None: + self.fillval = value + break + + # Get long_name if available + long_name = self._get_attr(info, 'long_name') + if long_name is not None: + self.long_name = long_name + + def _get_attr(self, info, attr_name): + """Safely get an attribute from netCDF variable info.""" + try: + return info.getncattr(attr_name) + except (AttributeError, KeyError): + return None + + def __getitem__(self, key): + """Dict-like access to attributes.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Dict-like setting of attributes.""" + setattr(self, key, value) + + def __repr__(self): + return f"VarMeta(name='{self.name}', dimensions={self.dimensions})" def compare(data, arrays, dim): @@ -168,20 +212,14 @@ def _markmodified(*args, **kwars): ) @smproperty.xml( """ - - - """ -) -@smproperty.xml( - """ - - + + JSON representing dimension slices (e.g. {"lev": 0, "ilev": 1}) + """ ) class EAMSliceSource(VTKPythonAlgorithmBase): @@ -200,33 +238,19 @@ def __init__(self): # Variables for dimension sliders self._time = 0 - self._lev = 0 - self._ilev = 0 - # Arrays to store field names in netCDF file - self._info_vars = [] # 1D info variables - self._surface_vars = [] # 2D surface variables - self._interface_vars = [] # 3D interface layer variables - self._midpoint_vars = [] # 3D midpoint layer variables + # Dictionaries to store metadata objects + self._variables = {} # Will store VarMeta objects by name + self._dimensions = {} # Will store DimMeta objects by name self._timeSteps = [] + # Dictionary to store dimension slices + self._slices = {} + # vtkDataArraySelection to allow users choice for fields # to fetch from the netCDF data set - self._info_selection = vtkDataArraySelection() - self._surface_selection = vtkDataArraySelection() - self._interface_selection = vtkDataArraySelection() - self._midpoint_selection = vtkDataArraySelection() - # Cache for non temporal variables - # Store { names : data } - self._info_vars_cache = {} + self._variable_selection = vtkDataArraySelection() # Add observers for the selection arrays - self._info_selection.AddObserver("ModifiedEvent", createModifiedCallback(self)) - self._surface_selection.AddObserver( - "ModifiedEvent", createModifiedCallback(self) - ) - self._interface_selection.AddObserver( - "ModifiedEvent", createModifiedCallback(self) - ) - self._midpoint_selection.AddObserver( + self._variable_selection.AddObserver( "ModifiedEvent", createModifiedCallback(self) ) # Flag for area var to calculate averages @@ -246,12 +270,12 @@ def __init__(self): self._cached_ncells2D = None # Special variable caching - self._cached_lev = None - self._cached_ilev = None + #self._cached_lev = None + #self._cached_ilev = None self._cached_area = None # Dynamic dimension detection - self._horizontal_dim = None # From connectivity file + self._horizontal_dim = None self._data_horizontal_dim = None # Matched in data file def __del__(self): @@ -302,16 +326,12 @@ def _get_var_dataset(self): # Method to clear all the variable names def _clear(self): - self._info_vars.clear() - self._surface_vars.clear() - self._interface_vars.clear() - self._midpoint_vars.clear() + self._variables.clear() + # Clear special variable cache when metadata changes - self._cached_lev = None - self._cached_ilev = None self._cached_area = None + # Clear dimension detection - self._horizontal_dim = None self._data_horizontal_dim = None def _identify_horizontal_dimension(self, meshdata, vardata): @@ -344,6 +364,11 @@ def _clear_geometry_cache(self): self._cached_offsets = None self._cached_ncells2D = None + ''' + Disable the derivation of lev/ilev for the new approach -- the new approach + relies on the identified dimensions from the data file and connectivity files. + We could reintroduce this later if required. + def _get_cached_lev(self, vardata): """Get cached lev array or compute and cache it.""" if self._cached_lev is None: @@ -359,6 +384,7 @@ def _get_cached_ilev(self, vardata): vardata, EAMConstants.ILEV, EAMConstants.HYAI, EAMConstants.HYBI ) return self._cached_ilev + ''' def _get_cached_area(self, vardata): """Get cached area array or load and cache it.""" @@ -371,27 +397,32 @@ def _get_cached_area(self, vardata): self._cached_area[mask] = np.nan return self._cached_area - def _load_2d_variable(self, vardata, varmeta, timeInd): - """Load 2D variable data with optimized operations.""" - # Get data without unnecessary copy - data = vardata[varmeta.name][:].data[timeInd].flatten() - data = np.where(data == varmeta.fillval, np.nan, data) - return data - - def _load_3d_slice(self, vardata, varmeta, timeInd, start_idx, end_idx): - """Load a slice of 3D variable data with optimized operations.""" - # Load full 3D data for time step - if not varmeta.transpose: - data = vardata[varmeta.name][:].data[timeInd].flatten()[start_idx:end_idx] - else: - data = ( - vardata[varmeta.name][:] - .data[timeInd] - .transpose() - .flatten()[start_idx:end_idx] - ) - data = np.where(data == varmeta.fillval, np.nan, data) - return data + def _load_variable(self, vardata, varmeta, timeInd): + """Load variable data with dimension-based slicing.""" + try: + # Build slice tuple based on variable's dimensions and user-selected slices + slice_tuple = [] + for dim in varmeta.dimensions: + if dim == self._data_horizontal_dim: + continue + elif dim == "time": + # Use timeInd for time dimension + slice_tuple.append(timeInd) + elif hasattr(self, '_slices') and dim in self._slices: + # Use user-specified slice for this dimension + slice_tuple.append(self._slices[dim]) + else: + # Use all data for unspecified dimensions + slice_tuple.append(slice(None)) + + # Get data with proper slicing + data = vardata[varmeta.name][tuple(slice_tuple)].data.flatten() + data = np.where(data == varmeta.fillval, np.nan, data) + return data + except Exception as e: + print_error(f"Error loading variable {varmeta.name}: {e}") + # Return empty array on error + return np.array([]) def _get_enabled_arrays(self, var_list, selection_obj): """Get list of enabled variable names from selection object.""" @@ -473,60 +504,59 @@ def _populate_variable_metadata(self): return # Clear existing selection arrays BEFORE adding new ones - self._surface_selection.RemoveAllArrays() - self._midpoint_selection.RemoveAllArrays() - self._interface_selection.RemoveAllArrays() + self._variable_selection.RemoveAllArrays() + + # Collect all unique dimensions that need slicing + all_dimensions = set() + + # Store dimension metadata + self._dimensions.clear() + for dim_name, dim_obj in vardata.dimensions.items(): + # Create DimMeta object + dim_meta = DimMeta(dim_name, dim_obj.size) + + # Try to load dimension coordinate variable if it exists + if dim_name in vardata.variables: + dim_var = vardata.variables[dim_name] + # Load dimension data + try: + dim_meta.data = vardata[dim_name][:].data + except: + pass + # Update metadata from variable attributes + dim_meta.update_from_variable(dim_var) + + self._dimensions[dim_name] = dim_meta - # Define dimension sets dynamically based on detected dimension - dims1 = set([self._data_horizontal_dim]) - dims2 = set(['time', self._data_horizontal_dim]) - dims3m = set(['time', 'lev', self._data_horizontal_dim]) - dims3i = set(['time', 'ilev', self._data_horizontal_dim]) - for name, info in vardata.variables.items(): dims = set(info.dimensions) - if not (dims == dims1 or dims == dims2 or dims == dims3m or dims == dims3i): + if not self._data_horizontal_dim in dims: continue varmeta = VarMeta(name, info, self._data_horizontal_dim) - if varmeta.type == VarType._1D: - self._info_vars.append(varmeta) - if "area" in name: + if len(dims) == 1 and "area" in name: self._areavar = varmeta - elif varmeta.type == VarType._2D: - self._surface_vars.append(varmeta) - self._surface_selection.AddArray(name) - elif varmeta.type == VarType._3Dm: - self._midpoint_vars.append(varmeta) - self._midpoint_selection.AddArray(name) - elif varmeta.type == VarType._3Di: - self._interface_vars.append(varmeta) - self._interface_selection.AddArray(name) - try: - fillval = info.getncattr("_FillValue") - varmeta.fillval = fillval - except Exception: - try: - fillval = info.getncattr("missing_value") - varmeta.fillval = fillval - except Exception: - pass - self._surface_selection.DisableAllArrays() - self._interface_selection.DisableAllArrays() - self._midpoint_selection.DisableAllArrays() + if len(dims) > 1: + all_dimensions.update(dims) + self._variables[name] = varmeta # Store by name as key + self._variable_selection.AddArray(name) + + # Initialize slices for all dimensions at once + for dim in all_dimensions: + if dim not in self._slices: # Only set if not already set + self._slices[dim] = 0 + self._variable_selection.DisableAllArrays() # Clear old timestamps before adding new ones self._timeSteps.clear() - timesteps = vardata["time"][:].data.flatten() - self._timeSteps.extend(timesteps) + if "time" in vardata.variables: + timesteps = vardata["time"][:].data.flatten() + self._timeSteps.extend(timesteps) def SetDataFileName(self, fname): if fname is not None and fname != "None": if fname != self._DataFileName: self._DataFileName = fname self._dirty = True - self._surface_update = True - self._midpoint_update = True - self._interface_update = True self._clear() # Close old dataset if filename changed if self._cached_var_filename != fname and self._var_dataset is not None: @@ -542,9 +572,6 @@ def SetConnFileName(self, fname): if fname != self._ConnFileName: self._ConnFileName = fname self._dirty = True - self._surface_update = True - self._midpoint_update = True - self._interface_update = True self._clear() # Clear dimension cache # Close old dataset if filename changed if self._cached_mesh_filename != fname and self._mesh_dataset is not None: @@ -559,23 +586,70 @@ def SetConnFileName(self, fname): self._populate_variable_metadata() self.Modified() - def SetMiddleLayer(self, lev): - if self._lev != lev: - self._lev = lev - self._midpoint_update = True - self.Modified() - - def SetInterfaceLayer(self, ilev): - if self._ilev != ilev: - self._ilev = ilev - self._interface_update = True - self.Modified() + def SetSlicing(self, slice_str): + # Parse JSON string containing dimension slices and update self._slices + # Initialize _slices if not already done + if not hasattr(self, '_slices'): + self._slices = {} + + # Initialize dimensions if not already done + if not hasattr(self, '_dimensions'): + self._dimensions = {} + + if slice_str and slice_str.strip(): # Check for non-empty string + try: + import json + slice_dict = json.loads(slice_str) + + # Validate and update slices for provided dimensions + invalid_slices = [] + for dim, slice_val in slice_dict.items(): + # Check if dimension exists + if dim in self._dimensions: + dim_meta = self._dimensions[dim] + dim_size = dim_meta.size + # Validate slice index + if isinstance(slice_val, int): + if slice_val < 0 or slice_val >= dim_size: + # Include dimension long name if available + dim_display = f"{dim}" + if dim_meta.long_name: + dim_display += f" ({dim_meta.long_name})" + invalid_slices.append( + f"{dim_display}={slice_val} (valid range: 0-{dim_size-1})" + ) + else: + self._slices[dim] = slice_val + else: + print_error(f"Slice value for '{dim}' must be an integer, got {type(slice_val).__name__}") + else: + # Store the slice anyway for dimensions we haven't seen yet + # (might be populated later) + self._slices[dim] = slice_val + if self._dimensions: # Only warn if we have dimension info + print_warning(f"Dimension '{dim}' not found in data file") + + if invalid_slices: + print_error(f"Invalid slice indices: {', '.join(invalid_slices)}") + else: + self.Modified() + + except (json.JSONDecodeError, ValueError) as e: + print_error(f"Invalid JSON for slicing: {e}") + except Exception as e: + print_error(f"Error setting slices: {e}") def SetCalculateAverages(self, calcavg): if self._avg != calcavg: self._avg = calcavg self.Modified() + def GetVariables(self): + return self._variables + + def GetDimensions(self): + return self._dimensions + @smproperty.doublevector( name="TimestepValues", information_only="1", si_class="vtkSITimeStepsProperty" ) @@ -587,17 +661,9 @@ def GetTimestepValues(self): # load. To expose that in ParaView, simply use the # smproperty.dataarrayselection(). # This method **must** return a `vtkDataArraySelection` instance. - @smproperty.dataarrayselection(name="Surface Variables") + @smproperty.dataarrayselection(name="Variables") def GetSurfaceVariables(self): - return self._surface_selection - - @smproperty.dataarrayselection(name="Midpoint Variables") - def GetMidpointVariables(self): - return self._midpoint_selection - - @smproperty.dataarrayselection(name="Interface Variables") - def GetInterfaceVariables(self): - return self._interface_selection + return self._variable_selection def RequestInformation(self, request, inInfo, outInfo): executive = self.GetExecutive() @@ -696,81 +762,17 @@ def RequestData(self, request, inInfo, outInfo): for i in range(last_num_arrays): to_remove.add(output_mesh.CellData.GetArrayName(i)) - for varmeta in self._surface_vars: - if self._surface_selection.ArrayIsEnabled(varmeta.name): - if output_mesh.CellData.HasArray(varmeta.name): - to_remove.remove(varmeta.name) + for name, varmeta in self._variables.items(): + if self._variable_selection.ArrayIsEnabled(name): + if output_mesh.CellData.HasArray(name): + to_remove.remove(name) if ( - not output_mesh.CellData.HasArray(varmeta.name) + not output_mesh.CellData.HasArray(name) or self._surface_update ): - data = self._load_2d_variable(vardata, varmeta, timeInd) - output_mesh.CellData.append(data, varmeta.name) - self._surface_update = False - - try: - lev_field_name = "lev" - has_lev_field = output_mesh.FieldData.HasArray(lev_field_name) - lev = self._get_cached_lev(vardata) - if lev is not None: - if not has_lev_field: - output_mesh.FieldData.append(lev, lev_field_name) - if self._lev >= vardata.dimensions[lev_field_name].size: - print_error( - f"User provided input for middle layer {self._lev} larger than actual data {len(lev) - 1}" - ) - lstart = self._lev * ncells2D - lend = lstart + ncells2D - - for varmeta in self._midpoint_vars: - if self._midpoint_selection.ArrayIsEnabled(varmeta.name): - if output_mesh.CellData.HasArray(varmeta.name): - to_remove.remove(varmeta.name) - if ( - not output_mesh.CellData.HasArray(varmeta.name) - or self._midpoint_update - ): - data = self._load_3d_slice( - vardata, varmeta, timeInd, lstart, lend - ) - output_mesh.CellData.append(data, varmeta.name) - self._midpoint_update = False - except Exception as e: - print_error("Error occurred while processing middle layer variables :", e) - traceback.print_exc() - - try: - ilev_field_name = "ilev" - has_ilev_field = output_mesh.FieldData.HasArray(ilev_field_name) - ilev = self._get_cached_ilev(vardata) - if ilev is not None: - if not has_ilev_field: - output_mesh.FieldData.append(ilev, ilev_field_name) - if self._ilev >= vardata.dimensions[ilev_field_name].size: - print_error( - f"User provided input for middle layer {self._ilev} larger than actual data {len(ilev) - 1}" - ) - ilstart = self._ilev * ncells2D - ilend = ilstart + ncells2D - for varmeta in self._interface_vars: - if self._interface_selection.ArrayIsEnabled(varmeta.name): - if output_mesh.CellData.HasArray(varmeta.name): - to_remove.remove(varmeta.name) - if ( - not output_mesh.CellData.HasArray(varmeta.name) - or self._interface_update - ): - data = self._load_3d_slice( - vardata, varmeta, timeInd, ilstart, ilend - ) - output_mesh.CellData.append(data, varmeta.name) - self._interface_update = False - except Exception as e: - print_error( - "Error occurred while processing interface layer variables :", e - ) - traceback.print_exc() - + data = self._load_variable(vardata, varmeta, timeInd) + output_mesh.CellData.append(data, name) + area_var_name = "area" if self._areavar and not output_mesh.CellData.HasArray(area_var_name): data = self._get_cached_area(vardata) From f4965103aa0bd3a9376255f00cc94a988e207986 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Wed, 3 Dec 2025 09:23:24 -0800 Subject: [PATCH 2/6] fix: Add/fix dynamic dimensions sliders --- src/e3sm_quickview/app.py | 84 ++++++++-------- src/e3sm_quickview/components/toolbars.py | 112 +++++++++------------- src/e3sm_quickview/pipeline.py | 27 ++---- src/e3sm_quickview/utils/compute.py | 3 + src/e3sm_quickview/view_manager.py | 5 +- 5 files changed, 95 insertions(+), 136 deletions(-) diff --git a/src/e3sm_quickview/app.py b/src/e3sm_quickview/app.py index f02d0ad..d244b84 100644 --- a/src/e3sm_quickview/app.py +++ b/src/e3sm_quickview/app.py @@ -45,13 +45,9 @@ def __init__(self, server=None): "variables_selected": [], # Control 'Load Variables' button availability "variables_loaded": False, - # Level controls - "midpoint_idx": 0, + # Dimension arrays (will be populated dynamically) "midpoints": [], - "interface_idx": 0, "interfaces": [], - # Time controls - "time_idx": 0, "timestamps": [], # Fields summaries "fields_avgs": {}, @@ -208,18 +204,19 @@ def _build_ui(self, **_): @property def selected_variables(self): - vars_per_type = {n: [] for n in "smi"} + from collections import defaultdict + vars_per_type = defaultdict(list) for var in self.state.variables_selected: - type = var[0] - name = var[1:] - vars_per_type[type].append(name) + print(self.state.variables_selected) + type = self.source.varmeta[var].dimensions + vars_per_type[type].append(var) - return vars_per_type + return dict(vars_per_type) @property def selected_variable_names(self): # Remove var type (first char) - return [var[1:] for var in self.state.variables_selected] + return [var for var in self.state.variables_selected] # ------------------------------------------------------------------------- # Methods connected to UI @@ -369,38 +366,39 @@ async def data_loading_open(self, simulation, connectivity): self.state.variables_filter = "" self.state.variables_listing = [ *( - {"name": name, "type": "surface", "id": f"s{name}"} - for name in self.source.surface_vars - ), - *( - {"name": name, "type": "interface", "id": f"i{name}"} - for name in self.source.interface_vars - ), - *( - {"name": name, "type": "midpoint", "id": f"m{name}"} - for name in self.source.midpoint_vars + {"name": var.name, "type": str(var.dimensions), "id": f"{var.name}"} + for _, var in self.source.varmeta.items() ), ] - print(self.state.variables_listing) # Update Layer/Time values and ui layout n_cols = 0 available_tracks = [] - for name in ["midpoints", "interfaces", "timestamps"]: - values = getattr(self.source, name) - self.state[name] = values - - if len(values) > 1: + for name, dim in self.source.dimmeta.items(): + values = dim.data + # Convert to list for JSON serialization + self.state[name] = values.tolist() if hasattr(values, 'tolist') else list(values) if values is not None else [] + if values is not None and len(values) > 1: n_cols += 1 - available_tracks.append(constants.TRACK_ENTRIES[name]) - + available_tracks.append({"title": name, "value": name}) self.state.toolbar_slider_cols = 12 / n_cols if n_cols else 12 + print("************", available_tracks) self.state.animation_tracks = available_tracks self.state.animation_track = ( self.state.animation_tracks[0]["value"] if available_tracks else None ) + + # Initialize dynamic index variables for each dimension + for track in available_tracks: + dim_name = track["value"] + index_var = f"{dim_name}_idx" + if "time" in index_var: + self.state[index_var] = 50 + else: + self.state[index_var] = 0 + print("***********", self.state.time_idx) @controller.set("file_selection_cancel") def data_loading_hide(self): @@ -415,10 +413,11 @@ async def _data_load_variables(self): """Called at 'Load Variables' button click""" vars_to_show = self.selected_variables + # Flatten the list of lists + flattened_vars = [var for var_list in vars_to_show.values() for var in var_list] + self.source.LoadVariables( - vars_to_show["s"], # surfaces - vars_to_show["m"], # midpoints - vars_to_show["i"], # interfaces + flattened_vars ) # Trigger source update + compute avg @@ -456,19 +455,21 @@ async def _on_projection(self, projection, **_): await asyncio.sleep(0.1) self.view_manager.reset_camera() - @change("active_tools") + @change("active_tools", "animation_tracks") def _on_toolbar_change(self, active_tools, **_): top_padding = 0 for name in active_tools: - top_padding += toolbars.SIZES.get(name, 0) + if name == "select-slice-time": + track_count = len(self.state.animation_tracks or []) + rows_needed = max(1, (track_count + 2) // 3) # 3 sliders per row + top_padding += 65 * rows_needed + else: + top_padding += toolbars.SIZES.get(name, 0) self.state.top_padding = top_padding @change( "variables_loaded", - "time_idx", - "midpoint_idx", - "interface_idx", "crop_longitude", "crop_latitude", "projection", @@ -476,10 +477,6 @@ def _on_toolbar_change(self, active_tools, **_): def _on_time_change( self, variables_loaded, - time_idx, - timestamps, - midpoint_idx, - interface_idx, crop_longitude, crop_latitude, projection, @@ -488,12 +485,9 @@ def _on_time_change( if not variables_loaded: return - time_value = timestamps[time_idx] if len(timestamps) else 0.0 - self.source.UpdateLev(midpoint_idx, interface_idx) self.source.ApplyClipping(crop_longitude, crop_latitude) self.source.UpdateProjection(projection[0]) - self.source.UpdateTimeStep(time_idx) - self.source.UpdatePipeline(time_value) + self.source.UpdatePipeline() self.view_manager.update_color_range() self.view_manager.render() diff --git a/src/e3sm_quickview/components/toolbars.py b/src/e3sm_quickview/components/toolbars.py index 119e4a8..61c9194 100644 --- a/src/e3sm_quickview/components/toolbars.py +++ b/src/e3sm_quickview/components/toolbars.py @@ -2,7 +2,8 @@ from trame.app import asynchronous from trame.decorators import change -from trame.widgets import html, vuetify3 as v3 +from trame.widgets import html, vuetify3 as v3, client + from e3sm_quickview.utils import js, constants @@ -233,81 +234,54 @@ def __init__(self): class DataSelection(v3.VToolbar): + + def on_update_slider(self, value, name, *args, **kwargs): + print(value, name, args, kwargs) + print(type(value)) + with self.state: + self.state[f"{name}_idx"] = value + + def __init__(self): super().__init__(**to_kwargs("select-slice-time")) with self: v3.VIcon("mdi-tune-variant", classes="ml-3 opacity-50") - with v3.VRow(classes="ma-0 pr-2 align-center", dense=True): - # midpoint layer + with v3.VRow(classes="ma-0 pr-2 overflow-y-auto", dense=True, style="max-height: 400px;"): + # Debug: Show animation_tracks array + # html.Div("Animation Tracks: {{ JSON.stringify(animation_tracks) }}", classes="col-12") + # Each track gets a column (3 per row) with v3.VCol( - cols=("toolbar_slider_cols", 4), - v_show="midpoints.length > 1", + cols=4, + v_for="(track, idx) in animation_tracks", + key="idx", + classes="pa-2" ): - with v3.VRow(classes="mx-2 my-0"): - v3.VLabel( - "Layer Midpoints", - classes="text-subtitle-2", - ) - v3.VSpacer() - v3.VLabel( - "{{ parseFloat(midpoints[midpoint_idx] || 0).toFixed(2) }} hPa (k={{ midpoint_idx }})", - classes="text-body-2", - ) - v3.VSlider( - v_model=("midpoint_idx", 0), - min=0, - max=("Math.max(0, midpoints.length - 1)",), - step=1, - density="compact", - hide_details=True, - ) - - # interface layer - with v3.VCol( - cols=("toolbar_slider_cols", 4), - v_show="interfaces.length > 1", - ): - with v3.VRow(classes="mx-2 my-0"): - v3.VLabel( - "Layer Interfaces", - classes="text-subtitle-2", - ) - v3.VSpacer() - v3.VLabel( - "{{ parseFloat(interfaces[interface_idx] || 0).toFixed(2) }} hPa (k={{interface_idx}})", - classes="text-body-2", - ) - v3.VSlider( - v_model=("interface_idx", 0), - min=0, - max=("Math.max(0, interfaces.length - 1)",), - step=1, - density="compact", - hide_details=True, - ) + with client.Getter(name=("track.value",), value_name="t_values"): + with client.Getter(name=("track.value + '_idx'",), value_name="t_idx"): + html.Span("Track {{ idx }}: {{ track.title }} = {{ track.value }} {{get(track.value + '_idx')}}") + with v3.VRow(classes="ma-0 align-center", dense=True): + v3.VLabel( + "{{track.title}}", + classes="text-subtitle-2", + ) + v3.VSpacer() + v3.VLabel( + "{{track.value}}", + #"{{ parseFloat(get(track.value)[get(`${track.value}_idx`)]).toFixed(2) }} hPa (k={{ get(`${track.value}_idx`) }})", + classes="text-body-2", + ) + v3.VSlider( + model_value=("t_idx",), + update_modelValue=(self.on_update_slider, "[$event, track.value, t_values]"), + min=0, + #max=100,#("get(track.value).length - 1",), + max=("t_values.length - 1",), + step=1, + density="compact", + hide_details=True, + ) - # time - with v3.VCol( - cols=("toolbar_slider_cols", 4), - v_show="timestamps.length > 1", - ): - self.state.setdefault("time_value", 80.50) - with v3.VRow(classes="mx-2 my-0"): - v3.VLabel("Time", classes="text-subtitle-2") - v3.VSpacer() - v3.VLabel( - "{{ parseFloat(timestamps[time_idx]).toFixed(2) }} (t={{time_idx}})", - classes="text-body-2", - ) - v3.VSlider( - v_model=("time_idx", 0), - min=0, - max=("Math.max(0, timestamps.length - 1)",), - step=1, - density="compact", - hide_details=True, - ) class Animation(v3.VToolbar): @@ -382,7 +356,7 @@ def _on_animation_track_change(self, animation_track, **_): @change("animation_step") def _on_animation_step(self, animation_track, animation_step, **_): if animation_track: - self.state[constants.TRACK_STEPS[animation_track]] = animation_step + self.state[f"{animation_track}_idx"] = animation_step @change("animation_play") def _on_animation_play(self, animation_play, **_): diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py index c7e5664..9c2d558 100644 --- a/src/e3sm_quickview/pipeline.py +++ b/src/e3sm_quickview/pipeline.py @@ -40,8 +40,8 @@ def __init__(self): self.conn_file = None # List of all available variables - self.variables = None - self.dimensions = None + self.varmeta = None + self.dimmeta = None self.data = None self.globe = None @@ -189,6 +189,8 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal vtk_obj = data.GetClientSideObject() vtk_obj.AddObserver("ErrorEvent", self.observer) vtk_obj.GetExecutive().AddObserver("ErrorEvent", self.observer) + self.varmeta = vtk_obj.GetVariables() + self.dimmeta = vtk_obj.GetDimensions() self.observer.clear() else: self.data.DataFile = data_file @@ -204,20 +206,7 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal "Please check if the data and connectivity files exist " "and are compatible" ) - - data_wrapped = dsa.WrapDataObject(sm.Fetch(self.data)) - self.midpoints = data_wrapped.FieldData["lev"].tolist() - self.interfaces = data_wrapped.FieldData["ilev"].tolist() - - self.surface_vars = list( - np.asarray(self.data.GetProperty("SurfaceVariablesInfo"))[::2] - ) - self.midpoint_vars = list( - np.asarray(self.data.GetProperty("MidpointVariablesInfo"))[::2] - ) - self.interface_vars = list( - np.asarray(self.data.GetProperty("InterfaceVariablesInfo"))[::2] - ) + # Ensure TimestepValues is always a list timestep_values = self.data.TimestepValues if isinstance(timestep_values, (list, tuple)): @@ -309,12 +298,10 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal return self.valid - def LoadVariables(self, surf, mid, intf): + def LoadVariables(self, vars): if not self.valid: return - self.data.SurfaceVariables = surf - self.data.MidpointVariables = mid - self.data.InterfaceVariables = intf + self.data.Variables = vars if __name__ == "__main__": e = EAMVisSource() diff --git a/src/e3sm_quickview/utils/compute.py b/src/e3sm_quickview/utils/compute.py index ca2beeb..cb5ffd9 100644 --- a/src/e3sm_quickview/utils/compute.py +++ b/src/e3sm_quickview/utils/compute.py @@ -36,9 +36,12 @@ def calculate_weighted_average( def extract_avgs(pv_data, array_names): results = {} vtk_data = servermanager.Fetch(pv_data) + print(vtk_data) area_array = vtk_data.GetCellData().GetArray("area") for name in array_names: vtk_array = vtk_data.GetCellData().GetArray(name) + print(f"Field Array {name} : ", vtk_array) + print("Area Array : ", area_array) avg_value = calculate_weighted_average(vtk_array, area_array) results[name] = avg_value diff --git a/src/e3sm_quickview/view_manager.py b/src/e3sm_quickview/view_manager.py index 7e26461..721681a 100644 --- a/src/e3sm_quickview/view_manager.py +++ b/src/e3sm_quickview/view_manager.py @@ -425,7 +425,8 @@ def build_auto_layout(self, variables=None): with DivLayout(self.server, template_name="auto_layout") as self.ui: if self.state.layout_grouped: with v3.VCol(classes="pa-1"): - for var_type in "smi": + for var_type in variables.keys(): + var_names = variables[var_type] total_size = len(var_names) @@ -436,7 +437,7 @@ def build_auto_layout(self, variables=None): border="start", classes="pr-1 py-1 pl-3 mb-1", variant="flat", - border_color=TYPE_COLOR[var_type], + border_color=TYPE_COLOR.get(var_type, "primary"), ): with v3.VRow(dense=True): for name in var_names: From c6242fbd0954e0a6303eaaa1552722a3790023e4 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Thu, 4 Dec 2025 12:30:06 -0800 Subject: [PATCH 3/6] fix: Fixing slicing based on arbitrary dimensions --- src/e3sm_quickview/app.py | 23 ++++++++-- src/e3sm_quickview/components/toolbars.py | 26 +++++------- src/e3sm_quickview/pipeline.py | 52 ++++++++++------------- src/e3sm_quickview/plugins/eam_reader.py | 8 ++-- src/e3sm_quickview/utils/compute.py | 3 -- 5 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/e3sm_quickview/app.py b/src/e3sm_quickview/app.py index d244b84..b7f260c 100644 --- a/src/e3sm_quickview/app.py +++ b/src/e3sm_quickview/app.py @@ -207,7 +207,6 @@ def selected_variables(self): from collections import defaultdict vars_per_type = defaultdict(list) for var in self.state.variables_selected: - print(self.state.variables_selected) type = self.source.varmeta[var].dimensions vars_per_type[type].append(var) @@ -382,7 +381,6 @@ async def data_loading_open(self, simulation, connectivity): n_cols += 1 available_tracks.append({"title": name, "value": name}) self.state.toolbar_slider_cols = 12 / n_cols if n_cols else 12 - print("************", available_tracks) self.state.animation_tracks = available_tracks self.state.animation_track = ( self.state.animation_tracks[0]["value"] @@ -390,6 +388,7 @@ async def data_loading_open(self, simulation, connectivity): else None ) + from functools import partial # Initialize dynamic index variables for each dimension for track in available_tracks: dim_name = track["value"] @@ -398,7 +397,7 @@ async def data_loading_open(self, simulation, connectivity): self.state[index_var] = 50 else: self.state[index_var] = 0 - print("***********", self.state.time_idx) + self.state.change(index_var)(partial(self._on_slicing_change, dim_name, index_var)) @controller.set("file_selection_cancel") def data_loading_hide(self): @@ -468,13 +467,28 @@ def _on_toolbar_change(self, active_tools, **_): self.state.top_padding = top_padding + def _on_slicing_change(self, var, ind_var, **_): + print(f"Here Updating {var}") + self.source.UpdateSlicing(var, self.state[ind_var]) + self.source.UpdatePipeline() + + self.view_manager.update_color_range() + self.view_manager.render() + + # Update avg computation + # Get area variable to calculate weighted average + data = self.source.views["atmosphere_data"] + self.state.fields_avgs = compute.extract_avgs( + data, self.selected_variable_names + ) + @change( "variables_loaded", "crop_longitude", "crop_latitude", "projection", ) - def _on_time_change( + def _on_downstream_change( self, variables_loaded, crop_longitude, @@ -482,6 +496,7 @@ def _on_time_change( projection, **_, ): + print("Calling!!!!") if not variables_loaded: return diff --git a/src/e3sm_quickview/components/toolbars.py b/src/e3sm_quickview/components/toolbars.py index 61c9194..1796a50 100644 --- a/src/e3sm_quickview/components/toolbars.py +++ b/src/e3sm_quickview/components/toolbars.py @@ -233,21 +233,16 @@ def __init__(self): ) -class DataSelection(v3.VToolbar): - - def on_update_slider(self, value, name, *args, **kwargs): - print(value, name, args, kwargs) - print(type(value)) - with self.state: - self.state[f"{name}_idx"] = value - - +class DataSelection(html.Div): def __init__(self): - super().__init__(**to_kwargs("select-slice-time")) + style = to_kwargs("select-slice-time") + style["classes"] = style["classes"] # + " d-flex align-center" + super().__init__(**style) with self: v3.VIcon("mdi-tune-variant", classes="ml-3 opacity-50") - with v3.VRow(classes="ma-0 pr-2 overflow-y-auto", dense=True, style="max-height: 400px;"): + + with v3.VRow(classes="ma-0 pr-2 flex-wrap", dense=True): # Debug: Show animation_tracks array # html.Div("Animation Tracks: {{ JSON.stringify(animation_tracks) }}", classes="col-12") # Each track gets a column (3 per row) @@ -259,7 +254,6 @@ def __init__(self): ): with client.Getter(name=("track.value",), value_name="t_values"): with client.Getter(name=("track.value + '_idx'",), value_name="t_idx"): - html.Span("Track {{ idx }}: {{ track.title }} = {{ track.value }} {{get(track.value + '_idx')}}") with v3.VRow(classes="ma-0 align-center", dense=True): v3.VLabel( "{{track.title}}", @@ -267,13 +261,12 @@ def __init__(self): ) v3.VSpacer() v3.VLabel( - "{{track.value}}", - #"{{ parseFloat(get(track.value)[get(`${track.value}_idx`)]).toFixed(2) }} hPa (k={{ get(`${track.value}_idx`) }})", + "{{ parseFloat(t_values[t_idx]).toFixed(2) }} hPa (k={{ t_idx }})", classes="text-body-2", ) v3.VSlider( model_value=("t_idx",), - update_modelValue=(self.on_update_slider, "[$event, track.value, t_values]"), + update_modelValue=(self.on_update_slider, "[track.value, $event]"), min=0, #max=100,#("get(track.value).length - 1",), max=("t_values.length - 1",), @@ -282,6 +275,9 @@ def __init__(self): hide_details=True, ) + def on_update_slider(self, dimension, index, *_, **__): + with self.state: + self.state[f"{dimension}_idx"] = index class Animation(v3.VToolbar): diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py index 9c2d558..6f79282 100644 --- a/src/e3sm_quickview/pipeline.py +++ b/src/e3sm_quickview/pipeline.py @@ -1,4 +1,5 @@ import fnmatch +import json import numpy as np import os @@ -15,6 +16,7 @@ from paraview.vtk.numpy_interface import dataset_adapter as dsa from vtkmodules.vtkCommonCore import vtkLogger +from collections import defaultdict # Define a VTK error observer class ErrorObserver: @@ -42,7 +44,8 @@ def __init__(self): # List of all available variables self.varmeta = None self.dimmeta = None - + self.slicing = defaultdict(int) + self.data = None self.globe = None self.projection = "Cyl. Equidistant" @@ -69,34 +72,6 @@ def __init__(self): except Exception as e: print("Error loading plugin :", e) - def UpdateLev(self, lev, ilev): - if not self.valid: - return - - if self.data is None: - return - - # Handle NaN, None, or invalid values - try: - if lev is None or (isinstance(lev, float) and np.isnan(lev)): - lev_idx = 0 - else: - lev_idx = int(lev) - except (ValueError, TypeError): - lev_idx = 0 - - try: - if ilev is None or (isinstance(ilev, float) and np.isnan(ilev)): - ilev_idx = 0 - else: - ilev_idx = int(ilev) - except (ValueError, TypeError): - ilev_idx = 0 - - if self.data.MiddleLayer != lev_idx or self.data.InterfaceLayer != ilev_idx: - self.data.MiddleLayer = lev_idx - self.data.InterfaceLayer = ilev_idx - def ApplyClipping(self, cliplong, cliplat): if not self.valid: return @@ -167,7 +142,18 @@ def UpdatePipeline(self, time=0.0): self.views["continents"] = OutputPort(cont_proj, 0) self.views["grid_lines"] = OutputPort(grid_proj, 0) - def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=False): + def UpdateSlicing(self, dimension, slice): + if self.slicing.get(dimension) == slice: + return + else: + self.slicing[dimension] = slice + if self.data is not None: + x = json.dumps(self.slicing) + print(x) + self.data.Slicing = x + + + def Update(self, data_file, conn_file, force_reload=False): # Check if we need to reload if ( not force_reload @@ -191,6 +177,11 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal vtk_obj.GetExecutive().AddObserver("ErrorEvent", self.observer) self.varmeta = vtk_obj.GetVariables() self.dimmeta = vtk_obj.GetDimensions() + + for dim in self.dimmeta.keys(): + self.slicing[dim] = 0 + + print(self.slicing) self.observer.clear() else: self.data.DataFile = data_file @@ -299,6 +290,7 @@ def Update(self, data_file, conn_file, midpoint=0, interface=0, force_reload=Fal return self.valid def LoadVariables(self, vars): + print(f"Gonna Load {vars}") if not self.valid: return self.data.Variables = vars diff --git a/src/e3sm_quickview/plugins/eam_reader.py b/src/e3sm_quickview/plugins/eam_reader.py index d526b93..274dcd4 100644 --- a/src/e3sm_quickview/plugins/eam_reader.py +++ b/src/e3sm_quickview/plugins/eam_reader.py @@ -405,16 +405,16 @@ def _load_variable(self, vardata, varmeta, timeInd): for dim in varmeta.dimensions: if dim == self._data_horizontal_dim: continue - elif dim == "time": - # Use timeInd for time dimension - slice_tuple.append(timeInd) + #elif dim == "time": + # # Use timeInd for time dimension + # slice_tuple.append(timeInd) elif hasattr(self, '_slices') and dim in self._slices: # Use user-specified slice for this dimension slice_tuple.append(self._slices[dim]) else: # Use all data for unspecified dimensions slice_tuple.append(slice(None)) - + print(f"Fetching {varmeta.name} with tuple : {slice_tuple}") # Get data with proper slicing data = vardata[varmeta.name][tuple(slice_tuple)].data.flatten() data = np.where(data == varmeta.fillval, np.nan, data) diff --git a/src/e3sm_quickview/utils/compute.py b/src/e3sm_quickview/utils/compute.py index cb5ffd9..ca2beeb 100644 --- a/src/e3sm_quickview/utils/compute.py +++ b/src/e3sm_quickview/utils/compute.py @@ -36,12 +36,9 @@ def calculate_weighted_average( def extract_avgs(pv_data, array_names): results = {} vtk_data = servermanager.Fetch(pv_data) - print(vtk_data) area_array = vtk_data.GetCellData().GetArray("area") for name in array_names: vtk_array = vtk_data.GetCellData().GetArray(name) - print(f"Field Array {name} : ", vtk_array) - print("Area Array : ", area_array) avg_value = calculate_weighted_average(vtk_array, area_array) results[name] = avg_value From 87c3d5685b9f5d1e7c9fe2c8b32131cdd85621b2 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Thu, 4 Dec 2025 12:42:16 -0800 Subject: [PATCH 4/6] fix: Remove unnecessary print statements --- src/e3sm_quickview/app.py | 2 -- src/e3sm_quickview/pipeline.py | 1 - src/e3sm_quickview/plugins/eam_reader.py | 1 - 3 files changed, 4 deletions(-) diff --git a/src/e3sm_quickview/app.py b/src/e3sm_quickview/app.py index b7f260c..fc8dbf8 100644 --- a/src/e3sm_quickview/app.py +++ b/src/e3sm_quickview/app.py @@ -468,7 +468,6 @@ def _on_toolbar_change(self, active_tools, **_): self.state.top_padding = top_padding def _on_slicing_change(self, var, ind_var, **_): - print(f"Here Updating {var}") self.source.UpdateSlicing(var, self.state[ind_var]) self.source.UpdatePipeline() @@ -496,7 +495,6 @@ def _on_downstream_change( projection, **_, ): - print("Calling!!!!") if not variables_loaded: return diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py index 6f79282..807f7bc 100644 --- a/src/e3sm_quickview/pipeline.py +++ b/src/e3sm_quickview/pipeline.py @@ -149,7 +149,6 @@ def UpdateSlicing(self, dimension, slice): self.slicing[dimension] = slice if self.data is not None: x = json.dumps(self.slicing) - print(x) self.data.Slicing = x diff --git a/src/e3sm_quickview/plugins/eam_reader.py b/src/e3sm_quickview/plugins/eam_reader.py index 274dcd4..2e8604e 100644 --- a/src/e3sm_quickview/plugins/eam_reader.py +++ b/src/e3sm_quickview/plugins/eam_reader.py @@ -414,7 +414,6 @@ def _load_variable(self, vardata, varmeta, timeInd): else: # Use all data for unspecified dimensions slice_tuple.append(slice(None)) - print(f"Fetching {varmeta.name} with tuple : {slice_tuple}") # Get data with proper slicing data = vardata[varmeta.name][tuple(slice_tuple)].data.flatten() data = np.where(data == varmeta.fillval, np.nan, data) From 3fbd8fccba4cbc8185dacf7ac0f56af671afccac Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Fri, 5 Dec 2025 13:17:11 -0800 Subject: [PATCH 5/6] fix: fixing errors while variable groupings --- src/e3sm_quickview/utils/constants.py | 14 +------------- src/e3sm_quickview/view_manager.py | 28 +++++++++++++++++---------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/e3sm_quickview/utils/constants.py b/src/e3sm_quickview/utils/constants.py index cf43c8f..278ed3a 100644 --- a/src/e3sm_quickview/utils/constants.py +++ b/src/e3sm_quickview/utils/constants.py @@ -1,16 +1,4 @@ VAR_HEADERS = [ {"title": "Name", "align": "start", "key": "name", "sortable": True}, {"title": "Type", "align": "start", "key": "type", "sortable": True}, -] - -TRACK_STEPS = { - "timestamps": "time_idx", - "interfaces": "interface_idx", - "midpoints": "midpoint_idx", -} - -TRACK_ENTRIES = { - "timestamps": {"title": "Time", "value": "timestamps"}, - "midpoints": {"title": "Layer Midpoints", "value": "midpoints"}, - "interfaces": {"title": "Layer Interfaces", "value": "interfaces"}, -} +] \ No newline at end of file diff --git a/src/e3sm_quickview/view_manager.py b/src/e3sm_quickview/view_manager.py index 721681a..338accc 100644 --- a/src/e3sm_quickview/view_manager.py +++ b/src/e3sm_quickview/view_manager.py @@ -41,11 +41,20 @@ def auto_size_to_col(size): "flow": None, } -TYPE_COLOR = { - "s": "success", - "i": "info", - "m": "warning", -} +TYPE_COLORS = [ + "success", + "info", + "warning", + "error", + "purple", + "cyan", + "teal", + "indigo", + "pink", + "amber", + "lime", + "deep-purple", +] def lut_name(element): @@ -425,8 +434,7 @@ def build_auto_layout(self, variables=None): with DivLayout(self.server, template_name="auto_layout") as self.ui: if self.state.layout_grouped: with v3.VCol(classes="pa-1"): - for var_type in variables.keys(): - + for idx, var_type in enumerate(variables.keys()): var_names = variables[var_type] total_size = len(var_names) @@ -437,7 +445,7 @@ def build_auto_layout(self, variables=None): border="start", classes="pr-1 py-1 pl-3 mb-1", variant="flat", - border_color=TYPE_COLOR.get(var_type, "primary"), + border_color=TYPE_COLORS[idx % len(TYPE_COLORS)], ): with v3.VRow(dense=True): for name in var_names: @@ -468,7 +476,7 @@ def build_auto_layout(self, variables=None): else: all_names = [name for names in variables.values() for name in names] with v3.VRow(dense=True, classes="pa-2"): - for var_type in "smi": + for var_type in variables.keys(): var_names = variables[var_type] for name in var_names: view = self.get_view(name, var_type) @@ -504,7 +512,7 @@ def build_auto_layout(self, variables=None): existed_order = set() order_max = 0 orders_to_update = [] - for var_type in "smi": + for var_type in variables.keys(): var_names = variables[var_type] for name in var_names: config = self.get_view(name, var_type).config From 0de06665cc22260c05e7797c6bef536b9b525988 Mon Sep 17 00:00:00 2001 From: Abhishek Yenpure Date: Mon, 8 Dec 2025 14:13:45 -0800 Subject: [PATCH 6/6] fix : Adding UI fixes for arbitrary dimension selections --- src/e3sm_quickview/app.py | 45 ++++-- src/e3sm_quickview/components/drawers.py | 34 +++-- src/e3sm_quickview/components/toolbars.py | 29 ++-- src/e3sm_quickview/pipeline.py | 24 ++- src/e3sm_quickview/plugins/eam_reader.py | 178 +++++++++++----------- src/e3sm_quickview/utils/colors.py | 18 +++ src/e3sm_quickview/utils/js.py | 14 -- src/e3sm_quickview/view_manager.py | 29 ++-- 8 files changed, 199 insertions(+), 172 deletions(-) create mode 100644 src/e3sm_quickview/utils/colors.py diff --git a/src/e3sm_quickview/app.py b/src/e3sm_quickview/app.py index fc8dbf8..45165a2 100644 --- a/src/e3sm_quickview/app.py +++ b/src/e3sm_quickview/app.py @@ -14,7 +14,7 @@ from e3sm_quickview.assets import ASSETS from e3sm_quickview.components import doc, file_browser, css, toolbars, dialogs, drawers from e3sm_quickview.pipeline import EAMVisSource -from e3sm_quickview.utils import compute, js, constants, cli +from e3sm_quickview.utils import compute, cli from e3sm_quickview.view_manager import ViewManager @@ -45,6 +45,8 @@ def __init__(self, server=None): "variables_selected": [], # Control 'Load Variables' button availability "variables_loaded": False, + # Dynamic type-color mapping (populated when data loads) + "variable_types": [], # Dimension arrays (will be populated dynamically) "midpoints": [], "interfaces": [], @@ -205,6 +207,7 @@ def _build_ui(self, **_): @property def selected_variables(self): from collections import defaultdict + vars_per_type = defaultdict(list) for var in self.state.variables_selected: type = self.source.varmeta[var].dimensions @@ -365,18 +368,39 @@ async def data_loading_open(self, simulation, connectivity): self.state.variables_filter = "" self.state.variables_listing = [ *( - {"name": var.name, "type": str(var.dimensions), "id": f"{var.name}"} + { + "name": var.name, + "type": str(var.dimensions), + "id": f"{var.name}", + } for _, var in self.source.varmeta.items() ), ] + # Build dynamic type-color mapping + from e3sm_quickview.utils.colors import get_type_color + + dim_types = sorted( + set(str(var.dimensions) for var in self.source.varmeta.values()) + ) + self.state.variable_types = [ + {"name": t, "color": get_type_color(i)} + for i, t in enumerate(dim_types) + ] + # Update Layer/Time values and ui layout n_cols = 0 available_tracks = [] for name, dim in self.source.dimmeta.items(): values = dim.data # Convert to list for JSON serialization - self.state[name] = values.tolist() if hasattr(values, 'tolist') else list(values) if values is not None else [] + self.state[name] = ( + values.tolist() + if hasattr(values, "tolist") + else list(values) + if values is not None + else [] + ) if values is not None and len(values) > 1: n_cols += 1 available_tracks.append({"title": name, "value": name}) @@ -387,8 +411,9 @@ async def data_loading_open(self, simulation, connectivity): if available_tracks else None ) - + from functools import partial + # Initialize dynamic index variables for each dimension for track in available_tracks: dim_name = track["value"] @@ -397,7 +422,9 @@ async def data_loading_open(self, simulation, connectivity): self.state[index_var] = 50 else: self.state[index_var] = 0 - self.state.change(index_var)(partial(self._on_slicing_change, dim_name, index_var)) + self.state.change(index_var)( + partial(self._on_slicing_change, dim_name, index_var) + ) @controller.set("file_selection_cancel") def data_loading_hide(self): @@ -414,10 +441,8 @@ async def _data_load_variables(self): # Flatten the list of lists flattened_vars = [var for var_list in vars_to_show.values() for var in var_list] - - self.source.LoadVariables( - flattened_vars - ) + + self.source.LoadVariables(flattened_vars) # Trigger source update + compute avg with self.state: @@ -461,7 +486,7 @@ def _on_toolbar_change(self, active_tools, **_): if name == "select-slice-time": track_count = len(self.state.animation_tracks or []) rows_needed = max(1, (track_count + 2) // 3) # 3 sliders per row - top_padding += 65 * rows_needed + top_padding += 70 * rows_needed else: top_padding += toolbars.SIZES.get(name, 0) diff --git a/src/e3sm_quickview/components/drawers.py b/src/e3sm_quickview/components/drawers.py index 1325e64..19fa1a5 100644 --- a/src/e3sm_quickview/components/drawers.py +++ b/src/e3sm_quickview/components/drawers.py @@ -68,20 +68,26 @@ def __init__(self, load_variables=None): with self: with html.Div(style="position:fixed;top:0;width: 500px;"): - with v3.VCardActions(key="variables_selected.length"): - for name, color in [ - ("surfaces", "success"), - ("interfaces", "info"), - ("midpoints", "warning"), - ]: - v3.VChip( - js.var_title(name), - color=color, - v_show=js.var_count(name), - size="small", - closable=True, - click_close=js.var_remove(name), - ) + with v3.VCardActions( + key="variables_selected.length", + classes="flex-wrap", + style="overflow-y: auto; max-height: 100px;", + ): + v3.VChip( + "{{ variables_selected.filter(id => variables_listing.find(v => v.id === id)?.type === vtype.name).length }} {{ vtype.name }}", + v_for="(vtype, idx) in variable_types", + key="idx", + color=("vtype.color",), + v_show=( + "variables_selected.filter(id => variables_listing.find(v => v.id === id)?.type === vtype.name).length", + ), + size="small", + closable=True, + click_close=( + "variables_selected = variables_selected.filter(id => variables_listing.find(v => v.id === id)?.type !== vtype.name)", + ), + classes="ma-1", + ) v3.VSpacer() v3.VBtn( diff --git a/src/e3sm_quickview/components/toolbars.py b/src/e3sm_quickview/components/toolbars.py index 1796a50..2a3362e 100644 --- a/src/e3sm_quickview/components/toolbars.py +++ b/src/e3sm_quickview/components/toolbars.py @@ -5,7 +5,7 @@ from trame.widgets import html, vuetify3 as v3, client -from e3sm_quickview.utils import js, constants +from e3sm_quickview.utils import js DENSITY = { "adjust-layout": "compact", @@ -17,7 +17,7 @@ SIZES = { "adjust-layout": 49, "adjust-databounds": 65, - "select-slice-time": 65, + "select-slice-time": 70, "animation-controls": 49, } @@ -236,13 +236,17 @@ def __init__(self): class DataSelection(html.Div): def __init__(self): style = to_kwargs("select-slice-time") - style["classes"] = style["classes"] # + " d-flex align-center" + # Use style instead of d-flex class to avoid !important override of v-show + # Add background color to match VToolbar appearance + style["style"] = ( + "display: flex; align-items: center; background: rgb(var(--v-theme-surface));" + ) super().__init__(**style) with self: - v3.VIcon("mdi-tune-variant", classes="ml-3 opacity-50") - - with v3.VRow(classes="ma-0 pr-2 flex-wrap", dense=True): + v3.VIcon("mdi-tune-variant", classes="ml-3 mr-2 opacity-50") + + with v3.VRow(classes="ma-0 pr-2 flex-wrap flex-grow-1", dense=True): # Debug: Show animation_tracks array # html.Div("Animation Tracks: {{ JSON.stringify(animation_tracks) }}", classes="col-12") # Each track gets a column (3 per row) @@ -250,10 +254,12 @@ def __init__(self): cols=4, v_for="(track, idx) in animation_tracks", key="idx", - classes="pa-2" + classes="pa-2", ): with client.Getter(name=("track.value",), value_name="t_values"): - with client.Getter(name=("track.value + '_idx'",), value_name="t_idx"): + with client.Getter( + name=("track.value + '_idx'",), value_name="t_idx" + ): with v3.VRow(classes="ma-0 align-center", dense=True): v3.VLabel( "{{track.title}}", @@ -266,9 +272,12 @@ def __init__(self): ) v3.VSlider( model_value=("t_idx",), - update_modelValue=(self.on_update_slider, "[track.value, $event]"), + update_modelValue=( + self.on_update_slider, + "[track.value, $event]", + ), min=0, - #max=100,#("get(track.value).length - 1",), + # max=100,#("get(track.value).length - 1",), max=("t_values.length - 1",), step=1, density="compact", diff --git a/src/e3sm_quickview/pipeline.py b/src/e3sm_quickview/pipeline.py index 807f7bc..f7cb479 100644 --- a/src/e3sm_quickview/pipeline.py +++ b/src/e3sm_quickview/pipeline.py @@ -1,6 +1,5 @@ import fnmatch import json -import numpy as np import os @@ -12,12 +11,11 @@ LegacyVTKReader, ) -from paraview import servermanager as sm -from paraview.vtk.numpy_interface import dataset_adapter as dsa from vtkmodules.vtkCommonCore import vtkLogger from collections import defaultdict + # Define a VTK error observer class ErrorObserver: def __init__(self): @@ -151,7 +149,6 @@ def UpdateSlicing(self, dimension, slice): x = json.dumps(self.slicing) self.data.Slicing = x - def Update(self, data_file, conn_file, force_reload=False): # Check if we need to reload if ( @@ -165,7 +162,7 @@ def Update(self, data_file, conn_file, force_reload=False): self.conn_file = conn_file if self.data is None: - data = EAMSliceDataReader( + data = EAMSliceDataReader( # noqa: F821 registrationName="AtmosReader", ConnectivityFile=conn_file, DataFile=data_file, @@ -180,7 +177,6 @@ def Update(self, data_file, conn_file, force_reload=False): for dim in self.dimmeta.keys(): self.slicing[dim] = 0 - print(self.slicing) self.observer.clear() else: self.data.DataFile = data_file @@ -196,7 +192,7 @@ def Update(self, data_file, conn_file, force_reload=False): "Please check if the data and connectivity files exist " "and are compatible" ) - + # Ensure TimestepValues is always a list timestep_values = self.data.TimestepValues if isinstance(timestep_values, (list, tuple)): @@ -213,7 +209,7 @@ def Update(self, data_file, conn_file, force_reload=False): ) # Step 1: Extract and transform atmospheric data - atmos_extract = EAMTransformAndExtract( + atmos_extract = EAMTransformAndExtract( # noqa: F821 registrationName="AtmosExtract", Input=self.data ) atmos_extract.LongitudeRange = [-180.0, 180.0] @@ -222,7 +218,7 @@ def Update(self, data_file, conn_file, force_reload=False): self.extents = atmos_extract.GetDataInformation().GetBounds() # Step 2: Apply map projection to atmospheric data - atmos_proj = EAMProject( + atmos_proj = EAMProject( # noqa: F821 registrationName="AtmosProj", Input=OutputPort(atmos_extract, 0) ) atmos_proj.Projection = self.projection @@ -247,14 +243,14 @@ def Update(self, data_file, conn_file, force_reload=False): self.globe = cont_contour # Step 4: Extract and transform continent data - cont_extract = EAMTransformAndExtract( + cont_extract = EAMTransformAndExtract( # noqa: F821 registrationName="ContExtract", Input=self.globe ) cont_extract.LongitudeRange = [-180.0, 180.0] cont_extract.LatitudeRange = [-90.0, 90.0] # Step 5: Apply map projection to continents - cont_proj = EAMProject( + cont_proj = EAMProject( # noqa: F821 registrationName="ContProj", Input=OutputPort(cont_extract, 0) ) cont_proj.Projection = self.projection @@ -262,11 +258,11 @@ def Update(self, data_file, conn_file, force_reload=False): cont_proj.UpdatePipeline() # Step 6: Generate lat/lon grid lines - grid_gen = EAMGridLines(registrationName="GridGen") + grid_gen = EAMGridLines(registrationName="GridGen") # noqa: F821 grid_gen.UpdatePipeline() # Step 7: Apply map projection to grid lines - grid_proj = EAMProject( + grid_proj = EAMProject( # noqa: F821 registrationName="GridProj", Input=OutputPort(grid_gen, 0) ) grid_proj.Projection = self.projection @@ -289,10 +285,10 @@ def Update(self, data_file, conn_file, force_reload=False): return self.valid def LoadVariables(self, vars): - print(f"Gonna Load {vars}") if not self.valid: return self.data.Variables = vars + if __name__ == "__main__": e = EAMVisSource() diff --git a/src/e3sm_quickview/plugins/eam_reader.py b/src/e3sm_quickview/plugins/eam_reader.py index 2e8604e..ab7fa60 100644 --- a/src/e3sm_quickview/plugins/eam_reader.py +++ b/src/e3sm_quickview/plugins/eam_reader.py @@ -1,9 +1,9 @@ -from paraview.util.vtkAlgorithm import * +from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase +from paraview.util.vtkAlgorithm import smproxy, smproperty from vtkmodules.numpy_interface import dataset_adapter as dsa from vtkmodules.vtkCommonCore import vtkPoints, vtkDataArraySelection from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkCellArray from vtkmodules.util import vtkConstants, numpy_support -from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase from paraview import print_error, print_warning try: @@ -32,80 +32,81 @@ class EAMConstants: PS0 = float(1e5) -from enum import Enum # noqa: E402 - class DimMeta: """Simple class to store dimension metadata.""" + def __init__(self, name, size, data=None): self.name = name self.size = size self.long_name = None self.units = None self.data = data # Store the actual dimension coordinate values - + def __getitem__(self, key): """Dict-like access to attributes.""" return getattr(self, key, None) - + def __setitem__(self, key, value): """Dict-like setting of attributes.""" setattr(self, key, value) - + def update_from_variable(self, var_info): """Update metadata from netCDF variable info - only long_name and units.""" try: - self.long_name = var_info.getncattr('long_name') + self.long_name = var_info.getncattr("long_name") except AttributeError: pass - + try: - self.units = var_info.getncattr('units') + self.units = var_info.getncattr("units") except AttributeError: pass - + def __repr__(self): return f"DimMeta(name='{self.name}', size={self.size}, long_name='{self.long_name}')" + class VarMeta: """Simple class to store variable metadata.""" + def __init__(self, name, info, horizontal_dim=None): self.name = name self.dimensions = info.dimensions # Store dimensions for slicing self.fillval = np.nan self.long_name = None - + # Extract metadata from info self._extract_metadata(info) - + def _extract_metadata(self, info): """Helper to extract metadata attributes from netCDF variable.""" # Try to get fill value from either _FillValue or missing_value - for fillattr in ['_FillValue', 'missing_value']: + for fillattr in ["_FillValue", "missing_value"]: value = self._get_attr(info, fillattr) if value is not None: self.fillval = value break - + # Get long_name if available - long_name = self._get_attr(info, 'long_name') + long_name = self._get_attr(info, "long_name") if long_name is not None: self.long_name = long_name - + def _get_attr(self, info, attr_name): """Safely get an attribute from netCDF variable info.""" try: return info.getncattr(attr_name) except (AttributeError, KeyError): return None - + def __getitem__(self, key): """Dict-like access to attributes.""" return getattr(self, key, None) - + def __setitem__(self, key, value): """Dict-like setting of attributes.""" setattr(self, key, value) - + def __repr__(self): return f"VarMeta(name='{self.name}', dimensions={self.dimensions})" @@ -178,9 +179,6 @@ def _markmodified(*args, **kwars): return _markmodified -import traceback # noqa: E402 - - @smproxy.reader( name="EAMSliceSource", label="EAM Slice Data Reader", @@ -244,7 +242,6 @@ def __init__(self): self._timeSteps = [] # Dictionary to store dimension slices self._slices = {} - # vtkDataArraySelection to allow users choice for fields # to fetch from the netCDF data set @@ -270,10 +267,10 @@ def __init__(self): self._cached_ncells2D = None # Special variable caching - #self._cached_lev = None - #self._cached_ilev = None + # self._cached_lev = None + # self._cached_ilev = None self._cached_area = None - + # Dynamic dimension detection self._horizontal_dim = None self._data_horizontal_dim = None # Matched in data file @@ -338,23 +335,25 @@ def _identify_horizontal_dimension(self, meshdata, vardata): """Identify horizontal dimension from connectivity and match with data file.""" if self._horizontal_dim and self._data_horizontal_dim: return # Already identified - + # Get first dimension from connectivity file conn_dims = list(meshdata.dimensions.keys()) if not conn_dims: print_error("No dimensions found in connectivity file") return - + self._horizontal_dim = conn_dims[0] conn_size = meshdata.dimensions[self._horizontal_dim].size - + # Match dimension in data file by size for dim_name, dim_obj in vardata.dimensions.items(): if dim_obj.size == conn_size: self._data_horizontal_dim = dim_name return - - print_error(f"Could not match horizontal dimension size {conn_size} in data file") + + print_error( + f"Could not match horizontal dimension size {conn_size} in data file" + ) def _clear_geometry_cache(self): """Clear cached geometry data.""" @@ -405,10 +404,10 @@ def _load_variable(self, vardata, varmeta, timeInd): for dim in varmeta.dimensions: if dim == self._data_horizontal_dim: continue - #elif dim == "time": + # elif dim == "time": # # Use timeInd for time dimension # slice_tuple.append(timeInd) - elif hasattr(self, '_slices') and dim in self._slices: + elif hasattr(self, "_slices") and dim in self._slices: # Use user-specified slice for this dimension slice_tuple.append(self._slices[dim]) else: @@ -439,12 +438,12 @@ def _build_geometry(self, meshdata): dims = meshdata.dimensions mvars = np.array(list(meshdata.variables.keys())) - + # Use the identified horizontal dimension if not self._horizontal_dim: print_error("Horizontal dimension not identified in connectivity file") return - + ncells2D = dims[self._horizontal_dim].size self._cached_ncells2D = ncells2D @@ -491,58 +490,60 @@ def _build_geometry(self, meshdata): def _populate_variable_metadata(self): if self._DataFileName is None or self._ConnFileName is None: return - + meshdata = self._get_mesh_dataset() vardata = self._get_var_dataset() - + # Identify horizontal dimensions first self._identify_horizontal_dimension(meshdata, vardata) - + if not self._data_horizontal_dim: print_error("Could not detect horizontal dimension in data file") return - + # Clear existing selection arrays BEFORE adding new ones self._variable_selection.RemoveAllArrays() - - # Collect all unique dimensions that need slicing + + # First pass: collect dimensions used by valid variables all_dimensions = set() - - # Store dimension metadata - self._dimensions.clear() - for dim_name, dim_obj in vardata.dimensions.items(): - # Create DimMeta object - dim_meta = DimMeta(dim_name, dim_obj.size) - - # Try to load dimension coordinate variable if it exists - if dim_name in vardata.variables: - dim_var = vardata.variables[dim_name] - # Load dimension data - try: - dim_meta.data = vardata[dim_name][:].data - except: - pass - # Update metadata from variable attributes - dim_meta.update_from_variable(dim_var) - - self._dimensions[dim_name] = dim_meta - for name, info in vardata.variables.items(): dims = set(info.dimensions) - if not self._data_horizontal_dim in dims: + if self._data_horizontal_dim not in dims: continue varmeta = VarMeta(name, info, self._data_horizontal_dim) if len(dims) == 1 and "area" in name: - self._areavar = varmeta + self._areavar = varmeta if len(dims) > 1: all_dimensions.update(dims) - self._variables[name] = varmeta # Store by name as key + self._variables[name] = varmeta self._variable_selection.AddArray(name) - - # Initialize slices for all dimensions at once - for dim in all_dimensions: - if dim not in self._slices: # Only set if not already set + + # Remove the horizontal dimension from sliceable dimensions + all_dimensions.discard(self._data_horizontal_dim) + + # Second pass: only populate _dimensions for dimensions that are: + # 1. Used by at least one valid variable + # 2. Have arity > 1 + self._dimensions.clear() + for dim_name in all_dimensions: + if dim_name in vardata.dimensions: + dim_obj = vardata.dimensions[dim_name] + if dim_obj.size > 1: + dim_meta = DimMeta(dim_name, dim_obj.size) + if dim_name in vardata.variables: + dim_var = vardata.variables[dim_name] + try: + dim_meta.data = vardata[dim_name][:].data + except Exception: + pass + dim_meta.update_from_variable(dim_var) + self._dimensions[dim_name] = dim_meta + + # Initialize slices for relevant dimensions + for dim in self._dimensions: + if dim not in self._slices: self._slices[dim] = 0 + self._variable_selection.DisableAllArrays() # Clear old timestamps before adding new ones @@ -588,18 +589,19 @@ def SetConnFileName(self, fname): def SetSlicing(self, slice_str): # Parse JSON string containing dimension slices and update self._slices # Initialize _slices if not already done - if not hasattr(self, '_slices'): + if not hasattr(self, "_slices"): self._slices = {} - + # Initialize dimensions if not already done - if not hasattr(self, '_dimensions'): + if not hasattr(self, "_dimensions"): self._dimensions = {} - + if slice_str and slice_str.strip(): # Check for non-empty string try: import json + slice_dict = json.loads(slice_str) - + # Validate and update slices for provided dimensions invalid_slices = [] for dim, slice_val in slice_dict.items(): @@ -615,24 +617,26 @@ def SetSlicing(self, slice_str): if dim_meta.long_name: dim_display += f" ({dim_meta.long_name})" invalid_slices.append( - f"{dim_display}={slice_val} (valid range: 0-{dim_size-1})" + f"{dim_display}={slice_val} (valid range: 0-{dim_size - 1})" ) else: self._slices[dim] = slice_val else: - print_error(f"Slice value for '{dim}' must be an integer, got {type(slice_val).__name__}") + print_error( + f"Slice value for '{dim}' must be an integer, got {type(slice_val).__name__}" + ) else: # Store the slice anyway for dimensions we haven't seen yet # (might be populated later) self._slices[dim] = slice_val if self._dimensions: # Only warn if we have dimension info print_warning(f"Dimension '{dim}' not found in data file") - + if invalid_slices: print_error(f"Invalid slice indices: {', '.join(invalid_slices)}") else: self.Modified() - + except (json.JSONDecodeError, ValueError) as e: print_error(f"Invalid JSON for slicing: {e}") except Exception as e: @@ -645,7 +649,7 @@ def SetCalculateAverages(self, calcavg): def GetVariables(self): return self._variables - + def GetDimensions(self): return self._dimensions @@ -724,14 +728,14 @@ def RequestData(self, request, inInfo, outInfo): # Ensure dimensions are identified self._identify_horizontal_dimension(meshdata, vardata) - + if not self._horizontal_dim or not self._data_horizontal_dim: print_error("Could not identify required dimensions from files") return 0 - + # Build geometry if not cached self._build_geometry(meshdata) - + if self._cached_points is None: print_error("Could not build geometry from connectivity file") return 0 @@ -752,9 +756,6 @@ def RequestData(self, request, inInfo, outInfo): self._dirty = False - # Use cached ncells2D - ncells2D = self._cached_ncells2D - # Needed to drop arrays from cached VTK Object to_remove = set() last_num_arrays = output_mesh.CellData.GetNumberOfArrays() @@ -765,13 +766,10 @@ def RequestData(self, request, inInfo, outInfo): if self._variable_selection.ArrayIsEnabled(name): if output_mesh.CellData.HasArray(name): to_remove.remove(name) - if ( - not output_mesh.CellData.HasArray(name) - or self._surface_update - ): + if not output_mesh.CellData.HasArray(name) or self._surface_update: data = self._load_variable(vardata, varmeta, timeInd) output_mesh.CellData.append(data, name) - + area_var_name = "area" if self._areavar and not output_mesh.CellData.HasArray(area_var_name): data = self._get_cached_area(vardata) diff --git a/src/e3sm_quickview/utils/colors.py b/src/e3sm_quickview/utils/colors.py new file mode 100644 index 0000000..b302ed7 --- /dev/null +++ b/src/e3sm_quickview/utils/colors.py @@ -0,0 +1,18 @@ +TYPE_COLORS = [ + "success", + "info", + "warning", + "error", + "purple", + "cyan", + "teal", + "indigo", + "pink", + "amber", + "lime", + "deep-purple", +] + + +def get_type_color(index: int) -> str: + return TYPE_COLORS[index % len(TYPE_COLORS)] diff --git a/src/e3sm_quickview/utils/js.py b/src/e3sm_quickview/utils/js.py index 441936e..7c2e51e 100644 --- a/src/e3sm_quickview/utils/js.py +++ b/src/e3sm_quickview/utils/js.py @@ -1,16 +1,2 @@ -def var_count(name): - return f"variables_selected.filter((v) => v[0] === '{name[0]}').length" - - -def var_remove(name): - return ( - f"variables_selected = variables_selected.filter((v) => v[0] !== '{name[0]}')" - ) - - -def var_title(name): - return " ".join(["{{", var_count(name), "}}", name.capitalize()]) - - def is_active(name): return f"active_tools.includes('{name}')" diff --git a/src/e3sm_quickview/view_manager.py b/src/e3sm_quickview/view_manager.py index 338accc..6a3bd3a 100644 --- a/src/e3sm_quickview/view_manager.py +++ b/src/e3sm_quickview/view_manager.py @@ -9,7 +9,7 @@ from paraview import simple -from e3sm_quickview.components import view +from e3sm_quickview.components import view as tview from e3sm_quickview.utils.color import get_cached_colorbar_image, COLORBAR_CACHE from e3sm_quickview.presets import COLOR_BLIND_SAFE @@ -41,21 +41,6 @@ def auto_size_to_col(size): "flow": None, } -TYPE_COLORS = [ - "success", - "info", - "warning", - "error", - "purple", - "cyan", - "teal", - "indigo", - "pink", - "amber", - "lime", - "deep-purple", -] - def lut_name(element): return element.get("name").lower() @@ -244,7 +229,7 @@ def _build_ui(self): dense=True, classes="ma-0 pa-0 bg-black opacity-90 d-flex align-center", ): - view.create_size_menu(self.name, self.config) + tview.create_size_menu(self.name, self.config) with html.Div( self.variable_name, classes="text-subtitle-2 pr-2", @@ -311,7 +296,7 @@ def _build_ui(self): self.view, interactive_ratio=1, ctx_name=self.name ) - view.create_bottom_bar(self.config, self.update_color_preset) + tview.create_bottom_bar(self.config, self.update_color_preset) class ViewManager(TrameComponent): @@ -431,21 +416,25 @@ def build_auto_layout(self, variables=None): # Create UI based on variables self.state.swap_groups = {} + # Build a lookup from type name to color from state.variable_types + type_to_color = {vt["name"]: vt["color"] for vt in self.state.variable_types} with DivLayout(self.server, template_name="auto_layout") as self.ui: if self.state.layout_grouped: with v3.VCol(classes="pa-1"): - for idx, var_type in enumerate(variables.keys()): + for var_type in variables.keys(): var_names = variables[var_type] total_size = len(var_names) if total_size == 0: continue + # Look up color from variable_types to match chip colors + border_color = type_to_color.get(str(var_type), "primary") with v3.VAlert( border="start", classes="pr-1 py-1 pl-3 mb-1", variant="flat", - border_color=TYPE_COLORS[idx % len(TYPE_COLORS)], + border_color=border_color, ): with v3.VRow(dense=True): for name in var_names: