Skip to content

Commit d23d2ab

Browse files
Large model inference (#2215)
* mv HF large model to new dir * adding HF models to the new dir * adding HF pippy * adding large model doc * adding large model utils * adding base handler for pippy * clean up * clean up * clean up * clean up * formatting * clean up * clean ups * adding max lenght * clean up * clean up * adding logger * update to latests * update to latests * update model-config * clean up * adding recent settings * update steps * update to use ctx * update to use ctx * update to use ctx * adding logger * remove setu_up config * clean up * changing max length * adding generate to the inference * adding better prompt * typos * clean up * adding test for pippy inference * uncommenting the padding * update steps * update instructions * update instructions * update instructions * addressing comments * adding rpc threads and manual seed * addressing comments * clean up * clean up * adding LMI doc * clean up * remove dist handler test * add utils here due to circular dependency * add check for pippy install * add large model post man json * addiing dist inference test * clean up * adding large model inference json * using ts start func with mar gen false * clean up * remove uneccesary move logs func * remove uneccesary move logs func * removing install from src * fix the model name * removing expected json type * adding assertion for exitcode * moving torchrun to frontend spec * moving torchrun to frontend spec * clean up * make sure for HF it returns the patched model that supports generate for others pipe_driver * adding torchpippy * fix typos * fix typos * extending the vocab * extending the vocab * fix typos * fix typos * update examples readme and fix deadinks * fixing typos --------- Co-authored-by: lxning <[email protected]>
1 parent 044bbc1 commit d23d2ab

36 files changed

+893
-202
lines changed

docs/contents.rst

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
:numbered:
44
:caption: Contents:
55
:titlesonly:
6-
6+
77
index
88
Troubleshooting
99
batch_inference_with_ts
@@ -23,14 +23,15 @@
2323
torchserve_on_wsl
2424
use_cases
2525
workflows
26+
large_model_inference
2627

2728
.. toctree::
2829
:maxdepth: 0
2930
:caption: Service APIs:
30-
31+
3132
apis
3233

3334
.. toctree::
3435
:caption: Developer APIs:
35-
36-
api/dev_api
36+
37+
api/dev_api

docs/large_model_inference.md

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Serving large models with Torchserve
2+
3+
This document explain how Torchserve supports large model serving, here large model refers to the models that are not able to fit into one gpu so they need be split in multiple partitions over multiple gpus.
4+
5+
## PiPPy (PyTorch Native solution for large model inference)
6+
7+
PiPPy provides pipeline parallelism for serving large models that would not fit into one gpu. It takes your model and splits it into equal sizes (stages) partitioned over the number devices you specify. Then uses microbatching to run your batched input for inference ( its is more optimal for batch sizes >1).
8+
9+
10+
## How to use PiPPy in Torchserve
11+
12+
To use Pippy in Torchserve, we need to use a custom handler which inherits from base_pippy_handler and put our setting in model-config.yaml.
13+
14+
Customer handler in Torchserve is simply a python script that defines model loading, preprocess, inference and postprocess logic specific to your workflow.
15+
16+
It would look like below:
17+
18+
Create `custom_handler.py` or any other descriptive name.
19+
20+
```python
21+
#DO import the necessary packages along with following
22+
from ts.torch_handler.distributed.base_pippy_handler import BasePippyHandler
23+
from ts.handler_utils.distributed.pt_pippy import initialize_rpc_workers, get_pipline_driver
24+
class ModelHandler(BasePippyHandler, ABC):
25+
def __init__(self):
26+
super(ModelHandler, self).__init__()
27+
self.initialized = False
28+
29+
def initialize(self, ctx):
30+
model = # load your model from model_dir
31+
self.device = self.local_rank % torch.cuda.device_count()# being used to move model inputs to (self.device)
32+
self.model = get_pipline_driver(model,self.world_size, ctx)
33+
34+
```
35+
36+
Here is what your `model-config.yaml` needs, this config file is very flexible, you can add setting related to frontend, backend and handler.
37+
38+
```bash
39+
#frontend settings
40+
minWorkers: 1
41+
maxWorkers: 1
42+
maxBatchDelay: 100
43+
responseTimeout: 120
44+
parallelLevel: 4
45+
deviceType: "gpu"
46+
parallelType: "pp" #options depending on the solution, pp(pipeline parallelism), tp(tensor parallelism), pptp ( pipeline and tensor parallelism)
47+
# This will be used to route input to either rank0 or all ranks from fontend based on the solution (e.g. DeepSpeed support tp, PiPPy support pp)
48+
torchrun:
49+
nproc-per-node: 4 # specifies the number of processes torchrun starts to serve your model, set to world_size or number of
50+
# gpus you wish to split your model
51+
#backend settings
52+
pippy:
53+
chunks: 1 # This sets the microbatch sizes, microbatch = batch size/ chunks
54+
input_names: ['input_ids'] # input arg names to the model, this is required for FX tracing
55+
model_type: "HF" # set the model type to HF if you are using Huggingface model other wise leave it blank or any other model you use.
56+
rpc_timeout: 1800
57+
num_worker_threads: 512 #set number of threads for rpc worker init.
58+
59+
handler:
60+
max_length: 80 # max length of tokens for tokenizer in the handler
61+
```
62+
63+
**How to access it in the handler?** here is an example:
64+
65+
```python
66+
def initialize(self, ctx):
67+
model_type = ctx.model_yaml_config["pippy"]["model_type"]
68+
69+
```
70+
71+
The rest is as usual in Torchserve, basically packaging your model and starting the server.
72+
73+
Example of the command for packaging your model, make sure you pass model-config.yaml
74+
75+
```bash
76+
torch-model-archiver --model-name bloom --version 1.0 --handler pippy_handler.py --extra-files --extra-files $MODEL_CHECKPOINTS_PATH -r requirements.txt --config-file model-config.yaml --archive-format tgz
77+
78+
```
79+
80+
Tensor Parallel support in progress and will be added as soon as ready.

examples/README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
* [Serving HuggingFace transformers model](Huggingface_Transformers)
2727

28+
### PiPPy [Serving Large Models with PyTorch Native Solution PiPPy](large_models/Huggingface_pippy/Readme.md)
29+
2830
### MLFlow <img src="images/mlflow.png" width="50" title="MLFlow" style="float:right padding:20px" />
2931

3032
* [Deploy models using `mlflow-torchserve` plugin](https://github.com/mlflow/mlflow-torchserve/tree/master/examples)
@@ -43,7 +45,7 @@
4345

4446
### Microsoft DeepSpeed-MII <img src="images/mii-white.svg" width="80" title="DeepSpeed MII" style="float:top" />
4547

46-
* [HuggingFace Stable Diffusion Model with Microsoft DeepSpeed-MII](deepspeed_mii)
48+
* [HuggingFace Stable Diffusion Model with Microsoft DeepSpeed-MII](large_models/deepspeed_mii/Readme.md)
4749

4850
### Prometheus and mtail <img src="images/prometheus-logo.svg" width="30" title="Prometheus" style="float:top" />
4951

@@ -66,8 +68,8 @@
6668
### Stable Diffusion <img src="images/huggingface_logo-noborder.svg" width="30" height="30" title="Hugging Face" style="float:right padding:10px" />
6769
* [Stable Diffusion using HuggingFace Diffusers](diffusers)
6870

69-
### HuggingFace Large Models <img src="images/huggingface_logo-noborder.svg" width="30" height="30" title="Hugging Face" style="float:right padding:10px" />
70-
* [HuggingFace Large Models with constrained resources](Huggingface_Largemodels)
71+
### HuggingFace Large Models with Accelerate <img src="images/huggingface_logo-noborder.svg" width="30" height="30" title="Hugging Face" style="float:right padding:10px" />
72+
* [HuggingFace Large Models with constrained resources](large_models/Huggingface_accelerate/Readme.md)
7173

7274
## UseCases
7375

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Loading large Huggingface models with PiPPy (PyTorch Native Large inference solution)
2+
3+
This document briefs on serving large HF model with PiPPy.
4+
5+
PiPPy provides pipeline parallelism for serving large models that would not fit into one gpu. It takes your model and splits it into equal sizes (stages) partitioned over the number devices you specify. Then uses micro batching to run your batched input for inference ( its is more optimal for batch sizes >1). Micro-batching is the techniques in pipeline parallelism to maximize gpu utilization.
6+
7+
## How to serve your large HuggingFace models with PiPPy in Torchserve?
8+
9+
We use a Torchserve custom handler that inherits from base_pippy_handler to load the model and define our logic for preprocess, inference and post processing. This is basically very similar to your evaluation process.
10+
11+
### Step 1: Download model
12+
13+
Login into huggingface hub with token by running the below command
14+
15+
```bash
16+
huggingface-cli login
17+
```
18+
paste the token generated from huggingface hub.
19+
20+
```bash
21+
python Download_model.py --model_name facebook/opt-6.7b
22+
```
23+
The script prints the path where the model is downloaded as below. This is an example and in your workload you want to use your actual trained model checkpoints.
24+
25+
`model/models--bigscience-bloom-7b1/snapshots/5546055f03398095e385d7dc625e636cc8910bf2/`
26+
27+
The downloaded model is around 14GB.
28+
29+
30+
### Step 2: Create a model-config.yaml with that include following
31+
32+
```bash
33+
34+
minWorkers: 1
35+
maxWorkers: 1
36+
maxBatchDelay: 100
37+
responseTimeout: 120
38+
parallelLevel: 4
39+
deviceType: "gpu"
40+
parallelType: "pp" #PiPPy as the solution for distributed inference
41+
torchrun:
42+
nproc-per-node: 4 # specifies the number of processes torchrun starts to serve your model, set to world_size or number of
43+
# gpus you wish to split your model
44+
pippy:
45+
chunks: 1 # This sets the microbatch sizes, microbatch = batch size/ chunks
46+
input_names: ['input_ids'] # input arg names to the model, this is required for FX tracing
47+
model_type: "HF" # set the model type to HF if you are using Huggingface model other wise leave it blank or any other model you use.
48+
rpc_timeout: 1800
49+
50+
handler:
51+
max_length: 80 # max length of tokens for tokenizer in the handler
52+
```
53+
54+
### Step 3: Generate Tar/ MAR file
55+
56+
Navigate up to `Huggingface_Largemodels` directory.
57+
58+
```bash
59+
torch-model-archiver --model-name bloom --version 1.0 --handler pippy_handler.py --extra-files model/models--facebook--opt-iml-max-1.3b/snapshots/d60fa58f50def19751da2075791da359ca19d273 -r requirements.txt --config-file model-config.yaml --archive-format tgz
60+
61+
```
62+
63+
### Step 4: Add the mar file to model store
64+
65+
```bash
66+
mkdir model_store
67+
mv bloom.mar model_store
68+
```
69+
70+
### Step 5: Start torchserve
71+
72+
Update config.properties and start torchserve
73+
74+
```bash
75+
torchserve --ncs --start --model-store model_store --models bloom.mar
76+
```
77+
78+
### Step 6: Run inference
79+
80+
```bash
81+
curl -v "http://localhost:8080/predictions/bloom" -T sample_text.txt
82+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#frontend settings
2+
minWorkers: 1
3+
maxWorkers: 1
4+
maxBatchDelay: 100
5+
responseTimeout: 120
6+
parallelType: "pp"
7+
deviceType: "gpu"
8+
torchrun:
9+
nproc-per-node: 4
10+
11+
#backend settings
12+
pippy:
13+
rpc_timeout: 1800
14+
model_type: "HF"
15+
chunks: 1
16+
input_names: ["input_ids"]
17+
num_worker_threads: 512
18+
19+
handler:
20+
max_length: 50
21+
manual_seed: 40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import logging
2+
import time
3+
from abc import ABC
4+
5+
import requests
6+
import torch
7+
import transformers
8+
from transformers import AutoModelForCausalLM, AutoTokenizer
9+
10+
from ts.handler_utils.distributed.pt_pippy import get_pipeline_driver
11+
from ts.torch_handler.distributed.base_pippy_handler import BasePippyHandler
12+
13+
logger = logging.getLogger(__name__)
14+
logger.info("Transformers version %s", transformers.__version__)
15+
16+
17+
class TransformersSeqClassifierHandler(BasePippyHandler, ABC):
18+
"""
19+
Transformers handler class for sequence, token classification and question answering.
20+
"""
21+
22+
def __init__(self):
23+
super(TransformersSeqClassifierHandler, self).__init__()
24+
self.initialized = False
25+
26+
def initialize(self, ctx):
27+
"""In this initialize function, the HF large model is loaded and
28+
partitioned into multiple stages each on one device using PiPPy.
29+
Args:
30+
ctx (context): It is a JSON Object containing information
31+
pertaining to the model artefacts parameters.
32+
"""
33+
super().initialize(ctx)
34+
self.manifest = ctx.manifest
35+
properties = ctx.system_properties
36+
model_dir = properties.get("model_dir")
37+
self.device = self.local_rank
38+
39+
seed = ctx.model_yaml_config["handler"]["manual_seed"]
40+
torch.manual_seed(seed)
41+
42+
self.model = AutoModelForCausalLM.from_pretrained(model_dir, use_cache=False)
43+
44+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, return_tensors="pt")
45+
46+
self.max_length = ctx.model_yaml_config["handler"]["max_length"]
47+
48+
logger.info("Instantiating model Pipeline")
49+
model_init_start = time.time()
50+
self.model = get_pipeline_driver(self.model, self.world_size, ctx)
51+
52+
logger.info("Transformer model from path %s loaded successfully", model_dir)
53+
54+
self.initialized = True
55+
56+
def preprocess(self, requests):
57+
"""
58+
Basic text preprocessing, based on the user's choice of application mode.
59+
Args:
60+
requests (list): A list of dictionaries with a "data" or "body" field, each
61+
containing the input text to be processed.
62+
Returns:
63+
tuple: A tuple with two tensors: the batch of input ids and the batch of
64+
attention masks.
65+
"""
66+
input_texts = [data.get("data") or data.get("body") for data in requests]
67+
input_ids_batch, attention_mask_batch = [], []
68+
for input_text in input_texts:
69+
input_ids, attention_mask = self.encode_input_text(input_text)
70+
input_ids_batch.append(input_ids)
71+
attention_mask_batch.append(attention_mask)
72+
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.device)
73+
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
74+
return input_ids_batch, attention_mask_batch
75+
76+
def encode_input_text(self, input_text):
77+
"""
78+
Encodes a single input text using the tokenizer.
79+
Args:
80+
input_text (str): The input text to be encoded.
81+
Returns:
82+
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
83+
"""
84+
if isinstance(input_text, (bytes, bytearray)):
85+
input_text = input_text.decode("utf-8")
86+
logger.info("Received text: '%s'", input_text)
87+
inputs = self.tokenizer.encode_plus(
88+
input_text,
89+
max_length=self.max_length,
90+
pad_to_max_length=True,
91+
add_special_tokens=True,
92+
return_tensors="pt",
93+
)
94+
input_ids = inputs["input_ids"]
95+
attention_mask = inputs["attention_mask"]
96+
return input_ids, attention_mask
97+
98+
def inference(self, input_batch):
99+
"""
100+
Predicts the class (or classes) of the received text using the serialized transformers
101+
checkpoint.
102+
Args:
103+
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
104+
of attention masks, as returned by the preprocess function.
105+
Returns:
106+
list: A list of strings with the predicted values for each input text in the batch.
107+
"""
108+
input_ids_batch, attention_mask_batch = input_batch
109+
input_ids_batch = input_ids_batch.to(self.device)
110+
outputs = self.model.generate(
111+
input_ids_batch,
112+
attention_mask=attention_mask_batch,
113+
max_length=30,
114+
)
115+
116+
inferences = [
117+
self.tokenizer.batch_decode(
118+
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
119+
)
120+
]
121+
logger.info("Generated text: %s", inferences)
122+
return inferences
123+
124+
def postprocess(self, inference_output):
125+
"""Post Process Function converts the predicted response into Torchserve readable format.
126+
Args:
127+
inference_output (list): It contains the predicted response of the input text.
128+
Returns:
129+
(list): Returns a list of the Predictions and Explanations.
130+
"""
131+
return inference_output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers
2+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Hey, are you conscious? Can you talk to me?

0 commit comments

Comments
 (0)