Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
644 changes: 644 additions & 0 deletions cov.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Binary file added gedai/data/fsavLEADFIELD_4_GEDAI-cov.fif
Binary file not shown.
Binary file removed gedai/data/fsavLEADFIELD_4_GEDAI.mat
Binary file not shown.
140 changes: 43 additions & 97 deletions gedai/gedai/covariances.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,55 @@
import h5py
import numpy as np
import os
import mne
import sklearn.metrics

from ..utils._checks import check_type


def _compute_distance_cov(raw):
ch_positions = [raw.info["chs"][i]["loc"][:3] for i in range(raw.info["nchan"])]
ch_distance_matrix = sklearn.metrics.pairwise_distances(
ch_positions, metric="euclidean"
)
cov = 1 - ch_distance_matrix

return cov


def _compute_refcov(inst, mat):
inst_ch_names = inst.info["ch_names"]

with h5py.File(mat, "r") as f:
leadfield_data = f["leadfield4GEDAI"]
# ch_names
leadfield_channel_data = leadfield_data["electrodes"]
leadfield_ch_names = [
f[ref[0]][()].tobytes().decode("utf-16le").lower()
for ref in leadfield_channel_data["Name"]
]
# leadfield matrix
leadfield_gain_matrix = leadfield_data["gram_matrix_avref"]
leadfield_gain_matrix = np.array(leadfield_gain_matrix).T

# Two-pass matching: exact first, then substring
ch_indices = []
ch_names = []
matched_inst_indices = set()
match_types = [] # Track match quality for logging

# Pass 1: Exact matching (case-insensitive)
for inst_idx, inst_ch_name in enumerate(inst_ch_names):
for leadfield_ch_index, leadfield_ch_name in enumerate(leadfield_ch_names):
if inst_ch_name.lower() == leadfield_ch_name.lower():
ch_indices.append(leadfield_ch_index)
ch_names.append(leadfield_ch_name)
matched_inst_indices.add(inst_idx)
match_types.append("exact")
break # Move to next inst channel after finding exact match

# Pass 2: Substring matching for unmatched channels
for inst_idx, inst_ch_name in enumerate(inst_ch_names):
if inst_idx in matched_inst_indices:
continue # Already matched exactly

inst_lower = inst_ch_name.lower()
best_match = None
best_match_length = 0

for leadfield_ch_index, leadfield_ch_name in enumerate(leadfield_ch_names):
leadfield_lower = leadfield_ch_name.lower()

# Check if leadfield name is substring of inst name
# or inst name is substring of leadfield name
if leadfield_lower in inst_lower or inst_lower in leadfield_lower:
# Prefer longer matches to avoid false positives
match_length = min(len(leadfield_lower), len(inst_lower))
if match_length > best_match_length:
best_match = leadfield_ch_index
best_match_length = match_length

if best_match is not None:
ch_indices.append(best_match)
ch_names.append(leadfield_ch_names[best_match])
matched_inst_indices.add(inst_idx)
match_types.append("substring")

# Validation and warnings
n_inst_channels = len(inst_ch_names)
n_matched = len(ch_indices)

if n_matched == 0:
raise ValueError(
f"No electrode matches found between data and leadfield "
f"template.\n"
f"Your channels: {inst_ch_names[:10]}\n"
f"Leadfield channels: {leadfield_ch_names[:10]}\n"
f"Please check that your electrode names follow standard "
f"conventions (e.g., Fp1, Fp2, F3, F4)."
)

# Always warn if any channels didn't match
if n_matched < n_inst_channels:
import warnings

unmatched = [
inst_ch_names[i]
for i in range(n_inst_channels)
if i not in matched_inst_indices
]
n_exact = match_types.count("exact")
n_substring = match_types.count("substring")

warnings.warn(
f"Electrode matching: {n_matched}/{n_inst_channels} channels "
f"matched ({n_exact} exact, {n_substring} substring). "
f"Unmatched channels ({len(unmatched)}): "
f"{unmatched}",
UserWarning,
stacklevel=2,
)

refCOV = leadfield_gain_matrix[np.ix_(ch_indices, ch_indices)]
return (refCOV, ch_names)
def _ensure_cov(reference_cov):
check_type(reference_cov, (str, mne.Covariance), "reference_cov")
if isinstance(reference_cov, str):
if reference_cov == "leadfield":
reference_cov = mne.read_cov(os.path.join(os.path.dirname(__file__), "../data/fsavLEADFIELD_4_GEDAI-cov.fif"))
else:
raise ValueError(
"Reference covariance must be 'leadfield'"
f"got '{reference_cov}' instead."
)
return reference_cov


def _pick_cov(cov, ch_names):
cov_ch_names = cov.ch_names

picks_cov = []
picks_ch_names = []
for cov_name in cov_ch_names:
for ch_name in ch_names:
if ch_name.lower() == cov_name.lower():
picks_cov.append(cov_name)
picks_ch_names.append(ch_name)
break
if len(picks_cov) == 0:
raise ValueError("No matching channel names found between inst and cov.\n"
f"Available channels in covariance are {cov_ch_names}.\n"
f"but instance has channels {ch_names}.")
if len(picks_cov) < len(ch_names):
raise ValueError("Only a subset of channels in the instance are present"
" in the covariance.\n"
f"Use inst.pick_channels({picks_ch_names}) to select only the channels"
f" that are in the covariance or provide a covariance that contains"
f" all channels in the instance.")
cov = cov.copy().pick_channels(picks_cov)
# Update the channel names in the covariance to match those in the instance
cov.update(names=ch_names)
return cov
Loading
Loading