Skip to content

Commit 0ccbcb6

Browse files
committed
Add hl.wait & AllGather Matmul example (via hl_ext helper).
stack-info: PR: #189, branch: joydddd/stack/5
1 parent 206adc7 commit 0ccbcb6

File tree

7 files changed

+569
-0
lines changed

7 files changed

+569
-0
lines changed

examples/all_gather_matmul.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import Any
5+
6+
import torch
7+
import torch.distributed as dist
8+
import torch.distributed._symmetric_memory as symm_mem
9+
10+
import helion
11+
import helion.language as hl
12+
13+
14+
def copy_engine_all_gather_w_progress(
15+
output: torch.Tensor,
16+
inp: torch.Tensor, # Must be symmetric tensor
17+
progress: torch.Tensor,
18+
splits_per_rank: int,
19+
backend_stream: torch.cuda.Stream | None = None,
20+
) -> torch.cuda.Stream:
21+
backend_stream = symm_mem._get_backend_stream(priority=-1)
22+
assert inp.is_contiguous()
23+
symm_mem_group = dist.group.WORLD
24+
if symm_mem_group is None:
25+
raise RuntimeError("No symmetric memory group available")
26+
symm_mem_hdl = symm_mem.rendezvous(inp, group=symm_mem_group)
27+
assert symm_mem_hdl is not None
28+
29+
rank = symm_mem_hdl.rank
30+
world_size = symm_mem_hdl.world_size
31+
32+
assert inp.numel() % splits_per_rank == 0
33+
assert progress.numel() >= world_size * splits_per_rank
34+
35+
output_shape = list(inp.shape)
36+
output_shape[0] *= world_size
37+
assert list(output.shape) == output_shape, (list(output.shape), output_shape)
38+
39+
chunks = output.chunk(world_size * splits_per_rank)
40+
41+
symm_mem_hdl.barrier()
42+
backend_stream.wait_stream(torch.cuda.current_stream())
43+
44+
with torch.cuda.stream(backend_stream):
45+
for step in range(world_size):
46+
src_rank = (rank + step + 1) % world_size
47+
for split_id in range(splits_per_rank):
48+
src_buf = symm_mem_hdl.get_buffer(
49+
src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
50+
)
51+
chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
52+
# cuStreamWriteValue32 issues a system level fence before the write
53+
symm_mem_hdl.stream_write_value32(
54+
progress,
55+
offset=src_rank * splits_per_rank + split_id,
56+
val=1,
57+
)
58+
symm_mem_hdl.barrier()
59+
60+
return backend_stream
61+
62+
63+
@helion.jit(
64+
config=helion.Config(
65+
block_sizes=[128, 256, 64],
66+
num_warps=8,
67+
num_stages=3,
68+
indexing="block_ptr",
69+
),
70+
static_shapes=True,
71+
)
72+
def helion_matmul_w_progress(
73+
a: torch.Tensor,
74+
a_shared: torch.Tensor,
75+
b: torch.Tensor,
76+
progress: torch.Tensor,
77+
SPLITS_PER_RANK: int,
78+
RANK: int,
79+
) -> torch.Tensor:
80+
M, K = a.size()
81+
K2, N = b.size()
82+
assert K2 == K, f"size mismatch {K2} != {K}"
83+
84+
out = torch.empty(
85+
[M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
86+
)
87+
88+
M_per_rank = a_shared.size(0)
89+
90+
for tile_m, tile_n in hl.tile([M, N]):
91+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
92+
hl.wait(
93+
progress,
94+
[
95+
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
96+
],
97+
signal=1,
98+
update=None,
99+
op="ld",
100+
scope="gpu",
101+
sem="acquire",
102+
)
103+
for tile_k in hl.tile(K):
104+
# TODO(joydddd): use a_shared and skipp barrier when data is available on local rank.
105+
# if tile_k.begin // M_per_rank == RANK:
106+
# acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n])
107+
# else:
108+
# hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire")
109+
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
110+
out[tile_m, tile_n] = acc
111+
return out
112+
113+
114+
def helion_all_gather_matmul(
115+
a_shared: torch.Tensor,
116+
b: torch.Tensor,
117+
a_out: torch.Tensor | None = None,
118+
progress: torch.Tensor | None = None,
119+
**kwargs: Any,
120+
) -> tuple[torch.Tensor, torch.Tensor]:
121+
configs = {
122+
"SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),
123+
"BLOCK_SIZE_M": kwargs.get("block_size_m", 128),
124+
"BLOCK_SIZE_N": kwargs.get("block_size_n", 256),
125+
"BLOCK_SIZE_K": kwargs.get("block_size_k", 64),
126+
"GROUP_SIZE_M": kwargs.get("group_size_m", 4),
127+
"num_stages": kwargs.get("num_stages", 3),
128+
"num_warps": kwargs.get("num_warps", 8),
129+
}
130+
131+
symm_mem_group = dist.group.WORLD
132+
if symm_mem_group is None:
133+
raise RuntimeError("No symmetric memory group available")
134+
135+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group)
136+
137+
a_shape = list(a_shared.shape)
138+
a_shape[0] *= symm_mem_hdl.world_size
139+
140+
configs["RANK"] = symm_mem_hdl.rank
141+
configs["WORLD_SIZE"] = symm_mem_hdl.world_size
142+
if (
143+
configs["SPLITS_PER_RANK"]
144+
* configs["WORLD_SIZE"]
145+
* configs["BLOCK_SIZE_M"]
146+
* configs["GROUP_SIZE_M"]
147+
> a_shape[0]
148+
):
149+
configs["GROUP_SIZE_M"] = 1
150+
configs["SPLITS_PER_RANK"] = 1
151+
152+
configs["COMM_BLOCK_SIZE_M"] = (
153+
a_shape[0] // configs["WORLD_SIZE"] // configs["SPLITS_PER_RANK"]
154+
)
155+
assert (
156+
configs["COMM_BLOCK_SIZE_M"]
157+
% (configs["BLOCK_SIZE_M"] * configs["GROUP_SIZE_M"])
158+
== 0
159+
)
160+
161+
if a_out is None:
162+
a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device)
163+
164+
if progress is None:
165+
progress = torch.zeros(
166+
symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"],
167+
dtype=torch.uint32,
168+
device=a_shared.device,
169+
)
170+
else:
171+
progress.fill_(
172+
0
173+
) # Reset progress to 0. Maybe we should reset inside the kernel using cas?
174+
175+
backend_stream = copy_engine_all_gather_w_progress(
176+
a_out, a_shared, progress, configs["SPLITS_PER_RANK"]
177+
)
178+
179+
c = helion_matmul_w_progress(
180+
a_out,
181+
a_shared,
182+
b,
183+
progress,
184+
SPLITS_PER_RANK=configs["SPLITS_PER_RANK"],
185+
RANK=configs["RANK"],
186+
)
187+
assert type(c) is torch.Tensor
188+
189+
torch.cuda.current_stream().wait_stream(backend_stream)
190+
191+
return a_out, c
192+
193+
194+
def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
195+
a_shared = symm_mem.empty(
196+
M // world_size, K, dtype=torch.bfloat16, device=device
197+
).normal_()
198+
b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T
199+
200+
a_out, c = helion_all_gather_matmul(a_shared, b)
201+
202+
golden_a = a_shared.clone()
203+
dist_group = dist.group.WORLD
204+
if dist_group is None:
205+
raise RuntimeError("No distributed group available")
206+
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
207+
golden_a, [b], gather_dim=0, group_name=dist_group.group_name
208+
)
209+
torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1)
210+
torch.testing.assert_close(a_out, ag_golden)
211+
212+
213+
def main() -> None:
214+
rank = int(os.environ["LOCAL_RANK"])
215+
world_size = int(os.environ["WORLD_SIZE"])
216+
torch.manual_seed(42 + rank)
217+
device = torch.device(f"cuda:{rank}")
218+
torch.cuda.set_device(device)
219+
dist.init_process_group("nccl")
220+
test(4096, 6656, 16384, world_size, device)
221+
222+
dist.destroy_process_group()
223+
224+
225+
if __name__ == "__main__":
226+
"""
227+
torchrun \
228+
--nnodes 1 --nproc-per-node 8 \
229+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
230+
--no_python python3 examples/all_gather_matmul.py
231+
"""
232+
main()

helion/_compiler/output_header.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"triton_helpers": "from torch._inductor.runtime import triton_helpers",
2424
"tl_math": "from torch._inductor.runtime.triton_helpers import math as tl_math",
2525
"libdevice": "from torch._inductor.runtime.triton_compat import libdevice",
26+
"hl_ext": "from helion import _triton_ext as hl_ext",
2627
}
2728

2829
if supports_tensor_descriptor():

helion/_triton_ext/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from __future__ import annotations
2+
3+
from .gmem_barrier import _triton_wait_multiple_signal
4+
from .gmem_barrier import _triton_wait_signal
5+
6+
__all__ = ["_triton_wait_multiple_signal", "_triton_wait_signal"]

helion/_triton_ext/gmem_barrier.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# pyre-ignore-all-errors[2] # ignore Missing parameter annotation
2+
from __future__ import annotations
3+
4+
import triton
5+
import triton.language as tl
6+
7+
8+
@triton.jit
9+
def _triton_wait_signal(
10+
addr,
11+
expect: tl.constexpr, # wait until lock is set to expect
12+
update: tl.constexpr, # update the lock once it is aquired.
13+
sem: tl.constexpr,
14+
scope: tl.constexpr,
15+
op: tl.constexpr,
16+
skip_sync: tl.constexpr,
17+
) -> None:
18+
"""
19+
Wait for a global memory barrier to reach the expected state.
20+
21+
This function implements a spin-wait loop that continuously checks a memory location
22+
until it reaches the expected value, providing synchronization across GPU threads.
23+
24+
Args:
25+
addr: Memory address of the barrier to wait on (Must be a scalar)
26+
expect: Expected value to wait for
27+
update: Update the barrier with once acquired
28+
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
29+
scope: Scope of the atomic operation. Options: "gpu", "sys"
30+
op: Atomic operation type: "ld", "atomic_cas"
31+
"""
32+
tl.static_assert(
33+
addr.type.is_ptr(),
34+
"Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ",
35+
)
36+
37+
tl.static_assert(
38+
sem == "acquire" or sem == "relaxed",
39+
"Invalid memory semantic. options: 'acquire', 'relaxed'. ",
40+
)
41+
tl.static_assert(
42+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
43+
)
44+
tl.static_assert(
45+
op == "ld" or op == "atomic_cas",
46+
"Invalid op. options: 'ld', 'atomic_cas'. ",
47+
)
48+
49+
# Spin-wait loop:
50+
# Uses atomic_add with update=0 for ld.global.{sem}.{scope}
51+
# Triton generates smem broadcasting of tl.atomic_add return value in ptx,
52+
# but it is optimized away by ptxas in SASS, hence no performance overhead.
53+
if op == "ld":
54+
tl.static_assert(
55+
update == 0, "ld wait on gmem_barriers cannot update the lock. "
56+
)
57+
while tl.atomic_add(addr, 0, sem=sem, scope=scope) != expect:
58+
pass
59+
elif op == "atomic_cas":
60+
while tl.atomic_cas(addr, expect, update, sem=sem, scope=scope) != expect:
61+
pass
62+
else:
63+
raise NotImplementedError(
64+
f"Unsupported op '{op}' for wait signal on gmem barrier. "
65+
)
66+
67+
if not skip_sync:
68+
tl.inline_asm_elementwise(
69+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
70+
)
71+
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?)
72+
73+
74+
@triton.jit
75+
def _triton_wait_multiple_signal(
76+
addr,
77+
expect: tl.constexpr, # wait until lock is set to expect
78+
update: tl.constexpr, # update the lock once it is aquired.
79+
sem: tl.constexpr,
80+
scope: tl.constexpr,
81+
op: tl.constexpr,
82+
skip_sync: tl.constexpr,
83+
) -> None:
84+
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
85+
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .memory_ops import atomic_add as atomic_add
1212
from .memory_ops import load as load
1313
from .memory_ops import store as store
14+
from .signal_wait import wait as wait
1415
from .tile_ops import tile_begin as tile_begin
1516
from .tile_ops import tile_block_size as tile_block_size
1617
from .tile_ops import tile_end as tile_end

0 commit comments

Comments
 (0)