Skip to content

[WIP] Add FSDP2 #797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 67 additions & 118 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from accelerate.utils import is_xpu_available

from llama_recipes.configs import (
fsdp_config as FSDP_CONFIG,
quantization_config as QUANTIZATION_CONFIG,
train_config as TRAIN_CONFIG,
fsdp_config as FsdpConfig,
quantization_config as QuantizationConfig,
train_config as TrainConfig,
)
from llama_recipes.data.concatenator import ConcatDataset
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing

from llama_recipes.utils import fsdp_auto_wrap_policy
from llama_recipes.utils import get_model_and_data_processor
from llama_recipes.utils.config_utils import (
check_fsdp_config,
generate_dataset_config,
Expand All @@ -38,8 +38,6 @@
from llama_recipes.utils.train_utils import (
clear_gpu_cache,
freeze_transformer_layers,
get_policies,
print_model_size,
setup,
setup_environ_flags,
train,
Expand All @@ -53,8 +51,6 @@
AutoProcessor,
AutoTokenizer,
BitsAndBytesConfig,
LlamaForCausalLM,
MllamaForConditionalGeneration,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mllama.modeling_mllama import (
Expand All @@ -72,9 +68,9 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
"You are trying to use wandb which is not currently installed. "
"Please install it using pip install wandb"
)
from llama_recipes.configs import wandb_config as WANDB_CONFIG
from llama_recipes.configs import wandb_config as WandBConfig

wandb_config = WANDB_CONFIG()
wandb_config = WandBConfig()
update_config(wandb_config, **kwargs)
init_dict = dataclasses.asdict(wandb_config)
run = wandb.init(**init_dict)
Expand All @@ -85,7 +81,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):

def main(**kwargs):
# Update the configuration for the training and sharding process
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
train_config, fsdp_config = TrainConfig(), FsdpConfig()
update_config((train_config, fsdp_config), **kwargs)
# Set the seeds for reproducibility
if is_xpu_available():
Expand Down Expand Up @@ -116,7 +112,7 @@ def main(**kwargs):
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)

# setting quantization configs
bnb_config = None
quant_config = None
if train_config.quantization:
if type(train_config.quantization) == type(True):
warn(
Expand All @@ -130,70 +126,15 @@ def main(**kwargs):
"8bit quantization is not supported with FSDP, please use 4bit quantization"
)

quant_config = QUANTIZATION_CONFIG()
quant_config = QuantizationConfig()
update_config(quant_config, **kwargs)
bnb_config = quant_config.create_bnb_config(train_config.quantization)

# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
config = AutoConfig.from_pretrained(train_config.model_name)
if config.model_type == "mllama":
is_vision = True
model = MllamaForConditionalGeneration.from_pretrained(
train_config.model_name,
quantization_config=bnb_config,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
device_map=(
"auto"
if train_config.quantization and not train_config.enable_fsdp
else None
),
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(
train_config.model_name
if train_config.tokenizer_name is None
else train_config.tokenizer_name
)
processor.tokenizer.padding_side = "right"
model.supports_gradient_checkpointing = True
model.language_model.supports_gradient_checkpointing = True
elif config.model_type == "llama":
is_vision = False
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
quantization_config=bnb_config,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
device_map=(
"auto"
if train_config.quantization and not train_config.enable_fsdp
else None
),
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
)
model, dataset_processer, is_vision = get_model_and_data_processor(train_config, quant_config)
if is_vision:
tokenizer = dataset_processer.tokenizer
else:
raise ValueError(
f"Model type {config.model_type} is not supported. Please use llama or mllama model."
)
# Load the tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(
train_config.model_name
if train_config.tokenizer_name is None
else train_config.tokenizer_name
)
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id

# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
print(
"WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
)
model.resize_token_embeddings(len(tokenizer))

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
tokenizer = dataset_processer

# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if (
Expand Down Expand Up @@ -235,71 +176,79 @@ def main(**kwargs):

if not train_config.use_peft and train_config.freeze_layers:
freeze_transformer_layers(model, train_config.num_freeze_layers)

device_id = 0
if is_xpu_available():
device_id = torch.xpu.current_device()
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()
from llama_recipes.utils.fsdp_utils import parallelize_model

# model = FSDP(
#
# cpu_offload=(
# CPUOffload(offload_params=True)
# if fsdp_config.fsdp_cpu_offload
# else None
# ),
# mixed_precision=(
# mixed_precision_policy if not fsdp_config.pure_bf16 else None
# ),
# sharding_strategy=fsdp_config.sharding_strategy,
# device_mesh=hsdp_device_mesh_plan,
# device_id=device_id,
# limit_all_gathers=True,
# sync_module_states=train_config.low_cpu_fsdp,
# param_init_fn=(
# (
# lambda module: module.to_empty(
# device=torch.device("cuda"), recurse=False
# )
# )
# if train_config.low_cpu_fsdp and rank != 0
# else None
# ),
# )

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
if is_vision:
my_auto_wrapping_policy = fsdp_auto_wrap_policy(
model,
[
MODS = (
MllamaSelfAttentionDecoderLayer,
MllamaSelfAttentionDecoderLayer,
MllamaVisionEncoderLayer,
],
)
sharding_conditions = [
lambda m: any(isinstance(m,n) for n in MODS),
]
else:
# Create the FSDP wrapper for LlamaDecoderLayer in text models
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
device_id = 0
if is_xpu_available():
device_id = torch.xpu.current_device()
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()
model = FSDP(
model,
auto_wrap_policy=(
my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
),
cpu_offload=(
CPUOffload(offload_params=True)
if fsdp_config.fsdp_cpu_offload
else None
),
mixed_precision=(
mixed_precision_policy if not fsdp_config.pure_bf16 else None
),
sharding_strategy=fsdp_config.sharding_strategy,
device_mesh=hsdp_device_mesh_plan,
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=(
(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)]

if train_config.use_peft:
sharding_conditions += [
lambda m: (
len(list(m.named_children())) == 0
and getattr(m, "weight", None) is not None
and m.weight.requires_grad
)
if train_config.low_cpu_fsdp and rank != 0
else None
),
]

parallelize_model(
model,
fsdp_config,
device_mesh = hsdp_device_mesh_plan,
sharding_conditions = sharding_conditions,
)

if fsdp_config.fsdp_activation_checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
if is_xpu_available():
model.to("xpu:0")
elif torch.cuda.is_available():
model.to("cuda")
dataset_config = generate_dataset_config(train_config, kwargs)
if is_vision:
dataset_processer = processor
else:
dataset_processer = tokenizer


# Load and preprocess the dataset for training and validation

dataset_train = get_preprocessed_dataset(
dataset_processer,
dataset_config,
Expand Down
8 changes: 4 additions & 4 deletions src/llama_recipes/model_checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from llama_recipes.model_checkpointing.checkpoint_handler import (
load_model_checkpoint,
save_fsdp_model_checkpoint_full,
save_fsdp_checkpoint_full,
save_fsdp_checkpoint_sharded,
save_peft_checkpoint,
save_model_checkpoint,
save_checkpoint,
load_optimizer_checkpoint,
save_optimizer_checkpoint,
save_model_and_optimizer_sharded,
load_model_sharded,
load_fsdp_checkpoint_sharded,
load_sharded_model_single_gpu
)
Loading