Skip to content

Commit 9ea38b3

Browse files
Caner Gocmenfacebook-github-bot
authored andcommitted
Add hashing for Topology
Differential Revision: D76004583
1 parent b1aa49c commit 9ea38b3

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

torchrec/distributed/planner/tests/test_types.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ParameterConstraints,
1919
Shard,
2020
ShardingOption,
21+
Topology,
2122
)
2223
from torchrec.distributed.types import (
2324
BoundsCheckMode,
@@ -214,6 +215,54 @@ def test_module_pooled_mch_ec(self) -> None:
214215
self.assertEqual(sharding_option.is_pooled, False)
215216

216217

218+
class TestTopologyHash(unittest.TestCase):
219+
def test_hash_equality(self) -> None:
220+
# Create two identical Topology instances
221+
topology1 = Topology(
222+
world_size=2,
223+
compute_device="cuda",
224+
hbm_cap=1024 * 1024 * 2,
225+
local_world_size=2,
226+
)
227+
228+
topology2 = Topology(
229+
world_size=2,
230+
compute_device="cuda",
231+
hbm_cap=1024 * 1024 * 2,
232+
local_world_size=2,
233+
)
234+
235+
# Verify that the hash values are equal
236+
self.assertEqual(
237+
hash(topology1),
238+
hash(topology2),
239+
"Hashes should be equal for identical Topology instances",
240+
)
241+
242+
def test_hash_inequality(self) -> None:
243+
# Create two different Topology instances
244+
topology1 = Topology(
245+
world_size=2,
246+
compute_device="cuda",
247+
hbm_cap=1024 * 1024 * 2,
248+
local_world_size=2,
249+
)
250+
251+
topology2 = Topology(
252+
world_size=4, # Different world_size
253+
compute_device="cuda",
254+
hbm_cap=1024 * 1024 * 2,
255+
local_world_size=2,
256+
)
257+
258+
# Verify that the hash values are different
259+
self.assertNotEqual(
260+
hash(topology1),
261+
hash(topology2),
262+
"Hashes should be different for different Topology instances",
263+
)
264+
265+
217266
class TestParameterConstraintsHash(unittest.TestCase):
218267

219268
def test_hash_equality(self) -> None:

torchrec/distributed/planner/types.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def get_bw(
248248

249249

250250
class Topology:
251+
"""
252+
Representation of a network of devices in a cluster.
253+
"""
254+
251255
def __init__(
252256
self,
253257
world_size: int,
@@ -396,6 +400,40 @@ def __repr__(self) -> str:
396400
topology_repr += str(self._comms_bandwidths) + "\n"
397401
return topology_repr
398402

403+
def __hash__(self) -> int:
404+
"""
405+
Compute a hash value for this Topology instance.
406+
407+
This allows Topology objects to be used as dictionary keys or in sets.
408+
The hash is based on the key attributes that define the topology.
409+
410+
Returns:
411+
int: A hash value for this Topology instance.
412+
"""
413+
414+
# Compute hbms and ddrs from the decives
415+
hbms = [device.storage.hbm for device in self._devices]
416+
ddrs = [device.storage.ddr for device in self._devices]
417+
418+
# Combine all attributes into a hashable tuple
419+
hashable_tuple = (
420+
self._world_size,
421+
self._compute_device,
422+
frozenset(hbms),
423+
frozenset(ddrs),
424+
self._local_world_size,
425+
self._hbm_mem_bw,
426+
self._ddr_mem_bw,
427+
self._hbm_to_ddr_mem_bw,
428+
self._comms_bandwidths.intra_host_bw,
429+
self._comms_bandwidths.inter_host_bw,
430+
self._bwd_compute_multiplier,
431+
self._weighted_feature_bwd_compute_multiplier,
432+
self._uneven_sharding_perf_multiplier,
433+
)
434+
435+
return hash(hashable_tuple)
436+
399437

400438
# ---- INPUT / OUTPUT ----- #
401439

0 commit comments

Comments
 (0)