From f97da6e8bc72722217be294ef7f18e278262dd7f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 17 Sep 2024 20:32:55 -0700 Subject: [PATCH] Add example code for printing the operator and shapes in a model Summary: This will be useful for people to do understand the ops/shapes for a model that they are interested in optimizing, also helpful for microbenchmarks with target ops/shapes Test Plan: python tutorials/developer_api_guide/print_op_and_shapes.py Reviewers: Subscribers: Tasks: Tags: --- .../print_op_and_shapes.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tutorials/developer_api_guide/print_op_and_shapes.py diff --git a/tutorials/developer_api_guide/print_op_and_shapes.py b/tutorials/developer_api_guide/print_op_and_shapes.py new file mode 100644 index 0000000000..0be26fd941 --- /dev/null +++ b/tutorials/developer_api_guide/print_op_and_shapes.py @@ -0,0 +1,35 @@ +import torch + +linear_shapes = [] +from torch.overrides import TorchFunctionMode +class TorchFunctionLoggingMode(TorchFunctionMode): + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.linear: + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + flattened_input_tensor = input_tensor.view(-1, input_tensor.shape[-1]) + M, K = flattened_input_tensor.shape[0], flattened_input_tensor.shape[1] + assert K == weight_tensor.shape[1] + N = weight_tensor.shape[0] + print(f"TORCH_FUNC={str(func)} (M, K, N):", M, K, N) + linear_shapes.append((M, K, N)) + else: + arg_shape = args[0].shape if len(args) > 0 and isinstance(args[0], torch.Tensor) else None + print(f"TORCH_FUNC={str(func)} args[0] shape:", arg_shape) + return func(*args, **kwargs) + +# NOTE: Modify this with your own model +from torchvision import models +m = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1) +example_inputs = (torch.randn(1, 3, 224, 224),) + +with TorchFunctionLoggingMode(): + m(*example_inputs) + +print() +print("all linear shapes (M, K, N):", linear_shapes)