Skip to content

Use the extended jax.experimental.colocated_python.colocated_cpu_devices API #1822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions MaxText/multihost_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,9 @@ def _colocated_cpu_devices(
return colocated_python.colocated_cpu_devices(devices)


def _get_cpu_mesh(mesh: Mesh):
flat_devices = tuple(mesh.devices.flat)
flat_cpu_devices = _colocated_cpu_devices(flat_devices)
cpu_mesh = jax.sharding.Mesh(
np.array(flat_cpu_devices).reshape(mesh.devices.shape), mesh.axis_names, axis_types=mesh.axis_types
)
return cpu_mesh
def _colocated_cpu_mesh(mesh: Mesh) -> Mesh:
"""Returns a CPU mesh that has colocated CPU devices."""
return colocated_python.colocated_cpu_devices(mesh)


class RemoteIterator:
Expand All @@ -179,7 +175,7 @@ class RemoteIterator:
def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape):
self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
self.tpu_devices = jax.local_devices()
self.cpu_mesh = _get_cpu_mesh(global_mesh)
self.cpu_mesh = _colocated_cpu_mesh(global_mesh)
self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
self.dummy_array = jnp.zeros((len(self.cpu_devices)))
Expand Down
Loading