diff --git a/keras_tuner/backend/config.py b/keras_tuner/backend/config.py index 439a2e041..1665a32a0 100644 --- a/keras_tuner/backend/config.py +++ b/keras_tuner/backend/config.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras +import os -def _multi_backend(): +if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"): + from tensorflow as keras + _MULTI_BACKEND = False +else: + import keras version_fn = getattr(keras, "version", None) - return version_fn and version_fn().startswith("3.") - - -_MULTI_BACKEND = _multi_backend() + _MULTI_BACKEND = version_fn and version_fn().startswith("3.") def multi_backend():