Skip to content

Commit a07990b

Browse files
committed
fix recipe merge logic
Signed-off-by: shanjiaz <[email protected]>
1 parent 159235f commit a07990b

File tree

3 files changed

+48
-18
lines changed

3 files changed

+48
-18
lines changed

src/llmcompressor/recipe/recipe.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_load_json_or_yaml_string,
1212
_parse_recipe_from_md,
1313
get_yaml_serializable_dict,
14+
deep_merge_dicts,
1415
)
1516

1617
__all__ = [
@@ -265,30 +266,49 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
265266

266267
return get_yaml_serializable_dict(modifiers=self.modifiers, stage=self.stage)
267268

268-
def yaml(self, file_path: Optional[str] = None) -> str:
269+
def yaml(
270+
self,
271+
file_path: Optional[str] = None,
272+
existing_recipe_path: Optional[str] = None,
273+
) -> str:
269274
"""
270-
Return a yaml string representation of the recipe.
275+
Return a YAML string representation of the recipe, optionally merging with another YAML file.
271276
272-
:param file_path: optional file path to save yaml to
273-
:return: The yaml string representation of the recipe
277+
:param file_path: Optional path to save YAML
278+
:param existing_recipe_path: Optional path to another recipe.yaml file
279+
:return: Combined YAML string
274280
"""
281+
# Load the other recipe from file, if given
282+
other_dict = {}
283+
if existing_recipe_path:
284+
with open(existing_recipe_path, "r") as f:
285+
existing_recipe_str = f.read()
286+
other_dict = _load_json_or_yaml_string(existing_recipe_str)
287+
288+
# Serialize current recipe
289+
self_dict = get_yaml_serializable_dict(
290+
modifiers=self.modifiers,
291+
stage=self.stage,
292+
)
293+
294+
# Deep merge — keep both recipe contents
295+
merged_dict = deep_merge_dicts(other_dict, self_dict)
275296

297+
# Dump YAML
276298
file_stream = None if file_path is None else open(file_path, "w")
277-
yaml_dict = get_yaml_serializable_dict(
278-
modifiers=self.modifiers, stage=self.stage
279-
)
280-
ret = yaml.dump(
281-
yaml_dict,
299+
yaml_str = yaml.dump(
300+
merged_dict,
282301
stream=file_stream,
283302
allow_unicode=True,
284303
sort_keys=False,
285304
default_flow_style=None,
286305
width=88,
287306
)
288-
289-
if file_stream is not None:
307+
308+
if file_stream:
290309
file_stream.close()
291-
return ret
310+
311+
return yaml_str
292312

293313

294314
RecipeInput = Union[str, List[str], Recipe, List[Recipe], Modifier, List[Modifier]]

src/llmcompressor/recipe/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,15 @@ def get_yaml_serializable_dict(modifiers: List[Modifier], stage: str) -> Dict[st
9696
stage_dict[stage_name][group_name][modifier_type] = args
9797

9898
return stage_dict
99+
100+
def deep_merge_dicts(d1: dict, d2: dict) -> dict:
101+
"""
102+
Recursively merge d2 into d1.
103+
"""
104+
result = dict(d1) # copy base
105+
for key, val in d2.items():
106+
if key in result and isinstance(result[key], dict) and isinstance(val, dict):
107+
result[key] = deep_merge_dicts(result[key], val)
108+
else:
109+
result[key] = val
110+
return result

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,10 @@ def update_and_save_recipe(model_stub: str, save_directory: str):
293293
existing recipe
294294
:param save_directory: path to save combined existing recipe and current recipe
295295
"""
296-
recipe_file_name = RECIPE_FILE_NAME
296+
297297
existing_recipe = infer_recipe_from_model_path(model_stub)
298-
if existing_recipe is not None:
299-
recipe_file_name = "new_recipe.yaml"
300298

301299
recipe = active_session().lifecycle.recipe
302-
if len(recipe.modifiers) > 0:
303-
recipe_path = os.path.join(save_directory, recipe_file_name)
304-
recipe.yaml(recipe_path)
300+
301+
recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME)
302+
recipe.yaml(file_path=recipe_path, existing_recipe_path=existing_recipe)

0 commit comments

Comments
 (0)