Skip to content

Commit 9283e50

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Skeleton for GGUF conversion (#2018)
Summary: Starting a skeleton implementation - Only llama for now. Will add new architecture inside gguf_util/converters/ - Only fp32. Will figure out the quantization. - Reusing the existing llama code in examples to reduce duplication. For other architectures, there won't be much duplication. - Currently converting to PyTorch, and then going through export, to_edge, to_executorch. But that's an implementation detail. Pull Request resolved: #2018 Test Plan: `python extension/gguf_util/convert_main.py --gguf_file="/Users/mnachin/models_gguf/OpenHermes-2.5-Mistral-7B-fp16.gguf"` Reviewed By: shoumikhin Differential Revision: D53982833 Pulled By: mergennachin fbshipit-source-id: 5402c0de3e729e434763a5d6a390448603e77429
1 parent a6d71e2 commit 9283e50

File tree

6 files changed

+317
-0
lines changed

6 files changed

+317
-0
lines changed

extension/gguf_util/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Summary
2+
This is an experimental feature to convert [GGUF format](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) to PTE file, which can be executed directly on ExecuTorch.
3+
4+
## Usage:
5+
6+
python executorch/extension/gguf_util/convert_main.py --gguf_file=<path_to_gguf_file> --pte_file=<output_pte_file>

extension/gguf_util/convert_main.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
9+
from executorch.extension.gguf_util.converter import convert_to_pte
10+
from executorch.extension.gguf_util.load_gguf import load_file
11+
12+
13+
def save_pte_program(_, pte_file) -> None:
14+
# TODO (mnachin): Save the PTE program to a file
15+
print(f"Saving PTE program to {pte_file}")
16+
17+
18+
def main() -> None:
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument(
21+
"--gguf_file",
22+
type=str,
23+
help="The GGUF file to load.",
24+
)
25+
parser.add_argument(
26+
"--pte_file",
27+
type=str,
28+
help="The path to save the PTE file.",
29+
)
30+
args = parser.parse_args()
31+
32+
# Step 1: Load the GGUF file
33+
gguf_model_args, gguf_weights = load_file(args.gguf_file)
34+
35+
# Step 2: Convert the GGUF model to PTE
36+
# Currently, underneath the hood, it is first converting the GGUF model
37+
# to a PyTorch model (nn.Module), then exporting to ET.
38+
#
39+
# NOTE: In the future, it may makes sense to refactor out the conversion from GGUF to nn.Module
40+
# into its own package that can be shared between ExecuTorch and PyTorch core. I can
41+
# imagine that there will be a need to do load GGUF file directly into PyTorch core, and
42+
# use torch.compile/AOTInductor to accelerate on server, without ever touching ExecuTorch.
43+
#
44+
# TODO(mnachin): Add a knob to delegate to various backends.
45+
pte_program = convert_to_pte(gguf_model_args, gguf_weights)
46+
47+
# Step 3: Save the PTE program so that
48+
# it can be used by the ExecuTorch runtime
49+
save_pte_program(pte_program, args.pte_file)
50+
51+
52+
if __name__ == "__main__":
53+
main()

extension/gguf_util/converter.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.extension.gguf_util.load_gguf import GGUFModelArgs, GGUFWeights
8+
9+
10+
def convert_to_pte(model_args: GGUFModelArgs, weights: GGUFWeights) -> None:
11+
"""Convert a GGUF model into a PTE file, an ExecuTorch program.
12+
13+
Args:
14+
model_args: The arguments for the GGUF model.
15+
weights: The weights of the GGUF model.
16+
"""
17+
18+
# Switch statement based on the architecture enum.
19+
# Each enum has its own converter function.
20+
if model_args.arch == "llama":
21+
from executorch.extension.gguf_util.converters.llama_converter import (
22+
convert_to_pte as llama_convert_to_pte,
23+
)
24+
25+
return llama_convert_to_pte(model_args, weights)
26+
else:
27+
raise NotImplementedError("Unsupported architecture.")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
from typing import Any, Mapping
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.examples.models.llama2.llama_transformer import (
13+
ModelArgs as LlamaModelArgs,
14+
Transformer as LlamaTransformer,
15+
)
16+
from executorch.extension.gguf_util.load_gguf import GGUFModelArgs, GGUFWeights
17+
18+
19+
def _create_pt_model(
20+
gguf_model_args: GGUFModelArgs,
21+
) -> nn.Module:
22+
llama_model_args = LlamaModelArgs(
23+
dim=gguf_model_args.embedding_length,
24+
n_layers=gguf_model_args.block_count,
25+
n_heads=gguf_model_args.attention.head_count,
26+
n_kv_heads=gguf_model_args.attention.head_count_kv,
27+
vocab_size=gguf_model_args.vocab_size,
28+
norm_eps=gguf_model_args.attention.layer_norm_rms_epsilon,
29+
hidden_dim=gguf_model_args.feed_forward_length,
30+
rope_freq_base=gguf_model_args.rope.freq_base,
31+
)
32+
pt_model = LlamaTransformer(llama_model_args)
33+
pt_model.eval()
34+
return pt_model
35+
36+
37+
_name_replacements = [
38+
("blk", "layers"),
39+
("token_embd", "tok_embeddings"),
40+
("attn_q", "attention.wq"),
41+
("attn_k", "attention.wk"),
42+
("attn_v", "attention.wv"),
43+
("attn_output", "attention.wo"),
44+
("attn_norm", "attention_norm"),
45+
("output_norm.weight", "norm.weight"),
46+
("ffn_down", "feed_forward.w2"),
47+
("ffn_gate", "feed_forward.w1"),
48+
("ffn_up", "feed_forward.w3"),
49+
]
50+
51+
52+
def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
53+
result = copy.deepcopy(gguf_name)
54+
for gguf_string, replacement in _name_replacements:
55+
result = result.replace(gguf_string, replacement)
56+
return result
57+
58+
59+
def _convert_to_state_dict(gguf_weights: GGUFWeights) -> Mapping[str, Any]:
60+
61+
state_dict = {}
62+
for tensor in gguf_weights.tensors:
63+
gguf_tensor_name = tensor.name
64+
nn_tensor_name = _convert_gguf_tensor_name_to_llama_nn(gguf_tensor_name)
65+
new_tensor = tensor.data.reshape(tensor.shape).transpose()
66+
state_dict[nn_tensor_name] = torch.from_numpy(new_tensor)
67+
68+
return state_dict
69+
70+
71+
def _load_weights_into_nn(
72+
pt_model: nn.Module, gguf_model_args: GGUFModelArgs, gguf_weights: GGUFWeights
73+
):
74+
75+
state_dict: Mapping[str, Any] = _convert_to_state_dict(gguf_weights)
76+
77+
# We need to fake initialize the mask, to match with the llama_transformer.py
78+
for id in range(gguf_model_args.block_count):
79+
mask_name = f"layers.{id}.attention.mask"
80+
mask = torch.full(
81+
(1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len),
82+
float("-inf"),
83+
)
84+
mask = torch.triu(mask, diagonal=1)
85+
state_dict[mask_name] = mask
86+
87+
pt_model.load_state_dict(state_dict)
88+
return
89+
90+
91+
def _create_pte_program(pt_model: nn.Module) -> bytes:
92+
# TODO (mnachin): Export
93+
return
94+
95+
96+
def convert_to_pte(gguf_model_args: GGUFModelArgs, gguf_weights: GGUFWeights) -> bytes:
97+
"""Convert a GGUF model into an ExecuTorch program.
98+
99+
Args:
100+
model_args: The arguments for the GGUF model.
101+
weights: The weights of the GGUF model.
102+
"""
103+
104+
assert (
105+
gguf_model_args.arch == "llama"
106+
), "Only LLaMa models are supported by this converter."
107+
108+
# Step 1: Create the PyTorch model
109+
print("Create the PyTorch model")
110+
pt_model = _create_pt_model(
111+
gguf_model_args,
112+
)
113+
114+
# Step 2: Load the weights into the PyTorch model
115+
print("Load the weights into the PyTorch model")
116+
_load_weights_into_nn(pt_model, gguf_model_args, gguf_weights)
117+
118+
# Step 3: Export to ExecuTorch
119+
print("Exporting to ExecuTorch.")
120+
pte_program = _create_pte_program(pt_model)
121+
return pte_program
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
pip install gguf==0.6.0

extension/gguf_util/load_gguf.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from pathlib import Path
9+
from typing import Any
10+
11+
import gguf
12+
from gguf import GGUFValueType, ReaderTensor
13+
14+
15+
@dataclass
16+
class AttentionArgs:
17+
head_count: int
18+
head_count_kv: int
19+
layer_norm_rms_epsilon: float
20+
21+
22+
@dataclass
23+
class RopeArgs:
24+
freq_base: float
25+
26+
27+
@dataclass
28+
class GGUFModelArgs:
29+
arch: str
30+
embedding_length: int
31+
block_count: int
32+
feed_forward_length: int
33+
vocab_size: int
34+
attention: AttentionArgs
35+
rope: RopeArgs
36+
37+
38+
@dataclass
39+
class GGUFWeights:
40+
tensors: list[ReaderTensor]
41+
42+
43+
def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
44+
metadata: dict[str, Any] = {}
45+
46+
for idx, field in enumerate(reader.fields.values()):
47+
val = None
48+
if field.types[:1] == [GGUFValueType.ARRAY]:
49+
itype = field.types[-1]
50+
if itype == GGUFValueType.STRING:
51+
val = [
52+
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
53+
]
54+
else:
55+
val = [pv for idx in field.data for pv in field.parts[idx].tolist()]
56+
elif field.types[0] == GGUFValueType.STRING:
57+
val = str(bytes(field.parts[-1]), encoding="utf-8")
58+
else:
59+
val = field.parts[-1].tolist()[0]
60+
61+
metadata[field.name] = val
62+
63+
return metadata
64+
65+
66+
def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
67+
arch = metadata["general.architecture"]
68+
69+
return GGUFModelArgs(
70+
arch=arch,
71+
embedding_length=metadata[f"{arch}.embedding_length"],
72+
block_count=metadata[f"{arch}.block_count"],
73+
feed_forward_length=metadata[f"{arch}.feed_forward_length"],
74+
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
75+
attention=AttentionArgs(
76+
head_count=metadata[f"{arch}.attention.head_count"],
77+
head_count_kv=metadata[f"{arch}.attention.head_count_kv"],
78+
layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
79+
),
80+
rope=RopeArgs(
81+
freq_base=metadata[f"{arch}.rope.freq_base"],
82+
),
83+
)
84+
85+
86+
def load_file(gguf_file: str) -> (GGUFModelArgs, GGUFWeights):
87+
"""
88+
Load a GGUF file and return the model arguments and weights.
89+
"""
90+
if not Path(gguf_file).is_file():
91+
raise ValueError(f"Could not find file {gguf_file}")
92+
93+
reader = gguf.GGUFReader(gguf_file, "r")
94+
95+
# Step 1: Build GGUFModelArgs
96+
metadata = _get_metadata(reader)
97+
model_args = _build_model_args(metadata)
98+
99+
# Step 2: Build GGUFWeights
100+
gguf_weights = GGUFWeights(tensors=reader.tensors)
101+
102+
return (model_args, gguf_weights)

0 commit comments

Comments
 (0)