Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 402bb90

Browse files
[Engine]: Enbale gpt neox and dolly (#939)
1 parent 2580f3c commit 402bb90

39 files changed

+3115
-112
lines changed

examples/.config/engine_deploy.json

100644100755
+3-4
Original file line numberDiff line numberDiff line change
@@ -553,16 +553,15 @@
553553
"model": "/tf_dataset2/models/nlp_toolkit/llama-7b-hf",
554554
"dtype": "fp32/bf16/int8",
555555
"output_model": "ir",
556-
"pt_file": "pt",
557-
"model_type": "llama_7b"
556+
"pt_file": "pt"
558557
}
559558
},
560559
"benchmark": {
561560
"cmd": "python run_llm.py",
562561
"params": {
563562
"max-new-tokens": "32",
564-
"model_path": "ir",
565-
"model_type": "llama_7b"
563+
"model": "decapoda-research/llama-7b-hf",
564+
"model_path": "ir"
566565
}
567566
},
568567
"launcher": {

examples/huggingface/pytorch/text-generation/deployment/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ python optimize_llm.py --model=EleutherAI/gpt-j-6B --dtype=(fp32|bf16) --output_
3838

3939
# int8
4040
wget https://huggingface.co/Intel/gpt-j-6B-pytorch-int8-static/resolve/main/pytorch_model.bin -O <path to int8_model.pt>
41-
python gen_ir.py --model=EleutherAI/gpt-j-6B --dtype=int8 --output_model=<path to ir> --pt_file=<path to int8_model.pt>
41+
python optimize_llm.py --model=EleutherAI/gpt-j-6B --dtype=int8 --output_model=<path to ir> --pt_file=<path to int8_model.pt>
4242
```
4343
- When the input dtype is fp32 or bf16, the model will be downloaded if it does not exist.
4444
- When the input dtype is int8, the int8 trace model should exist.

examples/huggingface/pytorch/text-generation/deployment/generation_utils.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
# Load past kv caches from files
9393
import numpy as np
9494
import pickle
95+
from optimum.utils import NormalizedConfigManager
96+
9597

9698
logger = logging.get_logger(__name__)
9799

@@ -551,7 +553,7 @@ def _prepare_model_inputs(
551553
)
552554
elif inputs_kwarg is not None:
553555
inputs = inputs_kwarg
554-
if not model_kwargs["llama"]:
556+
if model_kwargs['model_type'] != 'llama':
555557
# 3. models with `input_ids` can also make use of `inputs_embeds`
556558
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
557559
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
@@ -2854,6 +2856,7 @@ def beam_search(
28542856
beam_scores = beam_scores.view((batch_size * num_beams,))
28552857

28562858
this_peer_finished = False # used by synced_gpus only
2859+
28572860
while True:
28582861
if synced_gpus:
28592862
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@@ -2866,7 +2869,8 @@ def beam_search(
28662869
break
28672870

28682871
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2869-
if not model_kwargs["llama"]:
2872+
2873+
if model_kwargs['model_type'] != "llama":
28702874
if model_inputs["past_key_values"] is None:
28712875
first_token = model_inputs["input_ids"].size()[1] != 1
28722876
if first_token:
@@ -2876,10 +2880,8 @@ def beam_search(
28762880
model_inputs["input_ids"] = model_inputs["input_ids"][:1,:]
28772881
input_ids_1 = model_inputs['input_ids'].cpu().numpy().astype(np.int32)
28782882
attention_mask_1 = model_inputs['attention_mask'].cpu().numpy().astype(np.int32)
2879-
2880-
past_k_v = np.ones([1,0,16,256]).astype(np.float32)
2881-
predictions = engine_model.inference([input_ids_1] + [past_k_v for _ in range(2 * model_kwargs["past_kv_nums"])] + [attention_mask_1])
2882-
2883+
past_k_v = np.zeros([1,0,model_kwargs['num_attention_heads'],model_kwargs['d_k']]).astype(np.float32)
2884+
predictions = engine_model.inference([input_ids_1] + [past_k_v for _ in range(2 * model_kwargs['past_kv_nums'])] + [attention_mask_1])
28832885
for key in predictions:
28842886
predictions[key] = torch.from_numpy(predictions[key])
28852887

@@ -2901,28 +2903,22 @@ def beam_search(
29012903
value = value.view(value.size(1) * value.size(0), value.size(2), value.size(3))
29022904
past_key_values.append(tuple([key, value]))
29032905
outputs.past_key_values = tuple(past_key_values)
2906+
29042907
if synced_gpus and this_peer_finished:
29052908
cur_len = cur_len + 1
29062909
continue # don't waste resources running the code we don't need
29072910
next_token_logits = outputs.logits[:, -1, :]
29082911

29092912
else:
2910-
example_inputs = []
2911-
for k, v in model_inputs.items():
2912-
if v is not None and not isinstance(v, bool):
2913-
example_inputs.append(v)
2914-
example_inputs = tuple(example_inputs)
2915-
2916-
input_ids_1 = example_inputs[0].cpu().numpy().astype(np.int32)
2917-
attention_mask_1 = example_inputs[-1].cpu().numpy().astype(np.int32)
2918-
past_key_values = [example_inputs[1][i][j] for i in range(model_kwargs["past_kv_nums"]) for j in range(2)]
2913+
input_ids_1 = model_inputs['input_ids'].cpu().numpy().astype(np.int32)
2914+
attention_mask_1 = model_inputs['attention_mask'].cpu().numpy().astype(np.int32)
2915+
past_key_values = [model_inputs['past_key_values'][i][j] for i in range(model_kwargs["past_kv_nums"]) for j in range(2)]
29192916
predictions = engine_model.inference([input_ids_1] + past_key_values + [attention_mask_1])
2920-
29212917
# ts=time.time()
29222918
for key in predictions:
29232919
predictions[key] = torch.from_numpy(predictions[key])
29242920
outputs = CausalLMOutputWithPast()
2925-
outputs.logits = list(predictions.values())[0].reshape(-1,1,50400)
2921+
outputs.logits = list(predictions.values())[0].reshape(-1,1,model_kwargs['vocab_size'])
29262922
outputs.past_key_values = [(list(predictions.values())[2*i+1], list(predictions.values())[2*i+2]) for i in range(model_kwargs["past_kv_nums"])]
29272923

29282924
# print(2,time.time()-ts)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"pattern_switch": {"MatMulWithTranspose": false, "RemoveLastView": false, "NeoxReorderChange": true, "NeoxRoraryPosEmb": true, 'MultiHeadAttention': true}}

examples/huggingface/pytorch/text-generation/deployment/optimize_llm.py

+52-31
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,38 @@
33
import argparse
44
import os
55
import sys
6+
from optimum.utils import NormalizedConfigManager
7+
8+
class Net(torch.nn.Module):
9+
def __init__(self, ori_model):
10+
super(Net, self).__init__()
11+
self.model = ori_model
12+
def forward(self, input_ids, pastkv, mask):
13+
return self.model(input_ids=input_ids, attention_mask=mask, past_key_values=pastkv, return_dict=False)
614

715
parser = argparse.ArgumentParser('GPT-J Generation ir', add_help=False)
816
parser.add_argument("--model",
917
type=str,
10-
help="path to bfloat16 or int8 IR files",
18+
help="path to original config and weight files",
1119
default="EleutherAI/gpt-j-6B",
1220
)
1321
parser.add_argument('--dtype', default=None, type=str)
1422
parser.add_argument('--output_model', default="./ir", type=str)
15-
parser.add_argument('--model_type', default="gpt-j", type=str)
1623
parser.add_argument('--pt_file', type=str)
1724
args = parser.parse_args()
1825
print(args)
1926

2027
model_id = args.model
21-
model_type = args.model_type
28+
model = AutoModelForCausalLM.from_pretrained(model_id, return_dict=False)
29+
model.eval()
30+
31+
normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config)
32+
num_layers = normalized_config.num_layers
33+
num_attention_heads = normalized_config.num_attention_heads
34+
hidden_size = normalized_config.hidden_size
35+
d_k = hidden_size // num_attention_heads
36+
model_type = model.config.model_type
37+
2238
if 'llama' in model_type:
2339
from transformers import LlamaTokenizer
2440
tokenizer = LlamaTokenizer.from_pretrained(model_id)
@@ -28,54 +44,59 @@
2844
prompt = "Once upon a time, there existed a little girl, who liked to have adventures." + \
2945
" She wanted to go to places and meet new people, and have fun."
3046
init_input_ids = tokenizer(prompt, return_tensors="pt").input_ids[0]
31-
input_ids = init_input_ids.clone()
32-
attention_mask = torch.ones(len(input_ids)+1)
33-
attention_mask[0] = 0
34-
past_key_value_torch = tuple([(torch.zeros([1,16,32,256]), torch.zeros([1,16,32,256])) for i in range(28)])
35-
input_ids = input_ids[0:1].unsqueeze(0)
36-
attention_mask = attention_mask.unsqueeze(0)
47+
input_ids = init_input_ids.clone().unsqueeze(0)
48+
attention_mask = torch.ones(len(input_ids)).unsqueeze(0)
49+
past_key_value = tuple([(torch.zeros([1,num_attention_heads,0,d_k]),
50+
torch.zeros([1,num_attention_heads,0,d_k])) for i in range(num_layers)])
3751

52+
if 'llama' in model_type:
53+
input_ids = init_input_ids.clone()
54+
attention_mask = torch.ones(len(input_ids)+1)
55+
attention_mask[0] = 0
56+
input_ids = input_ids[0:1].unsqueeze(0)
57+
attention_mask = attention_mask.unsqueeze(0)
58+
past_key_value = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)])
59+
if 'llama_13b' in model_type:
60+
past_key_value = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)])
3861

3962
traced_model = None
40-
if 'llama' in model_type:
41-
past_key_value_torch = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)])
42-
if 'llama_13b' in model_type:
43-
past_key_value_torch = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)])
4463

4564
if args.pt_file and os.path.exists(args.pt_file):
4665
print('PT model exists, compile will be executed.')
66+
del model
4767
traced_model = torch.jit.load(args.pt_file)
4868
else:
49-
model = AutoModelForCausalLM.from_pretrained(model_id, return_dict=False)
50-
model.eval()
51-
if args.dtype in ['fp32', 'bf16']:
52-
if 'llama' in model_type:
53-
traced_model = torch.jit.trace(model, (input_ids, attention_mask, past_key_value_torch))
54-
print("Traced model is saved as {}".format(args.pt_file))
55-
else:
56-
traced_model = torch.jit.trace(model, (input_ids, past_key_value_torch, attention_mask))
57-
print("Traced model is saved as {}".format(args.pt_file))
69+
assert args.dtype in ['fp32', 'bf16'], "Model with {} can't be traced, please provide one.".format(args.dtype)
70+
if 'llama' in model_type:
71+
net = model
72+
traced_model = torch.jit.trace(net, (input_ids, attention_mask, past_key_value))
5873
else:
59-
print("Model with {} can't be traced, please provide one.".format(args.dtype))
60-
sys.exit(1)
74+
net = Net(model)
75+
traced_model = torch.jit.trace(net, (input_ids, past_key_value, attention_mask))
6176

6277
from intel_extension_for_transformers.backends.neural_engine.compile import compile, autocast
63-
if 'llama' not in model_type:
78+
if 'llama' in model_type:
6479
if args.dtype == "bf16":
6580
with autocast("bf16"):
66-
graph = compile(traced_model)
81+
graph = compile(traced_model, './llama_pattern.conf')
6782
elif args.dtype == "int8":
68-
graph = compile(traced_model, './int8_pattern.conf')
83+
graph = compile(traced_model, './llama_int8_pattern.conf')
6984
else:
70-
graph = compile(traced_model)
85+
graph = compile(traced_model, './llama_pattern.conf')
86+
elif 'gpt_neox' in model_type:
87+
if args.dtype == "bf16":
88+
with autocast("bf16"):
89+
graph = compile(traced_model, './gpt_neox_pattern.conf')
90+
else:
91+
graph = compile(traced_model, './gpt_neox_pattern.conf')
7192
else:
7293
if args.dtype == "bf16":
7394
with autocast("bf16"):
74-
graph = compile(traced_model, './llama_pattern.conf')
95+
graph = compile(traced_model)
7596
elif args.dtype == "int8":
76-
graph = compile(traced_model, './llama_int8_pattern.conf')
97+
graph = compile(traced_model, './int8_pattern.conf')
7798
else:
78-
graph = compile(traced_model, './llama_pattern.conf')
99+
graph = compile(traced_model)
79100

80101
graph.save(args.output_model)
81102
print('Neural Engine ir is saved as {}'.format(args.output_model))
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
transformers==4.27.4
22
torch==2.0
33
accelerate
4-
sentencepiece
4+
sentencepiece
5+
optimum

examples/huggingface/pytorch/text-generation/deployment/run_llm.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch.profiler import profile, record_function, ProfilerActivity
1010
from accelerate import init_empty_weights
1111
import generation_utils as itrex_generation_utils
12+
from optimum.utils import NormalizedConfigManager
1213

1314
# args
1415
parser = argparse.ArgumentParser('GPT-J generation script', add_help=False)
@@ -17,39 +18,42 @@
1718
help="path to bfloat16 or int8 IR files",
1819
default="bfloat16",
1920
)
21+
parser.add_argument("--model",
22+
type=str,
23+
help="path to original config and weight files",
24+
default="EleutherAI/gpt-j-6B",
25+
)
2026
parser.add_argument('--max-new-tokens', default=32, type=int, help="output max new tokens")
2127
parser.add_argument('--input-tokens', default='32', type=str)
2228
parser.add_argument('--prompt', default=None, type=str)
2329
parser.add_argument('--batch-size', default=1, type=int)
2430
parser.add_argument('--weight_type', default=None, type=str)
25-
parser.add_argument('--model_type', default='gpt-j', type=str)
2631
args = parser.parse_args()
2732
print(args)
2833

2934
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
30-
if args.model_type == 'llama_7b':
31-
generate_kwargs["past_kv_nums"] = 32
32-
generate_kwargs["llama"] = True
33-
model_id = "decapoda-research/llama-7b-hf"
34-
from transformers import LlamaForCausalLM, LlamaTokenizer
35-
tokenizer = LlamaTokenizer.from_pretrained(model_id)
36-
prompt_json = '/llamaprompt.json'
37-
elif args.model_type == 'llama_13b':
38-
generate_kwargs["past_kv_nums"] = 40
39-
generate_kwargs["llama"] = True
40-
model_id = "decapoda-research/llama-13b-hf"
41-
from transformers import LlamaForCausalLM, LlamaTokenizer
35+
36+
model_id = args.model
37+
config = AutoConfig.from_pretrained(model_id)
38+
model_type = config.model_type
39+
normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
40+
num_attention_heads = normalized_config.num_attention_heads
41+
hidden_size = normalized_config.hidden_size
42+
generate_kwargs["past_kv_nums"] = normalized_config.num_layers
43+
generate_kwargs["model_type"] = model_type
44+
generate_kwargs["num_attention_heads"] = num_attention_heads
45+
generate_kwargs["d_k"] = hidden_size // num_attention_heads
46+
generate_kwargs["vocab_size"] = normalized_config.vocab_size
47+
48+
if 'llama' in model_type:
49+
from transformers import LlamaTokenizer
4250
tokenizer = LlamaTokenizer.from_pretrained(model_id)
4351
prompt_json = '/llamaprompt.json'
44-
elif args.model_type == 'gpt-j':
45-
generate_kwargs["past_kv_nums"] = 28
46-
generate_kwargs["llama"] = False
47-
model_id = "EleutherAI/gpt-j-6B"
52+
else:
4853
tokenizer = AutoTokenizer.from_pretrained(model_id)
4954
prompt_json = '/prompt.json'
5055

5156
# load model
52-
config = AutoConfig.from_pretrained(model_id)
5357
with init_empty_weights():
5458
model = AutoModelForCausalLM.from_config(config)
5559
setattr(model, "generate", types.MethodType(itrex_generation_utils.GenerationMixin.generate, model))

0 commit comments

Comments
 (0)