feat: ichat 修改大模型的提问
This commit is contained in:
parent
6fcda8c0a7
commit
8834d66066
|
@ -13,7 +13,7 @@ from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
|
||||||
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
# MODEL = "qwen-plus"
|
# MODEL = "qwen-plus"
|
||||||
|
|
||||||
#本地部署的模式
|
# #本地部署的模式
|
||||||
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
|
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
|
||||||
API_BASE = "http://106.0.4.200:9000/v1"
|
API_BASE = "http://106.0.4.200:9000/v1"
|
||||||
MODEL = "qwen14b"
|
MODEL = "qwen14b"
|
||||||
|
@ -27,6 +27,7 @@ MODEL = "qwen14b"
|
||||||
# API_KEY = "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"
|
# API_KEY = "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"
|
||||||
# API_BASE = "https://openrouter.ai/api/v1"
|
# API_BASE = "https://openrouter.ai/api/v1"
|
||||||
# MODEL="deepseek/deepseek-chat-v3-0324:free"
|
# MODEL="deepseek/deepseek-chat-v3-0324:free"
|
||||||
|
|
||||||
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
|
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,15 +54,14 @@ class QueryLLMviewSet(CustomModelViewSet):
|
||||||
}
|
}
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
请判断以下问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。
|
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。
|
||||||
|
|
||||||
问题: {prompt}
|
注意:
|
||||||
|
只需回答"database"或"general"即可,不要有其他内容。
|
||||||
只需回答"database"或"general",不要有其他内容。
|
|
||||||
"""
|
"""
|
||||||
_payload = {
|
_payload = {
|
||||||
"model": MODEL,
|
"model": MODEL,
|
||||||
"messages": [{"role": "user", "content": user_prompt}],
|
"messages": [{"role": "user", "content": user_prompt}, {"role":"system" , "content": "只返回一个结果'database'或'general'"}],
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_tokens": 10
|
"max_tokens": 10
|
||||||
}
|
}
|
||||||
|
@ -70,9 +70,11 @@ class QueryLLMviewSet(CustomModelViewSet):
|
||||||
class_response.raise_for_status()
|
class_response.raise_for_status()
|
||||||
class_result = class_response.json()
|
class_result = class_response.json()
|
||||||
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
|
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
|
||||||
|
print("question_type", question_type)
|
||||||
if question_type == "database":
|
if question_type == "database":
|
||||||
conn = connect_db()
|
conn = connect_db()
|
||||||
schema_text = get_schema_text(conn, TABLES)
|
schema_text = get_schema_text(conn, TABLES)
|
||||||
|
print("schema_text----------------------", schema_text)
|
||||||
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
|
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
|
||||||
{schema_text}
|
{schema_text}
|
||||||
请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。
|
请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。
|
||||||
|
@ -88,8 +90,9 @@ class QueryLLMviewSet(CustomModelViewSet):
|
||||||
"""
|
"""
|
||||||
# TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages
|
# TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages
|
||||||
history = Message.objects.filter(conversation=conversation).order_by('create_time')
|
history = Message.objects.filter(conversation=conversation).order_by('create_time')
|
||||||
chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
|
# chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
|
||||||
chat_history.append({"role": "user", "content": prompt})
|
# chat_history.append({"role": "user", "content": prompt})
|
||||||
|
chat_history = [{"role":"user", "content":prompt}]
|
||||||
print("chat_history", chat_history)
|
print("chat_history", chat_history)
|
||||||
payload = {
|
payload = {
|
||||||
"model": MODEL,
|
"model": MODEL,
|
||||||
|
@ -112,7 +115,6 @@ class QueryLLMviewSet(CustomModelViewSet):
|
||||||
try:
|
try:
|
||||||
data = json.loads(decoded_line[6:])
|
data = json.loads(decoded_line[6:])
|
||||||
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
|
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
|
||||||
print("content", content)
|
|
||||||
if content:
|
if content:
|
||||||
accumulated_content += content
|
accumulated_content += content
|
||||||
yield f"data: {content}\n\n"
|
yield f"data: {content}\n\n"
|
||||||
|
@ -120,8 +122,6 @@ class QueryLLMviewSet(CustomModelViewSet):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f"data: [解析失败]: {str(e)}\n\n"
|
yield f"data: [解析失败]: {str(e)}\n\n"
|
||||||
print("accumulated_content", accumulated_content)
|
print("accumulated_content", accumulated_content)
|
||||||
print("question_type", question_type)
|
|
||||||
print("conversation", conversation)
|
|
||||||
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
|
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
|
||||||
|
|
||||||
if question_type == "database":
|
if question_type == "database":
|
||||||
|
|
Loading…
Reference in New Issue