Skip to content

Commit 6a24d4d

Browse files
Supporting Gemma V2
1 parent 4f21b70 commit 6a24d4d

File tree

6 files changed

+381
-37
lines changed

6 files changed

+381
-37
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ This is the official PyTorch implementation of Gemma models. We provide model an
1010

1111
## Updates
1212

13+
[June 26th] Support Gemma v2. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-2/pytorch) and Hugging Face
1314
[April 9th] Support CodeGemma. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/codegemma/pytorch) and [Hugging Face](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11)
1415
[April 5] Support Gemma v1.1. You can find the v1.1 checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch) and [Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b).
1516

@@ -28,7 +29,7 @@ huggingface-cli download google/gemma-7b-it-pytorch
2829
Note that you can choose between the 2B, 7B, 7B int8 quantized variants.
2930

3031
```
31-
VARIANT=<2b or 7b>
32+
VARIANT=<2b or 7b or 9b or 27b>
3233
CKPT_PATH=<Insert ckpt path here>
3334
```
3435

gemma/config.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
"""Gemma model config."""
1616

1717
import dataclasses
18+
import enum
1819
import torch
19-
from typing import Optional
20+
from typing import Optional, Sequence
2021

2122

2223
# Keep a mapping from dtype strings to the supported torch dtypes.
@@ -28,8 +29,20 @@
2829
})
2930

3031

32+
class AttentionType(enum.Enum):
33+
GLOBAL = 1
34+
LOCAL_SLIDING = 2
35+
36+
37+
class Architecture(enum.Enum):
38+
GEMMA_1 = 1
39+
GEMMA_2 = 2
40+
41+
3142
@dataclasses.dataclass
3243
class GemmaConfig:
44+
# The architecture of the model.
45+
architecture: Architecture = Architecture.GEMMA_1
3346
# The number of tokens in the vocabulary.
3447
vocab_size: int = 256000
3548
# The maximum sequence length that this model might ever be used with.
@@ -54,6 +67,21 @@ class GemmaConfig:
5467
quant: bool = False
5568
# The path to the model tokenizer.
5669
tokenizer: Optional[str] = 'tokenizer/tokenizer.model'
70+
# The types of attention used in the layers of the model.
71+
attn_types: Optional[Sequence[AttentionType]] = None
72+
# The size of the sliding window used for local attention.
73+
sliding_window_size: Optional[int] = None
74+
# If provided, the final logits are softcapped to this value.
75+
final_logit_softcapping: Optional[float] = None
76+
# If provided, the attention logits are softcapped to this value.
77+
attn_logit_softcapping: Optional[float] = None
78+
# If provided, the query vector is normalized using the
79+
# inverse square root of this value instead of head_dim.
80+
query_pre_attn_scalar: Optional[int] = None
81+
# Whether to use pre mlp normalization.
82+
use_pre_ffw_norm: bool = False
83+
# Whether to use post mlp normalization.
84+
use_post_ffw_norm: bool = False
5785

5886
def get_dtype(self) -> Optional[torch.dtype]:
5987
"""Gets the torch dtype from the config dtype string."""
@@ -74,10 +102,55 @@ def get_config_for_2b() -> GemmaConfig:
74102
)
75103

76104

105+
def get_config_for_9b() -> GemmaConfig:
106+
return GemmaConfig(
107+
architecture=Architecture.GEMMA_2,
108+
num_hidden_layers=42,
109+
num_attention_heads=16,
110+
num_key_value_heads=8,
111+
hidden_size=3584,
112+
intermediate_size=14336,
113+
use_pre_ffw_norm=True,
114+
use_post_ffw_norm=True,
115+
final_logit_softcapping=30.0,
116+
attn_logit_softcapping=50.0,
117+
head_dim=256,
118+
attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 21,
119+
sliding_window_size=4096,
120+
query_pre_attn_scalar=224, # hidden_size / num_attention_heads
121+
)
122+
123+
124+
def get_config_for_27b() -> GemmaConfig:
125+
return GemmaConfig(
126+
architecture=Architecture.GEMMA_2,
127+
num_hidden_layers=46,
128+
num_attention_heads=32,
129+
num_key_value_heads=16,
130+
hidden_size=4608,
131+
intermediate_size=36864,
132+
use_pre_ffw_norm=True,
133+
use_post_ffw_norm=True,
134+
final_logit_softcapping=30.0,
135+
attn_logit_softcapping=50.0,
136+
head_dim=128,
137+
attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 23,
138+
sliding_window_size=4096,
139+
query_pre_attn_scalar=144, # hidden_size / num_attention_heads
140+
)
141+
142+
77143
def get_model_config(variant: str) -> GemmaConfig:
78144
if variant == '7b':
79145
return get_config_for_7b()
80146
elif variant == '2b':
81147
return get_config_for_2b()
82-
raise ValueError(f'Invalid variant {variant}. Supported variants are "2b"'
83-
'and "7b"')
148+
elif variant == '9b':
149+
return get_config_for_9b()
150+
elif variant == '27b':
151+
return get_config_for_27b()
152+
else:
153+
raise ValueError(
154+
f'Invalid variant {variant}. Supported variants are "2b"'
155+
'and "7b" and "9b" and "27b".')
156+

gemma/model.py

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
"""Inference-only Gemma model implementation."""
1515

16+
import json
17+
import gc
18+
import os
1619
import re
1720
import torch
1821
from torch import nn
@@ -25,9 +28,10 @@
2528

2629
class Sampler(nn.Module):
2730

28-
def __init__(self, vocab_size: int):
31+
def __init__(self, vocab_size: int, config: gemma_config.GemmaConfig):
2932
super().__init__()
3033
self.vocab_size = vocab_size
34+
self.config = config
3135

3236
@torch.no_grad()
3337
def forward(
@@ -47,6 +51,10 @@ def forward(
4751
logits = torch.matmul(hidden_states, embedding.t())
4852
if embedding_bias is not None:
4953
logits += embedding_bias
54+
if self.config.final_logit_softcapping is not None:
55+
logits.div_(self.config.final_logit_softcapping)
56+
logits = torch.tanh(logits)
57+
logits.mul_(self.config.final_logit_softcapping)
5058

5159
if temperatures is None:
5260
return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits
@@ -208,8 +216,12 @@ def __init__(
208216
hidden_size: int,
209217
num_heads: int,
210218
num_kv_heads: int,
219+
attn_logit_softcapping: Optional[float],
220+
query_pre_attn_scalar: Optional[int],
211221
head_dim: int,
212222
quant: bool,
223+
attn_type: gemma_config.AttentionType,
224+
sliding_window_size: Optional[int] = None,
213225
):
214226
super().__init__()
215227

@@ -225,7 +237,10 @@ def __init__(
225237
self.q_size = self.num_heads * self.head_dim
226238
self.kv_size = self.num_kv_heads * self.head_dim
227239

228-
self.scaling = self.head_dim**-0.5
240+
if query_pre_attn_scalar is not None:
241+
self.scaling = query_pre_attn_scalar**-0.5
242+
else:
243+
self.scaling = self.head_dim**-0.5
229244

230245
self.qkv_proj = Linear(
231246
self.hidden_size,
@@ -236,6 +251,10 @@ def __init__(
236251
self.hidden_size,
237252
quant=quant)
238253

254+
self.attn_type = attn_type
255+
self.sliding_window_size = sliding_window_size
256+
self.attn_logit_softcapping = attn_logit_softcapping
257+
239258
def forward(
240259
self,
241260
hidden_states: torch.Tensor,
@@ -283,7 +302,21 @@ def forward(
283302
v = value.transpose(1, 2)
284303

285304
# [batch_size, n_local_heads, input_len, max_seq_len]
286-
scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
305+
q.mul_(self.scaling)
306+
scores = torch.matmul(q, k.transpose(2, 3))
307+
if (
308+
self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING
309+
and self.sliding_window_size is not None
310+
):
311+
all_ones = torch.ones_like(mask)
312+
sliding_mask = torch.triu(
313+
all_ones, -1 * self.sliding_window_size + 1
314+
) * torch.tril(all_ones, self.sliding_window_size - 1)
315+
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
316+
if self.attn_logit_softcapping is not None:
317+
scores.div_(self.attn_logit_softcapping)
318+
scores = torch.tanh(scores)
319+
scores.mul_(self.attn_logit_softcapping)
287320
scores = scores + mask
288321
scores = F.softmax(scores.float(), dim=-1).type_as(q)
289322

@@ -308,8 +341,11 @@ def __init__(
308341
hidden_size=config.hidden_size,
309342
num_heads=config.num_attention_heads,
310343
num_kv_heads=config.num_key_value_heads,
344+
attn_logit_softcapping=config.attn_logit_softcapping,
345+
query_pre_attn_scalar=config.query_pre_attn_scalar,
311346
head_dim=config.head_dim,
312347
quant=config.quant,
348+
attn_type=gemma_config.AttentionType.GLOBAL,
313349
)
314350
self.mlp = GemmaMLP(
315351
hidden_size=config.hidden_size,
@@ -350,6 +386,77 @@ def forward(
350386
return hidden_states
351387

352388

389+
class Gemma2DecoderLayer(nn.Module):
390+
def __init__(
391+
self,
392+
config: gemma_config.GemmaConfig,
393+
attn_type: gemma_config.AttentionType,
394+
):
395+
super().__init__()
396+
self.self_attn = GemmaAttention(
397+
hidden_size=config.hidden_size,
398+
num_heads=config.num_attention_heads,
399+
num_kv_heads=config.num_key_value_heads,
400+
attn_logit_softcapping=config.attn_logit_softcapping,
401+
query_pre_attn_scalar=config.query_pre_attn_scalar,
402+
head_dim=config.head_dim,
403+
quant=config.quant,
404+
attn_type=attn_type,
405+
sliding_window_size=config.sliding_window_size,
406+
)
407+
self.mlp = GemmaMLP(
408+
hidden_size=config.hidden_size,
409+
intermediate_size=config.intermediate_size,
410+
quant=config.quant,
411+
)
412+
self.input_layernorm = RMSNorm(config.hidden_size,
413+
eps=config.rms_norm_eps)
414+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
415+
eps=config.rms_norm_eps)
416+
self.pre_feedforward_layernorm = (
417+
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
418+
if config.use_pre_ffw_norm
419+
else None
420+
)
421+
self.post_feedforward_layernorm = (
422+
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
423+
if config.use_post_ffw_norm
424+
else None
425+
)
426+
427+
def forward(
428+
self,
429+
hidden_states: torch.Tensor,
430+
freqs_cis: torch.Tensor,
431+
kv_write_indices: torch.Tensor,
432+
kv_cache: Tuple[torch.Tensor, torch.Tensor],
433+
mask: torch.Tensor,
434+
) -> torch.Tensor:
435+
# Self Attention
436+
residual = hidden_states
437+
hidden_states = self.input_layernorm(hidden_states)
438+
hidden_states = self.self_attn(
439+
hidden_states=hidden_states,
440+
freqs_cis=freqs_cis,
441+
kv_write_indices=kv_write_indices,
442+
kv_cache=kv_cache,
443+
mask=mask,
444+
)
445+
hidden_states = self.post_attention_layernorm(hidden_states)
446+
hidden_states = residual + hidden_states
447+
448+
# MLP
449+
residual = hidden_states
450+
if self.pre_feedforward_layernorm is not None:
451+
hidden_states = self.pre_feedforward_layernorm(hidden_states)
452+
hidden_states = self.mlp(hidden_states)
453+
if self.post_feedforward_layernorm is not None:
454+
hidden_states = self.post_feedforward_layernorm(hidden_states)
455+
hidden_states = residual + hidden_states
456+
457+
return hidden_states
458+
459+
353460
class GemmaModel(nn.Module):
354461

355462
def __init__(self, config: gemma_config.GemmaConfig):
@@ -358,8 +465,18 @@ def __init__(self, config: gemma_config.GemmaConfig):
358465
self.vocab_size = config.vocab_size
359466

360467
self.layers = nn.ModuleList()
361-
for _ in range(config.num_hidden_layers):
362-
self.layers.append(GemmaDecoderLayer(config))
468+
for i in range(config.num_hidden_layers):
469+
if config.architecture == gemma_config.Architecture.GEMMA_1:
470+
self.layers.append(GemmaDecoderLayer(config))
471+
elif config.architecture == gemma_config.Architecture.GEMMA_2:
472+
attn_type = (
473+
config.attn_types[i]
474+
if config.attn_types is not None
475+
else gemma_config.AttentionType.GLOBAL
476+
)
477+
self.layers.append(Gemma2DecoderLayer(config, attn_type))
478+
else:
479+
raise ValueError(f'Unknown architecture: {config.architecture}')
363480
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
364481

365482
def forward(
@@ -400,7 +517,7 @@ def __init__(
400517
self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
401518
self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)
402519
self.model = GemmaModel(config)
403-
self.sampler = Sampler(vocab_size)
520+
self.sampler = Sampler(vocab_size, config)
404521

405522
# Pre-compute rotary embedding table.
406523
rope_theta = getattr(config, 'rope_theta', 10000)
@@ -558,9 +675,21 @@ def generate(
558675
return results[0] if is_str_prompt else results
559676

560677
def load_weights(self, model_path: str):
561-
self.load_state_dict(
562-
torch.load(
563-
model_path, mmap=True, weights_only=True,
564-
)['model_state_dict'],
565-
strict=False,
566-
)
678+
if os.path.isfile(model_path):
679+
self.load_state_dict(
680+
torch.load(
681+
model_path, mmap=True, weights_only=True,
682+
)['model_state_dict'],
683+
strict=False,
684+
)
685+
else:
686+
index_path = os.path.join(model_path, 'pytorch_model.bin.index.json')
687+
with open(index_path, "r", encoding="utf-8") as f:
688+
index = json.load(f)
689+
shard_files = list(set(index["weight_map"].values()))
690+
for shard_file in shard_files:
691+
shard_path = os.path.join(model_path, shard_file)
692+
state_dict = torch.load(shard_path, map_location="cpu", weights_only=True)
693+
self.load_state_dict(state_dict, strict=False)
694+
del state_dict # Save memory.
695+
gc.collect()

0 commit comments

Comments
 (0)