Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements in Gemma2 model card #37076

Merged
merged 5 commits into from
Apr 7, 2025

Conversation

devesh-2002
Copy link
Contributor

What does this PR do?

Fixes #36979,
This PR aims to improve model card for Gemma2 based on the given format mentioned here.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@stevhliu, Please let me know, if there are any changes needed here.

Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@github-actions github-actions bot marked this pull request as draft March 28, 2025 13:56
@devesh-2002 devesh-2002 marked this pull request as ready for review March 28, 2025 17:26
@github-actions github-actions bot requested a review from stevhliu March 28, 2025 17:27
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job, thanks! 🤗

<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
<div style="float: right;">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The badges should go above # Gemma2

</div>

## Overview

The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by Gemma2 Team, Google.
Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B).
**[Gemma 2](https://arxiv.org/pdf/2408.00118)** is Google's open-weight language model family (2B, 9B, 27B parameters) featuring interleaved local-global attention (4K sliding window + 8K global context), knowledge distillation for smaller models, and GQA for efficient inference. The 27B variant rivals models twice its size, scoring 75.2 on MMLU and 74.0 on GSM8K, while the instruction-tuned versions excel in multi-turn chat.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**[Gemma 2](https://arxiv.org/pdf/2408.00118)** is Google's open-weight language model family (2B, 9B, 27B parameters) featuring interleaved local-global attention (4K sliding window + 8K global context), knowledge distillation for smaller models, and GQA for efficient inference. The 27B variant rivals models twice its size, scoring 75.2 on MMLU and 74.0 on GSM8K, while the instruction-tuned versions excel in multi-turn chat.
[Gemma 2](https://huggingface.co/papers/2408.00118) is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance.


The abstract from the blog post is the following:
Key improvements over Gemma 1 include deeper networks, logit soft-capping, and stricter safety filters (<0.1% memorization). Available in base and instruction-tuned variants.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Key improvements over Gemma 1 include deeper networks, logit soft-capping, and stricter safety filters (<0.1% memorization). Available in base and instruction-tuned variants.
The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning.


*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.*
The original checkpoints of Gemma 2 can be found [here](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The original checkpoints of Gemma 2 can be found [here](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315).
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) release.


Tips:
> [!TIP]
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks.
> Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks.

outputs = model.generate(**input_ids, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))
```
#### Using 4-bit precision
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#### Using 4-bit precision

```
#### Using 4-bit precision
```python
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-27b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

input_text = "Explain quantum computing simply."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

print(tokenizer.decode(outputs[0]))

```
### AttentionMaskVisualizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### AttentionMaskVisualizer
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.

### AttentionMaskVisualizer

```python
visualizer = AttentionMaskVisualizer("google/gemma-2b")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
visualizer = AttentionMaskVisualizer("google/gemma-2b")
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you print out the image, upload it to https://huggingface.co/datasets/huggingface/documentation-images/tree/main/transformers/model_doc (ping me to merge!), and then add it here?

Comment on lines 149 to 150
## Notes
- Gemma 2's sliding window attention enables efficient long-context processing - see sidebar examples for >4K token use cases
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## Notes
- Gemma 2's sliding window attention enables efficient long-context processing - see sidebar examples for >4K token use cases

@devesh-2002
Copy link
Contributor Author

Hi @stevhliu. I have tried to make changes as per the suggestions. Please let me know if there are any further adjustments needed.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your changes, we're close to merging!

<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

## Overview
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## Overview


*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.*
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) release.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) release.
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) collection.


- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`

<Tip warning={true}>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove <Tip warning={true}> here and everything up to "The example below demonstrates..."


```
echo -e "Explain quantum computing simply." | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close the <hfoptions> block after the transformers-cli example

</hfoption>
</hfoptions>

### AttentionMaskVisualizer

```python
visualizer = AttentionMaskVisualizer("google/gemma-2b")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you print out the image, upload it to https://huggingface.co/datasets/huggingface/documentation-images/tree/main/transformers/model_doc (ping me to merge!), and then add it here?

from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")
visualizer("You are an assistant. Make sure you print me")
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add image below here:

<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/gemma-2-attn-mask.png"/>
</div>

Let's also add a ## Notes section below this:

  • Use a [HybridCache] instance to enable caching in Gemma 2. Gemma 2 doesn't support kv-caching strategies like [DynamicCache] or tuples of tensors because it uses sliding window attention every second layer.

    from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
    
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
    
    inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
    max_generated_length = inputs.input_ids.shape[1] + 10
    past_key_values = HybridCache(config=model.config, max_batch_size=1, 
    max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
    outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stevhliu, Could you please verify, if this attention visualization image is correct or not? If so, I will push this on huggingface/documentation-images.
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that looks correct!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created a PR, please have a look at it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, once all the other comments have been addressed we can merge :)

@devesh-2002
Copy link
Contributor Author

@stevhliu, I have made the changes, as suggested. Please let me know, if there is still anything.

@stevhliu stevhliu merged commit 6cc109c into huggingface:main Apr 7, 2025
10 checks passed
@devesh-2002
Copy link
Contributor Author

Thank you for all the help! @stevhliu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Community contributions] Model cards
3 participants