Skip to content

Commit 4304fe9

Browse files
abukhoyquic-amitraj
authored andcommitted
trust_remote_code enabled for grok1 only
Signed-off-by: Abukhoyer Shaik <[email protected]>
1 parent ea0850d commit 4304fe9

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tests/transformers/models/test_causal_lm_models.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from QEfficient.utils.device_utils import get_available_device_id
2323
from QEfficient.utils.run_utils import ApiRunner
2424

25+
extrenal_models = {"hpcai-tech/grok-1"}
2526
test_models_qaic = [
2627
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
2728
"gpt2",
@@ -61,7 +62,7 @@
6162
]
6263

6364

64-
def load_causal_lm_model(model_config, model_name):
65+
def load_causal_lm_model(model_config):
6566
"""
6667
Function to load model from huggingface and transform to KV model
6768
--------
@@ -80,11 +81,13 @@ def load_causal_lm_model(model_config, model_name):
8081
num_hidden_layers=model_config["n_layer"],
8182
attn_implementation="eager",
8283
low_cpu_mem_usage=False,
83-
trust_remote_code=True if model_name == "hpcai-tech/grok-1" else False,
84-
) # Run models for single layers only
84+
trust_remote_code=model_config["model_name"] in extrenal_models,
85+
)
86+
# Convert to FP32 if model is in BF16
87+
if getattr(model_hf.config, "torch_dtype", None) == torch.bfloat16:
88+
model_hf = model_hf.to(torch.float32)
89+
8590
params = sum(p.numel() for p in model_hf.parameters())
86-
if model_name == "hpcai-tech/grok-1":
87-
model_hf.to(torch.float32)
8891
model_hf.eval()
8992
return model_hf, params
9093

@@ -111,7 +114,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
111114
model_config = {"model_name": model_name}
112115
model_config["n_layer"] = n_layer
113116

114-
model_hf, _ = load_causal_lm_model(model_config, model_name)
117+
model_hf, _ = load_causal_lm_model(model_config)
115118

116119
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name)
117120
config = model_hf.config
@@ -172,7 +175,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
172175
if prefill_only is not None:
173176
return
174177
# testing for CB models
175-
model_hf, _ = load_causal_lm_model(model_config, model_name)
178+
model_hf, _ = load_causal_lm_model(model_config)
176179
full_batch_size = 4
177180
fbs_prompts = Constants.INPUT_STR * 4
178181
api_runner = ApiRunner(

0 commit comments

Comments
 (0)