Skip to content

Add ColQwen2 to 🤗 transformers #35778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 68 commits into from
Jun 2, 2025
Merged

Conversation

tonywu71
Copy link
Contributor

@tonywu71 tonywu71 commented Jan 19, 2025

What does this PR do?

Add ColQwen2 in 🤗 transformers. ColQwen2 is a model that uses the ColPali architecture with a Qwen2-VL backbone.

Who can review?

Additional details

Progress checklist

  • (Optional) Understood the model’s theoretical aspects
  • Prepared 🤗 Transformers dev environment
  • Set up debugging environment of the original repository
  • Created script that successfully runs the forward() pass using the original repository and checkpoint
  • Successfully added the model skeleton to 🤗 Transformers
  • Successfully converted original checkpoint to 🤗 Transformers checkpoint
  • Successfully ran forward() pass in 🤗 Transformers that gives identical output to original checkpoint
  • Finished model tests in 🤗 Transformers
  • Successfully added tokenizer in 🤗 Transformers
  • Run end-to-end integration tests
  • Finished docs
  • Uploaded model weights to the Hub
  • Submitted the pull request
  • (Optional) Added a demo notebook → can be found in Add ColQwen2 Hf cookbook tonywu71/colpali-cookbooks#17

@ArthurZucker
Copy link
Collaborator

Feel free to ping us once this is ready for review!

@ArthurZucker
Copy link
Collaborator

Feel free to ping @Cyrilvallez once this is ready for review! 🤗

@tonywu71 tonywu71 marked this pull request as ready for review April 15, 2025 18:10
@tonywu71 tonywu71 marked this pull request as draft April 15, 2025 18:19
@tonywu71 tonywu71 force-pushed the add-colqwen2 branch 2 times, most recently from 5ec9758 to 7cfd9dc Compare April 16, 2025 09:29
@tonywu71 tonywu71 marked this pull request as ready for review April 16, 2025 12:41
Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @tonywu71 ! Thanks for contributing 🤗. Looks almost ready to go to me, I just pointed out a few nits to change

)
self.query_prefix = query_prefix or "Query: "

self.tokenizer.padding_side = "left"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be set when saving the tokenizer/processor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! The Hf Hub commit with the new processor_config.json can be found here for reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: after discussion with @yonigozlan, I have realized it makes much more sense to let tokenizer_config.json handle padding_side. I've just applied the necessary changes!

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thanks for iterating! I see two small things left to change then LGTM for me!

@yonigozlan
Copy link
Member

Hey @tonywu71 @Cyrilvallez . Indeed I remember we had tried the modular way for ColQwen2 by inheriting from ColPali for the modeling but ended up needing to override the whole thing. Only thing I see that could be inherited are the input and start docstrings, but these will gone soon anyway hopefully (#33771 👀)

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Sorry for the delay! This is pretty clean, great work! 🤗 I just left a few last comments!

Comment on lines -167 to 174
loss = None
if not return_dict:
output = (embeddings,) + outputs[2:]
output[2] = output[2] if output_hidden_states is not None else None
output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,)
return (loss,) + output if loss is not None else output

return ColPaliForRetrievalOutput(
loss=loss,
embeddings=embeddings,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing the loss here? 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss was strictly speaking removed:

  • it used to be set to None.
  • the default value for loss in ColPaliForRetrievalOutput is None.

So I have removed the unneeded lines to make the code clearer.

Comment on lines +87 to +92
visual_prompt_prefix: str = "Describe the image.",
query_prefix: str = "Question: ",
):
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
self.visual_prompt_prefix = visual_prompt_prefix
self.query_prefix = query_prefix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kind of prefix should be part of the chat_template directly, not hardcoded here 🤗

Comment on lines 27 to 29
if is_torch_available():
import torch
from torch import nn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to protect the torch import here!

Comment on lines 41 to 43
raise AttributeError(
"The `initializer_range` attribute is not set in the configuration. Please set it before using the model."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure it is correctly defined in the Config with some default value instead of raising here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ColQwen2Config already has a default value for initializer_range, so I'll just remove the raise 👍🏼

Comment on lines 160 to 187
if inputs_embeds is None:
inputs_embeds = self.vlm.model.embed_tokens(input_ids)

if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

outputs = self.vlm.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind, I think it would help readability to have this block directly in the main forward instead of separating in 2 functions (due to the large signatures, we need to go back and forth)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I don't mind, I think it's actually a good idea! 🤗

Comment on lines +93 to +99
if visual_prompt_prefix is None:
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
self.visual_prompt_prefix = visual_prompt_prefix

if query_prefix is None:
query_prefix = "Query: "
self.query_prefix = query_prefix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be incorporated to the chat_template 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we should have a chat template here since this is not a chat model really. We had the same issue with Got OCR and ended up not using a chat template. wdyt?

Copy link
Member

@Cyrilvallez Cyrilvallez May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm indeed was a bit fast here - let's keep as is, especially as it aligns with ColPali!

Comment on lines +157 to +158
if text is not None and images is not None:
raise ValueError("Only one of text or images can be processed at a time")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let's keep it then!

Comment on lines 234 to 276
def process_images(
self,
images: ImageInput = None,
**kwargs: Unpack[ColQwen2ProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's
[`ColQwen2Processor.__call__`].

This method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`].

Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:

- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
return self.__call__(images=images, **kwargs)

def process_queries(
self,
text: Union[TextInput, List[TextInput]],
**kwargs: Unpack[ColQwen2ProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's
[`ColQwen2Processor.__call__`].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we redefine them here? They will be inherited directly!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you're right! However, I think the docstring will inherit from ColPaliProcessor's docstring and thus referencing ColPali. Is there a way to simply override the docstring here? If not, should we keep the code as it is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdyt @yonigozlan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a clean way to do this, but we can just remove specific references to the tokenizer and image processor in the docstring imo

@yonigozlan
Copy link
Member

@Cyrilvallez taking over for the final push on this PR as Tony is quite busy. I pushed some necessary updates after the refactoring of Qwen2VL (so nice to have btw), all should be good now and we use modular much more, including for the modeling code 🤗. @tonywu71 you'll still have to run the updated convert_weights script and push to your repo :), but apart from that we should be ready to merge!

@tonywu71
Copy link
Contributor Author

@Cyrilvallez taking over for the final push on this PR as Tony is quite busy. I pushed some necessary updates after the refactoring of Qwen2VL (so nice to have btw), all should be good now and we use modular much more, including for the modeling code 🤗. @tonywu71 you'll still have to run the updated convert_weights script and push to your repo :), but apart from that we should be ready to merge!

@yonigozlan Done, the model repo is updated! 🤗 I've also pushed a commit to fix the Hf model path for ColQwen2 integration tests. Lmk if there's anything left to do before merging!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right! Amazing work, congrats to you both @tonywu71 @yonigozlan! Super clean 🤗 I left 2 ultra small comments as my job here is to be very picky 🙃, but that's it! Feel free to merge @yonigozlan!
Thanks for the great addition 🤗

Comment on lines 53 to 54
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a super nit, feel free to disregard if you're too annoyed by the review process 😆 But passing None is a bit misleading for an example IMO, even if it's equivalent

Suggested change
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
)
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! It's been addressed 👌🏼

Copy link
Contributor Author

@tonywu71 tonywu71 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it seems sdpa doesn't work out-of-the-box for ColQwen2 as I get this error when loading the model on MPS.

❌ Code:

model_name = "vidore/colqwen2-v1.0-hf"

# Load model
model = ColQwen2ForRetrieval.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # "cpu", "cuda", or "mps" for Apple Silicon
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)

Note: Leaving attn_implementation=None works.

The error:

ValueError: ColQwen2ForRetrieval does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`

✅ However, I managed to load Qwen2VL with SDPA:

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto",  # "cpu", "cuda", or "mps" for Apple Silicon
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)

@Cyrilvallez @yonigozlan I read about the instructions for enabling SDPA on ColQwen2 but next steps are a bit unclear as ColQwen2 essentially piggybacks on Qwen2VL thanks to modular. Any ideas about the right fix? 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's only because the flags are not set in the PreTrainedModel - adding

_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True

should solve it

Copy link
Contributor Author

@tonywu71 tonywu71 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tsm, the fix is working like a charm! And as you expected, ColQwen2 works with attn_implementation="flex_attention" too 👌🏼

Comment on lines +30 to +31
if is_torch_available():
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not protect, simply import it 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem is we need to protect the import for the processor :(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see - not a big issue anyway you can disregard (it's just that torch.nn is imported without protection anyway so a bit weird), but really not a big concern

Comment on lines +93 to +99
if visual_prompt_prefix is None:
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
self.visual_prompt_prefix = visual_prompt_prefix

if query_prefix is None:
query_prefix = "Query: "
self.query_prefix = query_prefix
Copy link
Member

@Cyrilvallez Cyrilvallez May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm indeed was a bit fast here - let's keep as is, especially as it aligns with ColPali!

@yonigozlan yonigozlan enabled auto-merge (squash) June 2, 2025 12:56
@yonigozlan yonigozlan merged commit c72ba69 into huggingface:main Jun 2, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants