diff --git a/.gitignore b/.gitignore index 82f9275..9f5fa48 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +decoding_test-1.py +prepare_dat.py diff --git a/decoding.py b/decoding.py index 0cc7b0c..f169bed 100644 --- a/decoding.py +++ b/decoding.py @@ -18,12 +18,14 @@ import pandas as pd import json import mne +import misc +import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import clone, is_classifier from sklearn.ensemble._voting import LabelEncoder, _routing_enabled from sklearn.ensemble._voting import process_routing, Bunch from sklearn.ensemble._voting import _fit_single_estimator - +from sklearn import metrics as sk_metrics try: from . import misc @@ -154,7 +156,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, n_jobs=-2, plot_confmat=False, title_add="", ex_per_fold=2, simulate=False, subj="", tmin=-0.1, tmax=0.5, sfreq=100, - return_probas=False, + return_probas=True, metric='accuracy', + metric_kwargs={}, verbose=True): """ Perform cross-validation across time on the given dataset. @@ -185,8 +188,17 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, Subject identifier. Default is an empty string. ms_per_point : int, optional Milliseconds per time point. Default is 10. - return_preds : bool, optional - If True, return predictions along with the DataFrame. Default is False. + metric : str or function, optional + Scoring function used for model evaluation. + Either one of the scoring functions available from scikit-learn + (input needs to be string with name of the function, like + "average_precision_score") or self-defined function. + Default is "accuracy" (% correct predictions across folds) + metric_kwargs: dict, optional + extra parameters for the scoring function that are not + predictions / probabilities and correct_labels. Possible inputs depend + on the scoring function that is chosen + Returns ------- @@ -198,7 +210,6 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, # Ensure each class has the same number of examples assert (len(set(np.bincount(data_y)).difference(set([0]))) == 1), \ "WARNING not each class has the same number of examples" - # warnings.warn('RETURN THIS') # Set random seed based on subject ID for reproducibility np.random.seed(misc.string_to_seed(subj)) @@ -218,8 +229,8 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, tqdm_loop = tqdm(total=total, desc=f"CV Fold {subj}", disable=not verbose) df = pd.DataFrame() - # Initialize array to store all predictions - all_probas = np.zeros([len(data_y), time_max, len(labels)]) + # Initialize array to store the probabilities for each class + all_results = np.zeros([len(data_y), time_max, len(labels)]) times = np.linspace(tmin*1000, tmax*1000, time_max).round() @@ -234,39 +245,84 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, train_x = data_x[idxs_train] train_y = data_y[idxs_train] test_x = data_x[idxs_test] - test_y = data_y[idxs_test] # Add null data if specified neg_x = np.hstack(train_x[:, :, 0:1].T).T if add_null_data else None - - # Train and predict in parallel across time points - probas = Parallel(n_jobs=n_jobs)( - delayed(train_predict)( - train_x=train_x[:, :, start], - train_y=train_y, - test_x=test_x[:, :, start], - neg_x=neg_x, - clf=clf, - proba=True - # ova=ova, + + # determine scoring method + if metric == "accuracy": + func = sk_metrics.top_k_accuracy_score + metric_kwargs={"k":1, "labels": labels} + needs_probas = True + else: + if isinstance(metric, str): + if not hasattr(sk_metrics, metric): + raise ValueError(f"sklearn.metrics has no function named '{metric}'") + func = getattr(sk_metrics, metric) + elif callable(metric): + func = metric + else: + raise TypeError("metric must be 'accuracy', a sklearn.metrics name (str), or a callable.") + + # Determine if metric function expects probabilities or predictions. + sig = inspect.signature(func) + inputs_names = list(sig.parameters) + + second_param = inputs_names[1] + prob_indicators = ['y_score', 'probas_pred', 'y_proba'] + pred_indicators = ['y_pred', 'labels_pred'] + if any(indicator in second_param for indicator in prob_indicators): + needs_probas=True + elif any(indicator in second_param for indicator in pred_indicators): + needs_probas=False + else: + print("determining response method not possible") + + # add any extra parameters that are not preds and data_y + if metric_kwargs: + missing_kwargs = set(list(metric_kwargs)).difference(inputs_names) + if missing_kwargs: + raise ValueError(f'The following metric_kwargs were given but are not part of the function signature {missing_kwargs}') + + results_preds = Parallel(n_jobs=n_jobs)( + delayed(train_predict)( + train_x=train_x[:, :, start], + train_y=train_y, + test_x=test_x[:, :, start], + neg_x=neg_x, + clf=clf, + proba=True ) - for start in list(range(0, time_max)) + for start in list(range(0, time_max)) ) - probas = np.swapaxes(probas, 0, 1) - - # Store predictions and calculate accuracy - all_probas[idxs_test] = probas - - preds = np.argmax(probas, -1) - - accuracy = (preds == test_y[:, None]).mean(axis=0) + + results_swp = np.swapaxes(results_preds, 0, 1) + + # store results for each fold together + all_results[idxs_test] = results_swp + + # convert probabilities to prediction if necessary + if not needs_probas: + preds_idx = np.argmax(all_results, axis=2) + preds_lbl = labels[preds_idx] + + + # compute scoring metric + score = np.zeros(time_max) + for t in list(range(0, time_max)): + if needs_probas: + score[t] = func(data_y, all_results[:,t], **metric_kwargs) + else: + score[t] = func(data_y, preds_lbl[:,t], **metric_kwargs) # Create a temporary DataFrame for the current fold df_temp = pd.DataFrame( {"timepoint": times, - "fold": [j] * len(accuracy), - "accuracy": accuracy, - "subject": [subj] * len(accuracy), + "fold": [j] * len(score), + "score": score, + "metric_used": str(func), + "metric_kwargs": str(metric_kwargs), + "subject": [subj] * len(score) } ) # Concatenate the temporary DataFrame with the main DataFrame @@ -279,7 +335,7 @@ def cross_validation_across_time(data_x, data_y, clf, add_null_data=False, tqdm_loop.close() # Return results - return (df, all_probas) if return_probas else df + return (df, all_results)