Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs on interfacing with surrogates #804

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Add docs on interfacing with surrogates #804

wants to merge 3 commits into from

Conversation

theo-brown
Copy link
Collaborator

@theo-brown theo-brown commented Mar 6, 2025

Includes:

  1. Manually reimplementing the model in JAX
  2. Converting a Pytorch model to a JAX model using torch_xla2
  3. Using an ONNX model with jaxonnxruntime

Potentially closes #538

Copy link
Collaborator

@jcitrin jcitrin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Really awesome and appreciated that you are writing documentation.

Made a few comments. @hamelphi , @sbodenstein , @ernoc could you also take a look at this?

@jcitrin
Copy link
Collaborator

jcitrin commented Mar 7, 2025

When fusion_transport_surrogates matures more it will hopefully help abstract away some of this and we can supplement/modify these docs with examples using that library

@hamelphi
Copy link
Collaborator

Nice! LGTM. Thanks for the contribution Theo.

@jcitrin jcitrin added copybara:import-manual Set when ready for copybara manual import and removed copybara:import-manual Set when ready for copybara manual import labels Mar 18, 2025
@theo-brown theo-brown requested a review from jcitrin March 21, 2025 10:59
@theo-brown
Copy link
Collaborator Author

Thanks for the comments, sorry for the delay in responding!

@theo-brown
Copy link
Collaborator Author

Following some of the suggestions made by @sbodenstein, I've added a bit on saving/loading models in HLO format, which is the one backed by OpenXLA.

@theo-brown theo-brown requested a review from sbodenstein March 25, 2025 15:47

torch_model = PyTorchMLP(hidden_dim, n_hidden, output_dim, input_dim)

This model can be converted to a Flax model as follows:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer: 'can be replicated in Flax as follows'.

params = {'params': params}


The model can then be called like any Flax model,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you verified this? Eg. that no params are transposed between libraries

output_tensor = flax_model.apply(params, input_tensor)


Option 2: converting a Pytorch model to a JAX model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Pytorch -> PyTorch

import torch_xla2 as tx

trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model
params, jax_model_from_torch = tx.extract_jax(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs be be jitted (build in good practice: otherwise many users might be hit by terrible performance)


.. code-block:: python

output_tensor = flax_model.apply(params, input_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would show this jitted as PyTorch users might not know that this line of code is a terrible idea

import numpy as np

# jax.export uses StableHLO to serialize the model to a binary format
exported_model = jax.export(jax_model_from_torch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the docs, they first wrap the function in JIT, another reason to JIT this https://pytorch.org/xla/master/features/stablehlo.html#using-extract-jax

)

However, JAX will not be able to differentiate through the InferenceSession.
To convert the ONNX model to a JAX representation, you can use the `jaxonnxruntime`_ package:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use hyperlink to jaxonnxruntime


jax_model_from_onnx = ONNXJaxBackend.prepare(onnx_model)
# NOTE: run() returns a list of output tensors, in order of the output nodes
output_tensors = jax_model_from_onnx.run({"input": jnp.asarray(input_tensor, dtype=jnp.float32)})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely needs jitting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Provide ONNX converter
4 participants