Skip to content

Commit 8a80c89

Browse files
vbaddiquic-rishinr
authored andcommitted
nit: QAic changes
Signed-off-by: vbaddi <[email protected]> Signed-off-by: Rishin <[email protected]>
1 parent 6341be3 commit 8a80c89

File tree

5 files changed

+110
-7
lines changed

5 files changed

+110
-7
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,14 +1461,38 @@ def export(self, export_dir: Optional[str] = None) -> str:
14611461
0: "full_batch_size" if self.continuous_batching else "batch_size",
14621462
2: "ctx_len",
14631463
}
1464+
pkv_dynamic_sliding_axes = {
1465+
0: "full_batch_size" if self.continuous_batching else "batch_size",
1466+
2: "chunk_attn",
1467+
}
1468+
14641469
output_names = ["logits"]
14651470

1466-
for i in range(self.num_layers):
1471+
is_chunked_attention = torch.tensor(
1472+
[bool((i + 1) % 4) for i in range(self.model.config.num_hidden_layers)], dtype=torch.bool
1473+
)
1474+
global_cache_shape = [1, 8, seq_len, 128]
1475+
chunked_cache_shape = [
1476+
1,
1477+
8,
1478+
self.model.config.attention_chunk_size,
1479+
128,
1480+
]
1481+
1482+
for i in range(self.model.config.num_hidden_layers):
14671483
for kv in ["key", "value"]:
1468-
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
1469-
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
1484+
cache_shape = global_cache_shape if not is_chunked_attention[i] else chunked_cache_shape
1485+
apply_dynamic_axes = pkv_dynamic_axes if not is_chunked_attention[i] else pkv_dynamic_sliding_axes
1486+
example_inputs["past_key_values"][i].append(torch.zeros(cache_shape, dtype=torch.float32))
1487+
dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes
14701488
output_names.append(f"past_{kv}.{i}_RetainedState")
14711489

1490+
# for i in range(self.num_layers):
1491+
# for kv in ["key", "value"]:
1492+
# example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
1493+
# dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
1494+
# output_names.append(f"past_{kv}.{i}_RetainedState")
1495+
14721496
if self.continuous_batching:
14731497
example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
14741498
dynamic_axes["batch_index"] = {0: "batch_size"}
@@ -1497,6 +1521,7 @@ def build_prefill_specialization(
14971521
"batch_size": 1 if self.continuous_batching else batch_size,
14981522
"seq_len": prefill_seq_len,
14991523
"ctx_len": ctx_len,
1524+
"chunk_attn": self.model.config.attention_chunk_size,
15001525
"num_logits_to_keep": 1 if self.is_tlm else None,
15011526
}
15021527
if self.continuous_batching:
@@ -1522,6 +1547,7 @@ def build_decode_specialization(
15221547
"batch_size": full_batch_size if self.continuous_batching else batch_size,
15231548
"seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1,
15241549
"ctx_len": ctx_len,
1550+
"chunk_attn": self.model.config.attention_chunk_size,
15251551
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None,
15261552
}
15271553
if self.continuous_batching:

QEfficient/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_onnx_dir_name,
1717
get_padding_shape_from_config,
1818
get_qpc_dir_path,
19+
get_sliding_window_shapes,
1920
hf_download,
2021
load_hf_processor,
2122
load_hf_tokenizer,

QEfficient/utils/_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,64 @@ def padding_check_and_fix(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokeni
282282
tokenizer.pad_token_id = tokenizer.vocab_size - 1
283283

284284

285+
def get_sliding_window_shapes(config, batch_size, seq_len):
286+
"""
287+
Gets padding dims from model config - number of kv heads and d_head
288+
and returns padding shape - (batch_size, number of kv heads, seq_len, hidden size)
289+
required for initialization of past_key_values
290+
--------
291+
292+
:config: AutoConfig from pretrained model.
293+
:batch_size: int. number of input prompts used to create inputs
294+
:seq_len: int. sequence length to run the model for.
295+
296+
Return:
297+
List[int, int, int, int]
298+
"""
299+
300+
if hasattr(config, "n_head"): # Assuming n_head is a key in the config (GPTs/CodeGen)
301+
n_heads = config.n_head
302+
d_head = config.n_embd // config.n_head
303+
elif hasattr(config, "num_key_value_heads") and hasattr(
304+
config, "num_attention_heads"
305+
): # Check for num_key_value_heads (Llama/Mistral)
306+
n_heads = config.num_key_value_heads
307+
308+
if hasattr(config, "head_dim"):
309+
d_head = config.head_dim
310+
else:
311+
d_head = config.hidden_size // config.num_attention_heads
312+
313+
elif hasattr(config, "n_heads"): # Check for n_heads and d_model in the config (MPT Model)
314+
n_heads = config.n_heads
315+
d_head = config.d_model // config.n_heads
316+
elif hasattr(config, "new_decoder_architecture"): # Check for Falcon
317+
new_decoder_architecture = getattr(config, "new_decoder_architecture")
318+
if new_decoder_architecture: # multi_query is ignored when new_decoder_architecture is True
319+
n_heads = config.num_attention_heads
320+
else:
321+
if hasattr(config, "multi_query"):
322+
multi_query_value = getattr(config, "multi_query")
323+
if multi_query_value:
324+
n_heads = 1 # MQA , multi query is true
325+
else:
326+
n_heads = config.num_attention_heads
327+
d_head = config.hidden_size // config.num_attention_heads
328+
else:
329+
raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.")
330+
331+
# is_chunked_attention = torch.tensor([bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool)
332+
global_cache_shape = [batch_size, n_heads, seq_len, d_head]
333+
chunked_cache_shape = [
334+
batch_size,
335+
n_heads,
336+
seq_len if seq_len < config.attention_chunk_size else config.attention_chunk_size,
337+
d_head,
338+
]
339+
340+
return global_cache_shape, chunked_cache_shape
341+
342+
285343
def get_padding_shape_from_config(config, batch_size, seq_len):
286344
"""
287345
Gets padding dims from model config - number of kv heads and d_head

QEfficient/utils/generate_inputs.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
import numpy as np
99
import torch
1010

11-
from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix
11+
from QEfficient.utils import (
12+
get_num_layers_from_config,
13+
get_padding_shape_from_config,
14+
get_sliding_window_shapes,
15+
padding_check_and_fix,
16+
)
1217

1318

1419
class InputHandler:
@@ -39,6 +44,12 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f
3944
self.past_key_values = get_padding_shape_from_config(
4045
config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len
4146
)
47+
self.is_chunked_attention = torch.tensor(
48+
[bool((i + 1) % 4) for i in range(config.num_hidden_layers)], dtype=torch.bool
49+
)
50+
self.global_shape, self.sliding_shape = get_sliding_window_shapes(
51+
config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len
52+
)
4253

4354
def prepare_pytorch_inputs(self):
4455
"""
@@ -152,9 +163,16 @@ def prepare_ort_inputs(self):
152163
axis=1,
153164
).astype(np.int64)
154165

166+
# for i in range(self.n_layer):
167+
# inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
168+
# inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
169+
155170
for i in range(self.n_layer):
156-
inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
157-
inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32)
171+
cache_shape = self.global_shape if not self.is_chunked_attention[i] else self.sliding_shape
172+
inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32)
173+
inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32)
174+
175+
return inputs
158176

159177
return inputs
160178

QEfficient/utils/run_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def run_hf_model_on_pytorch(self, model_hf):
115115
input_len = model_inputs["input_ids"].shape[-1]
116116

117117
with torch.inference_mode():
118-
generation = model_hf.generate(**model_inputs, max_new_tokens=8, do_sample=False)
118+
generation = model_hf.generate(**model_inputs, max_new_tokens=12, do_sample=False)
119119
generation = generation[0][input_len:]
120120

121121
# generated_ids = input_ids[0][input_ids_len:].detach().numpy()

0 commit comments

Comments
 (0)