From a78542352f1dbf9d4f4e6cc5316f09e79d527b75 Mon Sep 17 00:00:00 2001 From: mattrosenblatt7 <52167418+mattrosenblatt7@users.noreply.github.com> Date: Sun, 7 Feb 2021 21:57:06 -0500 Subject: [PATCH] update cpm_cv.m adjusted cross-validation indices to avoid completely leaving some subject out --- matlab/func/cpm_cv.m | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/matlab/func/cpm_cv.m b/matlab/func/cpm_cv.m index 3bb91d8..bd8d7fe 100644 --- a/matlab/func/cpm_cv.m +++ b/matlab/func/cpm_cv.m @@ -7,10 +7,16 @@ % y_test y data used for testing % y_predict Predictions of y data used for testing -% Split data +% cross-validation indices nsubs=size(x,2); -randinds=randperm(nsubs); -ksample=floor(nsubs/kfolds); +all_ind=[]; +fold_size=floor(nsubs/kfolds); +for idx=1:kfolds + all_ind=[all_ind; idx*ones(fold_size, 1)]; +end +leftover=mod(nsubs, kfolds); +indices=[all_ind; randperm(kfolds, leftover)']; +indices=indices(randperm(length(indices))); y_predict = zeros(nsubs, 1); % Run CPM over all folds @@ -18,16 +24,9 @@ for leftout = 1:kfolds fprintf('%1.0f ',leftout); - if kfolds == nsubs % doing leave-one-out - testinds=randinds(leftout); - traininds=setdiff(randinds,testinds); - else - si=1+((leftout-1)*ksample); - fi=si+ksample-1; - - testinds=randinds(si:fi); - traininds=setdiff(randinds,testinds); - end + testinds=(indices==leftout); + traininds=(indices~=leftout); + % Assign x and y data to train and test groups x_train = x(:,traininds);