diff --git a/SalishSeaTools/salishsea_tools/evaltools.py b/SalishSeaTools/salishsea_tools/evaltools.py index 8f5c7a7f..ab24901d 100644 --- a/SalishSeaTools/salishsea_tools/evaltools.py +++ b/SalishSeaTools/salishsea_tools/evaltools.py @@ -527,6 +527,14 @@ def _match_model_to_data( mesh_data["mask"], model_file_hours_res, ), + "salinity": lambda: _salinity_match( + data, + file_lists, + file_types, + file_type_model_vars, + mesh_data["mask"], + model_var_file_types, + ), "vvlZ": lambda: _interpvvlZ( data, file_lists, @@ -803,6 +811,100 @@ def _binmatch( return data +def _salinity_match(data, flist, ftypes, filemap_r, omask, fdict): + import xarray as xr + import numpy as np + from tqdm import tqdm + + matched_salinities = [] + + # Find which filetype contains salinity + salinity_var, salinity_ftype = None, None + for ftype in ftypes: + for var in filemap_r[ftype]: + if "sal" in var.lower(): ## was just 'sal' before + salinity_var = var + salinity_ftype = ftype + break + if salinity_var: + break + if not salinity_var: + raise ValueError("No salinity variable found in filemap_r.") + + salinity_files = flist[salinity_ftype] + salinity_files.columns = ["fname", "start", "end"] + + # Cache for xarray datasets + dataset_cache = {} + + for idx, row in tqdm(data.iterrows(), total=len(data)): + obs_time = row["dtUTC"] + obs_sal = row["Sal (g kg-1)"] + j, i = int(row["j"]), int(row["i"]) + k = None + + # Step 1: Find matching salinity depth + for _, mf in salinity_files.iterrows(): + if mf["start"] <= obs_time < mf["end"]: + fname = mf["fname"] + if fname not in dataset_cache: + dataset_cache[fname] = xr.open_dataset(fname) + ds = dataset_cache[fname] + + try: + # Select time (nearest if needed), then slice j,i + sel = ds[salinity_var].sel(time_counter=obs_time, method="nearest") + sal_profile = sel[:, j, i].values # depth profile + + if np.isnan(sal_profile).all(): + matched_salinities.append(np.nan) + k = None + else: + sal_diff = np.abs(sal_profile - obs_sal) + k = np.nanargmin(sal_diff) + matched_salinities.append(sal_profile[k]) + except Exception as e: + print(f"Error reading salinity at {fname}: {e}") + matched_salinities.append(np.nan) + k = None + break + + if k is None: + # Fill all variables with NaN + for ft in ftypes: + for var in filemap_r[ft]: + data.at[idx, f"mod_{var}"] = np.nan + continue + + # Step 2: Grab each variable at (time, k, j, i) using xarray + for ft in ftypes: + var_files = flist[ft] + var_files.columns = ["fname", "start", "end"] + + for _, mf in var_files.iterrows(): + if mf["start"] <= obs_time < mf["end"]: + fname = mf["fname"] + if fname not in dataset_cache: + dataset_cache[fname] = xr.open_dataset(fname) + ds = dataset_cache[fname] + + for var in filemap_r[ft]: + try: + sel = ds[var].sel(time_counter=obs_time, method="nearest") + val = sel[k, j, i].item() + except Exception: + val = np.nan + data.at[idx, f"mod_{var}"] = val + break + + # Close all xarray datasets + for ds in dataset_cache.values(): + ds.close() + + data["matched_salinity"] = matched_salinities + return data + + def _vvlBin( data, flist, diff --git a/SalishSeaTools/tests/test_evaltools_salinity_match.py b/SalishSeaTools/tests/test_evaltools_salinity_match.py new file mode 100644 index 00000000..f0544621 --- /dev/null +++ b/SalishSeaTools/tests/test_evaltools_salinity_match.py @@ -0,0 +1,194 @@ +# Copyright 2013 – present by the SalishSeaCast contributors +# and The University of British Columbia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for evaltools module _salinity_match() function.""" + +import numpy +import pandas +import pytest +import xarray + +from salishsea_tools import evaltools + + +class TestSalinityMatch: + @staticmethod + @pytest.fixture + def sample_data(): + data = pandas.DataFrame( + { + "dtUTC": pandas.to_datetime(["2025-09-15 12:00"]), + "Sal (g kg-1)": [30.0], + "j": [10], + "i": [20], + } + ) + return data + + @staticmethod + @pytest.fixture + def sample_flist(): + return { + "grid_T": pandas.DataFrame( + { + "fname": ["file1.nc"], + "start": [pandas.Timestamp("2025-09-15 00:00")], + "end": [pandas.Timestamp("2025-09-16 00:00")], + } + ) + } + + @staticmethod + @pytest.fixture + def sample_ftypes(): + return ["grid_T"] + + @staticmethod + @pytest.fixture + def sample_filemap_r(): + return {"grid_T": ["salinity"]} + + @staticmethod + @pytest.fixture + def sample_omask(): + return None + + @staticmethod + @pytest.fixture + def sample_fdict(): + return {} + + def test_no_salinity_variable( + self, + sample_data, + sample_flist, + sample_ftypes, + sample_filemap_r, + sample_omask, + sample_fdict, + ): + filemap_r_no_sal = {"grid_T": ["temperature"]} + + with pytest.raises( + ValueError, match="No salinity variable found in filemap_r." + ): + evaltools._salinity_match( + sample_data, + sample_flist, + sample_ftypes, + filemap_r_no_sal, + sample_omask, + sample_fdict, + ) + + def test_matching_salinity( + self, + sample_data, + sample_flist, + sample_ftypes, + sample_filemap_r, + sample_omask, + sample_fdict, + monkeypatch, + ): + class MockSalinityProfile: + def __getitem__(self, *indices): + return xarray.DataArray( + data=numpy.array([29.5, 30.0, 30.5]), + coords={ + "deptht": numpy.array([0.5, 1.5, 2.5]), + }, + dims=("deptht",), + ) + + class MockDataArray: + def sel(self, time_counter, method): + return MockSalinityProfile() + + class MockDataset: + def __getitem__(self, key): + return MockDataArray() + + def close(self): + pass + + def mock_open_dataset(path): + return MockDataset() + + monkeypatch.setattr(evaltools.xr, "open_dataset", mock_open_dataset) + + result = evaltools._salinity_match( + sample_data, + sample_flist, + sample_ftypes, + sample_filemap_r, + sample_omask, + sample_fdict, + ) + + assert "matched_salinity" in result + numpy.testing.assert_array_almost_equal( + result["matched_salinity"], numpy.array([30.0]) + ) + + def test_no_matching_salinity( + self, + sample_data, + sample_flist, + sample_ftypes, + sample_filemap_r, + sample_omask, + sample_fdict, + monkeypatch, + ): + class MockSalinityProfile: + def __getitem__(self, *indices): + return xarray.DataArray( + data=numpy.array([numpy.nan, numpy.nan, numpy.nan]), + coords={ + "deptht": numpy.array([0.5, 1.5, 2.5]), + }, + dims=("deptht",), + ) + + class MockDataArray: + def sel(self, time_counter, method): + return MockSalinityProfile() + + class MockDataset: + def __getitem__(self, key): + return MockDataArray() + + def close(self): + pass + + def mock_open_dataset(path): + return MockDataset() + + monkeypatch.setattr(evaltools.xr, "open_dataset", mock_open_dataset) + + result = evaltools._salinity_match( + sample_data, + sample_flist, + sample_ftypes, + sample_filemap_r, + sample_omask, + sample_fdict, + ) + + assert "matched_salinity" in result + numpy.testing.assert_array_almost_equal( + result["matched_salinity"], numpy.array([numpy.nan]) + )