diff --git a/doc/source/getting_started/using_xinference.rst b/doc/source/getting_started/using_xinference.rst index 3401060625..fdb45fdd29 100644 --- a/doc/source/getting_started/using_xinference.rst +++ b/doc/source/getting_started/using_xinference.rst @@ -61,4 +61,4 @@ Logging in Xinference Xinference supports log rotation of log files. By default, logs rotate when they reach 100MB (maxBytes), and up to 30 backup files (backupCount) are kept. -All the logs are stored in the ``/logs`` directory, where ```` can be configured as above. +All the logs are stored in the ``/logs`` directory, where ```` can be configured as above. \ No newline at end of file diff --git a/doc/source/models/builtin/index.rst b/doc/source/models/builtin/index.rst index a4bfb256c8..c3d33e98c9 100644 --- a/doc/source/models/builtin/index.rst +++ b/doc/source/models/builtin/index.rst @@ -40,6 +40,7 @@ Chat & Instruction-following Models - :ref:`OpenBuddy v11.1 ` - :ref:`Orca Mini ` - :ref:`Qwen Chat ` +- :ref:`RWKV-4-Pile ` - :ref:`Vicuna v1.3 ` - :ref:`Vicuna v1.5 ` - :ref:`Vicuna v1.5 16k ` diff --git a/doc/source/models/builtin/rwkv-4-pile.rst b/doc/source/models/builtin/rwkv-4-pile.rst new file mode 100644 index 0000000000..3fb0f8e843 --- /dev/null +++ b/doc/source/models/builtin/rwkv-4-pile.rst @@ -0,0 +1,73 @@ +.. _models_builtin_rwkv_4_pile: + +============= +RWKV-4-Pile +============= + +- **Model Name:** rwkv-4-pile +- **Languages:** en +- **Abilities:** generate +- **Description:** The RWKV (Receptance-Weighted Key-Value) models are a series of language models that ranges in size. They are noteworthy for their large-scale implementation and innovative architecture, which combines the strengths of RNN and Transformer models. + +Specifications +^^^^^^^^^^^^^^ + +Model Spec 1 (pytorch, 1 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 1 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-430m-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 1 --model-format pytorch --quantization none + +Model Spec 2 (pytorch, 2 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 2 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-1b5-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 2 --model-format pytorch --quantization none + +Model Spec 3 (pytorch, 3 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 3 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-3b-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 3 --model-format pytorch --quantization none + +Model Spec 4 (pytorch, 7 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 7 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-7b-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 7 --model-format pytorch --quantization none + +Model Spec 5 (pytorch, 14 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 14 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-14b-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 14 --model-format pytorch --quantization none diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index 6271c9bec0..71c2200a16 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -44,6 +44,7 @@ def _install(): from .pytorch.core import PytorchChatModel, PytorchModel from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel + from .pytorch.rwkv import RWKVPilePytorchModel from .pytorch.vicuna import VicunaPytorchChatModel from .vllm.core import VLLMChatModel, VLLMModel @@ -68,6 +69,7 @@ def _install(): LLM_CLASSES.extend( [ BaichuanPytorchChatModel, + RWKVPilePytorchModel, VicunaPytorchChatModel, FalconPytorchChatModel, ChatglmPytorchChatModel, diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index a2b3c4234b..b327c9c3b4 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2018,6 +2018,65 @@ } ] }, + { + "version": 1, + "context_length": 2048, + "model_name": "rwkv-4-pile", + "model_lang": [ + "en" + ], + "model_ability": [ + "generate" + ], + "model_description": "The RWKV-4-Pile model is a L12-D768 causal language model trained on the Pile.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 1, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-430m-pile", + "model_revision": "6f7bcd0e5b0851cc00710a5c7d2ef5f5f24363a5" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 2, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-1b5-pile", + "model_revision": "643585471eaf5821d94dfcb498ab5b94a36b42cf" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 3, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-3b-pile", + "model_revision": "7fdda3c5570d4a9711f8f02cc3a20941a5623cd3" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-7b-pile", + "model_revision": "922e22a761427e50d7be457b31a76b1126021b8b" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 14, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-14b-pile", + "model_revision": "4effb0fa9d15c2f383a1d159f4a40df0e09eb6d5" + } + ] + }, { "version": 1, "context_length": 2048, diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 2537250c13..040f664963 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -228,6 +228,7 @@ def match( "chatglm2-32k", "llama-2", "llama-2-chat", + "rwkv-4-pile", ]: return False if "generate" not in llm_family.model_ability: @@ -241,6 +242,7 @@ def generate( generate_stream, generate_stream_chatglm, generate_stream_falcon, + generate_stream_rwkv, ) def generator_wrapper( @@ -266,6 +268,11 @@ def generator_wrapper( generate_config, ): yield completion_chunk + elif "rwkv" in self.model_family.model_name: + for completion_chunk, _ in generate_stream_rwkv( + self._model, self._tokenizer, prompt, self._device, generate_config + ): + yield completion_chunk else: for completion_chunk, _ in generate_stream( self.model_uid, @@ -308,6 +315,11 @@ def generator_wrapper( generate_config, ): pass + elif "rwkv" in self.model_family.model_name: + for completion_chunk, completion_usage in generate_stream_rwkv( + self._model, self._tokenizer, prompt, self._device, generate_config + ): + pass else: for completion_chunk, completion_usage in generate_stream( self.model_uid, diff --git a/xinference/model/llm/pytorch/rwkv.py b/xinference/model/llm/pytorch/rwkv.py new file mode 100644 index 0000000000..74fdec704f --- /dev/null +++ b/xinference/model/llm/pytorch/rwkv.py @@ -0,0 +1,77 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ..llm_family import LLMFamilyV1, LLMSpecV1 +from .core import PytorchModel, PytorchModelConfig + + +class RWKVPilePytorchModel(PytorchModel): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + pytorch_model_config: Optional[PytorchModelConfig] = None, + ): + super().__init__( + model_uid, + model_family, + model_spec, + quantization, + model_path, + pytorch_model_config=pytorch_model_config, + ) + + def _load_model(self, **kwargs): + try: + from transformers import AutoTokenizer, RwkvForCausalLM + except ImportError: + error_message = "Failed to import module 'transformers'" + installation_guide = [ + "Please make sure 'transformers' is installed. ", + "You can install it by `pip install transformers`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=kwargs["trust_remote_code"], + revision=kwargs["revision"], + ) + model = RwkvForCausalLM.from_pretrained( + self.model_path, + low_cpu_mem_usage=True, + **kwargs, + ) + tokenizer.pad_token_id = 9 + return model, tokenizer + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format != "pytorch": + return False + if "rwkv" not in llm_family.model_name: + return False + if "pile" not in llm_family.model_name: + return False + if "generate" not in llm_family.model_ability: + return False + return True diff --git a/xinference/model/llm/pytorch/utils.py b/xinference/model/llm/pytorch/utils.py index 917e87c98a..0d2252d2ab 100644 --- a/xinference/model/llm/pytorch/utils.py +++ b/xinference/model/llm/pytorch/utils.py @@ -596,3 +596,242 @@ def generate_stream_chatglm( ) yield completion_chunk, completion_usage + + +@torch.inference_mode() +def generate_stream_rwkv( + model, + tokenizer, + prompt, + device, + generate_config, + judge_sent_end=False, +) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]: + context_len = get_context_length(model.config) + stream_interval = generate_config.get("stream_interval", 2) + stream = generate_config.get("stream", False) + + len_prompt = len(prompt) + + temperature = float(generate_config.get("temperature", 1.0)) + repetition_penalty = float(generate_config.get("repetition_penalty", 1.0)) + top_p = float(generate_config.get("top_p", 1.0)) + top_k = int(generate_config.get("top_k", -1)) # -1 means disable + max_new_tokens = int(generate_config.get("max_tokens", 256)) + echo = bool(generate_config.get("echo", False)) + stop_str = generate_config.get("stop", None) + stop_token_ids = generate_config.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + + if "qwen" in str(type(model)).lower(): + # TODO: hacky + input_ids = tokenizer(prompt, allowed_special="all").input_ids + else: + input_ids = tokenizer(prompt).input_ids + output_ids = list(input_ids) + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + + # RWKV generates using previous states of the model instead of previous tokens + state = out = None + sent_interrupt = False + token = None + last_output_length = 0 + for i in range(max_new_tokens): + if i == 0: + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + state = out.state + else: + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], device=device + ), + encoder_hidden_states=encoder_output, + use_cache=True, + state=state if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], device=device + ), + use_cache=True, + state=state if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + state = out.state + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + if stream: + tmp_output_length = len(output) + output = output[last_output_length:] + last_output_length = tmp_output_length + + # prevent yielding partial stop sequence + if not partially_stopped: + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=None + ) + completion_chunk = CompletionChunk( + id=str(uuid.uuid1()), + object="text_completion", + created=int(time.time()), + model=generate_config["model"], + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=input_echo_len, + completion_tokens=i, + total_tokens=(input_echo_len + i), + ) + + yield completion_chunk, completion_usage + + if stopped: + break + + # finish stream event, which contains finish reason + if stopped: + finish_reason = "stop" + elif i == max_new_tokens - 1: + finish_reason = "length" + else: + finish_reason = None + + if stream: + completion_choice = CompletionChoice( + text="", index=0, logprobs=None, finish_reason=finish_reason + ) + else: + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=finish_reason + ) + + completion_chunk = CompletionChunk( + id=str(uuid.uuid1()), + object="text_completion", + created=int(time.time()), + model=generate_config["model"], + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=input_echo_len, + completion_tokens=i, + total_tokens=(input_echo_len + i), + ) + + yield completion_chunk, completion_usage + + # clean + del state, out + gc.collect() + torch.cuda.empty_cache()