Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions SalishSeaTools/salishsea_tools/evaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,14 @@ def _match_model_to_data(
mesh_data["mask"],
model_file_hours_res,
),
"salinity": lambda: _salinity_match(
data,
file_lists,
file_types,
file_type_model_vars,
mesh_data["mask"],
model_var_file_types,
),
"vvlZ": lambda: _interpvvlZ(
data,
file_lists,
Expand Down Expand Up @@ -803,6 +811,100 @@ def _binmatch(
return data


def _salinity_match(data, flist, ftypes, filemap_r, omask, fdict):
import xarray as xr
import numpy as np
from tqdm import tqdm

matched_salinities = []

# Find which filetype contains salinity
salinity_var, salinity_ftype = None, None
for ftype in ftypes:
for var in filemap_r[ftype]:
if "sal" in var.lower(): ## was just 'sal' before
salinity_var = var
salinity_ftype = ftype
break
if salinity_var:
break
if not salinity_var:
raise ValueError("No salinity variable found in filemap_r.")

salinity_files = flist[salinity_ftype]
salinity_files.columns = ["fname", "start", "end"]

# Cache for xarray datasets
dataset_cache = {}

for idx, row in tqdm(data.iterrows(), total=len(data)):
obs_time = row["dtUTC"]
obs_sal = row["Sal (g kg-1)"]
j, i = int(row["j"]), int(row["i"])
k = None

# Step 1: Find matching salinity depth
for _, mf in salinity_files.iterrows():
if mf["start"] <= obs_time < mf["end"]:
fname = mf["fname"]
if fname not in dataset_cache:
dataset_cache[fname] = xr.open_dataset(fname)
ds = dataset_cache[fname]

try:
# Select time (nearest if needed), then slice j,i
sel = ds[salinity_var].sel(time_counter=obs_time, method="nearest")
sal_profile = sel[:, j, i].values # depth profile

if np.isnan(sal_profile).all():
matched_salinities.append(np.nan)
k = None
else:
sal_diff = np.abs(sal_profile - obs_sal)
k = np.nanargmin(sal_diff)
matched_salinities.append(sal_profile[k])
except Exception as e:
print(f"Error reading salinity at {fname}: {e}")
matched_salinities.append(np.nan)
k = None
break

if k is None:
# Fill all variables with NaN
for ft in ftypes:
for var in filemap_r[ft]:
data.at[idx, f"mod_{var}"] = np.nan
continue

# Step 2: Grab each variable at (time, k, j, i) using xarray
for ft in ftypes:
var_files = flist[ft]
var_files.columns = ["fname", "start", "end"]

for _, mf in var_files.iterrows():
if mf["start"] <= obs_time < mf["end"]:
fname = mf["fname"]
if fname not in dataset_cache:
dataset_cache[fname] = xr.open_dataset(fname)
ds = dataset_cache[fname]

for var in filemap_r[ft]:
try:
sel = ds[var].sel(time_counter=obs_time, method="nearest")
val = sel[k, j, i].item()
except Exception:
val = np.nan
data.at[idx, f"mod_{var}"] = val
break

# Close all xarray datasets
for ds in dataset_cache.values():
ds.close()

data["matched_salinity"] = matched_salinities
return data


def _vvlBin(
data,
flist,
Expand Down
194 changes: 194 additions & 0 deletions SalishSeaTools/tests/test_evaltools_salinity_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright 2013 – present by the SalishSeaCast contributors
# and The University of British Columbia
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for evaltools module _salinity_match() function."""

import numpy
import pandas
import pytest
import xarray

from salishsea_tools import evaltools


class TestSalinityMatch:
@staticmethod
@pytest.fixture
def sample_data():
data = pandas.DataFrame(
{
"dtUTC": pandas.to_datetime(["2025-09-15 12:00"]),
"Sal (g kg-1)": [30.0],
"j": [10],
"i": [20],
}
)
return data

@staticmethod
@pytest.fixture
def sample_flist():
return {
"grid_T": pandas.DataFrame(
{
"fname": ["file1.nc"],
"start": [pandas.Timestamp("2025-09-15 00:00")],
"end": [pandas.Timestamp("2025-09-16 00:00")],
}
)
}

@staticmethod
@pytest.fixture
def sample_ftypes():
return ["grid_T"]

@staticmethod
@pytest.fixture
def sample_filemap_r():
return {"grid_T": ["salinity"]}

@staticmethod
@pytest.fixture
def sample_omask():
return None

@staticmethod
@pytest.fixture
def sample_fdict():
return {}

def test_no_salinity_variable(
self,
sample_data,
sample_flist,
sample_ftypes,
sample_filemap_r,
sample_omask,
sample_fdict,
):
filemap_r_no_sal = {"grid_T": ["temperature"]}

with pytest.raises(
ValueError, match="No salinity variable found in filemap_r."
):
evaltools._salinity_match(
sample_data,
sample_flist,
sample_ftypes,
filemap_r_no_sal,
sample_omask,
sample_fdict,
)

def test_matching_salinity(
self,
sample_data,
sample_flist,
sample_ftypes,
sample_filemap_r,
sample_omask,
sample_fdict,
monkeypatch,
):
class MockSalinityProfile:
def __getitem__(self, *indices):
return xarray.DataArray(
data=numpy.array([29.5, 30.0, 30.5]),
coords={
"deptht": numpy.array([0.5, 1.5, 2.5]),
},
dims=("deptht",),
)

class MockDataArray:
def sel(self, time_counter, method):
return MockSalinityProfile()

class MockDataset:
def __getitem__(self, key):
return MockDataArray()

def close(self):
pass

def mock_open_dataset(path):
return MockDataset()

monkeypatch.setattr(evaltools.xr, "open_dataset", mock_open_dataset)

result = evaltools._salinity_match(
sample_data,
sample_flist,
sample_ftypes,
sample_filemap_r,
sample_omask,
sample_fdict,
)

assert "matched_salinity" in result
numpy.testing.assert_array_almost_equal(
result["matched_salinity"], numpy.array([30.0])
)

def test_no_matching_salinity(
self,
sample_data,
sample_flist,
sample_ftypes,
sample_filemap_r,
sample_omask,
sample_fdict,
monkeypatch,
):
class MockSalinityProfile:
def __getitem__(self, *indices):
return xarray.DataArray(
data=numpy.array([numpy.nan, numpy.nan, numpy.nan]),
coords={
"deptht": numpy.array([0.5, 1.5, 2.5]),
},
dims=("deptht",),
)

class MockDataArray:
def sel(self, time_counter, method):
return MockSalinityProfile()

class MockDataset:
def __getitem__(self, key):
return MockDataArray()

def close(self):
pass

def mock_open_dataset(path):
return MockDataset()

monkeypatch.setattr(evaltools.xr, "open_dataset", mock_open_dataset)

result = evaltools._salinity_match(
sample_data,
sample_flist,
sample_ftypes,
sample_filemap_r,
sample_omask,
sample_fdict,
)

assert "matched_salinity" in result
numpy.testing.assert_array_almost_equal(
result["matched_salinity"], numpy.array([numpy.nan])
)
Loading