factory/apps/ichat/view_bak2.py

287 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import faiss
import numpy as np
from rest_framework.views import APIView
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
from rest_framework.response import Response
from apps.ichat.models import Conversation, Message
from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, is_safe_sql, save_message_thread_safe, get_table_structures
from django.http import StreamingHttpResponse, JsonResponse
from rest_framework.decorators import action
from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# MODEL = "qwen-plus"
#本地部署的模式
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
API_BASE = "http://106.0.4.200:9000/v1"
MODEL = "qwen14b"
# 文本向量化模型
EM_MODEL = "m3e-base"
API_BASE_EM = "http://106.0.4.200:9997/v1"
# google gemini
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="google/gemini-2.0-flash-exp:free"
# deepseek v3
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="deepseek/deepseek-chat-v3-0324:free"
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
HEADERS = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
def get_table_names(conn):
sql = """
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public';
"""
cur = conn.cursor()
cur.execute(sql)
data = cur.fetchall()
cur.close()
return [row[0] for row in data]
# def get_relation_table(query):
# conn = connect_db()
# # table_names = TABLES
# table_names = get_table_names(conn)
# schemas = get_table_structures(conn, table_names)
# texts = [
# f"这是一个数据库表结构,表名为 {s['table']},其结构如下:{s['text']}"
# for s in schemas
# ]
# table_names = [s["table"] for s in schemas]
# embeddings = embed_text(texts)
# index, index_table_map = create_index(embeddings, texts, table_names)
# results = search_similar_tables(query, index, index_table_map, top_k=3)
# if not results:
# return "没有找到相关表结构"
# return results
def get_relation_table(query: str):
conn = connect_db()
table_names = get_table_names(conn) # 只获取用户表
schemas = get_table_structures(conn, table_names)
texts = [s["text"] for s in schemas]
table_names = [s["table"] for s in schemas]
embeddings = embed_text(texts)
# 存储向量
store_embeddings_pg(conn, embeddings, texts, table_names)
# 查询相似表
results = search_similar_tables_pg(conn, query, top_k=5)
if len(results) == 0:
return "没有找到相关表结构"
# 只取相关表的结构
schemas = get_table_structures(conn, results)
llm_results = format_schema_for_llm(schemas)
return llm_results
def store_embeddings_pg(conn, embeddings: list[list[float]], texts: list[str], table_names: list[str]):
cur = conn.cursor()
for embedding, text, table_name in zip(embeddings, texts, table_names):
cur.execute("""
INSERT INTO table_embeddings (table_name, schema_text, embedding)
VALUES (%s, %s, %s)
ON CONFLICT (table_name) DO UPDATE
SET schema_text = EXCLUDED.schema_text,
embedding = EXCLUDED.embedding
""", (table_name, text, embedding))
conn.commit()
cur.close()
def search_similar_tables_pg(conn, query: str, top_k: int = 5):
# 第一步:将 query 转为 embedding
query_embedding = embed_text([query])[0]
# 第二步embedding 转成 '[x, y, z]' 格式字符串
embedding_str = ",".join(map(str, query_embedding))
cur = conn.cursor()
query = f"""
SELECT table_name
FROM table_embeddings
ORDER BY embedding <-> '[{embedding_str}]'::vector
LIMIT {top_k};
"""
cur.execute(query)
results = [row[0] for row in cur.fetchall()]
cur.close()
return results
def format_schema_for_llm(schemas: list[dict]) -> str:
lines = []
for schema in schemas:
lines.append(f"【表名】:{schema['table']}")
lines.append("【字段】:")
for col in schema["text"].split("结构如下:")[1].split("\n"):
if col.strip():
lines.append(f" - {col.strip()}")
lines.append("") # 空行分隔表
return "\n".join(lines)
def embed_text(texts: list[str]) -> list[list[float]]:
paylaod = {
"input":texts,
"model":EM_MODEL
}
url = f"{API_BASE_EM}/embeddings"
response = requests.post(url, headers=HEADERS, json=paylaod)
json_data = response.json()
return [e['embedding'] for e in json_data['data']]
# def search_similar_tables(query: str, index, index_table_map, top_k:int=3):
# query_embedding = embed_text([query])[0]
# distances, indices = index.search(np.array([query_embedding]).astype("float32"), int(top_k))
# results = []
# for i in indices[0]:
# if i != -1 and i in index_table_map:
# results.append(index_table_map[i])
# return results
# def create_index(embeddings: list[list[float]], texts: list[str], table_names: list[str]):
# print(len(embeddings), '-----------')
# dim = len(embeddings[0])
# index = faiss.IndexFlatL2(dim)
# embeddings_np = np.array(embeddings).astype('float32')
# index.add(embeddings_np)
# # 构建索引到表名的映射字典
# index_table_map = {i: table_names[i] for i in range(len(table_names))}
# return index, index_table_map
class QueryLLMviewSet(CustomModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}
@action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer)
def completion(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
prompt = serializer.validated_data['content']
conversation = serializer.validated_data['conversation']
if not prompt or not conversation:
return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400)
save_message_thread_safe(content=prompt, conversation=conversation, role="user")
url = f"{API_BASE}/chat/completions"
user_prompt = f"""
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"
注意:
只需回答"database""general"即可,不要有其他内容。
"""
_payload = {
"model": MODEL,
"messages": [{"role": "user", "content": user_prompt}],
"temperature": 0,
"max_tokens": 10
}
try:
class_response = requests.post(url, headers=HEADERS, json=_payload)
class_response.raise_for_status()
class_result = class_response.json()
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
print("question_type", question_type)
if question_type == "database":
schema_text = get_relation_table(prompt)
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
{schema_text}
请根据我的需求生成一条标准的PostgreSQL SQL语句直接返回SQL不要额外解释。
需求是:{prompt}
"""
else:
user_prompt = f"""
回答以下问题,不需要涉及数据库查询:
问题: {prompt}
请直接回答问题不要提及数据库或SQL。
"""
# TODO 是否应该拿到conservastion的id然后根据id去数据库查询所以的messages, 然后赋值给messages
# history = Message.objects.filter(conversation=conversation).order_by('create_time')
# chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
# chat_history.append({"role": "user", "content": prompt})
chat_history = [{"role":"user", "content":user_prompt}]
print("user_prompt", user_prompt)
payload = {
"model": MODEL,
"messages": chat_history,
"temperature": 0,
"stream": True
}
response = requests.post(url, headers=HEADERS, json=payload)
response.raise_for_status()
except requests.exceptions.RequestException as e:
return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500)
def stream_generator():
accumulated_content = ""
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data:'):
if decoded_line.strip() == "data: [DONE]":
break # OpenAI-style标志结束
try:
data = json.loads(decoded_line[6:])
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
if content:
accumulated_content += content
yield f"data: {content}\n\n"
except Exception as e:
yield f"data: [解析失败]: {str(e)}\n\n"
print("accumulated_content", accumulated_content)
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
if question_type == "database":
sql = extract_sql_code(accumulated_content)
if sql:
try:
conn = connect_db()
if is_safe_sql(sql):
result = execute_sql(conn, sql)
save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system")
yield f"data: SQL执行结果: {result}\n\n"
else:
yield f"data: 拒绝执行非查询类 SQL{sql}\n\n"
except Exception as e:
yield f"data: SQL执行失败: {str(e)}\n\n"
finally:
if conn:
conn.close()
else:
yield "data: \\n[文本结束]\n\n"
return StreamingHttpResponse(stream_generator(), content_type='text/event-stream')
# 先新建对话 生成对话session_id
class ConversationViewSet(CustomModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}