diff --git a/benchmarks/README.md b/benchmarks/README.md index 6f9fbb91cbd..b39f4ff6ca1 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -58,6 +58,12 @@ become available. ✅ AI-MO/aimo-validation-aime , AI-MO/NuminaMath-1.5, AI-MO/NuminaMath-CoT + + HuggingFace-Unsloth + ✅ + ✅ + unsloth/LaTeX_OCR, unsloth/Radiology_mini + HuggingFace-Other ✅ @@ -251,6 +257,49 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 80 ``` +**`unsloth/LaTeX_OCR`** + +``` bash +# Serve the model +vllm serve unsloth/Qwen2-VL-2B-Instruct \ + --dtype bfloat16 \ + --max-model-len 4096 \ + --max-num-seqs 5 \ + --limit-mm-per-prompt "image=1,video=0" \ + --max-seq-len-to-capture 4096 \ + --mm-processor-kwargs '{"min_pixels": 784, "max_pixels": 1003520}' +``` + +``` bash +python3 vllm/benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --request-rate 5 \ + --max-concurrency 5 \ + --model unsloth/Qwen2-VL-2B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path unsloth/LaTeX_OCR \ + --hf-split train \ + --hf-output-len 256 \ + --num-prompts 1000 +``` + +**`unsloth/Radiology_mini`** + +``` bash +python3 vllm/benchmarks/benchmark_serving.py \ + --backend openai-chat \ + --request-rate 5 \ + --max-concurrency 5 \ + --model unsloth/Qwen2-VL-2B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path unsloth/Radiology_mini \ + --hf-split train \ + --hf-output-len 256 \ + --num-prompts 1000 +``` + ### Running With Sampling Parameters When using OpenAI-compatible backends such as `vllm`, optional sampling @@ -371,6 +420,30 @@ python3 benchmarks/benchmark_throughput.py \ --num-prompts 10 ``` +**`unsloth/LaTeX_OCR`** + +```bash +python3 vllm/benchmarks/benchmark_throughput.py \ + --model unsloth/Qwen2-VL-2B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path unsloth/LaTeX_OCR \ + --hf-split train \ + --num-prompts 1000 +``` + +**`unsloth/Radiology_mini`** + +```bash +python3 vllm/benchmarks/benchmark_throughput.py \ + --model unsloth/Qwen2-VL-2B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path unsloth/Radiology_mini \ + --hf-split train \ + --num-prompts 1000 +``` + ### Benchmark with LoRA Adapters ``` bash diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 8671719bce7..485b336d941 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -832,6 +832,66 @@ def sample( return sampled_requests +# ----------------------------------------------------------------------------- +# Unsloth Vision Dataset Implementation +# ----------------------------------------------------------------------------- + +class UnslothVisionDataset(HuggingFaceDataset): + """ + Unsloth Vision Dataset. + """ + + DEFAULT_OUTPUT_LEN = 256 + SUPPORTED_DATASET_PATHS = { + "unsloth/LaTeX_OCR": lambda x: ( + "Write the LaTeX representation for this image.", + x["image"] + ), + "unsloth/Radiology_mini": lambda x: ( + "You are an expert radiographer. Describe accurately what you see in this image.", + x["image"] + ), + } + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN + sampled_requests = [] + + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, mm_content = parser_fn(item) + mm_content = process_image(mm_content) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + # ----------------------------------------------------------------------------- # Instruct Coder Dataset Implementation # ----------------------------------------------------------------------------- diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index f38e45b2611..2d7fbd7af11 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -70,6 +70,7 @@ SampleRequest, ShareGPTDataset, SonnetDataset, + UnslothVisionDataset, VisionArenaDataset, ) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json @@ -687,6 +688,10 @@ def main(args: argparse.Namespace): elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: dataset_class = ASRDataset args.hf_split = "train" + elif args.dataset_path in UnslothVisionDataset.SUPPORTED_DATASET_PATHS: + dataset_class = UnslothVisionDataset + args.hf_split = "train" + args.hf_subset = None else: supported_datasets = set( [ diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 0ded34c70ba..3bfb6d923d4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -25,6 +25,7 @@ SampleRequest, ShareGPTDataset, SonnetDataset, + UnslothVisionDataset, VisionArenaDataset, ) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json @@ -373,6 +374,12 @@ def get_requests(args, tokenizer): dataset_cls = AIMODataset common_kwargs["dataset_subset"] = None common_kwargs["dataset_split"] = "train" + elif args.dataset_path in UnslothVisionDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = UnslothVisionDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values @@ -527,6 +534,7 @@ def validate_args(args): if args.dataset_path in ( VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() | ConversationDataset.SUPPORTED_DATASET_PATHS + | UnslothVisionDataset.SUPPORTED_DATASET_PATHS.keys() ): assert args.backend == "vllm-chat", ( f"{args.dataset_path} needs to use vllm-chat as the backend."