Skip to content

Commit 83d990d

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Fix device and dtype discrepancy in _choose_qparams_affine (#2210)
Summary: Pull Request resolved: #2210 Reviewed By: jainapurva Differential Revision: D74446877
1 parent 35ffb26 commit 83d990d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,8 @@ def choose_qparams_affine_with_min_max(
12701270
if eps is None:
12711271
eps = torch.finfo(min_val.dtype).eps
12721272

1273+
scale_device = min_val.device
1274+
12731275
if preserve_zero:
12741276
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
12751277
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
@@ -1316,7 +1318,9 @@ def choose_qparams_affine_with_min_max(
13161318
scale = torch.clamp(scale, min=eps)
13171319
else:
13181320
assert mapping_type == MappingType.ASYMMETRIC
1319-
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
1321+
scale = (max_val_pos - min_val_neg) / torch.tensor(
1322+
float(quant_max - quant_min), dtype=scale_dtype, device=scale_device
1323+
)
13201324
scale = torch.clamp(scale, min=eps)
13211325
if zero_point_domain == ZeroPointDomain.NONE:
13221326
zero_point = None

0 commit comments

Comments
 (0)