Skip to content

Commit d7e64d9

Browse files
authored
[MSAN] handle assorted AVX permutations (#143462)
1 parent ca38027 commit d7e64d9

File tree

7 files changed

+1164
-652
lines changed

7 files changed

+1164
-652
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4173,7 +4173,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
41734173

41744174
// Instrument AVX permutation intrinsic.
41754175
// We apply the same permutation (argument index 1) to the shadow.
4176-
void handleAVXVpermilvar(IntrinsicInst &I) {
4176+
void handleAVXPermutation(IntrinsicInst &I) {
4177+
assert(I.arg_size() == 2);
4178+
assert(isa<FixedVectorType>(I.getArgOperand(0)->getType()));
4179+
assert(isa<FixedVectorType>(I.getArgOperand(1)->getType()));
4180+
[[maybe_unused]] auto ArgVectorSize =
4181+
cast<FixedVectorType>(I.getArgOperand(0)->getType())->getNumElements();
4182+
assert(cast<FixedVectorType>(I.getArgOperand(1)->getType())
4183+
->getNumElements() == ArgVectorSize);
4184+
assert(I.getType() == I.getArgOperand(0)->getType());
41774185
IRBuilder<> IRB(&I);
41784186
Value *Shadow = getShadow(&I, 0);
41794187
insertShadowCheck(I.getArgOperand(1), &I);
@@ -4187,6 +4195,38 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
41874195
setShadow(&I, IRB.CreateBitCast(CI, getShadowTy(&I)));
41884196
setOriginForNaryOp(I);
41894197
}
4198+
// Instrument AVX permutation intrinsic.
4199+
// We apply the same permutation (argument index 1) to the shadows.
4200+
void handleAVXVpermil2var(IntrinsicInst &I) {
4201+
assert(I.arg_size() == 3);
4202+
assert(isa<FixedVectorType>(I.getArgOperand(0)->getType()));
4203+
assert(isa<FixedVectorType>(I.getArgOperand(1)->getType()));
4204+
assert(isa<FixedVectorType>(I.getArgOperand(2)->getType()));
4205+
[[maybe_unused]] auto ArgVectorSize =
4206+
cast<FixedVectorType>(I.getArgOperand(0)->getType())->getNumElements();
4207+
assert(cast<FixedVectorType>(I.getArgOperand(1)->getType())
4208+
->getNumElements() == ArgVectorSize);
4209+
assert(cast<FixedVectorType>(I.getArgOperand(2)->getType())
4210+
->getNumElements() == ArgVectorSize);
4211+
assert(I.getArgOperand(0)->getType() == I.getArgOperand(2)->getType());
4212+
assert(I.getType() == I.getArgOperand(0)->getType());
4213+
assert(I.getArgOperand(1)->getType()->isIntOrIntVectorTy());
4214+
IRBuilder<> IRB(&I);
4215+
Value *AShadow = getShadow(&I, 0);
4216+
Value *Idx = I.getArgOperand(1);
4217+
Value *BShadow = getShadow(&I, 2);
4218+
insertShadowCheck(Idx, &I);
4219+
4220+
// Shadows are integer-ish types but some intrinsics require a
4221+
// different (e.g., floating-point) type.
4222+
AShadow = IRB.CreateBitCast(AShadow, I.getArgOperand(0)->getType());
4223+
BShadow = IRB.CreateBitCast(BShadow, I.getArgOperand(2)->getType());
4224+
CallInst *CI = IRB.CreateIntrinsic(I.getType(), I.getIntrinsicID(),
4225+
{AShadow, Idx, BShadow});
4226+
4227+
setShadow(&I, IRB.CreateBitCast(CI, getShadowTy(&I)));
4228+
setOriginForNaryOp(I);
4229+
}
41904230

41914231
// Instrument BMI / BMI2 intrinsics.
41924232
// All of these intrinsics are Z = I(X, Y)
@@ -5132,16 +5172,52 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
51325172
assert(Success);
51335173
break;
51345174
}
5135-
5175+
case Intrinsic::x86_avx2_permd:
5176+
case Intrinsic::x86_avx2_permps:
5177+
case Intrinsic::x86_ssse3_pshuf_b_128:
5178+
case Intrinsic::x86_avx2_pshuf_b:
5179+
case Intrinsic::x86_avx512_pshuf_b_512:
5180+
case Intrinsic::x86_avx512_permvar_df_256:
5181+
case Intrinsic::x86_avx512_permvar_df_512:
5182+
case Intrinsic::x86_avx512_permvar_di_256:
5183+
case Intrinsic::x86_avx512_permvar_di_512:
5184+
case Intrinsic::x86_avx512_permvar_hi_128:
5185+
case Intrinsic::x86_avx512_permvar_hi_256:
5186+
case Intrinsic::x86_avx512_permvar_hi_512:
5187+
case Intrinsic::x86_avx512_permvar_qi_128:
5188+
case Intrinsic::x86_avx512_permvar_qi_256:
5189+
case Intrinsic::x86_avx512_permvar_qi_512:
5190+
case Intrinsic::x86_avx512_permvar_sf_512:
5191+
case Intrinsic::x86_avx512_permvar_si_512:
51365192
case Intrinsic::x86_avx_vpermilvar_pd:
51375193
case Intrinsic::x86_avx_vpermilvar_pd_256:
51385194
case Intrinsic::x86_avx512_vpermilvar_pd_512:
51395195
case Intrinsic::x86_avx_vpermilvar_ps:
51405196
case Intrinsic::x86_avx_vpermilvar_ps_256:
51415197
case Intrinsic::x86_avx512_vpermilvar_ps_512: {
5142-
handleAVXVpermilvar(I);
5198+
handleAVXPermutation(I);
51435199
break;
51445200
}
5201+
case Intrinsic::x86_avx512_vpermi2var_d_128:
5202+
case Intrinsic::x86_avx512_vpermi2var_d_256:
5203+
case Intrinsic::x86_avx512_vpermi2var_d_512:
5204+
case Intrinsic::x86_avx512_vpermi2var_hi_128:
5205+
case Intrinsic::x86_avx512_vpermi2var_hi_256:
5206+
case Intrinsic::x86_avx512_vpermi2var_hi_512:
5207+
case Intrinsic::x86_avx512_vpermi2var_pd_128:
5208+
case Intrinsic::x86_avx512_vpermi2var_pd_256:
5209+
case Intrinsic::x86_avx512_vpermi2var_pd_512:
5210+
case Intrinsic::x86_avx512_vpermi2var_ps_128:
5211+
case Intrinsic::x86_avx512_vpermi2var_ps_256:
5212+
case Intrinsic::x86_avx512_vpermi2var_ps_512:
5213+
case Intrinsic::x86_avx512_vpermi2var_q_128:
5214+
case Intrinsic::x86_avx512_vpermi2var_q_256:
5215+
case Intrinsic::x86_avx512_vpermi2var_q_512:
5216+
case Intrinsic::x86_avx512_vpermi2var_qi_128:
5217+
case Intrinsic::x86_avx512_vpermi2var_qi_256:
5218+
case Intrinsic::x86_avx512_vpermi2var_qi_512:
5219+
handleAVXVpermil2var(I);
5220+
break;
51455221

51465222
case Intrinsic::x86_avx512fp16_mask_add_sh_round:
51475223
case Intrinsic::x86_avx512fp16_mask_sub_sh_round:

llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -740,8 +740,15 @@ define <32 x i8> @test_x86_avx2_pshuf_b(<32 x i8> %a0, <32 x i8> %a1) #0 {
740740
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i8>, ptr @__msan_param_tls, align 8
741741
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i8>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
742742
; CHECK-NEXT: call void @llvm.donothing()
743-
; CHECK-NEXT: [[_MSPROP:%.*]] = or <32 x i8> [[TMP1]], [[TMP2]]
744-
; CHECK-NEXT: [[RES:%.*]] = call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> [[A0:%.*]], <32 x i8> [[A1:%.*]])
743+
; CHECK-NEXT: [[_MSPROP:%.*]] = call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> [[TMP1]], <32 x i8> [[A1:%.*]])
744+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <32 x i8> [[TMP2]] to i256
745+
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP4]], 0
746+
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF1]]
747+
; CHECK: 5:
748+
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR6]]
749+
; CHECK-NEXT: unreachable
750+
; CHECK: 6:
751+
; CHECK-NEXT: [[RES:%.*]] = call <32 x i8> @llvm.x86.avx2.pshuf.b(<32 x i8> [[A0:%.*]], <32 x i8> [[A1]])
745752
; CHECK-NEXT: store <32 x i8> [[_MSPROP]], ptr @__msan_retval_tls, align 8
746753
; CHECK-NEXT: ret <32 x i8> [[RES]]
747754
;
@@ -969,8 +976,15 @@ define <8 x i32> @test_x86_avx2_permd(<8 x i32> %a0, <8 x i32> %a1) #0 {
969976
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
970977
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
971978
; CHECK-NEXT: call void @llvm.donothing()
972-
; CHECK-NEXT: [[_MSPROP:%.*]] = or <8 x i32> [[TMP1]], [[TMP2]]
973-
; CHECK-NEXT: [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.permd(<8 x i32> [[A0:%.*]], <8 x i32> [[A1:%.*]])
979+
; CHECK-NEXT: [[_MSPROP:%.*]] = call <8 x i32> @llvm.x86.avx2.permd(<8 x i32> [[TMP1]], <8 x i32> [[A1:%.*]])
980+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP2]] to i256
981+
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP4]], 0
982+
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF1]]
983+
; CHECK: 5:
984+
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR6]]
985+
; CHECK-NEXT: unreachable
986+
; CHECK: 6:
987+
; CHECK-NEXT: [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.permd(<8 x i32> [[A0:%.*]], <8 x i32> [[A1]])
974988
; CHECK-NEXT: store <8 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
975989
; CHECK-NEXT: ret <8 x i32> [[RES]]
976990
;
@@ -985,18 +999,18 @@ define <8 x float> @test_x86_avx2_permps(<8 x float> %a0, <8 x i32> %a1) #0 {
985999
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
9861000
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
9871001
; CHECK-NEXT: call void @llvm.donothing()
988-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i32> [[TMP1]] to i256
989-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP3]], 0
1002+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i32> [[TMP1]] to <8 x float>
1003+
; CHECK-NEXT: [[TMP6:%.*]] = call <8 x float> @llvm.x86.avx2.permps(<8 x float> [[TMP3]], <8 x i32> [[A1:%.*]])
1004+
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <8 x float> [[TMP6]] to <8 x i32>
9901005
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP2]] to i256
9911006
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i256 [[TMP4]], 0
992-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
993-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF1]]
994-
; CHECK: 5:
1007+
; CHECK-NEXT: br i1 [[_MSCMP1]], label [[TMP7:%.*]], label [[TMP8:%.*]], !prof [[PROF1]]
1008+
; CHECK: 7:
9951009
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR6]]
9961010
; CHECK-NEXT: unreachable
997-
; CHECK: 6:
998-
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx2.permps(<8 x float> [[A0:%.*]], <8 x i32> [[A1:%.*]])
999-
; CHECK-NEXT: store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
1011+
; CHECK: 8:
1012+
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx2.permps(<8 x float> [[A0:%.*]], <8 x i32> [[A1]])
1013+
; CHECK-NEXT: store <8 x i32> [[TMP5]], ptr @__msan_retval_tls, align 8
10001014
; CHECK-NEXT: ret <8 x float> [[RES]]
10011015
;
10021016
%res = call <8 x float> @llvm.x86.avx2.permps(<8 x float> %a0, <8 x i32> %a1) ; <<8 x float>> [#uses=1]

0 commit comments

Comments
 (0)