From ec7a36816e32a4ee9cb70edb7968a4e84848b071 Mon Sep 17 00:00:00 2001 From: Chris Lin Date: Sun, 11 May 2025 13:27:30 -0700 Subject: [PATCH 1/4] 1) Let read_mseed receive assigned sample_rate. 2) demean before fill zeros. 3) apply lowpass and update downsample method. 4) comment nt=8640001 for using corrected nt after downsample. --- cctorch/data.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cctorch/data.py b/cctorch/data.py index db6260e..e91f278 100644 --- a/cctorch/data.py +++ b/cctorch/data.py @@ -551,11 +551,11 @@ def read_data(file_name, data_path, format="h5", mode="CC", config={}): if format == "h5": data, info = read_das_continuous_data_h5(data_path / file_name, dataset_keys=[]) elif format == "mseed": - data, info = read_mseed(file_name, config=config) + data, info = read_mseed(file_name, config=config, sampling_rate=config.fs) elif mode == "TM": if format == "mseed": - data, info = read_mseed(file_name, config=config) + data, info = read_mseed(file_name, config=config, sampling_rate=config.fs) # data, info = read_mseed_3c(file_name, config=config) else: raise ValueError(f"Unknown mode: {mode}") @@ -576,6 +576,7 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None): stream += meta # stream += obspy.read(tmp) stream = stream.merge(fill_value="latest") + stream.detrend("demean") ## FIXME: HARDCODE for California if tmp.startswith("s3://ncedc-pds"): @@ -603,10 +604,10 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None): if trace.stats.sampling_rate != sampling_rate: logging.warning(f"Resampling {trace.id} from {trace.stats.sampling_rate} to {sampling_rate} Hz") try: - trace = trace.interpolate(sampling_rate, method="linear") - if tmp.startswith("s3://ncedc-pds"): - trace = trace.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True) - elif tmp.startswith("s3://scedc-pds"): + trace.filter("lowpass", freq=0.45 * sampling_rate, zerophase=True, corners=8) + trace.interpolate(method="lanczos", sampling_rate=sampling_rate, a=1.0) + # trace = trace.interpolate(sampling_rate, method="linear") + if tmp.startswith(("s3://ncedc-pds", "s3://scedc-pds")): trace = trace.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True) except Exception as e: print(f"Error resampling {trace.id}:\n{e}") @@ -647,10 +648,10 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None): nx = len(station_ids) nt = max([len(tr.data) for tr in stream]) - ## FIXME: HARDCODE for California - if tmp.startswith("s3://ncedc-pds") or tmp.startswith("s3://scedc-pds"): - nt = 8640001 - + # ## FIXME: HARDCODE for California + # if tmp.startswith("s3://ncedc-pds") or tmp.startswith("s3://scedc-pds"): + # nt = 8640001 + data = np.zeros([3, nx, nt], dtype=np.float32) for i, sta in enumerate(station_keys): for c in station_ids[sta]: From de94b5779f072942d0b727a304db9a87c27c36e8 Mon Sep 17 00:00:00 2001 From: Chris Lin Date: Sun, 11 May 2025 13:38:12 -0700 Subject: [PATCH 2/4] 1) Replace hann_window with partial_hann_taper. 2) Devide ccf with window length. --- cctorch/model.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/cctorch/model.py b/cctorch/model.py index 08a1672..5811e9e 100644 --- a/cctorch/model.py +++ b/cctorch/model.py @@ -40,9 +40,28 @@ def __init__( # AN self.nlag = config.nlag self.nfft = self.nlag * 2 - self.window = torch.hann_window(self.nfft, periodic=False).to(self.device) + # self.window = torch.hann_window(self.nfft, periodic=False).to(self.device) self.spectral_whitening = config.spectral_whitening + def partial_hann_taper(self, length, taper_fraction=0.04, device="cpu"): + # print('Chris flag taper', length, taper_fraction) + n_taper = int(length * taper_fraction) + if n_taper == 0: + return torch.ones(length, device=device) + + # Hann window for edges + x = torch.linspace(0, torch.pi / 2, n_taper, device=device) + taper_edge = torch.sin(x)**2 # sin² taper + + taper_start = taper_edge + taper_end = taper_edge.flip(0) + + + # Build full window: start + flat + end + ones_middle = torch.ones(length - 2 * n_taper, device=device) + window = torch.cat([taper_start, ones_middle, taper_end], dim=0) + return window + def forward(self, x): """Perform cross-correlation on input data Args: @@ -54,7 +73,7 @@ def forward(self, x): - data (torch.Tensor): data2 with shape (batch, nsta/nch, nt) - info (dict): information information of data2 """ - + self.window = self.partial_hann_taper(self.nfft, 0.04) x1, x2 = x if self.to_device: data1 = x1["data"].to(self.device) @@ -158,6 +177,7 @@ def forward(self, x): data2 = torch.exp(1j * data2.angle()) xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1) + xcor = xcor / data1.size(1) xcor = torch.roll(xcor, self.nlag, dims=-1) xcor = xcor.view(nb1, nc1, nx1, -1) From dd0ce0be38ade240617f8d5d25a651b799365323 Mon Sep 17 00:00:00 2001 From: Chris Lin Date: Sun, 11 May 2025 13:48:04 -0700 Subject: [PATCH 3/4] Allow the code to run with --device=mps --- run.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/run.py b/run.py index 97b033d..258e818 100644 --- a/run.py +++ b/run.py @@ -1,6 +1,8 @@ +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + import json import logging -import os import threading from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from dataclasses import dataclass From 9244b0d354b7545773bc803fc154b6849b85f183 Mon Sep 17 00:00:00 2001 From: Chris Lin Date: Sun, 11 May 2025 13:58:42 -0700 Subject: [PATCH 4/4] fix partial_hann_taper to read --device --- cctorch/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cctorch/model.py b/cctorch/model.py index 5811e9e..158f881 100644 --- a/cctorch/model.py +++ b/cctorch/model.py @@ -73,7 +73,7 @@ def forward(self, x): - data (torch.Tensor): data2 with shape (batch, nsta/nch, nt) - info (dict): information information of data2 """ - self.window = self.partial_hann_taper(self.nfft, 0.04) + self.window = self.partial_hann_taper(self.nfft, 0.04, device=self.device) x1, x2 = x if self.to_device: data1 = x1["data"].to(self.device)