Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .skyignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
**/.*
**/__pycache__
venv/
data/
42 changes: 30 additions & 12 deletions cctorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,9 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
meta = obspy.read(fs, format="MSEED")
stream += meta
# stream += obspy.read(tmp)
stream = stream.merge(fill_value="latest")

stream_mask = stream.copy().merge(fill_value=None)
stream = stream.merge(fill_value=0)

## FIXME: HARDCODE for California
if tmp.startswith("s3://ncedc-pds"):
Expand All @@ -611,12 +613,14 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
begin_time = obspy.UTCDateTime(year=year, julday=jday)
end_time = begin_time + 86400 ## 1 day
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True)
stream_mask = stream_mask.trim(begin_time, end_time, pad=True, fill_value=None, nearest_sample=True)
elif tmp.startswith("s3://scedc-pds"):
year_jday = tmp.split("/")[-1].rstrip(".ms")[-7:]
year, jday = int(year_jday[:4]), int(year_jday[4:])
begin_time = obspy.UTCDateTime(year=year, julday=jday)
end_time = begin_time + 86400 ## 1 day
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0, nearest_sample=True)
stream_mask = stream_mask.trim(begin_time, end_time, pad=True, fill_value=None, nearest_sample=True)
except Exception as e:
print(f"Error reading {fname}:\n{e}")
return None
Expand Down Expand Up @@ -661,6 +665,7 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
begin_time = min([st.stats.starttime for st in stream])
end_time = max([st.stats.endtime for st in stream])
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0)
stream_mask = stream_mask.trim(begin_time, end_time, pad=True, fill_value=None)

comp = ["3", "2", "1", "E", "N", "Z"]
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
Expand All @@ -681,6 +686,7 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
nt = int(24 * 60 * 60 * sampling_rate) + 1

data = np.zeros([3, nx, nt], dtype=np.float32)
mask = np.zeros([3, nx, nt], dtype=np.int8)
for i, sta in enumerate(station_keys):
for c in station_ids[sta]:
j = comp2idx[c]
Expand All @@ -690,19 +696,26 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
continue

trace = stream.select(id=sta + c)[0]
trace_mask = stream_mask.select(id=sta + c)[0]
try:
mask_array = trace_mask.data.mask
mask_array = mask_array.astype(int)
except:
mask_array = np.zeros(len(trace_mask.data))

## accerleration to velocity
if sta[-1] == "N":
trace = trace.integrate().filter("highpass", freq=1.0)

tmp = trace.data.astype("float32")
data[j, i, : len(tmp)] = tmp[:nt]
mask[j, i, : len(mask_array)] = mask_array[:nt]

# return data, {
# "begin_time": begin_time.datetime, # .strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
# "end_time": end_time.datetime, # .strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
# }
return data, {
return data, {"mask": mask,
"begin_time": np.datetime64(begin_time.datetime),
"end_time": np.datetime64(end_time.datetime),
}
Expand Down Expand Up @@ -766,18 +779,23 @@ def read_mseed_3c(fname, response=None, highpass_filter=0.0, sampling_rate=100,


def read_das_continuous_data_h5(fn, dataset_keys=[]):
with h5py.File(fn, "r") as f:
if "Data" in f:
data = f["Data"][:]
elif "data" in f:
data = f["data"][:]
else:
raise ValueError("Cannot find data in the file")
info = {}
for key in dataset_keys:
info[key] = f[key][:]
fs = fsspec.filesystem("gs", token="google_default")
with fs.open(fn, "rb") as f:
with h5py.File(f, "r") as hf:
if "Data" in hf:
data = hf["Data"][:]
elif "data" in hf:
data = hf["data"][:]
elif "Acquisition" in hf:
data = hf["Acquisition/Raw[0]/RawData"][:]
else:
raise ValueError("Cannot find data in the file")
info = {}
for key in dataset_keys:
info[key] = hf[key][:]
if data.ndim == 2:
data = data[np.newaxis, :, :] # (nc, nx, nt)

return data, info


Expand Down
65 changes: 55 additions & 10 deletions cctorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch.nn.functional as F
from tqdm import tqdm

from .utils import partial_hann_taper, cosine_taper_4freq, custom_demeaned_stft
from scipy.fft import next_fast_len

class CCModel(nn.Module):
def __init__(
Expand Down Expand Up @@ -39,8 +41,8 @@ def __init__(

# AN
self.nlag = config.nlag
self.nfft = self.nlag * 2
self.window = torch.hann_window(self.nfft, periodic=False).to(self.device)
self.nfft = next_fast_len(self.nlag * 2 + 1)
self.window = partial_hann_taper(self.nfft, 0.04, self.device)
self.spectral_whitening = config.spectral_whitening

def forward(self, x):
Expand Down Expand Up @@ -127,27 +129,57 @@ def forward(self, x):
xcor = torch.mean(xcor, dim=(-3), keepdim=True)

elif self.domain == "stft":
overlap_ratio = 0.0 # costumize overlap ratio for stft
hop_length = int(self.nlag * ((1-overlap_ratio)/0.5))

# --- for masking ---
# mask1 = x1['info']['mask']
# mask2 = x2['info']['mask']
# mask1 = torch.from_numpy(np.stack(mask1, axis=0)).float()
# mask2 = torch.from_numpy(np.stack(mask2, axis=0)).float()

# pooled_mask1 = F.max_pool2d(
# mask1,
# kernel_size=(1, self.nlag * 2 + 5),
# stride=(1, hop_length),
# )
# pooled_mask2 = F.max_pool2d(
# mask2,
# kernel_size=(1, self.nlag * 2 + 5),
# stride=(1, hop_length),
# )
# mask_reshaped1 = pooled_mask1.view(pooled_mask1.shape[0]*pooled_mask1.shape[1], 1, pooled_mask1.shape[-1])
# mask_reshaped2 = pooled_mask2.view(pooled_mask2.shape[0]*pooled_mask2.shape[1], 1, pooled_mask2.shape[-1])

# mask_reshaped1 = 1 - mask_reshaped1
# mask_reshaped2 = 1 - mask_reshaped2

nlag = self.nlag
nb1, nc1, nx1, nt1 = data1.shape
# nb2, nc2, nx2, nt2 = data2.shape
data1 = data1.view(nb1 * nc1 * nx1, nt1)
# data2 = data2.view(nb2 * nc2 * nx2, nt2)
data2 = data2.view(nb1 * nc1 * nx1, nt1)
if not self.pre_fft:

# data1 = custom_demeaned_stft(data1, nlag, hop_length, self.window) # slow down a lot, only use while overlap_ratio > 0
# data2 = custom_demeaned_stft(data2, nlag, hop_length, self.window)


data1 = torch.stft(
data1,
n_fft=self.nlag * 2,
hop_length=self.nlag,
n_fft=self.nfft,
hop_length=hop_length,
window=self.window,
center=True,
center=False, # turn off centering to prevent window shift caused by padding at both ends
return_complex=True,
)
data2 = torch.stft(
data2,
n_fft=self.nlag * 2,
hop_length=self.nlag,
n_fft=self.nfft,
hop_length=hop_length,
window=self.window,
center=True,
center=False,
return_complex=True,
)
if self.spectral_whitening:
Expand All @@ -157,9 +189,22 @@ def forward(self, x):
data1 = torch.exp(1j * data1.angle())
data2 = torch.exp(1j * data2.angle())

xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1)
f_taper_asym = cosine_taper_4freq(data1.shape[1], low=0.001, high=0.49/self.dt, sample_rate=1.0/self.dt)
f_taper_asym_on_device = f_taper_asym.to(data1.device)
data1 = data1 * f_taper_asym_on_device
data2 = data2 * f_taper_asym_on_device

# --- 3 components cross-correlation only ---
# xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1)
# # xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2) * mask_reshaped1 * mask_reshaped2, dim=-1), n=(self.nlag * 2 + 1),dim=-1) # only for masking

# --- all 9 components cross-correlation ---
xcor = torch.fft.irfft(torch.sum(data1.unsqueeze(1) * torch.conj(data2.unsqueeze(0)), dim=-1), dim=-1).reshape(nc1**2, -1)

xcor = xcor / data1.size(1)
xcor = torch.roll(xcor, self.nlag, dims=-1)
xcor = xcor.view(nb1, nc1, nx1, -1)
nc1_update = xcor.shape[0]
xcor = xcor.view(nb1, nc1_update, nx1, -1)

else:
raise ValueError("domain should be frequency or time or stft")
Expand Down
104 changes: 101 additions & 3 deletions cctorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ def _parse_pair_ids(fname1, fname2, pair_index_fallback):

def write_ambient_noise(results, root):
"""
Write ambient noise results to disk.
Append ambient noise results to zarr arrays.

Args:
results: List of result dicts with xcorr data and pair info
root: Zarr group with 'xcorr', 'id1', 'id2' arrays (created on first call)
Expand Down Expand Up @@ -350,8 +350,8 @@ def write_ambient_noise_indexed(results, store_path, start_idx, storage_options=
fallback = meta.get("pair_index", [(str(i), str(i))] * xcorr.shape[0])[i]
id1, id2 = _parse_pair_ids(fnames1[i], fnames2[i], fallback)
batch_xcorr.append(np.squeeze(xcorr[i]))
batch_id1.append(id1)
batch_id2.append(id2)
batch_id1.append(str(id1))
batch_id2.append(str(id2))

if not batch_xcorr:
return
Expand Down Expand Up @@ -489,6 +489,104 @@ def write_h5(fn, dataset_name, data, attrs_dict):
fid[dataset_name].attrs.modify(key, val)


def partial_hann_taper(length, taper_fraction=0.04, device="cpu"):
n_taper = int(length * taper_fraction)
if n_taper == 0:
return torch.ones(length, device=device)

# Hann window for edges
x = torch.linspace(0, torch.pi / 2, n_taper, device=device)
taper_edge = torch.sin(x)**2 # sin² taper

taper_start = taper_edge
taper_end = taper_edge.flip(0)


# Build full window: start + flat + end
ones_middle = torch.ones(length - 2 * n_taper, device=device)
window = torch.cat([taper_start, ones_middle, taper_end], dim=0)
return window

def custom_demeaned_stft(data1, nlag, hop_length, window):
"""
Custom STFT with per-window demeaning that matches torch.stft(..., center=False)

Args:
data1: (B, T) time-domain signal
nlag: for computing n_fft = 2 * nlag + 5
hop_length: step size between windows
window: (n_fft,) window function (e.g., Hann)

Returns:
Complex STFT of shape (B, freq_bins, time_frames), matching torch.stft
"""
n_fft = 2 * nlag + 5
B, T = data1.shape

# Compute number of complete frames (no padding)
num_frames = (T - n_fft) // hop_length + 1

# Use unfold to extract frames
frames = data1.unfold(dimension=-1, size=n_fft, step=hop_length) # (B, num_frames, n_fft)

# Demean each frame
frames = frames - frames.mean(dim=-1, keepdim=True)

# Apply window
window = window.to(data1.device)
frames = frames * window.view(1, 1, -1)

# Apply FFT
stft_result = torch.fft.rfft(frames, dim=-1) # (B, num_frames, freq_bins)

# Transpose to match torch.stft output: (B, freq_bins, time_frames)
stft_result = stft_result.transpose(-1, -2)

return stft_result # shape: (B, freq_bins, time_frames)

def cosine_taper_4freq(n_freqs, low, high, sample_rate=20):
"""
Create a 1D cosine taper with flat region between left_end and right_start,
and cosine transitions on both sides.

Parameters:
- n_freqs: total number of frequency bins
- left_start, left_end, right_start, right_end: index positions in frequency domain

Returns:
- taper: tensor of shape [n_freqs]
"""
delta_f = sample_rate / ((n_freqs - 1)*2 + 1)
low_idx = math.ceil(low / delta_f)
high_idx = math.floor(high / delta_f)
low_left = low_idx - 100
if low_left < 0:
low_left = 0
high_right = high_idx + 100
high_right = min(high_right, n_freqs-1)
# print(f"Doing the classic Brutal Whiten {n_freqs} {low_left} {low_idx} {high_idx} {high_right}")
left_start = low_left
left_end = low_idx
right_start = high_idx
right_end = high_right

taper = np.zeros(n_freqs)

# Left cosine ramp
for i in range(left_start, left_end):
frac = (i - left_start) / (left_end - left_start)
taper[i] = 0.5 * (1 - np.cos(np.pi * frac))

# Flat part
taper[left_end:right_start] = 1.0

# Right cosine ramp
for i in range(right_start, right_end):
frac = (i - right_start) / (right_end - right_start)
taper[i] = 0.5 * (1 + np.cos(np.pi * frac))
cos_taper = torch.tensor(taper, dtype=torch.float32)
return cos_taper[None, :, None]

# # %%
# @dataclass
# class Config:
Expand Down
Binary file added examples/california/.mseeds1_2005_123.txt.swp
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/california/.skyignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ waveforms/
gs:/
__pycache__/
mseeds*.txt
pairs*.txt
pairs*.txt
10 changes: 10 additions & 0 deletions examples/california/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ def parse_args():
parser.add_argument("--root_path", type=str, default="./")
parser.add_argument("--result_path", type=str, default="./results")
parser.add_argument("--knn_dist", type=int, default=300)

# --- add for more flexibility ---
parser.add_argument("--year_start", type=int, default=2024)
parser.add_argument("--year_end", type=int, default=2024)
parser.add_argument("--jday_start", type=int, default=1)
parser.add_argument("--jday_end", type=int, default=1)
parser.add_argument("--local_station_file", type=str, default="")

parser.add_argument("--subcluster", type=str, default="")

args = parser.parse_args()

return args
Loading