import requests import json from rest_framework.views import APIView from apps.ichat.serializers import MessageSerializer, ConversationSerializer from rest_framework.response import Response from apps.ichat.models import Conversation, Message from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe from django.http import StreamingHttpResponse, JsonResponse from rest_framework.decorators import action from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet # API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" # API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" # MODEL = "qwen-plus" # #本地部署的模式 API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" API_BASE = "http://106.0.4.200:9000/v1" MODEL = "qwen14b" # google gemini # API_KEY = "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" # API_BASE = "https://openrouter.ai/api/v1" # MODEL="google/gemini-2.0-flash-exp:free" # deepseek v3 # API_KEY = "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" # API_BASE = "https://openrouter.ai/api/v1" # MODEL="deepseek/deepseek-chat-v3-0324:free" TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 class QueryLLMviewSet(CustomModelViewSet): queryset = Message.objects.all() serializer_class = MessageSerializer ordering = ['create_time'] perms_map = {'get':'*', 'post':'*', 'put':'*'} @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer) def completion(self, request): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) serializer.save() prompt = serializer.validated_data['content'] conversation = serializer.validated_data['conversation'] if not prompt or not conversation: return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400) save_message_thread_safe(content=prompt, conversation=conversation, role="user") url = f"{API_BASE}/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}" } user_prompt = f""" 我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。 注意: 只需回答"database"或"general"即可,不要有其他内容。 """ _payload = { "model": MODEL, "messages": [{"role": "user", "content": user_prompt}, {"role":"system" , "content": "只返回一个结果'database'或'general'"}], "temperature": 0, "max_tokens": 10 } try: class_response = requests.post(url, headers=headers, json=_payload) class_response.raise_for_status() class_result = class_response.json() question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower() print("question_type", question_type) if question_type == "database": conn = connect_db() schema_text = get_schema_text(conn, TABLES) print("schema_text----------------------", schema_text) user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构: {schema_text} 请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 需求是:{prompt} """ else: user_prompt = f""" 回答以下问题,不需要涉及数据库查询: 问题: {prompt} 请直接回答问题,不要提及数据库或SQL。 """ # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages history = Message.objects.filter(conversation=conversation).order_by('create_time') # chat_history = [{"role": msg.role, "content": msg.content} for msg in history] # chat_history.append({"role": "user", "content": prompt}) chat_history = [{"role":"user", "content":prompt}] print("chat_history", chat_history) payload = { "model": MODEL, "messages": chat_history, "temperature": 0, "stream": True } response = requests.post(url, headers=headers, json=payload) response.raise_for_status() except requests.exceptions.RequestException as e: return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500) def stream_generator(): accumulated_content = "" for line in response.iter_lines(): if line: decoded_line = line.decode('utf-8') if decoded_line.startswith('data:'): if decoded_line.strip() == "data: [DONE]": break # OpenAI-style标志结束 try: data = json.loads(decoded_line[6:]) content = data.get('choices', [{}])[0].get('delta', {}).get('content', '') if content: accumulated_content += content yield f"data: {content}\n\n" except Exception as e: yield f"data: [解析失败]: {str(e)}\n\n" print("accumulated_content", accumulated_content) save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system") if question_type == "database": sql = extract_sql_code(accumulated_content) if sql: try: conn = connect_db() if is_safe_sql(sql): result = execute_sql(conn, sql) save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system") yield f"data: SQL执行结果: {result}\n\n" else: yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n" except Exception as e: yield f"data: SQL执行失败: {str(e)}\n\n" finally: if conn: conn.close() else: yield "data: \\n[文本结束]\n\n" return StreamingHttpResponse(stream_generator(), content_type='text/event-stream') # 先新建对话 生成对话session_id class ConversationViewSet(CustomModelViewSet): queryset = Conversation.objects.all() serializer_class = ConversationSerializer ordering = ['create_time'] perms_map = {'get':'*', 'post':'*', 'put':'*'}