-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Fix deepspeed loading #37281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix deepspeed loading #37281
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
2739e64
to
4362cad
Compare
@Cyrilvallez, here is a proper fix with meta device so it doesn't require much special treatment in the case of deepspeed zero3 other than in loading the state dict. The patch is against ad60356 in this PR as I had your earlier PR state when I started working on it, so it's not rebased.
|
Basically it's just this:
and dropping the old |
This needs to be addressed as well #37296 at least for the deepspeed tests. Need to switch to .safetensors tiny model so that the new code pass is properly exercised. For testing of this particular problem let's switch to:
and please rename |
So the only thing that got dropped is the special treatment of Do you handle it elsewhere? |
Hey @sfc-gh-sbekman, thanks a lot for the deep dive into this! For now, it is a simpler fix to simply use this PR -> loading the state dicts in cpu when deepspeed is activated, then feeding it to When things will be a little less crazy however, I will definitely come back to this one and move deepspeed loading to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to merge if this is enough to fix the issue, next release we can do as you say @stas00 and remove the rest!
* Update modeling_utils.py * Update modeling_utils.py * fix and remove all imports * Update modeling_utils.py * Update modeling_utils.py * style * Update modeling_utils.py
While this PR unbroke the most important functionality (thank you Cyril!), by making a new release you're creating a regression in at least 2 aspects:
Additionally, please fix the tests to use the correct tiny models as suggested here #37281 (comment) If you're in a rush to release now my recommendation is to revert the 2 PRs that lead to a regression, make a release and then replay those PRs with the additional required fixes. |
Completely agree cc @Cyrilvallez |
Hey @stas00! Thanks for still looking into this! Concerning your points:
Hope this solves the issue! Don't hesitate if you find that something might stiill be missing! The tests will be taken care of very soon, I promise 😉 |
My proposed patch was dealing only with the main breakage. thank you for fixing that remaining test, Cyril. Zero3 (not ZeRO1 and not ZeRO2) was never on cpu because the model get init'ed directly on gpus. The problem was in reading the whole I applaud you for having the courage to rework so many special cases into the modern meta device approach, Cyril. |
This reverts commit 84aa13d.
* Update modeling_utils.py * Update modeling_utils.py * fix and remove all imports * Update modeling_utils.py * Update modeling_utils.py * style * Update modeling_utils.py
What does this PR do?