-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add docs on interfacing with surrogates #804
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Really awesome and appreciated that you are writing documentation.
Made a few comments. @hamelphi , @sbodenstein , @ernoc could you also take a look at this?
When fusion_transport_surrogates matures more it will hopefully help abstract away some of this and we can supplement/modify these docs with examples using that library |
Nice! LGTM. Thanks for the contribution Theo. |
Thanks for the comments, sorry for the delay in responding! |
Following some of the suggestions made by @sbodenstein, I've added a bit on saving/loading models in HLO format, which is the one backed by OpenXLA. |
|
||
torch_model = PyTorchMLP(hidden_dim, n_hidden, output_dim, input_dim) | ||
|
||
This model can be converted to a Flax model as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer: 'can be replicated in Flax as follows'.
params = {'params': params} | ||
|
||
|
||
The model can then be called like any Flax model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you verified this? Eg. that no params are transposed between libraries
output_tensor = flax_model.apply(params, input_tensor) | ||
|
||
|
||
Option 2: converting a Pytorch model to a JAX model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Pytorch -> PyTorch
import torch_xla2 as tx | ||
|
||
trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model | ||
params, jax_model_from_torch = tx.extract_jax(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it needs be be jitted (build in good practice: otherwise many users might be hit by terrible performance)
|
||
.. code-block:: python | ||
|
||
output_tensor = flax_model.apply(params, input_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would show this jitted as PyTorch users might not know that this line of code is a terrible idea
import numpy as np | ||
|
||
# jax.export uses StableHLO to serialize the model to a binary format | ||
exported_model = jax.export(jax_model_from_torch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the docs, they first wrap the function in JIT, another reason to JIT this https://pytorch.org/xla/master/features/stablehlo.html#using-extract-jax
) | ||
|
||
However, JAX will not be able to differentiate through the InferenceSession. | ||
To convert the ONNX model to a JAX representation, you can use the `jaxonnxruntime`_ package: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use hyperlink to jaxonnxruntime
|
||
jax_model_from_onnx = ONNXJaxBackend.prepare(onnx_model) | ||
# NOTE: run() returns a list of output tensors, in order of the output nodes | ||
output_tensors = jax_model_from_onnx.run({"input": jnp.asarray(input_tensor, dtype=jnp.float32)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely needs jitting
Includes:
torch_xla2
jaxonnxruntime
Potentially closes #538