Skip to content

Commit 6afe1fa

Browse files
require torch 2, allow numpy>2, remove unnecessary warnings
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent af8f44b commit 6afe1fa

File tree

2 files changed

+2
-47
lines changed

2 files changed

+2
-47
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def localversion_func(version: ScmVersion) -> str:
112112
install_requires=[
113113
"loguru",
114114
"pyyaml>=5.0.0",
115-
"numpy>=1.17.0,<2.0",
115+
"numpy>=1.17.0",
116116
"requests>=2.0.0",
117117
"tqdm>=4.0.0",
118-
"torch>=1.7.0",
118+
"torch>=2.7.0",
119119
"transformers>4.0,<5.0",
120120
"datasets",
121121
"accelerate>=0.20.3,!=1.1.0",

src/llmcompressor/pytorch/__init__.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +0,0 @@
1-
"""
2-
Functionality for working with and sparsifying Models in the PyTorch framework
3-
"""
4-
5-
import os
6-
import warnings
7-
8-
from packaging import version
9-
10-
try:
11-
import torch
12-
13-
_PARSED_TORCH_VERSION = version.parse(torch.__version__)
14-
15-
if _PARSED_TORCH_VERSION.major >= 2:
16-
torch_compile_func = torch.compile
17-
18-
def raise_torch_compile_warning(*args, **kwargs):
19-
warnings.warn(
20-
"torch.compile is not supported by llmcompressor for torch 2.0.x"
21-
)
22-
return torch_compile_func(*args, **kwargs)
23-
24-
torch.compile = raise_torch_compile_warning
25-
26-
_BYPASS = bool(int(os.environ.get("NM_BYPASS_TORCH_VERSION", "0")))
27-
if _PARSED_TORCH_VERSION.major == 1 and _PARSED_TORCH_VERSION.minor in [10, 11]:
28-
if not _BYPASS:
29-
raise RuntimeError(
30-
"llmcompressor does not support torch==1.10.* or 1.11.*. "
31-
f"Found torch version {torch.__version__}.\n\n"
32-
"To bypass this error, set environment variable "
33-
"`NM_BYPASS_TORCH_VERSION` to '1'.\n\n"
34-
"Bypassing may result in errors or "
35-
"incorrect behavior, so set at your own risk."
36-
)
37-
else:
38-
warnings.warn(
39-
"llmcompressor quantized onnx export does not work "
40-
"with torch==1.10.* or 1.11.*"
41-
)
42-
except ImportError:
43-
pass
44-
45-
# flake8: noqa

0 commit comments

Comments
 (0)