173 lines
6.2 KiB
Python
173 lines
6.2 KiB
Python
import requests
|
||
import psycopg2
|
||
from rest_framework.views import APIView
|
||
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
|
||
from rest_framework.response import Response
|
||
from ichat.models import Conversation, Message
|
||
from rest_framework.generics import get_object_or_404
|
||
#本地部署模型
|
||
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
|
||
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
# MODEL = "qwen-plus"
|
||
|
||
#本地部署的模式
|
||
# API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
|
||
# API_BASE = "http://106.0.4.200:9000/v1"
|
||
# MODEL = "Qwen/Qwen2.5-14B-Instruct"
|
||
|
||
# google gemini
|
||
API_KEY = "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"
|
||
API_BASE = "https://openrouter.ai/api/v1"
|
||
MODEL="google/gemini-2.0-flash-exp:free"
|
||
|
||
# deepseek v3
|
||
# API_KEY = "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"
|
||
# API_BASE = "https://openrouter.ai/api/v1"
|
||
# MODEL="deepseek/deepseek-chat-v3-0324:free"
|
||
|
||
|
||
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
|
||
# 数据库连接
|
||
def connect_db():
|
||
from server.conf import DATABASES
|
||
db_conf = DATABASES['default']
|
||
conn = psycopg2.connect(
|
||
host=db_conf['HOST'],
|
||
port=db_conf['PORT'],
|
||
user=db_conf['USER'],
|
||
password=db_conf['PASSWORD'],
|
||
database=db_conf['NAME']
|
||
)
|
||
return conn
|
||
|
||
def get_schema_text(conn, table_names:list):
|
||
cur = conn.cursor()
|
||
query = """
|
||
SELECT
|
||
table_name, column_name, data_type
|
||
FROM
|
||
information_schema.columns
|
||
WHERE
|
||
table_schema = 'public'
|
||
and table_name in %s;
|
||
"""
|
||
cur.execute(query, (tuple(table_names), ))
|
||
|
||
schema = {}
|
||
for table_name, column_name, data_type in cur.fetchall():
|
||
if table_name not in schema:
|
||
schema[table_name] = []
|
||
schema[table_name].append(f"{column_name} ({data_type})")
|
||
cur.close()
|
||
schema_text = ""
|
||
for table_name, columns in schema.items():
|
||
schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n"
|
||
return schema_text
|
||
|
||
|
||
# 调用大模型生成sql
|
||
def call_llm_api(prompt, api_key=API_KEY, api_base=API_BASE, model=MODEL):
|
||
url = f"{api_base}/chat/completions"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}"
|
||
}
|
||
payload = {
|
||
"model": model,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0,
|
||
}
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
print("\n大模型返回:\n", response.json())
|
||
return response.json()["choices"][0]["message"]["content"]
|
||
|
||
|
||
def execute_sql(conn, sql_query):
|
||
cur = conn.cursor()
|
||
cur.execute(sql_query)
|
||
try:
|
||
rows = cur.fetchall()
|
||
columns = [desc[0] for desc in cur.description]
|
||
result = [dict(zip(columns, row)) for row in rows]
|
||
except psycopg2.ProgrammingError:
|
||
result = cur.statusmessage
|
||
cur.close()
|
||
return result
|
||
|
||
|
||
def strip_sql_markdown(content: str) -> str:
|
||
import re
|
||
# 去掉包裹在 ```sql 或 ``` 中的内容
|
||
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
|
||
if match:
|
||
return match.group(1).strip()
|
||
else:
|
||
return None
|
||
|
||
|
||
class QueryLLMview(APIView):
|
||
def post(self, request):
|
||
serializer = MessageSerializer(data=request.data)
|
||
serializer.is_valid(raise_exception=True)
|
||
serializer.save()
|
||
prompt = serializer.validated_data['prompt']
|
||
conn = connect_db()
|
||
# 数据库表结构
|
||
schema_text = get_schema_text(conn, TABLES)
|
||
user_prompt = f"""你是可能是一个专业的数据库工程师,根据以下数据库结构:
|
||
{schema_text}
|
||
请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。
|
||
需求是:{prompt}
|
||
"""
|
||
llm_data = call_llm_api(user_prompt)
|
||
# 判断是否生成的是sql 如果不是直接返回message
|
||
generated_sql = strip_sql_markdown(llm_data)
|
||
if generated_sql:
|
||
try:
|
||
result = execute_sql(conn, generated_sql)
|
||
return Response({"result": result})
|
||
except Exception as e:
|
||
print("\n第一次执行SQL报错了,错误信息:", str(e))
|
||
# 如果第一次执行SQL报错,则重新生成SQL
|
||
fix_prompt = f"""刚才你生成的SQL出现了错误,错误信息是:{str(e)}
|
||
请根据这个错误修正你的SQL,返回新的正确的SQL,直接给出SQL,不要解释。
|
||
数据库结构如下:
|
||
{schema_text}
|
||
用户需求是:{prompt}
|
||
"""
|
||
fixed_sql = call_llm_api(fix_prompt)
|
||
fixed_sql = strip_sql_markdown(fixed_sql)
|
||
try:
|
||
results = execute_sql(conn, fixed_sql)
|
||
print("\n修正后的查询结果:")
|
||
print(results)
|
||
return Response({"result": results})
|
||
except Exception as e2:
|
||
print("\n修正后的SQL仍然报错,错误信息:", str(e2))
|
||
return Response({"error": "SQL执行失败", "detail": str(e2)}, status=400)
|
||
finally:
|
||
conn.close()
|
||
else:
|
||
return Response({"result": llm_data})
|
||
|
||
|
||
# 先新建对话 生成对话session_id
|
||
class ConversationView(APIView):
|
||
def get(self, request):
|
||
conversation = Conversation.objects.all()
|
||
serializer = ConversationSerializer(conversation, many=True)
|
||
return Response(serializer.data)
|
||
|
||
def post(self, request):
|
||
serializer = ConversationSerializer(data=request.data)
|
||
serializer.is_valid(raise_exception=True)
|
||
serializer.save()
|
||
return Response(serializer.data)
|
||
|
||
def put(self, request, pk):
|
||
conversation = get_object_or_404(Conversation, pk=pk)
|
||
serializer = ConversationSerializer(conversation, data=request.data)
|
||
serializer.is_valid(raise_exception=True)
|
||
serializer.save()
|
||
return Response(serializer.data) |