Skip to content

Commit af1f69d

Browse files
safeguard against resolved mappings including observer/transform layers
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 04ada71 commit af1f69d

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,13 @@ def _set_resolved_mappings(self, model: Module) -> None:
304304
"""
305305
resolved_mappings: list[ResolvedMapping] = []
306306
for mapping_idx, mapping in enumerate(self.mappings):
307-
smooth_layers = get_layers(mapping.smooth_layer, model)
307+
smooth_layers = get_layers(
308+
mapping.smooth_layer, model, exclude_internal_modules=True
309+
)
308310
smooth_names = [
309311
smooth_name
310312
for smooth_name in smooth_layers
311-
if not find_name_or_class_matches(
312-
smooth_name, model, self.ignore + ["re:.*_observer$"]
313-
)
313+
if not find_name_or_class_matches(smooth_name, model, self.ignore)
314314
]
315315

316316
num_skipped_mappings = 0
@@ -331,10 +331,8 @@ def _set_resolved_mappings(self, model: Module) -> None:
331331
for balance_suffix, balance_layer in get_layers(
332332
balance_regex,
333333
smooth_parent,
334+
exclude_internal_modules=True,
334335
).items():
335-
if balance_suffix.endswith("observer"):
336-
continue
337-
338336
balance_name = f"{smooth_parent_name}.{balance_suffix}"
339337

340338
# exclude v_proj->o_proj mappings whose shapes are incompatible

src/llmcompressor/utils/pytorch/module.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111
from compressed_tensors.quantization.utils import is_module_quantized
12-
from packaging import version
1312
from torch.nn import Linear, Module, Parameter
1413
from torch.nn.modules.conv import _ConvNd
1514
from transformers import PreTrainedModel
@@ -64,10 +63,6 @@
6463
"get_layer_by_name",
6564
]
6665

67-
68-
_PARSED_TORCH_VERSION = version.parse(torch.__version__)
69-
70-
7166
ALL_TARGET = "__ALL__"
7267
ALL_PRUNABLE_TARGET = "__ALL_PRUNABLE__"
7368
ALL_QUANTIZABLE_TARGET = "__ALL_QUANTIZABLE__"
@@ -164,8 +159,47 @@ def match_layers_params(
164159
return resolved
165160

166161

167-
def get_layers(targets: Union[str, List[str]], module: Module) -> Dict[str, Module]:
168-
return match_layers_params(targets, module)
162+
def is_internal_module(name: str) -> bool:
163+
"""
164+
llm-compressor adds additional modules to a model, like observers
165+
and transforms, as part of its operation.
166+
Return whether module is internally instantiated by llm-compressor,
167+
based on its name.
168+
169+
:param name: name of module
170+
:return: True if name indicates a module instantiated
171+
"""
172+
return name.endswith(("_observer", "_transform", "perm"))
173+
174+
175+
def get_layers(
176+
targets: Union[str, List[str]],
177+
module: Module,
178+
exclude_internal_modules: bool = False,
179+
) -> Dict[str, Module]:
180+
"""
181+
Get layers (also known as submodules) of module based on targets
182+
183+
:param targets: names or regexes to search for
184+
Can be regex, e.g. "re:.*input_layernorm$" to find all layers
185+
in module whose names end in string "input_layernorm"
186+
:param module: Parent module in which to search for targets
187+
:param exclude_internal_modules: If True, don't include internal
188+
modules added by llm-compressor, e.g. Observers and Transforms.
189+
Defaults to False to maintain backward compatibility
190+
191+
:return: dict of layer name -> layer module of all layers in module
192+
that match targets
193+
"""
194+
layer_dict = match_layers_params(targets, module)
195+
if exclude_internal_modules:
196+
layer_dict = {
197+
layer_name: layer
198+
for layer_name, layer in layer_dict.items()
199+
if not is_internal_module(layer_name)
200+
}
201+
202+
return layer_dict
169203

170204

171205
def get_layer(target: str, module: Module) -> Tuple[str, Module]:

0 commit comments

Comments
 (0)