diff --git a/tests/test_aismod.py b/tests/test_aismod.py index ff61cd4..8cf1895 100644 --- a/tests/test_aismod.py +++ b/tests/test_aismod.py @@ -5,7 +5,7 @@ #@pytest.mark.need_credentials #def test_get_ais_data(): - +# # bbox = ['5.89', '62.3', '6.5', '62.7'] # sd = '2017-01-03 08' # ed = '2017-01-03 09' diff --git a/tests/test_collocmod.py b/tests/test_collocmod.py index 52b5190..b1b0625 100644 --- a/tests/test_collocmod.py +++ b/tests/test_collocmod.py @@ -27,11 +27,32 @@ def test_sat_collocation_and_validation(test_data, tmpdir): # collocate cco = cc(oco=sco, model=model, leadtime='best', distlim=6).populate() - assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(vars(cco).keys()) == 21 + assert len(cco.vars.keys()) == 10 # validate +def test_cco_multivar(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = 's3a' + varalias = 'Hs' + twin = 30 + nID = 'cmems_L3_NRT' + model = 'ww3_4km' + # init satellite_object and check for polygon region + sco = sc(sd=sd, ed=ed, nID=nID, name=name, + varalias=varalias, twin=twin) + # read data + sco = sco.populate(reader='read_local_ncfiles', + path=str(test_data/"L3/s3a")) + # crop to region + sco = sco.crop_to_region(model) + + # collocate + cco = cc(oco=sco, model=model, leadtime='best', distlim=6, varalias=['Hs','Tm01']).populate() + assert len(vars(cco).keys()) == 21 + assert len(cco.vars.keys()) == 11 def test_insitu_collocation_and_validation(test_data, tmpdir): sd = "2022-2-1 12" @@ -51,8 +72,8 @@ def test_insitu_collocation_and_validation(test_data, tmpdir): # collocate cco = cc(oco=ico, model=model, leadtime='best', distlim=6).populate() - assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(vars(cco).keys()) == 21 + assert len(cco.vars.keys()) == 10 # validate @@ -74,8 +95,8 @@ def test_insitu_collocation_leadtime(test_data, tmpdir): # collocate cco = cc(oco=ico, model=model, leadtime=10, twin=9).populate() - assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(vars(cco).keys()) == 21 + assert len(cco.vars.keys()) == 10 assert len(cco.vars.time) == 2 def test_poi_collocation(): @@ -90,8 +111,8 @@ def test_poi_collocation(): # collocate cco = cc(oco=pco, model='ww3_4km', leadtime='best').populate() - assert len(vars(cco).keys()) == 19 - assert len(cco.vars.keys()) == 9 + assert len(vars(cco).keys()) == 21 + assert len(cco.vars.keys()) == 10 # # write to nc diff --git a/tests/test_consolidate.py b/tests/test_consolidate.py index 989a39b..7e2ee2c 100644 --- a/tests/test_consolidate.py +++ b/tests/test_consolidate.py @@ -1,17 +1,34 @@ -#import sys -#import os -#import numpy as np -#from datetime import datetime -#import pytest -# -#from wavy.consolidate import consolidate_class as cs -#from wavy.satmod import satellite_class as sc +import sys +import os +import numpy as np +from datetime import datetime +import pytest +from wavy.satellite_module import satellite_class as sc +from wavy.consolidate import consolidate_class as cs -#def test_consolidate_scos(test_data, benchmark): -# sco1 = sc(sdate="2020-11-1 12",region="NordicSeas", -# mission='s3a', path_local=str(test_data/"L3")) -# sco2 = sc(sdate="2020-11-1 13",region="NordicSeas", -# mission='s3b', path_local=str(test_data/"L3")) -# cso = cs([sco1,sco2]) -# assert len(list(vars(cso).keys())) >= 8 -# assert len(list(cso.vars.keys())) == 6 +def test_consolidate_satellite(test_data): + # satellite consolidate + sco1 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3a').populate(path=str(test_data/"L3/s3a")) + sco2 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3b').populate(path=str(test_data/"L3/s3b")) + + cso = cs([sco1,sco2]) + + len(list(vars(cso).keys())) >= 8 + len(list(cso.vars.keys())) == 3 + +def test_consolidate_satellite_multivar(test_data): + + varalias = ['Hs','U'] + + # satellite consolidate + sco1 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3a', varalias=varalias).populate(path=str(test_data/"L3/s3a")) + sco2 = sc(sd="2022-2-1",ed ="2022-2-3",region="NordicSeas", nID="cmems_L3_NRT", + name='s3b', varalias=varalias).populate(path=str(test_data/"L3/s3b")) + + cso = cs([sco1,sco2]) + + len(list(vars(cso).keys())) >= 8 + len(list(cso.vars.keys())) == 3 \ No newline at end of file diff --git a/tests/test_filtermod.py b/tests/test_filtermod.py index 4d1e799..67babaf 100644 --- a/tests/test_filtermod.py +++ b/tests/test_filtermod.py @@ -143,6 +143,76 @@ def test_filter_distance_to_coast(test_data): assert type(sco.vars == 'xarray.core.dataset.Dataset') assert not 'error' in vars(sco).keys() +def test_filter_ico_multivar(test_data): + + varalias = ['Hs','U'] # default + sd = "2023-7-2 00" + ed = "2023-7-3 00" + nID = 'MO_Draugen_monthly' + name = 'Draugen' + ico = ic(nID=nID, sd=sd, ed=ed, varalias=varalias, name=name) + ico = ico.populate(path=str(test_data/"insitu/monthly/Draugen")) + + assert len(ico.vars['time']) > 0 + assert len(ico.vars.keys()) == 4 + assert not all(np.isnan(v) for v in ico.vars['Hs']) + assert not all(np.isnan(v) for v in ico.vars['U']) + + filter_1 = ico.filter_runmean(window=3, + chunk_min=3, + sampling_rate_Hz=1/600, + varalias='Hs') + + assert len(filter_1.vars['time']) > 0 + assert len(filter_1.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_1.vars['Hs']) + assert not all(np.isnan(v) for v in filter_1.vars['U']) + + filter_2 = filter_1.apply_limits(llim=1, ulim=3, + varalias='Hs') + + assert len(filter_2.vars['time']) > 0 + assert len(filter_2.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_2.vars['Hs']) + assert not all(np.isnan(v) for v in filter_2.vars['U']) + +def test_filter_sco_multivar(test_data): + + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = 's3a' + varalias = ['Hs','U'] + twin = 30 + nID = 'cmems_L3_NRT' + # init satellite_object + sco = sc(sd=sd, ed=ed, nID=nID, name=name, + varalias=varalias, + twin=twin) + # read data + sco = sco.populate(path=str(test_data/"L3/s3a")) + + assert len(sco.vars['time']) > 0 + assert len(sco.vars.keys()) == 4 + assert not all(np.isnan(v) for v in sco.vars['Hs']) + assert not all(np.isnan(v) for v in sco.vars['U']) + + filter_1 = sco.filter_runmean(window=3, + chunk_min=3, + sampling_rate_Hz=1/600, + varalias='Hs') + + assert len(filter_1.vars['time']) > 0 + assert len(filter_1.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_1.vars['Hs']) + assert not all(np.isnan(v) for v in filter_1.vars['U']) + + filter_2 = filter_1.apply_limits(llim=1, ulim=3, + varalias='Hs') + + assert len(filter_2.vars['time']) > 0 + assert len(filter_2.vars.keys()) == 4 + assert not all(np.isnan(v) for v in filter_2.vars['Hs']) + assert not all(np.isnan(v) for v in filter_2.vars['U']) def test_filter_landMask_ms(test_data): sd = "2022-2-1 12" ed = "2022-2-1 12" diff --git a/tests/test_insitumod.py b/tests/test_insitumod.py index 59a453a..7fc3db6 100644 --- a/tests/test_insitumod.py +++ b/tests/test_insitumod.py @@ -126,6 +126,29 @@ def test_cmems_insitu_daily(test_data): # if '.nc' in filelist[i]] # assert len(nclist) >= 1 +def test_cmems_insitu_multivar(test_data): + varalias = ['Hs','U'] # default + sd = "2023-7-2 00" + ed = "2023-7-3 00" + nID = 'MO_Draugen_monthly' + name = 'Draugen' + ico = ic(nID=nID, sd=sd, ed=ed, varalias=varalias, name=name) + print(ico) + print(vars(ico).keys()) + assert ico.__class__.__name__ == 'insitu_class' + assert len(vars(ico).keys()) == 12 + ico.list_input_files(show=True) + new = ico.populate(path=str(test_data/"insitu/monthly/Draugen")) + assert len(new.vars.keys()) == 4 + # check if some data was imported + assert len(new.vars['time']) > 0 + # check that not all data is nan + assert not all(np.isnan(v) for v in new.vars['time']) + assert not all(np.isnan(v) for v in new.vars['Hs']) + assert not all(np.isnan(v) for v in new.vars['U']) + assert not all(np.isnan(v) for v in new.vars['lons']) + assert not all(np.isnan(v) for v in new.vars['lats']) + def test_insitu_poi(tmpdir): # define poi dictionary for track diff --git a/tests/test_modelmod.py b/tests/test_modelmod.py index 6f41639..d2a5f86 100644 --- a/tests/test_modelmod.py +++ b/tests/test_modelmod.py @@ -59,6 +59,15 @@ def test_NORA3_hc_waves(): assert len(vars(mco).keys()) == 18 assert len(mco.vars.keys()) == 3 +def test_mco_multivar(): + #get_model + mco = mc(nID='ww3_4km', sd="2023-6-1", ed="2023-6-1 00", varalias=['Hs', 'U']) + assert mco.__class__.__name__ == 'model_class' + mco.populate() + print(mco.vars) + assert len(vars(mco).keys()) == 18 + assert len(mco.vars.keys()) == 4 + # Fails #def test_MY_L4_thredds(): # """ diff --git a/tests/test_multisat.py b/tests/test_multisat.py index f57e89e..2c189f9 100644 --- a/tests/test_multisat.py +++ b/tests/test_multisat.py @@ -25,3 +25,29 @@ def test_multisat(test_data): assert len(flst) >= 27 assert type(mso.vars == 'xarray.core.dataset.Dataset') assert not 'error' in vars(mso).keys() + +def test_multisat_multivar(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = ['s3a','s3b'] + varalias = ['Hs','U'] + + # init multisat_object + mso = ms(sd=sd, + ed=ed, + name=name, + varalias = varalias, + path = [str(test_data/"L3/s3a"), + str(test_data/"L3/s3b")]) + # read data + assert mso.__class__.__name__ == 'multisat_class' + # compare number of available variables + vlst = list(vars(mso).keys()) + assert len(vlst) == 17 + # compare number of available functions + dlst = dir(mso) + flst = [n for n in dlst if n not in vlst if '__' not in n] + assert len(flst) >= 27 + assert len(list(mso.vars.variables)) == 5 + assert type(mso.vars == 'xarray.core.dataset.Dataset') + assert not 'error' in vars(mso).keys() diff --git a/tests/test_satellite_module.py b/tests/test_satellite_module.py index 95946cc..dd88bf8 100644 --- a/tests/test_satellite_module.py +++ b/tests/test_satellite_module.py @@ -106,6 +106,35 @@ def test_default_reader(test_data): assert not 'error' in vars(sco).keys() +def test_sco_multivar(test_data): + sd = "2022-2-1 12" + ed = "2022-2-1 12" + name = 's3a' + varalias = ['Hs','U'] + twin = 30 + nID = 'cmems_L3_NRT' + # init satellite_object + sco = sc(sd=sd, ed=ed, nID=nID, name=name, + varalias=varalias, + twin=twin) + # read data + sco = sco.populate(path=str(test_data/"L3/s3a")) + assert sco.__class__.__name__ == 'satellite_class' + # compare number of available variables + vlst = list(vars(sco).keys()) + assert len(vlst) == 19 + # compare number of available functions + dlst = dir(sco) + flst = [n for n in dlst if n not in vlst if '__' not in n] + assert len(flst) >= 47 + assert type(sco.vars == 'xarray.core.dataset.Dataset') + assert not 'error' in vars(sco).keys() + assert len(sco.vars['time']) > 0 + assert len(sco.vars.keys()) == 4 + assert not all(np.isnan(v) for v in sco.vars['Hs']) + assert not all(np.isnan(v) for v in sco.vars['U']) + + def test_polygon_region(test_data): sd = "2022-2-01 01" ed = "2022-2-03 23" diff --git a/tests/test_triple_collocation.py b/tests/test_triple_collocation.py index c0ccedb..8fa6a26 100644 --- a/tests/test_triple_collocation.py +++ b/tests/test_triple_collocation.py @@ -194,17 +194,84 @@ def test_least_squares_merging(test_data): assert len(least_squares_merge) == len(dict_data['insitu']) -def test_spectra_cco(test_data): +def test_get_mean_spectra(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + + spectra = tc.get_mean_spectra(ico.vars, varname='Hs', fs=6, nsample=64) + + assert len(spectra) == 32 def test_integrate_r2(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + + # Import model data + mco = mc(sd='2014-01-01', + ed='2018-12-31', + nID='NORA3_hc_waves') + mco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_mco.nc") + ) + + ps_ico = tc.get_mean_spectra(ico.vars, varname='Hs', fs=6, nsample=64) + ps_mod = tc.get_mean_spectra(mco.vars, varname='Hs', fs=6, nsample=64) + + r2 = tc.integrate_r2(ps_ico['spectra'], ps_mod['spectra'], ps_ico['f']) + + assert isinstance(r2, float) + assert not(np.isnan(r2)) def test_filter_collocation_distance(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + # Import satellite data + sco = sc(sd='2014-01-01', + ed='2018-12-31', + nID='CCIv1_L3', + name='multi') + sco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_sco.nc") + ) + # Import model data + mco = mc(sd='2014-01-01', + ed='2018-12-31', + nID='NORA3_hc_waves') + mco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_mco.nc") + ) + # Create dictionary for triple collocation function + dict_data = {'insitu': ico, 'satellite': sco, 'model': mco} + + data_filtered = tc.filter_collocation_distance(dict_data, dist_max=10, name='satellite') + + assert dict_data.keys() == data_filtered.keys() + assert len(dict_data['insitu'].vars.time) >= len(data_filtered['insitu'].vars.time) + assert np.max(data_filtered['satellite'].vars.colloc_dist) <= 10 def test_filter_values(test_data): @@ -245,4 +312,37 @@ def test_filter_values(test_data): def test_filter_dynamic_collocation(test_data): - assert True + # Wavy objects + # Import in-situ data + ico = ic(sd='2014-01-01', + ed='2018-12-31', + nID='history_cmems_NRT', + name='Norne') + ico.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_ico.nc") + ) + # Import satellite data + sco = sc(sd='2014-01-01', + ed='2018-12-31', + nID='CCIv1_L3', + name='multi') + sco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_sco.nc") + ) + # Import model data + mco = mc(sd='2014-01-01', + ed='2018-12-31', + nID='NORA3_hc_waves') + mco.vars = xr.open_dataset( + str(test_data/"triple_collocation/Norne_mco.nc") + ) + # Create dictionary for triple collocation function + dict_data = {'insitu': ico.vars.Hs.values, + 'satellite': sco.vars.Hs.values, + 'model': mco.vars.Hs.values} + + data_filtered = tc.filter_dynamic_collocation(dict_data, 'insitu', 'model', max_rel_diff=0.05) + + assert dict_data.keys() == data_filtered.keys() + assert len(dict_data['insitu']) >= len(data_filtered['insitu']) + assert all((data_filtered['insitu'][i] - data_filtered['model'][i])/data_filtered['insitu'][i] <= 0.05 for i in range(len(data_filtered['insitu']))) diff --git a/wavy/collocation_module.py b/wavy/collocation_module.py index ae80c86..46bd49e 100755 --- a/wavy/collocation_module.py +++ b/wavy/collocation_module.py @@ -270,23 +270,30 @@ def __init__(self, oco=None, model=None, poi=None, else: varalias = oco.varalias[0] self.varalias = varalias - self.varalias_obs = varalias - self.varalias_mod = kwargs.get('varalias_mod', varalias) + if isinstance(self.varalias, str): + self.varalias = [self.varalias] + self.varalias_obs = oco.varalias + if isinstance(self.varalias_obs, str): + self.varalias_obs = [self.varalias_obs] + self.varalias_mod = self.varalias self.model = model self.leadtime = leadtime self.oco = oco self.nID = oco.nID self.model = model self.obstype = str(type(oco))[8:-2] - self.stdvarname = oco.stdvarname + self.units = [variable_def[v].get('units') for v in self.varalias] + self.stdvarname = [variable_def[v].get('standard_name') for v in\ + self.varalias] self.region = oco.region - self.units = variable_def[varalias].get('units') self.sd = oco.sd self.ed = oco.ed self.twin = kwargs.get('twin', oco.twin) self.distlim = kwargs.get('distlim', 6) self.method = kwargs.get('method', 'closest') self.colloc_time_method = kwargs.get('colloc_time_method', 'nearest') + self.nproc = kwargs.get('nproc',16) + self.res = kwargs.get('res',(0.5,0.5)) print(" ") print(" ### Collocation_class object initialized ###") @@ -304,6 +311,16 @@ def populate(self, **kwargs): ds = new._build_xr_dataset(results_dict) ds = ds.assign_coords(time=ds.time.values) new.vars = ds + + # Add extra variables from oco + list_vars_cco = list(new.vars.keys()) + list_vars_oco = ['obs_' + v for v in list(new.oco.vars.keys())] + list_vars_extra = [v[4:] for v in list_vars_oco if v not in\ + list_vars_cco] + new.vars = new.vars.merge(new.oco.vars[list_vars_extra].\ + rename({v:'obs_'+v for v in list_vars_extra}), + join='left') + new = new._drop_duplicates(**kwargs) t1 = time.time() print(" ") @@ -323,7 +340,9 @@ def populate(self, **kwargs): return new def _build_xr_dataset(self, results_dict): - ds = xr.Dataset({ + + try: + ds = xr.Dataset({ 'time': xr.DataArray( data=results_dict['obs_time'], dims=['time'], @@ -336,12 +355,6 @@ def _build_xr_dataset(self, results_dict): coords={'time': results_dict['obs_time']}, attrs=variable_def['dist'], ), - 'obs_values': xr.DataArray( - data=results_dict['obs_values'], - dims=['time'], - coords={'time': results_dict['obs_time']}, - attrs=variable_def[self.varalias], - ), 'obs_lons': xr.DataArray( data=results_dict['obs_lons'], dims=['time'], @@ -354,11 +367,16 @@ def _build_xr_dataset(self, results_dict): coords={'time': results_dict['obs_time']}, attrs=variable_def['lats'], ), - 'model_values': xr.DataArray( - data=results_dict['model_values'], + **{'obs_'+v:xr.DataArray( + data=results_dict['obs_'+v], + dims=['time'], + coords={'time': results_dict['obs_time']}, + attrs=variable_def[v]) for v in self.varalias_obs}, + 'model_time': xr.DataArray( + data=results_dict['model_time'], dims=['time'], coords={'time': results_dict['obs_time']}, - attrs=variable_def[self.varalias], + attrs={}, ), 'model_lons': xr.DataArray( data=results_dict['model_lons'], @@ -372,6 +390,11 @@ def _build_xr_dataset(self, results_dict): coords={'time': results_dict['obs_time']}, attrs=variable_def['lats'], ), + **{'model_'+v:xr.DataArray( + data=results_dict['model_'+v], + dims=['time'], + coords={'time': results_dict['obs_time']}, + attrs=variable_def[v]) for v in self.varalias_mod}, 'colidx_x': xr.DataArray( data=results_dict['collocation_idx_x'], dims=['time'], @@ -387,6 +410,9 @@ def _build_xr_dataset(self, results_dict): }, attrs={'title': str(type(self))[8:-2] + ' dataset'} ) + except Exception as e: + print(e) + print(ds) return ds def _drop_duplicates(self, **kwargs): @@ -408,20 +434,24 @@ def _collocate_field(self, mco, tmp_dict, **kwargs): """ Mlons = mco.vars.lons.data Mlats = mco.vars.lats.data - Mvars = mco.vars[mco.varalias].data + Mvars = {v:mco.vars[v].data for v in self.varalias_mod} + if len(Mlons.shape) > 2: Mlons = Mlons[0, :].squeeze() Mlats = Mlats[0, :].squeeze() - Mvars = Mvars[0, :].squeeze() + Mvars = {v:Mvars[v][0, :].squeeze() for v in\ + Mvars.keys()} + elif len(Mlons.shape) == 2: Mlons = Mlons.squeeze() Mlats = Mlats.squeeze() - Mvars = Mvars.squeeze() + Mvars = {v:Mvars[v].squeeze() for v in Mvars.keys()} + elif len(Mlons.shape) == 1: Mlons, Mlats = np.meshgrid(Mlons, Mlats) - Mvars = Mvars.squeeze() + Mvars = {v:Mvars[v].squeeze() for v in Mvars.keys()} assert len(Mlons.shape) == 2 - obs_vars = tmp_dict[self.varalias_obs] + obs_vars = {v:tmp_dict[v] for v in self.varalias_obs} obs_lons = tmp_dict['lons'] obs_lats = tmp_dict['lats'] # Compare wave heights of satellite with model with @@ -435,22 +465,25 @@ def _collocate_field(self, mco, tmp_dict, **kwargs): # caution: index_array_2d is tuple # impose distlim dist_idx = np.where((distance_array < self.distlim*1000) & - (~np.isnan(Mvars[index_array_2d[0], + (~np.isnan(Mvars[self.varalias_mod[0]]\ + [index_array_2d[0], index_array_2d[1]])))[0] idx_x = index_array_2d[0][dist_idx] idx_y = index_array_2d[1][dist_idx] results_dict = { 'dist': list(distance_array[dist_idx]), - 'model_values': list(Mvars[idx_x, idx_y]), 'model_lons': list(Mlons[idx_x, idx_y]), 'model_lats': list(Mlats[idx_x, idx_y]), - 'obs_values': list(obs_vars[dist_idx]), 'obs_lons': list(obs_lons[dist_idx]), 'obs_lats': list(obs_lats[dist_idx]), 'collocation_idx_x': list(idx_x), 'collocation_idx_y': list(idx_y), - 'time': tmp_dict['time'][dist_idx] + 'time': tmp_dict['time'][dist_idx], + **{'model_'+v: Mvars[v][idx_x, idx_y] for v in Mvars.keys()}, + **{'obs_'+v: obs_vars[v][dist_idx] for v in\ + self.varalias_obs} } + return results_dict def _collocate_track(self, **kwargs): @@ -479,18 +512,19 @@ def _collocate_track(self, **kwargs): print(f'... done, used {t2-t1:.2f} seconds') print("Start collocation ...") + results_dict = { 'model_time': [], 'obs_time': [], 'dist': [], - 'model_values': [], 'model_lons': [], 'model_lats': [], - 'obs_values': [], 'obs_lons': [], 'obs_lats': [], 'collocation_idx_x': [], 'collocation_idx_y': [], + **{'model_'+v:[] for v in self.varalias_mod}, + **{'obs_'+v:[] for v in self.varalias_obs} } for i in tqdm(range(len(fc_date))): @@ -535,29 +569,32 @@ def _collocate_track(self, **kwargs): tmp_dict['time'] = self.oco.vars['time'].values[idx] tmp_dict['lats'] = self.oco.vars['lats'].values[idx] tmp_dict['lons'] = self.oco.vars['lons'].values[idx] - tmp_dict[self.varalias] = \ - self.oco.vars[self.varalias].values[idx] - mco = mc(sd=fc_date[i], ed=fc_date[i], nID=self.model, - leadtime=self.leadtime, varalias=self.varalias_mod, + + for v in self.varalias_obs: + tmp_dict[v] = self.oco.vars[v].values[idx] + print("#########################") + print(self.varalias_mod) + mco = mc(sd=fc_date[i], ed=fc_date[i], + nID=self.model, + leadtime=self.leadtime, + varalias=self.varalias_mod, **kwargs) mco = mco.populate(**kwargs) results_dict_tmp = self._collocate_field( mco, tmp_dict, **kwargs) - if (len(results_dict_tmp['model_values']) > 0): + if (len(results_dict_tmp["model_" +\ + self.varalias_mod[0]]) > 0): # append to dict - results_dict['model_time'].append(fc_date[i]) + results_dict['model_time'].append(\ + [fc_date[i]]*len(results_dict_tmp['time'])) results_dict['obs_time'].append( results_dict_tmp['time']) results_dict['dist'].append( results_dict_tmp['dist']) - results_dict['model_values'].append( - results_dict_tmp['model_values']) results_dict['model_lons'].append( results_dict_tmp['model_lons']) results_dict['model_lats'].append( results_dict_tmp['model_lats']) - results_dict['obs_values'].append( - results_dict_tmp['obs_values']) results_dict['obs_lats'].append( results_dict_tmp['obs_lats']) results_dict['obs_lons'].append( @@ -566,6 +603,12 @@ def _collocate_track(self, **kwargs): results_dict_tmp['collocation_idx_x']) results_dict['collocation_idx_y'].append( results_dict_tmp['collocation_idx_y']) + for v in self.varalias_mod: + results_dict['model_'+v].append( + results_dict_tmp['model_'+v]) + for v in self.varalias_obs: + results_dict['obs_'+v].append( + results_dict_tmp['obs_'+v]) else: pass if 'results_dict_tmp' in locals(): @@ -577,19 +620,22 @@ def _collocate_track(self, **kwargs): logger.exception(e) print(e) # flatten all aggregated entries - results_dict['model_time'] = results_dict['model_time'] + results_dict['model_time'] = flatten(results_dict['model_time']) results_dict['obs_time'] = flatten(results_dict['obs_time']) results_dict['dist'] = flatten(results_dict['dist']) - results_dict['model_values'] = flatten(results_dict['model_values']) results_dict['model_lons'] = flatten(results_dict['model_lons']) results_dict['model_lats'] = flatten(results_dict['model_lats']) - results_dict['obs_values'] = flatten(results_dict['obs_values']) results_dict['obs_lats'] = flatten(results_dict['obs_lats']) results_dict['obs_lons'] = flatten(results_dict['obs_lons']) results_dict['collocation_idx_x'] = flatten( results_dict['collocation_idx_x']) results_dict['collocation_idx_y'] = flatten( results_dict['collocation_idx_y']) + for v in self.varalias_mod: + results_dict['model_'+v] = flatten(results_dict['model_'+v]) + for v in self.varalias_obs: + results_dict['obs_'+v] = flatten(results_dict['obs_'+v]) + return results_dict def _collocate_centered_model_value(self, time, lon, lat, **kwargs): @@ -598,58 +644,64 @@ def _collocate_centered_model_value(self, time, lon, lat, **kwargs): nID_model = self.model name_model = self.model - res = kwargs.get('res', (0.5, 0.5)) + res = self.res colloc_time_method = self.colloc_time_method print('Using resolution {}'.format(res)) # ADD CHECK LIMITS FOR LAT AND LON res_dict = {} - time = pd.to_datetime(time) - time = hour_rounder(time, method=colloc_time_method) - - mco = mc(sd=time, ed=time, - nID=nID_model, name=name_model, - max_lt=12).populate(twin=5) # ADD AS PARAMETERS - - bb = (lon - res[0]/2, - lon + res[0]/2, - lat - res[1]/2, + time_dt = pd.to_datetime(time) + model_time = hour_rounder(time_dt, method=colloc_time_method) + + mco = mc(sd=model_time, ed=model_time, + nID=nID_model, + name=name_model, + varalias=self.varalias_mod, + max_lt=12).populate(twin=5) # ADD AS PARAMETERS + + bb = (lon - res[0]/2, + lon + res[0]/2, + lat - res[1]/2, lat + res[1]/2) - gco = gc(lons=mco.vars.lons.squeeze().values.ravel(), - lats=mco.vars.lats.squeeze().values.ravel(), - values=mco.vars.Hs.squeeze().values.ravel(), - bb=bb, res=res, - varalias=mco.varalias, - units=mco.units, - sdate=mco.vars.time, - edate=mco.vars.time) - - gridvar, lon_grid, lat_grid = apply_metric(gco=gco) - - ts = gridvar['mor'].flatten() + for i, v in enumerate(mco.varalias): + gco = gc(lons=mco.vars.lons.squeeze().values.ravel(), + lats=mco.vars.lats.squeeze().values.ravel(), + values=mco.vars[v].squeeze().values.ravel(), + bb=bb, res=res, + varalias=v, + units=mco.units[i], + sdate=mco.vars.time, + edate=mco.vars.time) + + gridvar, lon_grid, lat_grid = apply_metric(gco=gco) + + ts = gridvar['mor'].flatten() + + res_dict['model_'+v] = ts[0] + lon_flat = lon_grid.flatten() lat_flat = lat_grid.flatten() - res_dict['hs'] = ts[0] res_dict['lon'] = lon_flat[0] res_dict['lat'] = lat_flat[0] - res_dict['time'] = time + res_dict['obs_time'] = time + res_dict['model_time'] = model_time return res_dict def _collocate_regridded_model(self, **kwargs): from joblib import Parallel, delayed - - hs_mod_list = [] - lon_mod_list = [] - lat_mod_list = [] - time_mod_list = [] - - nproc = kwargs.get('nproc', 16) - + + hs_mod_list=[] + lon_mod_list=[] + lat_mod_list=[] + time_mod_list=[] + + nproc = self.nproc + oco_vars = self.oco.vars length = len(oco_vars.time.values) @@ -664,41 +716,26 @@ def _collocate_regridded_model(self, **kwargs): ) length_colloc_mod_list = len(colloc_mod_list) - hs_mod_list = [colloc_mod_list[i]['hs'] for i in \ - range(length_colloc_mod_list)] + var_mod_list = {v:[colloc_mod_list[i]['model_'+v] for i in \ + range(length_colloc_mod_list)] for v in\ + self.varalias_mod} lon_mod_list = [colloc_mod_list[i]['lon'] for i in \ range(length_colloc_mod_list)] lat_mod_list = [colloc_mod_list[i]['lat'] for i in \ range(length_colloc_mod_list)] - time_mod_list = [colloc_mod_list[i]['time'] for i in \ + time_mod_list = [colloc_mod_list[i]['model_time'] for i in \ + range(length_colloc_mod_list)] + time_obs_list = [colloc_mod_list[i]['obs_time'] for i in \ range(length_colloc_mod_list)] - - mod_colloc_vars = xr.Dataset( - { - "lats": ( - ("time"), - lat_mod_list, - ), - "lons": ( - ("time"), - lon_mod_list - ), - "Hs": ( - ("time"), - hs_mod_list - ) - }, - coords={"time": time_mod_list}, - ) results_dict = { 'model_time': time_mod_list, - 'obs_time': oco_vars.time.values, + 'obs_time': time_obs_list, 'dist': [0]*length, - 'model_values': hs_mod_list, + **{'model_'+ v: var_mod_list[v] for v in self.varalias_mod}, 'model_lons': lon_mod_list, 'model_lats': lat_mod_list, - 'obs_values': oco_vars.Hs.values, + **{'obs_'+v: oco_vars[v].values for v in self.varalias_obs}, 'obs_lons': oco_vars.lons.values, 'obs_lats': oco_vars.lats.values, 'collocation_idx_x': [0]*length, @@ -733,10 +770,24 @@ def collocate(self, **kwargs): def validate_collocated_values(self, **kwargs): + + varalias = kwargs.get('varalias', self.varalias[0]) times = self.vars['time'] dtime = [parse_date(str(t.data)) for t in times] - mods = self.vars['model_values'] - obs = self.vars['obs_values'] + list_vars = list(self.vars.variables) + assert 'model_'+varalias in list_vars, "model_{}".format(varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + assert 'obs_'+varalias in list_vars, "obs_{}".format(varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + print("Validating model_{} against obs_{}".format(varalias, varalias)) + mods = self.vars['model_'+varalias] + obs = self.vars['obs_'+varalias] sdate = dtime[0] edate = dtime[-1] validation_dict = validate_collocated_values( diff --git a/wavy/config/insitu_cfg.yaml.default b/wavy/config/insitu_cfg.yaml.default index 5cd19d3..e7ac1b9 100644 --- a/wavy/config/insitu_cfg.yaml.default +++ b/wavy/config/insitu_cfg.yaml.default @@ -836,6 +836,7 @@ MO_Draugen_monthly: # object nID (nameID) # optional, needs to be defined if not cf and in variable_info.yaml vardef: Hs: VAVH + U: WSPD lats: LATITUDE lons: LONGITUDE time: TIME diff --git a/wavy/config/variable_def.yaml.default b/wavy/config/variable_def.yaml.default index 371c721..95ff5d5 100644 --- a/wavy/config/variable_def.yaml.default +++ b/wavy/config/variable_def.yaml.default @@ -159,9 +159,6 @@ U: standard_name: wind_speed units: m s-1 _FillValue: -999.9 - aliases_of_vector_components: - alt1: [ua,va] - alt2: [ux,vy] valid_range: [0., 200] Udir: @@ -225,3 +222,156 @@ hmaxst: long_name: 'space time domain maximum individual wave height' units: m _FillValue: -999.9 + +UdirTo: + standard_name: wind_to_direction + long_name: 'Wind direction' + units: degree + _FillValue: -999.9 + +FV: + standard_name: friction_velocity + long_name: 'friction velocity' + units: m s-1 + _FillValue: -999.9 + +DC: + standard_name: drag_coefficient + long_name: 'drag coefficient' + units: NA + _FillValue: -999.9 + +MdirTo: + standard_name: sea_surface_wave_to_direction + long_name: 'Total mean wave direction' + units: degree + _FillValue: -999.9 + +HsSea: + standard_name: sea_surface_wind_wave_significant_height + long_name: 'Sea significant wave height' + units: m + _FillValue: -999.9 + +TpSea: + standard_name: sea_surface_wind_wave_peak_period_from_variance_spectral_density + long_name: 'Sea peak period' + units: s + _FillValue: -999.9 + +Tm10Sea: + standard_name: sea_surface_wind_wave_mean_period_from_variance_spectral_density_inverse_frequency_moment + long_name: 'Sea mean period' + units: s + _FillValue: -999.9 + +MdirToSea: + standard_name: sea_surface_wind_wave_to_direction + long_name: 'Sea mean wave direction' + units: degree + _FillValue: -999.9 + +SIC: + standard_name: sea_ice_concentration + long_name: 'sea ice concentration' + units: '1' + _FillValue: -999.9 + +HsST: + standard_name: sea_surface_swell_wave_significant_height + long_name: 'Swell significant wave height' + units: m + _FillValue: -999.9 + +TpST: + standard_name: sea_surface_swell_wave_period_at_variance_spectral_density_maximum + long_name: 'Swell peak period' + units: s + _FillValue: -999.9 + +MdirToST: + standard_name: sea_surface_swell_wave_to_direction + long_name: 'Swell mean wave direction' + units: degree + _FillValue: -999.9 + +eTm: + standard_name: expected_wave_period + long_name: 'expected wave period' + units: s + _FillValue: -999.9 + +fpI: + standard_name: interpolated_peak_frequency + long_name: 'interpolated peak frequency' + units: s + _FillValue: -999.9 + +HsS1: + standard_name: sea_surface_primary_swell_wave_significant_height + long_name: 'first swell significant wave height' + units: m + _FillValue: -999.9 + +TmS1: + standard_name: sea_surface_primary_swell_wave_mean_period + long_name: 'first swell mean period' + units: s + _FillValue: -999.9 + +MdirToS1: + standard_name: sea_surface_primary_swell_wave_to_direction + long_name: 'first swell direction' + units: degree + _FillValue: -999.9 + +HsS2: + standard_name: sea_surface_secondary_swell_wave_significant_height + long_name: 'second swell significant wave height' + units: m + _FillValue: -999.9 + +TmS2: + standard_name: sea_surface_secondary_swell_wave_mean_period + long_name: 'second swell mean period' + units: s + _FillValue: -999.9 + +MdirToS2: + standard_name: sea_surface_secondary_swell_wave_to_direction + long_name: 'second swell direction' + units: degree + _FillValue: -999.9 + +HsS3: + standard_name: sea_surface_tertiary_swell_wave_significant_height + long_name: 'third swell significant wave height' + units: m + _FillValue: -999.9 + +TmS3: + standard_name: sea_surface_tertiary_swell_wave_mean_period + long_name: 'third swell mean period' + units: m s-1 + _FillValue: -999.9 + +MdirToS3: + standard_name: sea_surface_tertiary_swell_wave_to_direction + long_name: 'third swell direction' + units: degree + _FillValue: -999.9 + +SIT: + standard_name: sea_ice_thickness + long_name: 'sea ice thickness' + units: m + _FillValue: -999.9 + +PdirTo: + standard_name: "sea_surface_wave_to_direction_at_variance_spectral\ + _density_maximum" + units: degree + valid_range: [0.,360.] + _FillValue: -999.9 + type: cyclic + diff --git a/wavy/filtermod.py b/wavy/filtermod.py index 4b5b285..821c3cc 100644 --- a/wavy/filtermod.py +++ b/wavy/filtermod.py @@ -33,6 +33,13 @@ def apply_limits(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) llim = kwargs.get('llim', @@ -132,6 +139,13 @@ def filter_lanczos(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) # apply slider if needed @@ -175,8 +189,17 @@ def filter_runmean(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) + + print("Applying filter to {}".format(varalias)) # apply slider if needed win = kwargs.get('slider', len(new.vars.time)) ol = kwargs.get('overlap', 0) @@ -213,8 +236,17 @@ def filter_runmean(self, **kwargs): def filter_GP(self, **kwargs): print('Apply GPR filter') new = deepcopy(self) - varalias = kwargs.get('varalias', new.varalias[0]) - + if isinstance(new.varalias, list): + varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) + else: + varalias = kwargs.get('varalias', new.varalias) # apply slider if needed win = kwargs.get('slider', len(new.vars.time)) ol = kwargs.get('overlap', 0) @@ -256,9 +288,15 @@ def filter_linearGAM(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) - # apply slider if needed win = kwargs.get('slider', len(new.vars.time)) ol = kwargs.get('overlap', 0) @@ -311,6 +349,13 @@ def despike_blockStd(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) @@ -367,6 +412,13 @@ def despike_blockQ(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) @@ -410,6 +462,13 @@ def despike_GP(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) @@ -454,6 +513,13 @@ def despike_linearGAM(self, **kwargs): new = deepcopy(self) if isinstance(new.varalias, list): varalias = kwargs.get('varalias', new.varalias[0]) + + if (len(new.varalias) > 1) and\ + (kwargs.get('varalias', None) is None): + msg="Variable {} selected by default. ".format(varalias)\ + + "If you wish to apply the filter to another variable, "+\ + "please specify it using varalias." + print(msg) else: varalias = kwargs.get('varalias', new.varalias) diff --git a/wavy/grid_readers.py b/wavy/grid_readers.py index adaeead..b585961 100755 --- a/wavy/grid_readers.py +++ b/wavy/grid_readers.py @@ -40,13 +40,15 @@ def read_ww3_unstructured_to_grid(**kwargs): pathlst = kwargs.get('pathlst') nID = kwargs.get('nID') varalias = kwargs.get('varalias') + if isinstance(varalias, str): + varalias = [varalias] sd = kwargs.get('sd') ed = kwargs.get('ed') meta = kwargs.get('meta') # get meta data - varstr = get_filevarname(varalias, variable_def, - model_dict[nID], meta) + varstr = [get_filevarname(v, variable_def, + model_dict[nID], meta) for v in varalias] lonstr = get_filevarname('lons', variable_def, model_dict[nID], meta) latstr = get_filevarname('lats', variable_def, @@ -72,12 +74,12 @@ def read_ww3_unstructured_to_grid(**kwargs): if (dt >= sd and dt <= ed): # varnames = (varname, lonstr, latstr, timestr) if len(times) < 2: - var = (ds[varstr].data, + var = (*tuple(ds[v].data for v in varstr), ds[lonstr].data, ds[latstr].data, t.reshape((1,))) else: - var = (ds[varstr].sel({timestr: t}).data, + var = (*tuple(ds[v].sel({timestr: t}).data for v in varstr), ds[lonstr].data, ds[latstr].data, t.reshape((1,))) @@ -99,29 +101,54 @@ def read_ww3_unstructured_to_grid(**kwargs): def get_gridded_dataset(var, t, **kwargs): varstr = kwargs.get('varalias') + if isinstance(varstr, str): + varstr = [varstr] + len_varstr = len(varstr) + + var_means_list = [] + + print(var) if kwargs.get('interp') is None: print(" Apply gridding, no interpolation") # grid data - gridvar, lon_grid, lat_grid = \ - grid_point_cloud_ds(var[0], var[1], var[2], t, **kwargs) - - var_means = gridvar['mor'] - # transpose dims - var_means = np.transpose(var_means) - field_shape = (list(var_means.shape[::-1]) + [1])[::-1] - var_means = var_means.reshape(field_shape) + + for i in range(len_varstr): + gridvar, lon_grid, lat_grid = \ + grid_point_cloud_ds(var[i], + var[len_varstr], + var[len_varstr+1], + t, + **kwargs) + + var_means = gridvar['mor'] + # transpose dims + var_means = np.transpose(var_means) + field_shape = (list(var_means.shape[::-1]) + [1])[::-1] + var_means = var_means.reshape(field_shape) + + var_means_list.append(var_means) + else: print(" Apply gridding with interpolation") - var_means, lon_grid, lat_grid = \ - grid_point_cloud_interp_ds( - var[0], var[1], var[2], **kwargs) - # transpose dims - field_shape = (list(var_means.shape[::-1]) + [1])[::-1] - var_means = var_means.reshape(field_shape) + for i in range(len_varstr): + + var_means, lon_grid, lat_grid = \ + grid_point_cloud_interp_ds( + var[i], + var[len_varstr], + var[len_varstr+1], + **kwargs) + + # transpose dims + field_shape = (list(var_means.shape[::-1]) + [1])[::-1] + var_means = var_means.reshape(field_shape) + + var_means_list.append(var_means) + # create xr.dataset - ds = build_xr_ds_grid( - var_means, + ds = build_xr_ds_grid_multivar( + var_means_list, np.unique(lon_grid), np.unique(lat_grid), t.reshape((1,)), varstr=varstr) @@ -158,6 +185,41 @@ def build_xr_ds_grid(var_means, lon_grid, lat_grid, t, **kwargs): ) return ds + +def build_xr_ds_grid_multivar(var_means_list, lon_grid, lat_grid, t, **kwargs): + print(" building xarray dataset from grid") + varstr = kwargs.get('varstr') + if isinstance(varstr, str): + varstr = [varstr] + + ds = xr.Dataset({ + **{varstr[i]: xr.DataArray( + data=var_means_list[i], + dims=['time', 'latitude', 'longitude'], + coords={'latitude': lat_grid, + 'longitude': lon_grid, + 'time': t}, + attrs=variable_def[varstr[i]], + ) for i in range(len(var_means_list))}, + 'lons': xr.DataArray( + data=lon_grid, + dims=['longitude'], + coords={'longitude': lon_grid}, + attrs=variable_def['lons'], + ), + 'lats': xr.DataArray( + data=lat_grid, + dims=['latitude'], + coords={'latitude': lat_grid}, + attrs=variable_def['lats'], + ), + }, + attrs={'title': 'wavy dataset'} + ) + return ds + + + def build_xr_ds_grid_2D(var_means, lon_grid, lat_grid, t, **kwargs): print(" building xarray dataset from grid") varstr = kwargs.get('varstr') diff --git a/wavy/gridder_module.py b/wavy/gridder_module.py index 8c4049b..bf0479a 100755 --- a/wavy/gridder_module.py +++ b/wavy/gridder_module.py @@ -7,6 +7,7 @@ from wavy.wconfig import load_or_default validation_metric_abbreviations = load_or_default('validation_metrics.yaml') +variable_def = load_or_default('variable_def.yaml') class gridder_class(): @@ -28,40 +29,67 @@ def __init__( self.varalias = kwargs.get('varalias', oco.varalias) if isinstance(self.varalias, list): if len(self.varalias) > 1: - print("Warning: gridder does not work with more than one \ - variable at the moment.") - print("First variable selected as default: {}".format( - self.varalias[0])) + print("Warning: gridder only expects one varalias.") + print("First varalias selected as default: {}".format( + self.varalias[0])) + print("If you want to select another variable, please "\ + +"specify with varalias argument.") self.varalias=self.varalias[0] - self.stdvarname = oco.stdvarname - if isinstance(self.stdvarname, list): - self.stdvarname=self.stdvarname[0] - self.units = oco.units - if isinstance(self.units, list): - self.units = self.units[0] + self.units = variable_def[self.varalias].get('units') + self.stdvarname = variable_def[self.varalias].get('standard_name') self.olons = np.array(oco.vars['lons'].squeeze().values.ravel()) self.olats = np.array(oco.vars['lats'].squeeze().values.ravel()) self.ovals = np.array(oco.vars[self.varalias].squeeze().values.ravel()) self.sdate = oco.vars['time'][0] self.edate = oco.vars['time'][-1] elif cco is not None: + self.varalias = kwargs.get('varalias', cco.varalias) + if isinstance(self.varalias, list): + if len(self.varalias) > 1: + print("Warning: gridder only expects one varalias.") + print("First varalias selected as default: {}".format( + self.varalias[0])) + print("If you want to select another variable, please "\ + +"specify with varalias argument.") + self.varalias=self.varalias[0] + + list_vars = list(cco.vars.variables) + assert 'model_'+self.varalias in list_vars, "model_{}".format(self.varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + assert 'obs_'+self.varalias in list_vars, "obs_{}".format(self.varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + self.olons = np.array(cco.vars['obs_lons']) self.olats = np.array(cco.vars['obs_lats']) - self.ovals = np.array(cco.vars['obs_values']) - self.mvals = np.array(cco.vars['model_values']) - self.stdvarname = cco.stdvarname - self.varalias = cco.varalias - self.units = cco.units + self.ovals = np.array(cco.vars['obs_'+self.varalias]) + self.mvals = np.array(cco.vars['model_'+self.varalias]) + self.units = variable_def[self.varalias].get('units') + self.stdvarname = variable_def[self.varalias].get('standard_name') self.sdate = cco.vars['time'][0] self.edate = cco.vars['time'][-1] elif mco is not None: + self.varalias = kwargs.get('varalias', mco.varalias) + if isinstance(self.varalias, list): + if len(self.varalias) > 1: + print("Warning: gridder only expects one varalias.") + print("First varalias selected as default: {}".format( + self.varalias[0])) + print("If you want to select another variable, please "\ + +"specify with varalias argument.") + self.varalias=self.varalias[0] self.olons = np.array(mco.vars.lons.squeeze().values.flatten()) self.olats = np.array(mco.vars.lats.squeeze().values.flatten()) self.ovals = np.array( - mco.vars[mco.varalias].squeeze().values.flatten()) + mco.vars[self.varalias].squeeze().values.flatten()) self.stdvarname = mco.stdvarname - self.varalias = mco.varalias - self.units = mco.units + self.units = variable_def[self.varalias].get('units') + self.stdvarname = variable_def[self.varalias].get('standard_name') self.sdate = mco.vars['time'][0] self.edate = mco.vars['time'][-1] else: diff --git a/wavy/insitu_module.py b/wavy/insitu_module.py index 5971450..993cb3e 100755 --- a/wavy/insitu_module.py +++ b/wavy/insitu_module.py @@ -103,6 +103,9 @@ def __init__(self, **kwargs): self.distlim = kwargs.get('distlim', 6) self.filter = kwargs.get('filter', False) self.region = kwargs.get('region', 'global') + depth_lvls = kwargs.get('depth_lvls', None) + if depth_lvls is not None: + self.depth_lvls = depth_lvls self.cfg = dc print(" ") print(" ### insitu_class object initialized ### ") diff --git a/wavy/insitu_readers.py b/wavy/insitu_readers.py index e15695c..5f10c84 100755 --- a/wavy/insitu_readers.py +++ b/wavy/insitu_readers.py @@ -369,6 +369,8 @@ def get_cmems(**kwargs): varalias = [varalias] pathlst = kwargs.get('pathlst') cfg = vars(kwargs['cfg']) + depth_lvls = kwargs.get('depth_lvls', None) + # check if dimensions are fixed fixed_dim_str = list(cfg['misc']['fixed_dim'].keys())[0] fixed_dim_idx = cfg['misc']['fixed_dim'][fixed_dim_str] @@ -399,14 +401,30 @@ def get_cmems(**kwargs): for coord in list(ds.coords) if coord in [lonstr, latstr, timestr]} + len_timestr = len(dict_var[timestr]) + + for coord in [lonstr, latstr]: + if len(dict_var[coord].shape)==0: + dict_var[coord] = np.array([dict_var[coord]]*len_timestr) + + list_vars_tmp = list(ds.data_vars) + + if depth_lvls is not None: + + dict_var.update({var: ds.sel(DEPTH=depth_lvls[var])[var]\ + .values for var in depth_lvls.keys()}) + + list_vars_tmp = [k for k in list(ds.data_vars) if k + not in list(depth_lvls.keys())] + dict_var.update({var: rebuild_split_variable(ds, fixed_dim_str, var) - for var in list(ds.data_vars)}) + for var in list_vars_tmp}) # build an xr.dataset with timestr as the only coordinate # using build_xr_ds function ds_list.append(build_xr_ds_cmems(dict_var, timestr)) - + except Exception as e: logger.exception(e) diff --git a/wavy/model_module.py b/wavy/model_module.py index 432e303..481469a 100755 --- a/wavy/model_module.py +++ b/wavy/model_module.py @@ -152,9 +152,12 @@ def __init__(self, **kwargs): # add other class object variables self.nID = kwargs.get('nID') self.model = kwargs.get('model', self.nID) - self.varalias = kwargs.get('varalias', 'Hs') - self.units = variable_def[self.varalias].get('units') - self.stdvarname = variable_def[self.varalias].get('standard_name') + self.varalias = kwargs.get('varalias', ['Hs']) + if isinstance(self.varalias, str): + self.varalias = [self.varalias] + self.units = [variable_def[v].get('units') for v in self.varalias] + self.stdvarname = [variable_def[v].get('standard_name') for v in\ + self.varalias] self.distlim = kwargs.get('distlim', 6) self.filter = kwargs.get('filter', False) self.region = kwargs.get('region', 'global') @@ -546,27 +549,29 @@ def _enforce_longitude_format(self): def _enforce_meteorologic_convention(self): print(' enforcing meteorological convention') - if ('convention' in vars(self.cfg)['misc'].keys() and - vars(self.cfg)['misc']['convention'] == 'oceanographic'): - print('Convert from oceanographic to meteorologic convention') - self.vars[self.varalias] =\ - convert_meteorologic_oceanographic(self.vars[self.varalias]) - elif 'to_direction' in self.vars[self.varalias].attrs['standard_name']: - print('Convert from oceanographic to meteorologic convention') - self.vars[self.varalias] =\ - convert_meteorologic_oceanographic(self.vars[self.varalias]) + for v in self.varalias: + if ('convention' in vars(self.cfg)['misc'].keys() and + vars(self.cfg)['misc']['convention'] == 'oceanographic'): + print('Convert from oceanographic to meteorologic convention') + + self.vars[v] = convert_meteorologic_oceanographic(self.vars[v]) + + elif 'to_direction' in self.vars[v].attrs['standard_name']: + print('Convert from oceanographic to meteorologic convention') + self.vars[v] = convert_meteorologic_oceanographic(self.vars[v]) return self def _change_varname_to_aliases(self): print(' changing variables to aliases') # variables - ncvar = get_filevarname(self.varalias, variable_def, - vars(self.cfg), self.meta) - if self.varalias in list(self.vars.keys()): - print(' ', ncvar, 'is alreade named correctly and' - + ' therefore not adjusted') - else: - self.vars = self.vars.rename({ncvar: self.varalias}) + for v in self.varalias: + ncvar = get_filevarname(v, variable_def, + vars(self.cfg), self.meta) + if v in list(self.vars.keys()): + print(' ', ncvar, 'is alreade named correctly and' + + ' therefore not adjusted') + else: + self.vars = self.vars.rename({ncvar: v}) # coords coords = ['time', 'lons', 'lats'] for c in coords: @@ -590,8 +595,9 @@ def _change_stdvarname_to_cfname(self): self.vars['time'].attrs['standard_name'] = \ variable_def['time'].get('standard_name') # enforce standard_name for variable alias - self.vars[self.varalias].attrs['standard_name'] = \ - self.stdvarname + for i in range(len(self.varalias)): + self.vars[self.varalias[i]].attrs['standard_name'] = \ + self.stdvarname[i] return self def populate(self, **kwargs): @@ -607,8 +613,10 @@ def populate(self, **kwargs): print('') print('Checking variables..') self.meta = ncdumpMeta(self.pathlst[0]) - ncvar = get_filevarname(self.varalias, variable_def, - vars(self.cfg), self.meta) + ncvar = [get_filevarname(v, variable_def, + vars(self.cfg), + self.meta) for \ + v in self.varalias] print('') print('Choosing reader..') # define reader diff --git a/wavy/model_readers.py b/wavy/model_readers.py index 217c58a..17b14c7 100755 --- a/wavy/model_readers.py +++ b/wavy/model_readers.py @@ -36,6 +36,8 @@ def read_ww3_4km(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] ds_lst = [] # retrieve sliced data for i in range(len(fc_dates)): @@ -43,9 +45,8 @@ def read_ww3_4km(**kwargs): p = pathlst[i] ds = xr.open_dataset(p, engine='netcdf4') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}) - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], - model_dict[nID]['vardef']['lats']]] + ds_sliced = ds_sliced[varname + [model_dict[nID]['vardef']['lons'], + model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) @@ -68,6 +69,8 @@ def read_meps(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] ds_lst = [] # retrieve sliced data for i in range(len(fc_dates)): @@ -75,9 +78,8 @@ def read_meps(**kwargs): p = pathlst[i] ds = xr.open_dataset(p, engine='netcdf4') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}) - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], - model_dict[nID]['vardef']['lats']]] + ds_sliced = ds_sliced[varname + [model_dict[nID]['vardef']['lons'], + model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) @@ -101,7 +103,11 @@ def read_noresm_making_waves(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, list): + varname = varname[0] varalias = kwargs.get('varalias') + if isinstance(varalias, list): + varalias = varalias[0] timename = model_dict[nID]['vardef']['time'] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] @@ -148,12 +154,14 @@ def read_remote_ncfiles_aggregated_credentials(**kwargs): ed = kwargs.get('ed') twin = kwargs.get('twin') varalias = kwargs.get('varalias') + if isinstance(varalias, str): + varalias = [varalias] nID = kwargs.get('nID') remoteHostName = kwargs.get('remoteHostName') path = kwargs.get('pathlst')[0] # varnames - varname = model_dict[nID]['vardef'][varalias] + varname = [model_dict[nID]['vardef'][v] for v in varalias] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] timename = model_dict[nID]['vardef']['time'] @@ -170,10 +178,10 @@ def read_remote_ncfiles_aggregated_credentials(**kwargs): path, remoteHostName, usr, pw) ds_sliced = ds.sel(time=slice(sd, ed)) - var_sliced = ds_sliced[[varname, lonsname, latsname]] + var_sliced = ds_sliced[varname + [lonsname, latsname]] # forge into correct format varalias, lons, lats with dim time - ds = build_xr_ds_grid(var_sliced[varname], + ds = build_xr_ds_grid_multivar([var_sliced[v] for v in varname], var_sliced[lonsname], var_sliced[latsname], var_sliced[timename], @@ -203,6 +211,8 @@ def read_field(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, list): + varname = varname[0] timename = model_dict[nID]['vardef']['time'] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] @@ -249,7 +259,11 @@ def read_ecwam(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, list): + varname = varname[0] varalias = kwargs.get('varalias') + if isinstance(varalias, list): + varalias = varalias[0] timename = model_dict[nID]['vardef']['time'] lonsname = model_dict[nID]['vardef']['lons'] latsname = model_dict[nID]['vardef']['lats'] @@ -291,6 +305,8 @@ def read_era(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] ds_lst = [] # retrieve sliced data for i in range(len(fc_dates)): @@ -300,8 +316,8 @@ def read_era(**kwargs): ds = xr.open_dataset(p, engine='h5netcdf') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}, method='nearest') - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], + ds_sliced = ds_sliced[varname+ + [model_dict[nID]['vardef']['lons'], model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) @@ -326,6 +342,8 @@ def read_NORA3_wind(**kwargs): nID = kwargs.get('nID') fc_dates = kwargs.get('fc_dates') varname = kwargs.get('varname') + if isinstance(varname, str): + varname = [varname] hlevel = kwargs.get('heightlevel', 10) ds_lst = [] # retrieve sliced data @@ -335,8 +353,8 @@ def read_NORA3_wind(**kwargs): ds = xr.open_dataset(p, engine='netcdf4') ds_sliced = ds.sel({model_dict[nID]['vardef']['time']: d}) ds_sliced = ds_sliced.sel({'height': hlevel}) - ds_sliced = ds_sliced[[varname, - model_dict[nID]['vardef']['lons'], + ds_sliced = ds_sliced[varname + + [model_dict[nID]['vardef']['lons'], model_dict[nID]['vardef']['lats']]] ds_lst.append(ds_sliced) diff --git a/wavy/quicklookmod.py b/wavy/quicklookmod.py index efaf927..29066bb 100755 --- a/wavy/quicklookmod.py +++ b/wavy/quicklookmod.py @@ -51,6 +51,8 @@ def quicklook(self, a=False, projection=None, **kwargs): if isinstance(self.varalias, list): varalias = kwargs.get('varalias', self.varalias[0]) + assert varalias in self.varalias, "varalias must be one of {}"\ + .format(self.varalias) assert isinstance(varalias, str), "varalias argument should be a string" idx_units = np.argwhere(np.array(self.varalias)==varalias)[0][0] units_to_plot = self.units[idx_units] @@ -64,11 +66,22 @@ def quicklook(self, a=False, projection=None, **kwargs): plot_lons = self.vars.lons plot_lats = self.vars.lats except Exception as e: - plot_var = self.vars.obs_values + list_vars = list(self.vars.variables) + assert 'model_'+varalias in list_vars,"model_{}".format(varalias)+\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + assert 'obs_'+varalias in list_vars, "obs_{}".format(varalias) +\ + " is missing in "+\ + "the dataset, if you would like to "+\ + "validate another variable, please "+\ + "specify with varalias." + plot_var = self.vars["obs_"+varalias] plot_lons = self.vars.obs_lons plot_lats = self.vars.obs_lats - plot_var_obs = self.vars.obs_values - plot_var_model = self.vars.model_values + plot_var_obs = self.vars["obs_"+varalias] + plot_var_model = self.vars["model_"+varalias] if str(type(self)) == "": if len(plot_lons.shape) < 2: @@ -401,8 +414,8 @@ def quicklook(self, a=False, projection=None, **kwargs): plt.xlabel('obs (' + self.nID + ')') plt.ylabel('models (' + self.model + ')') - maxv = np.nanmax([self.vars['model_values'], - self.vars['obs_values']]) + maxv = np.nanmax([self.vars['model_'+varalias], + self.vars['obs_'+varalias]]) minv = 0 plt.xlim([minv, maxv*1.05]) plt.ylim([minv, maxv*1.05]) diff --git a/wavy/satellite_readers.py b/wavy/satellite_readers.py index f48aa22..a101887 100755 --- a/wavy/satellite_readers.py +++ b/wavy/satellite_readers.py @@ -179,7 +179,7 @@ def read_local_20Hz_files(**kwargs): sd = kwargs.get('sd') ed = kwargs.get('ed') twin = kwargs.get('twin') - + nc_dim = satellite_dict[nID].get('misc',{}).get('nc_dim', 'time') # adjust start and end sd = sd - timedelta(minutes=twin) ed = ed + timedelta(minutes=twin) @@ -195,7 +195,7 @@ def read_local_20Hz_files(**kwargs): satellite_dict[nID], ncmeta) # retrieve sliced data - ds = read_netcdfs(pathlst) + ds = read_netcdfs(pathlst, dim=nc_dim) ds_sort = ds.sortby(timestr) # get indices for included time period diff --git a/wavy/triple_collocation.py b/wavy/triple_collocation.py index be7099f..8afa621 100644 --- a/wavy/triple_collocation.py +++ b/wavy/triple_collocation.py @@ -103,7 +103,7 @@ def filter_dynamic_collocation(data, mod_1, mod_2, max_rel_diff=0.05): mod_1 = np.array(mod_1) mod_2 = np.array(mod_2) - idx = np.abs(mod_1 - mod_2)/mod_1 < 0.05 + idx = np.abs(mod_1 - mod_2)/mod_1 < max_rel_diff data_filtered = {} @@ -590,114 +590,84 @@ def least_squares_merging(data, tc_results=None, return_var=False, **kwargs): return least_squares_merge else: return least_squares_merge, least_squares_var + - -def spectra_cco(cco, fs, returns='average',nsample=64): +def get_mean_spectra(ds, varname, fs, nsample, + median_step=None, mode='average', + window='hamming'): """ - Calculate the power spectra for both model and observation series - of a collocation class object. The series are separated into - samples of a given length. Up to one missing value for each - sample is filled with interpolation. A Hamming window is applied - for each sample. - - cco (collocation_class object or xarray dataset): collocation class object - for which the spectra are calculated. If xarray.dataset - is given, it must have a single time dimension and obs_values - and model_values variables. - fs (float): frequency of the time series - returns (str): If 'average' returns the average power spectra of the time series. - If 'list' returns lists of power spectra of each samples of the given series. - nsample (int): size of the samples to calculate the power spectra. - + Divides a given time series into sample of given size, applies a window to + each sample and calculates the power spectra for each sample, using a Fast + Fourier transform. Returns the frequencies and either the list of the + spectra for each sample or the average spectra over all samples. + + ds (xarray dataset): xarray dataset with dimension time + varname (str): name of the variable for which the spectra is + to be computed, present in ds and indexed by time + fs (float): sampling frequency + nsample (int): number of points of each sample + median_step (np.timedelta64): time to consider between each point of the + time series + mode (str): either 'average' to return the mean power spectra or + 'list' to return the list of the power spectra + window (str): window to apply to the samples before applying the + FFT. See scipy.singal.periodogram for options. + + return: - PS_mod (numpy array): average spectra of cco model values or individual spectra for each sample - PS_obs (numpy array): average spectra of cco observation values or individual spectra for each sample - f (numpy array): frequencies for the spectra + df_spectra (pandas DataFrame): dataframe containing the mean power spectra + or the power spectra for each sample and the + corresponding frequencies. """ from scipy.signal import periodogram - from wavy.collocation_module import collocation_class as cc - - if isinstance(cco, cc): - cco_vars = cco.vars.dropna(dim='time') - elif isinstance(cco, xr.Dataset): - cco_vars = cco.dropna(dim='time') - - step_cco = [(cco_vars.time.values[i+1] - cco_vars.time.values[i])/np.timedelta64(1,'s') for\ - i in range(len(cco_vars.time)-1)] - step_cco.append(np.nan) - median_step = np.median(step_cco) - - i_0 = 0 - i_1 = i_0 + nsample - - list_PS_obs = [] - list_PS_mod = [] - - while i_1 < len(cco_vars.time.values): - - sample_tmp = cco_vars.isel(time=range(i_0,i_1)) + + ds = ds.dropna(dim='time') - step_tmp = [(sample_tmp.time.values[i+1] - sample_tmp.time.values[i])/np.timedelta64(1,'s') for\ - i in range(len(sample_tmp.time)-1)] - step_tmp.append(np.nan) - - idx_1_na = np.argwhere((step_tmp > 1.5*median_step) & (step_tmp < 2.5*median_step)) - idx_2_na = np.argwhere(step_tmp > 2.5*median_step) + diff_time = ds.time.diff(dim='time').values + if median_step == None: + median_step = np.timedelta64(np.nanmedian(diff_time),'ns') + else: + median_step = np.timedelta64(median_step, 'ns') - if len(idx_2_na) > 0: - i_0 = idx_2_na[-1][0] - i_1 = i_0 + nsample - continue + sample_period = nsample * median_step - for idx in idx_1_na: - idx = idx[0] - - time_idx = sample_tmp.time.values[idx] - obs_idx = sample_tmp.obs_values.values[idx] - mod_idx = sample_tmp.model_values.values[idx] - time_idx_1 = sample_tmp.time.values[idx+1] - obs_idx_1 = sample_tmp.obs_values.values[idx+1] - mod_idx_1 = sample_tmp.model_values.values[idx+1] - - time_between = time_idx + np.timedelta64(int(1000*median_step),'ms') - obs_between= (obs_idx + obs_idx_1)/2 - mod_between = (mod_idx + mod_idx_1)/2 - - ds_between = xr.Dataset( - data_vars={'obs_values': (('time'), [obs_between]), - 'model_values': (('time'), [mod_between])}, - coords={'time': [time_between]} - ) + run_sum_period = [np.sum(diff_time[i:i+nsample]) for\ + i in range(len(diff_time)-nsample)] - sample_tmp = xr.concat([sample_tmp, ds_between], dim='time').sortby('time') + up_condition = (nsample*median_step + 0.5*median_step > run_sum_period) + low_condition = (nsample*median_step - 0.5*median_step <= run_sum_period) + condition = up_condition & low_condition - sample_tmp = sample_tmp.isel(time=range(0,nsample)) - - obs_val_tmp = sample_tmp.obs_values.values - mod_val_tmp = sample_tmp.model_values.values - - f, PS_obs_tmp = periodogram(obs_val_tmp, fs=fs, window='hamming') - f, PS_mod_tmp = periodogram(mod_val_tmp, fs=fs, window='hamming') + idx_right_length = np.argwhere(up_condition & low_condition).flatten() - list_PS_obs.append(PS_obs_tmp) - list_PS_mod.append(PS_mod_tmp) + list_PS = [] - i_0 = i_1 - len(idx_1_na) - i_1 = i_0 + nsample + i = idx_right_length[0] - mean_PS_mod = np.mean(np.array(list_PS_mod), axis=0) - mean_PS_obs = np.mean(np.array(list_PS_obs), axis=0) + while i < idx_right_length[-1] - nsample: - if returns == 'list': - PS_mod = np.array(list_PS_mod) - PS_obs = np.array(list_PS_obs) - elif returns == 'average': - PS_mod = mean_PS_mod - PS_obs = mean_PS_obs - return PS_mod, PS_obs, f + ds_tmp = ds.isel(time=range(i,i+nsample)) + sample_tmp = ds_tmp[varname].values + + f, PS_tmp = periodogram(sample_tmp, fs=fs, window=window) + + list_PS.append(PS_tmp[1:]) + i = idx_right_length[np.argwhere(idx_right_length >= i + nsample).\ + flatten()[0]] + + if mode == 'average': + mean_PS = np.mean(np.array(list_PS), axis=0) + df_spectra = pd.DataFrame({'f':f[1:], 'spectra': mean_PS}) + return df_spectra + else: + df_spectra = pd.DataFrame({'f':f[1:], + **{'spectra_'+str(i):ps for i, ps in enumerate(list_PS)}}) + return df_spectra + + def integrate_r2(PS_mod, PS_obs, f, threshold=np.inf, threshold_type='inv_freq'): """ Estimates the representativeness error r2 by integrating the difference @@ -726,7 +696,6 @@ def integrate_r2(PS_mod, PS_obs, f, threshold=np.inf, threshold_type='inv_freq') else: threshold = 1/threshold - f_1 = [1/f[i] if f[i] != 0 else np.inf for i in range(len(f)) ] idx_threshold = np.argwhere(np.array(f_1) <= threshold)[0][0] r2 = np.sum(weighted_diff_PS[idx_threshold:])