Skip to content

Torchao's CPU overhead counteracts the performance benefit of using quantization kernel. #1930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
LuFinch opened this issue Mar 21, 2025 · 11 comments

Comments

@LuFinch
Copy link

LuFinch commented Mar 21, 2025

Hi,

I did some benchmark on LLM models with int4_weight_only on CPU/GPU/XPU and expected to see models have E2E speed up compared with pure bf16/fp16.

From the aspect of kernel, int4 GEMM kernels are 2x~3x faster compared with bf16/fp16 GEMM in general.

However, I did not see E2E performance improvement and even slowdown in some models.

After profiling, I found that Torchao's CPU overhead is too high, and it might be higher than the time saved from int4 GEMM kernel. The reason is that Torchao uses Tensor subclass and __torch_function__ to redispatch nn.linear to custom int4 matmul op.

  • In eager mode, the dispatching time (the blue square showed as below) takes longer than aten::_weight_int4pack_mm_cpu. I thought that torch.compile could optimize these redispatching things.

Image

  • However, in compile mode, redispatching disappears in compiled model.forward region but it introduces extra host work in dynamo/inductor. All torchao/dtypes/affine_quantized_tensor.py will be flattened by torch/_functorch/_aot_autograd/subclass_utils.py(233): flatten_subclass and the flatten time is close to the time of aten::_weight_int4pack_mm_cpu

Image

Both eager mode and compile mode suffer from these device-agnostic Torchao's CPU overhead, which may counteract the performance benefit we get from int4 GEMM in host bound model, such as small models or GPU/XPU is too fast (we meet this issue with Qwen2-0.5b, Phi3-3.8b from huggingface on Nvidia A100 GPU and Intel Data Center GPU Max Series).

Could you optimize these CPU overhead?

Reproducer

import torch
from torchao.quantization.quant_api import (
    int4_weight_only,
    quantize_,
)

class Linear_Gate_Up(torch.nn.Module):
    def __init__(self, in_feature, out_feature):
        super(Linear_Gate_Up, self).__init__()
        self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=False)
        self.gate_proj2 = torch.nn.Linear(out_feature, out_feature, bias=False)
 
    def forward(self, x):
        return self.gate_proj2(self.gate_proj(x))
 
if __name__ == "__main__":
    # device = "cpu"
    device = "cuda"
    # device = "xpu"
    quantize_model = True
    compile_model = True
    run_with_profiler = False

    with torch.no_grad():
        model = Linear_Gate_Up(512, 1024).eval().to(device).to(torch.bfloat16)
        x = torch.randn(1, 512).to(device).to(torch.bfloat16)
        if quantize_model:
            if device == "cpu":
                from torchao.dtypes import Int4CPULayout
                quantize_(model, int4_weight_only(layout=Int4CPULayout()))
            elif device == "cuda":
                quantize_(model, int4_weight_only())
            elif device == "xpu":
                from torchao.dtypes import Int4XPULayout
                quantize_(model, int4_weight_only(layout=Int4XPULayout()))

        if compile_model:
            model = torch.compile(model)
          
        # warmup run the actual torch.compile
        model(x)
        if device == "cuda":
            torch.cuda.synchronize()
        if device == "xpu":
            torch.xpu.synchronize()
        
        latencies = []
        import time, contextlib
        with (torch.profiler.profile(
                    activities=[torch.profiler.ProfilerActivity.CPU,
                                torch.profiler.ProfilerActivity.CUDA,
                                torch.profiler.ProfilerActivity.XPU],
                    with_stack=True) 
            if run_with_profiler else contextlib.nullcontext()) as prof:
            for i in range(10):
                start = time.time()
                model(x)
                if device == "cuda":
                    torch.cuda.synchronize()
                if device == "xpu":
                    torch.xpu.synchronize()
                if i > 3: # warmup
                    latencies.append(time.time() - start)

        print("Latency: {} ms".format(sum(latencies) / len(latencies) * 1000))
        if run_with_profiler:
            prof.export_chrome_trace("./{}.json".format(device))
@leslie-fang-intel
Copy link
Collaborator

Hi @LuFinch, for the overhead mentioned in compile approach, could you share the absolute time and relative ratio of this overhead in E2E run with the example code and the Qwen2-0.5b as you mentioned?

From the profiling runtime_unwrap_tensor_subclasses seems come from https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/torch/_functorch/_aot_autograd/subclass_utils.py#L233 which tries to unwraps the tensor subclass (Quantized Tensor in TorchAO) into PyTorch native Tensor. cc @jerryzh168 have you observed similar overhead before?

@LuFinch
Copy link
Author

LuFinch commented Mar 24, 2025

I run the test on A100-PCIE-40GB with AMD EPYC 7713.

In the example code, bf16 has 0.109ms E2E latency while int4woq has 0.162ms E2E latency with torch.compile on. However, the dynamo preparation time in bf16 is 0.024ms while it is 0.098ms in int4woq. There is an extra ~0.075ms CPU overhead caused by flattening torchao's subtensor class and it takes ~46% relative ratio in E2E latency.

When running with huggingface Qwen/Qwen2.5-0.5B, bs=1, beam=1, 1024in-128out, bf16 has 17.11ms next token latency while int4woq has 29.16ms next token latency with torch.compile on. There is a 9ms overhead caused by extra Pad Op

act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))
and a 2ms overhead caused by flattening tensor. Even though we could optimize the pad op (for example, add an if-else here to skip pad op if the shape already meets requirement), the 2ms overhead caused by flattening tensor will still take ~10% relative ratio to next token latency.

@drisspg
Copy link
Contributor

drisspg commented Mar 24, 2025

cc @bdhirsh @IvanKobzarev For subclass flatten overhead

@bdhirsh
Copy link
Contributor

bdhirsh commented Mar 24, 2025

Theoretically the parametrization work from Ivan could eliminate this extra runtime overhead in compile

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Mar 25, 2025

Thanks for the reply @bdhirsh @IvanKobzarev. May I know if there is a rough plan to mitigate this overhead?

@jerryzh168
Copy link
Contributor

this is actually not merged in to torchao main: #1114 cc @IvanKobzarev

@IvanKobzarev
Copy link
Contributor

@leslie-fang-intel
Parametrization can be used already, see example in the test: https://github.com/pytorch/pytorch/blob/main/test/functorch/test_aotdispatch.py#L6395

This parametrization requires Tensor Subclasses parameters constructors __init__, __new__ to be dynamo traceable.

In torchao in 2024 some Tensor Subclasses had Callable argument in constructor, which was not supported in dynamo.
But I can not find this anymore in torchao. So you can try parametrization and if everything fails - we can work on it with dynamo people.

@IvanKobzarev
Copy link
Contributor

@jerryzh168
Recreated PR to main - #1951

@leslie-fang-intel
Copy link
Collaborator

But I can not find this anymore in torchao. So you can try parametrization and if everything fails - we can work on it with dynamo people.

Thanks for the comment @IvanKobzarev. Will #1951 solve this issue? Or do you still suggest to invoke unwrap_tensor_subclass_parameters after #1951 landed?

@sunjiweiswift
Copy link
Contributor

sunjiweiswift commented Mar 26, 2025

@jerryzh168 Recreated PR to main - #1951

I use "unwrap_tensor_subclass_parameters" and pr1951in the script, and "is_traceable_wrapper_subclass" returns true. But I found that the performance is even worse.

 model = model_class[0].from_pretrained(
        args.model_id, torch_dtype=load_dtype, config=config, low_cpu_mem_usage=True, trust_remote_code=args.use_hf_code,
        device_map=device_map, quantization_config=quantization_config)
 unwrap_tensor_subclass_parameters(model)

@LuFinch
Copy link
Author

LuFinch commented Mar 26, 2025

@sunjiweiswift In my testing, using #1951 and unwrap_tensor_subclass_parameters(model) , Llama3.1-8b's dynamo/inductor preparation time decreases from ~4.5ms to ~2.5ms. The torch/_functorch/_aot_autograd/subclass_utils.py(233): flatten_subclass related things disappear in trace now. However, pure bf16 only needs ~0.6ms from dynamo_cache_lookup to dispatch first op. There are still something bringing CPU overhead in int4woq compared with bf16 but not recorded by profiler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants