-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmicroxcaling.patch
75 lines (70 loc) · 2.52 KB
/
microxcaling.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
diff --git a/mx/mx_ops.py b/mx/mx_ops.py
index 3aa1072..6d850fe 100644
--- a/mx/mx_ops.py
+++ b/mx/mx_ops.py
@@ -1,5 +1,6 @@
"""
-Copyright (c) Microsoft Corporation.
+Original Copyright (c) Microsoft Corporation.
+Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the MIT License.
Name: mx_ops.py
@@ -170,6 +171,7 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):
# -------------------------------------------------------------------------
# Main funcs
# -------------------------------------------------------------------------
def _quantize_mx(
A,
scale_bits,
@@ -187,6 +189,8 @@ def _quantize_mx(
if elem_format == None:
return A
+ if round == 'dither_scale':
+ assert custom_cuda == False
assert(scale_bits > 0)
# Make sure axes is a list of non-negative numbers
@@ -282,14 +286,23 @@ def _quantize_mx(
shared_exp[shared_exp > scale_emax] = float("NaN")
shared_exp[shared_exp < -scale_emax] = -scale_emax
- A = A / (2**shared_exp)
-
- A = _quantize_elemwise_core(
+ if round == 'dither_scale':
+ # 2025-04-22: Amazon addition.
+ # Stochastic rounding without clipping
+ A = 3 * A / (2**(shared_exp + 2))
+ A = _quantize_elemwise_core(
+ A, mbits, ebits, max_norm, round='dither',
+ allow_denorm=True, saturate_normals=True,
+ custom_cuda=custom_cuda)
+ A = A * (2**(shared_exp + 2)) / 3
+ # End of Amazon addition.
+ else:
+ A = A / (2**shared_exp)
+ A = _quantize_elemwise_core(
A, mbits, ebits, max_norm, round=round,
allow_denorm=True, saturate_normals=True,
custom_cuda=custom_cuda)
-
- A = A * (2**shared_exp)
+ A = A * (2**shared_exp)
# Undo tile reshaping
if block_size:
diff --git a/pyproject.toml b/pyproject.toml
index e80053e..c318c45 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,11 +4,7 @@ dynamic = ["version"]
description = 'The Microsoft MX floating point library'
readme = "README.md"
requires-python = ">=3.8"
-dependencies = [
- "torch==2.2.0",
- "torchvision==0.16",
- "torchaudio==2.1.0"
-]
+dependencies = []
license = { file = "LICENSE" }
keywords = ["mx", "floating point", "math", "mathematics", "machine learning", "deep learning", "artificial intelligence", "ai", "ml", "dl", "torch", "torchvision", "torchaudio"]
authors = [