diff --git a/matlab/func/cpm_cv.m b/matlab/func/cpm_cv.m index 3bb91d8..bd8d7fe 100644 --- a/matlab/func/cpm_cv.m +++ b/matlab/func/cpm_cv.m @@ -7,10 +7,16 @@ % y_test y data used for testing % y_predict Predictions of y data used for testing -% Split data +% cross-validation indices nsubs=size(x,2); -randinds=randperm(nsubs); -ksample=floor(nsubs/kfolds); +all_ind=[]; +fold_size=floor(nsubs/kfolds); +for idx=1:kfolds + all_ind=[all_ind; idx*ones(fold_size, 1)]; +end +leftover=mod(nsubs, kfolds); +indices=[all_ind; randperm(kfolds, leftover)']; +indices=indices(randperm(length(indices))); y_predict = zeros(nsubs, 1); % Run CPM over all folds @@ -18,16 +24,9 @@ for leftout = 1:kfolds fprintf('%1.0f ',leftout); - if kfolds == nsubs % doing leave-one-out - testinds=randinds(leftout); - traininds=setdiff(randinds,testinds); - else - si=1+((leftout-1)*ksample); - fi=si+ksample-1; - - testinds=randinds(si:fi); - traininds=setdiff(randinds,testinds); - end + testinds=(indices==leftout); + traininds=(indices~=leftout); + % Assign x and y data to train and test groups x_train = x(:,traininds);