Skip to content

LLama3 MVP #802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 1, 2025
Merged

LLama3 MVP #802

merged 11 commits into from
May 1, 2025

Conversation

ngc92
Copy link
Contributor

@ngc92 ngc92 commented Apr 13, 2025

This implements the minimum necessary changes to get an implementation of LLama3 that is functional.
Key fixes compared to the current llama3 branch:

  • ignore bias terms during backward
  • ensure learning rate matches pytorch reference
  • fix gradient checking (Fix gradient tests #801)

When trying to set up CI, we run into the problem that even the 1B model is too large to fit for training; I've tried two different things:
a) run reference code with --device=cpu. This works, but we see numerical differences quite prominently, would need to increase tolerances by ~10x for fp32 mode
b) use torchao's CPUOffloadOptimizer. Works, but introduces another dependency on the python side. Also changes the numerics, but only for the optimizer, so it doesn't break the gradient step. EDIT: CPUOffloadOptimizer is not compatible with gradient clipping, so I had to add a terrible hack :( The loss values set as targets in test_llama3.cu are generated from the .py file, but I ran it on a larger GPU so that these are without offloading.

What is still missing:

  • cuDNN support
  • tied weights (i.e., actually reproducing 1B/3B models)
  • evaluation in the C training loop
  • generating models from scratch
  • haven't checked the non-HF code path for the python reference

Q: Do we really want to store the hidden dimension size as a floating-point factor , followed by 1024-roudning? Instead of just specifying hidden_dim explicitly? The factor would make sense if it was the same across all models, but it differs.

I do have code for all of these, but I'd like to keep the PRs at a managable size, so these changes aren't included here.

@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch 2 times, most recently from 0c647e5 to ce772eb Compare April 13, 2025 16:27
@ngc92 ngc92 changed the base branch from master to llama3 April 13, 2025 18:08
@ngc92 ngc92 changed the base branch from llama3 to master April 13, 2025 18:33
@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch 8 times, most recently from b5bf445 to a94e63e Compare April 13, 2025 19:46
@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch 7 times, most recently from b220db7 to 4c44fcb Compare April 13, 2025 20:52
@ngc92 ngc92 changed the title Ngc92/llama3 dev WIP: Ngc92/llama3 dev Apr 13, 2025
@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch 3 times, most recently from ac89723 to 805e271 Compare April 14, 2025 09:00
@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch from f16202e to ce3a145 Compare April 14, 2025 09:31
@ngc92 ngc92 changed the base branch from master to llama3 April 14, 2025 09:33
@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch from ce3a145 to 43d6842 Compare April 14, 2025 09:33
@ngc92 ngc92 changed the title WIP: Ngc92/llama3 dev WIP: LLama3 MVP Apr 14, 2025
@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch from 43d6842 to 5b92829 Compare April 14, 2025 09:48
@ngc92 ngc92 changed the title WIP: LLama3 MVP LLama3 MVP Apr 14, 2025
@ngc92 ngc92 force-pushed the ngc92/llama3-dev branch from b6492d7 to 9c52a95 Compare April 14, 2025 15:06
@karpathy karpathy merged commit 49cef1d into karpathy:llama3 May 1, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants