Skip to content

Commit 09a3cd0

Browse files
committed
Added validation logic to recipe
Signed-off-by: shanjiaz <[email protected]>
1 parent 0e9e055 commit 09a3cd0

File tree

5 files changed

+21
-70
lines changed

5 files changed

+21
-70
lines changed

src/llmcompressor/modifiers/interface.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

src/llmcompressor/modifiers/modifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from abc import abstractmethod
22
from typing import Optional
33

4+
from pydantic import BaseModel, ConfigDict
5+
46
from llmcompressor.core.events import Event, EventType
57
from llmcompressor.core.state import State
6-
from llmcompressor.modifiers.interface import ModifierInterface
78
from llmcompressor.modifiers.utils.hooks import HooksMixin
89

910
__all__ = ["Modifier"]
1011

1112

12-
class Modifier(ModifierInterface, HooksMixin):
13+
class Modifier(HooksMixin, BaseModel):
1314
"""
1415
A base class for all modifiers to inherit from.
1516
Modifiers are used to modify the training process for a model.
@@ -30,6 +31,8 @@ class Modifier(ModifierInterface, HooksMixin):
3031
:param update: The update step for the modifier
3132
"""
3233

34+
model_config = ConfigDict(extra="forbid")
35+
3336
index: Optional[int] = None
3437
group: Optional[str] = None
3538
start: Optional[float] = None

src/llmcompressor/pipelines/registry.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,6 @@ def _validate_infer_pipeline(modifiers: List[Modifier]) -> str:
6464
return "sequential"
6565

6666
active_qmods = _get_active_quant_modifiers(modifiers)
67-
if len(active_qmods) > 1:
68-
raise ValueError(
69-
f"Recipe contains more than one active quantization config "
70-
f"({active_qmods}). These configs may be conflicting, Please modify "
71-
"your recipe to use at most one quantization config"
72-
)
7367

7468
if len(active_qmods) == 1:
7569
quant_modifier = active_qmods[0]

src/llmcompressor/recipe/recipe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from llmcompressor.modifiers import Modifier, ModifierFactory
1010
from llmcompressor.recipe.utils import (
11+
_get_active_quant_modifiers,
1112
_load_json_or_yaml_string,
1213
_parse_recipe_from_md,
1314
deep_merge_dicts,
@@ -161,6 +162,16 @@ def create_instance(
161162

162163
return Recipe.model_validate(filter_dict(obj, target_stage=target_stage))
163164

165+
@classmethod
166+
def validate_pipeline(self) -> str:
167+
active_qmods = _get_active_quant_modifiers(self.modifiers)
168+
if len(active_qmods) > 1:
169+
raise ValueError(
170+
f"Recipe contains more than one active quantization config "
171+
f"({active_qmods}). These configs may be conflicting. Please modify "
172+
"your recipe to use at most one quantization config"
173+
)
174+
164175
@model_validator(mode="before")
165176
@classmethod
166177
def parse_from_dict(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@@ -226,6 +237,7 @@ def create_modifier(self) -> List[Modifier]:
226237
)
227238
for modifier in self.modifiers
228239
]
240+
self.validate_pipeline()
229241
return self.modifiers
230242

231243
def dict(self, *args, **kwargs) -> Dict[str, Any]:

src/llmcompressor/recipe/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,7 @@ def deep_merge_dicts(d1: dict, d2: dict) -> dict:
139139
i += 1
140140
result[f"{base_key}_{i}"] = val
141141
return result
142+
143+
144+
def _get_active_quant_modifiers(modifiers: List[Modifier]) -> List[Modifier]:
145+
return [m for m in modifiers if m.__class__.__name__ == "QuantizationModifier"]

0 commit comments

Comments
 (0)