-
Notifications
You must be signed in to change notification settings - Fork 13
Fix a performance issue with Helion-emitted Flash Attention #181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Lint failures. Run
./lint.sh
- Test failures, you may need to run
EXPECTTEST_ACCEPT=1 pytest test
to update expected outputs. - Do we need to support
codegen_addmm
as well? - Fix code duplication
lhsSize = node.args[1].meta['val'].size() | ||
rhsSize = node.args[2].meta['val'].size() | ||
# check to see if it is 3D | ||
reduceDim = False | ||
if len(lhsSize) == 3: | ||
env = CompileEnvironment.current() | ||
lhsDimIdx = env.get_block_id(lhsSize[0]) | ||
rhsDimIdx = env.get_block_id(rhsSize[0]) | ||
if lhsDimIdx is not None and rhsDimIdx is not None: | ||
lhsDimVal = env.block_sizes[lhsDimIdx] | ||
rhsDimVal = env.block_sizes[rhsDimIdx] | ||
if (lhsDimVal.from_config(ctx.cg.device_function.config) == 1 and | ||
rhsDimVal.from_config(ctx.cg.device_function.config) == 1): | ||
reduceDim = True | ||
|
||
if not reduceDim: | ||
return expr_from_string( | ||
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})", | ||
lhs=lhs, | ||
rhs=rhs, | ||
acc=acc, | ||
) | ||
# create reshape, dot, then reshape | ||
lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str( | ||
[*node.args[1].meta["val"].size()[1:]] | ||
) | ||
rhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str( | ||
[*node.args[2].meta["val"].size()[1:]] | ||
) | ||
acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str( | ||
[*node.args[0].meta["val"].size()[1:]] | ||
) | ||
out_shape_str = ctx.cg.device_function.tile_strategy.shape_str( | ||
[*node.meta['val'].size()] | ||
) | ||
lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs) | ||
rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs) | ||
acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc) | ||
comp = expr_from_string( | ||
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})", | ||
lhs=lhs_reshape, | ||
rhs=rhs_reshape, | ||
acc=acc_reshape, | ||
) | ||
return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate code with the above? Please refactor into a helper function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some minor nits, but otherwise lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, need to fix merge conflict though
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@manman-ren has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -12,9 +12,9 @@ | |||
@helion.kernel( | |||
config=helion.Config( | |||
# This config was autotuned on a 3090, it won't be fast for other architectures |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this comment still valid? maybe delete it?
Triton doesn't handle 3D dots well in terms of performance. Perform the following operation when emitting dots from Helion when the highest dimension is 1 for 3D dots:
Perf Results on H100: