From c95da95da3e733475932969ab163bd0188ff4be1 Mon Sep 17 00:00:00 2001 From: Bojun Feng Date: Fri, 13 Oct 2023 03:17:43 -0500 Subject: [PATCH 1/6] add and test 169m --- xinference/model/llm/__init__.py | 2 + xinference/model/llm/llm_family.json | 23 +++ xinference/model/llm/pytorch/core.py | 12 ++ xinference/model/llm/pytorch/rwkv.py | 78 +++++++++ xinference/model/llm/pytorch/utils.py | 238 ++++++++++++++++++++++++++ 5 files changed, 353 insertions(+) create mode 100644 xinference/model/llm/pytorch/rwkv.py diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index abb44caaa8..a1d2fe3417 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -44,6 +44,7 @@ def _install(): from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel from .pytorch.vicuna import VicunaPytorchChatModel + from .pytorch.rwkv import RWKVPilePytorchModel from .vllm.core import VLLMChatModel, VLLMModel # register llm classes. @@ -67,6 +68,7 @@ def _install(): LLM_CLASSES.extend( [ BaichuanPytorchChatModel, + RWKVPilePytorchModel, VicunaPytorchChatModel, FalconPytorchChatModel, ChatglmPytorchChatModel, diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 957585f18c..cf7a391ffc 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -1756,6 +1756,29 @@ } ] }, + { + "version": 1, + "context_length": 2048, + "model_name": "rwkv-4-pile", + "model_lang": [ + "en" + ], + "model_ability": [ + "generate" + ], + "model_description": "The RWKV-4-Pile model is a L12-D768 causal language model trained on the Pile.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 1, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-169m-pile", + "model_revision": "46bdc280eb97b6141d5d51a935e0c4870ecaefcc" + } + ] + }, { "version": 1, "context_length": 2048, diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 99af2afb66..90c279fa64 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -251,6 +251,7 @@ def match( "chatglm2-32k", "llama-2", "llama-2-chat", + "rwkv-4-pile", ]: return False if "generate" not in llm_family.model_ability: @@ -264,6 +265,7 @@ def generate( generate_stream, generate_stream_chatglm, generate_stream_falcon, + generate_stream_rwkv, ) def generator_wrapper( @@ -279,6 +281,11 @@ def generator_wrapper( self._model, self._tokenizer, prompt, self._device, generate_config ): yield completion_chunk + elif "rwkv" in self.model_family.model_name: + for completion_chunk, _ in generate_stream_rwkv( + self._model, self._tokenizer, prompt, self._device, generate_config + ): + yield completion_chunk else: for completion_chunk, _ in generate_stream( self._model, self._tokenizer, prompt, self._device, generate_config @@ -306,6 +313,11 @@ def generator_wrapper( self._model, self._tokenizer, prompt, self._device, generate_config ): pass + elif "rwkv" in self.model_family.model_name: + for completion_chunk, completion_usage in generate_stream_rwkv( + self._model, self._tokenizer, prompt, self._device, generate_config + ): + pass else: for completion_chunk, completion_usage in generate_stream( self._model, self._tokenizer, prompt, self._device, generate_config diff --git a/xinference/model/llm/pytorch/rwkv.py b/xinference/model/llm/pytorch/rwkv.py new file mode 100644 index 0000000000..c436d2808f --- /dev/null +++ b/xinference/model/llm/pytorch/rwkv.py @@ -0,0 +1,78 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ..llm_family import LLMFamilyV1, LLMSpecV1 +from .core import PytorchModel, PytorchModelConfig + + +class RWKVPilePytorchModel(PytorchModel): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + pytorch_model_config: Optional[PytorchModelConfig] = None, + ): + super().__init__( + model_uid, + model_family, + model_spec, + quantization, + model_path, + pytorch_model_config=pytorch_model_config, + ) + + def _load_model(self, kwargs: dict): + try: + from transformers import RwkvForCausalLM, AutoTokenizer + except ImportError: + error_message = "Failed to import module 'transformers'" + installation_guide = [ + "Please make sure 'transformers' is installed. ", + "You can install it by `pip install transformers`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=kwargs["trust_remote_code"], + revision=kwargs["revision"], + ) + model = RwkvForCausalLM.from_pretrained( + self.model_path, + low_cpu_mem_usage=True, + **kwargs, + ) + tokenizer.pad_token_id = 9 + return model, tokenizer + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format != "pytorch": + return False + if "rwkv" not in llm_family.model_name: + return False + if "pile" not in llm_family.model_name: + return False + if "generate" not in llm_family.model_ability: + return False + return True + diff --git a/xinference/model/llm/pytorch/utils.py b/xinference/model/llm/pytorch/utils.py index 3dd03f13f8..51e0bf1686 100644 --- a/xinference/model/llm/pytorch/utils.py +++ b/xinference/model/llm/pytorch/utils.py @@ -586,3 +586,241 @@ def generate_stream_chatglm( ) yield completion_chunk, completion_usage + +@torch.inference_mode() +def generate_stream_rwkv( + model, + tokenizer, + prompt, + device, + generate_config, + judge_sent_end=False, +) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]: + context_len = get_context_length(model.config) + stream_interval = generate_config.get("stream_interval", 2) + stream = generate_config.get("stream", False) + + len_prompt = len(prompt) + + temperature = float(generate_config.get("temperature", 1.0)) + repetition_penalty = float(generate_config.get("repetition_penalty", 1.0)) + top_p = float(generate_config.get("top_p", 1.0)) + top_k = int(generate_config.get("top_k", -1)) # -1 means disable + max_new_tokens = int(generate_config.get("max_tokens", 256)) + echo = bool(generate_config.get("echo", False)) + stop_str = generate_config.get("stop", None) + stop_token_ids = generate_config.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + + if "qwen" in str(type(model)).lower(): + # TODO: hacky + input_ids = tokenizer(prompt, allowed_special="all").input_ids + else: + input_ids = tokenizer(prompt).input_ids + output_ids = list(input_ids) + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + + # RWKV generates using previous states of the model instead of previous tokens + state = out = None + sent_interrupt = False + token = None + last_output_length = 0 + for i in range(max_new_tokens): + if i == 0: + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + state = out.state + else: + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], device=device + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], device=device + ), + use_cache=True, + state=state if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + state = out.state + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + if stream: + tmp_output_length = len(output) + output = output[last_output_length:] + last_output_length = tmp_output_length + + # prevent yielding partial stop sequence + if not partially_stopped: + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=None + ) + completion_chunk = CompletionChunk( + id=str(uuid.uuid1()), + object="text_completion", + created=int(time.time()), + model=generate_config["model"], + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=input_echo_len, + completion_tokens=i, + total_tokens=(input_echo_len + i), + ) + + yield completion_chunk, completion_usage + + if stopped: + break + + # finish stream event, which contains finish reason + if stopped: + finish_reason = "stop" + elif i == max_new_tokens - 1: + finish_reason = "length" + else: + finish_reason = None + + if stream: + completion_choice = CompletionChoice( + text="", index=0, logprobs=None, finish_reason=finish_reason + ) + else: + completion_choice = CompletionChoice( + text=output, index=0, logprobs=None, finish_reason=finish_reason + ) + + completion_chunk = CompletionChunk( + id=str(uuid.uuid1()), + object="text_completion", + created=int(time.time()), + model=generate_config["model"], + choices=[completion_choice], + ) + completion_usage = CompletionUsage( + prompt_tokens=input_echo_len, + completion_tokens=i, + total_tokens=(input_echo_len + i), + ) + + yield completion_chunk, completion_usage + + # clean + del state, out + gc.collect() + torch.cuda.empty_cache() From 7eb6c0ad91aa1c588d3adea75d1037ed81686cfd Mon Sep 17 00:00:00 2001 From: Bojun Feng Date: Fri, 13 Oct 2023 03:36:17 -0500 Subject: [PATCH 2/6] update sizes --- xinference/model/llm/llm_family.json | 40 +++++++++++++++++++++++++-- xinference/model/llm/pytorch/utils.py | 2 +- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index cf7a391ffc..62315b363a 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -1774,8 +1774,44 @@ "quantizations": [ "none" ], - "model_id": "RWKV/rwkv-4-169m-pile", - "model_revision": "46bdc280eb97b6141d5d51a935e0c4870ecaefcc" + "model_id": "RWKV/rwkv-4-430m-pile", + "model_revision": "a4f6ec80438d4262d1bbc8f385feb2ef1a4a9d6b" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 2, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-1b5-pile", + "model_revision": "643585471eaf5821d94dfcb498ab5b94a36b42cf" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 3, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-3b-pile", + "model_revision": "7fdda3c5570d4a9711f8f02cc3a20941a5623cd3" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-7b-pile", + "model_revision": "922e22a761427e50d7be457b31a76b1126021b8b" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 14, + "quantizations": [ + "none" + ], + "model_id": "RWKV/rwkv-4-14b-pile", + "model_revision": "4effb0fa9d15c2f383a1d159f4a40df0e09eb6d5" } ] }, diff --git a/xinference/model/llm/pytorch/utils.py b/xinference/model/llm/pytorch/utils.py index 51e0bf1686..c8545e2e40 100644 --- a/xinference/model/llm/pytorch/utils.py +++ b/xinference/model/llm/pytorch/utils.py @@ -667,7 +667,7 @@ def generate_stream_rwkv( ), encoder_hidden_states=encoder_output, use_cache=True, - past_key_values=past_key_values if not sent_interrupt else None, + state=state if not sent_interrupt else None, ) sent_interrupt = False From 2c2dba886066378f03a4bf75de27bf500ad36d08 Mon Sep 17 00:00:00 2001 From: Bojun Feng Date: Fri, 13 Oct 2023 03:55:58 -0500 Subject: [PATCH 3/6] fix pre-commit formatting --- doc/source/getting_started/using_xinference.rst | 5 ++--- xinference/model/llm/__init__.py | 2 +- xinference/model/llm/pytorch/rwkv.py | 3 +-- xinference/model/llm/pytorch/utils.py | 1 + 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/doc/source/getting_started/using_xinference.rst b/doc/source/getting_started/using_xinference.rst index bf7510603a..c6d5bc7377 100644 --- a/doc/source/getting_started/using_xinference.rst +++ b/doc/source/getting_started/using_xinference.rst @@ -27,7 +27,7 @@ Starting the Supervisor On the server where you want to run the Xinference supervisor, run the following command: .. code-block:: bash - + xinference-supervisor -H "${supervisor_host}" Replace ${supervisor_host} with the actual host of your supervisor server. @@ -38,8 +38,7 @@ Starting the Workers On each of the other servers where you want to run Xinference workers, run the following command: .. code-block:: bash - + xinference-worker -e "http://${supervisor_host}:9997" Once Xinference is running, an endpoint will be accessible for model management via CLI or Xinference client. - diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index a1d2fe3417..12a0424d07 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -43,8 +43,8 @@ def _install(): from .pytorch.core import PytorchChatModel, PytorchModel from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel - from .pytorch.vicuna import VicunaPytorchChatModel from .pytorch.rwkv import RWKVPilePytorchModel + from .pytorch.vicuna import VicunaPytorchChatModel from .vllm.core import VLLMChatModel, VLLMModel # register llm classes. diff --git a/xinference/model/llm/pytorch/rwkv.py b/xinference/model/llm/pytorch/rwkv.py index c436d2808f..7a43aef873 100644 --- a/xinference/model/llm/pytorch/rwkv.py +++ b/xinference/model/llm/pytorch/rwkv.py @@ -39,7 +39,7 @@ def __init__( def _load_model(self, kwargs: dict): try: - from transformers import RwkvForCausalLM, AutoTokenizer + from transformers import AutoTokenizer, RwkvForCausalLM except ImportError: error_message = "Failed to import module 'transformers'" installation_guide = [ @@ -75,4 +75,3 @@ def match( if "generate" not in llm_family.model_ability: return False return True - diff --git a/xinference/model/llm/pytorch/utils.py b/xinference/model/llm/pytorch/utils.py index c8545e2e40..722c73d112 100644 --- a/xinference/model/llm/pytorch/utils.py +++ b/xinference/model/llm/pytorch/utils.py @@ -587,6 +587,7 @@ def generate_stream_chatglm( yield completion_chunk, completion_usage + @torch.inference_mode() def generate_stream_rwkv( model, From 2795f7de33a03df9f1a6babb885f114bc8186b5f Mon Sep 17 00:00:00 2001 From: Bojun Feng Date: Fri, 13 Oct 2023 16:12:47 -0500 Subject: [PATCH 4/6] minor --- xinference/model/llm/llm_family.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 62315b363a..aad0c093cc 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -1775,7 +1775,7 @@ "none" ], "model_id": "RWKV/rwkv-4-430m-pile", - "model_revision": "a4f6ec80438d4262d1bbc8f385feb2ef1a4a9d6b" + "model_revision": "6f7bcd0e5b0851cc00710a5c7d2ef5f5f24363a5" }, { "model_format": "pytorch", From e90d560c1dd918d9c3ca7585ab290349c164499a Mon Sep 17 00:00:00 2001 From: Bojun Feng Date: Fri, 10 Nov 2023 02:22:56 -0600 Subject: [PATCH 5/6] update docs --- doc/source/models/builtin/index.rst | 1 + doc/source/models/builtin/rwkv-4-pile.rst | 73 +++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 doc/source/models/builtin/rwkv-4-pile.rst diff --git a/doc/source/models/builtin/index.rst b/doc/source/models/builtin/index.rst index dc60921fc2..0919ab9b3b 100644 --- a/doc/source/models/builtin/index.rst +++ b/doc/source/models/builtin/index.rst @@ -33,6 +33,7 @@ Chat & Instruction-following Models - :ref:`Llama-2 Chat ` - :ref:`Orca Mini ` - :ref:`Qwen Chat ` +- :ref:`RWKV-4-Pile ` - :ref:`Vicuna v1.3 ` - :ref:`Vicuna v1.5 ` - :ref:`Vicuna v1.5 16k ` diff --git a/doc/source/models/builtin/rwkv-4-pile.rst b/doc/source/models/builtin/rwkv-4-pile.rst new file mode 100644 index 0000000000..3fb0f8e843 --- /dev/null +++ b/doc/source/models/builtin/rwkv-4-pile.rst @@ -0,0 +1,73 @@ +.. _models_builtin_rwkv_4_pile: + +============= +RWKV-4-Pile +============= + +- **Model Name:** rwkv-4-pile +- **Languages:** en +- **Abilities:** generate +- **Description:** The RWKV (Receptance-Weighted Key-Value) models are a series of language models that ranges in size. They are noteworthy for their large-scale implementation and innovative architecture, which combines the strengths of RNN and Transformer models. + +Specifications +^^^^^^^^^^^^^^ + +Model Spec 1 (pytorch, 1 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 1 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-430m-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 1 --model-format pytorch --quantization none + +Model Spec 2 (pytorch, 2 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 2 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-1b5-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 2 --model-format pytorch --quantization none + +Model Spec 3 (pytorch, 3 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 3 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-3b-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 3 --model-format pytorch --quantization none + +Model Spec 4 (pytorch, 7 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 7 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-7b-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 7 --model-format pytorch --quantization none + +Model Spec 5 (pytorch, 14 Billion) ++++++++++++++++++++++++++++++++++ + +- **Model Format:** pytorch +- **Model Size (in billions):** 14 +- **Quantizations:** none +- **Model ID:** RWKV/rwkv-4-14b-pile + +Execute the following command to launch the model:: + + xinference launch --model-name rwkv-4-pile --size-in-billions 14 --model-format pytorch --quantization none From 6dec727b2bab1aa734c2d7e7ca482a7539235195 Mon Sep 17 00:00:00 2001 From: Bojun Feng Date: Fri, 10 Nov 2023 02:34:47 -0600 Subject: [PATCH 6/6] minor --- xinference/model/llm/pytorch/rwkv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/llm/pytorch/rwkv.py b/xinference/model/llm/pytorch/rwkv.py index 7a43aef873..74fdec704f 100644 --- a/xinference/model/llm/pytorch/rwkv.py +++ b/xinference/model/llm/pytorch/rwkv.py @@ -37,7 +37,7 @@ def __init__( pytorch_model_config=pytorch_model_config, ) - def _load_model(self, kwargs: dict): + def _load_model(self, **kwargs): try: from transformers import AutoTokenizer, RwkvForCausalLM except ImportError: