-
Notifications
You must be signed in to change notification settings - Fork 28.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements in Gemma2 model card #37076
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job, thanks! 🤗
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | ||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
<div style="float: right;"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The badges should go above # Gemma2
docs/source/en/model_doc/gemma2.md
Outdated
</div> | ||
|
||
## Overview | ||
|
||
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by Gemma2 Team, Google. | ||
Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B). | ||
**[Gemma 2](https://arxiv.org/pdf/2408.00118)** is Google's open-weight language model family (2B, 9B, 27B parameters) featuring interleaved local-global attention (4K sliding window + 8K global context), knowledge distillation for smaller models, and GQA for efficient inference. The 27B variant rivals models twice its size, scoring 75.2 on MMLU and 74.0 on GSM8K, while the instruction-tuned versions excel in multi-turn chat. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**[Gemma 2](https://arxiv.org/pdf/2408.00118)** is Google's open-weight language model family (2B, 9B, 27B parameters) featuring interleaved local-global attention (4K sliding window + 8K global context), knowledge distillation for smaller models, and GQA for efficient inference. The 27B variant rivals models twice its size, scoring 75.2 on MMLU and 74.0 on GSM8K, while the instruction-tuned versions excel in multi-turn chat. | |
[Gemma 2](https://huggingface.co/papers/2408.00118) is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance. |
docs/source/en/model_doc/gemma2.md
Outdated
|
||
The abstract from the blog post is the following: | ||
Key improvements over Gemma 1 include deeper networks, logit soft-capping, and stricter safety filters (<0.1% memorization). Available in base and instruction-tuned variants. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Key improvements over Gemma 1 include deeper networks, logit soft-capping, and stricter safety filters (<0.1% memorization). Available in base and instruction-tuned variants. | |
The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning. |
docs/source/en/model_doc/gemma2.md
Outdated
|
||
*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.* | ||
The original checkpoints of Gemma 2 can be found [here](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original checkpoints of Gemma 2 can be found [here](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315). | |
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) release. |
docs/source/en/model_doc/gemma2.md
Outdated
|
||
Tips: | ||
> [!TIP] | ||
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks. | |
> Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks. |
docs/source/en/model_doc/gemma2.md
Outdated
outputs = model.generate(**input_ids, max_new_tokens=32) | ||
print(tokenizer.decode(outputs[0])) | ||
``` | ||
#### Using 4-bit precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#### Using 4-bit precision |
docs/source/en/model_doc/gemma2.md
Outdated
``` | ||
#### Using 4-bit precision | ||
```python | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-27b",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
input_text = "Explain quantum computing simply."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
docs/source/en/model_doc/gemma2.md
Outdated
print(tokenizer.decode(outputs[0])) | ||
|
||
``` | ||
### AttentionMaskVisualizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
### AttentionMaskVisualizer | |
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to. |
### AttentionMaskVisualizer | ||
|
||
```python | ||
visualizer = AttentionMaskVisualizer("google/gemma-2b") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
visualizer = AttentionMaskVisualizer("google/gemma-2b") | |
from transformers.utils.attention_visualizer import AttentionMaskVisualizer | |
visualizer = AttentionMaskVisualizer("google/gemma-2b") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you print out the image, upload it to https://huggingface.co/datasets/huggingface/documentation-images/tree/main/transformers/model_doc (ping me to merge!), and then add it here?
docs/source/en/model_doc/gemma2.md
Outdated
## Notes | ||
- Gemma 2's sliding window attention enables efficient long-context processing - see sidebar examples for >4K token use cases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
## Notes | |
- Gemma 2's sliding window attention enables efficient long-context processing - see sidebar examples for >4K token use cases |
9e08b48
to
6257220
Compare
Hi @stevhliu. I have tried to make changes as per the suggestions. Please let me know if there are any further adjustments needed. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your changes, we're close to merging!
docs/source/en/model_doc/gemma2.md
Outdated
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | ||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||
</div> | ||
|
||
## Overview |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
## Overview |
docs/source/en/model_doc/gemma2.md
Outdated
|
||
*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.* | ||
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) release. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) release. | |
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) collection. |
docs/source/en/model_doc/gemma2.md
Outdated
|
||
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py` | ||
|
||
<Tip warning={true}> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove <Tip warning={true}>
here and everything up to "The example below demonstrates..."
|
||
``` | ||
echo -e "Explain quantum computing simply." | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0 | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Close the <hfoptions>
block after the transformers-cli
example
</hfoption>
</hfoptions>
### AttentionMaskVisualizer | ||
|
||
```python | ||
visualizer = AttentionMaskVisualizer("google/gemma-2b") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you print out the image, upload it to https://huggingface.co/datasets/huggingface/documentation-images/tree/main/transformers/model_doc (ping me to merge!), and then add it here?
from transformers.utils.attention_visualizer import AttentionMaskVisualizer | ||
visualizer = AttentionMaskVisualizer("google/gemma-2b") | ||
visualizer("You are an assistant. Make sure you print me") | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add image below here:
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/gemma-2-attn-mask.png"/>
</div>
Let's also add a ## Notes
section below this:
-
Use a [
HybridCache
] instance to enable caching in Gemma 2. Gemma 2 doesn't support kv-caching strategies like [DynamicCache
] or tuples of tensors because it uses sliding window attention every second layer.from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") inputs = tokenizer(text="My name is Gemma", return_tensors="pt") max_generated_length = inputs.input_ids.shape[1] + 10 past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stevhliu, Could you please verify, if this attention visualization image is correct or not? If so, I will push this on huggingface/documentation-images.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that looks correct!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have created a PR, please have a look at it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, once all the other comments have been addressed we can merge :)
@stevhliu, I have made the changes, as suggested. Please let me know, if there is still anything. |
Thank you for all the help! @stevhliu |
What does this PR do?
Fixes #36979,
This PR aims to improve model card for Gemma2 based on the given format mentioned here.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@stevhliu, Please let me know, if there are any changes needed here.