diff --git a/cctorch/data.py b/cctorch/data.py index db6260e..e91f278 100644 --- a/cctorch/data.py +++ b/cctorch/data.py @@ -551,11 +551,11 @@ def read_data(file_name, data_path, format="h5", mode="CC", config={}): if format == "h5": data, info = read_das_continuous_data_h5(data_path / file_name, dataset_keys=[]) elif format == "mseed": - data, info = read_mseed(file_name, config=config) + data, info = read_mseed(file_name, config=config, sampling_rate=config.fs) elif mode == "TM": if format == "mseed": - data, info = read_mseed(file_name, config=config) + data, info = read_mseed(file_name, config=config, sampling_rate=config.fs) # data, info = read_mseed_3c(file_name, config=config) else: raise ValueError(f"Unknown mode: {mode}") @@ -576,6 +576,7 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None): stream += meta # stream += obspy.read(tmp) stream = stream.merge(fill_value="latest") + stream.detrend("demean") ## FIXME: HARDCODE for California if tmp.startswith("s3://ncedc-pds"): @@ -603,10 +604,10 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None): if trace.stats.sampling_rate != sampling_rate: logging.warning(f"Resampling {trace.id} from {trace.stats.sampling_rate} to {sampling_rate} Hz") try: - trace = trace.interpolate(sampling_rate, method="linear") - if tmp.startswith("s3://ncedc-pds"): - trace = trace.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True) - elif tmp.startswith("s3://scedc-pds"): + trace.filter("lowpass", freq=0.45 * sampling_rate, zerophase=True, corners=8) + trace.interpolate(method="lanczos", sampling_rate=sampling_rate, a=1.0) + # trace = trace.interpolate(sampling_rate, method="linear") + if tmp.startswith(("s3://ncedc-pds", "s3://scedc-pds")): trace = trace.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True) except Exception as e: print(f"Error resampling {trace.id}:\n{e}") @@ -647,10 +648,10 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None): nx = len(station_ids) nt = max([len(tr.data) for tr in stream]) - ## FIXME: HARDCODE for California - if tmp.startswith("s3://ncedc-pds") or tmp.startswith("s3://scedc-pds"): - nt = 8640001 - + # ## FIXME: HARDCODE for California + # if tmp.startswith("s3://ncedc-pds") or tmp.startswith("s3://scedc-pds"): + # nt = 8640001 + data = np.zeros([3, nx, nt], dtype=np.float32) for i, sta in enumerate(station_keys): for c in station_ids[sta]: diff --git a/cctorch/model.py b/cctorch/model.py index 08a1672..158f881 100644 --- a/cctorch/model.py +++ b/cctorch/model.py @@ -40,9 +40,28 @@ def __init__( # AN self.nlag = config.nlag self.nfft = self.nlag * 2 - self.window = torch.hann_window(self.nfft, periodic=False).to(self.device) + # self.window = torch.hann_window(self.nfft, periodic=False).to(self.device) self.spectral_whitening = config.spectral_whitening + def partial_hann_taper(self, length, taper_fraction=0.04, device="cpu"): + # print('Chris flag taper', length, taper_fraction) + n_taper = int(length * taper_fraction) + if n_taper == 0: + return torch.ones(length, device=device) + + # Hann window for edges + x = torch.linspace(0, torch.pi / 2, n_taper, device=device) + taper_edge = torch.sin(x)**2 # sin² taper + + taper_start = taper_edge + taper_end = taper_edge.flip(0) + + + # Build full window: start + flat + end + ones_middle = torch.ones(length - 2 * n_taper, device=device) + window = torch.cat([taper_start, ones_middle, taper_end], dim=0) + return window + def forward(self, x): """Perform cross-correlation on input data Args: @@ -54,7 +73,7 @@ def forward(self, x): - data (torch.Tensor): data2 with shape (batch, nsta/nch, nt) - info (dict): information information of data2 """ - + self.window = self.partial_hann_taper(self.nfft, 0.04, device=self.device) x1, x2 = x if self.to_device: data1 = x1["data"].to(self.device) @@ -158,6 +177,7 @@ def forward(self, x): data2 = torch.exp(1j * data2.angle()) xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1) + xcor = xcor / data1.size(1) xcor = torch.roll(xcor, self.nlag, dims=-1) xcor = xcor.view(nb1, nc1, nx1, -1) diff --git a/run.py b/run.py index 97b033d..258e818 100644 --- a/run.py +++ b/run.py @@ -1,6 +1,8 @@ +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + import json import logging -import os import threading from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from dataclasses import dataclass