-
Notifications
You must be signed in to change notification settings - Fork 241
Torchao's CPU overhead counteracts the performance benefit of using quantization kernel. #1930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi @LuFinch, for the overhead mentioned in compile approach, could you share the absolute time and relative ratio of this overhead in E2E run with the example code and the From the profiling |
I run the test on A100-PCIE-40GB with AMD EPYC 7713. In the example code, bf16 has 0.109ms E2E latency while int4woq has 0.162ms E2E latency with torch.compile on. However, the dynamo preparation time in bf16 is 0.024ms while it is 0.098ms in int4woq. There is an extra ~0.075ms CPU overhead caused by flattening torchao's subtensor class and it takes ~46% relative ratio in E2E latency. When running with huggingface Qwen/Qwen2.5-0.5B, bs=1, beam=1, 1024in-128out, bf16 has 17.11ms next token latency while int4woq has 29.16ms next token latency with torch.compile on. There is a 9ms overhead caused by extra Pad Op
|
cc @bdhirsh @IvanKobzarev For subclass flatten overhead |
Theoretically the parametrization work from Ivan could eliminate this extra runtime overhead in compile |
Thanks for the reply @bdhirsh @IvanKobzarev. May I know if there is a rough plan to mitigate this overhead? |
this is actually not merged in to torchao main: #1114 cc @IvanKobzarev |
@leslie-fang-intel This parametrization requires Tensor Subclasses parameters constructors In torchao in 2024 some Tensor Subclasses had |
@jerryzh168 |
Thanks for the comment @IvanKobzarev. Will #1951 solve this issue? Or do you still suggest to invoke |
I use "unwrap_tensor_subclass_parameters" and pr1951in the script, and "is_traceable_wrapper_subclass" returns true. But I found that the performance is even worse. model = model_class[0].from_pretrained(
args.model_id, torch_dtype=load_dtype, config=config, low_cpu_mem_usage=True, trust_remote_code=args.use_hf_code,
device_map=device_map, quantization_config=quantization_config)
unwrap_tensor_subclass_parameters(model) |
@sunjiweiswift In my testing, using #1951 and |
Hi,
I did some benchmark on LLM models with
int4_weight_only
on CPU/GPU/XPU and expected to see models have E2E speed up compared with pure bf16/fp16.From the aspect of kernel, int4 GEMM kernels are 2x~3x faster compared with bf16/fp16 GEMM in general.
However, I did not see E2E performance improvement and even slowdown in some models.
After profiling, I found that Torchao's CPU overhead is too high, and it might be higher than the time saved from int4 GEMM kernel. The reason is that Torchao uses
Tensor subclass
and__torch_function__
to redispatch nn.linear to custom int4 matmul op.torch.compile
could optimize these redispatching things.torchao/dtypes/affine_quantized_tensor.py
will be flattened bytorch/_functorch/_aot_autograd/subclass_utils.py(233): flatten_subclass
and the flatten time is close to the time of aten::_weight_int4pack_mm_cpuBoth eager mode and compile mode suffer from these device-agnostic Torchao's CPU overhead, which may counteract the performance benefit we get from int4 GEMM in host bound model, such as small models or GPU/XPU is too fast (we meet this issue with Qwen2-0.5b, Phi3-3.8b from huggingface on Nvidia A100 GPU and Intel Data Center GPU Max Series).
Could you optimize these CPU overhead?
Reproducer
The text was updated successfully, but these errors were encountered: