Skip to content

Commit a6d71e2

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Be able to set RopE.Freq_Base explicitly (#2064)
Summary: Pull Request resolved: #2064 Currently it is set to 10K all the time. Let's make it possible so that one can set it explicitly. This should be a no-op for existing models and CI jobs. bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: larryliu0820 Differential Revision: D54131552 fbshipit-source-id: 46e9eb98bc7bb99c7225d3d7b02ab4bd2ecaed1c
1 parent e585a57 commit a6d71e2

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

examples/models/llama2/llama_transformer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class ModelArgs:
8282
use_sdpa_with_kv_cache_op: bool = (
8383
False # Use custom sdpa op that updates kv cache in-place
8484
)
85+
rope_freq_base: float = 10000.0 # The base frequency for RoPE
8586
# Additional Model Metadata needed at runtime
8687
bos_idx: int = 1
8788
eos_idx: int = 3
@@ -108,7 +109,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
108109
)
109110

110111

111-
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
112+
def precompute_freqs_cis(dim: int, end: int, theta: float):
112113
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
113114
t = torch.arange(end, device=freqs.device) # pyre-ignore
114115
freqs = torch.outer(t, freqs).float() # pyre-ignore
@@ -411,7 +412,9 @@ def __init__(self, params: ModelArgs):
411412
self.use_kv_cache = params.use_kv_cache
412413

413414
freqs_cos, freqs_sin = precompute_freqs_cis(
414-
self.params.dim // self.params.n_heads, self.params.max_seq_len
415+
params.dim // params.n_heads,
416+
params.max_seq_len,
417+
params.rope_freq_base,
415418
)
416419
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
417420
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

0 commit comments

Comments
 (0)