Skip to content

Commit 4ed4b25

Browse files
committed
[mlir] Initial patch to add an MPI dialect
1 parent 30240e4 commit 4ed4b25

File tree

14 files changed

+302
-0
lines changed

14 files changed

+302
-0
lines changed

mlir/include/mlir/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_subdirectory(Math)
2121
add_subdirectory(MemRef)
2222
add_subdirectory(Mesh)
2323
add_subdirectory(MLProgram)
24+
add_subdirectory(MPI)
2425
add_subdirectory(NVGPU)
2526
add_subdirectory(OpenACC)
2627
add_subdirectory(OpenACCMPCommon)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS MPIOps.td)
2+
add_mlir_dialect(MPIOps mpi)
3+
add_mlir_doc(MPIOps MPIOps Dialects/ -gen-dialect-doc)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- MPI.h - MPI dialect ----------------------------*- C++-*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_DIALECT_MPI_IR_MPI_H_
9+
#define MLIR_DIALECT_MPI_IR_MPI_H_
10+
11+
#include "mlir/IR/Dialect.h"
12+
#include "mlir/IR/OpDefinition.h"
13+
#include "mlir/IR/OpImplementation.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// MPIDialect
17+
//===----------------------------------------------------------------------===//
18+
19+
#include "mlir/Dialect/MPI/IR/MPIOpsDialect.h.inc"
20+
21+
//===----------------------------------------------------------------------===//
22+
// MPI Dialect Operations
23+
//===----------------------------------------------------------------------===//
24+
25+
#define GET_OP_CLASSES
26+
#include "mlir/Dialect/MPI/IR/MPIOps.h.inc"
27+
28+
#endif // MLIR_DIALECT_MPI_IR_MPI_H_
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- MPIBase.td - Base defs for mpi dialect --*- tablegen -*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MPI_BASE
10+
#define MPI_BASE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def MPI_Dialect : Dialect {
15+
let name = "mpi";
16+
let cppNamespace = "::mlir::mpi";
17+
let description = [{
18+
This dialect models the Message Passing Interface (MPI), version 4.0. It is
19+
meant to serve as an interfacing dialect that is targeted by higher-level dialects.
20+
The MPI dialect itself can be lowered to multiple MPI implementations and hide
21+
differences in ABI. The dialect models the functions of the MPI specification as
22+
close to 1:1 as possible while preserving SSA value semantics where it makes sense,
23+
and uses `memref` types instead of bare pointers.
24+
25+
This dialect is under active development, and while stability is an
26+
eventual goal, it is not guaranteed at this juncture. Given the early state,
27+
it is recommended to inquire further prior to using this dialect.
28+
29+
For an in-depth documentation of the MPI library interface, please refer to official documentation
30+
such as the [OpenMPI online documentation](https://www.open-mpi.org/doc/current/).
31+
}];
32+
33+
let usePropertiesForAttributes = 1;
34+
}
35+
36+
#endif // MPI_BASE
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//===- MPI.td - Message Passing Interface Ops ---------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MPI_OPS
10+
#define MPI_OPS
11+
12+
include "mlir/Dialect/MPI/IR/MPIBase.td"
13+
14+
class MPI_Op<string mnemonic, list<Trait> traits = []>
15+
: Op<MPI_Dialect, mnemonic, traits>;
16+
17+
//===----------------------------------------------------------------------===//
18+
// InitOp
19+
//===----------------------------------------------------------------------===//
20+
21+
def MPI_InitOp : MPI_Op<"init", [
22+
23+
]> {
24+
let summary =
25+
"Initialize the MPI library, equivalent to `MPI_Init(NULL, NULL)`";
26+
let description = [{
27+
This operation must preceed most MPI calls (except for very few exceptions,
28+
please consult with the MPI specification on these).
29+
30+
Passing &argc, &argv is not supported currently.
31+
Inspecting the functions return value (error code) is also not supported.
32+
}];
33+
34+
let assemblyFormat = "attr-dict";
35+
}
36+
37+
//===----------------------------------------------------------------------===//
38+
// CommRankOp
39+
//===----------------------------------------------------------------------===//
40+
41+
def MPI_CommRankOp : MPI_Op<"comm_rank", [
42+
43+
]> {
44+
let summary = "Get the current rank, equivalent to "
45+
"`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
46+
let description = [{
47+
Communicators other than `MPI_COMM_WORLD` are not supprted for now.
48+
Inspecting the functions return value (error code) is also not supported.
49+
}];
50+
51+
let results = (outs I32 : $result);
52+
53+
let assemblyFormat = "attr-dict `:` type($result)";
54+
}
55+
56+
//===----------------------------------------------------------------------===//
57+
// SendOp
58+
//===----------------------------------------------------------------------===//
59+
60+
def MPI_SendOp : MPI_Op<"send", [
61+
62+
]> {
63+
let summary =
64+
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
65+
let description = [{
66+
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank `dest`.
67+
The `tag` value and communicator enables the library to determine the matching of
68+
multiple sends and receives between the same ranks.
69+
70+
Communicators other than `MPI_COMM_WORLD` are not supprted for now.
71+
Inspecting the functions return value (error code) is also not supported.
72+
}];
73+
74+
let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
75+
76+
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
77+
"type($ref) `,` type($tag) `,` type($rank)";
78+
}
79+
80+
//===----------------------------------------------------------------------===//
81+
// RecvOp
82+
//===----------------------------------------------------------------------===//
83+
84+
def MPI_RecvOp : MPI_Op<"recv", [
85+
86+
]> {
87+
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
88+
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
89+
let description = [{
90+
MPI_Recv performs a blocking receive of `size` elements of type `dtype` from rank `dest`.
91+
The `tag` value and communicator enables the library to determine the matching of
92+
multiple sends and receives between the same ranks.
93+
94+
Communicators other than `MPI_COMM_WORLD` are not supprted for now.
95+
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object is not yet ported to MLIR.
96+
Inspecting the functions return value (error code) is also not supported.
97+
}];
98+
99+
let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
100+
101+
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
102+
"type($ref) `,` type($tag) `,` type($rank)";
103+
}
104+
105+
//===----------------------------------------------------------------------===//
106+
// FinalizeOp
107+
//===----------------------------------------------------------------------===//
108+
109+
def MPI_FinalizeOp : MPI_Op<"finalize", [
110+
111+
]> {
112+
let summary = "Finalize the MPI library, equivalent to `MPI_Finalize()`";
113+
let description = [{
114+
This function cleans up the MPI state. Afterwards, no MPI methods may be invoked
115+
(excpet for MPI_Get_version, MPI_Initialized, and MPI_Finalized).
116+
Notably, MPI_Init cannot be called again in the same program.
117+
118+
Inspecting the functions return value (error code) is not supported.
119+
}];
120+
121+
let assemblyFormat = "attr-dict";
122+
}
123+
124+
#endif // MPI_OPS

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
4949
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
5050
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
51+
#include "mlir/Dialect/MPI/IR/MPI.h"
5152
#include "mlir/Dialect/Math/IR/Math.h"
5253
#include "mlir/Dialect/MemRef/IR/MemRef.h"
5354
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
@@ -120,6 +121,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
120121
memref::MemRefDialect,
121122
mesh::MeshDialect,
122123
ml_program::MLProgramDialect,
124+
mpi::MPIDialect,
123125
nvgpu::NVGPUDialect,
124126
NVVM::NVVMDialect,
125127
omp::OpenMPDialect,

mlir/lib/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_subdirectory(Math)
2121
add_subdirectory(MemRef)
2222
add_subdirectory(Mesh)
2323
add_subdirectory(MLProgram)
24+
add_subdirectory(MPI)
2425
add_subdirectory(NVGPU)
2526
add_subdirectory(OpenACC)
2627
add_subdirectory(OpenACCMPCommon)

mlir/lib/Dialect/MPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_dialect_library(MLIRMPIDialect
2+
MPIOps.cpp
3+
MPIDialect.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MPI
7+
8+
DEPENDS
9+
MLIRMPIOpsIncGen
10+
11+
LINK_LIBS PUBLIC
12+
MLIRDialect
13+
MLIRIR
14+
MLIRInferTypeOpInterface
15+
MLIRSideEffectInterfaces
16+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MPIDialect.cpp - MPI dialect implementation ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MPI/IR/MPI.h"
10+
#include "mlir/IR/DialectImplementation.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::mpi;
14+
15+
//===----------------------------------------------------------------------===//
16+
/// Tablegen Definitions
17+
//===----------------------------------------------------------------------===//
18+
19+
#include "mlir/Dialect/MPI/IR/MPIOpsDialect.cpp.inc"
20+
21+
void mpi::MPIDialect::initialize() {
22+
23+
addOperations<
24+
#define GET_OP_LIST
25+
#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
26+
>();
27+
}

mlir/lib/Dialect/MPI/IR/MPIOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===- MPIOps.cpp - MPI dialect ops implementation ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MPI/IR/MPI.h"
10+
11+
using namespace mlir;
12+
using namespace mlir::mpi;
13+
14+
//===----------------------------------------------------------------------===//
15+
// TableGen'd op method definitions
16+
//===----------------------------------------------------------------------===//
17+
18+
#define GET_OP_CLASSES
19+
#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"

mlir/test/Dialect/MPI/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
2+
3+
// expected-error @+1 {{op result #0 must be 32-bit signless integer, but got 'i64'}}
4+
%rank = mpi.comm_rank : i64
5+
6+
// -----
7+
8+
func.func @mpi_test(%ref : !llvm.ptr, %rank: i32) -> () {
9+
// expected-error @+1 {{invalid kind of type specified}}
10+
mpi.send(%ref, %rank, %rank) : !llvm.ptr, i32, i32
11+
12+
return
13+
}
14+
15+
// -----
16+
17+
func.func @mpi_test(%ref : !llvm.ptr, %rank: i32) -> () {
18+
// expected-error @+1 {{invalid kind of type specified}}
19+
mpi.recv(%ref, %rank, %rank) : !llvm.ptr, i32, i32
20+
21+
return
22+
}

mlir/test/Dialect/MPI/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2+
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
3+
4+
func.func @mpi_test(%ref : memref<100xf32>) -> () {
5+
// CHECK: mpi.init
6+
mpi.init
7+
8+
// CHECK-NEXT: mpi.comm_rank : i32
9+
%rank = mpi.comm_rank : i32
10+
11+
// CHECK-NEXT: mpi.send(%arg0, %0, %0) : memref<100xf32>, i32, i32
12+
mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
13+
14+
// CHECK-NEXT: mpi.recv(%arg0, %0, %0) : memref<100xf32>, i32, i32
15+
mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32
16+
17+
// CHECK-NEXT: mpi.finalize
18+
mpi.finalize
19+
20+
func.return
21+
}

0 commit comments

Comments
 (0)