|
| 1 | +import streamlit as st |
| 2 | +from openai import OpenAI |
| 3 | +import os |
| 4 | +import sys |
| 5 | +import yaml |
| 6 | + |
| 7 | +if "config" not in st.session_state: |
| 8 | + # 配置根目录 |
| 9 | + root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
| 10 | + root_dir = os.path.abspath(root_dir) |
| 11 | + |
| 12 | + original_pythonpath = os.environ.get("PYTHONPATH", "") |
| 13 | + os.environ["PYTHONPATH"] = original_pythonpath + ":" + root_dir |
| 14 | + sys.path.append(root_dir) |
| 15 | + support_models = [] |
| 16 | + config_path = os.path.join(root_dir, "gpt_server/script/config.yaml") |
| 17 | + with open(config_path, "r") as f: |
| 18 | + config = yaml.safe_load(f) |
| 19 | + # TODO 没有添加别名 |
| 20 | + for model_config_ in config["models"]: |
| 21 | + for model_name, model_config in model_config_.items(): |
| 22 | + # 启用的模型 |
| 23 | + if model_config["enable"]: |
| 24 | + if ( |
| 25 | + model_config["model_type"] != "embedding" |
| 26 | + and model_config["model_type"] != "embedding_infinity" |
| 27 | + ): |
| 28 | + support_models.append(model_name) |
| 29 | + port = config["serve_args"]["port"] |
| 30 | + client = OpenAI( |
| 31 | + api_key="EMPTY", |
| 32 | + base_url=f"http://localhost:{port}/v1", |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +def clear_chat_history(): |
| 37 | + del st.session_state.messages |
| 38 | + |
| 39 | + |
| 40 | +def init_chat_history(): |
| 41 | + with st.chat_message("assistant", avatar="🤖"): |
| 42 | + st.markdown("您好,很高兴为您服务!🥰") |
| 43 | + |
| 44 | + if "messages" in st.session_state: |
| 45 | + for message in st.session_state.messages: |
| 46 | + avatar = "🧑💻" if message["role"] == "user" else "🤖" |
| 47 | + with st.chat_message(message["role"], avatar=avatar): |
| 48 | + st.markdown(message["content"]) |
| 49 | + else: |
| 50 | + st.session_state.messages = [] |
| 51 | + |
| 52 | + return st.session_state.messages |
| 53 | + |
| 54 | + |
| 55 | +def main(): |
| 56 | + st.title(f"GPT_SERVER") |
| 57 | + models = [i.id for i in client.models.list() if i.id in support_models] |
| 58 | + model = st.sidebar.selectbox(label="选择模型", options=models) |
| 59 | + temperature = st.sidebar.slider( |
| 60 | + label="temperature", min_value=0.0, max_value=2.0, value=0.8, step=0.1 |
| 61 | + ) |
| 62 | + top_p = st.sidebar.slider( |
| 63 | + label="top_p", min_value=0.0, max_value=1.0, value=1.0, step=0.1 |
| 64 | + ) |
| 65 | + messages = init_chat_history() |
| 66 | + |
| 67 | + if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"): |
| 68 | + with st.chat_message("user", avatar="🧑"): |
| 69 | + st.markdown(prompt) |
| 70 | + messages.append({"role": "user", "content": prompt}) |
| 71 | + stream = client.chat.completions.create( |
| 72 | + model=model, # Model name to use |
| 73 | + messages=messages, # Chat history |
| 74 | + temperature=temperature, # Temperature for text generation |
| 75 | + top_p=top_p, |
| 76 | + stream=True, # Stream response |
| 77 | + ) |
| 78 | + with st.chat_message("assistant", avatar="🤖"): |
| 79 | + placeholder = st.empty() |
| 80 | + partial_message = "" |
| 81 | + for chunk in stream: |
| 82 | + partial_message += chunk.choices[0].delta.content or "" |
| 83 | + placeholder.markdown(partial_message) |
| 84 | + messages.append({"role": "assistant", "content": partial_message}) |
| 85 | + |
| 86 | + st.button("清空对话", on_click=clear_chat_history) |
| 87 | + |
| 88 | + |
| 89 | +if __name__ == "__main__": |
| 90 | + main() |
0 commit comments