diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3d8de4..09a2f56 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,21 +18,10 @@ jobs: uses: actions/setup-python@v3 with: python-version: '3.10' - - name: Add conda to system path - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - echo $CONDA/bin >> $GITHUB_PATH - - name: Install dependencies - run: | - conda env update --file environment.yml --name base - - name: Lint with flake8 - run: | - conda install flake8 - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - uses: mamba-org/setup-micromamba@v2 + with: + environment-file: environment.yml + generate-run-shell: true - name: Test with pytest - run: | - conda install pytest - pytest + run: pytest + shell: micromamba-shell {0} diff --git a/.gitignore b/.gitignore index aebc337..b164271 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ Untitled* # python package *.egg-info +*build # Jupyterhub checkpoints *.ipynb_checkpoints* diff --git a/environment.yml b/environment.yml index 894f9fe..7970f7f 100644 --- a/environment.yml +++ b/environment.yml @@ -2,20 +2,20 @@ name: pywk99 channels: - conda-forge dependencies: - - python>=3.9 - - pytest - - numpy - - pandas - - xarray - - netcdf4 - - shapely - - matplotlib - cartopy - cmocean - - ipython + - dask + - healpy - ipykernel + - ipython + - matplotlib + - netcdf4 + - numpy + - pandas - pip - - dask + - pytest + - python>=3.9 + - shapely + - xarray - pip: - - -e . - + - -e . \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ddfe1cd..7b916bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,13 @@ authors = [ {name = "MPI-M"}, ] dependencies = [ - "numpy", - "xarray", - "netcdf4", + "cartopy", "matplotlib", + "netcdf4", + "numpy", + "pandas", "shapely", - "cartopy", + "xarray", ] description = "Wheeler and Kiladis Wavenumber-Frequency Analysis in Python" requires-python = ">=3.10" @@ -26,4 +27,4 @@ classifiers=[ dynamic = ["version"] [tool.setuptools.dynamic] -version = {attr = "pywk99.__version__"} \ No newline at end of file +version = {attr = "pywk99.__version__"} diff --git a/pywk99/__init__.py b/pywk99/__init__.py index 68b4771..8acf24e 100644 --- a/pywk99/__init__.py +++ b/pywk99/__init__.py @@ -3,6 +3,9 @@ Changelog --------- +Apr 8, 2025: Version 0.4.3 + - Adding support for healpix + Mar 7, 2025: Version 0.4.2 - Organizing first github public version. @@ -67,4 +70,4 @@ - Several bugs where detected and corrected with the tests. """ -__version__ = "0.4.2" +__version__ = "0.4.3" diff --git a/pywk99/filter/filter.py b/pywk99/filter/filter.py index 3ea2384..1285c44 100644 --- a/pywk99/filter/filter.py +++ b/pywk99/filter/filter.py @@ -1,6 +1,6 @@ """Filter for equatorial wave bands following Wheeler and Kiladis 1999.""" -from typing import Union +from typing import Optional, Union, Tuple import xarray as xr import numpy as np @@ -11,19 +11,26 @@ from pywk99.timeseries.timeseries import remove_seasonal_cycle from pywk99.timeseries.timeseries import check_variable_coordinates_are_sorted from pywk99.filter.window import FilterPoint, FilterWindow - - -def filter_variable(variable: xr.DataArray, - filter_windows: Union[FilterWindow, list[FilterWindow]], - taper: bool = True, - taper_alpha: float = 0.5, - rm_seasonal_cycle: bool = True, - rm_linear_trend: bool = True) -> xr.Dataset: - modified_variable = _preprocess(variable, - taper, taper_alpha, - rm_seasonal_cycle, - rm_linear_trend) - spectrum = fourier_transform(modified_variable) +from pywk99.grid.grid import dataarray_to_equatorial_latlon_grid + + +def filter_variable( + variable: xr.DataArray, + filter_windows: Union[FilterWindow, list[FilterWindow]], + taper: bool = True, + taper_alpha: float = 0.5, + rm_seasonal_cycle: bool = True, + rm_linear_trend: bool = True, + grid_type: str = "latlon", + grid_dict: Optional[dict] = None, +) -> xr.Dataset: + new_variable = dataarray_to_equatorial_latlon_grid( + variable, grid_type, grid_dict + ) + new_variable = _preprocess( + new_variable, taper, taper_alpha, rm_seasonal_cycle, rm_linear_trend + ) + spectrum = fourier_transform(new_variable) if isinstance(filter_windows, FilterWindow): filter_windows = [filter_windows] data_vars = dict() @@ -31,34 +38,37 @@ def filter_variable(variable: xr.DataArray, for window in filter_windows: field_name = _set_field_name(window.name, seen_window_names) seen_window_names.append(window.name) - data_vars[field_name] = _filter(spectrum, window, variable.coords) - filtered_variables = xr.Dataset(data_vars=data_vars) + data_vars[field_name] = _filter(spectrum, window, new_variable.coords) + filtered_variables = xr.Dataset(data_vars=data_vars) return filtered_variables -def modify_spectrum(spectrum: xr.DataArray, - filter_window: FilterWindow, - action: str = "filter") -> xr.DataArray: +def modify_spectrum( + spectrum: xr.DataArray, filter_window: FilterWindow, action: str = "filter" +) -> xr.DataArray: """Filter the spectrum with the wave filter window.""" mask = _get_window_mask(spectrum, filter_window, action) modified_spectrum = mask * spectrum return modified_spectrum -def _filter(spectrum: xr.DataArray, - filter_window: FilterWindow, - xarray_coords) -> xr.DataArray: +def _filter( + spectrum: xr.DataArray, filter_window: FilterWindow, xarray_coords +) -> xr.DataArray: masked_spectrum = modify_spectrum(spectrum, filter_window, "filter") - filtered_variable = inverse_fourier_transform(masked_spectrum, xarray_coords) + filtered_variable = inverse_fourier_transform( + masked_spectrum, xarray_coords + ) return filtered_variable -def _preprocess(variable: Union[xr.DataArray, xr.Dataset], - taper: bool, - taper_alpha: float, - rm_seasonal_cycle: bool, - rm_linear_trend: bool - ) -> Union[xr.DataArray, xr.Dataset]: +def _preprocess( + variable: Union[xr.DataArray, xr.Dataset], + taper: bool, + taper_alpha: float, + rm_seasonal_cycle: bool, + rm_linear_trend: bool, +) -> Union[xr.DataArray, xr.Dataset]: check_variable_coordinates_are_sorted(variable) modified_variable = variable.transpose("time", "lon", ...) if rm_linear_trend: @@ -66,8 +76,9 @@ def _preprocess(variable: Union[xr.DataArray, xr.Dataset], if rm_seasonal_cycle: modified_variable = remove_seasonal_cycle(modified_variable) if taper: - modified_variable = taper_variable_time_ends(modified_variable, - taper_alpha) + modified_variable = taper_variable_time_ends( + modified_variable, taper_alpha + ) return modified_variable @@ -79,27 +90,31 @@ def _set_field_name(window_name: str, seen_window_names: list[str]) -> str: return f"{window_name}{name_count + 1}" -def _get_window_mask(spectrum: xr.DataArray, - filter_window: FilterWindow, - action: str = "filter") -> xr.DataArray: +def _get_window_mask( + spectrum: xr.DataArray, filter_window: FilterWindow, action: str = "filter" +) -> xr.DataArray: """Get a wavenumber-frequency mask corresponding to the filter window.""" - bbox_wavenumbers, bbox_frequencies = \ - _window_bbox_wavenumber_and_frequencies( - spectrum, filter_window - ) + bbox_wavenumbers, bbox_frequencies = ( + _window_bbox_wavenumber_and_frequencies(spectrum, filter_window) + ) mask = _get_mask_base(spectrum, action) include_fft_reflection = bool(np.any(spectrum.frequency.values < 0)) for wavenumber in bbox_wavenumbers.values: for frequency in bbox_frequencies.values: - mask = _modify_mask_value_at_point(mask, filter_window, action, - wavenumber, frequency, - include_fft_reflection) + mask = _modify_mask_value_at_point( + mask, + filter_window, + action, + wavenumber, + frequency, + include_fft_reflection, + ) return mask def _window_bbox_wavenumber_and_frequencies( - spectrum: xr.DataArray, - filter_window: FilterWindow) -> tuple[np.ndarray, np.ndarray]: + spectrum: xr.DataArray, filter_window: FilterWindow +) -> Tuple[np.ndarray, np.ndarray]: """Get the wavenumbers and frequencies of the window bounding box.""" k_wmin, w_wmin, k_wmax, w_wmax = filter_window.bounds if not np.all(np.diff(spectrum.wavenumber.values) > 0): @@ -121,12 +136,13 @@ def _get_mask_base(spectrum: xr.DataArray, action: str) -> xr.DataArray: def _modify_mask_value_at_point( - mask: xr.DataArray, - filter_window: FilterWindow, - action: str, - wavenumber: float, - frequency: float, - include_fft_reflection: bool = True) -> xr.DataArray: + mask: xr.DataArray, + filter_window: FilterWindow, + action: str, + wavenumber: float, + frequency: float, + include_fft_reflection: bool = True, +) -> xr.DataArray: ACTION_VALUE = {"filter": True, "substract": False} point = FilterPoint(wavenumber, frequency) point_is_contained = filter_window.covers(point) @@ -135,12 +151,15 @@ def _modify_mask_value_at_point( mask.loc[point_loc_dict1] = ACTION_VALUE[action] if include_fft_reflection: # rounding errors in the index make the following necessary - approx_point_loc_dict2 = dict(wavenumber=-wavenumber, - frequency=-frequency) - reflection_point = mask.sel(approx_point_loc_dict2, - method='nearest') + approx_point_loc_dict2 = dict( + wavenumber=-wavenumber, frequency=-frequency + ) + reflection_point = mask.sel( + approx_point_loc_dict2, method="nearest" + ) point_loc_dict2 = dict( wavenumber=reflection_point.wavenumber.values, - frequency=reflection_point.frequency.values) + frequency=reflection_point.frequency.values, + ) mask.loc[point_loc_dict2] = ACTION_VALUE[action] return mask diff --git a/pywk99/filter/window.py b/pywk99/filter/window.py index b45762f..7b103c5 100644 --- a/pywk99/filter/window.py +++ b/pywk99/filter/window.py @@ -1,7 +1,7 @@ """Define filtering windows for various waves following Wheeler and Kiladis.""" from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Tuple, Union import numpy as np from shapely.geometry import Point, Polygon, MultiPolygon @@ -18,7 +18,7 @@ class FilterWindow: polygon: Union[Polygon, MultiPolygon] @property - def bounds(self) -> tuple[float, float, float, float]: + def bounds(self) -> Tuple[float, float, float, float]: return self.polygon.bounds def union(self, other) -> "FilterWindow": diff --git a/pywk99/grid/__init__.py b/pywk99/grid/__init__.py new file mode 100644 index 0000000..8f0bcbd --- /dev/null +++ b/pywk99/grid/__init__.py @@ -0,0 +1,5 @@ +""" +Transform grids to latlon +""" + +from pywk99.grid.grid import dataset_to_equatorial_latlon_grid \ No newline at end of file diff --git a/pywk99/grid/grid.py b/pywk99/grid/grid.py new file mode 100644 index 0000000..5bda3a1 --- /dev/null +++ b/pywk99/grid/grid.py @@ -0,0 +1,34 @@ +"""Convert datasets to latlon grid""" + +from typing import Optional +import xarray as xr + + +from pywk99.grid.healpix import dataarray_healpix_to_equatorial_latlon +from pywk99.grid.healpix import dataset_healpix_to_equatorial_latlon + + +def dataset_to_equatorial_latlon_grid( + dataset: xr.Dataset, grid_type: str, grid_dict: Optional[dict] +) -> xr.Dataset: + if grid_type == "latlon": + return dataset + elif grid_type == "healpix": + if grid_dict is None: + raise ValueError("No grid_dict provided for healpix conversion.") + return dataset_healpix_to_equatorial_latlon(dataset, **grid_dict) + else: + raise ValueError("Grid type not found.") + + +def dataarray_to_equatorial_latlon_grid( + dataarray: xr.DataArray, grid_type: str, grid_dict: Optional[dict] +) -> xr.DataArray: + if grid_type == "latlon": + return dataarray + elif grid_type == "healpix": + if grid_dict is None: + raise ValueError("No grid_dict provided for healpix conversion.") + return dataarray_healpix_to_equatorial_latlon(dataarray, **grid_dict) + else: + raise ValueError("Grid type not found.") \ No newline at end of file diff --git a/pywk99/grid/healpix.py b/pywk99/grid/healpix.py new file mode 100644 index 0000000..6cbf979 --- /dev/null +++ b/pywk99/grid/healpix.py @@ -0,0 +1,108 @@ +"""Convert healpix data to a latlon grid""" + +import xarray as xr +import numpy as np +import healpy as hp +from scipy import interpolate + + +MAXIMUM_LAT_RANGE = 25 + + +def dataset_healpix_to_equatorial_latlon( + dataset: xr.Dataset, + nside: int, + nest: str, + minmax_lat: float +) -> xr.Dataset: + """ + Extract a latlon dataarray from a healpix dataset. + + The latlon array extracted is for a band around the equator. + """ + latlon_datarrays = [] + for variable_name in dataset.data_vars: + dataarray = dataset[variable_name] + latlon_datarray_aux = dataarray_healpix_to_equatorial_latlon( + dataarray, nside, nest, minmax_lat + ) + latlon_datarray_aux.name = variable_name + latlon_datarrays.append(latlon_datarray_aux) + return xr.merge(latlon_datarrays) + + +def dataarray_healpix_to_equatorial_latlon( + healpix_dataarray: xr.DataArray, + nside: int, + nest: str, + minmax_lat: float +) -> xr.DataArray: + """ + Extract a latlon dataarray from a healpix dataset. + + The latlon array extracted is for a band around the equator. + """ + if minmax_lat > MAXIMUM_LAT_RANGE: + msg = (f"Selected latitudinal belt (minmax_lat = {minmax_lat}) is too " + "wide for a meaningful analysis of equatorial waves.") + raise ValueError(msg) + # get data + data = healpix_dataarray.values + time = healpix_dataarray.time.values + lat, lon = _get_pix_latlon(nside, nest) + + # get latitudes + unique_lats = np.unique(np.round(lat, 10)) + unique_lats = unique_lats[unique_lats <= minmax_lat] + unique_lats = unique_lats[unique_lats >= -minmax_lat] + + # get final longitudes + final_lons = lon[np.where(np.round(lat, 10) == unique_lats[0])] + final_lons = np.sort(final_lons) + + # sort and remap, if needed, to final longitudes + ntime = len(time) + nlon = len(final_lons) + nlat = len(unique_lats) + resampled_data = np.zeros((ntime, nlon, nlat)) + for i, unique_lat in enumerate(unique_lats): + # get values + ring_index = np.where(np.round(lat, 10) == unique_lat)[0] + lons = lon[ring_index] + data_ring = data[:, ring_index] + # sort + sorted_index = np.argsort(lons) + sorted_lons = lons[sorted_index] + sorted_data_ring = data_ring[:, sorted_index] + # interpolate + if i % 2 != 0: + sorted_data_ring = _interp_array_along_first_axis( + final_lons, sorted_lons, sorted_data_ring, period=360 + ) + resampled_data[:, :, i] = sorted_data_ring + + # save as latlon dataarray + latlon_dataarray = xr.DataArray( + data=resampled_data, + dims=["time", "lon", "lat"], + coords={"time": time, "lat": unique_lats, "lon": final_lons}, + ) + return latlon_dataarray + + +def _interp_array_along_first_axis(x, xp, fp, period): + x = x % period + xp = xp % period + xp = np.concatenate((xp[-1:] - period, xp, xp[0:1] + period)) + fp = np.concatenate((fp[:, -1:], fp, fp[:, 0:1]), axis=1) + interp_func = interpolate.interp1d(xp, fp) + return interp_func(x) + + +def _get_pix_latlon(nside, nest): + if nest is False: + raise NotImplementedError("nest=False is not implemented.") + npix = hp.nside2npix(nside) + cell = hp.reorder(np.arange(npix), r2n=True) + lon, lat = hp.pix2ang(nside, cell, lonlat=True) + return lat, lon diff --git a/pywk99/spectrum/_spectrum.py b/pywk99/spectrum/_spectrum.py index f53a3f7..595e937 100644 --- a/pywk99/spectrum/_spectrum.py +++ b/pywk99/spectrum/_spectrum.py @@ -13,6 +13,7 @@ from pywk99.timeseries.timeseries import check_variable_coordinates_are_sorted from pywk99.timeseries.timeseries import taper_variable_time_ends from pywk99.timeseries.timeseries import remove_linear_trend +from pywk99.grid import dataset_to_equatorial_latlon_grid _VALID_SEASONS = ["DJF", "MAM", "JJA", "SON"] @@ -25,7 +26,9 @@ def get_spectrum(spc_quantity: str, overlap_length: Optional[str] = None, season: Optional[str] = None, min_periods_season: Optional[int] = None, - taper_alpha: Optional[float] = None) -> xr.DataArray: + taper_alpha: Optional[float] = None, + grid_type: str = None, + grid_dict: Optional[dict] = None) -> xr.DataArray: """ Get a Wheeler and Kiladis 1999 power, amplitude or cross spectrum. @@ -59,6 +62,12 @@ def get_spectrum(spc_quantity: str, default it is set to int(0.25*overlap_length/data_frequency). taper_alpha: float, optional Alpha value determining the shape of the Tukey window filter function. + grid_type: str, optional, default "latlon" + The type of grid of the dataarray. Either "latlon" or "healpix". If + "healpix" then a grid_dict must be also provided. + grid_dict: dict, optional + A dictionary with grid metadata. Used when grid_type = "healpix". The + dictionary must have keys for "nside", "nested" and "minmax_lat". Returns ------- @@ -76,9 +85,12 @@ def get_spectrum(spc_quantity: str, If the season is not recognized. """ # process inputs - check_variable_coordinates_are_sorted(variable) check_for_one_max_two_variables(variable) variable = convert_to_dataset(variable) + variable = dataset_to_equatorial_latlon_grid(variable, + grid_type, + grid_dict) + check_variable_coordinates_are_sorted(variable) window_length_np = pd.Timedelta(window_length).to_numpy() overlap_length_np = pd.Timedelta(overlap_length).to_numpy() data_frequency_np = _get_data_frequency(variable, data_frequency) diff --git a/pywk99/spectrum/spectrum.py b/pywk99/spectrum/spectrum.py index fb6c528..5228455 100644 --- a/pywk99/spectrum/spectrum.py +++ b/pywk99/spectrum/spectrum.py @@ -16,102 +16,167 @@ def get_power_spectrum( - variable: xr.DataArray, - component_type: str, - data_frequency: Optional[str] = None, - window_length: str = "96D", - overlap_length: str = "60D", - season: Optional[str] = None, - min_periods_season: Optional[int] = None, - taper_alpha: Optional[float] = 0.5) -> xr.DataArray: + variable: xr.DataArray, + component_type: str, + data_frequency: Optional[str] = None, + window_length: str = "96D", + overlap_length: str = "60D", + season: Optional[str] = None, + min_periods_season: Optional[int] = None, + taper_alpha: Optional[float] = 0.5, + grid_type: str = "latlon", + grid_dict: Optional[dict] = None, +) -> xr.DataArray: """ Get the Wheeler and Kiladis 1999 power spectrum of a variable. See pywk99.spectrum.get_spectrum for argument documentation. """ power_spectrum = get_spectrum( - "power", variable, component_type, data_frequency, window_length, - overlap_length, season, min_periods_season, taper_alpha) + "power", + variable, + component_type, + data_frequency, + window_length, + overlap_length, + season, + min_periods_season, + taper_alpha, + grid_type, + grid_dict + ) return power_spectrum def get_amplitude_spectrum( - variable: xr.DataArray, - component_type: str, - data_frequency: Optional[str] = None, - window_length: str = "96D", - overlap_length: str = "60D", - season: Optional[str] = None, - min_periods_season: Optional[int] = None, - taper_alpha: Optional[float] = 0.5) -> xr.DataArray: + variable: xr.DataArray, + component_type: str, + data_frequency: Optional[str] = None, + window_length: str = "96D", + overlap_length: str = "60D", + season: Optional[str] = None, + min_periods_season: Optional[int] = None, + taper_alpha: Optional[float] = 0.5, + grid_type: str = "latlon", + grid_dict: Optional[dict] = None, +) -> xr.DataArray: """ Get the Wheeler and Kiladis 1999 amplitude spectrum of a variable. See pywk99.spectrum.get_spectrum for argument documentation. """ amplitude_spectrum = get_spectrum( - "amplitude", variable, component_type, data_frequency, window_length, - overlap_length, season, min_periods_season, taper_alpha) + "amplitude", + variable, + component_type, + data_frequency, + window_length, + overlap_length, + season, + min_periods_season, + taper_alpha, + grid_type, + grid_dict + ) return amplitude_spectrum def get_cross_spectrum( - variable: xr.Dataset, - component_type: str, - data_frequency: Optional[str] = None, - window_length: str = "96D", - overlap_length: str = "60D", - season: Optional[str] = None, - min_periods_season: Optional[int] = None, - taper_alpha: Optional[float] = 0.5) -> xr.DataArray: + variable: xr.Dataset, + component_type: str, + data_frequency: Optional[str] = None, + window_length: str = "96D", + overlap_length: str = "60D", + season: Optional[str] = None, + min_periods_season: Optional[int] = None, + taper_alpha: Optional[float] = 0.5, + grid_type: str = "latlon", + grid_dict: Optional[dict] = None, +) -> xr.DataArray: """ Get the Wheeler and Kiladis 1999 cross spectrum of two variables. See pywk99.spectrum.get_spectrum for argument documentation. """ cross_spectrum = get_spectrum( - "cross", variable, component_type, data_frequency, window_length, - overlap_length, season, min_periods_season, taper_alpha) + "cross", + variable, + component_type, + data_frequency, + window_length, + overlap_length, + season, + min_periods_season, + taper_alpha, + grid_type, + grid_dict + ) return cross_spectrum def get_co_spectrum( - variable: xr.Dataset, - component_type: str, - data_frequency: Optional[str] = None, - window_length: str = "96D", - overlap_length: str = "60D", - season: Optional[str] = None, - min_periods_season: Optional[int] = None, - taper_alpha: Optional[float] = 0.5) -> xr.DataArray: + variable: xr.Dataset, + component_type: str, + data_frequency: Optional[str] = None, + window_length: str = "96D", + overlap_length: str = "60D", + season: Optional[str] = None, + min_periods_season: Optional[int] = None, + taper_alpha: Optional[float] = 0.5, + grid_type: str = "latlon", + grid_dict: Optional[dict] = None, +) -> xr.DataArray: """ Get the Wheeler and Kiladis 1999 co-spectrum of two variables. See pywk99.spectrum.get_spectrum for argument documentation. """ cross_spectrum = get_spectrum( - "cross", variable, component_type, data_frequency, window_length, - overlap_length, season, min_periods_season, taper_alpha) + "cross", + variable, + component_type, + data_frequency, + window_length, + overlap_length, + season, + min_periods_season, + taper_alpha, + grid_type, + grid_dict + ) co_spectrum = np.real(cross_spectrum) return co_spectrum def get_quadrature_spectrum( - variable: xr.Dataset, - component_type: str, - data_frequency: Optional[str] = None, - window_length: str = "96D", - overlap_length: str = "60D", - season: Optional[str] = None, - min_periods_season: Optional[int] = None, - taper_alpha: Optional[float] = 0.5) -> xr.DataArray: + variable: xr.Dataset, + component_type: str, + data_frequency: Optional[str] = None, + window_length: str = "96D", + overlap_length: str = "60D", + season: Optional[str] = None, + min_periods_season: Optional[int] = None, + taper_alpha: Optional[float] = 0.5, + grid_type: str = "latlon", + grid_dict: Optional[dict] = None, +) -> xr.DataArray: """ Get the Wheeler and Kiladis 1999 quadrature-spectrum of two variables. See pywk99.spectrum.get_spectrum for argument documentation. """ cross_spectrum = get_spectrum( - "cross", variable, component_type, data_frequency, window_length, - overlap_length, season, min_periods_season, taper_alpha) + "cross", + variable, + component_type, + data_frequency, + window_length, + overlap_length, + season, + min_periods_season, + taper_alpha, + grid_type, + grid_dict + ) quadrature_spectrum = np.imag(cross_spectrum) return quadrature_spectrum diff --git a/pywk99/waves/plot.py b/pywk99/waves/plot.py index ba82322..7c1a3d4 100644 --- a/pywk99/waves/plot.py +++ b/pywk99/waves/plot.py @@ -1,5 +1,5 @@ """Plot dispersion relations as seen in Wheeler and Kiladis, 1999.""" -from typing import Optional +from typing import Optional, Tuple from matplotlib import pyplot as plt import numpy as np @@ -30,7 +30,7 @@ def plot_dispersion_relations( _set_axis_limits(ax, k_min, k_max, w_min, w_max) def plot_individual_dispersion_relations( - wave_list: list[tuple[str, int]], + wave_list: list[Tuple[str, int]], ax: Optional[plt.Axes] = None, k_min: float = -14, k_max: Optional[float] = 14, diff --git a/pywk99/waves/waves.py b/pywk99/waves/waves.py index 62b6786..8cbbc49 100644 --- a/pywk99/waves/waves.py +++ b/pywk99/waves/waves.py @@ -2,7 +2,7 @@ from math import sqrt from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple import numpy as np EARTH_RADIUS = 6371000.0 # earth radius [m] @@ -67,7 +67,7 @@ def wk_waves(component_type: str, waves.append(wave) return waves -def individual_waves(wave_list: list[tuple[str, int]], +def individual_waves(wave_list: list[Tuple[str, int]], equivalent_depths: list[float]): """Get the user-specified wave types""" waves = [] @@ -80,7 +80,7 @@ def individual_waves(wave_list: list[tuple[str, int]], return waves -def _wk_curves_names(component_type: str) -> tuple[list[str], int]: +def _wk_curves_names(component_type: str) -> Tuple[list[str], int]: """Get curves and polynomial number as used in WK99 figures.""" if component_type == "symmetric": waves = SYMMETRIC_WAVE_TYPES diff --git a/tests/test_filter_healpix.py b/tests/test_filter_healpix.py new file mode 100644 index 0000000..500960a --- /dev/null +++ b/tests/test_filter_healpix.py @@ -0,0 +1,38 @@ +import pandas as pd +import pytest + +import numpy as np +import xarray as xr +import healpy as hp + +from pywk99.filter.filter import filter_variable +from pywk99.filter.window import get_box_filter_window + + +@pytest.fixture +def variable(): + """Variable with frequency 15/360 CPD and wavenumber 5""" + nside = 32 # zoom 6 + npix = hp.nside2npix(nside) + cell = hp.reorder(np.arange(npix), r2n=True) + time = pd.date_range("2020-01-01", freq="1D", periods=365) + lon, lat = hp.pix2ang(nside, cell, lonlat=True) + data = np.ones(shape=(len(time), len(cell))) + time_cycle = np.cos(15 * 2 * np.pi * np.arange(len(time)) / len(time)) + lon_cycle = np.cos(2 * np.deg2rad(lon)) * np.exp(-(lat**2) / 15**2) + data = 10 + ((data * lon_cycle).T * time_cycle).T + variable = xr.DataArray( + data, dims=["time", "cell"], coords={"time": time, "cell": cell} + ) + return variable + + +def test_filter_works_with_healpix_data(variable): + filter_window = get_box_filter_window(0, 10, 0 / 360, 5 / 360) + test_filtered_variable = filter_variable( + variable, + filter_window, + grid_type="healpix", + grid_dict={"nside": 32, "nest": True, "minmax_lat": 15}, + ) + assert (2 * test_filtered_variable).std() < 0.1 / 100 diff --git a/tests/test_grid_healpix.py b/tests/test_grid_healpix.py new file mode 100644 index 0000000..7476d19 --- /dev/null +++ b/tests/test_grid_healpix.py @@ -0,0 +1,68 @@ +"""Test healpix transformation to latlongirds.""" + +import pytest + +import healpy as hp +import xarray as xr +import numpy as np +import pandas as pd + +from pywk99.grid.healpix import dataset_healpix_to_equatorial_latlon + + +@pytest.fixture +def variable(): + nside = 32 # zoom 6 + npix = hp.nside2npix(nside) + cell = hp.reorder(np.arange(npix), r2n=True) + time = pd.date_range( + start="2000-01-01 06:00", end="2000-01-02 06:00", freq="12h" + ) + lon, lat = hp.pix2ang(nside, cell, lonlat=True) + data = np.exp(-((lat) ** 2) / (5**2)) * (np.sin(5 * np.deg2rad(lon))) + data = np.reshape(data, (-1, 1)) + data = np.tile(data, len(time)).T + dataarray = xr.DataArray( + data=data, dims=("time", "cell"), coords={"time": time, "cell": cell} + ) + variable = xr.Dataset( + {"olr": dataarray, "olr2": dataarray}, + coords={ + "crs": ( + "crs", + [np.nan], + { + "grid_mapping_name": "healpix", + "healpix_nside": 32, + "healpix_order": "nest", + }, + ) + }, + ) + return variable + + +@pytest.fixture +def latlon_variable(variable): + grid_dict = dict(nside=32, nest=True, minmax_lat=20) + latlon_variable = dataset_healpix_to_equatorial_latlon( + variable, **grid_dict + ) + return latlon_variable + + +def test_transformation_has_4xnside_equatorial_points(latlon_variable) -> None: + assert len(latlon_variable.lon) == 128 + + +def test_transformation_keeps_variables(latlon_variable) -> None: + assert len(latlon_variable.data_vars) == 2 + + +def test_transformation_keeps_time(latlon_variable, variable) -> None: + assert np.all(latlon_variable.time == variable.time) + + +def test_transformation_has_equatorial_lats(latlon_variable) -> None: + assert np.all(latlon_variable.lat <= 20) + assert np.all(latlon_variable.lat >= -20) diff --git a/tests/test_spectrum_healpix.py b/tests/test_spectrum_healpix.py new file mode 100644 index 0000000..1e952f9 --- /dev/null +++ b/tests/test_spectrum_healpix.py @@ -0,0 +1,60 @@ +"""Test power spectrum for healpix data fields.""" +import pytest + +import pandas as pd +import healpy as hp +import xarray as xr +import numpy as np + +from pywk99.spectrum.spectrum import get_power_spectrum + +@pytest.fixture +def variable(): + nside = 32 # zoom 6 + npix = hp.nside2npix(nside) + cell = hp.reorder(np.arange(npix), r2n=True) + time = pd.date_range(start="2000-01-01 06:00", + end="2001-01-01 06:00", + freq="12h") + lon, lat = hp.pix2ang(nside, cell, lonlat=True) + data = np.exp(-((lat) ** 2) / (5 ** 2))*(np.sin(5 * np.deg2rad(lon))) + data = np.reshape(data, (-1, 1)) + data = np.tile(data, len(time)).T + dataarray = xr.DataArray(data=data, + dims=("time", "cell"), + coords={"time": time, "cell": cell}) + variable = xr.Dataset({"olr": dataarray}, + coords ={"crs": ("crs", [np.nan], + {'grid_mapping_name': 'healpix', + 'healpix_nside': 32, + 'healpix_order': 'nest'})}) + return variable + + +@pytest.fixture(params=["symmetric", "asymmetric"]) +def spectrum(variable, request): + component_type = request.param + spectrum = get_power_spectrum(variable, + component_type, + window_length="30D", + overlap_length="10D", + grid_type="healpix", + grid_dict={"nside": 32, + "nest": True, + "minmax_lat": 15}) + return spectrum + + +def test_power_spectrum_shape(spectrum): + # from variable time segments ((time_points - 1)/2, 4*nside) + assert np.shape(spectrum) == (29, 128) + + +def test_power_spectrum_frequency_between_zero_and_one(spectrum): + assert np.all(spectrum.frequency.values == np.arange(1, 30)/29) + + +def test_power_spectrum_wavenumbers(spectrum): + # note that positive wave number is eastward in WK99 + assert np.all(spectrum.wavenumber.values == np.arange(-63, 65)) +