|
3 | 3 | import argparse
|
4 | 4 | import os
|
5 | 5 | import sys
|
| 6 | +from optimum.utils import NormalizedConfigManager |
| 7 | + |
| 8 | +class Net(torch.nn.Module): |
| 9 | + def __init__(self, ori_model): |
| 10 | + super(Net, self).__init__() |
| 11 | + self.model = ori_model |
| 12 | + def forward(self, input_ids, pastkv, mask): |
| 13 | + return self.model(input_ids=input_ids, attention_mask=mask, past_key_values=pastkv, return_dict=False) |
6 | 14 |
|
7 | 15 | parser = argparse.ArgumentParser('GPT-J Generation ir', add_help=False)
|
8 | 16 | parser.add_argument("--model",
|
9 | 17 | type=str,
|
10 |
| - help="path to bfloat16 or int8 IR files", |
| 18 | + help="path to original config and weight files", |
11 | 19 | default="EleutherAI/gpt-j-6B",
|
12 | 20 | )
|
13 | 21 | parser.add_argument('--dtype', default=None, type=str)
|
14 | 22 | parser.add_argument('--output_model', default="./ir", type=str)
|
15 |
| -parser.add_argument('--model_type', default="gpt-j", type=str) |
16 | 23 | parser.add_argument('--pt_file', type=str)
|
17 | 24 | args = parser.parse_args()
|
18 | 25 | print(args)
|
19 | 26 |
|
20 | 27 | model_id = args.model
|
21 |
| -model_type = args.model_type |
| 28 | +model = AutoModelForCausalLM.from_pretrained(model_id, return_dict=False) |
| 29 | +model.eval() |
| 30 | + |
| 31 | +normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config) |
| 32 | +num_layers = normalized_config.num_layers |
| 33 | +num_attention_heads = normalized_config.num_attention_heads |
| 34 | +hidden_size = normalized_config.hidden_size |
| 35 | +d_k = hidden_size // num_attention_heads |
| 36 | +model_type = model.config.model_type |
| 37 | + |
22 | 38 | if 'llama' in model_type:
|
23 | 39 | from transformers import LlamaTokenizer
|
24 | 40 | tokenizer = LlamaTokenizer.from_pretrained(model_id)
|
|
28 | 44 | prompt = "Once upon a time, there existed a little girl, who liked to have adventures." + \
|
29 | 45 | " She wanted to go to places and meet new people, and have fun."
|
30 | 46 | init_input_ids = tokenizer(prompt, return_tensors="pt").input_ids[0]
|
31 |
| -input_ids = init_input_ids.clone() |
32 |
| -attention_mask = torch.ones(len(input_ids)+1) |
33 |
| -attention_mask[0] = 0 |
34 |
| -past_key_value_torch = tuple([(torch.zeros([1,16,32,256]), torch.zeros([1,16,32,256])) for i in range(28)]) |
35 |
| -input_ids = input_ids[0:1].unsqueeze(0) |
36 |
| -attention_mask = attention_mask.unsqueeze(0) |
| 47 | +input_ids = init_input_ids.clone().unsqueeze(0) |
| 48 | +attention_mask = torch.ones(len(input_ids)).unsqueeze(0) |
| 49 | +past_key_value = tuple([(torch.zeros([1,num_attention_heads,0,d_k]), |
| 50 | + torch.zeros([1,num_attention_heads,0,d_k])) for i in range(num_layers)]) |
37 | 51 |
|
| 52 | +if 'llama' in model_type: |
| 53 | + input_ids = init_input_ids.clone() |
| 54 | + attention_mask = torch.ones(len(input_ids)+1) |
| 55 | + attention_mask[0] = 0 |
| 56 | + input_ids = input_ids[0:1].unsqueeze(0) |
| 57 | + attention_mask = attention_mask.unsqueeze(0) |
| 58 | + past_key_value = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)]) |
| 59 | + if 'llama_13b' in model_type: |
| 60 | + past_key_value = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)]) |
38 | 61 |
|
39 | 62 | traced_model = None
|
40 |
| -if 'llama' in model_type: |
41 |
| - past_key_value_torch = tuple([(torch.zeros([1,32,34,128]), torch.zeros([1,32,34,128])) for i in range(32)]) |
42 |
| -if 'llama_13b' in model_type: |
43 |
| - past_key_value_torch = tuple([(torch.zeros([1,40,34,128]), torch.zeros([1,40,34,128])) for i in range(40)]) |
44 | 63 |
|
45 | 64 | if args.pt_file and os.path.exists(args.pt_file):
|
46 | 65 | print('PT model exists, compile will be executed.')
|
| 66 | + del model |
47 | 67 | traced_model = torch.jit.load(args.pt_file)
|
48 | 68 | else:
|
49 |
| - model = AutoModelForCausalLM.from_pretrained(model_id, return_dict=False) |
50 |
| - model.eval() |
51 |
| - if args.dtype in ['fp32', 'bf16']: |
52 |
| - if 'llama' in model_type: |
53 |
| - traced_model = torch.jit.trace(model, (input_ids, attention_mask, past_key_value_torch)) |
54 |
| - print("Traced model is saved as {}".format(args.pt_file)) |
55 |
| - else: |
56 |
| - traced_model = torch.jit.trace(model, (input_ids, past_key_value_torch, attention_mask)) |
57 |
| - print("Traced model is saved as {}".format(args.pt_file)) |
| 69 | + assert args.dtype in ['fp32', 'bf16'], "Model with {} can't be traced, please provide one.".format(args.dtype) |
| 70 | + if 'llama' in model_type: |
| 71 | + net = model |
| 72 | + traced_model = torch.jit.trace(net, (input_ids, attention_mask, past_key_value)) |
58 | 73 | else:
|
59 |
| - print("Model with {} can't be traced, please provide one.".format(args.dtype)) |
60 |
| - sys.exit(1) |
| 74 | + net = Net(model) |
| 75 | + traced_model = torch.jit.trace(net, (input_ids, past_key_value, attention_mask)) |
61 | 76 |
|
62 | 77 | from intel_extension_for_transformers.backends.neural_engine.compile import compile, autocast
|
63 |
| -if 'llama' not in model_type: |
| 78 | +if 'llama' in model_type: |
64 | 79 | if args.dtype == "bf16":
|
65 | 80 | with autocast("bf16"):
|
66 |
| - graph = compile(traced_model) |
| 81 | + graph = compile(traced_model, './llama_pattern.conf') |
67 | 82 | elif args.dtype == "int8":
|
68 |
| - graph = compile(traced_model, './int8_pattern.conf') |
| 83 | + graph = compile(traced_model, './llama_int8_pattern.conf') |
69 | 84 | else:
|
70 |
| - graph = compile(traced_model) |
| 85 | + graph = compile(traced_model, './llama_pattern.conf') |
| 86 | +elif 'gpt_neox' in model_type: |
| 87 | + if args.dtype == "bf16": |
| 88 | + with autocast("bf16"): |
| 89 | + graph = compile(traced_model, './gpt_neox_pattern.conf') |
| 90 | + else: |
| 91 | + graph = compile(traced_model, './gpt_neox_pattern.conf') |
71 | 92 | else:
|
72 | 93 | if args.dtype == "bf16":
|
73 | 94 | with autocast("bf16"):
|
74 |
| - graph = compile(traced_model, './llama_pattern.conf') |
| 95 | + graph = compile(traced_model) |
75 | 96 | elif args.dtype == "int8":
|
76 |
| - graph = compile(traced_model, './llama_int8_pattern.conf') |
| 97 | + graph = compile(traced_model, './int8_pattern.conf') |
77 | 98 | else:
|
78 |
| - graph = compile(traced_model, './llama_pattern.conf') |
| 99 | + graph = compile(traced_model) |
79 | 100 |
|
80 | 101 | graph.save(args.output_model)
|
81 | 102 | print('Neural Engine ir is saved as {}'.format(args.output_model))
|
0 commit comments