Skip to content

Commit 775cb7f

Browse files
committed
initial commit
1 parent 2c5196e commit 775cb7f

26 files changed

+2816
-0
lines changed

createFolds.m

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
function folds = createFolds(cfg0, G)
2+
% [folds] = createFolds(cfg, G)
3+
% Divides trials into folds for cross-validation. The function attempts
4+
% to divide the trials as evenly as possible over the number of folds,
5+
% while trying to balance the different conditions as much as possible.
6+
%
7+
% cfg Configuration struct that can possess the following fields:
8+
% .nFold = [scalar] Number of folds the trials should be divided over.
9+
%
10+
% G Vector of length N, where N is the number of trials, that specifies
11+
% the condition to which that trial belongs, as identified by a unique number.
12+
%
13+
% folds A cell-array of length cfg.nFold that contains in each cell the indices
14+
% of the trials belonging to that particular fold.
15+
16+
% Created by Pim Mostert, 2016
17+
18+
G = G(:);
19+
20+
CONDS = unique(G);
21+
N_conds = length(CONDS);
22+
23+
folds = cell(cfg0.nFold, 1);
24+
for iCond = 1:N_conds
25+
% Find indices
26+
index = find(G == CONDS(iCond));
27+
nIndex = length(index);
28+
29+
% Shuffle
30+
index = index(randperm(nIndex));
31+
32+
% Distribute across folds
33+
groupNumber = floor((0:(nIndex-1))*(cfg0.nFold/nIndex))+1;
34+
35+
for iFold = 1:cfg0.nFold
36+
folds{iFold} = [folds{iFold}, index(groupNumber==iFold)'];
37+
end
38+
end
39+
40+
end

decodeCrossValidation.m

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
function Xhat = decodeCrossValidation(cfg0, X, Y)
2+
% [Xhat] = decodeCrossValidation(cfg, X, Y)
3+
% Implements k-fold cross-validation, in which a subset of trials is left out
4+
% in each iteration as testing data, while training on the remaining trials.
5+
%
6+
% cfg Configuration struct that can possess the following fields:
7+
% .trainfun = [function_name] The training function that is used for training.
8+
% .traincfg = [struct] The configuration struct that will be passed on to
9+
% the training function. Default = [];
10+
% .decodefun = [function_name] The decoding function that is used for decoding.
11+
% .decodecfg = [struct] The configuration struct that will be passed on to
12+
% the decoding function. Default = [].
13+
% .folds = [cell_array] A cell-array of length k, where k is the number of folds,
14+
% in which each cell contains a vector with the trial numbers
15+
% belonging to that particular fold.
16+
% .feedback = 'yes' or 'no' Whether the function should provide feedback on its progress.
17+
% Default = 'no'.
18+
%
19+
% X Matrix of arbitrary dimensions, but of which the last dimension is N, that contains
20+
% the training information. In each fold, a selection of this matrix (along the last
21+
% dimension) is sent to the training function.
22+
%
23+
% Y Matrix of arbitrary dimensions, but of which the last dimension corresponds to the
24+
% number of trials N, that contains the data. In each fold, a selection of this matrix
25+
% (along the last dimension) is sent to the training and decoding function.
26+
%
27+
% Xhat Matrix of dimensions as output by the decoding functiong, plus an additional dimension
28+
% of length N, that contains the decoded data.
29+
%
30+
% See also CREATEFOLDS
31+
32+
% Created by Pim Mostert, 2016
33+
34+
tStart = tic;
35+
36+
if ~isfield(cfg0, 'traincfg')
37+
cfg0.traincfg = [];
38+
end
39+
if ~isfield(cfg0, 'decodecfg')
40+
cfg0.decodecfg = [];
41+
end
42+
if ~isfield(cfg0, 'feedback')
43+
cfg0.feedback = 'no';
44+
end
45+
46+
dimsY = size(Y);
47+
48+
numN = dimsY(end);
49+
numFold = length(cfg0.folds);
50+
51+
%% Reshape data to allow for arbitrary dimensionality
52+
Y = reshape(Y, [prod(dimsY(1:(end-1))), numN]);
53+
if isvector(X)
54+
X = X(:)';
55+
dimsX = size(X);
56+
else
57+
dimsX = size(X);
58+
X = reshape(X, [prod(dimsX(1:(end-1))), numN]);
59+
end
60+
61+
%% Do first fold manually, to determine output size of decoder
62+
iFold = 1;
63+
64+
tFold = tic;
65+
66+
index_train = cell2mat(cfg0.folds((1:numFold) ~= iFold)');
67+
index_decode = cfg0.folds{iFold};
68+
69+
% Select training data
70+
Y_train = reshape(Y(:, index_train), [dimsY(1:(end-1)), length(index_train)]);
71+
X_train = reshape(X(:, index_train), [dimsX(1:(end-1)), length(index_train)]);
72+
73+
% Train decoder
74+
decoder = feval(cfg0.trainfun, cfg0.traincfg, X_train, Y_train);
75+
76+
% Select data to be decoded
77+
Y_decode = reshape(Y(:, index_decode), [dimsY(1:(end-1)), length(index_decode)]);
78+
79+
% Decode data
80+
Xhat_curFold = feval(cfg0.decodefun, cfg0.decodecfg, decoder, Y_decode);
81+
82+
% Feedback
83+
if strcmp(cfg0.feedback, 'yes')
84+
fprintf('%s: finished fold %g/%g - it took %.2f s\n', mfilename, iFold, numFold, toc(tFold));
85+
end
86+
87+
%% Allocate memory for results and do rest of folds
88+
dimsOut = size(Xhat_curFold);
89+
dimsOut = dimsOut(1:(end-1));
90+
91+
Xhat = zeros([prod(dimsOut), numN]);
92+
Xhat(:, index_decode) = reshape(Xhat_curFold, [prod(dimsOut), length(index_decode)]);
93+
94+
for iFold = 2:numFold
95+
tFold = tic;
96+
97+
index_train = cell2mat(cfg0.folds((1:numFold) ~= iFold)');
98+
index_decode = cfg0.folds{iFold};
99+
100+
% Select training data
101+
Y_train = reshape(Y(:, index_train), [dimsY(1:(end-1)), length(index_train)]);
102+
X_train = reshape(X(:, index_train), [dimsX(1:(end-1)), length(index_train)]);
103+
104+
% Train decoder
105+
decoder = feval(cfg0.trainfun, cfg0.traincfg, X_train, Y_train);
106+
107+
% Select data to be decoded
108+
Y_decode = reshape(Y(:, index_decode), [dimsY(1:(end-1)), length(index_decode)]);
109+
110+
% Decode data
111+
Xhat_curFold = feval(cfg0.decodefun, cfg0.decodecfg, decoder, Y_decode);
112+
Xhat(:, index_decode) = reshape(Xhat_curFold, [prod(dimsOut), length(index_decode)]);
113+
114+
% Feedback
115+
if strcmp(cfg0.feedback, 'yes')
116+
fprintf('%s: finished fold %g/%g - it took %.2f s\n', mfilename, iFold, numFold, toc(tFold));
117+
end
118+
end
119+
120+
%% Return
121+
Xhat = reshape(Xhat, [dimsOut, numN]);
122+
123+
if strcmp(cfg0.feedback, 'yes')
124+
fprintf('%s - all finished - it took %.2f s\n', mfilename, toc(tStart));
125+
end
126+
127+
end
128+
129+

decode_LDA.m

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
function Xhat = decode_LDA(cfg0, decoder, Y)
2+
% [Xhat] = decode_LDA(cfg, decoder, Y)
3+
% Decodes data using a linear discriminant analysis decoder as obtained from
4+
% an appropriate training function, e.g. train_LDA.
5+
%
6+
% decoder The decoder as obtained from an appropriate training function.
7+
%
8+
% Y Matrix of size F x N, where N is the number of trials and F the number of
9+
% features, that contains the data to be decoded.
10+
%
11+
% cfg Configuration struct that can possess the following fields:
12+
% .demean = 'yes' or 'no' Whether the data should be demeaned (per feature,
13+
% over trials) prior to decoding, based on the mean
14+
% of the training data. Default = 'yes'.
15+
%
16+
% Xhat Vector of length N that contains the decoded data.
17+
%
18+
% See also TRAIN_LDA
19+
20+
% Created by Pim Mostert, 2016
21+
22+
%% Pre-process cfg-struct
23+
if ~isfield(cfg0, 'demean')
24+
cfg0.demean = 'yes';
25+
end
26+
27+
%% Pre-process data
28+
numN = size(Y, 2);
29+
30+
% Demean
31+
if strcmp(cfg0.demean, 'yes')
32+
Y = Y - repmat(decoder.mY', [1, numN]);
33+
end
34+
35+
%% Decode
36+
Xhat = decoder.W*Y;
37+
38+
end

decode_array.m

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
function Xhat = decode_array(cfg0, decoder, Y)
2+
% [Xhat] = decode_array(cfg, decoders, Y)
3+
% Decodes the data along a specified dimension, e.g. time, using an array of decoders.
4+
%
5+
% decoders Cell-vector of length D, that contains an array of decoders, where D is the number
6+
% of decoders. This array may be obtained from an appropriate training function
7+
% such as train_array.
8+
%
9+
% Y Matrix of arbitrary dimension that contains the data to be decoded, though the last
10+
% two dimensions should be D and N, respectively, where N is the number of trials.
11+
% For example, [sensors x time x trials] if the array of decoders was trained along time.
12+
%
13+
% cfg Configuration struct that can possess the following fields:
14+
% .decodefun = [function_name] The decoding function to which each of the decoders
15+
% is passed on.
16+
% .decodecfg = [struct] The configuration struct that will be passed on to
17+
% the decoding function. Default = [];
18+
% .feedback = 'yes' or 'no' Whether the function should provide feedback on its progress.
19+
% Default = 'no'.
20+
%
21+
% Xhat Matrix of dimensions as output by the decoding functiong, plus the additional dimensions
22+
% of D and N, that contains the decoded data for each decoder and trial.
23+
%
24+
% See also TRAIN_ARRAY and DECODE_ARRAYGENERALIZATION
25+
26+
% Created by Pim Mostert, 2016
27+
28+
tStart = tic;
29+
30+
if ~isfield(cfg0, 'decodecfg')
31+
cfg0.decodecfg = [];
32+
end
33+
if ~isfield(cfg0, 'feedback')
34+
cfg0.feedback = 'no';
35+
end
36+
37+
dims = size(Y);
38+
39+
numN = dims(end);
40+
numDec = dims(end-1);
41+
42+
dimsSub = dims(1:(end-2));
43+
44+
%% Reshape data to allow for arbitrary dimensionality
45+
Y = reshape(Y, [prod(dimsSub), numDec, numN]);
46+
47+
%% Do first decoder manually, to obtain output size of decoder
48+
Xhat_curDec = feval(cfg0.decodefun, cfg0.decodecfg, decoder{1}, squeeze(Y(:, 1, :)));
49+
dimsOut = size(Xhat_curDec);
50+
dimsOut = dimsOut(1:(end-1));
51+
52+
%% Allocate memory for output and iterate over remaining decoders
53+
Xhat = zeros([prod(dimsOut), numDec, numN]);
54+
Xhat(:, 1, :) = reshape(Xhat_curDec, [prod(dimsOut), numN]);
55+
56+
tDec = tic;
57+
for iDec = 2:numDec
58+
Xhat_curDec = feval(cfg0.decodefun, cfg0.decodecfg, decoder{iDec}, squeeze(Y(:, iDec, :)));
59+
Xhat(:, iDec, :) = reshape(Xhat_curDec, [prod(dimsOut), numN]);
60+
61+
if (toc(tDec) > 2) && strcmp(cfg0.feedback, 'yes')
62+
fprintf('%s - finished decoder %g/%g\n', mfilename, iDec, numDec);
63+
tDec = tic;
64+
end
65+
end
66+
67+
%% Reshape output
68+
Xhat = reshape(Xhat, [dimsOut, numDec, numN]);
69+
70+
if strcmp(cfg0.feedback, 'yes')
71+
fprintf('%s - all finished - it took %.2f s\n', mfilename, toc(tStart));
72+
end
73+
74+
end
75+
76+
77+
78+
79+
80+

decode_arrayGeneralization.m

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
function Xhat = decode_arrayGeneralization(cfg0, decoder, Y)
2+
% [Xhat] = train_arrayGeneralization(cfg, decoders, Y)
3+
% Decodes the data along a specified dimension, e.g. time, using an array of decoders. Additionally,
4+
% each decoder is also applied to all other points along this dimension. For example, this function
5+
% can implement the temporal generalization method (King & Dehaene, 2014).
6+
%
7+
% decoders Cell-vector of length D, that contains an array of decoders, where D is the number
8+
% of decoders. This array may be obtained from an appropriate training function
9+
% such as train_array.
10+
%
11+
% Y Matrix of arbitrary dimension that contains the data to be decoded, though the last
12+
% two dimensions should be T and N, respectively, where T is the dimension along which all
13+
% decoders should be applied iteratively and N is the number of trials. For example,
14+
% [sensors x time x trials] if the array of decoders should be applied over time. Note that
15+
% although D and T are likely to correspond to the same quantity, e.g. time, they do not
16+
% necessarily have the same size, for instance when training on one task and generalizing
17+
% to another task. If Y is 2-dimensional, then it is assumed to correspond to one single
18+
% trial, i.e. [sensors x time x 1].
19+
%
20+
% cfg Configuration struct that can possess the following fields:
21+
% .decodefun = [function_name] The decoding function to which each of the decoders
22+
% is passed on.
23+
% .decodecfg = [struct] The configuration struct that will be passed on to
24+
% the decoding function. Default = [];
25+
% .feedback = 'yes' or 'no' Whether the function should provide feedback on its progress.
26+
% Default = 'no'.
27+
% .oneTrial = 'yes' or 'no' <yet to be described>
28+
%
29+
% Xhat Matrix of dimensions as output by the decoding functiong, plus the additional dimensions
30+
% of D, D and N, that contains the decoded data for each decoder, applied to all other
31+
% points along the dimension of interast, for each trial.
32+
%
33+
% See also TRAIN_ARRAY and DECODE_ARRAY
34+
35+
% Created by Pim Mostert, 2016
36+
37+
tStart = tic;
38+
39+
if ~isfield(cfg0, 'decodecfg')
40+
cfg0.decodecfg = [];
41+
end
42+
if ~isfield(cfg0, 'feedback')
43+
cfg0.feedback = 'no';
44+
end
45+
if ~isfield(cfg0, 'oneTrial')
46+
cfg0.oneTrial = 'no';
47+
end
48+
49+
dims = size(Y);
50+
51+
if (length(size(Y)) == 2)
52+
dims = [dims 1];
53+
end
54+
55+
numN = dims(end);
56+
numD = length(decoder);
57+
numT = dims(end-1);
58+
59+
dimsSub = dims(1:(end-2));
60+
61+
%% Reshape data to allow for arbitrary dimensionality
62+
Y = reshape(Y, [prod(dimsSub), numT, numN]);
63+
64+
%% Do first decoder manually, to obtain output size of decoder
65+
Xhat_curDec = feval(cfg0.decodefun, cfg0.decodecfg, decoder{1}, reshape(Y, [dimsSub, numT*numN]));
66+
dimsOut = size(Xhat_curDec);
67+
dimsOut = dimsOut(1:(end-1));
68+
69+
%% Allocate memory for output and iterate over remaining decoders
70+
Xhat = zeros([prod(dimsOut), numD, numT, numN]);
71+
Xhat(:, 1, :, :) = reshape(Xhat_curDec, [prod(dimsOut), numT, numN]);
72+
73+
tDec = tic;
74+
for iDec = 2:numD
75+
Xhat_curDec = feval(cfg0.decodefun, cfg0.decodecfg, decoder{iDec}, reshape(Y, [dimsSub, numT*numN]));
76+
Xhat(:, iDec, :, :) = reshape(Xhat_curDec, [prod(dimsOut), numT, numN]);
77+
78+
if (toc(tDec) > 2) && strcmp(cfg0.feedback, 'yes')
79+
fprintf('%s - finished decoder %g/%g\n', mfilename, iDec, numD);
80+
tDec = tic;
81+
end
82+
end
83+
84+
%% Reshape output
85+
Xhat = reshape(Xhat, [dimsOut, numD, numT, numN]);
86+
87+
if strcmp(cfg0.feedback, 'yes')
88+
fprintf('%s - all finished - it took %.2f s\n', mfilename, toc(tStart));
89+
end
90+
91+
end
92+
93+
94+
95+
96+
97+

0 commit comments

Comments
 (0)