factory/apps/ichat/views.py

173 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import psycopg2
from rest_framework.views import APIView
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
from rest_framework.response import Response
from ichat.models import Conversation, Message
from rest_framework.generics import get_object_or_404
#本地部署模型
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# MODEL = "qwen-plus"
#本地部署的模式
# API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
# API_BASE = "http://106.0.4.200:9000/v1"
# MODEL = "Qwen/Qwen2.5-14B-Instruct"
# google gemini
API_KEY = "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"
API_BASE = "https://openrouter.ai/api/v1"
MODEL="google/gemini-2.0-flash-exp:free"
# deepseek v3
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="deepseek/deepseek-chat-v3-0324:free"
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
# 数据库连接
def connect_db():
from server.conf import DATABASES
db_conf = DATABASES['default']
conn = psycopg2.connect(
host=db_conf['HOST'],
port=db_conf['PORT'],
user=db_conf['USER'],
password=db_conf['PASSWORD'],
database=db_conf['NAME']
)
return conn
def get_schema_text(conn, table_names:list):
cur = conn.cursor()
query = """
SELECT
table_name, column_name, data_type
FROM
information_schema.columns
WHERE
table_schema = 'public'
and table_name in %s;
"""
cur.execute(query, (tuple(table_names), ))
schema = {}
for table_name, column_name, data_type in cur.fetchall():
if table_name not in schema:
schema[table_name] = []
schema[table_name].append(f"{column_name} ({data_type})")
cur.close()
schema_text = ""
for table_name, columns in schema.items():
schema_text += f"{table_name} 包含列:{', '.join(columns)}\n"
return schema_text
# 调用大模型生成sql
def call_llm_api(prompt, api_key=API_KEY, api_base=API_BASE, model=MODEL):
url = f"{api_base}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
}
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
print("\n大模型返回:\n", response.json())
return response.json()["choices"][0]["message"]["content"]
def execute_sql(conn, sql_query):
cur = conn.cursor()
cur.execute(sql_query)
try:
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description]
result = [dict(zip(columns, row)) for row in rows]
except psycopg2.ProgrammingError:
result = cur.statusmessage
cur.close()
return result
def strip_sql_markdown(content: str) -> str:
import re
# 去掉包裹在 ```sql 或 ``` 中的内容
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
else:
return None
class QueryLLMview(APIView):
def post(self, request):
serializer = MessageSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
prompt = serializer.validated_data['prompt']
conn = connect_db()
# 数据库表结构
schema_text = get_schema_text(conn, TABLES)
user_prompt = f"""你是可能是一个专业的数据库工程师,根据以下数据库结构:
{schema_text}
请根据我的需求生成一条标准的PostgreSQL SQL语句直接返回SQL不要额外解释。
需求是:{prompt}
"""
llm_data = call_llm_api(user_prompt)
# 判断是否生成的是sql 如果不是直接返回message
generated_sql = strip_sql_markdown(llm_data)
if generated_sql:
try:
result = execute_sql(conn, generated_sql)
return Response({"result": result})
except Exception as e:
print("\n第一次执行SQL报错了错误信息", str(e))
# 如果第一次执行SQL报错则重新生成SQL
fix_prompt = f"""刚才你生成的SQL出现了错误错误信息是{str(e)}
请根据这个错误修正你的SQL返回新的正确的SQL直接给出SQL不要解释。
数据库结构如下:
{schema_text}
用户需求是:{prompt}
"""
fixed_sql = call_llm_api(fix_prompt)
fixed_sql = strip_sql_markdown(fixed_sql)
try:
results = execute_sql(conn, fixed_sql)
print("\n修正后的查询结果:")
print(results)
return Response({"result": results})
except Exception as e2:
print("\n修正后的SQL仍然报错错误信息", str(e2))
return Response({"error": "SQL执行失败", "detail": str(e2)}, status=400)
finally:
conn.close()
else:
return Response({"result": llm_data})
# 先新建对话 生成对话session_id
class ConversationView(APIView):
def get(self, request):
conversation = Conversation.objects.all()
serializer = ConversationSerializer(conversation, many=True)
return Response(serializer.data)
def post(self, request):
serializer = ConversationSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
def put(self, request, pk):
conversation = get_object_or_404(Conversation, pk=pk)
serializer = ConversationSerializer(conversation, data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)