Skip to content

Keras <> NNX integration #21252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 79 commits into
base: master
Choose a base branch
from

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented May 5, 2025

The PR integrates NNX into JAX backend!

The following snippet shows how you would enable the nnx backend

import os
os.environ["KERAS_BACKEND"]="jax"
os.environ["KERAS_NNX_ENABLED"]="true"
import keras

Demo colab here : https://colab.sandbox.google.com/drive/1mK-4qbce2HGRIkcb4v5n4niWGDezL_6n#scrollTo=m-ZH9Mpnphfz
Added a github workflow action for nnx backend. Note this will fail - because this needs a new release of flax to work.

@divyashreepathihalli divyashreepathihalli marked this pull request as draft May 5, 2025 23:05
@codecov-commenter
Copy link

codecov-commenter commented May 5, 2025

Codecov Report

Attention: Patch coverage is 23.75000% with 122 lines in your changes missing coverage. Please review.

Project coverage is 82.62%. Comparing base (92bfff0) to head (d8ca752).
Report is 7 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 8.75% 72 Missing and 1 partial ⚠️
keras/src/backend/config.py 30.55% 20 Missing and 5 partials ⚠️
keras/src/layers/layer.py 42.85% 5 Missing and 3 partials ⚠️
keras/src/backend/jax/layer.py 20.00% 3 Missing and 1 partial ⚠️
keras/src/backend/common/variables.py 0.00% 2 Missing and 1 partial ⚠️
keras/src/backend/jax/trainer.py 66.66% 2 Missing and 1 partial ⚠️
keras/src/ops/operation.py 25.00% 2 Missing and 1 partial ⚠️
keras/src/backend/jax/__init__.py 50.00% 1 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/config/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21252      +/-   ##
==========================================
- Coverage   82.67%   82.62%   -0.06%     
==========================================
  Files         565      565              
  Lines       55064    55494     +430     
  Branches     8569     8691     +122     
==========================================
+ Hits        45525    45851     +326     
- Misses       7441     7528      +87     
- Partials     2098     2115      +17     
Flag Coverage Δ
keras 82.43% <23.75%> (-0.06%) ⬇️
keras-jax 63.39% <23.75%> (-0.12%) ⬇️
keras-numpy 58.28% <16.87%> (-0.40%) ⬇️
keras-openvino 33.92% <15.62%> (+0.46%) ⬆️
keras-tensorflow 63.79% <16.87%> (-0.12%) ⬇️
keras-torch 63.42% <16.87%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

import jax.numpy as jnp

x = ops.ones(3)

@jax.jit
@nnx.jit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the integration prevent the use of jax.jit with Keras layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! it would only work with nnx.jit for now ( They might be working on adding support for jax.jit)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added nnx as a opt in with this flag - os.environ["KERAS_NNX_ENABLED"]

@@ -230,6 +233,25 @@ def is_flash_attention_enabled():
return global_state.get_global_attribute("flash_attention", default=None)


@keras_export("keras.config.is_nnx_backend_enabled")
def is_nnx_backend_enabled():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you renamed the setter set_nnx_enabled, but not the getter is_nnx_enabled.

@@ -1533,7 +1539,19 @@ def __setattr__(self, name, value):
if not hasattr(self, "_tracker"):
self._initialize_tracker()
value = self._tracker.track(value)
return super().__setattr__(name, value)

# NNX-specific bypass for `_called` and `built` attributes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I thought we were fixing that in NNX, is that not the case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently not.


vars(instance)["_object__state"] = nnx.object.ObjectState()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment got lost, can we move this to Layer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately that did not work. It is strange. Linking my comment here
#21252 (comment)

@abheesht17
Copy link
Collaborator

abheesht17 commented Jun 19, 2025

Noticed the following for when we use a custom NNX training loop.

  • When model is just a Keras layer, grads are non-zero and the loss goes down.
  • But whenever I have a keras.Model instance (for example, a keras.Sequential model), grads are all 0.

Here is a notebook demonstrating the same: https://colab.research.google.com/drive/1hvGYN00aETHlrwjUlLCjfwVJrCcQqySN?resourcekey=0-oyfeCvzJzqOf0aLroY4rhg&usp=sharing.

Let me know if I've messed up somewhere! Trying to figure out why it's happening. Will changing the trainable_variables filter work? Probably not, because the filter is working properly (see the cell with nnx.state; it returns the correct variables).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants