import requests import json import faiss import numpy as np from rest_framework.views import APIView from apps.ichat.serializers import MessageSerializer, ConversationSerializer from rest_framework.response import Response from apps.ichat.models import Conversation, Message from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, is_safe_sql, save_message_thread_safe, get_table_structures from django.http import StreamingHttpResponse, JsonResponse from rest_framework.decorators import action from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet # API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" # API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" # MODEL = "qwen-plus" #本地部署的模式 API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" API_BASE = "http://106.0.4.200:9000/v1" MODEL = "qwen14b" # 文本向量化模型 EM_MODEL = "m3e-base" API_BASE_EM = "http://106.0.4.200:9997/v1" # google gemini # API_KEY = "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" # API_BASE = "https://openrouter.ai/api/v1" # MODEL="google/gemini-2.0-flash-exp:free" # deepseek v3 # API_KEY = "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" # API_BASE = "https://openrouter.ai/api/v1" # MODEL="deepseek/deepseek-chat-v3-0324:free" TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 HEADERS = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}" } def get_table_names(conn): sql = """ SELECT tablename FROM pg_tables WHERE schemaname = 'public'; """ cur = conn.cursor() cur.execute(sql) data = cur.fetchall() cur.close() return [row[0] for row in data] # def get_relation_table(query): # conn = connect_db() # # table_names = TABLES # table_names = get_table_names(conn) # schemas = get_table_structures(conn, table_names) # texts = [ # f"这是一个数据库表结构,表名为 {s['table']},其结构如下:{s['text']}" # for s in schemas # ] # table_names = [s["table"] for s in schemas] # embeddings = embed_text(texts) # index, index_table_map = create_index(embeddings, texts, table_names) # results = search_similar_tables(query, index, index_table_map, top_k=3) # if not results: # return "没有找到相关表结构" # return results def get_relation_table(query: str): conn = connect_db() table_names = get_table_names(conn) # 只获取用户表 schemas = get_table_structures(conn, table_names) texts = [s["text"] for s in schemas] table_names = [s["table"] for s in schemas] embeddings = embed_text(texts) # 存储向量 store_embeddings_pg(conn, embeddings, texts, table_names) # 查询相似表 results = search_similar_tables_pg(conn, query, top_k=5) if len(results) == 0: return "没有找到相关表结构" # 只取相关表的结构 schemas = get_table_structures(conn, results) llm_results = format_schema_for_llm(schemas) return llm_results def store_embeddings_pg(conn, embeddings: list[list[float]], texts: list[str], table_names: list[str]): cur = conn.cursor() for embedding, text, table_name in zip(embeddings, texts, table_names): cur.execute(""" INSERT INTO table_embeddings (table_name, schema_text, embedding) VALUES (%s, %s, %s) ON CONFLICT (table_name) DO UPDATE SET schema_text = EXCLUDED.schema_text, embedding = EXCLUDED.embedding """, (table_name, text, embedding)) conn.commit() cur.close() def search_similar_tables_pg(conn, query: str, top_k: int = 5): # 第一步:将 query 转为 embedding query_embedding = embed_text([query])[0] # 第二步:embedding 转成 '[x, y, z]' 格式字符串 embedding_str = ",".join(map(str, query_embedding)) cur = conn.cursor() query = f""" SELECT table_name FROM table_embeddings ORDER BY embedding <-> '[{embedding_str}]'::vector LIMIT {top_k}; """ cur.execute(query) results = [row[0] for row in cur.fetchall()] cur.close() return results def format_schema_for_llm(schemas: list[dict]) -> str: lines = [] for schema in schemas: lines.append(f"【表名】:{schema['table']}") lines.append("【字段】:") for col in schema["text"].split("结构如下:")[1].split("\n"): if col.strip(): lines.append(f" - {col.strip()}") lines.append("") # 空行分隔表 return "\n".join(lines) def embed_text(texts: list[str]) -> list[list[float]]: paylaod = { "input":texts, "model":EM_MODEL } url = f"{API_BASE_EM}/embeddings" response = requests.post(url, headers=HEADERS, json=paylaod) json_data = response.json() return [e['embedding'] for e in json_data['data']] # def search_similar_tables(query: str, index, index_table_map, top_k:int=3): # query_embedding = embed_text([query])[0] # distances, indices = index.search(np.array([query_embedding]).astype("float32"), int(top_k)) # results = [] # for i in indices[0]: # if i != -1 and i in index_table_map: # results.append(index_table_map[i]) # return results # def create_index(embeddings: list[list[float]], texts: list[str], table_names: list[str]): # print(len(embeddings), '-----------') # dim = len(embeddings[0]) # index = faiss.IndexFlatL2(dim) # embeddings_np = np.array(embeddings).astype('float32') # index.add(embeddings_np) # # 构建索引到表名的映射字典 # index_table_map = {i: table_names[i] for i in range(len(table_names))} # return index, index_table_map class QueryLLMviewSet(CustomModelViewSet): queryset = Message.objects.all() serializer_class = MessageSerializer ordering = ['create_time'] perms_map = {'get':'*', 'post':'*', 'put':'*'} @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer) def completion(self, request): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) serializer.save() prompt = serializer.validated_data['content'] conversation = serializer.validated_data['conversation'] if not prompt or not conversation: return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400) save_message_thread_safe(content=prompt, conversation=conversation, role="user") url = f"{API_BASE}/chat/completions" user_prompt = f""" 我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。 注意: 只需回答"database"或"general"即可,不要有其他内容。 """ _payload = { "model": MODEL, "messages": [{"role": "user", "content": user_prompt}], "temperature": 0, "max_tokens": 10 } try: class_response = requests.post(url, headers=HEADERS, json=_payload) class_response.raise_for_status() class_result = class_response.json() question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower() print("question_type", question_type) if question_type == "database": schema_text = get_relation_table(prompt) user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构: {schema_text} 请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 需求是:{prompt} """ else: user_prompt = f""" 回答以下问题,不需要涉及数据库查询: 问题: {prompt} 请直接回答问题,不要提及数据库或SQL。 """ # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages # history = Message.objects.filter(conversation=conversation).order_by('create_time') # chat_history = [{"role": msg.role, "content": msg.content} for msg in history] # chat_history.append({"role": "user", "content": prompt}) chat_history = [{"role":"user", "content":user_prompt}] print("user_prompt", user_prompt) payload = { "model": MODEL, "messages": chat_history, "temperature": 0, "stream": True } response = requests.post(url, headers=HEADERS, json=payload) response.raise_for_status() except requests.exceptions.RequestException as e: return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500) def stream_generator(): accumulated_content = "" for line in response.iter_lines(): if line: decoded_line = line.decode('utf-8') if decoded_line.startswith('data:'): if decoded_line.strip() == "data: [DONE]": break # OpenAI-style标志结束 try: data = json.loads(decoded_line[6:]) content = data.get('choices', [{}])[0].get('delta', {}).get('content', '') if content: accumulated_content += content yield f"data: {content}\n\n" except Exception as e: yield f"data: [解析失败]: {str(e)}\n\n" print("accumulated_content", accumulated_content) save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system") if question_type == "database": sql = extract_sql_code(accumulated_content) if sql: try: conn = connect_db() if is_safe_sql(sql): result = execute_sql(conn, sql) save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system") yield f"data: SQL执行结果: {result}\n\n" else: yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n" except Exception as e: yield f"data: SQL执行失败: {str(e)}\n\n" finally: if conn: conn.close() else: yield "data: \\n[文本结束]\n\n" return StreamingHttpResponse(stream_generator(), content_type='text/event-stream') # 先新建对话 生成对话session_id class ConversationViewSet(CustomModelViewSet): queryset = Conversation.objects.all() serializer_class = ConversationSerializer ordering = ['create_time'] perms_map = {'get':'*', 'post':'*', 'put':'*'}