Skip to content

Commit a0ac63c

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE]: Apply ruff PERF403 to use dict comprehensions more often (pytorch#149257)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#149257 Approved by: https://github.com/jansel
1 parent 811f587 commit a0ac63c

23 files changed

+48
-81
lines changed

.github/scripts/trymerge.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -819,10 +819,9 @@ def _get_reviews(self) -> list[tuple[str, str]]:
819819
cursor=info["reviews"]["pageInfo"]["startCursor"],
820820
)
821821
info = rc["data"]["repository"]["pullRequest"]
822-
reviews = {}
823-
for author, state in self._reviews:
824-
if state != "COMMENTED":
825-
reviews[author] = state
822+
reviews = {
823+
author: state for author, state in self._reviews if state != "COMMENTED"
824+
}
826825
return list(reviews.items())
827826

828827
def get_approved_by(self) -> list[str]:

benchmarks/operator_benchmark/benchmark_core.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,7 @@ def split(s):
296296
(key.strip(), value.strip())
297297
for key, value in map(lambda str: str.split(":"), key_vals) # noqa: C417
298298
] # ['M: (32, 16)', 'ZPB: 2'] -> [('M', '(32, 16)'), ('ZPB', '2')]
299-
for key, value in key_vals:
300-
out[key] = value
299+
out.update(key_vals)
301300

302301
return out
303302

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ ignore = [
7272
# these ignores are from ruff PERF; please fix!
7373
"PERF203",
7474
"PERF401",
75-
"PERF403",
7675
# these ignores are from PYI; please fix!
7776
"PYI024",
7877
"PYI036",

test/inductor/test_custom_post_grad_passes.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,7 @@ def merge_mm_shared_rhs(graph: fx.Graph):
219219
for m in matmuls:
220220
rhs_vals[m.args[1]].add(m)
221221

222-
order = {}
223-
for idx, n in enumerate(graph.nodes):
224-
order[n] = idx
222+
order = {n: idx for idx, n in enumerate(graph.nodes)}
225223

226224
for rhs, matmuls in rhs_vals.items():
227225
if len(matmuls) == 1:

test/test_fx.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -2324,9 +2324,7 @@ def test_deepcopy_recursion_depth(self):
23242324

23252325
copied_graph = copy.deepcopy(g)
23262326

2327-
val_map = {}
2328-
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2329-
val_map[orig_node] = new_node
2327+
val_map = dict(zip(g.nodes, copied_graph.nodes))
23302328

23312329
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
23322330
orig_users = set(orig_node.users.keys())

test/test_jit.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1761,8 +1761,7 @@ def doit(x, y):
17611761
for node in g.nodes():
17621762
n_ = g2.createClone(node, lambda x: g_to_g2[x])
17631763
g2.appendNode(n_)
1764-
for o, no in zip(node.outputs(), n_.outputs()):
1765-
g_to_g2[o] = no
1764+
g_to_g2.update(zip(node.outputs(), n_.outputs()))
17661765

17671766
for node in g.outputs():
17681767
g2.registerOutput(g_to_g2[node])

test/test_linalg.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ def tunableop_matmul(device, dtype):
9191

9292
def get_tunableop_validators():
9393
assert len(torch.cuda.tunable.get_validators()) > 0
94-
validators = {}
95-
for key, value in torch.cuda.tunable.get_validators():
96-
validators[key] = value
94+
validators = dict(torch.cuda.tunable.get_validators())
9795
return validators
9896

9997
class TestLinalg(TestCase):

torch/_export/serde/serialize.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1564,9 +1564,7 @@ def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram:
15641564
# TODO: Directly serialize exported_program.constants once
15651565
# CustomClassHolders get stored in the ExportedProgram rather than in
15661566
# the graph
1567-
constants: dict[str, Any] = {}
1568-
for n, c in gm_serializer.custom_objs.items():
1569-
constants[n] = c
1567+
constants: dict[str, Any] = gm_serializer.custom_objs.copy()
15701568
for n, t in exported_program.constants.items():
15711569
assert n not in constants
15721570
constants[n] = t

torch/_functorch/partitioners.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
580580
for node in gm.graph.find_nodes(op="placeholder"):
581581
env[node] = new_graph.node_copy(node, lambda x: env[x])
582582

583-
order = {}
584-
for idx, node in enumerate(gm.graph.nodes):
585-
order[node] = idx
583+
order = {node: idx for idx, node in enumerate(gm.graph.nodes)}
586584

587585
def insert_node_in_graph(node):
588586
cur_nodes = [node]

torch/_functorch/top_operators_github_usage.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,5 @@ def get_nn_functional_top_list():
625625
return top_nn_functional_
626626

627627

628-
usage_count = {}
629-
for k, v in get_nn_functional_top_list():
630-
usage_count[k] = v
631-
for k, v in top_torch:
632-
usage_count[k] = v
628+
usage_count = dict(get_nn_functional_top_list())
629+
usage_count.update(top_torch)

torch/_inductor/autotune_process.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,9 @@ def benchmark(
398398
assert self.processes is not None, "Tuning process pool is not initialized"
399399
assert self.executor is not None
400400

401-
results = {}
402-
403401
# Use a ThreadExecutorPool to spread the work across the subprocesses and
404402
# to grab subprocesses as soon as they're free.
405-
for choice, result in zip(choices, self.executor.map(self.target, choices)):
406-
results[choice] = result
403+
results = dict(zip(choices, self.executor.map(self.target, choices)))
407404

408405
return results
409406

torch/_inductor/memory.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ class BufferInfo:
267267

268268
# get the execution step of each node, this will be used to determine
269269
# the end_step of buffers
270-
node_to_step: dict[BaseSchedulerNode, int] = dict()
271-
for step, node in enumerate(nodes):
272-
node_to_step[node] = step
270+
node_to_step: dict[BaseSchedulerNode, int] = {
271+
node: step for step, node in enumerate(nodes)
272+
}
273273

274274
# get buffers' size and liveliness information
275275
buf_info_list: list[BufferInfo] = []

torch/_inductor/runtime/triton_heuristics.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
154154
else:
155155
call_kwargs[k] = v
156156
if not triton_version_uses_attrs_dict():
157-
for k, v in launcher.config.kwargs.items():
158-
call_kwargs[k] = v
157+
call_kwargs.update(launcher.config.kwargs)
159158
call_kwargs["num_warps"] = launcher.config.num_warps
160159
call_kwargs["num_stages"] = launcher.config.num_stages
161160
args_str = [*call_args]

torch/_inductor/utils.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,7 @@ def _type_of(key: Optional[torch.dtype]) -> str:
340340
"uint64": "u64",
341341
}
342342
# reinterpret can create triton type
343-
for v in list(tys.values()):
344-
tys[v] = v
343+
tys.update({v: v for v in list(tys.values())})
345344
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
346345

347346

@@ -635,9 +634,7 @@ def get_kernel_metadata(
635634
single_graph = inductor_nodes[0].graph
636635
# create a map of idx -> node and cache it
637636
if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"):
638-
node_to_idx_map = {}
639-
for idx, n in enumerate(single_graph.nodes):
640-
node_to_idx_map[n] = idx
637+
node_to_idx_map = {n: idx for idx, n in enumerate(single_graph.nodes)}
641638
single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map # type: ignore[attr-defined]
642639
inductor_nodes.sort(
643640
key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] # type: ignore[attr-defined]

torch/ao/ns/_numeric_suite.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,7 @@ def prepare_model_with_stubs(
368368
"quantization_api._numeric_suite.prepare_model_with_stubs"
369369
)
370370

371-
float_module_children = {}
372-
for name, mod in float_module.named_children():
373-
float_module_children[name] = mod
371+
float_module_children = dict(float_module.named_children())
374372

375373
reassign = {}
376374
for name, mod in q_module.named_children():

torch/ao/quantization/_correct_bias.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,11 @@ def bias_correction(
119119
float_model, quantized_model, _supported_modules, MeanShadowLogger
120120
)
121121

122-
uncorrected_modules = {}
123-
for name, submodule in quantized_model.named_modules():
124-
if type(submodule) in target_modules:
125-
uncorrected_modules[name] = submodule
122+
uncorrected_modules = {
123+
name: submodule
124+
for name, submodule in quantized_model.named_modules()
125+
if type(submodule) in target_modules
126+
}
126127

127128
for uncorrected_module in uncorrected_modules:
128129
quantized_submodule = get_module(quantized_model, uncorrected_module)

torch/ao/quantization/fx/qconfig_mapping_utils.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,8 @@ def _get_flattened_qconfig_dict(
376376
flattened: dict[Union[Callable, str], QConfigAny] = {
377377
"": qconfig_mapping.global_qconfig
378378
}
379-
for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
380-
flattened[obj] = qconfig
381-
for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
382-
flattened[obj] = qconfig
379+
flattened.update(qconfig_mapping.object_type_qconfigs)
380+
flattened.update(qconfig_mapping.module_name_qconfigs) # type: ignore[arg-type]
383381
return flattened
384382

385383

torch/distributed/checkpoint/state_dict.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,7 @@ def _load_model_state_dict(
596596
)
597597
elif info.full_state_dict:
598598
_distribute_state_dict(state_dict, local_state_dict, device=devices.pop())
599-
for fqn, local_state in local_state_dict.items():
600-
state_dict[fqn] = local_state
599+
state_dict.update(local_state_dict)
601600

602601
with info.fsdp_context():
603602
return cast(

torch/distributed/fsdp/_optim_utils.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,9 @@ def _unflatten_communicated_optim_state(
314314
unflat_state_param[state_name] = optim_state
315315

316316
# Add zero-dimension tensor state: take the target rank's value
317-
for state_name, zero_dim_tensor in sorted_items(zero_dim_tensor_state):
318-
unflat_state_param[state_name] = zero_dim_tensor
317+
unflat_state_param.update(sorted_items(zero_dim_tensor_state))
319318
# Add non-tensor state: take the target rank's value
320-
for state_name, non_tensor in sorted_items(non_tensor_state):
321-
unflat_state_param[state_name] = non_tensor
319+
unflat_state_param.update(sorted_items(non_tensor_state))
322320
unflat_param_state.append(unflat_state_param)
323321
return unflat_param_state
324322

@@ -1827,11 +1825,12 @@ def _convert_state_with_flat_params(
18271825
)
18281826
if to_save:
18291827
assert len(unflat_state) == len(optim_state_key.unflat_param_names)
1830-
for unflat_param_name, unflat_param_state in zip(
1831-
optim_state_key.unflat_param_names,
1832-
unflat_state,
1833-
):
1834-
fsdp_osd_state[unflat_param_name] = unflat_param_state
1828+
fsdp_osd_state.update(
1829+
zip(
1830+
optim_state_key.unflat_param_names,
1831+
unflat_state,
1832+
)
1833+
)
18351834
elif to_save:
18361835
assert len(optim_state_key.unflat_param_names) == 1
18371836
unflat_param_name = optim_state_key.unflat_param_names[0]

torch/distributed/tensor/_ops/_common_rules.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
265265
# check if we replace the all inputs dim char with singleton dimension,
266266
# if we replace all inputs, we also need to replace the output dimension.
267267
for output_dim_idx in range(len(out_dimchars)):
268-
out_dimchar = out_dimchars[output_dim_idx]
269268
if singleton_counter[output_dim_idx] == len(input_specs):
270269
out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx)
271270

@@ -274,12 +273,10 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
274273
enforce_sharding: dict[str, int] = {}
275274
if _is_inplace_op(op_schema.op):
276275
# inplace op should keep the input sharding it writes to
277-
for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map):
278-
enforce_sharding[out_dimchar] = mesh_dim
276+
enforce_sharding.update(zip(out_dimchars, input_specs[0].dim_map))
279277
elif _is_out_variant_op(op_schema.op):
280278
out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"])
281-
for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map):
282-
enforce_sharding[out_dimchar] = mesh_dim
279+
enforce_sharding.update(zip(out_dimchars, out_spec.dim_map))
283280

284281
return einop_rule(
285282
fmt,

torch/onnx/_internal/_exporter_legacy.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -596,13 +596,14 @@ def export(self) -> _onnx_program.ONNXProgram:
596596
# not valid.
597597
# Concrete data is expected to be filled for those initializers later during `ONNXProgram.save`.
598598
if self.options.fake_context is not None:
599-
initializers_with_real_tensors: dict[str, torch.Tensor] = {}
600-
for (
601-
initializer_name,
602-
initializer,
603-
) in onnxscript_graph.initializers.items():
604-
if not isinstance(initializer, torch._subclasses.FakeTensor):
605-
initializers_with_real_tensors[initializer_name] = initializer
599+
initializers_with_real_tensors: dict[str, torch.Tensor] = {
600+
initializer_name: initializer
601+
for (
602+
initializer_name,
603+
initializer,
604+
) in onnxscript_graph.initializers.items()
605+
if not isinstance(initializer, torch._subclasses.FakeTensor)
606+
}
606607
onnxscript_graph.initializers = initializers_with_real_tensors
607608

608609
# Export TorchScript graph to ONNX ModelProto.

torch/testing/_internal/common_cuda.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ def wrapper(f):
217217

218218
@functools.wraps(f)
219219
def wrapped(*args, **kwargs):
220-
for k, v in zip(arg_names, args):
221-
kwargs[k] = v
220+
kwargs.update(zip(arg_names, args))
222221
cond = torch.cuda.is_tf32_supported()
223222
if 'device' in kwargs:
224223
cond = cond and (torch.device(kwargs['device']).type == 'cuda')

torch/testing/_internal/common_mkldnn.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def wrapper(f):
6060

6161
@functools.wraps(f)
6262
def wrapped(*args, **kwargs):
63-
for k, v in zip(arg_names, args):
64-
kwargs[k] = v
63+
kwargs.update(zip(arg_names, args))
6564
cond = bf32_is_not_fp32()
6665
if "device" in kwargs:
6766
cond = cond and (torch.device(kwargs["device"]).type == "cpu")

0 commit comments

Comments
 (0)