13
13
# limitations under the License.
14
14
"""Inference-only Gemma model implementation."""
15
15
16
+ import json
17
+ import gc
18
+ import os
16
19
import re
17
20
import torch
18
21
from torch import nn
25
28
26
29
class Sampler (nn .Module ):
27
30
28
- def __init__ (self , vocab_size : int ):
31
+ def __init__ (self , vocab_size : int , config : gemma_config . GemmaConfig ):
29
32
super ().__init__ ()
30
33
self .vocab_size = vocab_size
34
+ self .config = config
31
35
32
36
@torch .no_grad ()
33
37
def forward (
@@ -47,6 +51,10 @@ def forward(
47
51
logits = torch .matmul (hidden_states , embedding .t ())
48
52
if embedding_bias is not None :
49
53
logits += embedding_bias
54
+ if self .config .final_logit_softcapping is not None :
55
+ logits .div_ (self .config .final_logit_softcapping )
56
+ logits = torch .tanh (logits )
57
+ logits .mul_ (self .config .final_logit_softcapping )
50
58
51
59
if temperatures is None :
52
60
return torch .argmax (logits , dim = - 1 ).squeeze (dim = - 1 ), logits
@@ -208,8 +216,12 @@ def __init__(
208
216
hidden_size : int ,
209
217
num_heads : int ,
210
218
num_kv_heads : int ,
219
+ attn_logit_softcapping : Optional [float ],
220
+ query_pre_attn_scalar : Optional [int ],
211
221
head_dim : int ,
212
222
quant : bool ,
223
+ attn_type : gemma_config .AttentionType ,
224
+ sliding_window_size : Optional [int ] = None ,
213
225
):
214
226
super ().__init__ ()
215
227
@@ -225,7 +237,10 @@ def __init__(
225
237
self .q_size = self .num_heads * self .head_dim
226
238
self .kv_size = self .num_kv_heads * self .head_dim
227
239
228
- self .scaling = self .head_dim ** - 0.5
240
+ if query_pre_attn_scalar is not None :
241
+ self .scaling = query_pre_attn_scalar ** - 0.5
242
+ else :
243
+ self .scaling = self .head_dim ** - 0.5
229
244
230
245
self .qkv_proj = Linear (
231
246
self .hidden_size ,
@@ -236,6 +251,10 @@ def __init__(
236
251
self .hidden_size ,
237
252
quant = quant )
238
253
254
+ self .attn_type = attn_type
255
+ self .sliding_window_size = sliding_window_size
256
+ self .attn_logit_softcapping = attn_logit_softcapping
257
+
239
258
def forward (
240
259
self ,
241
260
hidden_states : torch .Tensor ,
@@ -283,7 +302,21 @@ def forward(
283
302
v = value .transpose (1 , 2 )
284
303
285
304
# [batch_size, n_local_heads, input_len, max_seq_len]
286
- scores = torch .matmul (q , k .transpose (2 , 3 )) * self .scaling
305
+ q .mul_ (self .scaling )
306
+ scores = torch .matmul (q , k .transpose (2 , 3 ))
307
+ if (
308
+ self .attn_type == gemma_config .AttentionType .LOCAL_SLIDING
309
+ and self .sliding_window_size is not None
310
+ ):
311
+ all_ones = torch .ones_like (mask )
312
+ sliding_mask = torch .triu (
313
+ all_ones , - 1 * self .sliding_window_size + 1
314
+ ) * torch .tril (all_ones , self .sliding_window_size - 1 )
315
+ mask = torch .where (sliding_mask == 1 , mask , - 2.3819763e38 )
316
+ if self .attn_logit_softcapping is not None :
317
+ scores .div_ (self .attn_logit_softcapping )
318
+ scores = torch .tanh (scores )
319
+ scores .mul_ (self .attn_logit_softcapping )
287
320
scores = scores + mask
288
321
scores = F .softmax (scores .float (), dim = - 1 ).type_as (q )
289
322
@@ -308,8 +341,11 @@ def __init__(
308
341
hidden_size = config .hidden_size ,
309
342
num_heads = config .num_attention_heads ,
310
343
num_kv_heads = config .num_key_value_heads ,
344
+ attn_logit_softcapping = config .attn_logit_softcapping ,
345
+ query_pre_attn_scalar = config .query_pre_attn_scalar ,
311
346
head_dim = config .head_dim ,
312
347
quant = config .quant ,
348
+ attn_type = gemma_config .AttentionType .GLOBAL ,
313
349
)
314
350
self .mlp = GemmaMLP (
315
351
hidden_size = config .hidden_size ,
@@ -350,6 +386,77 @@ def forward(
350
386
return hidden_states
351
387
352
388
389
+ class Gemma2DecoderLayer (nn .Module ):
390
+ def __init__ (
391
+ self ,
392
+ config : gemma_config .GemmaConfig ,
393
+ attn_type : gemma_config .AttentionType ,
394
+ ):
395
+ super ().__init__ ()
396
+ self .self_attn = GemmaAttention (
397
+ hidden_size = config .hidden_size ,
398
+ num_heads = config .num_attention_heads ,
399
+ num_kv_heads = config .num_key_value_heads ,
400
+ attn_logit_softcapping = config .attn_logit_softcapping ,
401
+ query_pre_attn_scalar = config .query_pre_attn_scalar ,
402
+ head_dim = config .head_dim ,
403
+ quant = config .quant ,
404
+ attn_type = attn_type ,
405
+ sliding_window_size = config .sliding_window_size ,
406
+ )
407
+ self .mlp = GemmaMLP (
408
+ hidden_size = config .hidden_size ,
409
+ intermediate_size = config .intermediate_size ,
410
+ quant = config .quant ,
411
+ )
412
+ self .input_layernorm = RMSNorm (config .hidden_size ,
413
+ eps = config .rms_norm_eps )
414
+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
415
+ eps = config .rms_norm_eps )
416
+ self .pre_feedforward_layernorm = (
417
+ RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
418
+ if config .use_pre_ffw_norm
419
+ else None
420
+ )
421
+ self .post_feedforward_layernorm = (
422
+ RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
423
+ if config .use_post_ffw_norm
424
+ else None
425
+ )
426
+
427
+ def forward (
428
+ self ,
429
+ hidden_states : torch .Tensor ,
430
+ freqs_cis : torch .Tensor ,
431
+ kv_write_indices : torch .Tensor ,
432
+ kv_cache : Tuple [torch .Tensor , torch .Tensor ],
433
+ mask : torch .Tensor ,
434
+ ) -> torch .Tensor :
435
+ # Self Attention
436
+ residual = hidden_states
437
+ hidden_states = self .input_layernorm (hidden_states )
438
+ hidden_states = self .self_attn (
439
+ hidden_states = hidden_states ,
440
+ freqs_cis = freqs_cis ,
441
+ kv_write_indices = kv_write_indices ,
442
+ kv_cache = kv_cache ,
443
+ mask = mask ,
444
+ )
445
+ hidden_states = self .post_attention_layernorm (hidden_states )
446
+ hidden_states = residual + hidden_states
447
+
448
+ # MLP
449
+ residual = hidden_states
450
+ if self .pre_feedforward_layernorm is not None :
451
+ hidden_states = self .pre_feedforward_layernorm (hidden_states )
452
+ hidden_states = self .mlp (hidden_states )
453
+ if self .post_feedforward_layernorm is not None :
454
+ hidden_states = self .post_feedforward_layernorm (hidden_states )
455
+ hidden_states = residual + hidden_states
456
+
457
+ return hidden_states
458
+
459
+
353
460
class GemmaModel (nn .Module ):
354
461
355
462
def __init__ (self , config : gemma_config .GemmaConfig ):
@@ -358,8 +465,18 @@ def __init__(self, config: gemma_config.GemmaConfig):
358
465
self .vocab_size = config .vocab_size
359
466
360
467
self .layers = nn .ModuleList ()
361
- for _ in range (config .num_hidden_layers ):
362
- self .layers .append (GemmaDecoderLayer (config ))
468
+ for i in range (config .num_hidden_layers ):
469
+ if config .architecture == gemma_config .Architecture .GEMMA_1 :
470
+ self .layers .append (GemmaDecoderLayer (config ))
471
+ elif config .architecture == gemma_config .Architecture .GEMMA_2 :
472
+ attn_type = (
473
+ config .attn_types [i ]
474
+ if config .attn_types is not None
475
+ else gemma_config .AttentionType .GLOBAL
476
+ )
477
+ self .layers .append (Gemma2DecoderLayer (config , attn_type ))
478
+ else :
479
+ raise ValueError (f'Unknown architecture: { config .architecture } ' )
363
480
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
364
481
365
482
def forward (
@@ -400,7 +517,7 @@ def __init__(
400
517
self .tokenizer = tokenizer .Tokenizer (config .tokenizer )
401
518
self .embedder = Embedding (vocab_size , config .hidden_size , config .quant )
402
519
self .model = GemmaModel (config )
403
- self .sampler = Sampler (vocab_size )
520
+ self .sampler = Sampler (vocab_size , config )
404
521
405
522
# Pre-compute rotary embedding table.
406
523
rope_theta = getattr (config , 'rope_theta' , 10000 )
@@ -558,9 +675,21 @@ def generate(
558
675
return results [0 ] if is_str_prompt else results
559
676
560
677
def load_weights (self , model_path : str ):
561
- self .load_state_dict (
562
- torch .load (
563
- model_path , mmap = True , weights_only = True ,
564
- )['model_state_dict' ],
565
- strict = False ,
566
- )
678
+ if os .path .isfile (model_path ):
679
+ self .load_state_dict (
680
+ torch .load (
681
+ model_path , mmap = True , weights_only = True ,
682
+ )['model_state_dict' ],
683
+ strict = False ,
684
+ )
685
+ else :
686
+ index_path = os .path .join (model_path , 'pytorch_model.bin.index.json' )
687
+ with open (index_path , "r" , encoding = "utf-8" ) as f :
688
+ index = json .load (f )
689
+ shard_files = list (set (index ["weight_map" ].values ()))
690
+ for shard_file in shard_files :
691
+ shard_path = os .path .join (model_path , shard_file )
692
+ state_dict = torch .load (shard_path , map_location = "cpu" , weights_only = True )
693
+ self .load_state_dict (state_dict , strict = False )
694
+ del state_dict # Save memory.
695
+ gc .collect ()
0 commit comments