学术, 成果公开

对话模型chat.guanjihuan.com的主要实现代码开源

这是之前的一篇:双语对话语言模型ChatGLM的下载和安装。本篇把详细的实现代码进行整理并开源,代码也放在我的 GitHub 上:https://github.com/guanjihuan/chat.guanjihuan.com

这里把 https://chat.guanjihuan.com 的主要实现代码进行开源。代码参考各个开源大模型的 GitHub 或 HuggingFace 主页、第三方模型的 API 官网,以及 HuggingFace 和 Pytorch 的文档等。

硬件要求:如果是本地 GPU 运行模型,还需要 Nvidia 显卡,至少 6G 显存。说明:这里只测试了几个模型,还有更多开源大模型,感兴趣的可以自行测试。通常,8G 显存的显卡可以量化地加载 7B 左右的模型(70亿参数);16G 显存的显卡可以完整加载 7B 左右的模型(70亿参数)或量化地加载 14B 左右的模型(140亿参数);更大参数空间的模型的运行需要更大显存的显卡。开源大模型的排行榜有:https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboardhttps://cevalbenchmark.com/static/leaderboard.htmlhttps://opencompass.org.cn/leaderboard-llm

一、基础环境

运行这里的代码需要安装 Python 环境,可以选择安装 Anaconda:https://www.anaconda.com

Web 框架是使用 Streamlit:https://streamlit.io 、https://github.com/streamlit/streamlit 。

Streamlit 的安装:

pip install streamlit

运行命令:

streamlit run web_demo.py

python -m streamlit run web_demo.py

如果是在公网IP下访问,并指定8501端口和黑色主题,那么运行命令为:

streamlit run web_demo.py --theme.base dark --server.port 8501 --server.address 0.0.0.0 

为了防止一些不必要的报错,可以更新一下操作系统的显卡驱动并重启:

sudo apt-get update

sudo apt-get install ubuntu-drivers-common

sudo ubuntu-drivers autoinstall

此外,可以更新一下 Pytorch( https://pytorch.org/get-started/locally/ ),也可以防止一些报错:

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

二、本地运行开源大语言模型

1. 开源模型 ChatGLM

ChatGLM3-6B 主页:https://github.com/THUDM/ChatGLM3 。 安装该模型依赖的环境:

pip install -r requirements.txt

requirements.txt 文件为:

# basic requirements

protobuf>=4.25.2
transformers>=4.36.2
tokenizers>=0.15.0
cpm_kernels>=1.0.11
torch>=2.1.0
gradio>=4.14.0
sentencepiece>=0.1.99
sentence_transformers>=2.2.2
accelerate>=0.26.1
streamlit>=1.30.0
fastapi>=0.109.0
loguru~=0.7.2
mdtex2html>=1.2.0
latex2mathml>=3.77.0

# for openai demo

openai>=1.7.2
zhipuai>=2.0.0

pydantic>=2.5.3
sse-starlette>=1.8.2
uvicorn>=0.25.0
timm>=0.9.12
tiktoken>=0.5.2

# for langchain demo

langchain>=0.1.0
langchainhub>=0.1.14
arxiv>=2.1.0

模型文件下载:https://huggingface.co/THUDM/chatglm3-6b-32k ,放在目录 THUDM/chatglm3-6b-32k 下。

显存/内存要求:量化加载大概要 6G 显存;默认加载大概需要 13G 显存;CPU加载大概需要 25G 内存。

运行命令:

python -m streamlit run ./ChatGLM.py --theme.base dark --server.port 8501

如果量化加载时 bitsandbytes 报错,那么安装该软件包:pip install bitsandbytes

ChatGLM.py 代码:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/38502
"""

import streamlit as st
st.set_page_config(
    page_title="Chat",
    layout='wide'
)

choose_load_method = 1  # 选择加载模型的方式

if choose_load_method == 0:
    # 默认加载(需要13G显存)
    @st.cache_resource
    def load_model_chatglm3():
        from transformers import AutoModel, AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True)
        model = AutoModel.from_pretrained("THUDM/chatglm3-6b-32k",trust_remote_code=True).half().cuda()
        model = model.eval()
        return  model, tokenizer
    model_chatglm3, tokenizer_chatglm3 = load_model_chatglm3()

elif choose_load_method == 1:
    # 量化加载(需要6G显存)
    @st.cache_resource
    def load_model_chatglm3():
        from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
        tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True)
        nf4_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
        )
        model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True, quantization_config=nf4_config)
        model = model.eval()
        return  model, tokenizer
    model_chatglm3, tokenizer_chatglm3 = load_model_chatglm3()

elif choose_load_method == 2:
    # 在CPU上加载(需要25G内存,对话速度会比较慢)
    @st.cache_resource
    def load_model_chatglm3():
        from transformers import AutoModel, AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True)
        model = AutoModel.from_pretrained("THUDM/chatglm3-6b-32k",trust_remote_code=True).float()
        model = model.eval()
        return  model, tokenizer
    model_chatglm3, tokenizer_chatglm3 = load_model_chatglm3()

with st.sidebar:
    with st.expander('参数', expanded=True):
        max_length = 409600
        top_p = st.slider('top_p', 0.01, 1.0, step=0.01, value=0.8, key='top_p_session')
        temperature = st.slider('temperature', 0.51, 1.0, step=0.01, value=0.8, key='temperature_session') 
        def reset_parameter():
            st.session_state['top_p_session'] = 0.8
            st.session_state['temperature_session'] = 0.8
        reset_parameter_button = st.button('重置参数', on_click=reset_parameter)

prompt = st.chat_input("在这里输入您的命令")

def chat_response_chatglm3(prompt):
    history, past_key_values = st.session_state.history_ChatGLM3, st.session_state.past_key_values_ChatGLM3
    for response, history, past_key_values in model_chatglm3.stream_chat(tokenizer_chatglm3, prompt, history,
                                                                past_key_values=past_key_values,
                                                                max_length=max_length, top_p=top_p,
                                                                temperature=temperature,
                                                                return_past_key_values=True):
        message_placeholder_chatglm3.markdown(response)
        if stop_button:
            break
    st.session_state.ai_response.append({"role": "robot", "content": response, "avatar": "assistant"})
    st.session_state.history_ChatGLM3 = history
    st.session_state.past_key_values_ChatGLM3 = past_key_values
    return response

def clear_all():
    st.session_state.history_ChatGLM3 = []
    st.session_state.past_key_values_ChatGLM3 = None
    st.session_state.ai_response = []

if 'history_ChatGLM3' not in st.session_state:
    st.session_state.history_ChatGLM3 = []
if 'past_key_values_ChatGLM3' not in st.session_state:
    st.session_state.past_key_values_ChatGLM3 = None
if 'ai_response' not in st.session_state:
    st.session_state.ai_response = []

for ai_response in st.session_state.ai_response:
    with st.chat_message(ai_response["role"], avatar=ai_response.get("avatar")):
        st.markdown(ai_response["content"])

prompt_placeholder = st.chat_message("user", avatar='user')
with st.chat_message("robot", avatar="assistant"):
    message_placeholder_chatglm3 = st.empty()

if prompt:
    prompt_placeholder.markdown(prompt)
    st.session_state.ai_response.append({"role": "user", "content": prompt, "avatar": 'user'})
    stop = st.empty()
    stop_button = stop.button('停止', key='break_response')
    chat_response_chatglm3(prompt)
    stop.empty()
button_clear = st.button("清空", on_click=clear_all, key='clear')

2. 开源模型 Qwen

Qwen 主页:https://github.com/QwenLM/Qwen 。 安装该模型依赖的环境:

pip install -r requirements.txt

requirements.txt 文件为:

transformers==4.32.0
accelerate
tiktoken
einops
transformers_stream_generator==0.0.4
scipy

Qwen-7B-Chat-Int4 模型文件下载:https://huggingface.co/Qwen/Qwen-7B-Chat-Int4 ,放在目录 Qwen/Qwen-7B-Chat-Int4 下。

Qwen-14B-Chat-Int4 模型文件下载:https://huggingface.co/Qwen/Qwen-14B-Chat-Int4 ,放在目录 Qwen/Qwen-14B-Chat-Int4 下。

显存要求:Qwen-7B-Chat-Int4 大概需要 8G 显存;Qwen-14B-Chat-Int4 大概需要 12G 显存。

运行命令:

python -m streamlit run ./Qwen.py --theme.base dark --server.port 8501

此外,如果运行有报错,可能还需要安装:

pip install optimum
pip install auto-gptq
pip install --upgrade s3fs aiobotocore botocore

Qwen.py 代码:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/38502
"""

import streamlit as st
st.set_page_config(
    page_title="Chat",
    layout='wide'
)

choose_load_model = 0  # 选择加载的模型(Qwen-7B 或 Qwen-14B)

if choose_load_model == 0:
    # Qwen-7B(需要8G显存)
    @st.cache_resource
    def load_model_qwen_7B():
        from transformers import AutoTokenizer, AutoModelForCausalLM
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat-Int4", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen-7B-Chat-Int4",
            device_map="auto",
            trust_remote_code=True,
        ).eval()
        return tokenizer, model
    tokenizer_qwen_7B, model_qwen_7B = load_model_qwen_7B()

elif choose_load_model == 1:
    # Qwen-14B(需要12G显存)
    @st.cache_resource
    def load_model_qwen_14B():
        from transformers import AutoTokenizer, AutoModelForCausalLM
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat-Int4", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen-14B-Chat-Int4",
            device_map="auto",
            trust_remote_code=True
        ).eval()
        return tokenizer, model
    tokenizer_qwen_14B, model_qwen_14B = load_model_qwen_14B()

with st.sidebar:
    with st.expander('参数', expanded=True):
        max_length = 409600
        top_p = st.slider('top_p', 0.01, 1.0, step=0.01, value=0.8, key='top_p_session')
        temperature = st.slider('temperature', 0.51, 1.0, step=0.01, value=0.8, key='temperature_session') 
        def reset_parameter():
            st.session_state['top_p_session'] = 0.8
            st.session_state['temperature_session'] = 0.8
        reset_parameter_button = st.button('重置参数', on_click=reset_parameter)

prompt = st.chat_input("在这里输入您的命令")

from transformers.generation import GenerationConfig

if choose_load_model == 0:
    config_qwen_7b = GenerationConfig.from_pretrained(
        "Qwen/Qwen-7B-Chat-Int4", trust_remote_code=True, resume_download=True, max_length = max_length, top_p = top_p, temperature = temperature
    )
    def chat_response_qwen_7B(query):
        for response in model_qwen_7B.chat_stream(tokenizer_qwen_7B, query, history=st.session_state.history_qwen, generation_config=config_qwen_7b):
            message_placeholder_qwen.markdown(response)
            if stop_button:
                break
        st.session_state.history_qwen.append((query, response))
        st.session_state.ai_response.append({"role": "robot", "content": response, "avatar": "assistant"})
        return response

elif choose_load_model == 1:
    config_qwen_14b = GenerationConfig.from_pretrained(
        "Qwen/Qwen-14B-Chat-Int4", trust_remote_code=True, resume_download=True, max_length = max_length, top_p = top_p, temperature = temperature
    )
    def chat_response_qwen_14B(query):
        for response in model_qwen_14B.chat_stream(tokenizer_qwen_14B, query, history=st.session_state.history_qwen, generation_config=config_qwen_14b):
            message_placeholder_qwen.markdown(response)
            if stop_button:
                break
        st.session_state.history_qwen.append((query, response))
        st.session_state.ai_response.append({"role": "robot", "content": response, "avatar": "assistant"})
        return response

def clear_all():
    st.session_state.history_qwen = []
    st.session_state.ai_response = []

if 'history_qwen' not in st.session_state:
    st.session_state.history_qwen = []
if 'ai_response' not in st.session_state:
    st.session_state.ai_response = []

for ai_response in st.session_state.ai_response:
    with st.chat_message(ai_response["role"], avatar=ai_response.get("avatar")):
        st.markdown(ai_response["content"])

prompt_placeholder = st.chat_message("user", avatar='user')
with st.chat_message("robot", avatar="assistant"):
    message_placeholder_qwen = st.empty()

if prompt:
    prompt_placeholder.markdown(prompt)
    st.session_state.ai_response.append({"role": "user", "content": prompt, "avatar": 'user'})
    stop = st.empty()
    stop_button = stop.button('停止', key='break_response')
    if choose_load_model == 0:
        chat_response_qwen_7B(prompt)
    elif choose_load_model == 1:
        chat_response_qwen_14B(prompt)
    stop.empty()
button_clear = st.button("清空", on_click=clear_all, key='clear')

3. 开源模型 InternLM

InternLM 主页:https://github.com/InternLM/InternLM 。运行代码时,需要调用其中的 tools 文件夹。

internlm-chat-7b 模型文件下载:https://huggingface.co/internlm/internlm-chat-7b ,放在 internlm/internlm-chat-7b 目录下。说明:提供的代码是加载 internlm-chat-7b 模型, 目前已经有 internlm2-chat-7b 模型,但个人还未测试。internlm2-chat-7b 模型文件下载:https://huggingface.co/internlm/internlm2-chat-7b

显存要求:大概需要 7G 的显存。

运行命令:

python -m streamlit run ./InternLM.py --theme.base dark --server.port 8501

InternLM.py 代码:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/38502
"""

import streamlit as st
st.set_page_config(
    page_title="Chat",
    layout='wide'
)

@st.cache_resource
def load_model_internlm_7B():
    # internlm(需要 7G 显存)
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
    )
    model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True, quantization_config=nf4_config)
    tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True, torch_dtype=torch.bfloat16)
    model = model.eval()
    return model, tokenizer
model_internlm_7B, tokenizer_internlm_7B = load_model_internlm_7B()

with st.sidebar:
    with st.expander('参数', expanded=True):
        max_length = 409600
        top_p = st.slider('top_p', 0.01, 1.0, step=0.01, value=0.8, key='top_p_session')
        temperature = st.slider('temperature', 0.51, 1.0, step=0.01, value=0.8, key='temperature_session') 
        def reset_parameter():
            st.session_state['top_p_session'] = 0.8
            st.session_state['temperature_session'] = 0.8
        reset_parameter_button = st.button('重置参数', on_click=reset_parameter)

prompt = st.chat_input("在这里输入您的命令")

from tools.transformers.interface import GenerationConfig, generate_interactive

def prepare_generation_config():
    generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
    return generation_config

def combine_history(prompt, messages):
    total_prompt = ""
    for message in messages:
        cur_content = message["content"]
        if message["role"] == "user":
            cur_prompt = user_prompt.replace("{user}", cur_content)
        elif message["role"] == "robot":
            cur_prompt = robot_prompt.replace("{robot}", cur_content)
        else:
            raise RuntimeError
        total_prompt += cur_prompt
    total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
    return total_prompt

user_prompt = "<|User|>:{user}<eoh>\n"
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
generation_config = prepare_generation_config()

if "messages_internlm_7B" not in st.session_state:
    st.session_state.messages_internlm_7B = []

from dataclasses import asdict

def chat_response_internlm_7B(prompt):
    real_prompt = combine_history(prompt, messages = st.session_state.messages_internlm_7B)
    st.session_state.messages_internlm_7B.append({"role": "user", "content": prompt, "avatar": 'user'})
    for cur_response in generate_interactive(
        model=model_internlm_7B,
        tokenizer=tokenizer_internlm_7B,
        prompt=real_prompt,
        additional_eos_token_id=103028,
        **asdict(generation_config),
    ):
        message_placeholder_internlm_7B.markdown(cur_response + "▌")
        if stop_button:
            break
    message_placeholder_internlm_7B.markdown(cur_response)
    st.session_state.messages_internlm_7B.append({"role": "robot", "content": cur_response, "avatar": "assistant"})
    st.session_state.ai_response.append({"role": "robot", "content": cur_response, "avatar": "assistant"})
    return cur_response


def clear_all():
    st.session_state.messages_internlm_7B = []
    st.session_state.ai_response = []

if 'messages_internlm_7B' not in st.session_state:
    st.session_state.messages_internlm_7B = []
if 'ai_response' not in st.session_state:
    st.session_state.ai_response = []

for ai_response in st.session_state.ai_response:
    with st.chat_message(ai_response["role"], avatar=ai_response.get("avatar")):
        st.markdown(ai_response["content"])

prompt_placeholder = st.chat_message("user", avatar='user')
with st.chat_message("robot", avatar="assistant"):
    message_placeholder_internlm_7B = st.empty()

if prompt:
    prompt_placeholder.markdown(prompt)
    st.session_state.ai_response.append({"role": "user", "content": prompt, "avatar": 'user'})
    stop = st.empty()
    stop_button = stop.button('停止', key='break_response')
    chat_response_internlm_7B(prompt)
    stop.empty()
button_clear = st.button("清空", on_click=clear_all, key='clear')

三、使用第三方模型 API

1. 智谱 - ChatGLM_Turbo

智谱 - ChatGLM Turbo 的 API key 获取(收费,可免费试用):https://maas.aminer.cn

运行命令:

python -m streamlit run ./ChatGLM_Turbo.py --theme.base dark --server.port 8501

ChatGLM_Turbo.py 代码:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/38502
"""

import streamlit as st
st.set_page_config(
    page_title="Chat",
    layout='wide'
)

try:
    import zhipuai
except:
    import os
    os.system('pip install zhipuai')
    import zhipuai

# 从官网获取 API_KEY
zhipuai.api_key = " "

with st.sidebar:
    with st.expander('参数', expanded=True):
        top_p = st.slider('top_p', 0.01, 1.0, value=0.7, step=0.01)
        temperature = st.slider('temperature', 0.01, 1.0, value=0.95, step=0.01)

def chatglm_chat(prompt=[]):
    response = zhipuai.model_api.sse_invoke(
        model="chatglm_turbo",
        prompt=prompt,
        temperature=temperature,
        top_p=top_p,
    )
    return response

def getlength(text):
        length = 0
        for content in text:
            temp = content["content"]
            leng = len(temp)
            length += leng
        return length

def checklen(text):
    while (getlength(text) > 8000):
        del text[0]
    return text

def getText(role,content, text):
        jsoncon = {}
        jsoncon["role"] = role
        jsoncon["content"] = content
        text.append(jsoncon)
        return text

answer = ""
if "text0" not in st.session_state:
    st.session_state.text0 = []
if "messages0" not in st.session_state:
    st.session_state.messages0 = [] 
def clear_all0():
    st.session_state.messages0 = []
    st.session_state.text0 = []
if st.session_state.messages0 == []:
    with st.chat_message("user", avatar="user"):
        input_placeholder = st.empty() 
    with st.chat_message("robot", avatar="assistant"):
        message_placeholder = st.empty()
for message in st.session_state.messages0:
    with st.chat_message(message["role"], avatar=message.get("avatar")):
        st.markdown(message["content"])
prompt_text = st.chat_input("请在这里输入您的命令")

if prompt_text:
    if st.session_state.messages0 != []:
        with st.chat_message("user", avatar="user"):
            input_placeholder = st.empty()
        with st.chat_message("robot", avatar="assistant"):
            message_placeholder = st.empty()
    input_placeholder.markdown(prompt_text)
    st.session_state.messages0.append({"role": "user", "content": prompt_text, "avatar": "user"})
    st.session_state.text0 = getText("user", prompt_text, st.session_state.text0)
    question = checklen(st.session_state.text0)
    response  = chatglm_chat(question)
    for event in response.events():
        answer += event.data
        message_placeholder.markdown(answer)
    st.session_state.text0 = getText("assistant", answer, st.session_state.text0)
    st.session_state.messages0.append({"role": "robot", "content": answer, "avatar": "assistant"})
    st.rerun()
button_clear = st.button("清空", on_click=clear_all0, key='clear0')

2. 讯飞 - 星火大模型

讯飞 - 星火大模型的 API key 获取(收费,可免费试用):https://xinghuo.xfyun.cn

运行命令:

python -m streamlit run ./星火大模型.py --theme.base dark --server.port 8501

星火大模型.py 代码:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/38502
"""

import streamlit as st
st.set_page_config(
    page_title="Chat",
    layout='wide'
)

# 以下密钥信息从控制台获取
appid = " "     # 填写控制台中获取的 APPID 信息
api_secret = " "   # 填写控制台中获取的 APISecret 信息
api_key =" "    # 填写控制台中获取的 APIKey 信息

with st.sidebar:
    with st.expander('模型', expanded=True):
        API_model = st.radio('选择:', ('讯飞 - 星火大模型 V1.5', '讯飞 - 星火大模型 V2.0', '讯飞 - 星火大模型 V3.0'), key='choose_API_model')
        if API_model == '讯飞 - 星火大模型 V1.5':
            API_model_0 = '星火大模型 V1.5'
        elif API_model == '讯飞 - 星火大模型 V2.0':
            API_model_0 = '星火大模型 V2.0'
        elif API_model == '讯飞 - 星火大模型 V3.0':
            API_model_0 = '星火大模型 V3.0'
        st.write('当前模型:'+API_model_0)

    with st.expander('参数', expanded=True):
        top_k = st.slider('top_k', 1, 6, value=4, step=1)
        temperature = st.slider('temperature', 0.01, 1.0, value=0.5, step=0.01)

# 云端环境的服务地址
if API_model == '讯飞 - 星火大模型 V1.5':
    domain = "general"   # v1.5版本
    Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat"  # v1.5环境的地址

elif API_model == '讯飞 - 星火大模型 V2.0':
    domain = "generalv2"    # v2.0版本
    Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"  # v2.0环境的地址

elif API_model == '讯飞 - 星火大模型 V3.0':
    domain = "generalv3"    # v3.0版本
    Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"  # v3.0环境的地址

import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket  # 使用websocket_client
answer = ""

class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, Spark_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(Spark_url).netloc
        self.path = urlparse(Spark_url).path
        self.Spark_url = Spark_url

    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))
        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"
        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()
        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        # 拼接鉴权参数,生成url
        url = self.Spark_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
        return url

# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)

# 收到websocket关闭的处理
def on_close(ws,one,two):
    print(" ")

# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))

def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
    ws.send(data)

# 收到websocket消息的处理
def on_message(ws, message):
    # print(message)
    data = json.loads(message)
    code = data['header']['code']
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        global answer
        answer += content
        message_placeholder.markdown(answer)
        if status == 2:
            ws.close()

def gen_params(appid, domain,question):
    """
    通过appid和用户的提问来生成请参数
    """
    data = {
        "header": {
            "app_id": appid,
            "uid": "1234"
        },
        "parameter": {
            "chat": {
                "domain": domain,
                "random_threshold": 0.5,
                "temperature": temperature,
                "top_k": top_k,
                "max_tokens": 4096,
                "auditing": "default"
            }
        },
        "payload": {
            "message": {
                "text": question
            }
        }
    }
    return data

def main_chat(appid, api_key, api_secret, Spark_url,domain, question):
    wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
    websocket.enableTrace(False)
    wsUrl = wsParam.create_url()
    ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
    ws.appid = appid
    ws.question = question
    ws.domain = domain
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

def getlength(text):
        length = 0
        for content in text:
            temp = content["content"]
            leng = len(temp)
            length += leng
        return length

def checklen(text):
    while (getlength(text) > 8000):
        del text[0]
    return text

def getText(role,content, text):
        jsoncon = {}
        jsoncon["role"] = role
        jsoncon["content"] = content
        text.append(jsoncon)
        return text

prompt_text = st.chat_input("请在这里输入您的命令")

if API_model == '讯飞 - 星火大模型 V1.5':
    if "text" not in st.session_state:
        st.session_state.text = []
    if "messages" not in st.session_state:
        st.session_state.messages = [] 
    def clear_all():
        st.session_state.messages = []
        st.session_state.text = []
    if st.session_state.messages == []:
        with st.chat_message("user", avatar="user"):
            input_placeholder = st.empty()
        with st.chat_message("robot", avatar="assistant"):
            message_placeholder = st.empty()
    for message in st.session_state.messages:
        with st.chat_message(message["role"], avatar=message.get("avatar")):
            st.markdown(message["content"])
    if prompt_text:
        if st.session_state.messages != []:
            with st.chat_message("user", avatar="user"):
                input_placeholder = st.empty()
            with st.chat_message("robot", avatar="assistant"):
                message_placeholder = st.empty()
        input_placeholder.markdown(prompt_text)
        st.session_state.messages.append({"role": "user", "content": prompt_text, "avatar": "user"})
        st.session_state.text = getText("user", prompt_text, st.session_state.text)
        question = checklen(st.session_state.text)
        main_chat(appid,api_key,api_secret,Spark_url,domain,question)
        st.session_state.text = getText("assistant", answer, st.session_state.text)
        st.session_state.messages.append({"role": "robot", "content": answer, "avatar": "assistant"})
        st.rerun()
    button_clear = st.button("清空", on_click=clear_all)

elif  API_model == '讯飞 - 星火大模型 V2.0':
    if "text2" not in st.session_state:
        st.session_state.text2 = []
    if "messages2" not in st.session_state:
        st.session_state.messages2 = [] 
    def clear_all2():
        st.session_state.messages2 = []
        st.session_state.text2 = []
    if st.session_state.messages2 == []:
        with st.chat_message("user", avatar="user"):
            input_placeholder = st.empty()
        with st.chat_message("robot", avatar="assistant"):
            message_placeholder = st.empty()
    for message in st.session_state.messages2:
        with st.chat_message(message["role"], avatar=message.get("avatar")):
            st.markdown(message["content"])
    if prompt_text:
        if st.session_state.messages2 != []:
            with st.chat_message("user", avatar="user"):
                input_placeholder = st.empty()
            with st.chat_message("robot", avatar="assistant"):
                message_placeholder = st.empty()
        input_placeholder.markdown(prompt_text)
        st.session_state.messages2.append({"role": "user", "content": prompt_text, "avatar": "user"})
        st.session_state.text2 = getText("user", prompt_text, st.session_state.text2)
        question = checklen(st.session_state.text2)
        main_chat(appid,api_key,api_secret,Spark_url,domain,question)
        st.session_state.text2 = getText("assistant", answer, st.session_state.text2)
        st.session_state.messages2.append({"role": "robot", "content": answer, "avatar": "assistant"})
        st.rerun()
    button_clear = st.button("清空", on_click=clear_all2, key='clear2')

elif  API_model == '讯飞 - 星火大模型 V3.0':
    if "text3" not in st.session_state:
        st.session_state.text3 = []
    if "messages3" not in st.session_state:
        st.session_state.messages3 = [] 
    def clear_all3():
        st.session_state.messages3 = []
        st.session_state.text3 = []
    if st.session_state.messages3 == []:
        with st.chat_message("user", avatar="user"):
            input_placeholder = st.empty()
        with st.chat_message("robot", avatar="assistant"):
            message_placeholder = st.empty()
    for message in st.session_state.messages3:
        with st.chat_message(message["role"], avatar=message.get("avatar")):
            st.markdown(message["content"])
    if prompt_text:
        if st.session_state.messages3 != []:
            with st.chat_message("user", avatar="user"):
                input_placeholder = st.empty()
            with st.chat_message("robot", avatar="assistant"):
                message_placeholder = st.empty()
        input_placeholder.markdown(prompt_text)
        st.session_state.messages3.append({"role": "user", "content": prompt_text, "avatar": "user"})
        st.session_state.text3 = getText("user", prompt_text, st.session_state.text3)
        question = checklen(st.session_state.text3)
        main_chat(appid,api_key,api_secret,Spark_url,domain,question)
        st.session_state.text3 = getText("assistant", answer, st.session_state.text3)
        st.session_state.messages3.append({"role": "robot", "content": answer, "avatar": "assistant"})
        st.rerun()
    button_clear = st.button("清空", on_click=clear_all3, key='clear3')
86 次浏览

【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注

Captcha Code