-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[MLIR][XeGPU] Add unroll patterns for scatter ops #143602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu Author: Jianhui Li (Jianhui-Li) ChangesAdd unrolling support for create_tdesc, load, store, prefetch, and update_offset. Full diff: https://github.com/llvm/llvm-project/pull/143602.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..900ade8c171d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
}
};
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+ using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+ VectorType indiceVecTy = indiceVec.getType();
+ SmallVector<Type> convertedIndiceTypes =
+ getUnrolledTypes(indiceVecTy, *targetShape);
+ SmallVector<Value> convertedIndiceVec =
+ pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+ op.getSource(), indice);
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+
+ return success();
+ }
+};
+
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+ using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ SmallVector<Value> convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+ auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+ using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ for (auto t : convertedTdesc)
+ rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+ using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy;
+ if (op.getMask())
+ maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> convertedMasks;
+ if (op.getMask()) {
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
+
+ for (size_t i = 0; i < convertedValues.size(); ++i) {
+ Value v = convertedValues[i];
+ Value t = convertedTdescs[i];
+ Value m = op.getMask() ? convertedMasks[i] : nullptr;
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+ using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+ VectorType offsetVecTy = offsetVec.getType();
+ SmallVector<Type> convertedOffsetTypes =
+ getUnrolledTypes(offsetVecTy, *targetShape);
+ SmallVector<Value> convertedOffsetVec =
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+ auto newOp =
+ rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
+ newOps.push_back(newOp);
+ }
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
- patterns.getContext(), options);
+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+ UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+ UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+ options);
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..47c54bfcb89d0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,147 @@ gpu.module @test {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_vec
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_step
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_load
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ gpu.return %ld : vector<32xf32>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_prefetch(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_store
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+ gpu.func @test_store(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.>: vector<32xf32>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch_load_store_update
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+ xegpu.store %st_vec, %tdesc, %mask:
+ vector<32xf32>,
+ !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,29 @@ struct TestXeGPUUnrollingPatterns
}
}
+ if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ tdescTy = createOp.getType();
+ } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ tdescTy = updateOp.getTensorDescType();
+ } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+ tdescTy = prefetchOp.getTensorDescType();
+ } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ tdescTy = loadOp.getTensorDescType();
+ } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+ tdescTy = storeOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isSgLayout())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_data.asArrayRef().end());
+ }
+ }
+
if (isa<xegpu::DpasOp>(op))
return SmallVector<int64_t>{8, 16, 16};
|
@llvm/pr-subscribers-mlir Author: Jianhui Li (Jianhui-Li) ChangesAdd unrolling support for create_tdesc, load, store, prefetch, and update_offset. Full diff: https://github.com/llvm/llvm-project/pull/143602.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..900ade8c171d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
}
};
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+ using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+ VectorType indiceVecTy = indiceVec.getType();
+ SmallVector<Type> convertedIndiceTypes =
+ getUnrolledTypes(indiceVecTy, *targetShape);
+ SmallVector<Value> convertedIndiceVec =
+ pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+ op.getSource(), indice);
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+
+ return success();
+ }
+};
+
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+ using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ SmallVector<Value> convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+ auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+ using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ for (auto t : convertedTdesc)
+ rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+ using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy;
+ if (op.getMask())
+ maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> convertedMasks;
+ if (op.getMask()) {
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
+
+ for (size_t i = 0; i < convertedValues.size(); ++i) {
+ Value v = convertedValues[i];
+ Value t = convertedTdescs[i];
+ Value m = op.getMask() ? convertedMasks[i] : nullptr;
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+ using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+ VectorType offsetVecTy = offsetVec.getType();
+ SmallVector<Type> convertedOffsetTypes =
+ getUnrolledTypes(offsetVecTy, *targetShape);
+ SmallVector<Value> convertedOffsetVec =
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+ auto newOp =
+ rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
+ newOps.push_back(newOp);
+ }
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
- patterns.getContext(), options);
+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+ UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+ UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+ options);
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..47c54bfcb89d0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,147 @@ gpu.module @test {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_vec
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_step
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_load
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ gpu.return %ld : vector<32xf32>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_prefetch(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_store
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+ gpu.func @test_store(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.>: vector<32xf32>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch_load_store_update
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+ xegpu.store %st_vec, %tdesc, %mask:
+ vector<32xf32>,
+ !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,29 @@ struct TestXeGPUUnrollingPatterns
}
}
+ if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ tdescTy = createOp.getType();
+ } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ tdescTy = updateOp.getTensorDescType();
+ } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+ tdescTy = prefetchOp.getTensorDescType();
+ } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ tdescTy = loadOp.getTensorDescType();
+ } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+ tdescTy = storeOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isSgLayout())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_data.asArrayRef().end());
+ }
+ }
+
if (isa<xegpu::DpasOp>(op))
return SmallVector<int64_t>{8, 16, 16};
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
newOps.push_back(newOp); | ||
} | ||
|
||
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this function called unpack
when it is doing N:1? shouldn't it be a pack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my understanding is that pack [m, n] to [m/bm, n/bn, bm, bn] so it is 1 to N. unpack does reverse so it is N to 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it follows pack/unpack definition in tensor dialect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good % formatting nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM generally, left a nit comment
|
||
VectorType maskTy; | ||
if (op.getMask()) | ||
maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about merge it into line 529?
Have we changed the mask as an optional operand by design? the definition so far hasn't revealed such change yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Co-authored-by: Adam Siemieniuk <[email protected]>
correct indentation
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets(); | ||
VectorType indiceVecTy = indiceVec.getType(); | ||
SmallVector<Type> convertedIndiceTypes = | ||
getUnrolledTypes(indiceVecTy, *targetShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the targetShape for indices should drop the last dim if chunkSize != 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will leave this to next PR.
✅ With the latest revision this PR passed the C/C++ code formatter. |
Add unrolling support for create_tdesc, load, store, prefetch, and update_offset.