@@ -305,16 +305,25 @@ class InsertGPUAllocsPass final
305
305
filter.insert (copy);
306
306
}
307
307
308
- if (allocType != memrefType)
309
- allocResult = builder.create <mlir::memref::CastOp>(loc, memrefType,
310
- allocResult);
311
-
312
- op.replaceAllUsesExcept (allocResult, filter);
313
- builder.setInsertionPoint (term);
314
- if (access .hostRead && access .deviceWrite ) {
315
- builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
308
+ if (allocType != memrefType) {
309
+ mlir::Value castedAllocResult = builder.create <mlir::memref::CastOp>(
310
+ loc, memrefType, allocResult);
311
+
312
+ op.replaceAllUsesExcept (castedAllocResult, filter);
313
+ builder.setInsertionPoint (term);
314
+ if (access .hostRead && access .deviceWrite ) {
315
+ builder.create <mlir::memref::CopyOp>(loc, castedAllocResult, op);
316
+ }
317
+ builder.create <mlir::gpu::DeallocOp>(loc, std::nullopt,
318
+ castedAllocResult);
319
+ } else {
320
+ op.replaceAllUsesExcept (allocResult, filter);
321
+ builder.setInsertionPoint (term);
322
+ if (access .hostRead && access .deviceWrite ) {
323
+ builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
324
+ }
325
+ builder.create <mlir::gpu::DeallocOp>(loc, std::nullopt, allocResult);
316
326
}
317
- builder.create <mlir::gpu::DeallocOp>(loc, std::nullopt, allocResult);
318
327
} else if (m_clientAPI == " vulkan" ) {
319
328
auto gpuAlloc =
320
329
builder.create <mlir::memref::AllocOp>(loc, allocType, dims);
@@ -325,14 +334,21 @@ class InsertGPUAllocsPass final
325
334
filter.insert (copy);
326
335
}
327
336
328
- if (allocType != memrefType)
329
- allocResult = builder.create <mlir::memref::CastOp>(loc, memrefType,
330
- allocResult);
337
+ if (allocType != memrefType) {
338
+ mlir::Value castedAllocResult = builder.create <mlir::memref::CastOp>(
339
+ loc, memrefType, allocResult);
331
340
332
- op.replaceAllUsesExcept (allocResult, filter);
333
- builder.setInsertionPoint (term);
334
- if (access .hostRead && access .deviceWrite ) {
335
- builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
341
+ op.replaceAllUsesExcept (castedAllocResult, filter);
342
+ builder.setInsertionPoint (term);
343
+ if (access .hostRead && access .deviceWrite ) {
344
+ builder.create <mlir::memref::CopyOp>(loc, castedAllocResult, op);
345
+ }
346
+ } else {
347
+ op.replaceAllUsesExcept (allocResult, filter);
348
+ builder.setInsertionPoint (term);
349
+ if (access .hostRead && access .deviceWrite ) {
350
+ builder.create <mlir::memref::CopyOp>(loc, allocResult, op);
351
+ }
336
352
}
337
353
}
338
354
};
0 commit comments