Skip to content

Commit f5ddf69

Browse files
committed
Cleaning done 3
Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent cf7a323 commit f5ddf69

File tree

7 files changed

+117
-71
lines changed

7 files changed

+117
-71
lines changed

QEfficient/utils/_utils.py

Lines changed: 73 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -521,27 +521,57 @@ def __repr__(self):
521521
def dump_qconfig(func):
522522
def wrapper(self, *args, **kwargs):
523523
result = func(self, *args, **kwargs)
524-
create_and_dump_qconfigs(
525-
self.qpc_path,
526-
self.onnx_path,
527-
self.get_model_config,
528-
[cls.__name__ for cls in self._pytorch_transforms],
529-
[cls.__name__ for cls in self._onnx_transforms],
530-
kwargs.get("specializations"),
531-
kwargs.get("mdp_ts_num_devices", 1),
532-
kwargs.get("num_speculative_tokens"),
533-
**{
534-
k: v
535-
for k, v in kwargs.items()
536-
if k
537-
not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"]
538-
},
539-
)
524+
try:
525+
create_and_dump_qconfigs(
526+
self.qpc_path,
527+
self.onnx_path,
528+
self.get_model_config,
529+
[cls.__name__ for cls in self._pytorch_transforms],
530+
[cls.__name__ for cls in self._onnx_transforms],
531+
kwargs.get("specializations"),
532+
kwargs.get("mdp_ts_num_devices", 1),
533+
kwargs.get("num_speculative_tokens"),
534+
**{
535+
k: v
536+
for k, v in kwargs.items()
537+
if k
538+
not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"]
539+
},
540+
)
541+
except Exception as e:
542+
print(f"An unexpected error occurred while dumping the qconfig: {e}")
540543
return result
541544

542545
return wrapper
543546

544547

548+
def get_qaic_sdk_version(qaic_sdk_xml_path: str) -> Optional[str]:
549+
"""
550+
Extracts the QAIC SDK version from the given SDK XML file.
551+
552+
Args:
553+
qaic_sdk_xml_path (str): Path to the SDK XML file.
554+
Returns:
555+
The SDK version as a string if found, otherwise None.
556+
"""
557+
qaic_sdk_version = None
558+
559+
# Check and extract version from the given SDK XML file
560+
if os.path.exists(qaic_sdk_xml_path):
561+
try:
562+
tree = ET.parse(qaic_sdk_xml_path)
563+
root = tree.getroot()
564+
base_version_element = root.find(".//base_version")
565+
if base_version_element is not None:
566+
qaic_sdk_version = base_version_element.text
567+
except ET.ParseError as e:
568+
print(f"Error parsing XML file {qaic_sdk_xml_path}: {e}")
569+
except Exception as e:
570+
print(f"An unexpected error occurred while processing {qaic_sdk_xml_path}: {e}")
571+
572+
return qaic_sdk_version
573+
574+
545575
def create_and_dump_qconfigs(
546576
qpc_path,
547577
onnx_path,
@@ -558,29 +588,12 @@ def create_and_dump_qconfigs(
558588
Such as huggingface configs, QEff transforms, QAIC sdk version, QNN sdk, compilation dir, qpc dir and
559589
many other compilation options.
560590
"""
561-
qnn_config = compiler_options["qnn_config"] if "qnn_config" in compiler_options else None
562-
enable_qnn = True if "qnn_config" in compiler_options else None
563-
591+
enable_qnn = compiler_options.get("enable_qnn", False)
592+
qnn_config_path = compiler_options.get("qnn_config", None)
564593
qconfig_file_path = os.path.join(os.path.dirname(qpc_path), "qconfig.json")
565594
onnx_path = str(onnx_path)
566595
specializations_file_path = str(os.path.join(os.path.dirname(qpc_path), "specializations.json"))
567596
compile_dir = str(os.path.dirname(qpc_path))
568-
qnn_config_path = (
569-
(qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json") if enable_qnn else None
570-
)
571-
572-
# Extract QAIC SDK Apps Version from SDK XML file
573-
tree = ET.parse(Constants.SDK_APPS_XML)
574-
root = tree.getroot()
575-
qaic_version = root.find(".//base_version").text
576-
577-
# Extract QNN SDK details from YAML file if the environment variable is set
578-
qnn_sdk_details = None
579-
qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME)
580-
if enable_qnn and qnn_sdk_path:
581-
qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML)
582-
with open(qnn_sdk_yaml_path, "r") as file:
583-
qnn_sdk_details = yaml.safe_load(file)
584597

585598
# Ensure all objects in the configs dictionary are JSON serializable
586599
def make_serializable(obj):
@@ -602,29 +615,38 @@ def make_serializable(obj):
602615
"onnx_transforms": make_serializable(onnx_transforms),
603616
"onnx_path": onnx_path,
604617
},
618+
"compiler_config": {
619+
"enable_qnn": enable_qnn,
620+
"compile_dir": compile_dir,
621+
"specializations_file_path": specializations_file_path,
622+
"specializations": make_serializable(specializations),
623+
"mdp_ts_num_devices": mdp_ts_num_devices,
624+
"num_speculative_tokens": num_speculative_tokens,
625+
**compiler_options,
626+
},
627+
"aic_sdk_config": {
628+
"qaic_apps_version": get_qaic_sdk_version(Constants.SDK_APPS_XML),
629+
"qaic_platform_version": get_qaic_sdk_version(Constants.SDK_PLATFORM_XML),
630+
},
605631
},
606632
}
607633

608-
aic_compiler_config = {
609-
"apps_sdk_version": qaic_version,
610-
"compile_dir": compile_dir,
611-
"specializations_file_path": specializations_file_path,
612-
"specializations": make_serializable(specializations),
613-
"mdp_ts_num_devices": mdp_ts_num_devices,
614-
"num_speculative_tokens": num_speculative_tokens,
615-
**compiler_options,
616-
}
617-
qnn_config = {
618-
"enable_qnn": enable_qnn,
619-
"qnn_config_path": qnn_config_path,
620-
}
621-
# Put AIC or qnn details.
622634
if enable_qnn:
635+
qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME)
636+
if not qnn_sdk_path:
637+
raise EnvironmentError(
638+
f"QNN_SDK_PATH {qnn_sdk_path} is not set. Please set {QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME}"
639+
)
640+
qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML)
641+
qnn_sdk_details = load_yaml(
642+
qnn_sdk_yaml_path
643+
) # Extract QNN SDK details from YAML file if the environment variable is set
644+
qnn_config = {
645+
"qnn_config_path": qnn_config_path,
646+
}
623647
qconfigs["qpc_config"]["qnn_config"] = qnn_config
624648
if qnn_sdk_details:
625649
qconfigs["qpc_config"]["qnn_config"].update(qnn_sdk_details)
626-
else:
627-
qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config
628650

629651
create_json(qconfig_file_path, qconfigs)
630652

QEfficient/utils/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,13 @@ class Constants:
105105
MAX_QPC_LIMIT = 30
106106
MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
107107
NUM_SPECULATIVE_TOKENS = 2
108-
109108
MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS
110109
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version.
111110
SDK_PLATFORM_XML = (
112111
"/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version.
113112
)
114113

115114

116-
117115
@dataclass
118116
class QnnConstants:
119117
# QNN PATH to be read from environment variable.

scripts/Jenkinsfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,4 @@ pipeline {
170170
deleteDir()
171171
}
172172
}
173-
}
173+
}

scripts/finetune/run_ft_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from peft import AutoPeftModelForCausalLM
1313
from transformers import AutoModelForCausalLM, AutoTokenizer
1414

15-
from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG
15+
from QEfficient.finetune.configs.training import TrainConfig
1616

1717
# Suppress all warnings
1818
warnings.filterwarnings("ignore")
@@ -25,7 +25,7 @@
2525
print(f"Warning: {e}. Moving ahead without these qaic modules.")
2626
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2727

28-
train_config = TRAIN_CONFIG()
28+
train_config = TrainConfig()
2929
model = AutoModelForCausalLM.from_pretrained(
3030
train_config.model_name,
3131
use_cache=False,

tests/finetune/test_finetune.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import shutil
1010

11+
import numpy as np
1112
import pytest
1213
import torch.optim as optim
1314
from torch.utils.data import DataLoader
@@ -22,12 +23,25 @@ def clean_up(path):
2223
shutil.rmtree(path)
2324

2425

25-
configs = [pytest.param("meta-llama/Llama-3.2-1B", 1, 1, 1, None, True, True, "cpu", id="llama_config")]
26+
configs = [
27+
pytest.param(
28+
"meta-llama/Llama-3.2-1B", # model_name
29+
10, # max_eval_step
30+
20, # max_train_step
31+
1, # intermediate_step_save
32+
None, # context_length
33+
True, # run_validation
34+
True, # use_peft
35+
"qaic", # device
36+
id="llama_config", # config name
37+
)
38+
]
2639

2740

28-
# TODO:enable this once docker is available
41+
@pytest.mark.skip(reason="Currently CI is broken. Once it is fixed we will enable this test.")
42+
@pytest.mark.cli
2943
@pytest.mark.on_qaic
30-
@pytest.mark.skip(reason="eager docker not available in sdk")
44+
@pytest.mark.finetune
3145
@pytest.mark.parametrize(
3246
"model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device",
3347
configs,
@@ -43,7 +57,7 @@ def test_finetune(
4357
device,
4458
mocker,
4559
):
46-
train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TRAIN_CONFIG")
60+
train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TrainConfig")
4761
generate_dataset_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_dataset_config")
4862
generate_peft_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_peft_config")
4963
get_dataloader_kwargs_spy = mocker.spy(QEfficient.cloud.finetune, "get_dataloader_kwargs")
@@ -65,23 +79,28 @@ def test_finetune(
6579
"device": device,
6680
}
6781

68-
finetune(**kwargs)
82+
results = finetune(**kwargs)
83+
assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching."
84+
assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching."
85+
assert np.allclose(results["avg_eval_loss"], 0.0206124, atol=1e-5), "Eval loss is not matching."
86+
assert np.allclose(results["avg_eval_metric"], 1.020826, atol=1e-5), "Eval metric is not matching."
87+
assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds."
6988

7089
train_config_spy.assert_called_once()
7190
generate_dataset_config_spy.assert_called_once()
7291
generate_peft_config_spy.assert_called_once()
73-
update_config_spy.assert_called_once()
7492
get_custom_data_collator_spy.assert_called_once()
7593
get_longest_seq_length_spy.assert_called_once()
7694
print_model_size_spy.assert_called_once()
7795
train_spy.assert_called_once()
7896

97+
assert update_config_spy.call_count == 2
7998
assert get_dataloader_kwargs_spy.call_count == 2
8099
assert get_preprocessed_dataset_spy.call_count == 2
81100

82101
args, kwargs = train_spy.call_args
83-
train_dataloader = args[1]
84-
eval_dataloader = args[2]
102+
train_dataloader = args[2]
103+
eval_dataloader = args[3]
85104
optimizer = args[4]
86105

87106
batch = next(iter(train_dataloader))
@@ -97,12 +116,19 @@ def test_finetune(
97116
else:
98117
assert eval_dataloader is None
99118

100-
args, kwargs = update_config_spy.call_args
119+
args, kwargs = update_config_spy.call_args_list[0]
101120
train_config = args[0]
121+
assert max_train_step >= train_config.gradient_accumulation_steps, (
122+
"Total training step should be more than "
123+
f"{train_config.gradient_accumulation_steps} which is gradient accumulation steps."
124+
)
102125

103-
saved_file = os.path.join(train_config.output_dir, "adapter_model.safetensors")
126+
saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors")
104127
assert os.path.isfile(saved_file)
105128

106129
clean_up(train_config.output_dir)
107130
clean_up("runs")
108131
clean_up(train_config.dump_root_dir)
132+
133+
134+
# TODO (Meet): Add seperate tests for BERT FT and LLama FT

tests/transformers/spd/test_pld_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_pld_spec_decode_inference(
262262
num_speculative_tokens=num_speculative_tokens,
263263
)
264264
# init qaic session
265-
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group)
265+
target_model_session = QAICInferenceSession(target_model_qpc_path)
266266
draft_model_session = None
267267

268268
# skip inputs/outputs buffers
@@ -453,7 +453,7 @@ def test_pld_spec_decode_inference(
453453
del draft_model_session
454454
generated_ids = np.asarray(generated_ids[0]).flatten()
455455
gen_len = generated_ids.shape[0]
456-
exec_info = target_model.generate(tokenizer, Constants.INPUT_STR, device_group)
456+
exec_info = target_model.generate(tokenizer, Constants.INPUT_STR)
457457
cloud_ai_100_tokens = exec_info.generated_ids[0][
458458
:gen_len
459459
] # Because we always run for single input and single batch size

tests/transformers/spd/test_spd_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def test_spec_decode_inference(
157157
full_batch_size=full_batch_size,
158158
)
159159
# init qaic session
160-
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group)
161-
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group)
160+
target_model_session = QAICInferenceSession(target_model_qpc_path)
161+
draft_model_session = QAICInferenceSession(draft_model_qpc_path)
162162

163163
# skip inputs/outputs buffers
164164
target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")]))
@@ -341,7 +341,7 @@ def test_spec_decode_inference(
341341
del draft_model_session
342342
generated_ids = np.asarray(generated_ids[0]).flatten()
343343
gen_len = generated_ids.shape[0]
344-
exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group)
344+
exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR)
345345
cloud_ai_100_tokens = exec_info.generated_ids[0][
346346
:gen_len
347347
] # Because we always run for single input and single batch size

0 commit comments

Comments
 (0)