Skip to content

FEAT: Support RWKV Pile #535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/getting_started/using_xinference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ Logging in Xinference
Xinference supports log rotation of log files. By default, logs rotate when they reach 100MB (maxBytes),
and up to 30 backup files (backupCount) are kept.

All the logs are stored in the ``<XINFERENCE_HOME>/logs`` directory, where ``<XINFERENCE_HOME>`` can be configured as above.
All the logs are stored in the ``<XINFERENCE_HOME>/logs`` directory, where ``<XINFERENCE_HOME>`` can be configured as above.
1 change: 1 addition & 0 deletions doc/source/models/builtin/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Chat & Instruction-following Models
- :ref:`OpenBuddy v11.1 <models_builtin_openbuddy_v11.1>`
- :ref:`Orca Mini <models_builtin_orca_mini>`
- :ref:`Qwen Chat <models_builtin_qwen_chat>`
- :ref:`RWKV-4-Pile <models_builtin_rwkv_4_pile>`
- :ref:`Vicuna v1.3 <models_builtin_vicuna_v1_3>`
- :ref:`Vicuna v1.5 <models_builtin_vicuna_v1_5>`
- :ref:`Vicuna v1.5 16k <models_builtin_vicuna_v1_5_16k>`
Expand Down
73 changes: 73 additions & 0 deletions doc/source/models/builtin/rwkv-4-pile.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
.. _models_builtin_rwkv_4_pile:

=============
RWKV-4-Pile
=============

- **Model Name:** rwkv-4-pile
- **Languages:** en
- **Abilities:** generate
- **Description:** The RWKV (Receptance-Weighted Key-Value) models are a series of language models that ranges in size. They are noteworthy for their large-scale implementation and innovative architecture, which combines the strengths of RNN and Transformer models.

Specifications
^^^^^^^^^^^^^^

Model Spec 1 (pytorch, 1 Billion)
+++++++++++++++++++++++++++++++++

- **Model Format:** pytorch
- **Model Size (in billions):** 1
- **Quantizations:** none
- **Model ID:** RWKV/rwkv-4-430m-pile

Execute the following command to launch the model::

xinference launch --model-name rwkv-4-pile --size-in-billions 1 --model-format pytorch --quantization none

Model Spec 2 (pytorch, 2 Billion)
+++++++++++++++++++++++++++++++++

- **Model Format:** pytorch
- **Model Size (in billions):** 2
- **Quantizations:** none
- **Model ID:** RWKV/rwkv-4-1b5-pile

Execute the following command to launch the model::

xinference launch --model-name rwkv-4-pile --size-in-billions 2 --model-format pytorch --quantization none

Model Spec 3 (pytorch, 3 Billion)
+++++++++++++++++++++++++++++++++

- **Model Format:** pytorch
- **Model Size (in billions):** 3
- **Quantizations:** none
- **Model ID:** RWKV/rwkv-4-3b-pile

Execute the following command to launch the model::

xinference launch --model-name rwkv-4-pile --size-in-billions 3 --model-format pytorch --quantization none

Model Spec 4 (pytorch, 7 Billion)
+++++++++++++++++++++++++++++++++

- **Model Format:** pytorch
- **Model Size (in billions):** 7
- **Quantizations:** none
- **Model ID:** RWKV/rwkv-4-7b-pile

Execute the following command to launch the model::

xinference launch --model-name rwkv-4-pile --size-in-billions 7 --model-format pytorch --quantization none

Model Spec 5 (pytorch, 14 Billion)
+++++++++++++++++++++++++++++++++

- **Model Format:** pytorch
- **Model Size (in billions):** 14
- **Quantizations:** none
- **Model ID:** RWKV/rwkv-4-14b-pile

Execute the following command to launch the model::

xinference launch --model-name rwkv-4-pile --size-in-billions 14 --model-format pytorch --quantization none
2 changes: 2 additions & 0 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _install():
from .pytorch.core import PytorchChatModel, PytorchModel
from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
from .pytorch.rwkv import RWKVPilePytorchModel
from .pytorch.vicuna import VicunaPytorchChatModel
from .vllm.core import VLLMChatModel, VLLMModel

Expand All @@ -68,6 +69,7 @@ def _install():
LLM_CLASSES.extend(
[
BaichuanPytorchChatModel,
RWKVPilePytorchModel,
VicunaPytorchChatModel,
FalconPytorchChatModel,
ChatglmPytorchChatModel,
Expand Down
59 changes: 59 additions & 0 deletions xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,65 @@
}
]
},
{
"version": 1,
"context_length": 2048,
"model_name": "rwkv-4-pile",
"model_lang": [
"en"
],
"model_ability": [
"generate"
],
"model_description": "The RWKV-4-Pile model is a L12-D768 causal language model trained on the Pile.",
"model_specs": [
{
"model_format": "pytorch",
"model_size_in_billions": 1,
"quantizations": [
"none"
],
"model_id": "RWKV/rwkv-4-430m-pile",
"model_revision": "6f7bcd0e5b0851cc00710a5c7d2ef5f5f24363a5"
},
{
"model_format": "pytorch",
"model_size_in_billions": 2,
"quantizations": [
"none"
],
"model_id": "RWKV/rwkv-4-1b5-pile",
"model_revision": "643585471eaf5821d94dfcb498ab5b94a36b42cf"
},
{
"model_format": "pytorch",
"model_size_in_billions": 3,
"quantizations": [
"none"
],
"model_id": "RWKV/rwkv-4-3b-pile",
"model_revision": "7fdda3c5570d4a9711f8f02cc3a20941a5623cd3"
},
{
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantizations": [
"none"
],
"model_id": "RWKV/rwkv-4-7b-pile",
"model_revision": "922e22a761427e50d7be457b31a76b1126021b8b"
},
{
"model_format": "pytorch",
"model_size_in_billions": 14,
"quantizations": [
"none"
],
"model_id": "RWKV/rwkv-4-14b-pile",
"model_revision": "4effb0fa9d15c2f383a1d159f4a40df0e09eb6d5"
}
]
},
{
"version": 1,
"context_length": 2048,
Expand Down
12 changes: 12 additions & 0 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def match(
"chatglm2-32k",
"llama-2",
"llama-2-chat",
"rwkv-4-pile",
]:
return False
if "generate" not in llm_family.model_ability:
Expand All @@ -241,6 +242,7 @@ def generate(
generate_stream,
generate_stream_chatglm,
generate_stream_falcon,
generate_stream_rwkv,
)

def generator_wrapper(
Expand All @@ -266,6 +268,11 @@ def generator_wrapper(
generate_config,
):
yield completion_chunk
elif "rwkv" in self.model_family.model_name:
for completion_chunk, _ in generate_stream_rwkv(
self._model, self._tokenizer, prompt, self._device, generate_config
):
yield completion_chunk
else:
for completion_chunk, _ in generate_stream(
self.model_uid,
Expand Down Expand Up @@ -308,6 +315,11 @@ def generator_wrapper(
generate_config,
):
pass
elif "rwkv" in self.model_family.model_name:
for completion_chunk, completion_usage in generate_stream_rwkv(
self._model, self._tokenizer, prompt, self._device, generate_config
):
pass
else:
for completion_chunk, completion_usage in generate_stream(
self.model_uid,
Expand Down
77 changes: 77 additions & 0 deletions xinference/model/llm/pytorch/rwkv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from ..llm_family import LLMFamilyV1, LLMSpecV1
from .core import PytorchModel, PytorchModelConfig


class RWKVPilePytorchModel(PytorchModel):
def __init__(
self,
model_uid: str,
model_family: "LLMFamilyV1",
model_spec: "LLMSpecV1",
quantization: str,
model_path: str,
pytorch_model_config: Optional[PytorchModelConfig] = None,
):
super().__init__(
model_uid,
model_family,
model_spec,
quantization,
model_path,
pytorch_model_config=pytorch_model_config,
)

def _load_model(self, **kwargs):
try:
from transformers import AutoTokenizer, RwkvForCausalLM
except ImportError:
error_message = "Failed to import module 'transformers'"
installation_guide = [
"Please make sure 'transformers' is installed. ",
"You can install it by `pip install transformers`\n",
]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=kwargs["trust_remote_code"],
revision=kwargs["revision"],
)
model = RwkvForCausalLM.from_pretrained(
self.model_path,
low_cpu_mem_usage=True,
**kwargs,
)
tokenizer.pad_token_id = 9
return model, tokenizer

@classmethod
def match(
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
) -> bool:
if llm_spec.model_format != "pytorch":
return False
if "rwkv" not in llm_family.model_name:
return False
if "pile" not in llm_family.model_name:
return False
if "generate" not in llm_family.model_ability:
return False
return True
Loading