-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainKernel.m
50 lines (39 loc) · 1.63 KB
/
TrainKernel.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
function [Y2hat,results,model] = TrainKernel(K,K2,Y,Y2,D,p,gamma,e,C,regression)
if strcmpi(regression,'krr')
H = K + gamma*eye(size(K));
model = H\Y(p+1:end-D);
Y2hat = K2*model;
svs = 100;
elseif strcmpi(regression,'svr')
% -----------------------------------------------------------------------------
% Training
% -----------------------------------------------------------------------------
H = K; % Standard SVR
model = svmtrain(Y(p+1:end-D),H,['-t 4 -s 3 -c ' num2str(C) ' -p ' num2str(e)]);
svs = 100*length(model.SVs)/length(Y(p+1:end-D));
% -----------------------------------------------------------------------------
% Testing
% -----------------------------------------------------------------------------
H2 = K2; % Standard SVR
Y2hat = svmpredict(Y2(p+1:end-D),H2,model);
elseif strcmpi(regression,'rvm')
% Training
H = K; YT = Y(p+1:end-D);
% Testing
H2 = K2;
% Fixed hyperparameters
alpha = 1;
beta = 1e3;
maxIts = 1000;
monIts = 1;
[weights, used] = sbl_estimate(H,YT,alpha,beta,maxIts,monIts);
PHI2 = H2(:,used);
Y2hat = PHI2*weights;
svs = 100*length(weights~=0)/length(YT);
model.weights = weights; model.used = used; model.H=H;model.H2=H2; model.YT=YT;
end;
% --------------------------------------------------------------------------------------------------------------------------
% Results in test
% --------------------------------------------------------------------------------------------------------------------------
results = ComputeResults(Y2hat,Y2(p+1:end-D));
results.svs = svs;