Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent results between torch and jax versions of DINOv2 #37246

Open
2 of 4 tasks
MasterXiong opened this issue Apr 3, 2025 · 1 comment
Open
2 of 4 tasks

Inconsistent results between torch and jax versions of DINOv2 #37246

MasterXiong opened this issue Apr 3, 2025 · 1 comment

Comments

@MasterXiong
Copy link

System Info

  • transformers version: 4.50.0
  • Platform: Linux-5.15.0-131-generic-x86_64-with-glibc2.31
  • Python version: 3.10.16
  • Huggingface_hub version: 0.29.0
  • Safetensors version: 0.5.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Tensorflow version (GPU?): 2.15.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.8.1 (gpu)
  • Jax version: 0.4.20
  • JaxLib version: 0.4.20
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA RTX A5000

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoImageProcessor, FlaxDinov2Model, Dinov2Model
from PIL import Image
import requests
import numpy as np

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
jax_inputs = image_processor(images=image, return_tensors="np")

# flax model
model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")
outputs = model(**jax_inputs)
jax_results = outputs.last_hidden_state

# torch model
import torch

model = Dinov2Model.from_pretrained("facebook/dinov2-base")
torch_inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
    outputs = model(**torch_inputs)

torch_results = outputs.last_hidden_state

print (np.abs(jax_results - torch_results.numpy()).max())

Expected behavior

Hi,

I'm using the Flax version of DINOv2 and want to make sure that it returns consistent results as the torch version. So I run a simple test script as attached. However, I noticed that the token embeddings can have value difference larger than 6 by running it. I was wondering that if this is as expected due to numerical differences? Or is there something wrong in my code and the difference should not be so large? Thanks for your help!

@purusharthmalik
Copy link
Contributor

The issue most likely arises due to numerical differences. While comparing the models for consistency, the ideal approach would be to look at the alignment of the classification token embeddings of the two models,

from scipy.spatial.distance import cosine

jax_cls = jax_results[:, 0, :]
torch_cls = torch_results[:, 0, :]
cls_similarity = 1 - cosine(jax_cls.flatten(), torch_cls.flatten())
print(f"Cosine similarity between CLS token embeddings: {cls_similarity:.6f}")

The result (in my case) comes out to be 0.999900, which clearly indicates that even though there are small numerical differences in the complete tensors, the semantic understanding of the image that the models have is essentially the same.

Note -> jax_results[:, 0, :] works in this case as the first token in DINOv2 is the classification token.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants