-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Add ColQwen2 to 🤗 transformers #35778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Feel free to ping us once this is ready for review! |
025ca25
to
48e0aa5
Compare
Feel free to ping @Cyrilvallez once this is ready for review! 🤗 |
5ec9758
to
7cfd9dc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @tonywu71 ! Thanks for contributing 🤗. Looks almost ready to go to me, I just pointed out a few nits to change
) | ||
self.query_prefix = query_prefix or "Query: " | ||
|
||
self.tokenizer.padding_side = "left" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be set when saving the tokenizer/processor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed! The Hf Hub commit with the new processor_config.json
can be found here for reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: after discussion with @yonigozlan, I have realized it makes much more sense to let tokenizer_config.json
handle padding_side
. I've just applied the necessary changes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice thanks for iterating! I see two small things left to change then LGTM for me!
Hey @tonywu71 @Cyrilvallez . Indeed I remember we had tried the modular way for ColQwen2 by inheriting from ColPali for the modeling but ended up needing to override the whole thing. Only thing I see that could be inherited are the input and start docstrings, but these will gone soon anyway hopefully (#33771 👀) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Sorry for the delay! This is pretty clean, great work! 🤗 I just left a few last comments!
loss = None | ||
if not return_dict: | ||
output = (embeddings,) + outputs[2:] | ||
output[2] = output[2] if output_hidden_states is not None else None | ||
output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,) | ||
return (loss,) + output if loss is not None else output | ||
|
||
return ColPaliForRetrievalOutput( | ||
loss=loss, | ||
embeddings=embeddings, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we removing the loss here? 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loss was strictly speaking removed:
- it used to be set to
None
. - the default value for
loss
inColPaliForRetrievalOutput
isNone
.
So I have removed the unneeded lines to make the code clearer.
visual_prompt_prefix: str = "Describe the image.", | ||
query_prefix: str = "Question: ", | ||
): | ||
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template) | ||
self.visual_prompt_prefix = visual_prompt_prefix | ||
self.query_prefix = query_prefix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These kind of prefix should be part of the chat_template directly, not hardcoded here 🤗
src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py
Outdated
Show resolved
Hide resolved
if is_torch_available(): | ||
import torch | ||
from torch import nn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to protect the torch import here!
raise AttributeError( | ||
"The `initializer_range` attribute is not set in the configuration. Please set it before using the model." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure it is correctly defined in the Config with some default value instead of raising here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ColQwen2Config
already has a default value for initializer_range
, so I'll just remove the raise
👍🏼
if inputs_embeds is None: | ||
inputs_embeds = self.vlm.model.embed_tokens(input_ids) | ||
|
||
if pixel_values is not None: | ||
pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) | ||
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw) | ||
image_mask = ( | ||
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) | ||
) | ||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | ||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | ||
|
||
if attention_mask is not None: | ||
attention_mask = attention_mask.to(inputs_embeds.device) | ||
|
||
outputs = self.vlm.model( | ||
input_ids=None, | ||
position_ids=position_ids, | ||
attention_mask=attention_mask, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
cache_position=cache_position, | ||
) | ||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't mind, I think it would help readability to have this block directly in the main forward
instead of separating in 2 functions (due to the large signatures, we need to go back and forth)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I don't mind, I think it's actually a good idea! 🤗
if visual_prompt_prefix is None: | ||
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" | ||
self.visual_prompt_prefix = visual_prompt_prefix | ||
|
||
if query_prefix is None: | ||
query_prefix = "Query: " | ||
self.query_prefix = query_prefix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be incorporated to the chat_template 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we should have a chat template here since this is not a chat model really. We had the same issue with Got OCR and ended up not using a chat template. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm indeed was a bit fast here - let's keep as is, especially as it aligns with ColPali!
if text is not None and images is not None: | ||
raise ValueError("Only one of text or images can be processed at a time") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, let's keep it then!
def process_images( | ||
self, | ||
images: ImageInput = None, | ||
**kwargs: Unpack[ColQwen2ProcessorKwargs], | ||
) -> BatchFeature: | ||
""" | ||
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColQwen2Processor's | ||
[`ColQwen2Processor.__call__`]. | ||
|
||
This method forwards the `images` and `kwargs` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`]. | ||
|
||
Args: | ||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): | ||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch | ||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a | ||
number of channels, H and W are image height and width. | ||
return_tensors (`str` or [`~utils.TensorType`], *optional*): | ||
If set, will return tensors of a particular framework. Acceptable values are: | ||
|
||
- `'tf'`: Return TensorFlow `tf.constant` objects. | ||
- `'pt'`: Return PyTorch `torch.Tensor` objects. | ||
- `'np'`: Return NumPy `np.ndarray` objects. | ||
- `'jax'`: Return JAX `jnp.ndarray` objects. | ||
|
||
Returns: | ||
[`BatchFeature`]: A [`BatchFeature`] with the following fields: | ||
|
||
- **input_ids** -- List of token ids to be fed to a model. | ||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when | ||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not | ||
`None`). | ||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. | ||
""" | ||
return self.__call__(images=images, **kwargs) | ||
|
||
def process_queries( | ||
self, | ||
text: Union[TextInput, List[TextInput]], | ||
**kwargs: Unpack[ColQwen2ProcessorKwargs], | ||
) -> BatchFeature: | ||
""" | ||
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColQwen2Processor's | ||
[`ColQwen2Processor.__call__`]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we redefine them here? They will be inherited directly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh you're right! However, I think the docstring will inherit from ColPaliProcessor
's docstring and thus referencing ColPali. Is there a way to simply override the docstring here? If not, should we keep the code as it is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wdyt @yonigozlan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a clean way to do this, but we can just remove specific references to the tokenizer and image processor in the docstring imo
@Cyrilvallez taking over for the final push on this PR as Tony is quite busy. I pushed some necessary updates after the refactoring of Qwen2VL (so nice to have btw), all should be good now and we use modular much more, including for the modeling code 🤗. @tonywu71 you'll still have to run the updated convert_weights script and push to your repo :), but apart from that we should be ready to merge! |
@yonigozlan Done, the model repo is updated! 🤗 I've also pushed a commit to fix the Hf model path for ColQwen2 integration tests. Lmk if there's anything left to do before merging! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right! Amazing work, congrats to you both @tonywu71 @yonigozlan! Super clean 🤗 I left 2 ultra small comments as my job here is to be very picky 🙃, but that's it! Feel free to merge @yonigozlan!
Thanks for the great addition 🤗
docs/source/en/model_doc/colqwen2.md
Outdated
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a super nit, feel free to disregard if you're too annoyed by the review process 😆 But passing None
is a bit misleading for an example IMO, even if it's equivalent
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, | |
) | |
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa", | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! It's been addressed 👌🏼
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it seems sdpa
doesn't work out-of-the-box for ColQwen2 as I get this error when loading the model on MPS.
❌ Code:
model_name = "vidore/colqwen2-v1.0-hf"
# Load model
model = ColQwen2ForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)
Note: Leaving attn_implementation=None
works.
The error:
ValueError: ColQwen2ForRetrieval does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`
✅ However, I managed to load Qwen2VL with SDPA:
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
)
@Cyrilvallez @yonigozlan I read about the instructions for enabling SDPA on ColQwen2 but next steps are a bit unclear as ColQwen2 essentially piggybacks on Qwen2VL thanks to modular. Any ideas about the right fix? 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it's only because the flags are not set in the PreTrainedModel - adding
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
should solve it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tsm, the fix is working like a charm! And as you expected, ColQwen2 works with attn_implementation="flex_attention"
too 👌🏼
if is_torch_available(): | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not protect, simply import it 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Problem is we need to protect the import for the processor :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see - not a big issue anyway you can disregard (it's just that torch.nn is imported without protection anyway so a bit weird), but really not a big concern
if visual_prompt_prefix is None: | ||
visual_prompt_prefix = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" | ||
self.visual_prompt_prefix = visual_prompt_prefix | ||
|
||
if query_prefix is None: | ||
query_prefix = "Query: " | ||
self.query_prefix = query_prefix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm indeed was a bit fast here - let's keep as is, especially as it aligns with ColPali!
What does this PR do?
Add ColQwen2 in 🤗
transformers
. ColQwen2 is a model that uses the ColPali architecture with a Qwen2-VL backbone.Who can review?
Additional details
colpali-engine==v0.3.6
.vidore/colqwen2-v1.0-hf
.Progress checklist