From dccc2166f67324ba7a37e9a87126a42e2a90154a Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Fri, 12 Sep 2025 09:45:17 +0200 Subject: [PATCH 1/9] Update cross_validation_across_time to use different scoring functions --- decoding.py | 48 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/decoding.py b/decoding.py index 0cc7b0c..97394cb 100644 --- a/decoding.py +++ b/decoding.py @@ -23,7 +23,7 @@ from sklearn.ensemble._voting import LabelEncoder, _routing_enabled from sklearn.ensemble._voting import process_routing, Bunch from sklearn.ensemble._voting import _fit_single_estimator - +from sklearn import metrics try: from . import misc @@ -154,7 +154,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, n_jobs=-2, plot_confmat=False, title_add="", ex_per_fold=2, simulate=False, subj="", tmin=-0.1, tmax=0.5, sfreq=100, - return_probas=False, + return_probas=False, metric='accuracy', + metric_kwargs={}, proba=True, verbose=True): """ Perform cross-validation across time on the given dataset. @@ -187,6 +188,21 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, Milliseconds per time point. Default is 10. return_preds : bool, optional If True, return predictions along with the DataFrame. Default is False. + metric : str or function, optional + Scoring function used for model evaluation. + Either one of the scoring functions available from scikit-learn + (input needs to be string with name of the function, like + "average_precision_score") or self-defined function. + Default is "accuracy" (% correct predictions across folds) + proba: bool, optional + If True, predict_proba is used. If false, predict is used and class + labels instead of probabilities are used as input to the metric function. + Choose depending on what kind of input the scoring function requires. + metric_kwargs: dict, optional + extra parameters for the scoring function that are not + predictions / probabilities and correct_labels. Possible inputs depend + on the scoring function that is chosen + Returns ------- @@ -247,7 +263,7 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, test_x=test_x[:, :, start], neg_x=neg_x, clf=clf, - proba=True + proba=proba # ova=ova, ) for start in list(range(0, time_max)) @@ -258,8 +274,30 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, all_probas[idxs_test] = probas preds = np.argmax(probas, -1) - - accuracy = (preds == test_y[:, None]).mean(axis=0) + + if metric == "accuracy": + accuracy = (preds == test_y[:, None]).mean(axis=0) + + elif isinstance(metric, str): + + function_name = metric + + if not hasattr(metrics, function_name): + raise ValueError(f"sklearn.metrics has no function named '{function_name}'") + + func = getattr(metrics, function_name) + sig = inspect.signature(func) + # add any extra parameters that are not preds and data_y + if metric_kwargs: + kwargs = {} + for k, v in metric_kwargs.items(): + if k in sig.parameters: + kwargs[k] = v + + # need to loop over timepoints + accuracy = np.zeros(time_max) + for t in list(range(0, time_max)): + accuracy[t] = func(data_y, all_probas[:,t], **kwargs) # Create a temporary DataFrame for the current fold df_temp = pd.DataFrame( From e3c354bdaaf41958a8fe1fd24afe85de027bfcdb Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Fri, 12 Sep 2025 09:56:41 +0200 Subject: [PATCH 2/9] cross_validation_across_time: allow self-defined metric function --- decoding.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/decoding.py b/decoding.py index 97394cb..3572959 100644 --- a/decoding.py +++ b/decoding.py @@ -23,7 +23,7 @@ from sklearn.ensemble._voting import LabelEncoder, _routing_enabled from sklearn.ensemble._voting import process_routing, Bunch from sklearn.ensemble._voting import _fit_single_estimator -from sklearn import metrics +from sklearn import metrics as sk_metrics try: from . import misc @@ -277,15 +277,17 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, if metric == "accuracy": accuracy = (preds == test_y[:, None]).mean(axis=0) - - elif isinstance(metric, str): - - function_name = metric - - if not hasattr(metrics, function_name): - raise ValueError(f"sklearn.metrics has no function named '{function_name}'") - - func = getattr(metrics, function_name) + else: + # resolve metric function + if isinstance(metric, str): + if not hasattr(sk_metrics, metric): + raise ValueError(f"sklearn.metrics has no function named '{metric}'") + func = getattr(sk_metrics, metric) + elif callable(metric): + func = metric + else: + raise TypeError("metric must be 'accuracy', a sklearn.metrics name (str), or a callable.") + sig = inspect.signature(func) # add any extra parameters that are not preds and data_y if metric_kwargs: From 6c06960bfb4d6eab1979d76d19d591a7c322a49c Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:26:34 +0200 Subject: [PATCH 3/9] Update decoding.py add error message when given metric_kwargs don't match the input arguments required by the scoring function Co-authored-by: Simon Kern <14980558+skjerns@users.noreply.github.com> --- decoding.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/decoding.py b/decoding.py index 3572959..3d85fb4 100644 --- a/decoding.py +++ b/decoding.py @@ -291,10 +291,10 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, sig = inspect.signature(func) # add any extra parameters that are not preds and data_y if metric_kwargs: - kwargs = {} - for k, v in metric_kwargs.items(): - if k in sig.parameters: - kwargs[k] = v + sig = inspect.signature(func) + missing_kwargs = set(metric_kwargs).difference(sig.parameters) + if missing_kwargs: + raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs} ') # need to loop over timepoints accuracy = np.zeros(time_max) From 9d9219d868ec69a931c35219cf26a581f63e9727 Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:39:19 +0200 Subject: [PATCH 4/9] implemented some more suggested changes input to function, default scoring function (top_1_accuracy) --- decoding.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/decoding.py b/decoding.py index 3d85fb4..db88cd0 100644 --- a/decoding.py +++ b/decoding.py @@ -154,8 +154,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, n_jobs=-2, plot_confmat=False, title_add="", ex_per_fold=2, simulate=False, subj="", tmin=-0.1, tmax=0.5, sfreq=100, - return_probas=False, metric='accuracy', - metric_kwargs={}, proba=True, + return_probas=True, metric='accuracy', + metric_kwargs={}, verbose=True): """ Perform cross-validation across time on the given dataset. @@ -263,7 +263,7 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, test_x=test_x[:, :, start], neg_x=neg_x, clf=clf, - proba=proba + proba=return_probas # ova=ova, ) for start in list(range(0, time_max)) @@ -276,7 +276,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, preds = np.argmax(probas, -1) if metric == "accuracy": - accuracy = (preds == test_y[:, None]).mean(axis=0) + func = sk_metrics.top_k_accuracy_score + metric_kwargs["k"]=1 else: # resolve metric function if isinstance(metric, str): From 90513357b9908e8a9168a13a01da209da8064fcd Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:42:13 +0200 Subject: [PATCH 5/9] output variable naming fixed --- decoding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/decoding.py b/decoding.py index db88cd0..14890b0 100644 --- a/decoding.py +++ b/decoding.py @@ -298,16 +298,16 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs} ') # need to loop over timepoints - accuracy = np.zeros(time_max) + score = np.zeros(time_max) for t in list(range(0, time_max)): - accuracy[t] = func(data_y, all_probas[:,t], **kwargs) + score[t] = func(data_y, all_probas[:,t], **metric_kwargs) # Create a temporary DataFrame for the current fold df_temp = pd.DataFrame( {"timepoint": times, - "fold": [j] * len(accuracy), - "accuracy": accuracy, - "subject": [subj] * len(accuracy), + "fold": [j] * len(score), + "score": score, + "subject": [subj] * len(score), } ) # Concatenate the temporary DataFrame with the main DataFrame From 8579ab7e6f6b309cca24fb5b96cc2e5e3585e8e4 Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:45:53 +0200 Subject: [PATCH 6/9] save metric_kwargs (optinal inputs for scoring function) in the final df --- decoding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/decoding.py b/decoding.py index 14890b0..169b2e4 100644 --- a/decoding.py +++ b/decoding.py @@ -307,7 +307,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, {"timepoint": times, "fold": [j] * len(score), "score": score, - "subject": [subj] * len(score), + "metric_kwargs": str(metric_kwargs), + "subject": [subj] * len(score) } ) # Concatenate the temporary DataFrame with the main DataFrame From 956e879667e1c2ff1574a5021db47477daeb5587 Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:17:17 +0200 Subject: [PATCH 7/9] fix default case ("accuracy) for cross_validation_across_time --- decoding.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/decoding.py b/decoding.py index 169b2e4..fd33792 100644 --- a/decoding.py +++ b/decoding.py @@ -290,24 +290,25 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, raise TypeError("metric must be 'accuracy', a sklearn.metrics name (str), or a callable.") sig = inspect.signature(func) - # add any extra parameters that are not preds and data_y - if metric_kwargs: - sig = inspect.signature(func) - missing_kwargs = set(metric_kwargs).difference(sig.parameters) - if missing_kwargs: - raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs} ') - - # need to loop over timepoints - score = np.zeros(time_max) - for t in list(range(0, time_max)): - score[t] = func(data_y, all_probas[:,t], **metric_kwargs) + # add any extra parameters that are not preds and data_y + if metric_kwargs: + sig = inspect.signature(func) + missing_kwargs = set(metric_kwargs).difference(sig.parameters) + if missing_kwargs: + raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs} ') + + # need to loop over timepoints + score = np.zeros(time_max) + print(func) + for t in list(range(0, time_max)): + score[t] = func(data_y, all_probas[:,t], **metric_kwargs) # Create a temporary DataFrame for the current fold df_temp = pd.DataFrame( {"timepoint": times, "fold": [j] * len(score), "score": score, - "metric_kwargs": str(metric_kwargs), + "metric_used": str(func) * len(score), "subject": [subj] * len(score) } ) From cf6e4c04357f65a29542cc358b029279f66f3d19 Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Wed, 24 Sep 2025 09:53:08 +0200 Subject: [PATCH 8/9] try to automatically determine the input the performance metric function for the classifier needs --- .gitignore | 2 ++ decoding.py | 96 ++++++++++++++++++++++++++++++++--------------------- 2 files changed, 60 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 82f9275..9f5fa48 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +decoding_test-1.py +prepare_dat.py diff --git a/decoding.py b/decoding.py index fd33792..a41430d 100644 --- a/decoding.py +++ b/decoding.py @@ -18,6 +18,7 @@ import pandas as pd import json import mne +import misc from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import clone, is_classifier from sklearn.ensemble._voting import LabelEncoder, _routing_enabled @@ -186,18 +187,12 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, Subject identifier. Default is an empty string. ms_per_point : int, optional Milliseconds per time point. Default is 10. - return_preds : bool, optional - If True, return predictions along with the DataFrame. Default is False. metric : str or function, optional Scoring function used for model evaluation. Either one of the scoring functions available from scikit-learn (input needs to be string with name of the function, like "average_precision_score") or self-defined function. Default is "accuracy" (% correct predictions across folds) - proba: bool, optional - If True, predict_proba is used. If false, predict is used and class - labels instead of probabilities are used as input to the metric function. - Choose depending on what kind of input the scoring function requires. metric_kwargs: dict, optional extra parameters for the scoring function that are not predictions / probabilities and correct_labels. Possible inputs depend @@ -235,7 +230,7 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, df = pd.DataFrame() # Initialize array to store all predictions - all_probas = np.zeros([len(data_y), time_max, len(labels)]) + all_results = np.zeros([len(data_y), time_max, len(labels)]) times = np.linspace(tmin*1000, tmax*1000, time_max).round() @@ -250,34 +245,13 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, train_x = data_x[idxs_train] train_y = data_y[idxs_train] test_x = data_x[idxs_test] - test_y = data_y[idxs_test] # Add null data if specified neg_x = np.hstack(train_x[:, :, 0:1].T).T if add_null_data else None - - # Train and predict in parallel across time points - probas = Parallel(n_jobs=n_jobs)( - delayed(train_predict)( - train_x=train_x[:, :, start], - train_y=train_y, - test_x=test_x[:, :, start], - neg_x=neg_x, - clf=clf, - proba=return_probas - # ova=ova, - ) - for start in list(range(0, time_max)) - ) - probas = np.swapaxes(probas, 0, 1) - - # Store predictions and calculate accuracy - all_probas[idxs_test] = probas - - preds = np.argmax(probas, -1) if metric == "accuracy": func = sk_metrics.top_k_accuracy_score - metric_kwargs["k"]=1 + needs_probas = False else: # resolve metric function if isinstance(metric, str): @@ -288,27 +262,73 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, func = metric else: raise TypeError("metric must be 'accuracy', a sklearn.metrics name (str), or a callable.") - + + # Determine if metric function expects probabilities or predictions. sig = inspect.signature(func) + inputs_names = list(sig.parameters) + + # try: + # scorer = sk_metrics.get_scorer(metric) + # # Check if scorer has _response_method attribute + # if hasattr(scorer, '_response_method'): + # if scorer._response_method == 'predict_proba': + # needs_probas=True + # elif scorer._response_method == 'predict': + # needs_probas=False + # print(needs_probas, (scorer._response_method)) + + # except: + # print("using get_scorer not possible") + + second_param = inputs_names[1] + prob_indicators = ['y_score', 'probas_pred', 'y_proba'] + pred_indicators = ['y_pred', 'labels_pred'] + if any(indicator in second_param for indicator in prob_indicators): + needs_probas=True + elif any(indicator in second_param for indicator in pred_indicators): + needs_probas=False + else: + print("determining response method not possible") + # add any extra parameters that are not preds and data_y if metric_kwargs: - sig = inspect.signature(func) - missing_kwargs = set(metric_kwargs).difference(sig.parameters) + missing_kwargs = set(list(metric_kwargs)).difference(inputs_names) if missing_kwargs: raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs} ') - - # need to loop over timepoints + + results_preds = Parallel(n_jobs=n_jobs)( + delayed(train_predict)( + train_x=train_x[:, :, start], + train_y=train_y, + test_x=test_x[:, :, start], + neg_x=neg_x, + clf=clf, + proba=True + ) + for start in list(range(0, time_max)) + ) + + results_swp = np.swapaxes(results_preds, 0, 1) + # Store probabilities + all_results[idxs_test] = results_swp + # preds = np.argmax(probas, -1) + score = np.zeros(time_max) - print(func) + # need to loop over timepoints for t in list(range(0, time_max)): - score[t] = func(data_y, all_probas[:,t], **metric_kwargs) + if needs_probas: + score[t] = func(data_y, all_results[:,t], **metric_kwargs) + else: + #preds_t = np.argmax(all_probas[:,t], -1) + score[t] = func(data_y, all_results[:,t], **metric_kwargs) # Create a temporary DataFrame for the current fold df_temp = pd.DataFrame( {"timepoint": times, "fold": [j] * len(score), "score": score, - "metric_used": str(func) * len(score), + "metric_used": str(func), + "metric_kwargs": str(metric_kwargs), "subject": [subj] * len(score) } ) @@ -322,7 +342,7 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, tqdm_loop.close() # Return results - return (df, all_probas) if return_probas else df + return (df, all_results) From d0bee49aa5f346739e775bb0e146afe8826d47b1 Mon Sep 17 00:00:00 2001 From: sianna338 <157957826+sianna338@users.noreply.github.com> Date: Tue, 30 Sep 2025 10:41:59 +0200 Subject: [PATCH 9/9] implement way to determine which type of input the scoring function requires --- decoding.py | 51 ++++++++++++++++++++++----------------------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/decoding.py b/decoding.py index a41430d..f169bed 100644 --- a/decoding.py +++ b/decoding.py @@ -19,6 +19,7 @@ import json import mne import misc +import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import clone, is_classifier from sklearn.ensemble._voting import LabelEncoder, _routing_enabled @@ -209,7 +210,6 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, # Ensure each class has the same number of examples assert (len(set(np.bincount(data_y)).difference(set([0]))) == 1), \ "WARNING not each class has the same number of examples" - # warnings.warn('RETURN THIS') # Set random seed based on subject ID for reproducibility np.random.seed(misc.string_to_seed(subj)) @@ -229,7 +229,7 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, tqdm_loop = tqdm(total=total, desc=f"CV Fold {subj}", disable=not verbose) df = pd.DataFrame() - # Initialize array to store all predictions + # Initialize array to store the probabilities for each class all_results = np.zeros([len(data_y), time_max, len(labels)]) times = np.linspace(tmin*1000, tmax*1000, time_max).round() @@ -249,11 +249,12 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, # Add null data if specified neg_x = np.hstack(train_x[:, :, 0:1].T).T if add_null_data else None + # determine scoring method if metric == "accuracy": func = sk_metrics.top_k_accuracy_score - needs_probas = False + metric_kwargs={"k":1, "labels": labels} + needs_probas = True else: - # resolve metric function if isinstance(metric, str): if not hasattr(sk_metrics, metric): raise ValueError(f"sklearn.metrics has no function named '{metric}'") @@ -262,24 +263,11 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, func = metric else: raise TypeError("metric must be 'accuracy', a sklearn.metrics name (str), or a callable.") - + # Determine if metric function expects probabilities or predictions. sig = inspect.signature(func) inputs_names = list(sig.parameters) - # try: - # scorer = sk_metrics.get_scorer(metric) - # # Check if scorer has _response_method attribute - # if hasattr(scorer, '_response_method'): - # if scorer._response_method == 'predict_proba': - # needs_probas=True - # elif scorer._response_method == 'predict': - # needs_probas=False - # print(needs_probas, (scorer._response_method)) - - # except: - # print("using get_scorer not possible") - second_param = inputs_names[1] prob_indicators = ['y_score', 'probas_pred', 'y_proba'] pred_indicators = ['y_pred', 'labels_pred'] @@ -290,11 +278,11 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, else: print("determining response method not possible") - # add any extra parameters that are not preds and data_y - if metric_kwargs: - missing_kwargs = set(list(metric_kwargs)).difference(inputs_names) - if missing_kwargs: - raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs} ') + # add any extra parameters that are not preds and data_y + if metric_kwargs: + missing_kwargs = set(list(metric_kwargs)).difference(inputs_names) + if missing_kwargs: + raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs}') results_preds = Parallel(n_jobs=n_jobs)( delayed(train_predict)( @@ -309,18 +297,23 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, ) results_swp = np.swapaxes(results_preds, 0, 1) - # Store probabilities + + # store results for each fold together all_results[idxs_test] = results_swp - # preds = np.argmax(probas, -1) - + + # convert probabilities to prediction if necessary + if not needs_probas: + preds_idx = np.argmax(all_results, axis=2) + preds_lbl = labels[preds_idx] + + + # compute scoring metric score = np.zeros(time_max) - # need to loop over timepoints for t in list(range(0, time_max)): if needs_probas: score[t] = func(data_y, all_results[:,t], **metric_kwargs) else: - #preds_t = np.argmax(all_probas[:,t], -1) - score[t] = func(data_y, all_results[:,t], **metric_kwargs) + score[t] = func(data_y, preds_lbl[:,t], **metric_kwargs) # Create a temporary DataFrame for the current fold df_temp = pd.DataFrame(