File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -1270,6 +1270,8 @@ def choose_qparams_affine_with_min_max(
1270
1270
if eps is None :
1271
1271
eps = torch .finfo (min_val .dtype ).eps
1272
1272
1273
+ scale_device = min_val .device
1274
+
1273
1275
if preserve_zero :
1274
1276
min_val_neg = torch .min (min_val , torch .zeros_like (min_val ))
1275
1277
max_val_pos = torch .max (max_val , torch .zeros_like (max_val ))
@@ -1316,7 +1318,9 @@ def choose_qparams_affine_with_min_max(
1316
1318
scale = torch .clamp (scale , min = eps )
1317
1319
else :
1318
1320
assert mapping_type == MappingType .ASYMMETRIC
1319
- scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
1321
+ scale = (max_val_pos - min_val_neg ) / torch .tensor (
1322
+ float (quant_max - quant_min ), dtype = scale_dtype , device = scale_device
1323
+ )
1320
1324
scale = torch .clamp (scale , min = eps )
1321
1325
if zero_point_domain == ZeroPointDomain .NONE :
1322
1326
zero_point = None
You can’t perform that action at this time.
0 commit comments