@@ -3854,13 +3854,36 @@ LLVM_Util::mask_as_int8(llvm::Value* mask)
3854
3854
llvm::Value*
3855
3855
LLVM_Util::mask4_as_int8 (llvm::Value* mask)
3856
3856
{
3857
- OSL_ASSERT (m_supports_llvm_bit_masks_natively);
3858
- // combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
3859
- // to i8
3860
- llvm::Value* zero_mask4
3861
- = llvm::ConstantDataVector::getSplat (4 , constant_bool (false ));
3862
- return builder ().CreateBitCast (op_combine_4x_vectors (mask, zero_mask4),
3863
- type_int8 ());
3857
+ if (m_supports_llvm_bit_masks_natively) {
3858
+ // combine <4xi1> mask with <4xi1> zero init to get <8xi1> and cast it
3859
+ // to i8
3860
+ llvm::Value* zero_mask4
3861
+ = llvm::ConstantDataVector::getSplat (4 , constant_bool (false ));
3862
+ return builder ().CreateBitCast (op_combine_4x_vectors (mask, zero_mask4),
3863
+ type_int8 ());
3864
+ } else {
3865
+ // Convert <4 x i1> -> <4 x i32>
3866
+ llvm::Value* wide_int_mask = builder ().CreateSExt (mask,
3867
+ type_wide_int ());
3868
+
3869
+ // Now we will use the horizontal sign extraction intrinsic
3870
+ // to build a 32 bit mask value. However the only 128bit
3871
+ // version works on floats, so we will cast from int32 to
3872
+ // float beforehand
3873
+ llvm::Type* w4_float_type = llvm_vector_type (m_llvm_type_float, 4 );
3874
+ llvm::Value* w4_float_mask = builder ().CreateBitCast (wide_int_mask,
3875
+ w4_float_type);
3876
+
3877
+ llvm::Function* func = llvm::Intrinsic::getDeclaration (
3878
+ module (), llvm::Intrinsic::x86_sse_movmsk_ps);
3879
+
3880
+ llvm::Value* args[1 ] = { w4_float_mask };
3881
+ llvm::Value* int32 = builder ().CreateCall (func, toArrayRef (args));
3882
+
3883
+ llvm::Value* i8 = builder ().CreateIntCast (int32, type_int8 (), true );
3884
+
3885
+ return i8;
3886
+ }
3864
3887
}
3865
3888
3866
3889
@@ -4013,17 +4036,22 @@ LLVM_Util::op_1st_active_lane_of(llvm::Value* mask)
4013
4036
intMaskType = type_int8 ();
4014
4037
break ;
4015
4038
case 4 : {
4016
- // We can just reinterpret cast a 4 bit mask to a 8 bit integer
4017
- // and all types are happy
4018
4039
intMaskType = type_int8 ();
4019
4040
4020
- // extended_int_vector_type = (llvm::Type *) llvm::VectorType::get(llvm::Type::getInt32Ty (*m_llvm_context), m_vector_width);
4021
- // llvm::Value * wide_int_mask = builder().CreateSExt(mask, extended_int_vector_type);
4022
- //
4023
- // int_reinterpret_cast_vector_type = (llvm::Type *) llvm::Type::getInt128Ty (*m_llvm_context);
4024
- // zeroConstant = constant128(0);
4025
- //
4026
- // llvm::Value * mask_as_int = builder().CreateBitCast (wide_int_mask, int_reinterpret_cast_vector_type);
4041
+ llvm::Value* mask_as_int = mask4_as_int8 (mask);
4042
+
4043
+ // Count trailing zeros, least significant
4044
+ llvm::Type* types[] = { intMaskType };
4045
+ llvm::Function* func_cttz
4046
+ = llvm::Intrinsic::getDeclaration (module (), llvm::Intrinsic::cttz,
4047
+ toArrayRef (types));
4048
+
4049
+ llvm::Value* args[2 ] = { mask_as_int, constant_bool (true ) };
4050
+
4051
+ llvm::Value* firstNonZeroIndex = builder ().CreateCall (func_cttz,
4052
+ toArrayRef (args));
4053
+ return firstNonZeroIndex;
4054
+
4027
4055
break ;
4028
4056
}
4029
4057
default : OSL_ASSERT (0 && " unsupported native bit mask width" );
0 commit comments