Skip to content

Commit 5cfc4c7

Browse files
authored
Remove int_scaled_mm's dependency on triton for cpu (#128)
1 parent baa78f2 commit 5cfc4c7

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

torchao/kernel/intmm.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import torch
44

5-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_2
5+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_6
66

77
try:
88
# Only works for torch2.2 or newer.
@@ -134,6 +134,13 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -
134134
assert scales1.is_contiguous()
135135
scales1 = scales1.expand((M, N))
136136
assert scales1.dim() == 2
137+
138+
if scales1.device.type == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
139+
# CPU prefers decomposed version of int_scaled_matmul
140+
# to leverage the fusion capability of Inductor
141+
c = torch._int_mm(a, b)
142+
return c.to(scales1.dtype) * scales1
143+
137144
if intmm_triton is not None and AUTOTUNER_ENABLE:
138145
return torch.ops.torchao.int_scaled_matmul(a, b, scales1)
139146

torchao/kernel/intmm_triton.py

-6
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,3 @@ def int_scaled_matmul_cuda(a, b, scales1):
356356
int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs
357357
)
358358
return int_scaled_matmul_kernel(a, b, scales1, c, best_config)
359-
360-
361-
@torch.library.impl(lib, "int_scaled_matmul", "CPU")
362-
def int_scaled_matmul_cpu(a, b, scales1):
363-
c = torch._int_mm(a, b)
364-
return c.to(scales1.dtype) * scales1

0 commit comments

Comments
 (0)