Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions mtspec/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,29 @@ def mtspec(data, delta, time_bandwidth, nfft=None, number_of_tapers=None,

# Depending if nfft is specified or not initialte MtspecTytpe
# for mtspec_pad_ or mtspec_d_
complex = any(np.iscomplex(data))
if nfft is None or nfft == npts:
nfft = npts
mt = _MtspecType("float64") # mtspec_d_
else:
mt = _MtspecType("float32") # mtspec_pad_
quadratic = False
if complex:
mt = _MtspecType("float32") # mtspec_c_
mt.mtspec = mtspeclib.mtspec_c_
# Use the optimal number of tapers in case no number is specified.
if number_of_tapers is None:
number_of_tapers = int(2 * time_bandwidth) - 1
# Transform the data to work with the library.
data = np.require(data, dtype=mt.float, requirements=[mt.order])
if complex:
data = np.require(data, dtype=mt.complex, requirements=[mt.order])
else:
data = np.require(data, dtype=mt.float, requirements=[mt.order])
# Get some information necessary for the call to the Fortran library.
number_of_frequency_bins = int(nfft / 2) + 1
if complex:
number_of_frequency_bins = nfft
else:
number_of_frequency_bins = int(nfft / 2) + 1
# Create output arrays.
spectrum = mt.empty(number_of_frequency_bins)
frequency_bins = mt.empty(number_of_frequency_bins)
Expand Down Expand Up @@ -155,7 +165,7 @@ def mtspec(data, delta, time_bandwidth, nfft=None, number_of_tapers=None,
fcrit = None
# Call the library. Fortran passes pointers!
args = [C.byref(C.c_int(npts)), C.byref(C.c_int(nfft)),
C.byref(mt.c_float(delta)), mt.p(data),
C.byref(mt.c_float(delta)), mt.p(data, complex),
C.byref(mt.c_float(time_bandwidth)),
C.byref(C.c_int(number_of_tapers)),
C.byref(C.c_int(number_of_frequency_bins)), mt.p(frequency_bins),
Expand All @@ -165,6 +175,7 @@ def mtspec(data, delta, time_bandwidth, nfft=None, number_of_tapers=None,
mt.p(eigenspectra), rshape, mt.p(f_statistics), fcrit, None]
# diffrent arguments, depending on mtspec_pad_ or mtspec_d_, adapt
if npts == nfft:
#print('npts == nfft')
args.pop(1)

# finally call the shared library function
Expand Down Expand Up @@ -764,9 +775,13 @@ def __init__(self, dtype):
:param dtype: 'float32' or 'float64'
"""
if dtype not in self.struct.keys():
raise ValueError("dtype must be either 'float32' or 'float64'")
raise ValueError("dtype must be either 'float32' or 'float64'"
+ " or 'complex32'")
self.float = dtype
self.complex = 'complex%d' % (2 * float(dtype[-2:]))
#self.real = 'float%d' % (float(dtype[-2:]))
self.complex = 'complex%d' % (2*float(dtype[-2:]))
#self.complex = 'complex%d' % (float(dtype[-2:]))
# aboves leads to: TypeError: data type "complex32" not understood
self.c_float = self.struct[dtype][0]
self.pointer = C.POINTER(self.c_float)
self.order = "F"
Expand All @@ -783,7 +798,7 @@ def empty(self, shape, complex=False):
return np.empty(shape, dtype=self.complex, order=self.order)
return np.empty(shape, dtype=self.float, order=self.order)

def p(self, ndarray):
def p(self, ndarray, complex=False):
"""
A wrapper around ctypes.data_as which automatically sets the
correct type. Returns none if ndarray is None.
Expand All @@ -793,4 +808,8 @@ def p(self, ndarray):
# short variable name for passing as argument in function calls
if ndarray is None:
return None
return ndarray.ctypes.data_as(self.pointer)
if complex:
pointer_complex = C.POINTER(2*self.c_float)
return ndarray.ctypes.data_as(pointer_complex)
else:
return ndarray.ctypes.data_as(self.pointer)
22 changes: 22 additions & 0 deletions mtspec/tests/test_multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ def test_multitaper_spectrum(self):
np.testing.assert_almost_equal(freq, freq2)
np.testing.assert_almost_equal(spec / spec, spec2 / spec, 5)

def test_multitaper_spectrum_complex_input(self):
"""
Test for mtspec. The result is compared to the output of
test_recreatePaperFigures.py in the same directory. This is assumed to
be correct because they are identical to the figures in the paper on
the machine that created these.
"""
data = _load_mtdata('PASC.dat.gz')
# Calculate the spectra.
spec, freq = mtspec(data+1j*data, 1.0, 4.5, number_of_tapers=5)
# No NaNs are supposed to be in the output.
self.assertEqual(np.isnan(spec).any(), False)
self.assertEqual(np.isnan(spec).any(), False)
# Load the good data.
datafile = os.path.join(os.path.dirname(__file__), 'data',
'multitaper.npz')
spec2 = np.load(datafile)['spec']
freq2 = np.arange(43201) * 1.15740741e-05
# Compare, normalize for subdigit comparision
np.testing.assert_almost_equal(freq, freq2)
np.testing.assert_almost_equal(spec / spec, 2*spec2 / spec, 5)

def test_multitaper_spectrum_optional_output(self):
"""
Test for mtspec. The result is compared to the output of
Expand Down