diff --git a/keras/api/_tf_keras/keras/callbacks/__init__.py b/keras/api/_tf_keras/keras/callbacks/__init__.py index 4e165cddb6a8..e7f0bedd62c3 100644 --- a/keras/api/_tf_keras/keras/callbacks/__init__.py +++ b/keras/api/_tf_keras/keras/callbacks/__init__.py @@ -16,6 +16,9 @@ from keras.src.callbacks.learning_rate_scheduler import ( LearningRateScheduler as LearningRateScheduler, ) +from keras.src.callbacks.memory_usage_callback import ( + MemoryUsageCallback as MemoryUsageCallback, +) from keras.src.callbacks.model_checkpoint import ( ModelCheckpoint as ModelCheckpoint, ) diff --git a/keras/api/callbacks/__init__.py b/keras/api/callbacks/__init__.py index 4e165cddb6a8..e7f0bedd62c3 100644 --- a/keras/api/callbacks/__init__.py +++ b/keras/api/callbacks/__init__.py @@ -16,6 +16,9 @@ from keras.src.callbacks.learning_rate_scheduler import ( LearningRateScheduler as LearningRateScheduler, ) +from keras.src.callbacks.memory_usage_callback import ( + MemoryUsageCallback as MemoryUsageCallback, +) from keras.src.callbacks.model_checkpoint import ( ModelCheckpoint as ModelCheckpoint, ) diff --git a/keras/src/callbacks/__init__.py b/keras/src/callbacks/__init__.py index 427c4f6da95f..dea6f62fe27d 100644 --- a/keras/src/callbacks/__init__.py +++ b/keras/src/callbacks/__init__.py @@ -6,6 +6,7 @@ from keras.src.callbacks.history import History from keras.src.callbacks.lambda_callback import LambdaCallback from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler +from keras.src.callbacks.memory_usage_callback import MemoryUsageCallback from keras.src.callbacks.model_checkpoint import ModelCheckpoint from keras.src.callbacks.monitor_callback import MonitorCallback from keras.src.callbacks.progbar_logger import ProgbarLogger diff --git a/keras/src/callbacks/memory_usage_callback.py b/keras/src/callbacks/memory_usage_callback.py new file mode 100644 index 000000000000..ad6831e7041f --- /dev/null +++ b/keras/src/callbacks/memory_usage_callback.py @@ -0,0 +1,278 @@ +import os +import time +import warnings + +from keras.src import backend as K +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback + +try: + import psutil +except ImportError: + psutil = None + + +def running_on_gpu(): + """Detect if any GPU is available on the current Keras backend.""" + backend_name = K.backend() + if backend_name == "tensorflow": + import tensorflow as tf + + try: + return bool(tf.config.list_physical_devices("GPU")) + except Exception: + return False + elif backend_name == "torch": + try: + import torch + + return torch.cuda.is_available() + except ImportError: + return False + elif backend_name == "jax": + try: + import jax + + return any(d.platform.upper() == "GPU" for d in jax.devices()) + except ImportError: + return False + return False + + +def running_on_tpu(): + """Detect if any TPU is available on the current Keras backend.""" + backend_name = K.backend() + if backend_name == "tensorflow": + import tensorflow as tf + + try: + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + except Exception: + pass + try: + return bool(tf.config.list_physical_devices("TPU")) + except Exception: + return False + elif backend_name == "jax": + try: + import jax + + return any(d.platform.upper() == "TPU" for d in jax.devices()) + except ImportError: + return False + return False + + +@keras_export("keras.callbacks.MemoryUsageCallback") +class MemoryUsageCallback(Callback): + """ + Monitors and logs memory usage (CPU + optional GPU/TPU) during training. + + This callback measures: + - **CPU**: via psutil.Process().memory_info().rss + - **GPU**: if a GPU is detected, via backend-specific APIs + (TensorFlow, PyTorch, JAX) + - **TPU**: if a TPU is detected, via backend-specific APIs + (TensorFlow, JAX) + + Logs are printed to stdout at the start and end of each epoch + (with a leading newline to avoid clobbering the progress bar), + and, if `log_every_batch=True`, after every batch. + If `tensorboard_log_dir` is provided, scalars are also written + via tf.summary (TensorBoard). + + Args: + log_every_batch (bool): If True, also log after each batch. Defaults to False. + tensorboard_log_dir (str|None): Directory for TensorBoard logs; if None, + no TF summary writer is created. + + Raises: + ImportError: If `psutil` is not installed (required for CPU logging). + """ + + def __init__( + self, + log_every_batch=False, + tensorboard_log_dir=None, + ): + super().__init__() + + if psutil is None: + raise ImportError( + "MemoryUsageCallback requires the 'psutil' library. " + "To install, please use: pip install psutil" + ) + self.log_every_batch = log_every_batch + self._proc = psutil.Process() + self._step_counter = 0 + self._writer = None + + if tensorboard_log_dir: + try: + import tensorflow as tf + + logdir = os.path.expanduser(tensorboard_log_dir) + self._writer = tf.summary.create_file_writer(logdir) + print(f"MemoryUsageCallback: TensorBoard logs → {logdir}") + except Exception as e: + warnings.warn( + f"Could not initialize TensorBoard writer: {e}", RuntimeWarning + ) + self._writer = None + + def on_train_begin(self, logs=None): + self._step_counter = 0 + + def on_epoch_begin(self, epoch, logs=None): + print() + self._log_epoch("start", epoch) + + def on_epoch_end(self, epoch, logs=None): + print() + self._log_epoch("end", epoch, offset=1) + + def on_batch_end(self, batch, logs=None): + if self.log_every_batch: + print() + self._log_step(f"Batch {self._step_counter} end", self._step_counter) + self._step_counter += 1 + + def on_train_end(self, logs=None): + if self._writer: + self._writer.close() + + def _log_epoch(self, when, epoch, offset=0): + label = f"Epoch {epoch} {when}" + step = epoch + offset + self._log_step(label, step) + + def _log_step(self, label, step): + """ + Internal helper to measure and print CPU/GPU/TPU memory. + Inserts a short delay (time.sleep(0)) to let stdout flush cleanly. + """ + cpu_mb = self._get_cpu_memory() + gpu_mb = self._get_gpu_memory() + tpu_mb = self._get_tpu_memory() + + msg = f"{label} - CPU Memory: {cpu_mb:.2f} MB" + if gpu_mb is not None: + msg += f"; GPU Memory: {gpu_mb:.2f} MB" + if tpu_mb is not None: + msg += f"; TPU Memory: {tpu_mb:.2f} MB" + print(msg) + time.sleep(0) + + if self._writer: + import tensorflow as tf + + with self._writer.as_default(step=int(step)): + tf.summary.scalar("Memory/CPU_MB", cpu_mb) + if gpu_mb is not None: + tf.summary.scalar("Memory/GPU_MB", gpu_mb) + if tpu_mb is not None: + tf.summary.scalar("Memory/TPU_MB", tpu_mb) + + def _get_cpu_memory(self): + """Return current process CPU memory usage in MB.""" + return self._proc.memory_info().rss / (1024**2) + + def _get_gpu_memory(self): + """ + Return current GPU memory usage in MB for the detected backend, + or None if no GPU is present or if measurement fails. + """ + if not running_on_gpu(): + return None + backend_name = K.backend() + try: + if backend_name == "tensorflow": + import tensorflow as tf + + try: + mem_info = tf.config.experimental.get_memory_info("GPU:0") + return mem_info["current"] / (1024**2) + except Exception: + gpus = tf.config.list_physical_devices("GPU") + if not gpus: + return None + total = 0 + for i in range(len(gpus)): + try: + info = tf.config.experimental.get_memory_info(f"GPU:{i}") + total += info.get("current", 0) + except Exception: + continue + return total / (1024**2) + elif backend_name == "torch": + import torch + + if not torch.cuda.is_available(): + return None + total_bytes = 0 + for i in range(torch.cuda.device_count()): + total_bytes += torch.cuda.memory_allocated(i) + return total_bytes / (1024**2) + elif backend_name == "jax": + import jax + + devs = [d for d in jax.devices() if d.platform.upper() == "GPU"] + if not devs: + return None + total = 0 + for d in devs: + stats = getattr(d, "memory_stats", lambda: {})() + total += stats.get("bytes_in_use", 0) + return total / (1024**2) + return None + except ImportError as imp_err: + if not hasattr(self, "_warn_import"): + warnings.warn( + f"Could not import library for GPU memory tracking ({backend_name}): {imp_err}", + RuntimeWarning, + ) + self._warn_import = True + return None + except Exception as exc: + if not hasattr(self, "_warn_exc"): + warnings.warn(f"Error retrieving GPU memory: {exc}", RuntimeWarning) + self._warn_exc = True + return None + + def _get_tpu_memory(self): + """ + Return current TPU memory usage in MB for the detected backend, + or None if no TPU is present or if measurement fails. + Note: TPU memory APIs vary; here we attempt best‐effort. + """ + if not running_on_tpu(): + return None + backend_name = K.backend() + try: + if backend_name == "tensorflow": + return None + elif backend_name == "jax": + import jax + + devs = [d for d in jax.devices() if d.platform.upper() == "TPU"] + if not devs: + return None + stats = getattr(devs[0], "memory_stats", lambda: {})() + tpu_bytes = stats.get("bytes_in_use", stats.get("allocated_bytes", 0)) + return tpu_bytes / (1024**2) + return None + except ImportError as imp_err: + if not hasattr(self, "_warn_tpu_imp"): + warnings.warn( + f"Could not import library for TPU memory tracking ({backend_name}): {imp_err}", + RuntimeWarning, + ) + self._warn_tpu_imp = True + return None + except Exception as exc: + if not hasattr(self, "_warn_tpu_exc"): + warnings.warn(f"Error retrieving TPU memory: {exc}", RuntimeWarning) + self._warn_tpu_exc = True + return None \ No newline at end of file diff --git a/keras/src/callbacks/memory_usage_callback_test.py b/keras/src/callbacks/memory_usage_callback_test.py new file mode 100644 index 000000000000..852d72fe9afb --- /dev/null +++ b/keras/src/callbacks/memory_usage_callback_test.py @@ -0,0 +1,169 @@ +import os +import glob +import sys +import tempfile +import re + +import numpy as np +import pytest +import tensorflow as tf + +from io import StringIO +from contextlib import redirect_stdout +from importlib import reload +from unittest.mock import patch, MagicMock + +from keras.src import backend as K +from keras.src.callbacks.memory_usage_callback import ( + MemoryUsageCallback, + running_on_gpu, + running_on_tpu, +) +from keras.src.models import Sequential +from keras.src.layers import Dense + +try: + import psutil +except ImportError: + psutil = None + + +@pytest.mark.skipif(psutil is None, reason="psutil is required for MemoryUsageCallback tests.") +class TestMemoryUsageCallback: + @pytest.fixture(autouse=True) + def setup_model(self): + self.x_train = np.random.random((20, 10)).astype(np.float32) + self.y_train = np.random.randint(0, 2, (20, 1)).astype(np.float32) + + self.model = Sequential([ + Dense(5, activation="relu", input_shape=(10,)), + Dense(1, activation="sigmoid") + ]) + self.model.compile(optimizer="adam", loss="binary_crossentropy") + + self.epochs = 2 + self.batch_size = 5 + self.steps_per_epoch = len(self.x_train) // self.batch_size + yield + + @pytest.mark.requires_trainable_backend + def test_cpu_only_epoch_logging(self): + # Force TF backend and no GPU/TPU + monkey = patch.object(K, "backend", lambda: "tensorflow") + with monkey: + out = StringIO() + with redirect_stdout(out), \ + patch("keras.src.callbacks.memory_usage_callback.running_on_gpu", return_value=False), \ + patch("keras.src.callbacks.memory_usage_callback.running_on_tpu", return_value=False): + cb = MemoryUsageCallback(log_every_batch=False) + self.model.fit(self.x_train, self.y_train, + epochs=self.epochs, + batch_size=self.batch_size, + callbacks=[cb], + verbose=0) + + lines = out.getvalue().splitlines() + start = [ln for ln in lines if re.match(r"^Epoch \d+ start - CPU Memory:", ln)] + end = [ln for ln in lines if re.match(r"^Epoch \d+ end - CPU Memory:", ln)] + assert len(start) == self.epochs + assert len(end) == self.epochs + assert not any("GPU Memory" in ln or "TPU Memory" in ln for ln in lines) + + @pytest.mark.requires_trainable_backend + def test_log_every_batch(self): + monkey = patch.object(K, "backend", lambda: "tensorflow") + with monkey: + out = StringIO() + with redirect_stdout(out), \ + patch("keras.src.callbacks.memory_usage_callback.running_on_gpu", return_value=False), \ + patch("keras.src.callbacks.memory_usage_callback.running_on_tpu", return_value=False): + cb = MemoryUsageCallback(log_every_batch=True) + self.model.fit(self.x_train, self.y_train, + epochs=self.epochs, + batch_size=self.batch_size, + callbacks=[cb], + verbose=0) + + lines = out.getvalue().splitlines() + batches = [ln for ln in lines if re.match(r"^Batch \d+ end - CPU Memory:", ln)] + assert len(batches) == self.epochs * self.steps_per_epoch + + @pytest.mark.requires_trainable_backend + def test_tensorboard_log_dir(self): + monkey = patch.object(K, "backend", lambda: "tensorflow") + with monkey: + with tempfile.TemporaryDirectory() as tmpdir: + log_dir = os.path.join(tmpdir, "tb_logs") + with patch("keras.src.callbacks.memory_usage_callback.running_on_gpu", return_value=False), \ + patch("keras.src.callbacks.memory_usage_callback.running_on_tpu", return_value=False): + cb = MemoryUsageCallback(log_every_batch=True, tensorboard_log_dir=log_dir) + assert os.path.isdir(log_dir) + self.model.fit(self.x_train, self.y_train, + epochs=self.epochs, + batch_size=self.batch_size, + callbacks=[cb], + verbose=0) + files = glob.glob(os.path.join(log_dir, "events.out.tfevents.*")) + assert files, f"No events files under {log_dir}" + + @pytest.mark.requires_trainable_backend + def test_get_gpu_memory_tensorflow(self): + patch_backend = patch.object(K, "backend", lambda: "tensorflow") + fake_tf = MagicMock() + # mock physical devices + fake_tf.config.list_physical_devices.return_value = ["GPU:0"] + fake_tf.config.experimental.get_memory_info.return_value = {"current": 150 * 1024**2} + + with patch_backend, \ + patch.dict(sys.modules, { + "tensorflow": fake_tf, + "tensorflow.config": fake_tf.config, + "tensorflow.config.experimental": fake_tf.config.experimental + }): + cb = MemoryUsageCallback() + assert pytest.approx(150.0, rel=1e-6) == cb._get_gpu_memory() + + @pytest.mark.requires_trainable_backend + def test_get_gpu_memory_torch(self): + patch_backend = patch.object(K, "backend", lambda: "torch") + fake_torch = MagicMock() + fake_torch.cuda.is_available.return_value = True + fake_torch.cuda.device_count.return_value = 2 + # return 100MB then 200MB + fake_torch.cuda.memory_allocated.side_effect = [100 * 1024**2, 200 * 1024**2] + + with patch_backend, \ + patch.dict(sys.modules, {"torch": fake_torch, "torch.cuda": fake_torch.cuda}): + cb = MemoryUsageCallback() + assert pytest.approx(300.0, rel=1e-6) == cb._get_gpu_memory() + + @pytest.mark.requires_trainable_backend + def test_get_gpu_memory_jax(self): + patch_backend = patch.object(K, "backend", lambda: "jax") + class Dev: + platform = "gpu" + def memory_stats(self): return {"bytes_in_use": 200 * 1024**2} + fake_jax = MagicMock() + fake_jax.devices.return_value = [Dev(), Dev()] + + with patch_backend, patch.dict(sys.modules, {"jax": fake_jax}): + cb = MemoryUsageCallback() + assert pytest.approx(400.0, rel=1e-6) == cb._get_gpu_memory() + + def test_running_on_gpu_and_tpu_flags(self): + g = running_on_gpu(); t = running_on_tpu() + assert isinstance(g, bool) and isinstance(t, bool) + + def test_psutil_missing(self): + # ensure ImportError if psutil absent + orig = sys.modules.pop("psutil", None) + try: + import keras.src.callbacks.memory_usage_callback as mod + with patch.dict(sys.modules, {"psutil": None}): + with pytest.raises(ImportError): + reload(mod) + _ = mod.MemoryUsageCallback() + finally: + if orig is not None: + sys.modules["psutil"] = orig + reload(sys.modules["keras.src.callbacks.memory_usage_callback"]) diff --git a/requirements-common.txt b/requirements-common.txt index 7edc40c97a1a..e1c042e3555c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -23,4 +23,4 @@ dm_tree coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime -openvino +psutil