From 0560f5bc08bd95722b1fed1e8e65124f10c10f27 Mon Sep 17 00:00:00 2001
From: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Date: Sun, 26 Jan 2025 21:51:16 -0800
Subject: [PATCH] [release/2.5] Enable tf32 testing on test_nn

1. In Context.cpp, for the variable allow_tf32 remove static const
variable type to non const variable. This allows us to capture the
env variable HIPBLASLT_ALLOW_TF32 changes.
2. Add ROCm arch support to tf32_is_not_fp32()

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
---
 aten/src/ATen/Context.cpp              |  4 ++--
 torch/testing/_internal/common_cuda.py | 25 ++++++++++++++++++++++++-
 2 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
index 492e58cdd6991..3c50cbe0dc57b 100644
--- a/aten/src/ATen/Context.cpp
+++ b/aten/src/ATen/Context.cpp
@@ -233,7 +233,7 @@ void Context::setBenchmarkLimitCuDNN(int b) {
 
 bool Context::allowTF32CuBLAS() const {
 #ifdef USE_ROCM
-    const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
+    auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
     if (allow_tf32 != true) {
       return false;
     }
@@ -243,7 +243,7 @@ bool Context::allowTF32CuBLAS() const {
 
 void Context::setAllowTF32CuBLAS(bool b) {
 #ifdef USE_ROCM
-  const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
+  auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
   if (allow_tf32 != true) {
     LOG(INFO) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
               << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py
index 01eeac86ae135..51da5f42bb41a 100644
--- a/torch/testing/_internal/common_cuda.py
+++ b/torch/testing/_internal/common_cuda.py
@@ -5,7 +5,14 @@
 import functools
 import torch
 import torch.cuda
-from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
+from torch.testing._internal.common_utils import (
+    LazyVal,
+    MI300_ARCH,
+    TEST_NUMBA,
+    TEST_WITH_ROCM,
+    TEST_CUDA,
+    IS_WINDOWS,
+)
 import inspect
 import contextlib
 import os
@@ -118,7 +125,15 @@ def initialize_cuda_context_rng():
 # Test whether hardware TF32 math mode enabled. It is enabled only on:
 # - CUDA >= 11
 # - arch >= Ampere
+#More--
+# For AMD GPUs, tf32 is supported on mi300.
 def tf32_is_not_fp32():
+    if torch.version.hip:
+        prop = torch.cuda.get_device_properties(torch.cuda.current_device())
+        if prop.gcnArchName.split(":")[0] in MI300_ARCH:
+            return True
+        else:
+            return False
     if not torch.cuda.is_available() or torch.version.cuda is None:
         return False
     if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
@@ -141,6 +156,9 @@ def tf32_off():
 
 @contextlib.contextmanager
 def tf32_on(self, tf32_precision=1e-5):
+    if torch.version.hip:
+        hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
+        os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
     old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
     old_precision = self.precision
     try:
@@ -149,6 +167,11 @@ def tf32_on(self, tf32_precision=1e-5):
         with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
             yield
     finally:
+        if torch.version.hip:
+            if hip_allow_tf32 is not None:
+                os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
+            else:
+                del os.environ["HIPBLASLT_ALLOW_TF32"]
         torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
         self.precision = old_precision