@@ -50,58 +50,54 @@ launch_trt_server() {
50
50
git clone https://github.com/triton-inference-server/tensorrtllm_backend.git
51
51
git lfs install
52
52
cd tensorrtllm_backend
53
- git checkout $trt_llm_version
54
- tensorrtllm_backend_dir=$( pwd)
53
+ git checkout " $trt_llm_version "
55
54
git submodule update --init --recursive
56
55
57
56
# build trtllm engine
58
57
cd /tensorrtllm_backend
59
- cd ./tensorrt_llm/examples/${model_type}
58
+ cd " ./tensorrt_llm/examples/${model_type} "
60
59
python3 convert_checkpoint.py \
61
- --model_dir ${model_path} \
62
- --dtype ${model_dtype} \
63
- --tp_size ${model_tp_size} \
64
- --output_dir ${trt_model_path}
60
+ --model_dir " ${model_path} " \
61
+ --dtype " ${model_dtype} " \
62
+ --tp_size " ${model_tp_size} " \
63
+ --output_dir " ${trt_model_path} "
65
64
trtllm-build \
66
- --checkpoint_dir ${trt_model_path} \
65
+ --checkpoint_dir " ${trt_model_path} " \
67
66
--use_fused_mlp \
68
67
--reduce_fusion disable \
69
68
--workers 8 \
70
- --gpt_attention_plugin ${model_dtype} \
71
- --gemm_plugin ${model_dtype} \
72
- --tp_size ${model_tp_size} \
73
- --max_batch_size ${max_batch_size} \
74
- --max_input_len ${max_input_len} \
75
- --max_seq_len ${max_seq_len} \
76
- --max_num_tokens ${max_num_tokens} \
77
- --output_dir ${trt_engine_path}
69
+ --gpt_attention_plugin " ${model_dtype} " \
70
+ --gemm_plugin " ${model_dtype} " \
71
+ --tp_size " ${model_tp_size} " \
72
+ --max_batch_size " ${max_batch_size} " \
73
+ --max_input_len " ${max_input_len} " \
74
+ --max_seq_len " ${max_seq_len} " \
75
+ --max_num_tokens " ${max_num_tokens} " \
76
+ --output_dir " ${trt_engine_path} "
78
77
79
78
# handle triton protobuf files and launch triton server
80
79
cd /tensorrtllm_backend
81
80
mkdir triton_model_repo
82
81
cp -r all_models/inflight_batcher_llm/* triton_model_repo/
83
82
cd triton_model_repo
84
83
rm -rf ./tensorrt_llm/1/*
85
- cp -r ${trt_engine_path} /* ./tensorrt_llm/1
84
+ cp -r " ${trt_engine_path} " /* ./tensorrt_llm/1
86
85
python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false
87
- python3 ../tools/fill_template.py -i preprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path ,preprocessing_instance_count:5
88
- python3 ../tools/fill_template.py -i postprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path ,postprocessing_instance_count:5,skip_special_tokens:false
89
- python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:$max_batch_size
90
- python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt triton_max_batch_size:$max_batch_size ,decoupled_mode:true,accumulate_tokens:" False" ,bls_instance_count:1
86
+ python3 ../tools/fill_template.py -i preprocessing/config.pbtxt " triton_max_batch_size:2048,tokenizer_dir:$model_path ,preprocessing_instance_count:5"
87
+ python3 ../tools/fill_template.py -i postprocessing/config.pbtxt " triton_max_batch_size:2048,tokenizer_dir:$model_path ,postprocessing_instance_count:5,skip_special_tokens:false"
88
+ python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:" $max_batch_size "
89
+ python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt " triton_max_batch_size:$max_batch_size ,decoupled_mode:true,accumulate_tokens:False,bls_instance_count:1"
91
90
cd /tensorrtllm_backend
92
91
python3 scripts/launch_triton_server.py \
93
- --world_size=${model_tp_size} \
92
+ --world_size=" ${model_tp_size} " \
94
93
--model_repo=/tensorrtllm_backend/triton_model_repo &
95
94
96
95
}
97
96
98
97
launch_tgi_server () {
99
98
model=$( echo " $common_params " | jq -r ' .model' )
100
99
tp=$( echo " $common_params " | jq -r ' .tp' )
101
- dataset_name=$( echo " $common_params " | jq -r ' .dataset_name' )
102
- dataset_path=$( echo " $common_params " | jq -r ' .dataset_path' )
103
100
port=$( echo " $common_params " | jq -r ' .port' )
104
- num_prompts=$( echo " $common_params " | jq -r ' .num_prompts' )
105
101
server_args=$( json2args " $server_params " )
106
102
107
103
if echo " $common_params " | jq -e ' has("fp8")' > /dev/null; then
@@ -129,10 +125,7 @@ launch_tgi_server() {
129
125
launch_lmdeploy_server () {
130
126
model=$( echo " $common_params " | jq -r ' .model' )
131
127
tp=$( echo " $common_params " | jq -r ' .tp' )
132
- dataset_name=$( echo " $common_params " | jq -r ' .dataset_name' )
133
- dataset_path=$( echo " $common_params " | jq -r ' .dataset_path' )
134
128
port=$( echo " $common_params " | jq -r ' .port' )
135
- num_prompts=$( echo " $common_params " | jq -r ' .num_prompts' )
136
129
server_args=$( json2args " $server_params " )
137
130
138
131
server_command=" lmdeploy serve api_server $model \
@@ -149,10 +142,7 @@ launch_sglang_server() {
149
142
150
143
model=$( echo " $common_params " | jq -r ' .model' )
151
144
tp=$( echo " $common_params " | jq -r ' .tp' )
152
- dataset_name=$( echo " $common_params " | jq -r ' .dataset_name' )
153
- dataset_path=$( echo " $common_params " | jq -r ' .dataset_path' )
154
145
port=$( echo " $common_params " | jq -r ' .port' )
155
- num_prompts=$( echo " $common_params " | jq -r ' .num_prompts' )
156
146
server_args=$( json2args " $server_params " )
157
147
158
148
if echo " $common_params " | jq -e ' has("fp8")' > /dev/null; then
@@ -185,10 +175,7 @@ launch_vllm_server() {
185
175
186
176
model=$( echo " $common_params " | jq -r ' .model' )
187
177
tp=$( echo " $common_params " | jq -r ' .tp' )
188
- dataset_name=$( echo " $common_params " | jq -r ' .dataset_name' )
189
- dataset_path=$( echo " $common_params " | jq -r ' .dataset_path' )
190
178
port=$( echo " $common_params " | jq -r ' .port' )
191
- num_prompts=$( echo " $common_params " | jq -r ' .num_prompts' )
192
179
server_args=$( json2args " $server_params " )
193
180
194
181
if echo " $common_params " | jq -e ' has("fp8")' > /dev/null; then
@@ -217,19 +204,19 @@ launch_vllm_server() {
217
204
218
205
main () {
219
206
220
- if [[ $CURRENT_LLM_SERVING_ENGINE == " trt" ]]; then
207
+ if [[ " $CURRENT_LLM_SERVING_ENGINE " == " trt" ]]; then
221
208
launch_trt_server
222
209
fi
223
210
224
- if [[ $CURRENT_LLM_SERVING_ENGINE == " tgi" ]]; then
211
+ if [[ " $CURRENT_LLM_SERVING_ENGINE " == " tgi" ]]; then
225
212
launch_tgi_server
226
213
fi
227
214
228
- if [[ $CURRENT_LLM_SERVING_ENGINE == " lmdeploy" ]]; then
215
+ if [[ " $CURRENT_LLM_SERVING_ENGINE " == " lmdeploy" ]]; then
229
216
launch_lmdeploy_server
230
217
fi
231
218
232
- if [[ $CURRENT_LLM_SERVING_ENGINE == " sglang" ]]; then
219
+ if [[ " $CURRENT_LLM_SERVING_ENGINE " == " sglang" ]]; then
233
220
launch_sglang_server
234
221
fi
235
222
0 commit comments