From 229a69f1dfd17148378f73c3c3f16315f2fe62be Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Thu, 13 Nov 2025 12:16:07 +0100 Subject: [PATCH 01/22] Working on adding the option to import several variables with model objects. --- wavy/grid_readers.py | 102 +++++++++++++++++++++++++++++++++--------- wavy/model_module.py | 54 ++++++++++++---------- wavy/model_readers.py | 44 ++++++++++++------ 3 files changed, 144 insertions(+), 56 deletions(-) diff --git a/wavy/grid_readers.py b/wavy/grid_readers.py index adaeead..b585961 100755 --- a/wavy/grid_readers.py +++ b/wavy/grid_readers.py @@ -40,13 +40,15 @@ def read_ww3_unstructured_to_grid(**kwargs): pathlst = kwargs.get('pathlst') nID = kwargs.get('nID') varalias = kwargs.get('varalias') + if isinstance(varalias, str): + varalias = [varalias] sd = kwargs.get('sd') ed = kwargs.get('ed') meta = kwargs.get('meta') # get meta data - varstr = get_filevarname(varalias, variable_def, - model_dict[nID], meta) + varstr = [get_filevarname(v, variable_def, + model_dict[nID], meta) for v in varalias] lonstr = get_filevarname('lons', variable_def, model_dict[nID], meta) latstr = get_filevarname('lats', variable_def, @@ -72,12 +74,12 @@ def read_ww3_unstructured_to_grid(**kwargs): if (dt >= sd and dt <= ed): # varnames = (varname, lonstr, latstr, timestr) if len(times) < 2: - var = (ds[varstr].data, + var = (*tuple(ds[v].data for v in varstr), ds[lonstr].data, ds[latstr].data, t.reshape((1,))) else: - var = (ds[varstr].sel({timestr: t}).data, + var = (*tuple(ds[v].sel({timestr: t}).data for v in varstr), ds[lonstr].data, ds[latstr].data, t.reshape((1,))) @@ -99,29 +101,54 @@ def read_ww3_unstructured_to_grid(**kwargs): def get_gridded_dataset(var, t, **kwargs): varstr = kwargs.get('varalias') + if isinstance(varstr, str): + varstr = [varstr] + len_varstr = len(varstr) + + var_means_list = [] + + print(var) if kwargs.get('interp') is None: print(" Apply gridding, no interpolation") # grid data - gridvar, lon_grid, lat_grid = \ - grid_point_cloud_ds(var[0], var[1], var[2], t, **kwargs) - - var_means = gridvar['mor'] - # transpose dims - var_means = np.transpose(var_means) - field_shape = (list(var_means.shape[::-1]) + [1])[::-1] - var_means = var_means.reshape(field_shape) + + for i in range(len_varstr): + gridvar, lon_grid, lat_grid = \ + grid_point_cloud_ds(var[i], + var[len_varstr], + var[len_varstr+1], + t, + **kwargs) + + var_means = gridvar['mor'] + # transpose dims + var_means = np.transpose(var_means) + field_shape = (list(var_means.shape[::-1]) + [1])[::-1] + var_means = var_means.reshape(field_shape) + + var_means_list.append(var_means) + else: print(" Apply gridding with interpolation") - var_means, lon_grid, lat_grid = \ - grid_point_cloud_interp_ds( - var[0], var[1], var[2], **kwargs) - # transpose dims - field_shape = (list(var_means.shape[::-1]) + [1])[::-1] - var_means = var_means.reshape(field_shape) + for i in range(len_varstr): + + var_means, lon_grid, lat_grid = \ + grid_point_cloud_interp_ds( + var[i], + var[len_varstr], + var[len_varstr+1], + **kwargs) + + # transpose dims + field_shape = (list(var_means.shape[::-1]) + [1])[::-1] + var_means = var_means.reshape(field_shape) + + var_means_list.append(var_means) + # create xr.dataset - ds = build_xr_ds_grid( - var_means, + ds = build_xr_ds_grid_multivar( + var_means_list, np.unique(lon_grid), np.unique(lat_grid), t.reshape((1,)), varstr=varstr) @@ -158,6 +185,41 @@ def build_xr_ds_grid(var_means, lon_grid, lat_grid, t, **kwargs): ) return ds + +def build_xr_ds_grid_multivar(var_means_list, lon_grid, lat_grid, t, **kwargs): + print(" building xarray dataset from grid") + varstr = kwargs.get('varstr') + if isinstance(varstr, str): + varstr = [varstr] + + ds = xr.Dataset({ + **{varstr[i]: xr.DataArray( + data=var_means_list[i], + dims=['time', 'latitude', 'longitude'], + coords={'latitude': lat_grid, + 'longitude': lon_grid, + 'time': t}, + attrs=variable_def[varstr[i]], + ) for i in range(len(var_means_list))}, + 'lons': xr.DataArray( + data=lon_grid, + dims=['longitude'], + coords={'longitude': lon_grid}, + attrs=variable_def['lons'], + ), + 'lats': xr.DataArray( + data=lat_grid, + dims=['latitude'], + coords={'latitude': lat_grid}, + attrs=variable_def['lats'], + ), + }, + attrs={'title': 'wavy dataset'} + ) + return ds + + + def build_xr_ds_grid_2D(var_means, lon_grid, lat_grid, t, **kwargs): print(" building xarray dataset from grid") varstr = kwargs.get('varstr') diff --git a/wavy/model_module.py b/wavy/model_module.py index dd61c18..d5ad029 100755 --- a/wavy/model_module.py +++ b/wavy/model_module.py @@ -152,9 +152,12 @@ def __init__(self, **kwargs): # add other class object variables self.nID = kwargs.get('nID') self.model = kwargs.get('model', self.nID) - self.varalias = kwargs.get('varalias', 'Hs') - self.units = variable_def[self.varalias].get('units') - self.stdvarname = variable_def[self.varalias].get('standard_name') + self.varalias = kwargs.get('varalias', ['Hs']) + if isinstance(self.varalias, str): + self.varalias = [self.varalias] + self.units = [variable_def[v].get('units') for v in self.varalias] + self.stdvarname = [variable_def[v].get('standard_name') for v in\ + self.varalias] self.distlim = kwargs.get('distlim', 6) self.filter = kwargs.get('filter', False) self.region = kwargs.get('region', 'global') @@ -540,27 +543,29 @@ def _enforce_longitude_format(self): def _enforce_meteorologic_convention(self): print(' enforcing meteorological convention') - if ('convention' in vars(self.cfg)['misc'].keys() and - vars(self.cfg)['misc']['convention'] == 'oceanographic'): - print('Convert from oceanographic to meteorologic convention') - self.vars[self.varalias] =\ - convert_meteorologic_oceanographic(self.vars[self.varalias]) - elif 'to_direction' in self.vars[self.varalias].attrs['standard_name']: - print('Convert from oceanographic to meteorologic convention') - self.vars[self.varalias] =\ - convert_meteorologic_oceanographic(self.vars[self.varalias]) + for v in self.varalias: + if ('convention' in vars(self.cfg)['misc'].keys() and + vars(self.cfg)['misc']['convention'] == 'oceanographic'): + print('Convert from oceanographic to meteorologic convention') + + self.vars[v] = convert_meteorologic_oceanographic(self.vars[v]) + + elif 'to_direction' in self.vars[v].attrs['standard_name']: + print('Convert from oceanographic to meteorologic convention') + self.vars[v] = convert_meteorologic_oceanographic(self.vars[v]) return self def _change_varname_to_aliases(self): print(' changing variables to aliases') # variables - ncvar = get_filevarname(self.varalias, variable_def, - vars(self.cfg), self.meta) - if self.varalias in list(self.vars.keys()): - print(' ', ncvar, 'is alreade named correctly and' - + ' therefore not adjusted') - else: - self.vars = self.vars.rename({ncvar: self.varalias}) + for v in self.varalias: + ncvar = get_filevarname(v, variable_def, + vars(self.cfg), self.meta) + if v in list(self.vars.keys()): + print(' ', ncvar, 'is alreade named correctly and' + + ' therefore not adjusted') + else: + self.vars = self.vars.rename({ncvar: v}) # coords coords = ['time', 'lons', 'lats'] for c in coords: @@ -584,8 +589,9 @@ def _change_stdvarname_to_cfname(self): self.vars['time'].attrs['standard_name'] = \ variable_def['time'].get('standard_name') # enforce standard_name for variable alias - self.vars[self.varalias].attrs['standard_name'] = \ - self.stdvarname + for i in range(len(self.varalias)): + self.vars[self.varalias[i]].attrs['standard_name'] = \ + self.stdvarname[i] return self def populate(self, **kwargs): @@ -601,8 +607,10 @@ def populate(self, **kwargs): print('') print('Checking variables..') self.meta = ncdumpMeta(self.pathlst[0]) - ncvar = get_filevarname(self.varalias, variable_def, - vars(self.cfg), self.meta) + ncvar = [get_filevarname(v, variable_def, + vars(self.cfg), + self.meta) for \ + v in self.varalias] print('') print('Choosing reader..') # define reader diff --git a/wavy/model_readers.py b/wavy/model_readers.py index 1453467..90ebc7c 100755 --- a/wavy/model_readers.py +++ b/wavy/model_readers.py @@ -36,6 +36,8 @@ def read_ww3_4km(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] ds_lst = [] # retrieve sliced data for i in range(len(fc_dates)): @@ -43,9 +45,8 @@ def read_ww3_4km(**kwargs): p = pathlst[i] ds = xr.open_dataset(p, engine='netcdf4') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}) - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], - model_dict[nID]['vardef']['lats']]] + ds_sliced = ds_sliced[varname + [model_dict[nID]['vardef']['lons'], + model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) @@ -68,6 +69,8 @@ def read_meps(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] ds_lst = [] # retrieve sliced data for i in range(len(fc_dates)): @@ -75,9 +78,8 @@ def read_meps(**kwargs): p = pathlst[i] ds = xr.open_dataset(p, engine='netcdf4') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}) - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], - model_dict[nID]['vardef']['lats']]] + ds_sliced = ds_sliced[varname + [model_dict[nID]['vardef']['lons'], + model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) @@ -101,7 +103,11 @@ def read_noresm_making_waves(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, list): + varname = varname[0] varalias = kwargs.get('varalias') + if isinstance(varalias, list): + varalias = varalias[0] timename = model_dict[nID]['vardef']['time'] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] @@ -148,12 +154,14 @@ def read_remote_ncfiles_aggregated_credentials(**kwargs): ed = kwargs.get('ed') twin = kwargs.get('twin') varalias = kwargs.get('varalias') + if isinstance(varalias, str): + varalias = [varalias] nID = kwargs.get('nID') remoteHostName = kwargs.get('remoteHostName') path = kwargs.get('pathlst')[0] # varnames - varname = model_dict[nID]['vardef'][varalias] + varname = [model_dict[nID]['vardef'][v] for v in varalias] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] timename = model_dict[nID]['vardef']['time'] @@ -170,10 +178,10 @@ def read_remote_ncfiles_aggregated_credentials(**kwargs): path, remoteHostName, usr, pw) ds_sliced = ds.sel(time=slice(sd, ed)) - var_sliced = ds_sliced[[varname, lonsname, latsname]] + var_sliced = ds_sliced[varname + [lonsname, latsname]] # forge into correct format varalias, lons, lats with dim time - ds = build_xr_ds_grid(var_sliced[varname], + ds = build_xr_ds_grid_multivar([var_sliced[v] for v in varname], var_sliced[lonsname], var_sliced[latsname], var_sliced[timename], @@ -203,6 +211,8 @@ def read_field(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, list): + varname = varname[0] timename = model_dict[nID]['vardef']['time'] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] @@ -249,7 +259,11 @@ def read_ecwam(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, list): + varname = varname[0] varalias = kwargs.get('varalias') + if isinstance(varalias, list): + varalias = varalias[0] timename = model_dict[nID]['vardef']['time'] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] @@ -291,6 +305,8 @@ def read_era(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] ds_lst = [] # retrieve sliced data for i in range(len(fc_dates)): @@ -299,8 +315,8 @@ def read_era(**kwargs): ds = xr.open_dataset(p, engine='netcdf4') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}, method='nearest') - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], + ds_sliced = ds_sliced[varname+ + [model_dict[nID]['vardef']['lons'], model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) @@ -325,6 +341,8 @@ def read_NORA3_wind(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] hlevel = kwargs.get('heightlevel', 10) ds_lst = [] # retrieve sliced data @@ -334,8 +352,8 @@ def read_NORA3_wind(**kwargs): ds = xr.open_dataset(p, engine='netcdf4') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}) ds_sliced = ds_sliced.sel({'height': hlevel}) - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], + ds_sliced = ds_sliced[varname + + [model_dict[nID]['vardef']['lons'], model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) From 097d42c6fa19ddd73937171b1dd5b3e31ee9264f Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 25 Nov 2025 20:09:13 +0100 Subject: [PATCH 02/22] Added possibility to collocate several variables from a model to observations. --- tests/test_collocmod.py | 6 +- ...test_gridder.py => test_gridder_module.py} | 43 +++- wavy/collocation_module.py | 203 +++++++++++------- wavy/gridder_module.py | 63 ++++-- wavy/quicklookmod.py | 23 +- wavy/triple_collocation.py | 25 ++- 6 files changed, 239 insertions(+), 124 deletions(-) rename tests/{test_gridder.py => test_gridder_module.py} (62%) diff --git a/tests/test_collocmod.py b/tests/test_collocmod.py index 6e829ac..9e2a61e 100644 --- a/tests/test_collocmod.py +++ b/tests/test_collocmod.py @@ -28,7 +28,7 @@ def test_sat_collocation_and_validation(test_data, tmpdir): # collocate cco = cc(oco=sco, model=model, leadtime='best', distlim=6).populate() assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(cco.vars.keys()) == 10 # validate @@ -52,7 +52,7 @@ def test_insitu_collocation_and_validation(test_data, tmpdir): # collocate cco = cc(oco=ico, model=model, leadtime='best', distlim=6).populate() assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(cco.vars.keys()) == 10 # validate @@ -70,7 +70,7 @@ def test_poi_collocation(): # collocate cco = cc(oco=pco, model='ww3_4km', leadtime='best').populate() assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(cco.vars.keys()) == 10 # # write to nc diff --git a/tests/test_gridder.py b/tests/test_gridder_module.py similarity index 62% rename from tests/test_gridder.py rename to tests/test_gridder_module.py index 2ca82d2..3155417 100644 --- a/tests/test_gridder.py +++ b/tests/test_gridder_module.py @@ -1,12 +1,37 @@ -#import sys -#import os -#import numpy as np -#from datetime import datetime -#import pytest -# -#from wavy.satellite_module import satellite_class as sc -#from wavy.gridder import gridder_class as gc -#from wavy.grid_stats import apply_metric +import sys +import os +import numpy as np +from datetime import datetime +import pytest +from wavy import sc, gc +from wavy.grid_stats import apply_metric + +def test_gridder_init(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = 's3a' + varalias = ['Hs','U'] + twin = 30 + nID = 'cmems_L3_NRT' + # init satellite_object + sco = sc(sd=sd, ed=ed, nID=nID, name=name, + varalias=varalias, + twin=twin) + # read data + sco = sco.populate(reader='read_local_ncfiles', + path=str(test_data/"L3/s3a")) + + bb = (-179, 178, -80, 80) + res = (5, 5) + gco = gc(oco=sco,bb=bb,res=res) + assert len(vars(gco)) == 17 + assert gco.varalias == 'Hs' + assert gco.units == 'm' + gridvar, lon_grid, lat_grid = apply_metric(gco=gco) + assert len(gridvar.keys()) == 13 + + + #def test_gridder_lowres(test_data, benchmark): # sco = sc(sdate="2020-11-1",edate="2020-11-3",region="global", diff --git a/wavy/collocation_module.py b/wavy/collocation_module.py index ea3e03a..f919d71 100755 --- a/wavy/collocation_module.py +++ b/wavy/collocation_module.py @@ -269,17 +269,22 @@ def __init__(self, oco=None, model=None, poi=None, else: varalias = oco.varalias[0] self.varalias = varalias - self.varalias_obs = varalias - self.varalias_mod = kwargs.get('varalias_mod', varalias) + if isinstance(self.varalias, str): + self.varalias = [self.varalias] + self.varalias_obs = oco.varalias + if isinstance(self.varalias_obs, str): + self.varalias_obs = [self.varalias_obs] + self.varalias_mod = self.varalias self.model = model self.leadtime = leadtime self.oco = oco self.nID = oco.nID self.model = model self.obstype = str(type(oco))[8:-2] - self.stdvarname = oco.stdvarname + self.units = [variable_def[v].get('units') for v in self.varalias] + self.stdvarname = [variable_def[v].get('standard_name') for v in\ + self.varalias] self.region = oco.region - self.units = variable_def[varalias].get('units') self.sd = oco.sd self.ed = oco.ed self.twin = kwargs.get('twin', oco.twin) @@ -322,7 +327,9 @@ def populate(self, **kwargs): return new def _build_xr_dataset(self, results_dict): - ds = xr.Dataset({ + + try: + ds = xr.Dataset({ 'time': xr.DataArray( data=results_dict['obs_time'], dims=['time'], @@ -335,12 +342,6 @@ def _build_xr_dataset(self, results_dict): coords={'time': results_dict['obs_time']}, attrs=variable_def['dist'], ), - 'obs_values': xr.DataArray( - data=results_dict['obs_values'], - dims=['time'], - coords={'time': results_dict['obs_time']}, - attrs=variable_def[self.varalias], - ), 'obs_lons': xr.DataArray( data=results_dict['obs_lons'], dims=['time'], @@ -353,11 +354,16 @@ def _build_xr_dataset(self, results_dict): coords={'time': results_dict['obs_time']}, attrs=variable_def['lats'], ), - 'model_values': xr.DataArray( - data=results_dict['model_values'], + **{'obs_'+v:xr.DataArray( + data=results_dict['obs_'+v], + dims=['time'], + coords={'time': results_dict['obs_time']}, + attrs=variable_def[v]) for v in self.varalias_obs}, + 'model_time': xr.DataArray( + data=results_dict['model_time'], dims=['time'], coords={'time': results_dict['obs_time']}, - attrs=variable_def[self.varalias], + attrs=variable_def['time'], ), 'model_lons': xr.DataArray( data=results_dict['model_lons'], @@ -371,6 +377,11 @@ def _build_xr_dataset(self, results_dict): coords={'time': results_dict['obs_time']}, attrs=variable_def['lats'], ), + **{'model_'+v:xr.DataArray( + data=results_dict['model_'+v], + dims=['time'], + coords={'time': results_dict['obs_time']}, + attrs=variable_def[v]) for v in self.varalias_mod}, 'colidx_x': xr.DataArray( data=results_dict['collocation_idx_x'], dims=['time'], @@ -386,6 +397,9 @@ def _build_xr_dataset(self, results_dict): }, attrs={'title': str(type(self))[8:-2] + ' dataset'} ) + except Exception as e: + print(e) + print(ds) return ds def _drop_duplicates(self, **kwargs): @@ -407,20 +421,24 @@ def _collocate_field(self, mco, tmp_dict, **kwargs): """ Mlons = mco.vars.lons.data Mlats = mco.vars.lats.data - Mvars = mco.vars[mco.varalias].data + Mvars = {v:mco.vars[v].data for v in self.varalias_mod} + if len(Mlons.shape) > 2: Mlons = Mlons[0, :].squeeze() Mlats = Mlats[0, :].squeeze() - Mvars = Mvars[0, :].squeeze() + Mvars = {v:Mvars[v][0, :].squeeze() for v in\ + Mvars.keys()} + elif len(Mlons.shape) == 2: Mlons = Mlons.squeeze() Mlats = Mlats.squeeze() - Mvars = Mvars.squeeze() + Mvars = {v:Mvars[v].squeeze() for v in Mvars.keys()} + elif len(Mlons.shape) == 1: Mlons, Mlats = np.meshgrid(Mlons, Mlats) - Mvars = Mvars.squeeze() + Mvars = {v:Mvars[v].squeeze() for v in Mvars.keys()} assert len(Mlons.shape) == 2 - obs_vars = tmp_dict[self.varalias_obs] + obs_vars = {v:tmp_dict[v] for v in self.varalias_obs} obs_lons = tmp_dict['lons'] obs_lats = tmp_dict['lats'] # Compare wave heights of satellite with model with @@ -434,22 +452,25 @@ def _collocate_field(self, mco, tmp_dict, **kwargs): # caution: index_array_2d is tuple # impose distlim dist_idx = np.where((distance_array < self.distlim*1000) & - (~np.isnan(Mvars[index_array_2d[0], + (~np.isnan(Mvars[self.varalias_mod[0]]\ + [index_array_2d[0], index_array_2d[1]])))[0] idx_x = index_array_2d[0][dist_idx] idx_y = index_array_2d[1][dist_idx] results_dict = { 'dist': list(distance_array[dist_idx]), - 'model_values': list(Mvars[idx_x, idx_y]), 'model_lons': list(Mlons[idx_x, idx_y]), 'model_lats': list(Mlats[idx_x, idx_y]), - 'obs_values': list(obs_vars[dist_idx]), 'obs_lons': list(obs_lons[dist_idx]), 'obs_lats': list(obs_lats[dist_idx]), 'collocation_idx_x': list(idx_x), 'collocation_idx_y': list(idx_y), - 'time': tmp_dict['time'][dist_idx] + 'time': tmp_dict['time'][dist_idx], + **{'model_'+v: Mvars[v][idx_x, idx_y] for v in Mvars.keys()}, + **{'obs_'+v: obs_vars[v][dist_idx] for v in\ + self.varalias_obs} } + return results_dict def _collocate_track(self, **kwargs): @@ -478,18 +499,19 @@ def _collocate_track(self, **kwargs): print(f'... done, used {t2-t1:.2f} seconds') print("Start collocation ...") + results_dict = { 'model_time': [], 'obs_time': [], 'dist': [], - 'model_values': [], 'model_lons': [], 'model_lats': [], - 'obs_values': [], 'obs_lons': [], 'obs_lats': [], 'collocation_idx_x': [], 'collocation_idx_y': [], + **{'model_'+v:[] for v in self.varalias_mod}, + **{'obs_'+v:[] for v in self.varalias_obs} } for i in tqdm(range(len(fc_date))): @@ -534,29 +556,32 @@ def _collocate_track(self, **kwargs): tmp_dict['time'] = self.oco.vars['time'].values[idx] tmp_dict['lats'] = self.oco.vars['lats'].values[idx] tmp_dict['lons'] = self.oco.vars['lons'].values[idx] - tmp_dict[self.varalias] = \ - self.oco.vars[self.varalias].values[idx] - mco = mc(sd=fc_date[i], ed=fc_date[i], nID=self.model, - leadtime=self.leadtime, varalias=self.varalias_mod, + + for v in self.varalias_obs: + tmp_dict[v] = self.oco.vars[v].values[idx] + print("#########################") + print(self.varalias_mod) + mco = mc(sd=fc_date[i], ed=fc_date[i], + nID=self.model, + leadtime=self.leadtime, + varalias=self.varalias_mod, **kwargs) mco = mco.populate(**kwargs) results_dict_tmp = self._collocate_field( mco, tmp_dict, **kwargs) - if (len(results_dict_tmp['model_values']) > 0): + if (len(results_dict_tmp["model_" +\ + self.varalias_mod[0]]) > 0): # append to dict - results_dict['model_time'].append(fc_date[i]) + results_dict['model_time'].append(\ + [fc_date[i]]*len(results_dict_tmp['time'])) results_dict['obs_time'].append( results_dict_tmp['time']) results_dict['dist'].append( results_dict_tmp['dist']) - results_dict['model_values'].append( - results_dict_tmp['model_values']) results_dict['model_lons'].append( results_dict_tmp['model_lons']) results_dict['model_lats'].append( results_dict_tmp['model_lats']) - results_dict['obs_values'].append( - results_dict_tmp['obs_values']) results_dict['obs_lats'].append( results_dict_tmp['obs_lats']) results_dict['obs_lons'].append( @@ -565,6 +590,12 @@ def _collocate_track(self, **kwargs): results_dict_tmp['collocation_idx_x']) results_dict['collocation_idx_y'].append( results_dict_tmp['collocation_idx_y']) + for v in self.varalias_mod: + results_dict['model_'+v].append( + results_dict_tmp['model_'+v]) + for v in self.varalias_obs: + results_dict['obs_'+v].append( + results_dict_tmp['obs_'+v]) else: pass if 'results_dict_tmp' in locals(): @@ -576,19 +607,22 @@ def _collocate_track(self, **kwargs): logger.exception(e) print(e) # flatten all aggregated entries - results_dict['model_time'] = results_dict['model_time'] + results_dict['model_time'] = flatten(results_dict['model_time']) results_dict['obs_time'] = flatten(results_dict['obs_time']) results_dict['dist'] = flatten(results_dict['dist']) - results_dict['model_values'] = flatten(results_dict['model_values']) results_dict['model_lons'] = flatten(results_dict['model_lons']) results_dict['model_lats'] = flatten(results_dict['model_lats']) - results_dict['obs_values'] = flatten(results_dict['obs_values']) results_dict['obs_lats'] = flatten(results_dict['obs_lats']) results_dict['obs_lons'] = flatten(results_dict['obs_lons']) results_dict['collocation_idx_x'] = flatten( results_dict['collocation_idx_x']) results_dict['collocation_idx_y'] = flatten( results_dict['collocation_idx_y']) + for v in self.varalias_mod: + results_dict['model_'+v] = flatten(results_dict['model_'+v]) + for v in self.varalias_obs: + results_dict['obs_'+v] = flatten(results_dict['obs_'+v]) + return results_dict def _collocate_centered_model_value(self, time, lon, lat, **kwargs): @@ -604,37 +638,43 @@ def _collocate_centered_model_value(self, time, lon, lat, **kwargs): # ADD CHECK LIMITS FOR LAT AND LON res_dict = {} - time = pd.to_datetime(time) - time = hour_rounder(time, method=colloc_time_method) + time_dt = pd.to_datetime(time) + model_time = hour_rounder(time_dt, method=colloc_time_method) - mco = mc(sd=time, ed=time, - nID=nID_model, name=name_model, + mco = mc(sd=model_time, ed=model_time, + nID=nID_model, + name=name_model, + varalias=self.varalias_mod, max_lt=12).populate(twin=5) # ADD AS PARAMETERS - + bb = (lon - res[0]/2, lon + res[0]/2, lat - res[1]/2, lat + res[1]/2) - gco = gc(lons=mco.vars.lons.squeeze().values.ravel(), - lats=mco.vars.lats.squeeze().values.ravel(), - values=mco.vars.Hs.squeeze().values.ravel(), - bb=bb, res=res, - varalias=mco.varalias, - units=mco.units, - sdate=mco.vars.time, - edate=mco.vars.time) + for i, v in enumerate(mco.varalias): + gco = gc(lons=mco.vars.lons.squeeze().values.ravel(), + lats=mco.vars.lats.squeeze().values.ravel(), + values=mco.vars[v].squeeze().values.ravel(), + bb=bb, res=res, + varalias=v, + units=mco.units[i], + sdate=mco.vars.time, + edate=mco.vars.time) - gridvar, lon_grid, lat_grid = apply_metric(gco=gco) + gridvar, lon_grid, lat_grid = apply_metric(gco=gco) - ts = gridvar['mor'].flatten() + ts = gridvar['mor'].flatten() + + res_dict['model_'+v] = ts[0] + lon_flat = lon_grid.flatten() lat_flat = lat_grid.flatten() - - res_dict['hs'] = ts[0] + res_dict['lon'] = lon_flat[0] res_dict['lat'] = lat_flat[0] - res_dict['time'] = time + res_dict['obs_time'] = time + res_dict['model_time'] = model_time return res_dict @@ -663,41 +703,26 @@ def _collocate_regridded_model(self, **kwargs): ) length_colloc_mod_list = len(colloc_mod_list) - hs_mod_list = [colloc_mod_list[i]['hs'] for i in \ - range(length_colloc_mod_list)] + var_mod_list = {v:[colloc_mod_list[i]['model_'+v] for i in \ + range(length_colloc_mod_list)] for v in\ + self.varalias_mod} lon_mod_list = [colloc_mod_list[i]['lon'] for i in \ range(length_colloc_mod_list)] lat_mod_list = [colloc_mod_list[i]['lat'] for i in \ range(length_colloc_mod_list)] - time_mod_list = [colloc_mod_list[i]['time'] for i in \ + time_mod_list = [colloc_mod_list[i]['model_time'] for i in \ + range(length_colloc_mod_list)] + time_obs_list = [colloc_mod_list[i]['obs_time'] for i in \ range(length_colloc_mod_list)] - - mod_colloc_vars = xr.Dataset( - { - "lats": ( - ("time"), - lat_mod_list, - ), - "lons": ( - ("time"), - lon_mod_list - ), - "Hs": ( - ("time"), - hs_mod_list - ) - }, - coords={"time": time_mod_list}, - ) results_dict = { 'model_time': time_mod_list, - 'obs_time': oco_vars.time.values, + 'obs_time': time_obs_list, 'dist': [0]*length, - 'model_values': hs_mod_list, + **{'model_'+ v: var_mod_list[v] for v in self.varalias_mod}, 'model_lons': lon_mod_list, 'model_lats': lat_mod_list, - 'obs_values': oco_vars.Hs.values, + **{'obs_'+v: oco_vars[v].values for v in self.varalias_obs}, 'obs_lons': oco_vars.lons.values, 'obs_lats': oco_vars.lats.values, 'collocation_idx_x': [0]*length, @@ -732,10 +757,24 @@ def collocate(self, **kwargs): def validate_collocated_values(self, **kwargs): + + varalias = kwargs.get('varalias', self.varalias[0]) times = self.vars['time'] dtime = [parse_date(str(t.data)) for t in times] - mods = self.vars['model_values'] - obs = self.vars['obs_values'] + list_vars = list(self.vars.variables) + assert 'model_'+varalias in list_vars, "model_{}".format(varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + assert 'obs_'+varalias in list_vars, "obs_{}".format(varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + print("Validating model_{} against obs_{}".format(varalias, varalias)) + mods = self.vars['model_'+varalias] + obs = self.vars['obs_'+varalias] sdate = dtime[0] edate = dtime[-1] validation_dict = validate_collocated_values( diff --git a/wavy/gridder_module.py b/wavy/gridder_module.py index 626d4ea..bf0479a 100755 --- a/wavy/gridder_module.py +++ b/wavy/gridder_module.py @@ -7,6 +7,7 @@ from wavy.wconfig import load_or_default validation_metric_abbreviations = load_or_default('validation_metrics.yaml') +variable_def = load_or_default('variable_def.yaml') class gridder_class(): @@ -25,32 +26,70 @@ def __init__( print(" ") self.mvals = None if oco is not None: + self.varalias = kwargs.get('varalias', oco.varalias) + if isinstance(self.varalias, list): + if len(self.varalias) > 1: + print("Warning: gridder only expects one varalias.") + print("First varalias selected as default: {}".format( + self.varalias[0])) + print("If you want to select another variable, please "\ + +"specify with varalias argument.") + self.varalias=self.varalias[0] + self.units = variable_def[self.varalias].get('units') + self.stdvarname = variable_def[self.varalias].get('standard_name') self.olons = np.array(oco.vars['lons'].squeeze().values.ravel()) self.olats = np.array(oco.vars['lats'].squeeze().values.ravel()) - self.ovals = np.array(oco.vars[oco.varalias].squeeze().values.ravel()) - self.stdvarname = oco.stdvarname - self.varalias = oco.varalias - self.units = oco.units + self.ovals = np.array(oco.vars[self.varalias].squeeze().values.ravel()) self.sdate = oco.vars['time'][0] self.edate = oco.vars['time'][-1] elif cco is not None: + self.varalias = kwargs.get('varalias', cco.varalias) + if isinstance(self.varalias, list): + if len(self.varalias) > 1: + print("Warning: gridder only expects one varalias.") + print("First varalias selected as default: {}".format( + self.varalias[0])) + print("If you want to select another variable, please "\ + +"specify with varalias argument.") + self.varalias=self.varalias[0] + + list_vars = list(cco.vars.variables) + assert 'model_'+self.varalias in list_vars, "model_{}".format(self.varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + assert 'obs_'+self.varalias in list_vars, "obs_{}".format(self.varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + self.olons = np.array(cco.vars['obs_lons']) self.olats = np.array(cco.vars['obs_lats']) - self.ovals = np.array(cco.vars['obs_values']) - self.mvals = np.array(cco.vars['model_values']) - self.stdvarname = cco.stdvarname - self.varalias = cco.varalias - self.units = cco.units + self.ovals = np.array(cco.vars['obs_'+self.varalias]) + self.mvals = np.array(cco.vars['model_'+self.varalias]) + self.units = variable_def[self.varalias].get('units') + self.stdvarname = variable_def[self.varalias].get('standard_name') self.sdate = cco.vars['time'][0] self.edate = cco.vars['time'][-1] elif mco is not None: + self.varalias = kwargs.get('varalias', mco.varalias) + if isinstance(self.varalias, list): + if len(self.varalias) > 1: + print("Warning: gridder only expects one varalias.") + print("First varalias selected as default: {}".format( + self.varalias[0])) + print("If you want to select another variable, please "\ + +"specify with varalias argument.") + self.varalias=self.varalias[0] self.olons = np.array(mco.vars.lons.squeeze().values.flatten()) self.olats = np.array(mco.vars.lats.squeeze().values.flatten()) self.ovals = np.array( - mco.vars[mco.varalias].squeeze().values.flatten()) + mco.vars[self.varalias].squeeze().values.flatten()) self.stdvarname = mco.stdvarname - self.varalias = mco.varalias - self.units = mco.units + self.units = variable_def[self.varalias].get('units') + self.stdvarname = variable_def[self.varalias].get('standard_name') self.sdate = mco.vars['time'][0] self.edate = mco.vars['time'][-1] else: diff --git a/wavy/quicklookmod.py b/wavy/quicklookmod.py index 4b713d0..5f675e7 100755 --- a/wavy/quicklookmod.py +++ b/wavy/quicklookmod.py @@ -51,6 +51,8 @@ def quicklook(self, a=False, projection=None, **kwargs): if isinstance(self.varalias, list): varalias = kwargs.get('varalias', self.varalias[0]) + assert varalias in self.varalias, "varalias must be one of {}"\ + .format(self.varalias) idx_units = np.argwhere(np.array(self.varalias)==varalias)[0][0] units_to_plot = self.units[idx_units] else: @@ -63,11 +65,22 @@ def quicklook(self, a=False, projection=None, **kwargs): plot_lons = self.vars.lons plot_lats = self.vars.lats except Exception as e: - plot_var = self.vars.obs_values + list_vars = list(self.vars.variables) + assert 'model_'+varalias in list_vars,"model_{}".format(varalias)+\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + assert 'obs_'+varalias in list_vars, "obs_{}".format(varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + plot_var = self.vars["obs_"+varalias] plot_lons = self.vars.obs_lons plot_lats = self.vars.obs_lats - plot_var_obs = self.vars.obs_values - plot_var_model = self.vars.model_values + plot_var_obs = self.vars["obs_"+varalias] + plot_var_model = self.vars["model_"+varalias] if str(type(self)) == "": if len(plot_lons.shape) < 2: @@ -399,8 +412,8 @@ def quicklook(self, a=False, projection=None, **kwargs): plt.xlabel('obs (' + self.nID + ')') plt.ylabel('models (' + self.model + ')') - maxv = np.nanmax([self.vars['model_values'], - self.vars['obs_values']]) + maxv = np.nanmax([self.vars['model_'+varalias], + self.vars['obs_'+varalias]]) minv = 0 plt.xlim([minv, maxv*1.05]) plt.ylim([minv, maxv*1.05]) diff --git a/wavy/triple_collocation.py b/wavy/triple_collocation.py index be7099f..edc1399 100644 --- a/wavy/triple_collocation.py +++ b/wavy/triple_collocation.py @@ -592,7 +592,7 @@ def least_squares_merging(data, tc_results=None, return_var=False, **kwargs): return least_squares_merge, least_squares_var -def spectra_cco(cco, fs, returns='average',nsample=64): +def spectra_cco(cco, fs, varalias, returns='average',nsample=64): """ Calculate the power spectra for both model and observation series of a collocation class object. The series are separated into @@ -601,9 +601,9 @@ def spectra_cco(cco, fs, returns='average',nsample=64): for each sample. cco (collocation_class object or xarray dataset): collocation class object - for which the spectra are calculated. If xarray.dataset - is given, it must have a single time dimension and obs_values - and model_values variables. + for which the spectra are calculated. If xarray.dataset + is given, it must have a single time dimension and 'obs_'+varalias + and 'model_'+varalias variables. fs (float): frequency of the time series returns (str): If 'average' returns the average power spectra of the time series. If 'list' returns lists of power spectra of each samples of the given series. @@ -653,19 +653,19 @@ def spectra_cco(cco, fs, returns='average',nsample=64): idx = idx[0] time_idx = sample_tmp.time.values[idx] - obs_idx = sample_tmp.obs_values.values[idx] - mod_idx = sample_tmp.model_values.values[idx] + obs_idx = sample_tmp['obs_'+varalias].values[idx] + mod_idx = sample_tmp['model_'+varalias].values[idx] time_idx_1 = sample_tmp.time.values[idx+1] - obs_idx_1 = sample_tmp.obs_values.values[idx+1] - mod_idx_1 = sample_tmp.model_values.values[idx+1] + obs_idx_1 = sample_tmp['obs_'+varalias].values[idx+1] + mod_idx_1 = sample_tmp['model_'+varalias].values[idx+1] time_between = time_idx + np.timedelta64(int(1000*median_step),'ms') obs_between= (obs_idx + obs_idx_1)/2 mod_between = (mod_idx + mod_idx_1)/2 ds_between = xr.Dataset( - data_vars={'obs_values': (('time'), [obs_between]), - 'model_values': (('time'), [mod_between])}, + data_vars={'obs_'+varalias: (('time'), [obs_between]), + 'model_'+varalias: (('time'), [mod_between])}, coords={'time': [time_between]} ) @@ -673,8 +673,8 @@ def spectra_cco(cco, fs, returns='average',nsample=64): sample_tmp = sample_tmp.isel(time=range(0,nsample)) - obs_val_tmp = sample_tmp.obs_values.values - mod_val_tmp = sample_tmp.model_values.values + obs_val_tmp = sample_tmp['obs_'+varalias].values + mod_val_tmp = sample_tmp['model_'+varalias].values f, PS_obs_tmp = periodogram(obs_val_tmp, fs=fs, window='hamming') f, PS_mod_tmp = periodogram(mod_val_tmp, fs=fs, window='hamming') @@ -726,7 +726,6 @@ def integrate_r2(PS_mod, PS_obs, f, threshold=np.inf, threshold_type='inv_freq') else: threshold = 1/threshold - f_1 = [1/f[i] if f[i] != 0 else np.inf for i in range(len(f)) ] idx_threshold = np.argwhere(np.array(f_1) <= threshold)[0][0] r2 = np.sum(weighted_diff_PS[idx_threshold:]) From 53ace15720ba7489bfe0e28e9a473f331027e19a Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Wed, 26 Nov 2025 16:49:24 +0100 Subject: [PATCH 03/22] Variables in the observations objects that are not in varalias are now also kept when performing collocation. --- wavy/collocation_module.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/wavy/collocation_module.py b/wavy/collocation_module.py index f919d71..7f22a0a 100755 --- a/wavy/collocation_module.py +++ b/wavy/collocation_module.py @@ -308,6 +308,16 @@ def populate(self, **kwargs): ds = new._build_xr_dataset(results_dict) ds = ds.assign_coords(time=ds.time.values) new.vars = ds + + # Add extra variables from oco + list_vars_cco = list(new.vars.keys()) + list_vars_oco = ['obs_' + v for v in list(new.oco.vars.keys())] + list_vars_extra = [v[4:] for v in list_vars_oco if v not in\ + list_vars_cco] + new.vars = new.vars.merge(new.oco.vars[list_vars_extra].\ + rename({v:'obs_'+v for v in list_vars_extra}), + join='left') + new = new._drop_duplicates(**kwargs) t1 = time.time() print(" ") From 209426bb4b368706c4a5246d29916eb5b1ab47e2 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Thu, 4 Dec 2025 14:12:43 +0100 Subject: [PATCH 04/22] Fixed filtermod with multiple variables --- wavy/filtermod.py | 72 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/wavy/filtermod.py b/wavy/filtermod.py index 4b5b285..821c3cc 100644 --- a/wavy/filtermod.py +++ b/wavy/filtermod.py @@ -33,6 +33,13 @@ def apply_limits(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) llim = kwargs.get('llim', @@ -132,6 +139,13 @@ def filter_lanczos(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) # apply slider if needed @@ -175,8 +189,17 @@ def filter_runmean(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) + + print("Applying filter to {}".format(varalias)) # apply slider if needed win = kwargs.get('slider', len(new.vars.time)) ol = kwargs.get('overlap', 0) @@ -213,8 +236,17 @@ def filter_runmean(self, **kwargs): def filter_GP(self, **kwargs): print('Apply GPR filter') new = deepcopy(self) - varalias = kwargs.get('varalias', new.varalias[0]) - + if isinstance(new.varalias, list): + varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) + else: + varalias = kwargs.get('varalias', new.varalias) # apply slider if needed win = kwargs.get('slider', len(new.vars.time)) ol = kwargs.get('overlap', 0) @@ -256,9 +288,15 @@ def filter_linearGAM(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) - # apply slider if needed win = kwargs.get('slider', len(new.vars.time)) ol = kwargs.get('overlap', 0) @@ -311,6 +349,13 @@ def despike_blockStd(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) @@ -367,6 +412,13 @@ def despike_blockQ(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) @@ -410,6 +462,13 @@ def despike_GP(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) @@ -454,6 +513,13 @@ def despike_linearGAM(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) From ea3e5257a009f22ef81bce14d06658b1db238828 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Thu, 4 Dec 2025 14:36:49 +0100 Subject: [PATCH 05/22] Fixed collocation distlim + comment ais tests --- tests/test_aismod.py | 22 +++++++++++----------- wavy/collocation_module.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_aismod.py b/tests/test_aismod.py index 67f141f..8cf1895 100644 --- a/tests/test_aismod.py +++ b/tests/test_aismod.py @@ -3,15 +3,15 @@ import pytest -@pytest.mark.need_credentials -def test_get_ais_data(): +#@pytest.mark.need_credentials +#def test_get_ais_data(): +# +# bbox = ['5.89', '62.3', '6.5', '62.7'] +# sd = '2017-01-03 08' +# ed = '2017-01-03 09' - bbox = ['5.89', '62.3', '6.5', '62.7'] - sd = '2017-01-03 08' - ed = '2017-01-03 09' - - ais_ds = ais.get_AIS_data(bbox, sd, ed) - assert len(ais_ds['time']) > 0 - assert not all(np.isnan(v) for v in ais_ds['time']) - assert not all(np.isnan(v) for v in ais_ds['lons']) - assert not all(np.isnan(v) for v in ais_ds['lats']) +# ais_ds = ais.get_AIS_data(bbox, sd, ed) +# assert len(ais_ds['time']) > 0 +# assert not all(np.isnan(v) for v in ais_ds['time']) +# assert not all(np.isnan(v) for v in ais_ds['lons']) +# assert not all(np.isnan(v) for v in ais_ds['lats']) diff --git a/wavy/collocation_module.py b/wavy/collocation_module.py index 7f22a0a..01c9d68 100755 --- a/wavy/collocation_module.py +++ b/wavy/collocation_module.py @@ -258,7 +258,7 @@ class collocation_class(qls): ''' def __init__(self, oco=None, model=None, poi=None, - distlim=None, leadtime=None, varalias=None, **kwargs): + leadtime=None, varalias=None, **kwargs): print('# ----- ') print(" ### Initializing collocation_class object ###") print(" ") From f2d6fb6eda35c200d6982151050a1b24236d37be Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Mon, 8 Dec 2025 11:11:43 +0100 Subject: [PATCH 06/22] Added variables to variable_def default file --- wavy/config/variable_def.yaml.default | 162 +++++++++++++++++++++++++- 1 file changed, 159 insertions(+), 3 deletions(-) diff --git a/wavy/config/variable_def.yaml.default b/wavy/config/variable_def.yaml.default index 371c721..f143d64 100644 --- a/wavy/config/variable_def.yaml.default +++ b/wavy/config/variable_def.yaml.default @@ -159,9 +159,6 @@ U: standard_name: wind_speed units: m s-1 _FillValue: -999.9 - aliases_of_vector_components: - alt1: [ua,va] - alt2: [ux,vy] valid_range: [0., 200] Udir: @@ -225,3 +222,162 @@ hmaxst: long_name: 'space time domain maximum individual wave height' units: m _FillValue: -999.9 + +UdirTo: + standard_name: wind_to_direction + long_name: 'Wind direction' + units: degree + _FillValue: -999.9 + +FV: + standard_name: friction_velocity + long_name: 'friction velocity' + units: m s-1 + _FillValue: -999.9 + +DC: + standard_name: drag_coefficient + long_name: 'drag coefficient' + units: NA + _FillValue: -999.9 + +MdirTo: + standard_name: sea_surface_wave_to_direction + long_name: 'Total mean wave direction' + units: degree + _FillValue: -999.9 + +HsSea: + standard_name: sea_surface_wind_wave_significant_height + long_name: 'Sea significant wave height' + units: m + _FillValue: -999.9 + +TpSea: + standard_name: sea_surface_wind_wave_peak_period_from_variance_spectral_density + long_name: 'Sea peak period' + units: s + _FillValue: -999.9 + +Tm10Sea: + standard_name: sea_surface_wind_wave_mean_period_from_variance_spectral_density_inverse_frequency_moment + long_name: 'Sea mean period' + units: s + _FillValue: -999.9 + +MdirToSea: + standard_name: sea_surface_wind_wave_to_direction + long_name: 'Sea mean wave direction' + units: degree + _FillValue: -999.9 + +SIC: + standard_name: sea_ice_concentration + long_name: 'sea ice concentration' + units: '1' + _FillValue: -999.9 + +HsST: + standard_name: sea_surface_swell_wave_significant_height + long_name: 'Swell significant wave height' + units: m + _FillValue: -999.9 + +TpST: + standard_name: sea_surface_swell_wave_period_at_variance_spectral_density_maximum + long_name: 'Swell peak period' + units: s + _FillValue: -999.9 + +MdirToST: + standard_name: sea_surface_swell_wave_to_direction + long_name: 'Swell mean wave direction' + units: degree + _FillValue: -999.9 + +eHmax: + standard_name: expected_maximum_wave_height + long_name: 'expected maximum wave height' + units: m + _FillValue: -999.9 + +eTm: + standard_name: expected_wave_period + long_name: 'expected wave period' + units: s + _FillValue: -999.9 + +fpI: + standard_name: interpolated_peak_frequency + long_name: 'interpolated peak frequency' + units: s + _FillValue: -999.9 + +HsS1: + standard_name: sea_surface_primary_swell_wave_significant_height + long_name: 'first swell significant wave height' + units: m + _FillValue: -999.9 + +TmS1: + standard_name: sea_surface_primary_swell_wave_mean_period + long_name: 'first swell mean period' + units: s + _FillValue: -999.9 + +MdirToS1: + standard_name: sea_surface_primary_swell_wave_to_direction + long_name: 'first swell direction' + units: degree + _FillValue: -999.9 + +HsS2: + standard_name: sea_surface_secondary_swell_wave_significant_height + long_name: 'second swell significant wave height' + units: m + _FillValue: -999.9 + +TmS2: + standard_name: sea_surface_secondary_swell_wave_mean_period + long_name: 'second swell mean period' + units: s + _FillValue: -999.9 + +MdirToS2: + standard_name: sea_surface_secondary_swell_wave_to_direction + long_name: 'second swell direction' + units: degree + _FillValue: -999.9 + +HsS3: + standard_name: sea_surface_tertiary_swell_wave_significant_height + long_name: 'third swell significant wave height' + units: m + _FillValue: -999.9 + +TmS3: + standard_name: sea_surface_tertiary_swell_wave_mean_period + long_name: 'third swell mean period' + units: m s-1 + _FillValue: -999.9 + +MdirToS3: + standard_name: sea_surface_tertiary_swell_wave_to_direction + long_name: 'third swell direction' + units: degree + _FillValue: -999.9 + +SIT: + standard_name: sea_ice_thickness + long_name: 'sea ice thickness' + units: m + _FillValue: -999.9 + +PdirTo: + standard_name: "sea_surface_wave_to_direction_at_variance_spectral\ + _density_maximum" + units: degree + valid_range: [0.,360.] + _FillValue: -999.9 + type: cyclic + From cce5b45b0f9bd11ee10428e495a538bfa84beda9 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Mon, 8 Dec 2025 15:40:26 +0100 Subject: [PATCH 07/22] Adapted get_cmems reader such that DEPTH level on which to fetch the data can be given for each varalias --- wavy/insitu_module.py | 3 +++ wavy/insitu_readers.py | 22 ++++++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/wavy/insitu_module.py b/wavy/insitu_module.py index 05be035..2626796 100755 --- a/wavy/insitu_module.py +++ b/wavy/insitu_module.py @@ -103,6 +103,9 @@ def __init__(self, **kwargs): self.distlim = kwargs.get('distlim', 6) self.filter = kwargs.get('filter', False) self.region = kwargs.get('region', 'global') + depth_lvls = kwargs.get('depth_lvls', None) + if depth_lvls is not None: + self.depth_lvls = depth_lvls self.cfg = dc print(" ") print(" ### insitu_class object initialized ### ") diff --git a/wavy/insitu_readers.py b/wavy/insitu_readers.py index 30fa3fb..da31caf 100755 --- a/wavy/insitu_readers.py +++ b/wavy/insitu_readers.py @@ -369,6 +369,8 @@ def get_cmems(**kwargs): varalias = [varalias] pathlst = kwargs.get('pathlst') cfg = vars(kwargs['cfg']) + depth_lvls = kwargs.get('depth_lvls', None) + # check if dimensions are fixed fixed_dim_str = list(cfg['misc']['fixed_dim'].keys())[0] fixed_dim_idx = cfg['misc']['fixed_dim'][fixed_dim_str] @@ -399,14 +401,30 @@ def get_cmems(**kwargs): for coord in list(ds.coords) if coord in [lonstr, latstr, timestr]} + len_timestr = len(dict_var[timestr]) + + for coord in [lonstr, latstr]: + if len(dict_var[coord].shape)==0: + dict_var[coord] = np.array([dict_var[coord]]*len_timestr) + + list_vars_tmp = list(ds.data_vars) + + if depth_lvls is not None: + + dict_var.update({var: ds.sel(DEPTH=depth_lvls[var])[var]\ + .values for var in depth_lvls.keys()}) + + list_vars_tmp = [k for k in list(ds.data_vars) if k + not in list(depth_lvls.keys())] + dict_var.update({var: rebuild_split_variable(ds, fixed_dim_str, var) - for var in list(ds.data_vars)}) + for var in list_vars_tmp}) # build an xr.dataset with timestr as the only coordinate # using build_xr_ds function ds_list.append(build_xr_ds_cmems(dict_var, timestr)) - + except Exception as e: logger.exception(e) From 8e29f5a817b1a55511faa580ee2787dfa8029269 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Mon, 8 Dec 2025 16:26:09 +0100 Subject: [PATCH 08/22] Removed duplicate from variable_def default file --- wavy/config/variable_def.yaml.default | 6 ------ 1 file changed, 6 deletions(-) diff --git a/wavy/config/variable_def.yaml.default b/wavy/config/variable_def.yaml.default index f143d64..95ff5d5 100644 --- a/wavy/config/variable_def.yaml.default +++ b/wavy/config/variable_def.yaml.default @@ -295,12 +295,6 @@ MdirToST: units: degree _FillValue: -999.9 -eHmax: - standard_name: expected_maximum_wave_height - long_name: 'expected maximum wave height' - units: m - _FillValue: -999.9 - eTm: standard_name: expected_wave_period long_name: 'expected wave period' From e3bd956cb46f61aa0e56f4452aa0fb656dacf250 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 9 Dec 2025 14:01:03 +0100 Subject: [PATCH 09/22] Modified read_local_20Hz_files satellite reader to deal with ncfiles with original dimension other than time --- wavy/satellite_readers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wavy/satellite_readers.py b/wavy/satellite_readers.py index f48aa22..a101887 100755 --- a/wavy/satellite_readers.py +++ b/wavy/satellite_readers.py @@ -179,7 +179,7 @@ def read_local_20Hz_files(**kwargs): sd = kwargs.get('sd') ed = kwargs.get('ed') twin = kwargs.get('twin') - + nc_dim = satellite_dict[nID].get('misc',{}).get('nc_dim', 'time') # adjust start and end sd = sd - timedelta(minutes=twin) ed = ed + timedelta(minutes=twin) @@ -195,7 +195,7 @@ def read_local_20Hz_files(**kwargs): satellite_dict[nID], ncmeta) # retrieve sliced data - ds = read_netcdfs(pathlst) + ds = read_netcdfs(pathlst, dim=nc_dim) ds_sort = ds.sortby(timestr) # get indices for included time period From 5737519b2a9bd49c36e44b6a7f8fdfbe488b752b Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Thu, 11 Dec 2025 12:10:40 +0100 Subject: [PATCH 10/22] Fixed model_time attributes causing netcdf error in cco.vars --- wavy/collocation_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wavy/collocation_module.py b/wavy/collocation_module.py index 01c9d68..1b56d13 100755 --- a/wavy/collocation_module.py +++ b/wavy/collocation_module.py @@ -317,7 +317,7 @@ def populate(self, **kwargs): new.vars = new.vars.merge(new.oco.vars[list_vars_extra].\ rename({v:'obs_'+v for v in list_vars_extra}), join='left') - + new = new._drop_duplicates(**kwargs) t1 = time.time() print(" ") @@ -373,7 +373,7 @@ def _build_xr_dataset(self, results_dict): data=results_dict['model_time'], dims=['time'], coords={'time': results_dict['obs_time']}, - attrs=variable_def['time'], + attrs={}, ), 'model_lons': xr.DataArray( data=results_dict['model_lons'], From c207cd6b1c90e03e7c734f8c1bb751f2a158b333 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Fri, 12 Dec 2025 15:41:54 +0100 Subject: [PATCH 11/22] Corrected error in triple_collocation module --- wavy/triple_collocation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wavy/triple_collocation.py b/wavy/triple_collocation.py index edc1399..5442c33 100644 --- a/wavy/triple_collocation.py +++ b/wavy/triple_collocation.py @@ -103,7 +103,7 @@ def filter_dynamic_collocation(data, mod_1, mod_2, max_rel_diff=0.05): mod_1 = np.array(mod_1) mod_2 = np.array(mod_2) - idx = np.abs(mod_1 - mod_2)/mod_1 < 0.05 + idx = np.abs(mod_1 - mod_2)/mod_1 < max_rel_diff data_filtered = {} From 6c4893221f3e8942c614d880f61ba00886acd5c4 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Fri, 12 Dec 2025 17:45:46 +0100 Subject: [PATCH 12/22] Fixed arguments in collocation module with regridded method --- tests/test_collocmod.py | 6 +++--- wavy/collocation_module.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_collocmod.py b/tests/test_collocmod.py index 9e2a61e..0569590 100644 --- a/tests/test_collocmod.py +++ b/tests/test_collocmod.py @@ -27,7 +27,7 @@ def test_sat_collocation_and_validation(test_data, tmpdir): # collocate cco = cc(oco=sco, model=model, leadtime='best', distlim=6).populate() - assert len(vars(cco).keys()) == 19 + assert len(vars(cco).keys()) == 21 assert len(cco.vars.keys()) == 10 # validate @@ -51,7 +51,7 @@ def test_insitu_collocation_and_validation(test_data, tmpdir): # collocate cco = cc(oco=ico, model=model, leadtime='best', distlim=6).populate() - assert len(vars(cco).keys()) == 19 + assert len(vars(cco).keys()) == 21 assert len(cco.vars.keys()) == 10 # validate @@ -69,7 +69,7 @@ def test_poi_collocation(): # collocate cco = cc(oco=pco, model='ww3_4km', leadtime='best').populate() - assert len(vars(cco).keys()) == 19 + assert len(vars(cco).keys()) == 21 assert len(cco.vars.keys()) == 10 diff --git a/wavy/collocation_module.py b/wavy/collocation_module.py index 1b56d13..6c108ef 100755 --- a/wavy/collocation_module.py +++ b/wavy/collocation_module.py @@ -291,6 +291,8 @@ def __init__(self, oco=None, model=None, poi=None, self.distlim = kwargs.get('distlim', 6) self.method = kwargs.get('method', 'closest') self.colloc_time_method = kwargs.get('colloc_time_method', 'nearest') + self.nproc = kwargs.get('nproc',16) + self.res = kwargs.get('res',(0.5,0.5)) print(" ") print(" ### Collocation_class object initialized ###") @@ -641,7 +643,7 @@ def _collocate_centered_model_value(self, time, lon, lat, **kwargs): nID_model = self.model name_model = self.model - res = kwargs.get('res', (0.5,0.5)) + res = self.res colloc_time_method = self.colloc_time_method print('Using resolution {}'.format(res)) @@ -697,7 +699,7 @@ def _collocate_regridded_model(self, **kwargs): lat_mod_list=[] time_mod_list=[] - nproc = kwargs.get('nproc', 16) + nproc = self.nproc oco_vars = self.oco.vars From a76fdc00441aa543b359ebdd5f718786dd644b87 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Thu, 18 Dec 2025 15:03:10 +0100 Subject: [PATCH 13/22] Improved function to calculate the power spectra in triple collocation, making it faster and more generic for different usage. --- tests/test_triple_collocation.py | 16 +++- wavy/triple_collocation.py | 152 +++++++++++++------------------ 2 files changed, 75 insertions(+), 93 deletions(-) diff --git a/tests/test_triple_collocation.py b/tests/test_triple_collocation.py index c0ccedb..e9d86d3 100644 --- a/tests/test_triple_collocation.py +++ b/tests/test_triple_collocation.py @@ -194,9 +194,21 @@ def test_least_squares_merging(test_data): assert len(least_squares_merge) == len(dict_data['insitu']) -def test_spectra_cco(test_data): +def test_get_mean_spectra(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + + spectra = tc.get_mean_spectra(ico.vars, varname='Hs', fs=6, nsample=64) + + assert len(spectra) == 32 def test_integrate_r2(test_data): diff --git a/wavy/triple_collocation.py b/wavy/triple_collocation.py index 5442c33..8afa621 100644 --- a/wavy/triple_collocation.py +++ b/wavy/triple_collocation.py @@ -590,114 +590,84 @@ def least_squares_merging(data, tc_results=None, return_var=False, **kwargs): return least_squares_merge else: return least_squares_merge, least_squares_var + - -def spectra_cco(cco, fs, varalias, returns='average',nsample=64): +def get_mean_spectra(ds, varname, fs, nsample, + median_step=None, mode='average', + window='hamming'): """ - Calculate the power spectra for both model and observation series - of a collocation class object. The series are separated into - samples of a given length. Up to one missing value for each - sample is filled with interpolation. A Hamming window is applied - for each sample. - - cco (collocation_class object or xarray dataset): collocation class object - for which the spectra are calculated. If xarray.dataset - is given, it must have a single time dimension and 'obs_'+varalias - and 'model_'+varalias variables. - fs (float): frequency of the time series - returns (str): If 'average' returns the average power spectra of the time series. - If 'list' returns lists of power spectra of each samples of the given series. - nsample (int): size of the samples to calculate the power spectra. - + Divides a given time series into sample of given size, applies a window to + each sample and calculates the power spectra for each sample, using a Fast + Fourier transform. Returns the frequencies and either the list of the + spectra for each sample or the average spectra over all samples. + + ds (xarray dataset): xarray dataset with dimension time + varname (str): name of the variable for which the spectra is + to be computed, present in ds and indexed by time + fs (float): sampling frequency + nsample (int): number of points of each sample + median_step (np.timedelta64): time to consider between each point of the + time series + mode (str): either 'average' to return the mean power spectra or + 'list' to return the list of the power spectra + window (str): window to apply to the samples before applying the + FFT. See scipy.singal.periodogram for options. + + return: - PS_mod (numpy array): average spectra of cco model values or individual spectra for each sample - PS_obs (numpy array): average spectra of cco observation values or individual spectra for each sample - f (numpy array): frequencies for the spectra + df_spectra (pandas DataFrame): dataframe containing the mean power spectra + or the power spectra for each sample and the + corresponding frequencies. """ from scipy.signal import periodogram - from wavy.collocation_module import collocation_class as cc - - if isinstance(cco, cc): - cco_vars = cco.vars.dropna(dim='time') - elif isinstance(cco, xr.Dataset): - cco_vars = cco.dropna(dim='time') - - step_cco = [(cco_vars.time.values[i+1] - cco_vars.time.values[i])/np.timedelta64(1,'s') for\ - i in range(len(cco_vars.time)-1)] - step_cco.append(np.nan) - median_step = np.median(step_cco) - - i_0 = 0 - i_1 = i_0 + nsample - - list_PS_obs = [] - list_PS_mod = [] - - while i_1 < len(cco_vars.time.values): - - sample_tmp = cco_vars.isel(time=range(i_0,i_1)) + + ds = ds.dropna(dim='time') - step_tmp = [(sample_tmp.time.values[i+1] - sample_tmp.time.values[i])/np.timedelta64(1,'s') for\ - i in range(len(sample_tmp.time)-1)] - step_tmp.append(np.nan) - - idx_1_na = np.argwhere((step_tmp > 1.5*median_step) & (step_tmp < 2.5*median_step)) - idx_2_na = np.argwhere(step_tmp > 2.5*median_step) + diff_time = ds.time.diff(dim='time').values + if median_step == None: + median_step = np.timedelta64(np.nanmedian(diff_time),'ns') + else: + median_step = np.timedelta64(median_step, 'ns') - if len(idx_2_na) > 0: - i_0 = idx_2_na[-1][0] - i_1 = i_0 + nsample - continue + sample_period = nsample * median_step - for idx in idx_1_na: - idx = idx[0] - - time_idx = sample_tmp.time.values[idx] - obs_idx = sample_tmp['obs_'+varalias].values[idx] - mod_idx = sample_tmp['model_'+varalias].values[idx] - time_idx_1 = sample_tmp.time.values[idx+1] - obs_idx_1 = sample_tmp['obs_'+varalias].values[idx+1] - mod_idx_1 = sample_tmp['model_'+varalias].values[idx+1] - - time_between = time_idx + np.timedelta64(int(1000*median_step),'ms') - obs_between= (obs_idx + obs_idx_1)/2 - mod_between = (mod_idx + mod_idx_1)/2 - - ds_between = xr.Dataset( - data_vars={'obs_'+varalias: (('time'), [obs_between]), - 'model_'+varalias: (('time'), [mod_between])}, - coords={'time': [time_between]} - ) + run_sum_period = [np.sum(diff_time[i:i+nsample]) for\ + i in range(len(diff_time)-nsample)] - sample_tmp = xr.concat([sample_tmp, ds_between], dim='time').sortby('time') + up_condition = (nsample*median_step + 0.5*median_step > run_sum_period) + low_condition = (nsample*median_step - 0.5*median_step <= run_sum_period) + condition = up_condition & low_condition - sample_tmp = sample_tmp.isel(time=range(0,nsample)) - - obs_val_tmp = sample_tmp['obs_'+varalias].values - mod_val_tmp = sample_tmp['model_'+varalias].values - - f, PS_obs_tmp = periodogram(obs_val_tmp, fs=fs, window='hamming') - f, PS_mod_tmp = periodogram(mod_val_tmp, fs=fs, window='hamming') + idx_right_length = np.argwhere(up_condition & low_condition).flatten() - list_PS_obs.append(PS_obs_tmp) - list_PS_mod.append(PS_mod_tmp) + list_PS = [] - i_0 = i_1 - len(idx_1_na) - i_1 = i_0 + nsample + i = idx_right_length[0] - mean_PS_mod = np.mean(np.array(list_PS_mod), axis=0) - mean_PS_obs = np.mean(np.array(list_PS_obs), axis=0) + while i < idx_right_length[-1] - nsample: - if returns == 'list': - PS_mod = np.array(list_PS_mod) - PS_obs = np.array(list_PS_obs) - elif returns == 'average': - PS_mod = mean_PS_mod - PS_obs = mean_PS_obs - return PS_mod, PS_obs, f + ds_tmp = ds.isel(time=range(i,i+nsample)) + sample_tmp = ds_tmp[varname].values + + f, PS_tmp = periodogram(sample_tmp, fs=fs, window=window) + + list_PS.append(PS_tmp[1:]) + i = idx_right_length[np.argwhere(idx_right_length >= i + nsample).\ + flatten()[0]] + + if mode == 'average': + mean_PS = np.mean(np.array(list_PS), axis=0) + df_spectra = pd.DataFrame({'f':f[1:], 'spectra': mean_PS}) + return df_spectra + else: + df_spectra = pd.DataFrame({'f':f[1:], + **{'spectra_'+str(i):ps for i, ps in enumerate(list_PS)}}) + return df_spectra + + def integrate_r2(PS_mod, PS_obs, f, threshold=np.inf, threshold_type='inv_freq'): """ Estimates the representativeness error r2 by integrating the difference From e5e45735a4a1f55933e96f9097dedb50580df7f4 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 6 Jan 2026 15:17:32 +0100 Subject: [PATCH 14/22] Added tests for insit, satellite, model, collocation and filter modules, when dealing with several varalias. --- tests/test_collocmod.py | 21 +++++++++ tests/test_filtermod.py | 71 +++++++++++++++++++++++++++++ tests/test_insitumod.py | 23 ++++++++++ tests/test_modelmod.py | 9 ++++ tests/test_satellite_module.py | 29 ++++++++++++ wavy/config/insitu_cfg.yaml.default | 1 + 6 files changed, 154 insertions(+) diff --git a/tests/test_collocmod.py b/tests/test_collocmod.py index 0569590..38811ee 100644 --- a/tests/test_collocmod.py +++ b/tests/test_collocmod.py @@ -32,6 +32,27 @@ def test_sat_collocation_and_validation(test_data, tmpdir): # validate +def test_cco_multivar(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = 's3a' + varalias = 'Hs' + twin = 30 + nID = 'cmems_L3_NRT' + model = 'ww3_4km' + # init satellite_object and check for polygon region + sco = sc(sd=sd, ed=ed, nID=nID, name=name, + varalias=varalias, twin=twin) + # read data + sco = sco.populate(reader='read_local_ncfiles', + path=str(test_data/"L3/s3a")) + # crop to region + sco = sco.crop_to_region(model) + + # collocate + cco = cc(oco=sco, model=model, leadtime='best', distlim=6, varalias=['Hs','Tm01']).populate() + assert len(vars(cco).keys()) == 21 + assert len(cco.vars.keys()) == 11 def test_insitu_collocation_and_validation(test_data, tmpdir): sd = "2022-2-1 12" diff --git a/tests/test_filtermod.py b/tests/test_filtermod.py index 32a4534..36aca85 100644 --- a/tests/test_filtermod.py +++ b/tests/test_filtermod.py @@ -144,3 +144,74 @@ def test_filter_distance_to_coast(test_data): assert len(flst) >= 47 assert type(sco.vars == 'xarray.core.dataset.Dataset') assert not 'error' in vars(sco).keys() + +def test_filter_ico_multivar(test_data): + + varalias = ['Hs','U'] # default + sd = "2023-7-2 00" + ed = "2023-7-3 00" + nID = 'MO_Draugen_monthly' + name = 'Draugen' + ico = ic(nID=nID, sd=sd, ed=ed, varalias=varalias, name=name) + ico = ico.populate(path=str(test_data/"insitu/monthly/Draugen")) + + assert len(ico.vars['time']) > 0 + assert len(ico.vars.keys()) == 4 + assert not all(np.isnan(v) for v in ico.vars['Hs']) + assert not all(np.isnan(v) for v in ico.vars['U']) + + filter_1 = ico.filter_runmean(window=3, + chunk_min=3, + sampling_rate_Hz=1/600, + varalias='Hs') + + assert len(filter_1.vars['time']) > 0 + assert len(filter_1.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_1.vars['Hs']) + assert not all(np.isnan(v) for v in filter_1.vars['U']) + + filter_2 = filter_1.apply_limits(llim=1, ulim=3, + varalias='Hs') + + assert len(filter_2.vars['time']) > 0 + assert len(filter_2.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_2.vars['Hs']) + assert not all(np.isnan(v) for v in filter_2.vars['U']) + +def test_filter_sco_multivar(test_data): + + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = 's3a' + varalias = ['Hs','U'] + twin = 30 + nID = 'cmems_L3_NRT' + # init satellite_object + sco = sc(sd=sd, ed=ed, nID=nID, name=name, + varalias=varalias, + twin=twin) + # read data + sco = sco.populate(path=str(test_data/"L3/s3a")) + + assert len(sco.vars['time']) > 0 + assert len(sco.vars.keys()) == 4 + assert not all(np.isnan(v) for v in sco.vars['Hs']) + assert not all(np.isnan(v) for v in sco.vars['U']) + + filter_1 = sco.filter_runmean(window=3, + chunk_min=3, + sampling_rate_Hz=1/600, + varalias='Hs') + + assert len(filter_1.vars['time']) > 0 + assert len(filter_1.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_1.vars['Hs']) + assert not all(np.isnan(v) for v in filter_1.vars['U']) + + filter_2 = filter_1.apply_limits(llim=1, ulim=3, + varalias='Hs') + + assert len(filter_2.vars['time']) > 0 + assert len(filter_2.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_2.vars['Hs']) + assert not all(np.isnan(v) for v in filter_2.vars['U']) \ No newline at end of file diff --git a/tests/test_insitumod.py b/tests/test_insitumod.py index 59a453a..7fc3db6 100644 --- a/tests/test_insitumod.py +++ b/tests/test_insitumod.py @@ -126,6 +126,29 @@ def test_cmems_insitu_daily(test_data): # if '.nc' in filelist[i]] # assert len(nclist) >= 1 +def test_cmems_insitu_multivar(test_data): + varalias = ['Hs','U'] # default + sd = "2023-7-2 00" + ed = "2023-7-3 00" + nID = 'MO_Draugen_monthly' + name = 'Draugen' + ico = ic(nID=nID, sd=sd, ed=ed, varalias=varalias, name=name) + print(ico) + print(vars(ico).keys()) + assert ico.__class__.__name__ == 'insitu_class' + assert len(vars(ico).keys()) == 12 + ico.list_input_files(show=True) + new = ico.populate(path=str(test_data/"insitu/monthly/Draugen")) + assert len(new.vars.keys()) == 4 + # check if some data was imported + assert len(new.vars['time']) > 0 + # check that not all data is nan + assert not all(np.isnan(v) for v in new.vars['time']) + assert not all(np.isnan(v) for v in new.vars['Hs']) + assert not all(np.isnan(v) for v in new.vars['U']) + assert not all(np.isnan(v) for v in new.vars['lons']) + assert not all(np.isnan(v) for v in new.vars['lats']) + def test_insitu_poi(tmpdir): # define poi dictionary for track diff --git a/tests/test_modelmod.py b/tests/test_modelmod.py index 6f41639..d2a5f86 100644 --- a/tests/test_modelmod.py +++ b/tests/test_modelmod.py @@ -59,6 +59,15 @@ def test_NORA3_hc_waves(): assert len(vars(mco).keys()) == 18 assert len(mco.vars.keys()) == 3 +def test_mco_multivar(): + #get_model + mco = mc(nID='ww3_4km', sd="2023-6-1", ed="2023-6-1 00", varalias=['Hs', 'U']) + assert mco.__class__.__name__ == 'model_class' + mco.populate() + print(mco.vars) + assert len(vars(mco).keys()) == 18 + assert len(mco.vars.keys()) == 4 + # Fails #def test_MY_L4_thredds(): # """ diff --git a/tests/test_satellite_module.py b/tests/test_satellite_module.py index 95946cc..dd88bf8 100644 --- a/tests/test_satellite_module.py +++ b/tests/test_satellite_module.py @@ -106,6 +106,35 @@ def test_default_reader(test_data): assert not 'error' in vars(sco).keys() +def test_sco_multivar(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = 's3a' + varalias = ['Hs','U'] + twin = 30 + nID = 'cmems_L3_NRT' + # init satellite_object + sco = sc(sd=sd, ed=ed, nID=nID, name=name, + varalias=varalias, + twin=twin) + # read data + sco = sco.populate(path=str(test_data/"L3/s3a")) + assert sco.__class__.__name__ == 'satellite_class' + # compare number of available variables + vlst = list(vars(sco).keys()) + assert len(vlst) == 19 + # compare number of available functions + dlst = dir(sco) + flst = [n for n in dlst if n not in vlst if '__' not in n] + assert len(flst) >= 47 + assert type(sco.vars == 'xarray.core.dataset.Dataset') + assert not 'error' in vars(sco).keys() + assert len(sco.vars['time']) > 0 + assert len(sco.vars.keys()) == 4 + assert not all(np.isnan(v) for v in sco.vars['Hs']) + assert not all(np.isnan(v) for v in sco.vars['U']) + + def test_polygon_region(test_data): sd = "2022-2-01 01" ed = "2022-2-03 23" diff --git a/wavy/config/insitu_cfg.yaml.default b/wavy/config/insitu_cfg.yaml.default index 5cd19d3..e7ac1b9 100644 --- a/wavy/config/insitu_cfg.yaml.default +++ b/wavy/config/insitu_cfg.yaml.default @@ -836,6 +836,7 @@ MO_Draugen_monthly: # object nID (nameID) # optional, needs to be defined if not cf and in variable_info.yaml vardef: Hs: VAVH + U: WSPD lats: LATITUDE lons: LONGITUDE time: TIME From 37ceee2ff2a8aeb60bfabaf20f4530a8adb2d867 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 6 Jan 2026 16:58:09 +0100 Subject: [PATCH 15/22] Completed tests for triple collocation module --- tests/test_triple_collocation.py | 96 ++++++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 4 deletions(-) diff --git a/tests/test_triple_collocation.py b/tests/test_triple_collocation.py index e9d86d3..8fa6a26 100644 --- a/tests/test_triple_collocation.py +++ b/tests/test_triple_collocation.py @@ -205,18 +205,73 @@ def test_get_mean_spectra(test_data): ico.vars = xr.open_dataset( str(test_data/"triple_collocation/Norne_ico.nc") ) - + spectra = tc.get_mean_spectra(ico.vars, varname='Hs', fs=6, nsample=64) assert len(spectra) == 32 def test_integrate_r2(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + + # Import model data + mco = mc(sd='2014-01-01', + ed='2018-12-31', + nID='NORA3_hc_waves') + mco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_mco.nc") + ) + + ps_ico = tc.get_mean_spectra(ico.vars, varname='Hs', fs=6, nsample=64) + ps_mod = tc.get_mean_spectra(mco.vars, varname='Hs', fs=6, nsample=64) + + r2 = tc.integrate_r2(ps_ico['spectra'], ps_mod['spectra'], ps_ico['f']) + + assert isinstance(r2, float) + assert not(np.isnan(r2)) def test_filter_collocation_distance(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + # Import satellite data + sco = sc(sd='2014-01-01', + ed='2018-12-31', + nID='CCIv1_L3', + name='multi') + sco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_sco.nc") + ) + # Import model data + mco = mc(sd='2014-01-01', + ed='2018-12-31', + nID='NORA3_hc_waves') + mco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_mco.nc") + ) + # Create dictionary for triple collocation function + dict_data = {'insitu': ico, 'satellite': sco, 'model': mco} + + data_filtered = tc.filter_collocation_distance(dict_data, dist_max=10, name='satellite') + + assert dict_data.keys() == data_filtered.keys() + assert len(dict_data['insitu'].vars.time) >= len(data_filtered['insitu'].vars.time) + assert np.max(data_filtered['satellite'].vars.colloc_dist) <= 10 def test_filter_values(test_data): @@ -257,4 +312,37 @@ def test_filter_values(test_data): def test_filter_dynamic_collocation(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + # Import satellite data + sco = sc(sd='2014-01-01', + ed='2018-12-31', + nID='CCIv1_L3', + name='multi') + sco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_sco.nc") + ) + # Import model data + mco = mc(sd='2014-01-01', + ed='2018-12-31', + nID='NORA3_hc_waves') + mco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_mco.nc") + ) + # Create dictionary for triple collocation function + dict_data = {'insitu': ico.vars.Hs.values, + 'satellite': sco.vars.Hs.values, + 'model': mco.vars.Hs.values} + + data_filtered = tc.filter_dynamic_collocation(dict_data, 'insitu', 'model', max_rel_diff=0.05) + + assert dict_data.keys() == data_filtered.keys() + assert len(dict_data['insitu']) >= len(data_filtered['insitu']) + assert all((data_filtered['insitu'][i] - data_filtered['model'][i])/data_filtered['insitu'][i] <= 0.05 for i in range(len(data_filtered['insitu']))) From 946b05af74e6ee6e7c4f02a091a23dbe9b3cc447 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Wed, 7 Jan 2026 11:16:47 +0100 Subject: [PATCH 16/22] Added test for consalidate module --- tests/test_consolidate.py | 49 ++++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/tests/test_consolidate.py b/tests/test_consolidate.py index 989a39b..7e2ee2c 100644 --- a/tests/test_consolidate.py +++ b/tests/test_consolidate.py @@ -1,17 +1,34 @@ -#import sys -#import os -#import numpy as np -#from datetime import datetime -#import pytest -# -#from wavy.consolidate import consolidate_class as cs -#from wavy.satmod import satellite_class as sc +import sys +import os +import numpy as np +from datetime import datetime +import pytest +from wavy.satellite_module import satellite_class as sc +from wavy.consolidate import consolidate_class as cs -#def test_consolidate_scos(test_data, benchmark): -# sco1 = sc(sdate="2020-11-1 12",region="NordicSeas", -# mission='s3a', path_local=str(test_data/"L3")) -# sco2 = sc(sdate="2020-11-1 13",region="NordicSeas", -# mission='s3b', path_local=str(test_data/"L3")) -# cso = cs([sco1,sco2]) -# assert len(list(vars(cso).keys())) >= 8 -# assert len(list(cso.vars.keys())) == 6 +def test_consolidate_satellite(test_data): + # satellite consolidate + sco1 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3a').populate(path=str(test_data/"L3/s3a")) + sco2 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3b').populate(path=str(test_data/"L3/s3b")) + + cso = cs([sco1,sco2]) + + len(list(vars(cso).keys())) >= 8 + len(list(cso.vars.keys())) == 3 + +def test_consolidate_satellite_multivar(test_data): + + varalias = ['Hs','U'] + + # satellite consolidate + sco1 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3a', varalias=varalias).populate(path=str(test_data/"L3/s3a")) + sco2 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3b', varalias=varalias).populate(path=str(test_data/"L3/s3b")) + + cso = cs([sco1,sco2]) + + len(list(vars(cso).keys())) >= 8 + len(list(cso.vars.keys())) == 3 \ No newline at end of file From 7f66cdc9e43571462c0219483e4fa245ead6e8f5 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 27 Jan 2026 12:53:47 +0100 Subject: [PATCH 17/22] Added test for multivariables with multisat module --- tests/test_multisat.py | 60 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/tests/test_multisat.py b/tests/test_multisat.py index f92e10a..5033770 100644 --- a/tests/test_multisat.py +++ b/tests/test_multisat.py @@ -1,9 +1,53 @@ -#import pytest -# -#from wavy.multisat import multisat_class as ms +import pytest +from wavy import ms -#def test_multisat(test_data): -# mso = ms(sdate="2020-11-1",edate="2020-11-3",region="NordicSeas", -# mission=['s3a','s3b'], path_local=str(test_data/"L3")) -# assert len(list(vars(mso).keys())) == 15 -# assert len(list(mso.ocos)) >= 1 +def test_multisat(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = ['s3a','s3b'] + varalias = 'Hs' + + # init multisat_object + mso = ms(sd=sd, + ed=ed, + name=name, + varalias = varalias, + path = [str(test_data/"L3/s3a"), + str(test_data/"L3/s3b")]) + # read data + assert mso.__class__.__name__ == 'multisat_class' + # compare number of available variables + vlst = list(vars(mso).keys()) + assert len(vlst) == 15 + # compare number of available functions + dlst = dir(mso) + flst = [n for n in dlst if n not in vlst if '__' not in n] + assert len(flst) >= 27 + assert type(mso.vars == 'xarray.core.dataset.Dataset') + assert not 'error' in vars(mso).keys() + +def test_multisat_multivar(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = ['s3a','s3b'] + varalias = ['Hs','U'] + + # init multisat_object + mso = ms(sd=sd, + ed=ed, + name=name, + varalias = varalias, + path = [str(test_data/"L3/s3a"), + str(test_data/"L3/s3b")]) + # read data + assert mso.__class__.__name__ == 'multisat_class' + # compare number of available variables + vlst = list(vars(mso).keys()) + assert len(vlst) == 15 + # compare number of available functions + dlst = dir(mso) + flst = [n for n in dlst if n not in vlst if '__' not in n] + assert len(flst) >= 27 + assert len(list(mso.vars.variables)) == 5 + assert type(mso.vars == 'xarray.core.dataset.Dataset') + assert not 'error' in vars(mso).keys() From 11b5480294e1d042c80aca6768cc111f5d2b1a77 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 27 Jan 2026 13:21:42 +0100 Subject: [PATCH 18/22] Fix test multisat --- tests/test_multisat.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_multisat.py b/tests/test_multisat.py index 5033770..190832d 100644 --- a/tests/test_multisat.py +++ b/tests/test_multisat.py @@ -1,5 +1,8 @@ import pytest +import os +import sys from wavy import ms +from wavy.wconfig import load_or_default def test_multisat(test_data): sd = "2022-2-1 12" @@ -18,6 +21,7 @@ def test_multisat(test_data): assert mso.__class__.__name__ == 'multisat_class' # compare number of available variables vlst = list(vars(mso).keys()) + print(vlst) assert len(vlst) == 15 # compare number of available functions dlst = dir(mso) From e768a1cd3b5684e9634b719aeed55fe8ba2ac074 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 27 Jan 2026 13:34:08 +0100 Subject: [PATCH 19/22] Fix test_multisat 2 --- tests/test_multisat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_multisat.py b/tests/test_multisat.py index 190832d..f548cc7 100644 --- a/tests/test_multisat.py +++ b/tests/test_multisat.py @@ -1,7 +1,8 @@ import pytest import os import sys -from wavy import ms +from wavy import ms, sc +from wavy.consolidate import consolidate_class as cs from wavy.wconfig import load_or_default def test_multisat(test_data): From 15c56864a5a5d38d0cc277a2bc48b042ef3ddb9a Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 27 Jan 2026 14:37:21 +0100 Subject: [PATCH 20/22] Fix multisat test 3 --- tests/test_multisat.py | 4 ++-- wavy/multisat_module.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/test_multisat.py b/tests/test_multisat.py index f548cc7..7774d51 100644 --- a/tests/test_multisat.py +++ b/tests/test_multisat.py @@ -23,7 +23,7 @@ def test_multisat(test_data): # compare number of available variables vlst = list(vars(mso).keys()) print(vlst) - assert len(vlst) == 15 + assert len(vlst) == 17 # compare number of available functions dlst = dir(mso) flst = [n for n in dlst if n not in vlst if '__' not in n] @@ -48,7 +48,7 @@ def test_multisat_multivar(test_data): assert mso.__class__.__name__ == 'multisat_class' # compare number of available variables vlst = list(vars(mso).keys()) - assert len(vlst) == 15 + assert len(vlst) == 17 # compare number of available functions dlst = dir(mso) flst = [n for n in dlst if n not in vlst if '__' not in n] diff --git a/wavy/multisat_module.py b/wavy/multisat_module.py index b31f9d7..3894dab 100644 --- a/wavy/multisat_module.py +++ b/wavy/multisat_module.py @@ -38,6 +38,8 @@ def __init__(self, **kwargs): self.twin = kwargs.get('twin', 30) self.distlim = kwargs.get('distlim', 6) self.region = kwargs.get('region', 'global') + self.path = kwargs.get('path', len(self.name)*[None]) + self.wavy_path = kwargs.get('wavy_path', len(self.name)*[None]) t0 = time.time() # products: either None, same as names, or one product @@ -54,7 +56,8 @@ def __init__(self, **kwargs): nID=self.nID[i], name=n, twin=self.twin, distlim=self.distlim, region=self.region, varalias=self.varalias) - sco = sco.populate() + sco = sco.populate(path=self.path[i], + wavy_path=self.wavy_path[i]) if 'vars' in list(vars(sco)): scos.append(deepcopy(sco)) del sco @@ -85,6 +88,19 @@ def __init__(self, **kwargs): print(" ### multisat object initialized ###") print('# ----- ') + def crop_to_period(self, **kwargs): + """ + Function to crop the variable dictionary to a given period + """ + new = deepcopy(self) + sd = parse_date(kwargs.get('sd', str(new.sd))) + ed = parse_date(kwargs.get('ed', str(new.ed))) + print('Crop to time period:', sd, 'to', ed) + new.vars = new.vars.sortby("time").sel(time=slice(sd, ed)) + new.sd = sd + new.ed = ed + return new + def find_valid_names(scos): names = [scos[0].name] From 61e63362f9337e1a4b410e1b1a9c8cfe9faa9f01 Mon Sep 17 00:00:00 2001 From: Fabien Collas <32676562+fcollas@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:58:50 +0100 Subject: [PATCH 21/22] Update tests/test_collocmod.py --- tests/test_collocmod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_collocmod.py b/tests/test_collocmod.py index 216df48..5f7a286 100644 --- a/tests/test_collocmod.py +++ b/tests/test_collocmod.py @@ -96,7 +96,7 @@ def test_insitu_collocation_leadtime(test_data, tmpdir): # collocate cco = cc(oco=ico, model=model, leadtime=10, twin=9).populate() assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(cco.vars.keys()) == 10 assert len(cco.vars.time) == 2 def test_poi_collocation(): From a891148ecc87e45e83b64952e495a289b6e42784 Mon Sep 17 00:00:00 2001 From: Fabien Collas <32676562+fcollas@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:58:55 +0100 Subject: [PATCH 22/22] Update tests/test_collocmod.py --- tests/test_collocmod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_collocmod.py b/tests/test_collocmod.py index 5f7a286..b1b0625 100644 --- a/tests/test_collocmod.py +++ b/tests/test_collocmod.py @@ -95,7 +95,7 @@ def test_insitu_collocation_leadtime(test_data, tmpdir): # collocate cco = cc(oco=ico, model=model, leadtime=10, twin=9).populate() - assert len(vars(cco).keys()) == 19 + assert len(vars(cco).keys()) == 21 assert len(cco.vars.keys()) == 10 assert len(cco.vars.time) == 2