156 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			156 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
import requests
 | 
						||
import json
 | 
						||
from rest_framework.views import APIView
 | 
						||
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
 | 
						||
from rest_framework.response import Response
 | 
						||
from apps.ichat.models import Conversation, Message
 | 
						||
from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe
 | 
						||
from django.http import StreamingHttpResponse, JsonResponse
 | 
						||
from rest_framework.decorators import action
 | 
						||
from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
 | 
						||
 | 
						||
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
 | 
						||
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
 | 
						||
# MODEL = "qwen-plus"
 | 
						||
 | 
						||
# #本地部署的模式
 | 
						||
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
 | 
						||
API_BASE = "http://106.0.4.200:9000/v1"
 | 
						||
MODEL = "qwen14b"
 | 
						||
 | 
						||
# google gemini
 | 
						||
# API_KEY = "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"
 | 
						||
# API_BASE = "https://openrouter.ai/api/v1"
 | 
						||
# MODEL="google/gemini-2.0-flash-exp:free"
 | 
						||
 | 
						||
# deepseek v3
 | 
						||
# API_KEY = "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"
 | 
						||
# API_BASE = "https://openrouter.ai/api/v1"
 | 
						||
# MODEL="deepseek/deepseek-chat-v3-0324:free"
 | 
						||
 | 
						||
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
 | 
						||
 | 
						||
 | 
						||
class QueryLLMviewSet(CustomModelViewSet):
 | 
						||
    queryset = Message.objects.all()
 | 
						||
    serializer_class = MessageSerializer
 | 
						||
    ordering = ['create_time']
 | 
						||
    perms_map = {'get':'*', 'post':'*', 'put':'*'}
 | 
						||
 | 
						||
    @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer)
 | 
						||
    def completion(self, request):
 | 
						||
        serializer = self.get_serializer(data=request.data)
 | 
						||
        serializer.is_valid(raise_exception=True)
 | 
						||
        serializer.save()
 | 
						||
        prompt = serializer.validated_data['content']
 | 
						||
        conversation = serializer.validated_data['conversation']
 | 
						||
        if not prompt or not conversation:
 | 
						||
            return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400)
 | 
						||
        save_message_thread_safe(content=prompt, conversation=conversation, role="user")
 | 
						||
        url = f"{API_BASE}/chat/completions"
 | 
						||
        headers = {
 | 
						||
            "Content-Type": "application/json",
 | 
						||
            "Authorization": f"Bearer {API_KEY}"
 | 
						||
        }
 | 
						||
        
 | 
						||
        user_prompt = f"""
 | 
						||
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。
 | 
						||
 | 
						||
注意:
 | 
						||
只需回答"database"或"general"即可,不要有其他内容。
 | 
						||
"""     
 | 
						||
        _payload = {
 | 
						||
            "model": MODEL,
 | 
						||
            "messages": [{"role": "user", "content": user_prompt}, {"role":"system" , "content": "只返回一个结果'database'或'general'"}],
 | 
						||
            "temperature": 0,
 | 
						||
            "max_tokens": 10
 | 
						||
        }
 | 
						||
        try:
 | 
						||
            class_response = requests.post(url, headers=headers, json=_payload)
 | 
						||
            class_response.raise_for_status()
 | 
						||
            class_result = class_response.json()
 | 
						||
            question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
 | 
						||
            print("question_type", question_type)
 | 
						||
            if question_type == "database":
 | 
						||
                conn = connect_db()
 | 
						||
                schema_text = get_schema_text(conn, TABLES)
 | 
						||
                print("schema_text----------------------", schema_text)
 | 
						||
                user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
 | 
						||
{schema_text}
 | 
						||
请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。
 | 
						||
需求是:{prompt}
 | 
						||
"""
 | 
						||
            else:
 | 
						||
                user_prompt = f"""
 | 
						||
回答以下问题,不需要涉及数据库查询:
 | 
						||
 | 
						||
问题: {prompt}
 | 
						||
 | 
						||
请直接回答问题,不要提及数据库或SQL。
 | 
						||
"""         
 | 
						||
            # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages
 | 
						||
            history = Message.objects.filter(conversation=conversation).order_by('create_time')
 | 
						||
            # chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
 | 
						||
            # chat_history.append({"role": "user", "content": prompt})
 | 
						||
            chat_history = [{"role":"user", "content":prompt}]
 | 
						||
            print("chat_history", chat_history)
 | 
						||
            payload = {
 | 
						||
                    "model": MODEL,
 | 
						||
                    "messages": chat_history,
 | 
						||
                    "temperature": 0,
 | 
						||
                    "stream": True
 | 
						||
                }
 | 
						||
            response = requests.post(url, headers=headers, json=payload)
 | 
						||
            response.raise_for_status()
 | 
						||
        except requests.exceptions.RequestException as e:
 | 
						||
            return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500)
 | 
						||
        def stream_generator():
 | 
						||
            accumulated_content = ""
 | 
						||
            for line in response.iter_lines():
 | 
						||
                if line:
 | 
						||
                    decoded_line = line.decode('utf-8')
 | 
						||
                    if decoded_line.startswith('data:'):
 | 
						||
                        if decoded_line.strip() == "data: [DONE]":
 | 
						||
                            break  # OpenAI-style标志结束
 | 
						||
                        try:
 | 
						||
                            data = json.loads(decoded_line[6:])
 | 
						||
                            content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
 | 
						||
                            if content:
 | 
						||
                                accumulated_content += content
 | 
						||
                                yield f"data: {content}\n\n"
 | 
						||
                            
 | 
						||
                        except Exception as e:
 | 
						||
                            yield f"data: [解析失败]: {str(e)}\n\n"
 | 
						||
            print("accumulated_content", accumulated_content)
 | 
						||
            save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
 | 
						||
 | 
						||
            if question_type == "database":
 | 
						||
                sql = extract_sql_code(accumulated_content)
 | 
						||
                if sql:
 | 
						||
                    try:
 | 
						||
                        conn = connect_db()
 | 
						||
                        if is_safe_sql(sql):
 | 
						||
                            result = execute_sql(conn, sql)
 | 
						||
                            save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system")
 | 
						||
                            yield f"data: SQL执行结果: {result}\n\n"
 | 
						||
                        else:
 | 
						||
                            yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n"
 | 
						||
                    except Exception as e:
 | 
						||
                        yield f"data: SQL执行失败: {str(e)}\n\n"
 | 
						||
                    finally:
 | 
						||
                        if conn:
 | 
						||
                            conn.close()
 | 
						||
                else:
 | 
						||
                    yield "data: \\n[文本结束]\n\n"
 | 
						||
        return StreamingHttpResponse(stream_generator(), content_type='text/event-stream')
 | 
						||
 | 
						||
 | 
						||
# 先新建对话 生成对话session_id
 | 
						||
class ConversationViewSet(CustomModelViewSet):
 | 
						||
    queryset = Conversation.objects.all()
 | 
						||
    serializer_class = ConversationSerializer
 | 
						||
    ordering = ['create_time']
 | 
						||
    perms_map = {'get':'*', 'post':'*', 'put':'*'}
 | 
						||
 | 
						||
 |