Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
decoding_test-1.py
prepare_dat.py
120 changes: 88 additions & 32 deletions decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import pandas as pd
import json
import mne
import misc
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.base import clone, is_classifier
from sklearn.ensemble._voting import LabelEncoder, _routing_enabled
from sklearn.ensemble._voting import process_routing, Bunch
from sklearn.ensemble._voting import _fit_single_estimator

from sklearn import metrics as sk_metrics

try:
from . import misc
Expand Down Expand Up @@ -154,7 +156,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False,
n_jobs=-2, plot_confmat=False, title_add="",
ex_per_fold=2, simulate=False, subj="",
tmin=-0.1, tmax=0.5, sfreq=100,
return_probas=False,
return_probas=True, metric='accuracy',
metric_kwargs={},
verbose=True):
"""
Perform cross-validation across time on the given dataset.
Expand Down Expand Up @@ -185,8 +188,17 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False,
Subject identifier. Default is an empty string.
ms_per_point : int, optional
Milliseconds per time point. Default is 10.
return_preds : bool, optional
If True, return predictions along with the DataFrame. Default is False.
metric : str or function, optional
Scoring function used for model evaluation.
Either one of the scoring functions available from scikit-learn
(input needs to be string with name of the function, like
"average_precision_score") or self-defined function.
Default is "accuracy" (% correct predictions across folds)
metric_kwargs: dict, optional
extra parameters for the scoring function that are not
predictions / probabilities and correct_labels. Possible inputs depend
on the scoring function that is chosen


Returns
-------
Expand All @@ -198,7 +210,6 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False,
# Ensure each class has the same number of examples
assert (len(set(np.bincount(data_y)).difference(set([0]))) == 1), \
"WARNING not each class has the same number of examples"
# warnings.warn('RETURN THIS')
# Set random seed based on subject ID for reproducibility
np.random.seed(misc.string_to_seed(subj))

Expand All @@ -218,8 +229,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False,
tqdm_loop = tqdm(total=total, desc=f"CV Fold {subj}", disable=not verbose)
df = pd.DataFrame()

# Initialize array to store all predictions
all_probas = np.zeros([len(data_y), time_max, len(labels)])
# Initialize array to store the probabilities for each class
all_results = np.zeros([len(data_y), time_max, len(labels)])

times = np.linspace(tmin*1000, tmax*1000, time_max).round()

Expand All @@ -234,39 +245,84 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False,
train_x = data_x[idxs_train]
train_y = data_y[idxs_train]
test_x = data_x[idxs_test]
test_y = data_y[idxs_test]

# Add null data if specified
neg_x = np.hstack(train_x[:, :, 0:1].T).T if add_null_data else None

# Train and predict in parallel across time points
probas = Parallel(n_jobs=n_jobs)(
delayed(train_predict)(
train_x=train_x[:, :, start],
train_y=train_y,
test_x=test_x[:, :, start],
neg_x=neg_x,
clf=clf,
proba=True
# ova=ova,

# determine scoring method
if metric == "accuracy":
func = sk_metrics.top_k_accuracy_score
metric_kwargs={"k":1, "labels": labels}
needs_probas = True
else:
if isinstance(metric, str):
if not hasattr(sk_metrics, metric):
raise ValueError(f"sklearn.metrics has no function named '{metric}'")
func = getattr(sk_metrics, metric)
elif callable(metric):
func = metric
else:
raise TypeError("metric must be 'accuracy', a sklearn.metrics name (str), or a callable.")

# Determine if metric function expects probabilities or predictions.
sig = inspect.signature(func)
inputs_names = list(sig.parameters)

second_param = inputs_names[1]
prob_indicators = ['y_score', 'probas_pred', 'y_proba']
pred_indicators = ['y_pred', 'labels_pred']
if any(indicator in second_param for indicator in prob_indicators):
needs_probas=True
elif any(indicator in second_param for indicator in pred_indicators):
needs_probas=False
else:
print("determining response method not possible")

# add any extra parameters that are not preds and data_y
if metric_kwargs:
missing_kwargs = set(list(metric_kwargs)).difference(inputs_names)
if missing_kwargs:
raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs}')

results_preds = Parallel(n_jobs=n_jobs)(
delayed(train_predict)(
train_x=train_x[:, :, start],
train_y=train_y,
test_x=test_x[:, :, start],
neg_x=neg_x,
clf=clf,
proba=True
)
for start in list(range(0, time_max))
for start in list(range(0, time_max))
)
probas = np.swapaxes(probas, 0, 1)

# Store predictions and calculate accuracy
all_probas[idxs_test] = probas

preds = np.argmax(probas, -1)

accuracy = (preds == test_y[:, None]).mean(axis=0)

results_swp = np.swapaxes(results_preds, 0, 1)

# store results for each fold together
all_results[idxs_test] = results_swp

# convert probabilities to prediction if necessary
if not needs_probas:
preds_idx = np.argmax(all_results, axis=2)
preds_lbl = labels[preds_idx]


# compute scoring metric
score = np.zeros(time_max)
for t in list(range(0, time_max)):
if needs_probas:
score[t] = func(data_y, all_results[:,t], **metric_kwargs)
else:
score[t] = func(data_y, preds_lbl[:,t], **metric_kwargs)
Comment on lines +305 to +316

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block of code contains several critical bugs that will lead to incorrect results or runtime errors:

  1. UnboundVariable: preds_lbl is defined within an if not needs_probas: block but is used later in a separate else block. If needs_probas is True, this will raise an UnboundLocalError.
  2. Incorrect Data Scope for Predictions: Predictions are generated using np.argmax(all_results, ...). The all_results array contains data from all folds processed so far, not just the current test fold. Predictions should be based on results_swp, which holds the probabilities for the current fold.
  3. Incorrect Scoring Data: The scoring function func is called with data_y (all labels) and all_results. Scoring in cross-validation must be performed on the hold-out test set for the current fold, using data_y[idxs_test] and the corresponding predictions/probabilities from results_swp.

The suggested change restructures the logic to fix these issues by correctly scoping data to the current fold.

Suggested change
if not needs_probas:
preds_idx = np.argmax(all_results, axis=2)
preds_lbl = labels[preds_idx]
# compute scoring metric
score = np.zeros(time_max)
for t in list(range(0, time_max)):
if needs_probas:
score[t] = func(data_y, all_results[:,t], **metric_kwargs)
else:
score[t] = func(data_y, preds_lbl[:,t], **metric_kwargs)
test_y = data_y[idxs_test]
score = np.zeros(time_max)
if needs_probas:
# Score using probabilities for each time point
for t in range(time_max):
score[t] = func(test_y, results_swp[:, t], **metric_kwargs)
else:
# Convert probabilities to label predictions once
preds_idx = np.argmax(results_swp, axis=2)
preds_lbl = labels[preds_idx]
# Score using predictions for each time point
for t in range(time_max):
score[t] = func(test_y, preds_lbl[:, t], **metric_kwargs)


# Create a temporary DataFrame for the current fold
df_temp = pd.DataFrame(
{"timepoint": times,
"fold": [j] * len(accuracy),
"accuracy": accuracy,
"subject": [subj] * len(accuracy),
"fold": [j] * len(score),
"score": score,
"metric_used": str(func),
"metric_kwargs": str(metric_kwargs),
"subject": [subj] * len(score)
}
)
# Concatenate the temporary DataFrame with the main DataFrame
Expand All @@ -279,7 +335,7 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False,
tqdm_loop.close()

# Return results
return (df, all_probas) if return_probas else df
return (df, all_results)



Expand Down