287 lines
11 KiB
Python
287 lines
11 KiB
Python
import requests
|
||
import json
|
||
import faiss
|
||
import numpy as np
|
||
from rest_framework.views import APIView
|
||
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
|
||
from rest_framework.response import Response
|
||
from apps.ichat.models import Conversation, Message
|
||
from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, is_safe_sql, save_message_thread_safe, get_table_structures
|
||
from django.http import StreamingHttpResponse, JsonResponse
|
||
from rest_framework.decorators import action
|
||
from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
|
||
|
||
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
|
||
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
# MODEL = "qwen-plus"
|
||
|
||
#本地部署的模式
|
||
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
|
||
API_BASE = "http://106.0.4.200:9000/v1"
|
||
MODEL = "qwen14b"
|
||
|
||
# 文本向量化模型
|
||
EM_MODEL = "m3e-base"
|
||
API_BASE_EM = "http://106.0.4.200:9997/v1"
|
||
|
||
# google gemini
|
||
# API_KEY = "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"
|
||
# API_BASE = "https://openrouter.ai/api/v1"
|
||
# MODEL="google/gemini-2.0-flash-exp:free"
|
||
|
||
# deepseek v3
|
||
# API_KEY = "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"
|
||
# API_BASE = "https://openrouter.ai/api/v1"
|
||
# MODEL="deepseek/deepseek-chat-v3-0324:free"
|
||
|
||
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
|
||
|
||
HEADERS = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {API_KEY}"
|
||
}
|
||
|
||
def get_table_names(conn):
|
||
sql = """
|
||
SELECT tablename
|
||
FROM pg_tables
|
||
WHERE schemaname = 'public';
|
||
"""
|
||
cur = conn.cursor()
|
||
cur.execute(sql)
|
||
data = cur.fetchall()
|
||
cur.close()
|
||
return [row[0] for row in data]
|
||
|
||
# def get_relation_table(query):
|
||
# conn = connect_db()
|
||
# # table_names = TABLES
|
||
# table_names = get_table_names(conn)
|
||
# schemas = get_table_structures(conn, table_names)
|
||
|
||
# texts = [
|
||
# f"这是一个数据库表结构,表名为 {s['table']},其结构如下:{s['text']}"
|
||
# for s in schemas
|
||
# ]
|
||
# table_names = [s["table"] for s in schemas]
|
||
# embeddings = embed_text(texts)
|
||
# index, index_table_map = create_index(embeddings, texts, table_names)
|
||
|
||
# results = search_similar_tables(query, index, index_table_map, top_k=3)
|
||
|
||
# if not results:
|
||
# return "没有找到相关表结构"
|
||
# return results
|
||
|
||
def get_relation_table(query: str):
|
||
conn = connect_db()
|
||
table_names = get_table_names(conn) # 只获取用户表
|
||
schemas = get_table_structures(conn, table_names)
|
||
texts = [s["text"] for s in schemas]
|
||
table_names = [s["table"] for s in schemas]
|
||
embeddings = embed_text(texts)
|
||
|
||
# 存储向量
|
||
store_embeddings_pg(conn, embeddings, texts, table_names)
|
||
|
||
# 查询相似表
|
||
results = search_similar_tables_pg(conn, query, top_k=5)
|
||
|
||
if len(results) == 0:
|
||
return "没有找到相关表结构"
|
||
# 只取相关表的结构
|
||
schemas = get_table_structures(conn, results)
|
||
|
||
llm_results = format_schema_for_llm(schemas)
|
||
return llm_results
|
||
|
||
def store_embeddings_pg(conn, embeddings: list[list[float]], texts: list[str], table_names: list[str]):
|
||
cur = conn.cursor()
|
||
for embedding, text, table_name in zip(embeddings, texts, table_names):
|
||
cur.execute("""
|
||
INSERT INTO table_embeddings (table_name, schema_text, embedding)
|
||
VALUES (%s, %s, %s)
|
||
ON CONFLICT (table_name) DO UPDATE
|
||
SET schema_text = EXCLUDED.schema_text,
|
||
embedding = EXCLUDED.embedding
|
||
""", (table_name, text, embedding))
|
||
conn.commit()
|
||
cur.close()
|
||
|
||
def search_similar_tables_pg(conn, query: str, top_k: int = 5):
|
||
# 第一步:将 query 转为 embedding
|
||
query_embedding = embed_text([query])[0]
|
||
# 第二步:embedding 转成 '[x, y, z]' 格式字符串
|
||
embedding_str = ",".join(map(str, query_embedding))
|
||
cur = conn.cursor()
|
||
query = f"""
|
||
SELECT table_name
|
||
FROM table_embeddings
|
||
ORDER BY embedding <-> '[{embedding_str}]'::vector
|
||
LIMIT {top_k};
|
||
"""
|
||
cur.execute(query)
|
||
results = [row[0] for row in cur.fetchall()]
|
||
cur.close()
|
||
return results
|
||
|
||
|
||
def format_schema_for_llm(schemas: list[dict]) -> str:
|
||
lines = []
|
||
for schema in schemas:
|
||
lines.append(f"【表名】:{schema['table']}")
|
||
lines.append("【字段】:")
|
||
for col in schema["text"].split("结构如下:")[1].split("\n"):
|
||
if col.strip():
|
||
lines.append(f" - {col.strip()}")
|
||
lines.append("") # 空行分隔表
|
||
return "\n".join(lines)
|
||
|
||
|
||
def embed_text(texts: list[str]) -> list[list[float]]:
|
||
paylaod = {
|
||
"input":texts,
|
||
"model":EM_MODEL
|
||
}
|
||
url = f"{API_BASE_EM}/embeddings"
|
||
response = requests.post(url, headers=HEADERS, json=paylaod)
|
||
json_data = response.json()
|
||
return [e['embedding'] for e in json_data['data']]
|
||
|
||
# def search_similar_tables(query: str, index, index_table_map, top_k:int=3):
|
||
# query_embedding = embed_text([query])[0]
|
||
# distances, indices = index.search(np.array([query_embedding]).astype("float32"), int(top_k))
|
||
# results = []
|
||
# for i in indices[0]:
|
||
# if i != -1 and i in index_table_map:
|
||
# results.append(index_table_map[i])
|
||
# return results
|
||
|
||
# def create_index(embeddings: list[list[float]], texts: list[str], table_names: list[str]):
|
||
# print(len(embeddings), '-----------')
|
||
# dim = len(embeddings[0])
|
||
# index = faiss.IndexFlatL2(dim)
|
||
# embeddings_np = np.array(embeddings).astype('float32')
|
||
# index.add(embeddings_np)
|
||
|
||
# # 构建索引到表名的映射字典
|
||
# index_table_map = {i: table_names[i] for i in range(len(table_names))}
|
||
# return index, index_table_map
|
||
|
||
|
||
|
||
class QueryLLMviewSet(CustomModelViewSet):
|
||
queryset = Message.objects.all()
|
||
serializer_class = MessageSerializer
|
||
ordering = ['create_time']
|
||
perms_map = {'get':'*', 'post':'*', 'put':'*'}
|
||
|
||
@action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer)
|
||
def completion(self, request):
|
||
serializer = self.get_serializer(data=request.data)
|
||
serializer.is_valid(raise_exception=True)
|
||
serializer.save()
|
||
prompt = serializer.validated_data['content']
|
||
conversation = serializer.validated_data['conversation']
|
||
if not prompt or not conversation:
|
||
return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400)
|
||
save_message_thread_safe(content=prompt, conversation=conversation, role="user")
|
||
url = f"{API_BASE}/chat/completions"
|
||
user_prompt = f"""
|
||
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。
|
||
|
||
注意:
|
||
只需回答"database"或"general"即可,不要有其他内容。
|
||
"""
|
||
_payload = {
|
||
"model": MODEL,
|
||
"messages": [{"role": "user", "content": user_prompt}],
|
||
"temperature": 0,
|
||
"max_tokens": 10
|
||
}
|
||
try:
|
||
class_response = requests.post(url, headers=HEADERS, json=_payload)
|
||
class_response.raise_for_status()
|
||
class_result = class_response.json()
|
||
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
|
||
print("question_type", question_type)
|
||
if question_type == "database":
|
||
schema_text = get_relation_table(prompt)
|
||
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
|
||
{schema_text}
|
||
请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。
|
||
需求是:{prompt}
|
||
"""
|
||
else:
|
||
user_prompt = f"""
|
||
回答以下问题,不需要涉及数据库查询:
|
||
|
||
问题: {prompt}
|
||
|
||
请直接回答问题,不要提及数据库或SQL。
|
||
"""
|
||
# TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages
|
||
# history = Message.objects.filter(conversation=conversation).order_by('create_time')
|
||
# chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
|
||
# chat_history.append({"role": "user", "content": prompt})
|
||
chat_history = [{"role":"user", "content":user_prompt}]
|
||
print("user_prompt", user_prompt)
|
||
payload = {
|
||
"model": MODEL,
|
||
"messages": chat_history,
|
||
"temperature": 0,
|
||
"stream": True
|
||
}
|
||
response = requests.post(url, headers=HEADERS, json=payload)
|
||
response.raise_for_status()
|
||
except requests.exceptions.RequestException as e:
|
||
return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500)
|
||
def stream_generator():
|
||
accumulated_content = ""
|
||
for line in response.iter_lines():
|
||
if line:
|
||
decoded_line = line.decode('utf-8')
|
||
if decoded_line.startswith('data:'):
|
||
if decoded_line.strip() == "data: [DONE]":
|
||
break # OpenAI-style标志结束
|
||
try:
|
||
data = json.loads(decoded_line[6:])
|
||
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
|
||
if content:
|
||
accumulated_content += content
|
||
yield f"data: {content}\n\n"
|
||
|
||
except Exception as e:
|
||
yield f"data: [解析失败]: {str(e)}\n\n"
|
||
print("accumulated_content", accumulated_content)
|
||
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
|
||
|
||
if question_type == "database":
|
||
sql = extract_sql_code(accumulated_content)
|
||
if sql:
|
||
try:
|
||
conn = connect_db()
|
||
if is_safe_sql(sql):
|
||
result = execute_sql(conn, sql)
|
||
save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system")
|
||
yield f"data: SQL执行结果: {result}\n\n"
|
||
else:
|
||
yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n"
|
||
except Exception as e:
|
||
yield f"data: SQL执行失败: {str(e)}\n\n"
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
else:
|
||
yield "data: \\n[文本结束]\n\n"
|
||
return StreamingHttpResponse(stream_generator(), content_type='text/event-stream')
|
||
|
||
# 先新建对话 生成对话session_id
|
||
class ConversationViewSet(CustomModelViewSet):
|
||
queryset = Conversation.objects.all()
|
||
serializer_class = ConversationSerializer
|
||
ordering = ['create_time']
|
||
perms_map = {'get':'*', 'post':'*', 'put':'*'}
|
||
|
||
|