Skip to content

Commit da26d3e

Browse files
authored
added retrying downloads logic for stability (#370)
Signed-off-by: Onkar Chougule <[email protected]>
1 parent 3de4072 commit da26d3e

File tree

5 files changed

+20
-11
lines changed

5 files changed

+20
-11
lines changed

QEfficient/base/common.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
QEFFAutoModel provides a common interface for loading the HuggingFace models using either the HF card name of local path of downloaded model.
1313
"""
1414

15+
import os
1516
from typing import Any
1617

1718
from transformers import AutoConfig
1819
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
1920

2021
from QEfficient.base.modeling_qeff import QEFFBaseModel
2122
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
23+
from QEfficient.utils import login_and_download_hf_lm
2224

2325

2426
class QEFFCommonLoader:
@@ -50,6 +52,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
5052
)
5153

5254
local_model_dir = kwargs.pop("local_model_dir", None)
55+
if not os.path.isdir(pretrained_model_name_or_path) and local_model_dir is None:
56+
pretrained_model_name_or_path = login_and_download_hf_lm(pretrained_model_name_or_path, *args, **kwargs)
5357
hf_token = kwargs.pop("hf_token", None)
5458
continuous_batching = True if kwargs.pop("full_batch_size", None) else False
5559

QEfficient/utils/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class Constants:
9595
INPUT_STR = ["My name is"]
9696
GB = 2**30
9797
MAX_QPC_LIMIT = 30
98-
MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
98+
MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
9999
NUM_SPECULATIVE_TOKENS = 2
100100
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version.
101101

scripts/replicate_kv_head/replicate_kv_heads.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
1616
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
1717
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
18+
from QEfficient.utils._utils import login_and_download_hf_lm
1819

1920

2021
def duplicate_weights_for_linear_layer(
@@ -79,7 +80,9 @@ def main(args):
7980
model_kwargs = {"attn_implementation": "eager"}
8081
if args.num_hidden_layers:
8182
model_kwargs["num_hidden_layers"] = args.num_hidden_layers
82-
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
83+
84+
pretrained_model_name_or_path = login_and_download_hf_lm(model_name)
85+
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
8386

8487
# Undo the effect of replace_transformers_quantizers
8588
undo_transformers_quantizers()

tests/transformers/spd/test_pld_inference.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int):
145145
"""
146146
num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float
147147
input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len
148-
assert input_len_padded <= ctx_len, (
149-
"input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
150-
)
148+
assert (
149+
input_len_padded <= ctx_len
150+
), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
151151
return input_len_padded
152152

153153

@@ -202,6 +202,7 @@ def find_candidate_pred_tokens(
202202
return np.full(num_pred_tokens, fill_tok, dtype=np.int64), has_empty_tokens
203203

204204

205+
@pytest.mark.on_qaic
205206
@pytest.mark.parametrize(
206207
"prompts, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, target_model_name, full_batch_size, max_ngram_size",
207208
configs,

tests/transformers/spd/test_spd_inference.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int):
7575
"""
7676
num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float
7777
input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len
78-
assert input_len_padded <= ctx_len, (
79-
"input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
80-
)
78+
assert (
79+
input_len_padded <= ctx_len
80+
), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
8181
return input_len_padded
8282

8383

@@ -93,6 +93,7 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs):
9393
return bonus_token_inputs, dlm_decode_inputs
9494

9595

96+
@pytest.mark.on_qaic
9697
@pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model
9798
@pytest.mark.parametrize(
9899
"prompts, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size",
@@ -319,9 +320,9 @@ def test_spec_decode_inference(
319320
for prompt, generation in zip(prompts, batch_decode):
320321
print(f"{prompt=} {generation=}")
321322
# validation check
322-
assert mean_num_accepted_tokens == float(num_speculative_tokens + 1), (
323-
f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}"
324-
)
323+
assert mean_num_accepted_tokens == float(
324+
num_speculative_tokens + 1
325+
), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}"
325326
del target_model_session
326327
del draft_model_session
327328
generated_ids = np.asarray(generated_ids[0]).flatten()

0 commit comments

Comments
 (0)