Skip to content

Commit 530d4bd

Browse files
authored
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
1 parent c52ccc4 commit 530d4bd

File tree

85 files changed

+1268
-1684
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+1268
-1684
lines changed

Diff for: llama_stack/apis/inference/inference.py

+56-7
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,64 @@
2525
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
2626
from llama_stack.models.llama.datatypes import (
2727
BuiltinTool,
28-
SamplingParams,
2928
StopReason,
3029
ToolCall,
3130
ToolDefinition,
31+
ToolParamDefinition,
3232
ToolPromptFormat,
3333
)
3434
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
3535
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
3636

37+
register_schema(ToolCall)
38+
register_schema(ToolParamDefinition)
39+
register_schema(ToolDefinition)
40+
41+
42+
@json_schema_type
43+
class GreedySamplingStrategy(BaseModel):
44+
type: Literal["greedy"] = "greedy"
45+
46+
47+
@json_schema_type
48+
class TopPSamplingStrategy(BaseModel):
49+
type: Literal["top_p"] = "top_p"
50+
temperature: Optional[float] = Field(..., gt=0.0)
51+
top_p: Optional[float] = 0.95
52+
53+
54+
@json_schema_type
55+
class TopKSamplingStrategy(BaseModel):
56+
type: Literal["top_k"] = "top_k"
57+
top_k: int = Field(..., ge=1)
58+
59+
60+
SamplingStrategy = Annotated[
61+
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
62+
Field(discriminator="type"),
63+
]
64+
register_schema(SamplingStrategy, name="SamplingStrategy")
65+
66+
67+
@json_schema_type
68+
class SamplingParams(BaseModel):
69+
"""Sampling parameters.
70+
71+
:param strategy: The sampling strategy.
72+
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
73+
your prompt plus max_tokens cannot exceed the model's context length.
74+
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
75+
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
76+
:param stop: Up to 4 sequences where the API will stop generating further tokens.
77+
The returned text will not contain the stop sequence.
78+
"""
79+
80+
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
81+
82+
max_tokens: Optional[int] = 0
83+
repetition_penalty: Optional[float] = 1.0
84+
stop: Optional[List[str]] = None
85+
3786

3887
class LogProbConfig(BaseModel):
3988
"""
@@ -48,18 +97,18 @@ class QuantizationType(Enum):
4897
"""Type of model quantization to run inference with.
4998
5099
:cvar bf16: BFloat16 typically this means _no_ quantization
51-
:cvar fp8: 8-bit floating point quantization
52-
:cvar int4: 4-bit integer quantization
100+
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
101+
:cvar int4_mixed: 4-bit integer quantization with mixed precision
53102
"""
54103

55104
bf16 = "bf16"
56-
fp8 = "fp8"
57-
int4 = "int4"
105+
fp8_mixed = "fp8_mixed"
106+
int4_mixed = "int4_mixed"
58107

59108

60109
@json_schema_type
61110
class Fp8QuantizationConfig(BaseModel):
62-
type: Literal["fp8"] = "fp8"
111+
type: Literal["fp8_mixed"] = "fp8_mixed"
63112

64113

65114
@json_schema_type
@@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
75124
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
76125
"""
77126

78-
type: Literal["int4"] = "int4"
127+
type: Literal["int4_mixed"] = "int4_mixed"
79128
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
80129

81130

Diff for: llama_stack/cli/download.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from termcolor import cprint
3030

3131
from llama_stack.cli.subcommand import Subcommand
32-
from llama_stack.models.llama.datatypes import Model
3332
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
33+
from llama_stack.models.llama.sku_types import Model
3434

3535

3636
class Download(Subcommand):

Diff for: llama_stack/cli/model/describe.py

-11
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,6 @@ def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
6363
("Model params.json", json.dumps(model.arch_args, indent=4)),
6464
]
6565

66-
if model.recommended_sampling_params is not None:
67-
sampling_params = model.recommended_sampling_params.model_dump()
68-
for k in ("max_tokens", "repetition_penalty"):
69-
del sampling_params[k]
70-
rows.append(
71-
(
72-
"Recommended sampling params",
73-
json.dumps(sampling_params, indent=4),
74-
)
75-
)
76-
7766
print_table(
7867
rows,
7968
headers,

Diff for: llama_stack/cli/model/prompt_format.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from llama_stack.cli.subcommand import Subcommand
1313
from llama_stack.cli.table import print_table
14-
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
14+
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
1515

1616
ROOT_DIR = Path(__file__).parent.parent.parent
1717

Diff for: llama_stack/cli/model/safety_models.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from typing import Any, Dict, Optional
7+
from typing import Any, Dict
88

99
from pydantic import BaseModel, ConfigDict, Field
1010

11-
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
1211
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
12+
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
1313

1414

1515
class PromptGuardModel(BaseModel):
@@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
2323
is_instruct_model: bool = False
2424
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
2525
arch_args: Dict[str, Any] = Field(default_factory=dict)
26-
recommended_sampling_params: Optional[SamplingParams] = None
2726

2827
def descriptor(self) -> str:
2928
return self.model_id

Diff for: llama_stack/models/llama/checkpoint.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import concurrent.futures
8+
import re
9+
from pathlib import Path
10+
from typing import Any, Dict, List, Optional, Union
11+
12+
import numpy as np
13+
import torch
14+
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
15+
16+
17+
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
18+
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
19+
if new_mp_size % old_mp_size == 0:
20+
# Read old MP shard and split it into smaller ones
21+
return [new_mp_rank * old_mp_size // new_mp_size]
22+
elif old_mp_size % new_mp_size == 0:
23+
# Merge old MP shards into a single one
24+
mp_factor = old_mp_size // new_mp_size
25+
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
26+
else:
27+
raise ValueError(
28+
f"Either old MP size or new MP size should be a multiple of the other: "
29+
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
30+
)
31+
32+
33+
def maybe_reshard_state_dict(
34+
ckpt_paths: List[Path],
35+
n_kv_heads: int,
36+
moe_num_experts: Optional[int] = None,
37+
map_location: Union[str, torch.device] = "cpu",
38+
mmap: bool = True,
39+
) -> Dict[str, torch.Tensor]:
40+
if str(map_location) == "cpu":
41+
torch.set_default_tensor_type(torch.BFloat16Tensor)
42+
else:
43+
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
44+
45+
ckpt_paths = np.array(sorted(ckpt_paths))
46+
47+
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
48+
old_mp_size = len(ckpt_paths)
49+
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
50+
51+
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
52+
paths = ckpt_paths[old_mp_ranks] # type: ignore
53+
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
54+
55+
if new_mp_size == old_mp_size:
56+
return state_dicts[0] # type: ignore
57+
58+
if moe_num_experts is not None:
59+
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
60+
61+
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
62+
return reshard_mp(
63+
state_dicts,
64+
size=max(new_mp_size // old_mp_size, 1),
65+
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
66+
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
67+
)
68+
69+
70+
_WEIGHT_ROW_KEY = {
71+
"feed_forward.w2",
72+
"feed_forward.mlp.fc2",
73+
"attention.wo",
74+
"feed_forward.mlp.fc2_weight",
75+
"feed_forward.w_out_shared_DF.weight",
76+
"attn.wo.weight",
77+
"mlp.c_proj.weight",
78+
}
79+
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
80+
81+
_WEIGHT_COLUMN_KEY = {
82+
"output",
83+
"feed_forward.(w1|w3)",
84+
"feed_forward.mlp.(fc1|fc3)",
85+
"feed_forward.mlp.fc1_weight",
86+
"attention.(wk|wq|wv|wqkv).weight",
87+
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
88+
"attn.(wk|wq|wv).weight",
89+
"attn.(wk|wq|wv).bias",
90+
"mlp.c_fc.weight",
91+
"mlp.c_fc.bias",
92+
"conv1._linear.weight",
93+
"tok_embeddings.weight",
94+
"vision_projection.weight",
95+
}
96+
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
97+
98+
99+
def reshard_mp(
100+
state_dicts: List[Dict[str, torch.Tensor]],
101+
size: int,
102+
rank: int,
103+
repeat_qk_qv: int = 1,
104+
) -> Dict[str, torch.Tensor]:
105+
"""
106+
Reshard a list of state dicts into a single state dict given a change in MP size.
107+
If the list has more than one state dict, we concatenate the values of the same
108+
key across all state dicts. Otherwise, we just slice it for the current MP rank.
109+
"""
110+
111+
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
112+
if len(tensors) > 1:
113+
return torch.cat(tensors, dim=dim)
114+
return tensors[0].chunk(size, dim=dim)[rank].clone()
115+
116+
def process_key(key: str) -> torch.Tensor:
117+
if row_regex.search(key):
118+
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
119+
elif column_regex.search(key):
120+
if "w13" in key or "fc1_weight" in key:
121+
dims = state_dicts[0][key].size()
122+
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
123+
return concat_or_chunk(values, dim=1).flatten(0, 1)
124+
elif "qkv" in key:
125+
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
126+
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
127+
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
128+
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
129+
elif "wk.weight" in key or "wv.weight" in key:
130+
# Support MP > #kv_head
131+
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
132+
elif key == "output.bias" or key == "fc.weight":
133+
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
134+
elif "w_" in key:
135+
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
136+
else:
137+
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
138+
else:
139+
return state_dicts[0][key].clone()
140+
141+
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
142+
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
143+
144+
column_regex = re.compile("|".join(column_keys))
145+
row_regex = re.compile("|".join(row_keys))
146+
147+
output: Dict[str, torch.Tensor] = {}
148+
with concurrent.futures.ThreadPoolExecutor() as executor:
149+
# Note: only processes keys in the first state dict.
150+
# Assumes keys are the same across all state dicts.
151+
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
152+
for future in concurrent.futures.as_completed(mappings):
153+
output[mappings[future]] = future.result()
154+
return output
155+
156+
157+
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
158+
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
159+
routed_regex = re.compile("|".join(routed_keys))
160+
keys = list(state_dict.keys())
161+
for key in keys:
162+
if routed_regex.search(key):
163+
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
164+
return state_dict

0 commit comments

Comments
 (0)