Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions examples/cross_spectra.ipynb

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pywk99/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

Changelog
---------
Jul 18, 2025: Version 0.4.4
- Adding coherence to the cross spectra function.
- Fixing bugs in the cross spectra.

Apr 8, 2025: Version 0.4.3
- Adding support for healpix

Expand Down Expand Up @@ -70,4 +74,4 @@
- Several bugs where detected and corrected with the tests.
"""

__version__ = "0.4.3"
__version__ = "0.4.4"
2 changes: 1 addition & 1 deletion pywk99/filter/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class FilterPoint(Point):
"""See shapely.geometry.Point."""

@dataclass(frozen=True)
@dataclass
class FilterWindow:
name: str
polygon: Union[Polygon, MultiPolygon]
Expand Down
1 change: 1 addition & 0 deletions pywk99/spectrum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from pywk99.spectrum.background import get_background_spectrum
from pywk99.spectrum.plot import plot_spectrum
from pywk99.spectrum.plot import plot_spectrum_peaks

16 changes: 12 additions & 4 deletions pywk99/spectrum/_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_spectrum(spc_quantity: str,
]
number_of_spectrums = len(wk_spectrums)
wk_spectrum = (sum(wk_spectrums) / number_of_spectrums).sum("lat")
wk_spectrum = wk_spectrum[wk_spectrum.frequency > 0]
wk_spectrum = wk_spectrum.where(wk_spectrum.frequency > 0, drop=True)
wk_spectrum = wk_spectrum.sortby(["frequency", "wavenumber"])
return wk_spectrum

Expand Down Expand Up @@ -178,9 +178,17 @@ def _compute_hayashi_cross_spectrum(variables: xr.Dataset) -> xr.DataArray:
n_time = len(variables.time)
n_lon = len(variables.lon)
varlist = list(variables.keys())
variable1_fft = fourier_transform(variables[varlist[0]]) / (n_time * n_lon)
variable2_fft = fourier_transform(variables[varlist[1]]) / (n_time * n_lon)
cross_spectrum = variable1_fft * np.conj(variable2_fft)
variable1_name = varlist[0]
variable2_name = varlist[1]
variable1_fft = fourier_transform(variables[variable1_name])
variable2_fft = fourier_transform(variables[variable2_name])
spectrum_1 = np.abs(variable1_fft)**2 / (n_time * n_lon)**2
spectrum_2 = np.abs(variable2_fft)**2 / (n_time * n_lon)**2
cross_1_2 = variable1_fft * np.conj(variable2_fft) / (n_time * n_lon)**2
spectrum_1.name = f"spectra1"
spectrum_2.name = f"spectra2"
cross_1_2.name = f"cross"
cross_spectrum = xr.merge([spectrum_1, spectrum_2, cross_1_2])
return cross_spectrum


Expand Down
18 changes: 14 additions & 4 deletions pywk99/spectrum/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def get_cross_spectrum(
taper_alpha: Optional[float] = 0.5,
grid_type: str = "latlon",
grid_dict: Optional[dict] = None,
) -> xr.DataArray:
) -> xr.Dataset:
"""
Get the Wheeler and Kiladis 1999 cross spectrum of two variables.
Get the Wheeler and Kiladis 1999 cross spectrum.

See pywk99.spectrum.get_spectrum for argument documentation.
"""
Expand All @@ -111,9 +111,18 @@ def get_cross_spectrum(
grid_type,
grid_dict
)
cross_spectrum["coherence_squared"] = _coherence_squared(cross_spectrum)
return cross_spectrum


def _coherence_squared(cross_spectrum: xr.Dataset) -> xr.Dataset:
"""Compute the squared coherence a cross spectrum."""
sxy2 = np.abs(cross_spectrum.cross)**2
sxx = np.abs(cross_spectrum.spectra1)
syy = np.abs(cross_spectrum.spectra2)
return sxy2 / (sxx * syy)


def get_co_spectrum(
variable: xr.Dataset,
component_type: str,
Expand Down Expand Up @@ -144,7 +153,7 @@ def get_co_spectrum(
grid_type,
grid_dict
)
co_spectrum = np.real(cross_spectrum)
co_spectrum = np.real(cross_spectrum.cross)
return co_spectrum


Expand Down Expand Up @@ -178,5 +187,6 @@ def get_quadrature_spectrum(
grid_type,
grid_dict
)
quadrature_spectrum = np.imag(cross_spectrum)
quadrature_spectrum = np.imag(cross_spectrum.cross)
return quadrature_spectrum

32 changes: 32 additions & 0 deletions tests/test_spectrum_coh2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Test power spectrum for time-lon-lat fields."""

import numpy as np
import pytest
import xarray as xr

from pywk99.spectrum.spectrum import get_cross_spectrum


@pytest.fixture
def variables():
olr1 = xr.open_dataarray("tests/olr.test.nc").transpose(
"time", "lon", "lat"
)
olr2 = olr1.copy()
variables = xr.Dataset({"olr1": olr1, "olr2": olr2})
variables = variables.sortby(["lat"])
return variables


@pytest.fixture(params=["symmetric", "asymmetric"])
def cross_spectrum(variables, request):
component_type = request.param
cross_spectrum = get_cross_spectrum(
variables, component_type, window_length="30D", overlap_length="10D"
)
return cross_spectrum


def test_coherence_squared_is_one_for_the_same_field(cross_spectrum) -> None:
coh2 = cross_spectrum.coherence_squared
assert np.all(np.ravel(coh2.values) == pytest.approx(1.0))