Skip to content

Torchao import time #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
felipemello1 opened this issue Mar 24, 2025 · 1 comment
Open

Torchao import time #1944

felipemello1 opened this issue Mar 24, 2025 · 1 comment

Comments

@felipemello1
Copy link

felipemello1 commented Mar 24, 2025

Hi folks, not a bug. In torchtune, importing the library takes ~7s. When I profile it, majority is coming from torchao imports.

just a simple 'import torchao' takes ~4s

import time
start = time.perf_counter()
import torchao
end = time.perf_counter()
print("time import: ", end - start)

Its possible to do some profiling like this (by cumulative):

python -X importtime -c "import torchao" 2> torchao_import_times.txt
sort -k5,5nr torchao_import_times.txt > sorted_torchao_by_cumulative.txt

by self

sort -k3,3nr torchao_import_times.txt > sorted_torchao_by_self.txt

Just wanted to share it here in case someone wants to take a look. Thanks!

outputs self

import time:    763864 |     819415 |     torch._C
import time:    436553 |     436553 |                             torchao.float8.float8_utils
import time:    247885 |     249457 |           torch._prims
import time:    220180 |     999455 |     torchao.quantization.autoquant
import time:    150263 |     285311 |               torch._inductor.decomposition
import time:    111738 |     111738 |                         torch.ao.quantization.quantizer.quantizer
import time:    109062 |     478208 |     torch._meta_registrations
import time:     83299 |     238870 |                         torch._dynamo.codegen
import time:     57174 |      57174 |                           mpmath.ctx_iv
import time:     53617 |      73218 |         torch._refs
import time:     42450 |     291907 |         torch._decomp.decompositions
import time:     36143 |      36143 |                                                 torch.distributed.tensor._collective_utils
import time:     32693 |     212937 |           torch._dynamo.polyfills.loader
import time:     30891 |    1893773 |   torch
import time:     25798 |      25798 |                       torch._dynamo.source
import time:     22112 |      23282 |                 torchgen.model
import time:     20698 |      20698 |                     triton.language.standard
import time:     16428 |     373652 |                   torch.fx.experimental.symbolic_shapes
import time:     15278 |      15627 |           torch._inductor.config
import time:     15120 |      15120 |                                   _csv
import time:     14515 |      43410 |           torch.fx.experimental._constant_symnode
import time:     13942 |      15210 |                         torch.fx.node
import time:     13130 |      13130 |                 torch._inductor.inductor_prims
import time:     12562 |     817979 |             torch._dynamo.symbolic_convert
import time:     11626 |      11735 |                       triton._C.libtriton
import time:     11372 |      12810 |                   numpy._core._multiarray_umath
import time:     10966 |     401829 |               torch._dynamo.trace_rules
import time:      9922 |       9922 |                       networkx.utils.backends
import time:      9892 |      37945 |                                 torch._subclasses.fake_tensor
import time:      9745 |      38499 |                             torch.utils._pytree
import time:      9706 |     136029 |         torchao.kernel.intmm_triton
import time:      9687 |      79312 |               torch.nn.functional
import time:      9642 |       9642 |                           torch.onnx.symbolic_opset9
import time:      9526 |       9776 |                                   torch._subclasses.meta_utils
import time:      8763 |       8763 |                     triton.language.random
import time:      8699 |       8699 |                     torch._functorch._aot_autograd.schemas
import time:      8633 |       8633 |                                 torch._dynamo._trace_wrapped_higher_order_op
import time:      8586 |       8586 |                                         dill.session
import time:      8447 |     398127 |                 torch._dynamo.utils
import time:      7971 |      10009 |                                   sympy.integrals.transforms
import time:      7683 |       7683 |                                           torch.distributed.tensor._ops._pointwise_ops
import time:      7199 |     350781 |                     torch.utils._sympy.functions
import time:      7145 |       7361 |                             torch._guards
import time:      6978 |      65819 |                     triton.backends
import time:      6969 |     107833 |                   networkx
import time:      6952 |       6952 |                         torch.ao.quantization.backend_config.backend_config
import time:      6506 |       8114 |                           torch._dynamo.variables.torch_function
import time:      6481 |       6648 |           torch._refs.nn.functional
import time:      6459 |       7007 |               torch._dynamo.config
import time:      6405 |       6405 |           torch._refs.fft
import time:      6368 |       6368 |                               sympy.sets.handlers.intersection
import time:      6067 |       6283 |                                   torch._subclasses.fake_impls
import time:      5856 |       5856 |             torch.masked.maskedtensor.unary
import time:      5738 |      12546 |                                               sympy.functions.combinatorial.numbers
import time:      5604 |      18271 |                                 sympy.core.numbers
import time:      5450 |       5450 |                                     torch.distributed.fsdp.api
import time:      5442 |       5442 |                         torch.export.graph_signature
import time:      5193 |      15833 |                 torch._functorch._aot_autograd.runtime_wrappers
import time:      4764 |       6177 |         torch.profiler._memory_profiler

outputs cumulative:

import time: self [us] | cumulative | imported package
import time:      1740 |    4098456 | torchao
import time:       534 |    2193863 |   torchao.quantization
import time:     30891 |    1893773 |   torch
import time:       123 |    1186976 |     torchao.kernel
import time:       299 |    1186853 |       torchao.kernel.intmm
import time:       506 |    1046467 |         torch._dynamo
import time:    220180 |     999455 |     torchao.quantization.autoquant
import time:      1298 |     830580 |           torch._dynamo.convert_frame
import time:    763864 |     819415 |     torch._C
import time:     12562 |     817979 |             torch._dynamo.symbolic_convert
import time:       298 |     775419 |       torchao.dtypes
import time:      1153 |     772792 |         torchao.dtypes.affine_quantized_tensor_ops
import time:    109062 |     478208 |     torch._meta_registrations
import time:        42 |     471058 |           torchao.dtypes.floatx.float8_layout
import time:       325 |     471016 |             torchao.dtypes.floatx
import time:      1162 |     469622 |               torchao.dtypes.floatx.float8_layout
import time:        30 |     468460 |                 torchao.float8.inference
import time:       268 |     468431 |                   torchao.float8
import time:       528 |     463194 |                     torchao.float8.float8_linear_utils
import time:       448 |     462115 |                       torchao.float8.float8_linear
import time:       206 |     460604 |                         torchao.float8.distributed_utils
import time:      1337 |     437890 |                           torchao.float8.float8_tensor
import time:    436553 |     436553 |                             torchao.float8.float8_utils
import time:     10966 |     401829 |               torch._dynamo.trace_rules
import time:      1147 |     399274 |               torch._dynamo.exc
import time:      8447 |     398127 |                 torch._dynamo.utils
import time:       334 |     384203 |                 torch._dynamo.variables
import time:       876 |     383869 |                   torch._dynamo.variables.base
import time:     16428 |     373652 |                   torch.fx.experimental.symbolic_shapes
import time:       546 |     365670 |       torch._decomp
import time:      7199 |     350781 |                     torch.utils._sympy.functions
import time:      3333 |     344590 |                     torch._dynamo.variables.builder
import time:      1394 |     342723 |                       sympy
import time:     42450 |     291907 |         torch._decomp.decompositions
import time:       439 |     291042 |           torchao.dtypes.affine_quantized_tensor
import time:      3597 |     289228 |             torchao.quantization.quant_primitives
import time:    150263 |     285311 |               torch._inductor.decomposition
import time:    247885 |     249457 |           torch._prims
import time:       563 |     239433 |                       torch._dynamo.side_effects
import time:     83299 |     238870 |                         torch._dynamo.codegen
import time:     32693 |     212937 |           torch._dynamo.polyfills.loader
import time:      3344 |     206276 |     torch.functional
import time:        24 |     202417 |       torch.nn.functional
import time:       288 |     202393 |         torch.nn
import time:       768 |     186719 |           torch.nn.modules
import time:      1098 |     179754 |             torch._functorch.aot_autograd
import time:       435 |     148021 |                         sympy.polys
import time:       353 |     138673 |     torch.quantization
import time:       171 |     137354 |       torch.quantization.fake_quantize
import time:        28 |     137183 |         torch.ao.quantization.fake_quantize
import time:       588 |     137155 |           torch.ao.quantization
import time:      9706 |     136029 |         torchao.kernel.intmm_triton
import time:        26 |     135048 |                 torch.ao.quantization.fx._decomposed
import time:       238 |     135022 |                   torch.ao.quantization.fx
import time:      1581 |     130600 |                           torch._dynamo.variables.functions
import time:      2767 |     125378 |               torch._functorch.partitioners
import time:        24 |     124152 |                             torch.distributed.fsdp._fully_shard
import time:       368 |     124128 |                               torch.distributed.fsdp
import time:      2008 |     123693 |             torch.ao.quantization.pt2e._numeric_debugger
import time:       357 |     120122 |               torch.ao.quantization.pt2e.graph_utils
import time:       505 |     118673 |                 torch.export
import time:       771 |     112702 |                     torch.ao.quantization.fx.prepare
import time:       194 |     111932 |                       torch.ao.quantization.quantizer
import time:    111738 |     111738 |                         torch.ao.quantization.quantizer.quantizer
import time:       259 |     110492 |           triton
import time:       498 |     108531 |                 torch._functorch._activation_checkpointing.graph_info_provider
import time:      6969 |     107833 |                   networkx
import time:      2846 |      97056 |                                 torch.distributed.fsdp._flat_param
import time:       315 |      85810 |                                   torch.distributed.fsdp._fsdp_extensions
import time:      1922 |      85796 |             torch.nn.modules.module
import time:        31 |      84666 |                   torch.fx.passes.infra.pass_base
import time:       207 |      84636 |                     torch.fx.passes.infra
import time:       380 |      83535 |                       torch.fx.passes
import time:      2907 |      83486 |                       torch._dynamo.variables.torch
import time:       228 |      83183 |                           sympy.polys.polyfuncs
import time:       337 |      82955 |                             sympy.polys.specialpolys
import time:      2156 |      82613 |                         mpmath
import time:      1691 |      82388 |               torch.utils._python_dispatch
import time:       523 |      80528 |             torch.nn.modules.linear
import time:      9687 |      79312 |               torch.nn.functional
import time:      1201 |      78049 |                               sympy.polys.rings
import time:       173 |      75427 |                                 sympy.printing.defaults
import time:       612 |      75255 |                                   sympy.printing
import time:     53617 |      73218 |         torch._refs
import time:       412 |      70004 |                         torch.fx.passes.graph_drawer
import time:       278 |      69541 |                                     torch.distributed.fsdp._shard_utils
import time:       675 |      69482 |                           torch.fx.passes.shape_prop
import time:       204 |      69263 |                                       torch.distributed.tensor
import time:       411 |      69060 |                                         torch.distributed.tensor._ops
import time:       188 |      67813 |             triton.runtime
import time:       355 |      67626 |               triton.runtime.autotuner
import time:       909 |      66945 |                 triton.runtime.jit
import time:       900 |      66753 |                 torch._jit_internal
import time:       218 |      66036 |                   triton.runtime.driver
import time:      6978 |      65819 |                     triton.backends
import time:       762 |      60191 |                                     sympy.printing.pycode
import time:      1307 |      59438 |                   torch.distributed.rpc
import time:       501 |      59430 |                                       sympy.printing.codeprinter
import time:        23 |      58930 |                                         sympy.functions.elementary.complexes
import time:        22 |      58907 |                                           sympy.functions.elementary
import time:       641 |      58885 |                                             sympy.functions
import time:      1467 |      57545 |                     networkx.algorithms
import time:       380 |      57234 |                 torch.utils
import time:     57174 |      57174 |                           mpmath.ctx_iv
import time:      1157 |      55663 |                         torch.onnx.utils
import time:      3006 |      55551 |       numpy
import time:       232 |      54895 |                   torch.utils.data
import time:       188 |      54830 |                     torch.distributed.rpc.server_process_global_profiler
import time:       267 |      54643 |                       torch.autograd.profiler_legacy
import time:      1045 |      54408 |                     torch.utils.data.dataloader
import time:      2022 |      54376 |                         torch.autograd
import time:       278 |      51423 |                                           torch.distributed.tensor._ops._conv_ops
import time:       211 |      50727 |                           torch.onnx._internal.diagnostics
import time:       501 |      49081 |                             torch.onnx._internal.diagnostics._diagnostic
import time:       197 |      48581 |                               torch.onnx._internal.diagnostics.infra
import time:      4544 |      45907 |                                 torch.onnx._internal.diagnostics.infra._infra
import time:       378 |      45331 |     torch.nested
import time:       691 |      44954 |       torch.nested._internal.nested_tensor
import time:       318 |      44771 |                       triton.runtime.build
import time:       586 |      43970 |                         setuptools
import time:       230 |      43639 |         torch.nested._internal.nested_int
import time:     14515 |      43410 |           torch.fx.experimental._constant_symnode
import time:       999 |      43305 |                                             torch.distributed.tensor._dtensor_spec
import time:      2460 |      42306 |                                               torch.distributed.tensor.placement_types
import time:       145 |      42012 |             triton.compiler
import time:       522 |      41868 |               triton.compiler.compiler
import time:       258 |      41363 |                                   torch.onnx._internal.diagnostics.infra.formatter
import time:       625 |      41106 |                                     torch.onnx._internal.diagnostics.infra.sarif
import time:      1153 |      40580 |                 triton.compiler.code_generator
import time:       459 |      39132 |                   triton.language
import time:        30 |      39052 |                         sympy.core.cache
import time:       346 |      39022 |                           sympy.core
import time:       409 |      38907 |                           torch._vmap_internals
import time:        28 |      38576 |                             torch._subclasses.meta_utils
import time:       203 |      38549 |                               torch._subclasses
import time:      9745 |      38499 |                             torch.utils._pytree
import time:      3227 |      38278 |                     torch._dynamo.guards
import time:      9892 |      37945 |                                 torch._subclasses.fake_tensor
import time:     36143 |      36143 |                                                 torch.distributed.tensor._collective_utils
import time:       236 |      35946 |                       torch.utils.data.graph_settings
import time:        18 |      35517 |                         torch.utils.data.datapipes.iter.sharding
import time:        16 |      35500 |                           torch.utils.data.datapipes.iter
import time:       175 |      35484 |                             torch.utils.data.datapipes
import time:      2365 |      34309 |               torch._functorch._aot_autograd.autograd_cache
import time:       255 |      32348 |                           sympy.polys.partfrac
import time:       294 |      32093 |                             sympy.matrices
import time:       161 |      28896 |             torch.fx.experimental
import time:      1691 |      28754 |                               importlib.metadata
import time:       224 |      28735 |               torch.fx
import time:      1794 |      28666 |                             sympy.core.expr
import time:       445 |      27855 |         numpy.__config__
import time:        24 |      27410 |           numpy._core._multiarray_umath
import time:       723 |      27387 |             numpy._core
import time:       273 |      26109 |                   torch.export.decomp_utils
import time:       778 |      26105 |                           setuptools.dist
import time:      1567 |      25842 |                 torch.fx._symbolic_trace
import time:        36 |      25837 |                     torch._export.utils
import time:       952 |      25802 |                       torch._export
import time:     25798 |      25798 |                       torch._dynamo.source
import time:      1244 |      25464 | site
import time:       907 |      24864 |     torch.distributions
import time:     22112 |      23282 |                 torchgen.model
import time:       421 |      22871 |                             torch._dispatch.python
import time:       591 |      22721 |         numpy.lib
import time:       203 |      22618 |                               torch.utils.data.datapipes.dataframe
import time:       886 |      22509 |                           torch.distributed._tensor
import time:      1783 |      22340 |                               unittest.mock
import time:       743 |      21252 |   __editable___torchtune_0_0_0_finder
import time:       310 |      21224 |                   torch.fx._lazy_graph_module
import time:       886 |      20915 |                     torch.fx.graph_module
import time:      1856 |      20761 |                               sympy.core.mul
import time:     20698 |      20698 |                     triton.language.standard
import time:      2042 |      20645 |                           sympy.polys.polytools
import time:       881 |      20517 |                                 torch.utils.data.datapipes.dataframe.dataframes
import time:      3886 |      20029 |                       torch.fx.graph
import time:       965 |      19650 |                     torch.ao.quantization.fx.convert
import time:       565 |      19196 |                                   torch.utils.data.datapipes._decorator
import time:        52 |      18579 |                             torch.distributed.checkpoint.metadata
import time:       295 |      18528 |                               torch.distributed.checkpoint
import time:      5604 |      18271 |                                 sympy.core.numbers
import time:       408 |      18270 |                         torch._export.wrappers
import time:       393 |      18155 |                           mpmath.ctx_fp
import time:       262 |      17634 |                         sympy.geometry
import time:      1046 |      17413 |                                     torch.utils.data.datapipes.datapipe
import time:       187 |      17154 |     torch.nn.intrinsic
import time:       632 |      16896 |                             mpmath.ctx_base
import time:       747 |      16858 |                           torch._dynamo.variables.nn_module
import time:      1330 |      16812 |                               sympy.matrices.immutable
import time:        25 |      16503 |                           torch._higher_order_ops.strict_mode
import time:       532 |      16479 |                             torch._higher_order_ops
import time:      3976 |      16112 |                 torch._inductor.codecache
import time:        20 |      15956 |                                     torch.distributed._shard.sharded_tensor.api
import time:        16 |      15936 |                                       torch.distributed._shard.sharded_tensor
import time:       117 |      15921 |                                         torch.distributed._shard
import time:      5193 |      15833 |                 torch._functorch._aot_autograd.runtime_wrappers
import time:       279 |      15804 |                                           torch.distributed._shard.api
import time:       184 |      15751 |                         sympy.concrete
import time:       582 |      15702 |                                 csv
import time:     15278 |      15627 |           torch._inductor.config
import time:       493 |      15567 |                           sympy.concrete.products
import time:       395 |      15482 |                                 sympy.matrices.expressions
import time:       404 |      15448 |     torch._utils_internal
import time:      4500 |      15380 |                       torch.distributed
import time:     13942 |      15210 |                         torch.fx.node
import time:     15120 |      15120 |                                   _csv
import time:       602 |      14949 |           numpy.lib._index_tricks_impl
import time:       209 |      14817 |     torch.masked
import time:       434 |      14798 |                                       dill
import time:       580 |      14798 |     torch.hub
import time:       260 |      14729 |                             sympy.polys.constructor
import time:       827 |      14609 |       torch.masked._ops
import time:       375 |      14470 |                               sympy.polys.domains
import time:       389 |      14391 |                                             torch.distributed._shard.sharded_tensor
import time:      3002 |      14310 |                                 torch.distributed.fsdp.fully_sharded_data_parallel
import time:       833 |      14287 |                             sympy.concrete.summations
import time:      2792 |      14236 |               torch._inductor.output_code
import time:       672 |      14043 |               numpy._core.multiarray
import time:       420 |      14008 |                               sympy.matrices.dense
import time:       394 |      13735 |                                 asyncio
import time:        24 |      13621 |         torch.masked.maskedtensor.core
import time:       233 |      13597 |           torch.masked.maskedtensor
import time:       393 |      13372 |                 numpy._core.overrides
import time:        22 |      13188 |                               sympy.integrals.integrals
import time:       543 |      13186 |                             torch._dynamo.variables.lazy
import time:       190 |      13166 |                                 sympy.integrals
import time:     13130 |      13130 |                 torch._inductor.inductor_prims
import time:       974 |      13038 |                           _distutils_hack.override
import time:       294 |      12884 |                     networkx.utils
import time:     11372 |      12810 |                   numpy._core._multiarray_umath
import time:      1236 |      12644 |                               torch._dynamo.variables.tensor
import time:       155 |      12624 |             numpy.matrixlib
import time:      5738 |      12546 |                                               sympy.functions.combinatorial.numbers
import time:       356 |      12470 |               numpy.matrixlib.defmatrix
import time:       221 |      12396 |                                 torch.distributed.fsdp._fully_shard
import time:       201 |      12114 |                 numpy.linalg
import time:      3242 |      12102 |     torch.cuda
import time:      2196 |      12043 |     pathlib
import time:      2868 |      11991 |                                 sympy.matrices.matrixbase
import time:       733 |      11967 |                           sympy.geometry.point
import time:      1720 |      11779 |                   numpy.linalg._linalg
import time:     11626 |      11735 |                       triton._C.libtriton
import time:       152 |      11686 |       torch.nn.intrinsic.quantized
import time:      1421 |      11322 |                                   sympy.core.power
import time:       315 |      11322 |                     networkx.readwrite
import time:       335 |      11224 |                         sympy.solvers
import time:      1889 |      11094 |                             sympy.geometry.entity
import time:      1155 |      11042 |                                               torch.distributed._shard.sharded_tensor.api
import time:       152 |      10965 |         torch.nn.intrinsic.quantized.dynamic
import time:       255 |      10945 |                           sympy.polys.numberfields
import time:       148 |      10813 |           torch.nn.intrinsic.quantized.dynamic.modules
....
@supriyar
Copy link
Contributor

thanks for reporting this @felipemello1!

looks like some culprits are torchao.float8.float8_utils: ~437ms individual time and torchao.quantization.autoquant: ~220ms individual time and ~999ms cumulative

Also noticed that the float8 related modules have very deep nested imports for inference.

@vkuzo and @jerryzh168 can we fix this?

@drisspg drisspg added float8 and removed float8 labels Mar 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants