From 65875d135ae48cdb2893f9a05187bf201eea4189 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 4 Apr 2025 10:47:27 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- benchmarks/microbenchmarks/README.md | 17 +++++++ .../microbenchmarks/benchmark_runner.py | 43 +++++++++++++++++- .../microbenchmarks/test/benchmark_config.yml | 14 ++++++ .../test/test_benchmark_runner.py | 44 +++++++++++++++++++ benchmarks/microbenchmarks/utils.py | 2 + 5 files changed, 119 insertions(+), 1 deletion(-) diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index a95dc53755..25e3374f7a 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -46,6 +46,16 @@ model_params: [2048, 4096, 1024], [4096, 4096, 1024] ] + - name: "llama" + - name: "pow2" + min_power: 10 # Optional, default is 10 (1024) + max_power: 14 # Optional, default is 14 (16,384) + - name: "pow2_extended" + min_power: 10 # Optional, default is 10 (1024) + max_power: 14 # Optional, default is 14 (16,384) + - name: "sweep" + min_power: 8 # Optional, default is 8 (256) + max_power: 15 # Optional, default is 15 (32,768) high_precision_dtype: "torch.bfloat16" compile: "max-autotune" # Options: "default", "max-autotune", "false" device: "cuda" # Options: "cuda", "mps", "xpu", "cpu" @@ -54,6 +64,13 @@ model_params: ## Configuration Options +### Shape Generation Options +- `custom`: Manually specify shapes as a list of [m, k, n] dimensions +- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13) +- `pow2`: Generate shapes with dimensions that are powers of 2 (e.g., 1024, 2048, 4096, etc.) +- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half (e.g., 1024, 1536, 2048, 3072, etc.) +- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions + ### Quantization Methods Currently, quantization string is in same format as the one being passed in llama/generate.py. - `baseline`: No quantization diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 1a60ca6b16..e52b018470 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -48,9 +48,50 @@ def get_shapes_for_config( name = shape_config["name"] if name == "custom": shapes.extend([(name, shape) for shape in shape_config["shapes"]]) + elif name == "llama": + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + bsz, seq_len = 4, 4096 + M = bsz * seq_len + llama_shapes = { + "attn.wqkv": (M, 8192, 1280), + "attn.w0": (M, 1024, 8192), + "ffn.w13": (M, 8192, 7168), + "ffn.w2": (M, 3584, 8192), + } + shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()]) + elif name == "pow2": + # Generate shapes with dimensions that are powers of 2 + min_power_of_2 = shape_config.get("min_power", 10) # 1024 + max_power_of_2 = shape_config.get("max_power", 14) # 16,384 + for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): + val = 2**power_of_2 + shapes.append((f"{name}_{idx}", [val, val, val])) + elif name == "pow2_extended": + # Generate shapes with dimensions that are powers of 2 and powers of 2 + half + min_power_of_2 = shape_config.get("min_power", 10) # 1024 + max_power_of_2 = shape_config.get("max_power", 14) # 16,384 + for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): + val1 = 2**power_of_2 + val2 = 2**power_of_2 + 2 ** (power_of_2 - 1) + shapes.append((f"{name}_{idx*2}", [val1, val1, val1])) + shapes.append((f"{name}_{idx*2+1}", [val2, val2, val2])) + elif name == "sweep": + # Generate a sweep of shapes with different powers of 2 for M, K, N + min_p2 = shape_config.get("min_power", 8) # 256 + max_p2 = shape_config.get("max_power", 15) # 32,768 + counter = 0 + for M_p2 in range(min_p2, max_p2 + 1): + M = 2**M_p2 + for K_p2 in range(min_p2, max_p2 + 1): + K = 2**K_p2 + for N_p2 in range(min_p2, max_p2 + 1): + N = 2**N_p2 + shapes.append((f"{name}_{counter}", [M, K, N])) + counter += 1 else: raise NotImplementedError( - f"Shape config {name} not supported. Currently only supports custom shapes." + f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep." ) return shapes diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 4394d0208b..b0c660dfce 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -31,6 +31,20 @@ model_params: [2048, 4096, 1024], # [4096, 4096, 1024] ] + # Example of using LLaMa shapes + - name: "llama" + # Example of using power of 2 shapes + - name: "pow2" + min_power: 10 # 1024 + max_power: 12 # 4096 + # Example of using extended power of 2 shapes + - name: "pow2_extended" + min_power: 10 # 1024 + max_power: 11 # 2048 + # Example of using sweep shapes (commented out as it generates many shapes) + # - name: "sweep" + # min_power: 8 # 256 + # max_power: 9 # 512 high_precision_dtype: "torch.bfloat16" use_torch_compile: true torch_compile_mode: "max-autotune" diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index a8683a1de8..20991e4122 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -57,11 +57,55 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_get_shapes_for_config(self): + # Test custom shapes shapes = get_shapes_for_config( self.test_config["model_params"][0]["matrix_shapes"] ) self.assertEqual(len(shapes), 1) self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024])) + + # Test llama shapes + llama_shapes = get_shapes_for_config([ + {"name": "llama"} + ]) + self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes + self.assertTrue(any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes)) + self.assertTrue(any(name.startswith("llama_attn.w0") for name, _ in llama_shapes)) + self.assertTrue(any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes)) + self.assertTrue(any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes)) + + # Test pow2 shapes + pow2_shapes = get_shapes_for_config([ + {"name": "pow2", "min_power": 10, "max_power": 12} + ]) + self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12) + self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10 + self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11 + self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12 + + # Test pow2_extended shapes + pow2_extended_shapes = get_shapes_for_config([ + {"name": "pow2_extended", "min_power": 10, "max_power": 11} + ]) + self.assertEqual(len(pow2_extended_shapes), 4) # 2 powers of 2, each with 2 variants + self.assertEqual(pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])) # 2^10 + self.assertEqual(pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])) # 2^10 + 2^9 + self.assertEqual(pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])) # 2^11 + self.assertEqual(pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])) # 2^11 + 2^10 + + # Test sweep shapes (limited to a small range for testing) + sweep_shapes = get_shapes_for_config([ + {"name": "sweep", "min_power": 8, "max_power": 9} + ]) + # For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations) + self.assertEqual(len(sweep_shapes), 8) + # Check that all shapes have the expected format + for name, shape in sweep_shapes: + self.assertTrue(name.startswith("sweep_")) + self.assertEqual(len(shape), 3) # [M, K, N] + # Check that all dimensions are powers of 2 between 2^8 and 2^9 + for dim in shape: + self.assertTrue(dim in [256, 512]) # 2^8, 2^9 def test_get_param_combinations(self): model_param = self.test_config["model_params"][0] diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 9e978f70fa..4a326d3c26 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -753,6 +753,7 @@ def print_results(results: List[BenchmarkResult]): result.config.name, result.config.quantization or "baseline", result.config.sparsity or "none", + f"{result.config.shape_name} ({result.config.m}, {result.config.k}, {result.config.n})", f"{result.model_inference_time_in_ms:.2f}", str(result.config.enable_profiler), str(result.config.enable_memory_profile), @@ -774,6 +775,7 @@ def print_results(results: List[BenchmarkResult]): "Name", "Quantization", "Sparsity", + "Shape", "Inference Time (ms)", "Profiler Enabled", "Memory Profiling Enabled", From a886a270fb33b4775f1b8eb9fc3232a6a8e237c3 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 4 Apr 2025 10:59:42 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- .../test/test_benchmark_runner.py | 66 ++++++++++++------- benchmarks/microbenchmarks/test/test_utils.py | 62 +++++++++++------ benchmarks/microbenchmarks/utils.py | 46 +++++++------ 3 files changed, 110 insertions(+), 64 deletions(-) diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index 20991e4122..7f93213a22 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -63,40 +63,56 @@ def test_get_shapes_for_config(self): ) self.assertEqual(len(shapes), 1) self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024])) - + # Test llama shapes - llama_shapes = get_shapes_for_config([ - {"name": "llama"} - ]) + llama_shapes = get_shapes_for_config([{"name": "llama"}]) self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes - self.assertTrue(any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes)) - self.assertTrue(any(name.startswith("llama_attn.w0") for name, _ in llama_shapes)) - self.assertTrue(any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes)) - self.assertTrue(any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes)) - + self.assertTrue( + any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes) + ) + self.assertTrue( + any(name.startswith("llama_attn.w0") for name, _ in llama_shapes) + ) + self.assertTrue( + any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes) + ) + self.assertTrue( + any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes) + ) + # Test pow2 shapes - pow2_shapes = get_shapes_for_config([ - {"name": "pow2", "min_power": 10, "max_power": 12} - ]) + pow2_shapes = get_shapes_for_config( + [{"name": "pow2", "min_power": 10, "max_power": 12}] + ) self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12) self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10 self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11 self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12 - + # Test pow2_extended shapes - pow2_extended_shapes = get_shapes_for_config([ - {"name": "pow2_extended", "min_power": 10, "max_power": 11} - ]) - self.assertEqual(len(pow2_extended_shapes), 4) # 2 powers of 2, each with 2 variants - self.assertEqual(pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])) # 2^10 - self.assertEqual(pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])) # 2^10 + 2^9 - self.assertEqual(pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])) # 2^11 - self.assertEqual(pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])) # 2^11 + 2^10 - + pow2_extended_shapes = get_shapes_for_config( + [{"name": "pow2_extended", "min_power": 10, "max_power": 11}] + ) + self.assertEqual( + len(pow2_extended_shapes), 4 + ) # 2 powers of 2, each with 2 variants + self.assertEqual( + pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024]) + ) # 2^10 + self.assertEqual( + pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536]) + ) # 2^10 + 2^9 + self.assertEqual( + pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048]) + ) # 2^11 + self.assertEqual( + pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072]) + ) # 2^11 + 2^10 + # Test sweep shapes (limited to a small range for testing) - sweep_shapes = get_shapes_for_config([ - {"name": "sweep", "min_power": 8, "max_power": 9} - ]) + sweep_shapes = get_shapes_for_config( + [{"name": "sweep", "min_power": 8, "max_power": 9}] + ) # For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations) self.assertEqual(len(sweep_shapes), 8) # Check that all shapes have the expected format diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 46f6a74685..20d8915e78 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -171,7 +171,7 @@ def test_rms_norm(self): x = torch.randn(16, 64) out = rms_norm(x) self.assertEqual(out.shape, (16, 64)) - + # Test with different eps rms_norm = RMSNorm(dim=64, eps=1e-5) out = rms_norm(x) @@ -184,38 +184,50 @@ def test_rms_norm_linear_activation(self): out = model(x) self.assertEqual(out.shape, (16, 32)) self.assertEqual(out.dtype, torch.float32) - + # Test with ReLU activation - model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu") + model = RMSNormLinearActivation( + fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu" + ) out = model(x) self.assertEqual(out.shape, (16, 32)) self.assertTrue(torch.all(out >= 0)) # Check ReLU output range - + # Test with SiLU activation - model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu") + model = RMSNormLinearActivation( + fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu" + ) out = model(x) self.assertEqual(out.shape, (16, 32)) - + # Test with invalid activation with self.assertRaises(ValueError): - RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid") + RMSNormLinearActivation( + fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid" + ) def test_transformer_block(self): # Test with default parameters - model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + model = TransformerBlock( + hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32 + ) x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim] out = model(x) self.assertEqual(out.shape, (16, 16, 64)) self.assertEqual(out.dtype, torch.float32) - + # Test with different parameters - model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32) + model = TransformerBlock( + hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32 + ) x = torch.randn(8, 32, 128) out = model(x) self.assertEqual(out.shape, (8, 32, 128)) - + # Test with different head dimensions - model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32) + model = TransformerBlock( + hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32 + ) x = torch.randn(4, 8, 96) out = model(x) self.assertEqual(out.shape, (4, 8, 96)) @@ -255,7 +267,7 @@ def test_create_model_and_input(self): ) self.assertIsInstance(model, RMSNormLinearActivation) self.assertEqual(input_data.shape, (m, k)) - + # Test TransformerBlock model, input_data = create_model_and_input( model_type="transformer_block", @@ -266,40 +278,50 @@ def test_create_model_and_input(self): device="cpu", ) self.assertIsInstance(model, TransformerBlock) - self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim] + self.assertEqual( + input_data.shape, (m, 16, k) + ) # [batch_size, seq_len, hidden_dim] def test_quantization_on_models(self): # Test quantization on RMSNormLinearActivation model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) x = torch.randn(16, 64) - + # Test with Int8WeightOnlyConfig config = string_to_config(quantization="int8wo", sparsity=None) if config is not None: # Skip quantization test if torchao.quantization.quantize is not available try: from torchao.quantization import quantize + quantized_model = quantize(model, config) out = quantized_model(x) self.assertEqual(out.shape, (16, 32)) except ImportError: - print("Skipping quantization test: torchao.quantization.quantize not available") - + print( + "Skipping quantization test: torchao.quantization.quantize not available" + ) + # Test quantization on TransformerBlock - model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + model = TransformerBlock( + hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32 + ) x = torch.randn(16, 16, 64) - + # Test with Int8WeightOnlyConfig config = string_to_config(quantization="int8wo", sparsity=None) if config is not None: # Skip quantization test if torchao.quantization.quantize is not available try: from torchao.quantization import quantize + quantized_model = quantize(model, config) out = quantized_model(x) self.assertEqual(out.shape, (16, 16, 64)) except ImportError: - print("Skipping quantization test: torchao.quantization.quantize not available") + print( + "Skipping quantization test: torchao.quantization.quantize not available" + ) def test_generate_results_csv(self): results = [ diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 4a326d3c26..d405d985a1 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -399,7 +399,7 @@ def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"): super().__init__() self.rms_norm = RMSNorm(fc_dim1, dtype=dtype) self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) - + if activation == "gelu": self.activation = torch.nn.GELU() elif activation == "relu": @@ -422,58 +422,64 @@ def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_dim = hidden_dim // num_heads - + # Self-attention self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) - + # MLP self.mlp_ratio = mlp_ratio self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) - self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype) - self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype) - + self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to( + dtype + ) + self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to( + dtype + ) + # Layer norms self.norm1 = RMSNorm(hidden_dim, dtype=dtype) self.norm2 = RMSNorm(hidden_dim, dtype=dtype) - + # Activation self.activation = torch.nn.GELU() def forward(self, x): batch_size, seq_len, _ = x.shape - + # Self-attention residual = x x = self.norm1(x) - + # Reshape qkv projection for better memory layout qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) - qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim] + qkv = qkv.permute( + 2, 0, 3, 1, 4 + ) # [3, batch_size, num_heads, seq_len, head_dim] q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] - + # Scaled dot-product attention with proper reshaping # Reshape for better memory layout and avoid broadcasting issues q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) - + # Compute attention scores - attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5)) + attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim**0.5)) attn = torch.softmax(attn, dim=-1) - + # Apply attention to values x = attn @ v # [batch_size * num_heads, seq_len, head_dim] - + # Reshape back to original dimensions x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) - + # Project back to hidden dimension x = self.proj(x) x = residual + x - + # MLP residual = x x = self.norm2(x) @@ -481,7 +487,7 @@ def forward(self, x): x = self.activation(x) x = self.mlp_fc2(x) x = residual + x - + return x @@ -683,7 +689,9 @@ def create_model_and_input( input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) elif model_type == "transformer_block": # For transformer block, k is the hidden dimension - model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device) + model = TransformerBlock( + k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype + ).to(device) # Input shape for transformer is [batch_size, seq_len, hidden_dim] input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) else: From 36b3639f88d8eef9fd0c24f8e5ee72b36bf6a23e Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 7 Apr 2025 11:08:57 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- .../microbenchmarks/test/benchmark_config.yml | 80 ++++++++++--------- benchmarks/microbenchmarks/utils.py | 7 +- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index b0c660dfce..63e975a8ad 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -2,27 +2,29 @@ benchmark_mode: "inference" quantization_config_recipe_names: # Will run a baseline inference for model by default, without quantization for comparison - # - "int4wo-32" + - "int4wo-32" # - "marlin" - "int8wo" + - "int8dq" + - "float8dq" # sparsity_config_recipe_names: # Will run a baseline inference for model by default, without sparsity for comparison # - "semi-sparse" # - "block" output_dir: "benchmarks/microbenchmarks/results" model_params: - # - name: "small_bf16_linear" - # matrix_shapes: - # - name: "custom" - # shapes: [ - # [1024, 1024, 1024], # [m, k, n] - # ] - # high_precision_dtype: "torch.bfloat16" - # use_torch_compile: true - # torch_compile_mode: "max-autotune" - # device: "cuda" - # model_type: "linear" - # enable_profiler: true # Enable profiling for this model + - name: "small_bf16_linear" + matrix_shapes: + - name: "custom" + shapes: [ + [1024, 1024, 1024], # [m, k, n] + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "linear" + enable_profiler: true # Enable profiling for this model - name: "large_bf16_ln_linear" matrix_shapes: @@ -65,30 +67,30 @@ model_params: # model_type: "linear" # enable_profiler: true # Enable profiling for this model - - name: "bf16_rms_norm_linear_activation" - matrix_shapes: - - name: "custom" - shapes: [ - [2048, 4096, 1024], - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "rms_norm_linear_activation" - enable_profiler: true - enable_memory_profile: true + # - name: "bf16_rms_norm_linear_activation" + # matrix_shapes: + # - name: "custom" + # shapes: [ + # [2048, 4096, 1024], + # ] + # high_precision_dtype: "torch.bfloat16" + # use_torch_compile: true + # torch_compile_mode: "max-autotune" + # device: "cuda" + # model_type: "rms_norm_linear_activation" + # enable_profiler: true + # enable_memory_profile: true - - name: "bf16_transformer_block" - matrix_shapes: - - name: "custom" - shapes: [ - [2048, 4096, 1024], # For transformer_block, k is the hidden dimension - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "transformer_block" - enable_profiler: true - enable_memory_profile: true + # - name: "bf16_transformer_block" + # matrix_shapes: + # - name: "custom" + # shapes: [ + # [2048, 4096, 1024], # For transformer_block, k is the hidden dimension + # ] + # high_precision_dtype: "torch.bfloat16" + # use_torch_compile: true + # torch_compile_mode: "max-autotune" + # device: "cuda" + # model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition) + # enable_profiler: true + # enable_memory_profile: true diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index d405d985a1..dbbb86b37b 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -359,6 +359,8 @@ def to_dict(self) -> Dict[str, Any]: return result_dict +# TODO: MOE block (Maybe) +# TODO: Move stuff to torchao/testing class ToyLinearModel(torch.nn.Module): def __init__(self, k=64, n=32, dtype=torch.bfloat16): super().__init__() @@ -369,12 +371,14 @@ def forward(self, x): return x +# TODO: Maybe we can specify a diy for activation function and use it in the model +# TODO: MLP block class LNLinearSigmoid(torch.nn.Module): def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16): super().__init__() self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False) self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) - self.sigmoid = torch.nn.Sigmoid() + self.sigmoid = torch.nn.Sigmoid() # TODO: Find a way to make it configurable def forward(self, x): x = self.ln(x) @@ -383,6 +387,7 @@ def forward(self, x): return x +# TODO: We might not need it, need to figure of it's relevant in any technique class RMSNorm(torch.nn.Module): def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16): super().__init__()