Skip to content

Commit 69fc24b

Browse files
committed
Fix gather operations on results of other gather operations
With symbolic gather operations, invalid indices might have been pushed/propagated to previous gather nodes in the dependency tree which could produce out-of-bound memory reads if there was an other gather operation in the tree. This commit also propagates the mask throughout the tree to guarantee that no out-of-bound reads will be made.
1 parent 2581dd6 commit 69fc24b

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

src/op.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,7 @@ static uint32_t jitc_scatter_gather_index(uint32_t source, uint32_t index) {
14151415

14161416
/// Change all indices/counters in an expression tree to 'new_index'
14171417
static uint32_t jitc_var_reindex(uint32_t var_index, uint32_t new_index,
1418-
uint32_t size) {
1418+
uint32_t mask, uint32_t size) {
14191419
Variable *v = jitc_var(var_index);
14201420

14211421
if (v->is_data() || (VarType) v->type == VarType::Void)
@@ -1439,7 +1439,16 @@ static uint32_t jitc_var_reindex(uint32_t var_index, uint32_t new_index,
14391439
uint32_t index_2 = v->dep[i];
14401440
if (!index_2)
14411441
continue;
1442-
dep[i] = steal(jitc_var_reindex(index_2, new_index, size));
1442+
1443+
if (v->kind == VarKind::Gather && i == 2) {
1444+
// Gather nodes must have their masks replaced rather than reindexed
1445+
JitBackend backend = (JitBackend) v->backend;
1446+
Ref default_mask = steal(jitc_var_mask_default(backend, size));
1447+
dep[i] = steal(jitc_var_and(mask, default_mask));
1448+
} else {
1449+
dep[i] = steal(jitc_var_reindex(index_2, new_index, mask, size));
1450+
}
1451+
14431452
v = jitc_var(var_index);
14441453
if (!dep[i])
14451454
return 0; // recursive call failed, give up
@@ -1527,7 +1536,7 @@ uint32_t jitc_var_gather(uint32_t src, uint32_t index, uint32_t mask) {
15271536
// Don't perform the gather operation if the inputs are trivial / can be re-indexed
15281537
if (!result) {
15291538
Ref index_2 = steal(jitc_var_cast(index, VarType::UInt32, 0));
1530-
Ref src_reindexed = steal(jitc_var_reindex(src, index_2, var_info.size));
1539+
Ref src_reindexed = steal(jitc_var_reindex(src, index_2, mask, var_info.size));
15311540
if (src_reindexed) {
15321541
// Temporarily hold an extra reference to prevent 'jitc_var_resize' from changing 'src'
15331542
Ref unused = borrow(src_reindexed);

tests/mem.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,22 @@ TEST_LLVM(14_gather_symbolic_llvm_mask) {
213213
Float buf_3 = gather<Float>(buf_2, arange<UInt32>(4, 8, 1));
214214
jit_assert(strcmp(buf_3.str(), "[5, 6, 7, 8]") == 0);
215215
}
216+
217+
TEST_BOTH(15_gather_symbolic_multiple_mask) {
218+
/* A gather expression that is reindexed/rewritten should properly apply its
219+
* mask to any previous gather operations it depends on */
220+
Float buf_0 = Float(1, 2, 3, 4, 5, 6, 7, 8);
221+
222+
// true, true, true, false, true
223+
Mask mask_1 = (arange<UInt32>(0, 5, 1) % 4) != 0;
224+
UInt32 index_1 = arange<UInt32>(0, 5, 1);
225+
Float buf_1 = gather<Float>(buf_0, index_1, mask_1);
226+
227+
Mask mask_2 = Mask(true, true, false, false);
228+
UInt32 index_2 = UInt32(0, 1, -1, -1);
229+
230+
// This gather will reindex, and should apply `mask_2` to the previous
231+
// gather, or else it will lookup invalid memory
232+
Float buf_2 = gather<Float>(buf_1, index_2, mask_2);
233+
jit_assert(strcmp(buf_2.str(), "[1, 2, 0, 0]") == 0);
234+
}

0 commit comments

Comments
 (0)