Skip to content

Commit 9734aae

Browse files
committed
added a probabilistic multiclass classifier: train_probClass.m and decode_probClass.m
1 parent 691d218 commit 9734aae

9 files changed

+488
-1
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Data
22
data/*
33
*.m~
4+
STDIN*

decode_beamformer.m

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Y = Y - repmat(m, [1, numN]);
4444
elseif isnumeric(cfg0.demean)
4545
Y = Y - repmat(cfg0.demean, [1, numN]);
46+
elseif strcmp(cfg0.demean, 'no')
4647
else
4748
error('Demeaning configuration ''%s'' is unknown', cfg0.demean);
4849
end

decode_pattern.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
% .covariance Whether the pattern should be multiplied by the
1818
% inverse of a covariance matrix.
1919
% = 'testData' The nancov of the testing data. Specify
20-
% covariance.gamma for regularization.
20+
% cfg.gamma for regularization.
2121
% = [F x F] vector Manually specified covariance matrix, where F is
2222
% the number of features (e.g. sensors).
2323
% = 'no' (default)

decode_probClass.m

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
function pPost = decode_probClass(cfg0, decoder, Y)
2+
% pPost = decode_beamformer(cfg, decoder, Y)
3+
% Apply probabilistic multiclass decoder, obtained from an appropriate training function. It
4+
% returns the posterior probabilities of a the trial belonging to a class, given the data. At
5+
% this moment the inclusion of (non-uniform) prior probabilities is however not yet implemented.
6+
%
7+
% decoder The probabilistic multiclass decoder obtained from e.g. train_probClass.
8+
% Y Matrix of size F x N, where F is the number of features and the N the number of trials,
9+
% that contains the data that is to be classified.
10+
% cfg Configuration struct that can possess the following fields:
11+
% .demean Whether the data should be demeaned (per feature,
12+
% over trials) prior to decoding. The mean can be
13+
% specified in the following ways:
14+
% = 'trainData' The mean of the training data (default).
15+
% = 'testData' The mean of the testing data.
16+
% = [F x 1] vector Manually specified mean, where F is the number of
17+
% features (e.g. sensors).
18+
% = 'no' No demeaning.
19+
%
20+
% pPost Matrix of size C x N, where C is the number of classes, that contains the posterior
21+
% probabilities of the trial belonging to each of the classes, given the data.
22+
%
23+
% See also TRAIN_PROBCLASS.
24+
25+
% Created by Pim Mostert, 2017
26+
27+
if ~isfield(cfg0, 'demean')
28+
cfg0.demean = 'trainData';
29+
end
30+
31+
% Useful variables
32+
numN = size(Y, 2);
33+
numC = size(decoder.W, 1);
34+
35+
% Demean
36+
if strcmp(cfg0.demean, 'trainData')
37+
if ~isfield(decoder, 'mY')
38+
error('No mean found in decoder');
39+
end
40+
41+
Y = Y - repmat(decoder.mY, [1, numN]);
42+
elseif strcmp(cfg0.demean, 'testData')
43+
m = nanmean(Y, 2);
44+
45+
Y = Y - repmat(m, [1, numN]);
46+
elseif isnumeric(cfg0.demean)
47+
Y = Y - repmat(cfg0.demean, [1, numN]);
48+
elseif strcmp(cfg0.demean, 'no')
49+
else
50+
error('Demeaning configuration ''%s'' is unknown', cfg0.demean);
51+
end
52+
53+
% Decode
54+
pPost = exp(decoder.W*Y + repmat(decoder.b, [1, numN]));
55+
pPost = pPost ./ repmat(sum(pPost, 1), [numC, 1]);
56+
57+
end

development/decode_probClass.m

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
function pPost = decode_probClass(cfg0, decoder, Y)
2+
% pPost = decode_beamformer(cfg, decoder, Y)
3+
% Apply probabilistic multiclass decoder, obtained from an appropriate training function. It
4+
% returns the posterior probabilities of a the trial belonging to a class, given the data. At
5+
% this moment the inclusion of (non-uniform) prior probabilities is however not yet implemented.
6+
%
7+
% decoder The probabilistic multiclass decoder obtained from e.g. train_probClass.
8+
% Y Matrix of size F x N, where F is the number of features and the N the number of trials,
9+
% that contains the data that is to be classified.
10+
% cfg Configuration struct that can possess the following fields:
11+
% .demean Whether the data should be demeaned (per feature,
12+
% over trials) prior to decoding. The mean can be
13+
% specified in the following ways:
14+
% = 'trainData' The mean of the training data (default).
15+
% = 'testData' The mean of the testing data.
16+
% = [F x 1] vector Manually specified mean, where F is the number of
17+
% features (e.g. sensors).
18+
% = 'no' No demeaning.
19+
%
20+
% pPost Matrix of size C x N, where C is the number of classes, that contains the posterior
21+
% probabilities of the trial belonging to each of the classes, given the data.
22+
%
23+
% See also TRAIN_PROBCLASS.
24+
25+
% Created by Pim Mostert, 2017
26+
27+
if ~isfield(cfg0, 'demean')
28+
cfg0.demean = 'trainData';
29+
end
30+
31+
% Useful variables
32+
numN = size(Y, 2);
33+
numC = size(decoder.W, 1);
34+
35+
% Demean
36+
if strcmp(cfg0.demean, 'trainData')
37+
if ~isfield(decoder, 'mY')
38+
error('No mean found in decoder');
39+
end
40+
41+
Y = Y - repmat(decoder.mY, [1, numN]);
42+
elseif strcmp(cfg0.demean, 'testData')
43+
m = nanmean(Y, 2);
44+
45+
Y = Y - repmat(m, [1, numN]);
46+
elseif isnumeric(cfg0.demean)
47+
Y = Y - repmat(cfg0.demean, [1, numN]);
48+
elseif strcmp(cfg0.demean, 'no')
49+
else
50+
error('Demeaning configuration ''%s'' is unknown', cfg0.demean);
51+
end
52+
53+
% Decode
54+
pPost = exp(decoder.W*Y + repmat(decoder.b, [1, numN]));
55+
pPost = pPost ./ repmat(sum(pPost, 1), [numC, 1]);
56+
57+
end

development/example_multiClass.m

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
clear all;
2+
3+
addpath('../');
4+
5+
% Gives me: X (labels), Y (data), time, label
6+
tmp = load('../data/testdata_orientation.mat');
7+
8+
Y = tmp.Y; % Y: features x time x trials
9+
X = tmp.X';
10+
time = tmp.time;
11+
label = tmp.label;
12+
13+
phi = X * (180/8); % Presented orientation in degrees
14+
15+
numF = size(Y, 1);
16+
numT = size(Y, 2);
17+
numN = size(Y, 3);
18+
19+
%% Probabilistic classification
20+
% Create folds
21+
cfg = [];
22+
cfg.nFold = 5;
23+
folds = createFolds(cfg, phi);
24+
25+
cfg = [];
26+
cfg.folds = folds;
27+
cfg.feedback = 'yes';
28+
cfg.trainfun = 'train_array';
29+
cfg.traincfg.feedback = 'yes';
30+
cfg.traincfg.trainfun = 'train_probClass';
31+
cfg.traincfg.traincfg.gamma = 0.01;
32+
cfg.decodefun = 'decode_arrayGeneralization';
33+
cfg.decodecfg.feedback = 'yes';
34+
cfg.decodecfg.decodefun = 'decode_probClass';
35+
36+
pPost = decodeCrossValidation(cfg, phi, Y);
37+
38+
% Extract classes
39+
[~, class] = max(pPost, [], 1);
40+
class = squeeze(class);
41+
42+
% Classification accuracy
43+
correct = (class == repmat(permute((X+1), [1 3 2]), [numT, numT]));
44+
45+
[~, pCorrect] = ttest(correct*1, 1/8, 'dim', 3);
46+
mCorrect = mean(correct, 3);
47+
48+
figure; colormap(jet(256));
49+
subplot(2, 1, 1); imagesc(time, time, mCorrect); colorbar; eqClims(1/8); axis image; axis xy;
50+
subplot(2, 1, 2); imagesc(time, time, log10(pCorrect) .* (pCorrect < 0.05)); colorbar; axis image; axis xy;
51+
52+
numC = 8;
53+
54+
pPostCorrect = zeros(numT, numT, numN);
55+
for ic = 1:numC
56+
pPostCorrect(:, :, X==(ic-1)) = pPost(ic, :, :, X==(ic-1));
57+
end
58+
59+
[~, ppPost] = ttest(pPostCorrect, 1/8, 'dim', 3);
60+
mCorrect = mean(pPostCorrect, 3);
61+
62+
% Plot
63+
figure; colormap(jet(256));
64+
subplot(2, 1, 1); imagesc(time, time, mCorrect); colorbar; eqClims(1/8); axis image; axis xy;
65+
subplot(2, 1, 2); imagesc(time, time, log10(ppPost) .* (ppPost < 0.05)); colorbar; axis image; axis xy;
66+
67+
%% B&H model
68+
% Create design matrix
69+
numC = 8;
70+
71+
cfg = [];
72+
cfg.numC = numC;
73+
cfg.tuningCurve = 'vonMises';
74+
cfg.kappa = 5;
75+
76+
design = designMatrix_BH(cfg, phi);
77+
78+
% Create folds
79+
cfg = [];
80+
cfg.nFold = 5;
81+
folds = createFolds(cfg, X);
82+
83+
cfg = [];
84+
cfg.folds = folds;
85+
cfg.feedback = 'yes';
86+
cfg.trainfun = 'train_array';
87+
cfg.traincfg.feedback = 'yes';
88+
cfg.traincfg.trainfun = 'train_beamformer';
89+
cfg.traincfg.traincfg.gamma = 0.01;
90+
cfg.decodefun = 'decode_arrayGeneralization';
91+
cfg.decodecfg.feedback = 'yes';
92+
cfg.decodecfg.decodefun = 'decode_beamformer';
93+
94+
Xhat = decodeCrossValidation(cfg, design, Y);
95+
96+
% Extract single orientations and correlate
97+
kernel = exp(1i * (0:(numC-1)) * (2*pi/numC));
98+
Z = reshape(kernel*reshape(Xhat, [numC, numT*numT*numN]), [numT, numT, numN]);
99+
theta = mod(angle(Z), 2*pi) * (180/pi) * 0.5; % Decoded orientation
100+
101+
r = exp(1i * (theta - repmat(permute(phi, [1 3 2]), [numT, numT, 1])) * (pi/180)*2);
102+
r = abs(r) .* cos(angle(r));
103+
104+
[~, p] = ttest(r, 0, 'dim', 3);
105+
mr = mean(r, 3);
106+
107+
% Plot
108+
figure; colormap(jet(256));
109+
subplot(2, 1, 1); imagesc(time, time, mr); colorbar; eqClims; axis image; axis xy;
110+
subplot(2, 1, 2); imagesc(time, time, log10(p) .* (p < 0.05)); colorbar; axis image; axis xy;
111+
112+
113+
114+
115+
116+
117+
118+
119+
120+
121+
122+
123+

development/logistic_regression.m

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
clear all;
2+
3+
%% Load data
4+
% Gives me: X (labels), Y (data), time, label
5+
tmp = load('../data/testdata_orientation.mat');
6+
7+
Y = tmp.Y; % Y: features x time x trials
8+
X = tmp.X';
9+
time = tmp.time;
10+
label = tmp.label;
11+
12+
numF = size(Y, 1);
13+
numN = size(Y, 3);
14+
15+
%% Single time opint
16+
% Select time point
17+
sel_t = find(time >= .2, 1);
18+
19+
% Split into train- and test-set
20+
cfg = [];
21+
cfg.nFold = 2;
22+
folds = createFolds(cfg, X);
23+
24+
X_train = X(folds{1});
25+
Y_train = squeeze(Y(:, sel_t, folds{1}));
26+
27+
X_test = X(folds{2});
28+
Y_test = squeeze(Y(:, sel_t, folds{2}));
29+
30+
%% Train classifier
31+
numC = length(unique(X));
32+
gamma = 0.1;
33+
34+
% Calculate means and covariance matrix
35+
m = zeros(numF, numC);
36+
S = zeros(numF, numF);
37+
38+
for ic = 1:numC
39+
m(:, ic) = nanmean(Y_train(:, X_train==(ic-1)), 2);
40+
41+
S = S + nancov(Y_train(:, X_train==(ic-1))');
42+
end
43+
S = S/numC;
44+
45+
% Regularize
46+
S = (1-gamma)*S + gamma*eye(numF)*trace(S)/numF;
47+
48+
% Calculate weights and bias
49+
W = S\m;
50+
b = -0.5*diag(m'*W);
51+
52+
cfg = [];
53+
cfg.gamma = 0.1;
54+
cfg.demean = 'yes';
55+
decoder = train_probClass(cfg, X_train, Y_train);
56+
57+
cfg = [];
58+
cfg.demean = 'no';
59+
pPost = decode_probClass(cfg, decoder, Y_test);
60+
61+
decoder = train_beamformer(cfg, X_train, Y_train);
62+
63+
%% Test classifier
64+
p = exp(W'*Y_test + repmat(b, [1, length(X_test)]));
65+
p = p ./ repmat(sum(p, 1), [numC, 1]);
66+
67+
68+
mp = zeros(numC, numC);
69+
for ic = 1:numC
70+
mp(:, ic) = mean(p(:, X_test==(ic-1)), 2);
71+
end
72+
73+
figure;
74+
imagesc(mp); colorbar; axis image;
75+
76+
77+
78+
79+
80+

0 commit comments

Comments
 (0)