diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index f63fe17da51ff..c56cfec81acdd 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -6149,7 +6149,12 @@ SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) { if (ExtType == ISD::NON_EXTLOAD && TLI.isOperationLegalOrCustom(ISD::VP_LOAD, WidenVT) && - TLI.isTypeLegal(WideMaskVT)) { + TLI.isTypeLegal(WideMaskVT) && + // If there is a passthru, we shouldn't use vp.load. However, + // type legalizer will struggle on masked.load with + // scalable vectors, so for scalable vectors, we still use vp.load + // but manually merge the load result with the passthru using vp.select. + (N->getPassThru()->isUndef() || VT.isScalableVector())) { Mask = DAG.getInsertSubvector(dl, DAG.getUNDEF(WideMaskVT), Mask, 0); SDValue EVL = DAG.getElementCount(dl, TLI.getVPExplicitVectorLengthTy(), VT.getVectorElementCount()); @@ -6157,12 +6162,20 @@ SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) { DAG.getLoadVP(N->getAddressingMode(), ISD::NON_EXTLOAD, WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask, EVL, N->getMemoryVT(), N->getMemOperand()); + SDValue NewVal = NewLoad; + + // Manually merge with vp.select + if (!N->getPassThru()->isUndef()) { + assert(WidenVT.isScalableVector()); + NewVal = + DAG.getNode(ISD::VP_SELECT, dl, WidenVT, Mask, NewVal, PassThru, EVL); + } // Modified the chain - switch anything that used the old chain to use // the new one. ReplaceValueWith(SDValue(N, 1), NewLoad.getValue(1)); - return NewLoad; + return NewVal; } // The mask should be widened as well diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll index 545c89495e621..ed60d91308495 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll @@ -341,3 +341,16 @@ define <7 x i8> @masked_load_v7i8(ptr %a, <7 x i1> %mask) { ret <7 x i8> %load } +define <7 x i8> @masked_load_passthru_v7i8(ptr %a, <7 x i1> %mask) { +; CHECK-LABEL: masked_load_passthru_v7i8: +; CHECK: # %bb.0: +; CHECK-NEXT: li a1, 127 +; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu +; CHECK-NEXT: vmv.s.x v8, a1 +; CHECK-NEXT: vmand.mm v0, v0, v8 +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vle8.v v8, (a0), v0.t +; CHECK-NEXT: ret + %load = call <7 x i8> @llvm.masked.load.v7i8(ptr %a, i32 8, <7 x i1> %mask, <7 x i8> zeroinitializer) + ret <7 x i8> %load +} diff --git a/llvm/test/CodeGen/RISCV/rvv/masked-load-int.ll b/llvm/test/CodeGen/RISCV/rvv/masked-load-int.ll index d992669306fb1..75537406f3515 100644 --- a/llvm/test/CodeGen/RISCV/rvv/masked-load-int.ll +++ b/llvm/test/CodeGen/RISCV/rvv/masked-load-int.ll @@ -21,7 +21,27 @@ define @masked_load_nxv1i8(ptr %a, %mask) no %load = call @llvm.masked.load.nxv1i8(ptr %a, i32 1, %mask, undef) ret %load } -declare @llvm.masked.load.nxv1i8(ptr, i32, , ) + +define @masked_load_passthru_nxv1i8(ptr %a, %mask) nounwind { +; V-LABEL: masked_load_passthru_nxv1i8: +; V: # %bb.0: +; V-NEXT: vsetvli a1, zero, e8, mf8, ta, mu +; V-NEXT: vmv.v.i v8, 0 +; V-NEXT: vle8.v v8, (a0), v0.t +; V-NEXT: ret +; +; ZVE32-LABEL: masked_load_passthru_nxv1i8: +; ZVE32: # %bb.0: +; ZVE32-NEXT: csrr a1, vlenb +; ZVE32-NEXT: srli a1, a1, 3 +; ZVE32-NEXT: vsetvli a2, zero, e8, mf4, ta, ma +; ZVE32-NEXT: vmv.v.i v8, 0 +; ZVE32-NEXT: vsetvli zero, a1, e8, mf4, ta, mu +; ZVE32-NEXT: vle8.v v8, (a0), v0.t +; ZVE32-NEXT: ret + %load = call @llvm.masked.load.nxv1i8(ptr %a, i32 1, %mask, zeroinitializer) + ret %load +} define @masked_load_nxv1i16(ptr %a, %mask) nounwind { ; V-LABEL: masked_load_nxv1i16: