Skip to content

Commit 8ecaaf2

Browse files
committed
Fix LLVM IR generation when loading Bool types in VCalls
Credits to @mcrescas for finding this
1 parent 69fc24b commit 8ecaaf2

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

src/llvm_eval.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,8 @@ void jitc_llvm_assemble_func(const char *name, uint32_t inst_id,
344344
v, v, v, v);
345345

346346
uint32_t offset = it->second - data_offset;
347+
bool is_pointer_or_bool =
348+
(vt == VarType::Pointer) || (vt == VarType::Bool);
347349
// Expand $<..$> only when we are compiling a recursive function call
348350
callable_depth--;
349351
fmt( " $v_p1 = getelementptr inbounds i8, $<{i8*}$> %data, i32 $u\n"
@@ -353,12 +355,15 @@ void jitc_llvm_assemble_func(const char *name, uint32_t inst_id,
353355
v, offset,
354356
v, v,
355357
v, v, v,
356-
v, vt == VarType::Pointer ? "_p4" : "", v, v, v, v, v, v);
358+
v, is_pointer_or_bool ? "_p4" : "", v, v, v, v, v, v);
357359
callable_depth++;
358360

359361
if (vt == VarType::Pointer)
360362
fmt(" $v = inttoptr <$w x i64> $v_p4 to <$w x {i8*}>\n",
361363
v, v);
364+
else if (vt == VarType::Bool)
365+
fmt(" $v = trunc <$w x i8> $v_p4 to <$w x i1>\n",
366+
v, v);
362367
} else if (!v->is_stmt()) {
363368
jitc_llvm_render_var(sv.index, v);
364369
} else {

tests/vcall.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,3 +916,48 @@ TEST_BOTH(12_nested_with_side_effects) {
916916
jit_registry_trim();
917917
}
918918
}
919+
920+
TEST_BOTH(13_load_bool_data) {
921+
struct Base {
922+
virtual Float f() = 0;
923+
};
924+
using BasePtr = Array<Base *>;
925+
926+
struct F1 : Base {
927+
Mask cond = dr::opaque<Mask>(true);
928+
Float val1 = dr::opaque<Float>(1);
929+
Float val2 = dr::opaque<Float>(2);
930+
Float f() override {
931+
return select(cond, val1, val2);
932+
}
933+
};
934+
935+
struct F2 : Base {
936+
Mask cond = dr::opaque<Mask>(false);
937+
Float val1 = dr::opaque<Float>(3);
938+
Float val2 = dr::opaque<Float>(4);
939+
Float f() override {
940+
return select(cond, val1, val2);
941+
}
942+
};
943+
944+
BasePtr self = arange<UInt32>(5) % 3;
945+
for (uint32_t i = 0; i < 2; ++i) {
946+
jit_set_flag(JitFlag::VCallOptimize, i);
947+
948+
F1 f1; F2 f2;
949+
uint32_t i1 = jit_registry_put(Backend, "Base", &f1);
950+
uint32_t i2 = jit_registry_put(Backend, "Base", &f2);
951+
jit_assert(i1 == 1 && i2 == 2);
952+
953+
Float result = vcall(
954+
"Base", [](Base *self2) { return self2->f(); }, self);
955+
956+
jit_var_schedule(result.index());
957+
jit_assert(strcmp(result.str(), "[0, 1, 4, 0, 1]") == 0);
958+
959+
jit_registry_remove(Backend, &f1);
960+
jit_registry_remove(Backend, &f2);
961+
jit_registry_trim();
962+
}
963+
}

0 commit comments

Comments
 (0)