Skip to content

Commit 0560f5b

Browse files
committed
[release/2.5] Enable tf32 testing on test_nn
1. In Context.cpp, for the variable allow_tf32 remove static const variable type to non const variable. This allows us to capture the env variable HIPBLASLT_ALLOW_TF32 changes. 2. Add ROCm arch support to tf32_is_not_fp32() Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent 33911de commit 0560f5b

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

aten/src/ATen/Context.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ void Context::setBenchmarkLimitCuDNN(int b) {
233233

234234
bool Context::allowTF32CuBLAS() const {
235235
#ifdef USE_ROCM
236-
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
236+
auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
237237
if (allow_tf32 != true) {
238238
return false;
239239
}
@@ -243,7 +243,7 @@ bool Context::allowTF32CuBLAS() const {
243243

244244
void Context::setAllowTF32CuBLAS(bool b) {
245245
#ifdef USE_ROCM
246-
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
246+
auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
247247
if (allow_tf32 != true) {
248248
LOG(INFO) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
249249
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";

torch/testing/_internal/common_cuda.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
import functools
66
import torch
77
import torch.cuda
8-
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
8+
from torch.testing._internal.common_utils import (
9+
LazyVal,
10+
MI300_ARCH,
11+
TEST_NUMBA,
12+
TEST_WITH_ROCM,
13+
TEST_CUDA,
14+
IS_WINDOWS,
15+
)
916
import inspect
1017
import contextlib
1118
import os
@@ -118,7 +125,15 @@ def initialize_cuda_context_rng():
118125
# Test whether hardware TF32 math mode enabled. It is enabled only on:
119126
# - CUDA >= 11
120127
# - arch >= Ampere
128+
#More--
129+
# For AMD GPUs, tf32 is supported on mi300.
121130
def tf32_is_not_fp32():
131+
if torch.version.hip:
132+
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
133+
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
134+
return True
135+
else:
136+
return False
122137
if not torch.cuda.is_available() or torch.version.cuda is None:
123138
return False
124139
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
@@ -141,6 +156,9 @@ def tf32_off():
141156

142157
@contextlib.contextmanager
143158
def tf32_on(self, tf32_precision=1e-5):
159+
if torch.version.hip:
160+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
161+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
144162
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
145163
old_precision = self.precision
146164
try:
@@ -149,6 +167,11 @@ def tf32_on(self, tf32_precision=1e-5):
149167
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
150168
yield
151169
finally:
170+
if torch.version.hip:
171+
if hip_allow_tf32 is not None:
172+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
173+
else:
174+
del os.environ["HIPBLASLT_ALLOW_TF32"]
152175
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
153176
self.precision = old_precision
154177

0 commit comments

Comments
 (0)