-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add memory usage monitor callback #21245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
DimiChatzipavlis
wants to merge
20
commits into
keras-team:master
Choose a base branch
from
DimiChatzipavlis:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
6f7143c
Add memory usage monitor callback
DimiChatzipavlis 7695601
Add memory usage monitor callback
DimiChatzipavlis 5f9d975
Add memory usage monitor callback
DimiChatzipavlis daddf29
Fix formatting errors
DimiChatzipavlis 105cbdc
Add openvino support
DimiChatzipavlis f340101
Fix openvino support
DimiChatzipavlis 5af4a44
Appropriate API integration
e7f225c
Merge branch 'keras-team:master' into master
DimiChatzipavlis 1ae7659
Reformatted code
cb00aa2
Merge branch 'master' of https://github.com/DimiChatzipavlis/keras
e13528e
Fix API integration
728b770
Format the code
c4c0e5e
Format the code (2)
e064d0e
Fix openvino case
a9e0212
Add keras devs' comments into code
9148a60
Fix minor error
20cbdf9
Fix test file
a671b62
Fix TPU support
bd7fc07
Fix callback code
8f37649
Fix tests
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
import os | ||
import time | ||
import warnings | ||
|
||
from keras.src import backend as K | ||
from keras.src.api_export import keras_export | ||
from keras.src.callbacks.callback import Callback | ||
|
||
try: | ||
import psutil | ||
except ImportError: | ||
psutil = None | ||
|
||
|
||
def running_on_gpu(): | ||
"""Detect if any GPU is available on the current Keras backend.""" | ||
backend_name = K.backend() | ||
if backend_name == "tensorflow": | ||
import tensorflow as tf | ||
|
||
try: | ||
return bool(tf.config.list_physical_devices("GPU")) | ||
except Exception: | ||
return False | ||
elif backend_name == "torch": | ||
try: | ||
import torch | ||
|
||
return torch.cuda.is_available() | ||
except ImportError: | ||
return False | ||
elif backend_name == "jax": | ||
try: | ||
import jax | ||
|
||
return any(d.platform.upper() == "GPU" for d in jax.devices()) | ||
except ImportError: | ||
return False | ||
return False | ||
|
||
|
||
def running_on_tpu(): | ||
"""Detect if any TPU is available on the current Keras backend.""" | ||
backend_name = K.backend() | ||
if backend_name == "tensorflow": | ||
import tensorflow as tf | ||
|
||
try: | ||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver() | ||
tf.config.experimental_connect_to_cluster(resolver) | ||
tf.tpu.experimental.initialize_tpu_system(resolver) | ||
except Exception: | ||
pass | ||
try: | ||
return bool(tf.config.list_physical_devices("TPU")) | ||
except Exception: | ||
return False | ||
elif backend_name == "jax": | ||
try: | ||
import jax | ||
|
||
return any(d.platform.upper() == "TPU" for d in jax.devices()) | ||
except ImportError: | ||
return False | ||
return False | ||
|
||
|
||
@keras_export("keras.callbacks.MemoryUsageCallback") | ||
class MemoryUsageCallback(Callback): | ||
""" | ||
Monitors and logs memory usage (CPU + optional GPU/TPU) during training. | ||
|
||
This callback measures: | ||
- **CPU**: via psutil.Process().memory_info().rss | ||
- **GPU**: if a GPU is detected, via backend-specific APIs | ||
(TensorFlow, PyTorch, JAX) | ||
- **TPU**: if a TPU is detected, via backend-specific APIs | ||
(TensorFlow, JAX) | ||
|
||
Logs are printed to stdout at the start and end of each epoch | ||
(with a leading newline to avoid clobbering the progress bar), | ||
and, if `log_every_batch=True`, after every batch. | ||
If `tensorboard_log_dir` is provided, scalars are also written | ||
via tf.summary (TensorBoard). | ||
|
||
Args: | ||
log_every_batch (bool): If True, also log after each batch. Defaults to False. | ||
tensorboard_log_dir (str|None): Directory for TensorBoard logs; if None, | ||
no TF summary writer is created. | ||
|
||
Raises: | ||
ImportError: If `psutil` is not installed (required for CPU logging). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
log_every_batch=False, | ||
tensorboard_log_dir=None, | ||
): | ||
super().__init__() | ||
|
||
if psutil is None: | ||
raise ImportError( | ||
"MemoryUsageCallback requires the 'psutil' library. " | ||
"To install, please use: pip install psutil" | ||
) | ||
self.log_every_batch = log_every_batch | ||
self._proc = psutil.Process() | ||
self._step_counter = 0 | ||
self._writer = None | ||
|
||
if tensorboard_log_dir: | ||
try: | ||
import tensorflow as tf | ||
|
||
logdir = os.path.expanduser(tensorboard_log_dir) | ||
self._writer = tf.summary.create_file_writer(logdir) | ||
print(f"MemoryUsageCallback: TensorBoard logs → {logdir}") | ||
except Exception as e: | ||
warnings.warn( | ||
f"Could not initialize TensorBoard writer: {e}", RuntimeWarning | ||
) | ||
self._writer = None | ||
|
||
def on_train_begin(self, logs=None): | ||
self._step_counter = 0 | ||
|
||
def on_epoch_begin(self, epoch, logs=None): | ||
print() | ||
self._log_epoch("start", epoch) | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
print() | ||
self._log_epoch("end", epoch, offset=1) | ||
|
||
def on_batch_end(self, batch, logs=None): | ||
if self.log_every_batch: | ||
print() | ||
self._log_step(f"Batch {self._step_counter} end", self._step_counter) | ||
self._step_counter += 1 | ||
|
||
def on_train_end(self, logs=None): | ||
if self._writer: | ||
self._writer.close() | ||
|
||
def _log_epoch(self, when, epoch, offset=0): | ||
label = f"Epoch {epoch} {when}" | ||
step = epoch + offset | ||
self._log_step(label, step) | ||
|
||
def _log_step(self, label, step): | ||
""" | ||
Internal helper to measure and print CPU/GPU/TPU memory. | ||
Inserts a short delay (time.sleep(0)) to let stdout flush cleanly. | ||
""" | ||
cpu_mb = self._get_cpu_memory() | ||
gpu_mb = self._get_gpu_memory() | ||
tpu_mb = self._get_tpu_memory() | ||
|
||
msg = f"{label} - CPU Memory: {cpu_mb:.2f} MB" | ||
if gpu_mb is not None: | ||
msg += f"; GPU Memory: {gpu_mb:.2f} MB" | ||
if tpu_mb is not None: | ||
msg += f"; TPU Memory: {tpu_mb:.2f} MB" | ||
print(msg) | ||
time.sleep(0) | ||
|
||
if self._writer: | ||
import tensorflow as tf | ||
|
||
with self._writer.as_default(step=int(step)): | ||
tf.summary.scalar("Memory/CPU_MB", cpu_mb) | ||
if gpu_mb is not None: | ||
tf.summary.scalar("Memory/GPU_MB", gpu_mb) | ||
if tpu_mb is not None: | ||
tf.summary.scalar("Memory/TPU_MB", tpu_mb) | ||
|
||
def _get_cpu_memory(self): | ||
"""Return current process CPU memory usage in MB.""" | ||
return self._proc.memory_info().rss / (1024**2) | ||
|
||
def _get_gpu_memory(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another function to get tpu memory would be needed as well |
||
""" | ||
Return current GPU memory usage in MB for the detected backend, | ||
or None if no GPU is present or if measurement fails. | ||
""" | ||
if not running_on_gpu(): | ||
return None | ||
backend_name = K.backend() | ||
try: | ||
if backend_name == "tensorflow": | ||
import tensorflow as tf | ||
|
||
try: | ||
mem_info = tf.config.experimental.get_memory_info("GPU:0") | ||
return mem_info["current"] / (1024**2) | ||
except Exception: | ||
gpus = tf.config.list_physical_devices("GPU") | ||
if not gpus: | ||
return None | ||
total = 0 | ||
for i in range(len(gpus)): | ||
try: | ||
info = tf.config.experimental.get_memory_info(f"GPU:{i}") | ||
total += info.get("current", 0) | ||
except Exception: | ||
continue | ||
return total / (1024**2) | ||
elif backend_name == "torch": | ||
import torch | ||
|
||
if not torch.cuda.is_available(): | ||
return None | ||
total_bytes = 0 | ||
for i in range(torch.cuda.device_count()): | ||
total_bytes += torch.cuda.memory_allocated(i) | ||
return total_bytes / (1024**2) | ||
elif backend_name == "jax": | ||
import jax | ||
|
||
devs = [d for d in jax.devices() if d.platform.upper() == "GPU"] | ||
if not devs: | ||
return None | ||
total = 0 | ||
for d in devs: | ||
stats = getattr(d, "memory_stats", lambda: {})() | ||
total += stats.get("bytes_in_use", 0) | ||
return total / (1024**2) | ||
return None | ||
except ImportError as imp_err: | ||
if not hasattr(self, "_warn_import"): | ||
warnings.warn( | ||
f"Could not import library for GPU memory tracking ({backend_name}): {imp_err}", | ||
RuntimeWarning, | ||
) | ||
self._warn_import = True | ||
return None | ||
except Exception as exc: | ||
if not hasattr(self, "_warn_exc"): | ||
warnings.warn(f"Error retrieving GPU memory: {exc}", RuntimeWarning) | ||
self._warn_exc = True | ||
return None | ||
|
||
def _get_tpu_memory(self): | ||
""" | ||
Return current TPU memory usage in MB for the detected backend, | ||
or None if no TPU is present or if measurement fails. | ||
Note: TPU memory APIs vary; here we attempt best‐effort. | ||
""" | ||
if not running_on_tpu(): | ||
return None | ||
backend_name = K.backend() | ||
try: | ||
if backend_name == "tensorflow": | ||
return None | ||
elif backend_name == "jax": | ||
import jax | ||
|
||
devs = [d for d in jax.devices() if d.platform.upper() == "TPU"] | ||
if not devs: | ||
return None | ||
stats = getattr(devs[0], "memory_stats", lambda: {})() | ||
tpu_bytes = stats.get("bytes_in_use", stats.get("allocated_bytes", 0)) | ||
return tpu_bytes / (1024**2) | ||
return None | ||
except ImportError as imp_err: | ||
if not hasattr(self, "_warn_tpu_imp"): | ||
warnings.warn( | ||
f"Could not import library for TPU memory tracking ({backend_name}): {imp_err}", | ||
RuntimeWarning, | ||
) | ||
self._warn_tpu_imp = True | ||
return None | ||
except Exception as exc: | ||
if not hasattr(self, "_warn_tpu_exc"): | ||
warnings.warn(f"Error retrieving TPU memory: {exc}", RuntimeWarning) | ||
self._warn_tpu_exc = True | ||
return None |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from the colab output I am observing that Epoch end is not logged when log_every_batch is False