-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdriverGradientFlow.m
More file actions
101 lines (84 loc) · 2.8 KB
/
driverGradientFlow.m
File metadata and controls
101 lines (84 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
%
% Example for applying hybrid regularization tools to a random feature
% model using the MNIST/CIFAR dataset
%
% In this file, columns of his are training square error, testing square
% error and norm of the weights, respectively.
% clear all;
close all;
rng("default")
rng(1)
doPlot = true;
nTrain = 2^10;
nVal = 10000;
if not(exist('dataset','var'))
dataset = 'CIFAR10'; % 'MNIST' or 'CIFAR10';
end
sample = 'Sd';
switch dataset
case 'MNIST'
[Y,C] = setupMNIST(nTrain+nVal);
case 'CIFAR10'
[Y,C] = setupCIFAR10(nTrain+nVal);
otherwise
warning('no such data!')
return
end
dim1=size(Y,1);dim2=size(Y,2);dim3=size(Y,3);
Y = normalizeData(Y,dim1*dim2*dim3);
id = randperm(size(C,2));
idt = id(1:nTrain);
idv = id(nTrain+1:end);
Yt = reshape(Y(:,:,:,idt),dim1*dim2*dim3,[]); Ct = C(:,idt);
Yv = reshape(Y(:,:,:,idv),dim1*dim2*dim3,[]); Cv = C(:,idv);
ms = 2.^(4:15);
his = zeros(numel(ms),2);
tt = logspace(-8,8,100);
ftest_all = zeros(numel(ms),numel(tt));
ftrain_all = zeros(numel(ms),numel(tt));
for k=1:numel(ms)
m = ms(k);
fprintf('%s : \t dataset=%s, \t m=%d\n',mfilename,dataset,m);
switch sample
case 'Sd'
K = sampleSd(dim1*dim2*dim3,m-1);
b = sampleSd(m-1,1)';
otherwise
sample = 'uniform';
K = 2*(rand(m-1,dim1*dim2*dim3)-0.5);
b = 2*(rand(m-1,1)-0.5);
end
Zt = [max(K*Yt+b,0); ones(1,size(Yt,2))];
Zv = [max(K*Yv+b,0); ones(1,size(Yv,2))];
[U,S,V] = svd(Zt, 'econ');
diagS = diag(S);
phiS = @(alpha) (1-exp(-diagS.^2*alpha/numel(Zt)))./diagS; % filter factors from Ma et al.
WOpt = @(alpha) (Ct*V)*(phiS(alpha).*U');
train_error = @(alpha) norm(WOpt(alpha)*Zt-Ct,'fro')^2/(2*size(Zt,2));
test_error = @(alpha) norm(WOpt(alpha)*Zv-Cv,'fro')^2/(2*size(Zv,2));
ftest = 0*tt;
ftrain = 0*tt;
for j=1:numel(tt)
ftest(j) = test_error(tt(j));
ftrain(j) = train_error(tt(j));
end
ftest_all(k,:) = ftest;
ftrain_all(k,:) = ftrain;
[f0,j0] = min(ftest);
[opt_alpha,opt_error,flag] = fminsearch(test_error,tt(j0));
if doPlot
fig = figure(); clf;
fig.Name = sprintf('GF_%s,m-%d',dataset,m);
loglog(tt,ftest,'LineWidth',2,'DisplayName','test error')
hold on;
loglog(opt_alpha,opt_error,'.r','MarkerSize',30,'DisplayName','optimal')
legend()
drawnow
end
his(k,:) = [opt_alpha, opt_error];
if flag~=1
warning('fminbnd did not converge');
end
fprintf('m=%d\topt_alpha=%1.2e\topt_error=%1.4f\n',m,opt_alpha,opt_error);
end
save(sprintf('%s_%s_%s.mat',mfilename,dataset,sample),'his','ms','tt','ftest_all','ftrain_all')