Skip to content

Fix deepspeed loading #37281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 5, 2025
Merged

Fix deepspeed loading #37281

merged 7 commits into from
Apr 5, 2025

Conversation

Cyrilvallez
Copy link
Member

What does this PR do?

@github-actions github-actions bot marked this pull request as draft April 4, 2025 14:07
Copy link
Contributor

github-actions bot commented Apr 4, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez marked this pull request as ready for review April 4, 2025 21:23
@sfc-gh-sbekman
Copy link

sfc-gh-sbekman commented Apr 4, 2025

@Cyrilvallez, here is a proper fix with meta device so it doesn't require much special treatment in the case of deepspeed zero3 other than in loading the state dict. The patch is against ad60356 in this PR as I had your earlier PR state when I started working on it, so it's not rebased.

diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py
index 1700301db5..6178729ef4 100644
--- a/src/transformers/integrations/deepspeed.py
+++ b/src/transformers/integrations/deepspeed.py
@@ -303,54 +303,6 @@ def deepspeed_config():
         return None


-def _load_state_dict_into_zero3_model(model_to_load, state_dict):
-    """
-    Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
-    tensor parallelism API.
-
-    Nearly identical code to PyTorch's `_load_from_state_dict`
-    """
-    # copy state_dict so `_load_state_dict_into_zero3_model` can modify it
-    metadata = getattr(state_dict, "_metadata", None)
-    state_dict = state_dict.copy()
-    if metadata is not None:
-        state_dict._metadata = metadata
-
-    error_msgs = []
-
-    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
-    # so we need to apply the function recursively.
-    def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
-        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
-        local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
-
-        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
-        # Parameters of module and children will start with prefix. We can exit early if there are none in this
-        # state_dict
-        if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
-            import deepspeed
-
-            # In sharded models, each shard has only part of the full state_dict, so only gather
-            # parameters that are in the current state_dict.
-            named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
-            params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
-            if len(params_to_gather) > 0:
-                # because zero3 puts placeholders in model params, this context
-                # manager gathers (unpartitions) the params of the current layer, then loads from
-                # the state dict and then re-partitions them again
-                with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
-                    if torch.distributed.get_rank() == 0:
-                        module._load_from_state_dict(*args)
-
-        for name, child in module._modules.items():
-            if child is not None:
-                load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
-
-    load(model_to_load, state_dict, assign_to_params_buffers=False)
-
-    return error_msgs
-
-
 def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
     """
     A convenience wrapper that deals with optimizer and lr scheduler configuration.
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 79823a5d9e..95bbcfad06 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -57,7 +57,6 @@ from .configuration_utils import PretrainedConfig
 from .dynamic_module_utils import custom_object_save
 from .generation import CompileConfig, GenerationConfig, GenerationMixin
 from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
-from .integrations.deepspeed import _load_state_dict_into_zero3_model
 from .integrations.flash_attention import flash_attention_forward
 from .integrations.flex_attention import flex_attention_forward
 from .integrations.sdpa_attention import sdpa_attention_forward
@@ -703,8 +702,19 @@ def _infer_parameter_dtype(
 def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
     """Cast a single parameter `param_name` into the `model`, with value `tensor`."""
     module, param_type = get_module_from_name(model, param_name)
-    # This will check potential shape mismatch if skipped before
-    module.load_state_dict({param_type: tensor}, strict=False, assign=True)
+    if is_deepspeed_zero3_enabled():
+        import deepspeed
+        module_named_parameters = dict(module.named_parameters(recurse=False))
+        # because zero3 puts placeholders in model params, this context
+        # manager gathers (unpartitions) the params of the current layer, then loads from
+        # the state dict and then re-partitions them again to all ranks
+        with deepspeed.zero.GatheredParameters(module_named_parameters[param_type], modifier_rank=0):
+            if torch.distributed.get_rank() == 0:
+                # assign=False is crucial for deepspeed zero3 init
+                module.load_state_dict({param_type: tensor}, strict=False, assign=False)
+    else:
+        # This will check potential shape mismatch if skipped before
+        module.load_state_dict({param_type: tensor}, strict=False, assign=True)


 @torch.no_grad()
@@ -725,7 +735,7 @@ def _load_state_dict_into_meta_model(
     unexpected_keys: Optional[List[str]] = None,  # passing `unexpected` for cleanup from quantization items
     device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
 ) -> Tuple[Optional[Dict], Optional[Dict]]:
-    """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
+    """Load parameters from `state_dict` into the model. The parameters of the `state_dict` are on the meta
     device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
     from `shard_file`, which is the actual state dict file on disk.
     This function takes care of correctly casting dtypes, devices, and sharding tensors in case of tensor parallelism.
@@ -4838,10 +4848,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
             # Fix the key names
             state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}

-            if is_deepspeed_zero3_enabled():
-                error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
             # Skip it with fsdp on ranks other than 0
-            elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
+            if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
                 disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
                     model_to_load,
                     state_dict,

@sfc-gh-sbekman
Copy link

Basically it's just this:

def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
    """Cast a single parameter `param_name` into the `model`, with value `tensor`."""
    module, param_type = get_module_from_name(model, param_name)
    if is_deepspeed_zero3_enabled():
        import deepspeed
        module_named_parameters = dict(module.named_parameters(recurse=False))
        # because zero3 puts placeholders in model params, this context
        # manager gathers (unpartitions) the params of the current layer, then loads from
        # the state dict and then re-partitions them again to all ranks
        with deepspeed.zero.GatheredParameters(module_named_parameters[param_type], modifier_rank=0):
            if torch.distributed.get_rank() == 0:
                # assign=False is crucial for deepspeed zero3 init
                module.load_state_dict({param_type: tensor}, strict=False, assign=False)
    else:
        # This will check potential shape mismatch if skipped before
        module.load_state_dict({param_type: tensor}, strict=False, assign=True)

and dropping the old _load_state_dict_into_zero3_model call and its source.

@sfc-gh-sbekman
Copy link

sfc-gh-sbekman commented Apr 4, 2025

This needs to be addressed as well #37296 at least for the deepspeed tests. Need to switch to .safetensors tiny model so that the new code pass is properly exercised.

For testing of this particular problem let's switch to:

diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py
index 11bea3c3aa..90a4200971 100644
--- a/tests/deepspeed/test_deepspeed.py
+++ b/tests/deepspeed/test_deepspeed.py
@@ -73,7 +73,7 @@ set_seed(42)
 DEFAULT_MASTER_PORT = "10999"

 T5_SMALL = "google-t5/t5-small"
-T5_TINY = "patrickvonplaten/t5-tiny-random"
+T5_TINY = "hf-internal-testing/tiny-random-LlamaForCausalLM"
 GPT2_TINY = "sshleifer/tiny-gpt2"
 GPTJ_TINY = "hf-internal-testing/tiny-random-gptj"

and please rename s/T5_TINY/LLAMA_TINY/ in the var name across the file. I tested it has safetensors.

@sfc-gh-sbekman
Copy link

sfc-gh-sbekman commented Apr 4, 2025

So the only thing that got dropped is the special treatment of local_metadata in the state dict, which originally was there for all models (not just loaded under deepspeed zero).

Do you handle it elsewhere?

@winglian winglian mentioned this pull request Apr 5, 2025
4 tasks
@Cyrilvallez
Copy link
Member Author

Hey @sfc-gh-sbekman, thanks a lot for the deep dive into this! For now, it is a simpler fix to simply use this PR -> loading the state dicts in cpu when deepspeed is activated, then feeding it to _load_state_dict_into_zero3_model is correct.

When things will be a little less crazy however, I will definitely come back to this one and move deepspeed loading to _load_state_dict_into_meta_model as you propose, as it's muchhh simpler to have only 1 code path for everything (it's the core idea of the huge changes we made to loading logic lately).
But it will need just a little more tweaking than what you propose, e.g. we'll need to make sure to declare the model on meta as well before loading weights to truly standardize (it's not the case with deepspeed currently) 😉

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to merge if this is enough to fix the issue, next release we can do as you say @stas00 and remove the rest!

@ArthurZucker ArthurZucker merged commit 84aa13d into main Apr 5, 2025
18 of 21 checks passed
@ArthurZucker ArthurZucker deleted the fix-deepspeed branch April 5, 2025 15:05
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Apr 5, 2025
* Update modeling_utils.py

* Update modeling_utils.py

* fix and remove all imports

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Update modeling_utils.py
@stas00
Copy link
Contributor

stas00 commented Apr 5, 2025

While this PR unbroke the most important functionality (thank you Cyril!), by making a new release you're creating a regression in at least 2 aspects:

  1. This test is failing after this merge: tests/deepspeed/test_deepspeed.py::CoreIntegrationDeepSpeed::test_init_zero3_missing_params and it wasn't failing before
    (6 other tests apparently have been failing for who knows how long)
  2. you removed critical functionality and just putting a model on cpu is a problem for low memory situations like google collab. Why remove code that worked if you're not replacing it with an equivalent.

Additionally, please fix the tests to use the correct tiny models as suggested here #37281 (comment)
this is absolutely critical to avoid this situation in the future. It was a pure luck I needed to work with transformers@main and noticed the problem. And hopefully someone on the team will pay attention to failing nightly tests. Nightly tests aren't unimportant tests, they are just heavy and thus don't run in the PR's CI.

If you're in a rush to release now my recommendation is to revert the 2 PRs that lead to a regression, make a release and then replay those PRs with the additional required fixes.

@ArthurZucker
Copy link
Collaborator

Completely agree cc @Cyrilvallez

@Cyrilvallez
Copy link
Member Author

Cyrilvallez commented Apr 5, 2025

Hey @stas00! Thanks for still looking into this!

Concerning your points:

  1. Indeed it still fails. However, your proposed changes do not resolve it. The correct fix is here Fix deepspeed loading (part 2) #37306 (see here https://github.com/huggingface/transformers/pull/36963/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL3790-L3796, it was never on meta with deepspeed before, as low_cpu_mem_usage was always False with deepspeed)
  2. There are no regresion in terms of memopry usage. See here https://github.com/huggingface/transformers/pull/36963/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL4248-R4181 and https://github.com/huggingface/transformers/pull/36963/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL4918-R4846 -> the state dicts always were loaded on cpu.

Hope this solves the issue! Don't hesitate if you find that something might stiill be missing! The tests will be taken care of very soon, I promise 😉

@stas00
Copy link
Contributor

stas00 commented Apr 5, 2025

My proposed patch was dealing only with the main breakage. thank you for fixing that remaining test, Cyril.

Zero3 (not ZeRO1 and not ZeRO2) was never on cpu because the model get init'ed directly on gpus. The problem was in reading the whole state_dict into cpu memory and that's why low_cpu_memory_usage was added - then later we solved the problem by sharding weights so that one didn't need to load the whole dict at once. After that I'm no longer sure what was done it's been a few years since I have been involved.

I applaud you for having the courage to rework so many special cases into the modern meta device approach, Cyril.

winglian added a commit to winglian/transformers that referenced this pull request Apr 6, 2025
@winglian winglian mentioned this pull request Apr 6, 2025
5 tasks
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* Update modeling_utils.py

* Update modeling_utils.py

* fix and remove all imports

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Update modeling_utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants