Skip to content

Gemma 3 AWQ Quantization doesnt work cache_position is NoneType #1577

Closed
@Gman3214

Description

@Gman3214

Describe the bug
i am trying to quantize gemma 3 AWQ, i have managed to quantize it using gptq but when trying to do it in AWQ i am getting an error.

Expected behavior
the model to be compressed

Environment
Include all relevant environment information:

  1. OS Ubuntu 20.04
  2. Python version 3.11
  3. LLM Compressor version latest

code:

import torch
from datasets import load_dataset
from transformers import Gemma3ForConditionalGeneration, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping

# Select model and load it
MODEL_ID = "google/gemma-3-4b-it-qat-int4-unquantized"
model = Gemma3ForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Select calibration dataset
DATASET_ID = "ram3214/calibration_hebrew"
DATASET_SPLIT = "train[:3000]"
NUM_CALIBRATION_SAMPLES = 36
MAX_SEQUENCE_LENGTH = 128

# Load and preprocess dataset
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)

# Gemma 3 specific mappings
_gemma3_mappings = [
    AWQMapping(
        smooth_layer="re:.*input_layernorm$",
        balance_layers=[
            "re:.*self_attn\.q_proj$",
            "re:.*self_attn\.k_proj$",
            "re:.*self_attn\.v_proj$",
        ],
    ),
    AWQMapping(
        smooth_layer="re:.*self_attn\.v_proj$",
        balance_layers=["re:.*self_attn\.o_proj$"],
    ),
    AWQMapping(
        smooth_layer="re:.*post_attention_layernorm$",
        balance_layers=[
            "re:.*mlp\.gate_proj$",
            "re:.*mlp\.up_proj$",
        ],
    ),
    AWQMapping(
        smooth_layer="re:.*mlp\.up_proj$",
        balance_layers=["re:.*mlp\.down_proj$"],
    ),
]

# Configure the quantization algorithm
recipe = [
    AWQModifier(
        ignore=[
            "lm_head",
            "re:model\.vision_tower.*",
            "re:model\.multi_modal_projector.*",
        ],
        mappings=_gemma3_mappings,
        scheme="W4A16_ASYM",
        targets=["Linear"],
    ),
]
def data_collator(batch):
    assert len(batch) == 1
    return {key: torch.tensor(value) for key, value in batch[0].items()}


# Apply algorithms
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    data_collator=data_collator,
)

# Save to disk compressed
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-asym"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

Errors

[/usr/local/lib/python3.11/dist-packages/transformers/models/gemma3/modeling_gemma3.py](https://localhost:8080/#) in prepare_inputs_for_generation(self, input_ids, past_key_values, inputs_embeds, cache_position, position_ids, pixel_values, attention_mask, token_type_ids, use_cache, logits_to_keep, labels, **kwargs)
   1431         # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
   1432         # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
-> 1433         if cache_position[0] == 0:
   1434             model_inputs["pixel_values"] = pixel_values
   1435         is_training = token_type_ids is not None and labels is not None

TypeError: 'NoneType' object is not subscriptable

Additional context
i have ran to this issue before when i quantized the model in gptq all i had to do is change the ignore to this and it worked :

ignore=[
        "lm_head",
        "re:model\.vision_tower.*",
        "re:model\.multi_modal_projector.*",
    ],

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions