Skip to content

Fix a performance issue with Helion-emitted Flash Attention #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 23, 2025

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Jun 16, 2025

Triton doesn't handle 3D dots well in terms of performance. Perform the following operation when emitting dots from Helion when the highest dimension is 1 for 3D dots:

-        qk = tl.dot(q_copy, k, input_precision='tf32')
+        q_copy_r = q.reshape(_BLOCK_SIZE_1, 64)
+        k_r = k.reshape(64, _BLOCK_SIZE_2)
+        qk_r = tl.dot(q_copy_r, k_r, input_precision='tf32')
+        qk = qk_r.reshape(1, _BLOCK_SIZE_1, _BLOCK_SIZE_2)

Perf Results on H100:

CUDA_VISIBLE_DEVICES=5 python examples/attention.py 
Helion time: 0.0642ms, flex time: 0.0638, torch time: 0.0716

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 16, 2025
@manman-ren manman-ren requested review from jansel and yf225 June 16, 2025 21:39
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Lint failures. Run ./lint.sh
  2. Test failures, you may need to run EXPECTTEST_ACCEPT=1 pytest test to update expected outputs.
  3. Do we need to support codegen_addmm as well?
  4. Fix code duplication

Comment on lines 851 to 895
lhsSize = node.args[1].meta['val'].size()
rhsSize = node.args[2].meta['val'].size()
# check to see if it is 3D
reduceDim = False
if len(lhsSize) == 3:
env = CompileEnvironment.current()
lhsDimIdx = env.get_block_id(lhsSize[0])
rhsDimIdx = env.get_block_id(rhsSize[0])
if lhsDimIdx is not None and rhsDimIdx is not None:
lhsDimVal = env.block_sizes[lhsDimIdx]
rhsDimVal = env.block_sizes[rhsDimIdx]
if (lhsDimVal.from_config(ctx.cg.device_function.config) == 1 and
rhsDimVal.from_config(ctx.cg.device_function.config) == 1):
reduceDim = True

if not reduceDim:
return expr_from_string(
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
lhs=lhs,
rhs=rhs,
acc=acc,
)
# create reshape, dot, then reshape
lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*node.args[1].meta["val"].size()[1:]]
)
rhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*node.args[2].meta["val"].size()[1:]]
)
acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*node.args[0].meta["val"].size()[1:]]
)
out_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*node.meta['val'].size()]
)
lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs)
rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs)
acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc)
comp = expr_from_string(
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
lhs=lhs_reshape,
rhs=rhs_reshape,
acc=acc_reshape,
)
return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate code with the above? Please refactor into a helper function.

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor nits, but otherwise lgtm

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, need to fix merge conflict though

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren merged commit c86b278 into main Jun 23, 2025
6 checks passed
@facebook-github-bot
Copy link

@manman-ren has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -12,9 +12,9 @@
@helion.kernel(
config=helion.Config(
# This config was autotuned on a 3090, it won't be fast for other architectures
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment still valid? maybe delete it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants