diff --git a/src/e3sm_quickview/plugins/eam_reader.py b/src/e3sm_quickview/plugins/eam_reader.py index 275cf10..2e1173e 100644 --- a/src/e3sm_quickview/plugins/eam_reader.py +++ b/src/e3sm_quickview/plugins/eam_reader.py @@ -18,10 +18,7 @@ ) _has_deps = False -dims1 = set(["ncol"]) -dims2 = set(["time", "ncol"]) -dims3i = set(["time", "ilev", "ncol"]) -dims3m = set(["time", "lev", "ncol"]) +# Dimensions will be dynamically determined from connectivity and data files class EAMConstants: @@ -46,7 +43,7 @@ class VarType(Enum): class VarMeta: - def __init__(self, name, info): + def __init__(self, name, info, horizontal_dim=None): self.name = name self.type = None self.transpose = False @@ -64,7 +61,8 @@ def __init__(self, name, info): elif "ilev" in dims: self.type = VarType._3Di - if "ncol" in dims[1]: + # Use dynamic horizontal dimension + if horizontal_dim and len(dims) > 1 and horizontal_dim in dims[1]: self.transpose = True @@ -74,7 +72,7 @@ def compare(data, arrays, dim): raise Exception( "Length of hya_/hyb_ variable does not match the corresponding dimension" ) - for i, array in enumerate(arrays[1:], start=1): + for array in arrays[1:]: comp = data[array][:].flatten() if not np.array_equal(ref, comp): return None @@ -251,6 +249,10 @@ def __init__(self): self._cached_lev = None self._cached_ilev = None self._cached_area = None + + # Dynamic dimension detection + self._horizontal_dim = None # From connectivity file + self._data_horizontal_dim = None # Matched in data file def __del__(self): """Clean up NetCDF file handles on deletion.""" @@ -308,6 +310,31 @@ def _clear(self): self._cached_lev = None self._cached_ilev = None self._cached_area = None + # Clear dimension detection + self._horizontal_dim = None + self._data_horizontal_dim = None + + def _identify_horizontal_dimension(self, meshdata, vardata): + """Identify horizontal dimension from connectivity and match with data file.""" + if self._horizontal_dim and self._data_horizontal_dim: + return # Already identified + + # Get first dimension from connectivity file + conn_dims = list(meshdata.dimensions.keys()) + if not conn_dims: + print_error("No dimensions found in connectivity file") + return + + self._horizontal_dim = conn_dims[0] + conn_size = meshdata.dimensions[self._horizontal_dim].size + + # Match dimension in data file by size + for dim_name, dim_obj in vardata.dimensions.items(): + if dim_obj.size == conn_size: + self._data_horizontal_dim = dim_name + return + + print_error(f"Could not match horizontal dimension size {conn_size} in data file") def _clear_geometry_cache(self): """Clear cached geometry data.""" @@ -381,18 +408,14 @@ def _build_geometry(self, meshdata): return dims = meshdata.dimensions - mdims = np.array(list(meshdata.dimensions.keys())) mvars = np.array(list(meshdata.variables.keys())) - - # Find ncells2D - ncells2D = dims[ - mdims[ - np.where( - (np.char.find(mdims, "grid_size") > -1) - | (np.char.find(mdims, "ncol") > -1) - )[0][0] - ] - ].size + + # Use the identified horizontal dimension + if not self._horizontal_dim: + print_error("Horizontal dimension not identified in connectivity file") + return + + ncells2D = dims[self._horizontal_dim].size self._cached_ncells2D = ncells2D # Find lat/lon dimensions @@ -436,20 +459,35 @@ def _build_geometry(self, meshdata): ) def _populate_variable_metadata(self): - if self._DataFileName is None: + if self._DataFileName is None or self._ConnFileName is None: return + + meshdata = self._get_mesh_dataset() vardata = self._get_var_dataset() - + + # Identify horizontal dimensions first + self._identify_horizontal_dimension(meshdata, vardata) + + if not self._data_horizontal_dim: + print_error("Could not detect horizontal dimension in data file") + return + # Clear existing selection arrays BEFORE adding new ones self._surface_selection.RemoveAllArrays() self._midpoint_selection.RemoveAllArrays() self._interface_selection.RemoveAllArrays() + + # Define dimension sets dynamically based on detected dimension + dims1 = set([self._data_horizontal_dim]) + dims2 = set(['time', self._data_horizontal_dim]) + dims3m = set(['time', 'lev', self._data_horizontal_dim]) + dims3i = set(['time', 'ilev', self._data_horizontal_dim]) for name, info in vardata.variables.items(): dims = set(info.dimensions) if not (dims == dims1 or dims == dims2 or dims == dims3m or dims == dims3i): continue - varmeta = VarMeta(name, info) + varmeta = VarMeta(name, info, self._data_horizontal_dim) if varmeta.type == VarType._1D: self._info_vars.append(varmeta) if "area" in name: @@ -507,6 +545,7 @@ def SetConnFileName(self, fname): self._surface_update = True self._midpoint_update = True self._interface_update = True + self._clear() # Clear dimension cache # Close old dataset if filename changed if self._cached_mesh_filename != fname and self._mesh_dataset is not None: try: @@ -514,8 +553,10 @@ def SetConnFileName(self, fname): except Exception: pass self._mesh_dataset = None - # Clear geometry cache when connectivity file changes self._clear_geometry_cache() + # Re-populate metadata if data file is already set + if self._DataFileName: + self._populate_variable_metadata() self.Modified() def SetMiddleLayer(self, lev): @@ -616,8 +657,19 @@ def RequestData(self, request, inInfo, outInfo): meshdata = self._get_mesh_dataset() vardata = self._get_var_dataset() + # Ensure dimensions are identified + self._identify_horizontal_dimension(meshdata, vardata) + + if not self._horizontal_dim or not self._data_horizontal_dim: + print_error("Could not identify required dimensions from files") + return 0 + # Build geometry if not cached self._build_geometry(meshdata) + + if self._cached_points is None: + print_error("Could not build geometry from connectivity file") + return 0 output_mesh = dsa.WrapDataObject(self._output)