factory/apps/ichat/views.py

156 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
from rest_framework.views import APIView
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
from rest_framework.response import Response
from apps.ichat.models import Conversation, Message
from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe
from django.http import StreamingHttpResponse, JsonResponse
from rest_framework.decorators import action
from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# MODEL = "qwen-plus"
# #本地部署的模式
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
API_BASE = "http://106.0.4.200:9000/v1"
MODEL = "qwen14b"
# google gemini
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="google/gemini-2.0-flash-exp:free"
# deepseek v3
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="deepseek/deepseek-chat-v3-0324:free"
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
class QueryLLMviewSet(CustomModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}
@action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer)
def completion(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
prompt = serializer.validated_data['content']
conversation = serializer.validated_data['conversation']
if not prompt or not conversation:
return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400)
save_message_thread_safe(content=prompt, conversation=conversation, role="user")
url = f"{API_BASE}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
user_prompt = f"""
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"
注意:
只需回答"database""general"即可,不要有其他内容。
"""
_payload = {
"model": MODEL,
"messages": [{"role": "user", "content": user_prompt}, {"role":"system" , "content": "只返回一个结果'database''general'"}],
"temperature": 0,
"max_tokens": 10
}
try:
class_response = requests.post(url, headers=headers, json=_payload)
class_response.raise_for_status()
class_result = class_response.json()
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
print("question_type", question_type)
if question_type == "database":
conn = connect_db()
schema_text = get_schema_text(conn, TABLES)
print("schema_text----------------------", schema_text)
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
{schema_text}
请根据我的需求生成一条标准的PostgreSQL SQL语句直接返回SQL不要额外解释。
需求是:{prompt}
"""
else:
user_prompt = f"""
回答以下问题,不需要涉及数据库查询:
问题: {prompt}
请直接回答问题不要提及数据库或SQL。
"""
# TODO 是否应该拿到conservastion的id然后根据id去数据库查询所以的messages, 然后赋值给messages
history = Message.objects.filter(conversation=conversation).order_by('create_time')
# chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
# chat_history.append({"role": "user", "content": prompt})
chat_history = [{"role":"user", "content":prompt}]
print("chat_history", chat_history)
payload = {
"model": MODEL,
"messages": chat_history,
"temperature": 0,
"stream": True
}
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
except requests.exceptions.RequestException as e:
return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500)
def stream_generator():
accumulated_content = ""
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data:'):
if decoded_line.strip() == "data: [DONE]":
break # OpenAI-style标志结束
try:
data = json.loads(decoded_line[6:])
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
if content:
accumulated_content += content
yield f"data: {content}\n\n"
except Exception as e:
yield f"data: [解析失败]: {str(e)}\n\n"
print("accumulated_content", accumulated_content)
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
if question_type == "database":
sql = extract_sql_code(accumulated_content)
if sql:
try:
conn = connect_db()
if is_safe_sql(sql):
result = execute_sql(conn, sql)
save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system")
yield f"data: SQL执行结果: {result}\n\n"
else:
yield f"data: 拒绝执行非查询类 SQL{sql}\n\n"
except Exception as e:
yield f"data: SQL执行失败: {str(e)}\n\n"
finally:
if conn:
conn.close()
else:
yield "data: \\n[文本结束]\n\n"
return StreamingHttpResponse(stream_generator(), content_type='text/event-stream')
# 先新建对话 生成对话session_id
class ConversationViewSet(CustomModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}