diff --git a/mtspec/multitaper.py b/mtspec/multitaper.py index 8ebbcc9..ff10050 100644 --- a/mtspec/multitaper.py +++ b/mtspec/multitaper.py @@ -93,19 +93,29 @@ def mtspec(data, delta, time_bandwidth, nfft=None, number_of_tapers=None, # Depending if nfft is specified or not initialte MtspecTytpe # for mtspec_pad_ or mtspec_d_ + complex = any(np.iscomplex(data)) if nfft is None or nfft == npts: nfft = npts mt = _MtspecType("float64") # mtspec_d_ else: mt = _MtspecType("float32") # mtspec_pad_ quadratic = False + if complex: + mt = _MtspecType("float32") # mtspec_c_ + mt.mtspec = mtspeclib.mtspec_c_ # Use the optimal number of tapers in case no number is specified. if number_of_tapers is None: number_of_tapers = int(2 * time_bandwidth) - 1 # Transform the data to work with the library. - data = np.require(data, dtype=mt.float, requirements=[mt.order]) + if complex: + data = np.require(data, dtype=mt.complex, requirements=[mt.order]) + else: + data = np.require(data, dtype=mt.float, requirements=[mt.order]) # Get some information necessary for the call to the Fortran library. - number_of_frequency_bins = int(nfft / 2) + 1 + if complex: + number_of_frequency_bins = nfft + else: + number_of_frequency_bins = int(nfft / 2) + 1 # Create output arrays. spectrum = mt.empty(number_of_frequency_bins) frequency_bins = mt.empty(number_of_frequency_bins) @@ -155,7 +165,7 @@ def mtspec(data, delta, time_bandwidth, nfft=None, number_of_tapers=None, fcrit = None # Call the library. Fortran passes pointers! args = [C.byref(C.c_int(npts)), C.byref(C.c_int(nfft)), - C.byref(mt.c_float(delta)), mt.p(data), + C.byref(mt.c_float(delta)), mt.p(data, complex), C.byref(mt.c_float(time_bandwidth)), C.byref(C.c_int(number_of_tapers)), C.byref(C.c_int(number_of_frequency_bins)), mt.p(frequency_bins), @@ -165,6 +175,7 @@ def mtspec(data, delta, time_bandwidth, nfft=None, number_of_tapers=None, mt.p(eigenspectra), rshape, mt.p(f_statistics), fcrit, None] # diffrent arguments, depending on mtspec_pad_ or mtspec_d_, adapt if npts == nfft: + #print('npts == nfft') args.pop(1) # finally call the shared library function @@ -764,9 +775,13 @@ def __init__(self, dtype): :param dtype: 'float32' or 'float64' """ if dtype not in self.struct.keys(): - raise ValueError("dtype must be either 'float32' or 'float64'") + raise ValueError("dtype must be either 'float32' or 'float64'" + + " or 'complex32'") self.float = dtype - self.complex = 'complex%d' % (2 * float(dtype[-2:])) + #self.real = 'float%d' % (float(dtype[-2:])) + self.complex = 'complex%d' % (2*float(dtype[-2:])) + #self.complex = 'complex%d' % (float(dtype[-2:])) + # aboves leads to: TypeError: data type "complex32" not understood self.c_float = self.struct[dtype][0] self.pointer = C.POINTER(self.c_float) self.order = "F" @@ -783,7 +798,7 @@ def empty(self, shape, complex=False): return np.empty(shape, dtype=self.complex, order=self.order) return np.empty(shape, dtype=self.float, order=self.order) - def p(self, ndarray): + def p(self, ndarray, complex=False): """ A wrapper around ctypes.data_as which automatically sets the correct type. Returns none if ndarray is None. @@ -793,4 +808,8 @@ def p(self, ndarray): # short variable name for passing as argument in function calls if ndarray is None: return None - return ndarray.ctypes.data_as(self.pointer) + if complex: + pointer_complex = C.POINTER(2*self.c_float) + return ndarray.ctypes.data_as(pointer_complex) + else: + return ndarray.ctypes.data_as(self.pointer) diff --git a/mtspec/tests/test_multitaper.py b/mtspec/tests/test_multitaper.py index f396cc9..ed0c70d 100644 --- a/mtspec/tests/test_multitaper.py +++ b/mtspec/tests/test_multitaper.py @@ -71,6 +71,28 @@ def test_multitaper_spectrum(self): np.testing.assert_almost_equal(freq, freq2) np.testing.assert_almost_equal(spec / spec, spec2 / spec, 5) + def test_multitaper_spectrum_complex_input(self): + """ + Test for mtspec. The result is compared to the output of + test_recreatePaperFigures.py in the same directory. This is assumed to + be correct because they are identical to the figures in the paper on + the machine that created these. + """ + data = _load_mtdata('PASC.dat.gz') + # Calculate the spectra. + spec, freq = mtspec(data+1j*data, 1.0, 4.5, number_of_tapers=5) + # No NaNs are supposed to be in the output. + self.assertEqual(np.isnan(spec).any(), False) + self.assertEqual(np.isnan(spec).any(), False) + # Load the good data. + datafile = os.path.join(os.path.dirname(__file__), 'data', + 'multitaper.npz') + spec2 = np.load(datafile)['spec'] + freq2 = np.arange(43201) * 1.15740741e-05 + # Compare, normalize for subdigit comparision + np.testing.assert_almost_equal(freq, freq2) + np.testing.assert_almost_equal(spec / spec, 2*spec2 / spec, 5) + def test_multitaper_spectrum_optional_output(self): """ Test for mtspec. The result is compared to the output of