Skip to content

Commit 6551e79

Browse files
committed
Add support for b4_SSE2 batched mode (24)
Signed-off-by: Tuomas Tonteri <[email protected]>
1 parent f4d033c commit 6551e79

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

src/liboslexec/llvm_util.cpp

+44-16
Original file line numberDiff line numberDiff line change
@@ -3854,13 +3854,36 @@ LLVM_Util::mask_as_int8(llvm::Value* mask)
38543854
llvm::Value*
38553855
LLVM_Util::mask4_as_int8(llvm::Value* mask)
38563856
{
3857-
OSL_ASSERT(m_supports_llvm_bit_masks_natively);
3858-
// combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
3859-
// to i8
3860-
llvm::Value* zero_mask4
3861-
= llvm::ConstantDataVector::getSplat(4, constant_bool(false));
3862-
return builder().CreateBitCast(op_combine_4x_vectors(mask, zero_mask4),
3863-
type_int8());
3857+
if (m_supports_llvm_bit_masks_natively) {
3858+
// combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
3859+
// to i8
3860+
llvm::Value* zero_mask4
3861+
= llvm::ConstantDataVector::getSplat(4, constant_bool(false));
3862+
return builder().CreateBitCast(op_combine_4x_vectors(mask, zero_mask4),
3863+
type_int8());
3864+
} else {
3865+
// Convert <4 x i1> -> <4 x i32>
3866+
llvm::Value* wide_int_mask = builder().CreateSExt(mask,
3867+
type_wide_int());
3868+
3869+
// Now we will use the horizontal sign extraction intrinsic
3870+
// to build a 32 bit mask value. However the only 128bit
3871+
// version works on floats, so we will cast from int32 to
3872+
// float beforehand
3873+
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
3874+
llvm::Value* w4_float_mask = builder().CreateBitCast(wide_int_mask,
3875+
w4_float_type);
3876+
3877+
llvm::Function* func = llvm::Intrinsic::getDeclaration(
3878+
module(), llvm::Intrinsic::x86_sse_movmsk_ps);
3879+
3880+
llvm::Value* args[1] = { w4_float_mask };
3881+
llvm::Value* int32 = builder().CreateCall(func, toArrayRef(args));
3882+
3883+
llvm::Value* i8 = builder().CreateIntCast(int32, type_int8(), true);
3884+
3885+
return i8;
3886+
}
38643887
}
38653888

38663889

@@ -4013,17 +4036,22 @@ LLVM_Util::op_1st_active_lane_of(llvm::Value* mask)
40134036
intMaskType = type_int8();
40144037
break;
40154038
case 4: {
4016-
// We can just reinterpret cast a 4 bit mask to a 8 bit integer
4017-
// and all types are happy
40184039
intMaskType = type_int8();
40194040

4020-
// extended_int_vector_type = (llvm::Type *) llvm::VectorType::get(llvm::Type::getInt32Ty (*m_llvm_context), m_vector_width);
4021-
// llvm::Value * wide_int_mask = builder().CreateSExt(mask, extended_int_vector_type);
4022-
//
4023-
// int_reinterpret_cast_vector_type = (llvm::Type *) llvm::Type::getInt128Ty (*m_llvm_context);
4024-
// zeroConstant = constant128(0);
4025-
//
4026-
// llvm::Value * mask_as_int = builder().CreateBitCast (wide_int_mask, int_reinterpret_cast_vector_type);
4041+
llvm::Value* mask_as_int = mask4_as_int8(mask);
4042+
4043+
// Count trailing zeros, least significant
4044+
llvm::Type* types[] = { intMaskType };
4045+
llvm::Function* func_cttz
4046+
= llvm::Intrinsic::getDeclaration(module(), llvm::Intrinsic::cttz,
4047+
toArrayRef(types));
4048+
4049+
llvm::Value* args[2] = { mask_as_int, constant_bool(true) };
4050+
4051+
llvm::Value* firstNonZeroIndex = builder().CreateCall(func_cttz,
4052+
toArrayRef(args));
4053+
return firstNonZeroIndex;
4054+
40274055
break;
40284056
}
40294057
default: OSL_ASSERT(0 && "unsupported native bit mask width");

testsuite/example-batched-deformer/oslbatcheddeformer.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,13 @@ main(int argc, char* argv[])
237237
batch_width = 16;
238238
} else if (shadsys->configure_batch_execution_at(8)) {
239239
batch_width = 8;
240+
} else if (shadsys->configure_batch_execution_at(4)) {
241+
batch_width = 4;
240242
} else {
241243
std::cout
242-
<< "Error: Hardware doesn't support 8 or 16 wide SIMD or the OSL has not been configured and built with a proper USE_BATCHED."
244+
<< "Error: Hardware doesn't support 4, 8 or 16 wide SIMD or the OSL has not been configured and built with a proper USE_BATCHED."
243245
<< std::endl;
244-
std::cout << "Error: e.g.: USE_BATCHED=b8_AVX2,b8_AVX512,b16_AVX512"
246+
std::cout << "Error: e.g.: USE_BATCHED=b4_SSE2,b8_AVX2,b8_AVX512,b16_AVX512"
245247
<< std::endl;
246248
return -1;
247249
}
@@ -437,8 +439,11 @@ main(int argc, char* argv[])
437439

438440
if (batch_width == 16) {
439441
batched_shadepoints(std::integral_constant<int, 16> {});
440-
} else {
442+
}
443+
else if (batch_width == 8) {
441444
batched_shadepoints(std::integral_constant<int, 8> {});
445+
} else {
446+
batched_shadepoints(std::integral_constant<int, 4> {});
442447
}
443448

444449
// Print some results to prove that we generated an expected Pout.

0 commit comments

Comments
 (0)