Skip to content

Use torch.compile to speed up GPTQ algo #1561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

aladerran
Copy link

SUMMARY:
In response to #1496, this PR uses torch.compile to speed up the GPTQ quantization process in gptq_quantize.py, along with simple benchmarking tools.

I tested on a single NVIDIA A100-SXM4-80GB, with:
PyTorch version: 2.7.0+cu126
CUDA version: 12.6
cuDNN version: 90501

gptq_baseline_profile.txt
gptq_tc_profile.txt

TEST PLAN:

# Run with torch.compile
CUDA_VISIBLE_DEVICES=0 python benchmarks/gptq_runtime_profiler.py > gptq_tc_profile.txt 2>&1
# Baseline run without torch.compile
CUDA_VISIBLE_DEVICES=0 python benchmarks/gptq_runtime_profiler.py > gptq_baseline_profile.txt 2>&1

First-time contributor here, please let me know if you have any tips!

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @aladerran, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request aims to enhance the performance of the GPTQ quantization process by integrating torch.compile. The changes primarily focus on optimizing the quantize_weight function, which is a critical component of the GPTQ algorithm, to achieve faster execution times.

Highlights

  • Performance Optimization: The core change involves applying torch.compile to the quantize_weight function within the GPTQ quantization algorithm to significantly speed up its execution.
  • TorchDynamo Configuration: Necessary torch._dynamo.config imports and settings, specifically torch._dynamo.config.capture_scalar_outputs = True, have been added to ensure proper compilation and avoid potential issues with scalar outputs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully integrates torch.compile into the GPTQ quantization process, which the provided profiling results indicate leads to significant speedups. The changes involve adding the necessary import, applying the @torch.compile decorator, and setting a torch._dynamo.config flag. The core change of applying the decorator is well-placed. The only feedback is a minor suggestion regarding the global configuration setting for torch.compile to improve code clarity and maintainability.

@@ -16,6 +17,8 @@
from llmcompressor.observers.base import Observer
from llmcompressor.pytorch.utils.helpers import tensor_sparsity

torch._dynamo.config.capture_scalar_outputs = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting torch._dynamo.config.capture_scalar_outputs = True at the module level applies this configuration globally to any code that imports this module. While this might be necessary for torch.compile to function correctly with the quantize_weight function, it's a broad setting that could potentially affect other parts of the codebase in unexpected ways. Consider adding a brief comment explaining why this setting is needed specifically for this module/function and acknowledging its global scope.

Suggested change
torch._dynamo.config.capture_scalar_outputs = True
# Enable scalar capture for torch.compile, potentially needed for control flow
torch._dynamo.config.capture_scalar_outputs = True

@aladerran
Copy link
Author

aladerran commented Jun 17, 2025

Benchmarking script I used:

import torch
from contextlib import contextmanager
from typing import Dict, List, Tuple, Callable
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from llmcompressor.modifiers.quantization.gptq.gptq_quantize import (
    quantize_weight,
    quantize_weight_optimized,
    GPTQ_PRECISION
)
from llmcompressor.utils.pytorch.utils import measure_cuda_memory

# Tolerance values for floating-point comparisons
RTOL = 1e-5
ATOL = 1e-8

class GPTQRuntimeProfiler:
    def __init__(self, warmup_runs: int = 3, benchmark_runs: int = 5, quant_args: QuantizationArgs = None):
        self.timing_results = {}
        self.profile_data = {}
        self.warmup_runs = warmup_runs
        self.benchmark_runs = benchmark_runs
        
        # Quantization parameter configuration
        if quant_args is None:
            self.quant_args = QuantizationArgs(
                num_bits=4,
                type="int",
                symmetric=True,
                strategy=QuantizationStrategy.CHANNEL
            )
        else:
            self.quant_args = quant_args


    @contextmanager
    def cuda_time_section(self, section_name: str):
        """Measure GPU execution time"""
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        start_event.record()
        yield
        end_event.record()

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        gpu_time = start_event.elapsed_time(end_event) / 1000.0  # Convert to seconds
        self.timing_results[section_name] = gpu_time

    def warmup_gpu(self, module: torch.nn.Module, H: torch.Tensor):
        """Stabilize GPU timing through warmup runs"""
        print("Performing GPU warmup...")

        for _ in range(self.warmup_runs):
            H_copy = H.clone()
            hessians_dict = {module: H_copy}
            hessians_dict_optimized = {module: H_copy}

            _ = quantize_weight(
                module=module,
                quant_args=self.quant_args,
                hessians_dict=hessians_dict,
                blocksize=128,
                percdamp=0.01,
            )

            _ = quantize_weight_optimized(
                module=module,
                quant_args=self.quant_args,
                hessians_dict=hessians_dict_optimized,
                blocksize=128,
                percdamp=0.01,
            )

            if torch.cuda.is_available():
                torch.cuda.synchronize()

        print(f"Warmup completed ({self.warmup_runs} runs)")

    def _profile_function(self, func: Callable, module: torch.nn.Module, 
                         H: torch.Tensor, section_name: str) -> Dict:
        """Profile function performance with multiple runs"""
        run_times = []
        all_results = []  # Store results for validation
        
        with measure_cuda_memory() as mem_tracker:
            for _ in range(self.benchmark_runs):
                H_copy = H.clone()
                hessians_dict = {module: H_copy}

                with self.cuda_time_section(section_name):
                    result = func(
                        module=module,
                        quant_args=self.quant_args,
                        hessians_dict=hessians_dict,
                        blocksize=128,
                        percdamp=0.01,
                    )
                run_times.append(self.timing_results[section_name])
                all_results.append(result)

        # Validate result consistency
        self._validate_results(all_results, section_name)
        
        avg_time = sum(run_times) / len(run_times)
        min_time = min(run_times)
        max_time = max(run_times)
        std_dev = (sum((t - avg_time) ** 2 for t in run_times) / len(run_times)) ** 0.5

        return {
            "average": avg_time,
            "min": min_time,
            "max": max_time,
            "std_dev": std_dev,
            "peak_memory_mb": mem_tracker.peak_consumed_memory / 1024 / 1024,
        }

    def _validate_results(self, results: List, section_name: str):
        """Validate consistency across multiple runs"""
        if len(results) < 2:
            return
            
        ref_loss, ref_weight, ref_scale, ref_zero, ref_g_idx = results[0]
        
        for i, result in enumerate(results[1:], start=1):
            loss, weight, scale, zero, g_idx = result
            
            # Validate loss value
            assert abs(loss - ref_loss) < ATOL, (
                f"{section_name} run {i} loss mismatch: {loss} vs {ref_loss}"
            )
            
            # Validate weight tensor
            assert torch.allclose(weight, ref_weight, rtol=RTOL, atol=ATOL), (
                f"{section_name} run {i} weight mismatch"
            )
            
            # Validate scale
            assert torch.allclose(scale, ref_scale, rtol=RTOL, atol=ATOL), (
                f"{section_name} run {i} scale mismatch"
            )
            
            # Validate zero-point
            assert torch.allclose(zero, ref_zero, rtol=RTOL, atol=ATOL), (
                f"{section_name} run {i} zero-point mismatch"
            )
            
            # Validate g_idx
            if ref_g_idx is None:
                assert g_idx is None, f"{section_name} run {i} g_idx should be None"
            else:
                assert torch.equal(g_idx, ref_g_idx), f"{section_name} run {i} g_idx mismatch"
        
        print(f"Validation passed for {section_name}: {len(results)} runs consistent")

    def profile_quantize_weight(self, module: torch.nn.Module, H: torch.Tensor):
        """Profile quantize_weight function"""
        print("\n=== Benchmarking quantize_weight ===")
        result = self._profile_function(quantize_weight, module, H, "quantize_weight")
        self._print_results(result)
        return result

    def profile_quantize_weight_optimized(self, module: torch.nn.Module, H: torch.Tensor):
        """Profile quantize_weight_optimized function"""
        print("\n=== Benchmarking quantize_weight_optimized ===")
        result = self._profile_function(quantize_weight_optimized, module, H, "quantize_weight_optimized")
        self._print_results(result)
        return result
    
    def _print_results(self, result: Dict):
        """Print performance metrics"""
        print(f"Average time: {result['average']:.4f}s")
        print(f"Min time: {result['min']:.4f}s")
        print(f"Max time: {result['max']:.4f}s")
        print(f"Std dev: {result['std_dev']:.4f}s")
        print(f"Peak memory: {result['peak_memory_mb']:.2f} MB")

    def validate_implementations(self, module: torch.nn.Module, H: torch.Tensor):
        """Validate equivalence between implementations"""
        print("\n=== Validating implementations ===")
        
        # Run reference implementation
        H_ref = H.clone()
        hessians_ref = {module: H_ref}
        ref_result = quantize_weight(
            module=module,
            quant_args=self.quant_args,
            hessians_dict=hessians_ref,
            blocksize=128,
            percdamp=0.01,
        )
        
        # Run optimized implementation
        H_opt = H.clone()
        hessians_opt = {module: H_opt}
        opt_result = quantize_weight_optimized(
            module=module,
            quant_args=self.quant_args,
            hessians_dict=hessians_opt,
            blocksize=128,
            percdamp=0.01,
        )
        
        # Unpack results
        ref_loss, ref_weight, ref_scale, ref_zero, ref_g_idx = ref_result
        opt_loss, opt_weight, opt_scale, opt_zero, opt_g_idx = opt_result
        
        # Validate numerical equivalence
        assert abs(ref_loss - opt_loss) < ATOL, f"Loss mismatch: {ref_loss} vs {opt_loss}"
        assert torch.allclose(ref_weight, opt_weight, rtol=RTOL, atol=ATOL), "Weight tensors differ"
        assert torch.allclose(ref_scale, opt_scale, rtol=RTOL, atol=ATOL), "Scale tensors differ"
        assert torch.allclose(ref_zero, opt_zero, rtol=RTOL, atol=ATOL), "Zero-point tensors differ"
        
        if ref_g_idx is None:
            assert opt_g_idx is None, "g_idx should be None"
        else:
            assert torch.equal(ref_g_idx, opt_g_idx), "g_idx tensors differ"
        
        print("Validation passed: Both implementations return identical results")

    def benchmark_matrix_sizes(self, sizes: List[Tuple[int, int]]):
        """Benchmark functions across different matrix sizes"""
        results = {}

        for rows, cols in sizes:
            print(f"\n=== Benchmarking {rows}x{cols} matrix (GPU) ===")
            module = torch.nn.Linear(cols, rows).cuda()
            H = torch.randn(cols, cols, device='cuda', dtype=GPTQ_PRECISION)
            H = H @ H.T  # Ensure positive definite
            
            # GPU warmup
            self.warmup_gpu(module, H)
            
            # Implementation validation
            self.validate_implementations(module, H)
            
            # Performance testing
            print(f"\nTesting quantize_weight:")
            results[f"{rows}x{cols}_quantize_weight"] = self.profile_quantize_weight(module, H)

            print(f"\nTesting quantize_weight_optimized:")
            results[f"{rows}x{cols}_quantize_weight_optimized"] = self.profile_quantize_weight_optimized(module, H)

        return results


def main():
    if not torch.cuda.is_available():
        print("CUDA not available, skipping GPU benchmarks")
        return

    matrix_sizes = [
        (512, 512),
        (1024, 1024),
        (2048, 2048),
        (4096, 4096)
    ]
    
    strategies = [
        {
            "name": QuantizationStrategy.TENSOR,
            "args": {"strategy": QuantizationStrategy.TENSOR}
        },
        {
            "name": QuantizationStrategy.CHANNEL,
            "args": {"strategy": QuantizationStrategy.CHANNEL}
        },
        {
            "name": QuantizationStrategy.GROUP,
            "args": {"strategy": QuantizationStrategy.GROUP, "group_size": 128}
        }
    ]

    print("=== GPTQ Runtime Profiling (GPU) ===")
    
    for strategy_config in strategies:
        print(f"\n{'='*50}")
        print(f"Testing strategy: {strategy_config['name']}")
        print(f"{'='*50}")
        
        # Create quantization args with strategy-specific parameters
        quant_args = QuantizationArgs(
            num_bits=4,
            type="int",
            symmetric=True,
            **strategy_config["args"]
        )
        
        # Create profiler with current strategy
        profiler = GPTQRuntimeProfiler(
            warmup_runs=10,
            benchmark_runs=50,
            quant_args=quant_args
        )
        
        # Run benchmark
        results = profiler.benchmark_matrix_sizes(matrix_sizes)

        # Print summary for current strategy
        print(f"\n=== Summary for {strategy_config['name']} strategy ===")
        for key, data in results.items():
            avg_time = data['average']
            peak_mem = data['peak_memory_mb']
            print(f"{key}: {avg_time:.4f}s (±{data['max'] - data['min']:.4f}s), {peak_mem:.1f}MB")


if __name__ == "__main__":
    main()

@kylesayrs
Copy link
Collaborator

kylesayrs commented Jun 17, 2025

Hi @aladerran!

Thank you for your contribution and thorough profiling data! It seems like the new runtime is about 86% of the original, a notable improvement! This change should be good to merge now, but there are a few other small modifications to the gptq_quantize method that have the potential to drastically improve runtime.

Specifically, removing branching logic in the algorithm in order to reduce graph breaks. You can debug graph breaks with TORCH_LOGS="graph_breaks". Below are a couple suggestions from ChatGPT of places to look at for optimization.
https://chatgpt.com/s/t_68517d104cb88191814964727ba0d8db

@aladerran
Copy link
Author

Hi @kylesayrs,

Thank you for the feedback! I'll look into further optimizing the runtime.

@aladerran
Copy link
Author

aladerran commented Jun 22, 2025

Hi @kylesayrs,

I introduce quantize_weight_optimized in a new commit, which isolates the main GPTQ quantization loop into a function that can be accelerated with torch.compile. The core logic should remain functionally equivalent to the original implementation. Without torch.compile, this version already achieves ~70% of the original runtime. With torch.compile enabled, execution time drops further to ~10-20% of the original.

I have updated my test script above and some of the test results are shown here:

gptq_baseline_profile.txt
gptq_tc_dynamic_profile.txt

However, there are a few considerations:

  1. Precision: I observed numerical differences may occur due to torch.compile optimizations.
  2. Memory: The peak memory increases by 1.5x on average.
  3. Compilation Overhead: Initial compilation time can be significant.

Given the overhead, should we set the torch.compile as an optional feature?

Any feedback on how to best make this optimization feature would be great.

@kylesayrs
Copy link
Collaborator

kylesayrs commented Jun 24, 2025

@aladerran Amazing work! Thank you for the contribution! I'll verify this asap so we can start quantizing faster ⚡💪

@kylesayrs kylesayrs added the ready When a PR is ready for review label Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants