Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 74 additions & 22 deletions src/e3sm_quickview/plugins/eam_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
)
_has_deps = False

dims1 = set(["ncol"])
dims2 = set(["time", "ncol"])
dims3i = set(["time", "ilev", "ncol"])
dims3m = set(["time", "lev", "ncol"])
# Dimensions will be dynamically determined from connectivity and data files


class EAMConstants:
Expand All @@ -46,7 +43,7 @@ class VarType(Enum):


class VarMeta:
def __init__(self, name, info):
def __init__(self, name, info, horizontal_dim=None):
self.name = name
self.type = None
self.transpose = False
Expand All @@ -64,7 +61,8 @@ def __init__(self, name, info):
elif "ilev" in dims:
self.type = VarType._3Di

if "ncol" in dims[1]:
# Use dynamic horizontal dimension
if horizontal_dim and len(dims) > 1 and horizontal_dim in dims[1]:
self.transpose = True


Expand All @@ -74,7 +72,7 @@ def compare(data, arrays, dim):
raise Exception(
"Length of hya_/hyb_ variable does not match the corresponding dimension"
)
for i, array in enumerate(arrays[1:], start=1):
for array in arrays[1:]:
comp = data[array][:].flatten()
if not np.array_equal(ref, comp):
return None
Expand Down Expand Up @@ -251,6 +249,10 @@ def __init__(self):
self._cached_lev = None
self._cached_ilev = None
self._cached_area = None

# Dynamic dimension detection
self._horizontal_dim = None # From connectivity file
self._data_horizontal_dim = None # Matched in data file

def __del__(self):
"""Clean up NetCDF file handles on deletion."""
Expand Down Expand Up @@ -308,6 +310,31 @@ def _clear(self):
self._cached_lev = None
self._cached_ilev = None
self._cached_area = None
# Clear dimension detection
self._horizontal_dim = None
self._data_horizontal_dim = None

def _identify_horizontal_dimension(self, meshdata, vardata):
"""Identify horizontal dimension from connectivity and match with data file."""
if self._horizontal_dim and self._data_horizontal_dim:
return # Already identified

# Get first dimension from connectivity file
conn_dims = list(meshdata.dimensions.keys())
if not conn_dims:
print_error("No dimensions found in connectivity file")
return

self._horizontal_dim = conn_dims[0]
conn_size = meshdata.dimensions[self._horizontal_dim].size

# Match dimension in data file by size
for dim_name, dim_obj in vardata.dimensions.items():
if dim_obj.size == conn_size:
self._data_horizontal_dim = dim_name
return

print_error(f"Could not match horizontal dimension size {conn_size} in data file")

def _clear_geometry_cache(self):
"""Clear cached geometry data."""
Expand Down Expand Up @@ -381,18 +408,14 @@ def _build_geometry(self, meshdata):
return

dims = meshdata.dimensions
mdims = np.array(list(meshdata.dimensions.keys()))
mvars = np.array(list(meshdata.variables.keys()))

# Find ncells2D
ncells2D = dims[
mdims[
np.where(
(np.char.find(mdims, "grid_size") > -1)
| (np.char.find(mdims, "ncol") > -1)
)[0][0]
]
].size

# Use the identified horizontal dimension
if not self._horizontal_dim:
print_error("Horizontal dimension not identified in connectivity file")
return

ncells2D = dims[self._horizontal_dim].size
self._cached_ncells2D = ncells2D

# Find lat/lon dimensions
Expand Down Expand Up @@ -436,20 +459,35 @@ def _build_geometry(self, meshdata):
)

def _populate_variable_metadata(self):
if self._DataFileName is None:
if self._DataFileName is None or self._ConnFileName is None:
return

meshdata = self._get_mesh_dataset()
vardata = self._get_var_dataset()


# Identify horizontal dimensions first
self._identify_horizontal_dimension(meshdata, vardata)

if not self._data_horizontal_dim:
print_error("Could not detect horizontal dimension in data file")
return

# Clear existing selection arrays BEFORE adding new ones
self._surface_selection.RemoveAllArrays()
self._midpoint_selection.RemoveAllArrays()
self._interface_selection.RemoveAllArrays()

# Define dimension sets dynamically based on detected dimension
dims1 = set([self._data_horizontal_dim])
dims2 = set(['time', self._data_horizontal_dim])
dims3m = set(['time', 'lev', self._data_horizontal_dim])
dims3i = set(['time', 'ilev', self._data_horizontal_dim])

for name, info in vardata.variables.items():
dims = set(info.dimensions)
if not (dims == dims1 or dims == dims2 or dims == dims3m or dims == dims3i):
continue
varmeta = VarMeta(name, info)
varmeta = VarMeta(name, info, self._data_horizontal_dim)
if varmeta.type == VarType._1D:
self._info_vars.append(varmeta)
if "area" in name:
Expand Down Expand Up @@ -507,15 +545,18 @@ def SetConnFileName(self, fname):
self._surface_update = True
self._midpoint_update = True
self._interface_update = True
self._clear() # Clear dimension cache
# Close old dataset if filename changed
if self._cached_mesh_filename != fname and self._mesh_dataset is not None:
try:
self._mesh_dataset.close()
except Exception:
pass
self._mesh_dataset = None
# Clear geometry cache when connectivity file changes
self._clear_geometry_cache()
# Re-populate metadata if data file is already set
if self._DataFileName:
self._populate_variable_metadata()
self.Modified()

def SetMiddleLayer(self, lev):
Expand Down Expand Up @@ -616,8 +657,19 @@ def RequestData(self, request, inInfo, outInfo):
meshdata = self._get_mesh_dataset()
vardata = self._get_var_dataset()

# Ensure dimensions are identified
self._identify_horizontal_dimension(meshdata, vardata)

if not self._horizontal_dim or not self._data_horizontal_dim:
print_error("Could not identify required dimensions from files")
return 0

# Build geometry if not cached
self._build_geometry(meshdata)

if self._cached_points is None:
print_error("Could not build geometry from connectivity file")
return 0

output_mesh = dsa.WrapDataObject(self._output)

Expand Down
Loading