-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvolknnsearch.m
292 lines (251 loc) · 12.3 KB
/
volknnsearch.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
function [patches, pDst, pIdx, pRefIdxs, srcgridsize, refgridsize] = ...
volknnsearch(srcvol, refvols, patchSize, varargin)
% VOLKNNSEARCH k-NN search of patches in source given a set of reference volumes
% patches = volknnsearch(src, refs, patchSize) k-NN search of patches in source volume src given
% a set of reference volumes refs. refs can be a volume or a cell of volumes, of the same
% dimentionality as src. patchSize is the size of the patches. patches is a [M x V x K] array,
% with M being the number of patches, V the number of voxels in a patch, and K from kNN.
%
% patches = volknnsearch(src, refs, patchSize, srcPatchOverlap) allows the specification of a
% patchOverlap amount or kind for src. See patchlib.overlapkind for more information. Default is
% the default used in patchlib.vol2lib.
%
% patches = volknnsearch(src, refs, patchSize, srcPatchOverlap, refsPatchOverlap) allows the
% specification of a patchOverlap amount or kind for refs as well.
%
% patches = volknnsearch(..., Param, Value) allows for parameter/value pairs:
% - 'local' (voxel spacing integer): if desired to do local search around each voxel, instead of
% global (which is default).
%
% - 'searchfn': function handle --- allow for a different type of local search, rather than the
% standard local or global search functions. For example, this can be useful if you want to do
% a local search in some registration space, where the source location of a voxel corresponds
% to a different locationin each source. The function signature is (src, refs, patchSize,
% knnvarargin), where src is a struct with fields vol, lib, grididx, cropVolSize and gridSize
% and refs is a structs with fields vols, lib, grididx, cropVolSize, gridSize, refidx -- all
% cells of size nRefs x 1
%
% - 'location' - a location based weight. the location of each voxel (in spatial coordinates) is
% added to the feature vector, times the factor passed in via 'location'. default is 0
% (location does not factor in). scalar or [1 x nDims] vector.
%
% - 'buildreflibs': logical (default: true) on whether to pre-compute the reference libraries
% (which can take space and memory). This must stay true for the default global or local
% search functions.
%
% - 'mask' logical mask the size of the source vol, volknnsearch will only run for voxels where
% mask is true
%
% - 'fillK' logical. If true: when searching in a local window, the window might be too small
% for K, especially in volume corner/edges. true fillK will fill up K NN and assign them a
% distance infinity and location and reference indexes of 1.
%
% - 'libfn' (function handle).
% if you want to use your own library construction method, the signature is:
% >> libstruct = libfn(inputvol, patchSize, patchOverlap)
% where inputvol is either srcvol or refvols (as passed in to volknnsearch), and
% patchOverlap is optional only gets passed if it's passed into volknnsearch.
% libstruct should be a struct with fields vol, lib, grididx, cropVolSize, gridSize and
% optionally refidx
%
% Note: if using this option, srcvol and refvols don't actually need to be volumes. For
% example, they could be pre-computed libstructs, in which case one can use
% >> libfn = @(x, y, z) x;
%
% - 'excludePatches' (default: false);
%
% - 'separateProc' (default: 0)
% 0 - no separate processing
% 1 - separately process each reference and gather the results
% at the end - saves a bit of memory but might be a bit slower due to overhead
% [TODO: NOT IMPLEMENTED] X for X > 10 - separate by reference and do separate knnsearches
% where each call is
% - 'separateProcAgg' (default: agg) agg or sep.
%
% - any Param/Value argument for knnsearch.
%
% [patches, pDst, pIdx, pRefIdx, srcgridsize, refgridsize] = volknnsearch(...) also returns the
% patch indexes (M x K) into the reference libraries (i.e. matching refgridsize, *not* the
% entire reference volume(s)), pRefIdxs (Mx1) indexing the libraries for every patch, and pDst
% giving the distance of every patch resulting from knnsearch(). srcgridsize is the source grid
% size. refgridsize is the grid size of each ref (a cell if refs was cell, a vector otherwise.
%
% Contact: [email protected]
% parse inputs
narginchk(3, inf);
[refvolscell, srcoverlap, refoverlap, knnvarargin, inputs] = ...
parseinputs(refvols, patchSize, varargin{:});
% if processing references separately
nRefs = numel(refvolscell); % guaranteed refvolscell is a cell.
if inputs.separateProc == 1 && nRefs > 1
vargout = cell(nRefs, 1);
% exclude patches in this mode. Will build the patches at the end, if necessary.
varargin = setParamValue('excludePatches', true, varargin);
% do the search
srcpass = prepsrc(srcvol, patchSize, inputs, srcoverlap{:});
for i = 1:nRefs
if inputs.verbose, reftic = tic; end
vargout{i} = cell(6, 1);
[vargout{i}{:}] = patchlib.volknnsearch(srcpass, refvols{i}, patchSize, varargin{:});
vargout{i}{1} = []; % empty out patches
if inputs.verbose,
fprintf('volknnsearch with reference %d: %3.2f\n', i, toc(reftic));
end
end
% combine the results
[pDst, pIdx, pRefIdxs, srcgridsize, refgridsize] = volknnaggresults(vargout, inputs);
% get the patches if necessary. Note, this is slower than if we got the patches originally
% with the volknnsearch mode, but that can lead to memory issues in large numbers of
% references (since you're getting a factor of nRefs more patches than you need, which can
% grow quickly). Since separateProc is a low-memory mode, we choose this path.
patches = [];
if ~inputs.excludePatches
src = prepsrc(srcvol, patchSize, inputs, srcoverlap{:});
libfn = @(x, y) getfield(inputs.libfn(x, patchSize, refoverlap{:}), 'lib'); %#ok<GFLD>
patches = getpatches(src, refvolscell, patchSize, pIdx, pRefIdxs, libfn);
end
return
end
% compute source library
src = prepsrc(srcvol, patchSize, inputs, srcoverlap{:});
assert(size(src.lib, 1) > 0, 'Source library is empty');
% build the reference libraries
if inputs.buildreflibs
refs = inputs.libfn(refvolscell, patchSize, refoverlap{:});
else
for i = 1:nRefs, refs(i) = refvolscell{i}; end
end
if ~isempty(inputs.location) && any(inputs.location ~= 0);
% TODO: add nDim location support.
[src, refs] = addlocation(src, refs, inputs.location);
end
% compute the search
[pIdx, pRefIdxs, pDst] = inputs.searchfn(src, refs, knnvarargin{:});
srcgridsize = src.gridSize;
% extract the patches
if inputs.excludePatches
patches = [];
else
patches = getpatches(src, refs, patchSize, pIdx, pRefIdxs);
end
refgridsize = refs(1).gridSize;
% % call knnsearch
% [patches, pDst, pIdx, pRefIdxs, srcgridsize, refgridsize] = ...
% patchlib.knnsearch(srcvol, refvols, patchSize, varargin);
end
function src = prepsrc(srcvol, patchSize, inputs, varargin)
% prepare library and mask
if isstruct(srcvol) && isfield(srcvol, 'lib');
src = srcvol;
else
src = inputs.libfn(srcvol, patchSize, varargin{:});
end
src.mask = ifelse(isempty(inputs.mask), true(size(src.vol)), inputs.mask);
end
function patches = getpatches(src, refs, patchSize, pIdx, pRefIdxs, varargin)
% patches = getpatches(src, refs, patchSize, pIdx, pRefIdxs) or
% patches = getpatches(src, refs, patchSize, pIdx, pRefIdxs, libfn)
% TODO: what about the patchOverlap? need to pass this in as well? ot ro lib2patches
mask = src.mask(src.grididx);
if ~isempty(varargin)
pm = patchlib.lib2patches(refs, pIdx(mask, :, :), pRefIdxs(mask, :, :), patchSize, varargin{:});
else
refslibs = cellfunc(@(x) x(:, 1:prod(patchSize)), {refs.lib});
pm = patchlib.lib2patches(refslibs, pIdx(mask, :, :), pRefIdxs(mask, :, :), patchSize);
end
% TODO - unsure. if mask is *very* small compared to the volume, might need to work in
% sparse? but then, most other functions need to worry about memory as well.
if sum(~mask(:)) > 0
% patches = maskvox2vol(pm, mask, @sparse);
patches = maskvox2vol(pm, mask);
else
patches = pm;
end
end
function [src, refs] = addlocation(src, refs, locwt)
% adds location subscripts to the src + refs libraries.
srcsub = size2sub(size(src.vol));
srcvec = cat(2, srcsub{:});
src.lib = [src.lib, bsxfun(@times, locwt, srcvec(src.grididx, :))];
refsub = cellfunc(@(x) size2sub(size(x)), {refs.vol});
refvecs = cellfunc(@(x) cat(2, x{:}), refsub);
refvecssel = cellfunc(@(x, y) x(y, :), refvecs, {refs.grididx});
refsubvec = cellfunc(@(x) bsxfun(@times, locwt, x), refvecssel);
fullreflib = cellfunc(@horzcat, {refs.lib}, refsubvec);
[refs.lib] = fullreflib{:};
end
function volstruct = vol2libwrap(vol, patchSize, varargin)
% vol2lib
refidx = cell(1*iscell(vol));
[lib, grididx, cropVolSize, gridSize, refidx{:}] = ...
patchlib.vol2lib(vol, patchSize, varargin{:});
% build struct. note that if lib is a cell, all elements should be a cell
volstruct = structrich(vol, lib, grididx, cropVolSize, gridSize);
% assignrefidx
if ~isempty(refidx)
[volstruct.refidx] = refidx{1}{:};
end
end
function [refs, srcoverlap, refoverlap, knnvarargin, inputs] = parseinputs(refs, patchSize, varargin)
% getPatchFunction (2dLocation_in_src, ref),
% method for extracting the actual stuff - this can probably be put with getPatchFunction.
% pre-sel voxels?
% Other stuff for knnsearch
% check for source overlaps
srcoverlap = {};
if numel(varargin) > 1 && patchlib.isvalidoverlap(varargin{1})
srcoverlap = varargin(1);
varargin = varargin(2:end);
end
% check for reference overlaps
refoverlap = {};
if numel(varargin) > 1 && patchlib.isvalidoverlap(varargin{1})
refoverlap = varargin(1);
varargin = varargin(2:end);
end
% 'local' means local search, and takes in spacing or function.
% also allow 'localpreprocess'
p = inputParser();
p.addParameter('local', [], @isnumeric);
p.addParameter('location', 0, @isnumeric);
p.addParameter('searchfn', [], @(x) isa(x, 'function_handle'));
p.addParameter('buildreflibs', true, @islogical);
p.addParameter('excludePatches', false, @islogical);
p.addParameter('mask', [], @islogical);
p.addParameter('separateProc', 0, @isnumeric); % see help
p.addParameter('separateProcAgg', 'agg', @(x) validatestring(x, 'agg', 'sep')); % see help
p.addParameter('libfn', @vol2libwrap, @(x) isa(x, 'function_handle'));
p.addParameter('verbose', false, @islogical);
p.addParameter('fillK', false, @islogical);
p.addParameter('tmpfolder', '', @islogical);
p.addParameter('memory', -1, @islogical);
p.KeepUnmatched = true;
p.parse(varargin{:});
knnvarargin = struct2cellWithNames(p.Unmatched);
inputs = p.Results;
% make sure refs is a cell
if ~iscell(refs)
refs = {refs};
end
nDims = numel(patchSize);
% default global search
if isempty(inputs.local) && isempty(inputs.searchfn)
% assert(inputs.buildreflibs)
inputs.searchfn = @volknnglobalsearch;
end
% default local search
if ~isempty(inputs.local)
assert(inputs.buildreflibs);
assert(isnumeric(inputs.local));
assert(isempty(inputs.searchfn), 'Only provide local spacing or search function, not both');
if isscalar(inputs.local), inputs.local = inputs.local * ones(1, nDims); end
inputs.searchfn = @(x, y, varargin) volknnlocalsearch(x, y, inputs.local, inputs.fillK, varargin{:});
end
if isscalar(inputs.location)
inputs.location = repmat(inputs.location, [1, nDims]);
end
if ~isempty(inputs.tmpfolder)
assert(sys.isdir(inputs.tmpfolder));
assert(inputs.memory > 0);
end
end