import requests import psycopg2 from rest_framework.views import APIView from apps.ichat.serializers import MessageSerializer, ConversationSerializer from rest_framework.response import Response from ichat.models import Conversation, Message from rest_framework.generics import get_object_or_404 #本地部署模型 # API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" # API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" # MODEL = "qwen-plus" #本地部署的模式 # API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" # API_BASE = "http://106.0.4.200:9000/v1" # MODEL = "Qwen/Qwen2.5-14B-Instruct" # google gemini API_KEY = "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" API_BASE = "https://openrouter.ai/api/v1" MODEL="google/gemini-2.0-flash-exp:free" # deepseek v3 # API_KEY = "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" # API_BASE = "https://openrouter.ai/api/v1" # MODEL="deepseek/deepseek-chat-v3-0324:free" TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 # 数据库连接 def connect_db(): from server.conf import DATABASES db_conf = DATABASES['default'] conn = psycopg2.connect( host=db_conf['HOST'], port=db_conf['PORT'], user=db_conf['USER'], password=db_conf['PASSWORD'], database=db_conf['NAME'] ) return conn def get_schema_text(conn, table_names:list): cur = conn.cursor() query = """ SELECT table_name, column_name, data_type FROM information_schema.columns WHERE table_schema = 'public' and table_name in %s; """ cur.execute(query, (tuple(table_names), )) schema = {} for table_name, column_name, data_type in cur.fetchall(): if table_name not in schema: schema[table_name] = [] schema[table_name].append(f"{column_name} ({data_type})") cur.close() schema_text = "" for table_name, columns in schema.items(): schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" return schema_text # 调用大模型生成sql def call_llm_api(prompt, api_key=API_KEY, api_base=API_BASE, model=MODEL): url = f"{api_base}/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } payload = { "model": model, "messages": [{"role": "user", "content": prompt}], "temperature": 0, } response = requests.post(url, headers=headers, json=payload) response.raise_for_status() print("\n大模型返回:\n", response.json()) return response.json()["choices"][0]["message"]["content"] def execute_sql(conn, sql_query): cur = conn.cursor() cur.execute(sql_query) try: rows = cur.fetchall() columns = [desc[0] for desc in cur.description] result = [dict(zip(columns, row)) for row in rows] except psycopg2.ProgrammingError: result = cur.statusmessage cur.close() return result def strip_sql_markdown(content: str) -> str: import re # 去掉包裹在 ```sql 或 ``` 中的内容 match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() else: return None class QueryLLMview(APIView): def post(self, request): serializer = MessageSerializer(data=request.data) serializer.is_valid(raise_exception=True) serializer.save() prompt = serializer.validated_data['prompt'] conn = connect_db() # 数据库表结构 schema_text = get_schema_text(conn, TABLES) user_prompt = f"""你是可能是一个专业的数据库工程师,根据以下数据库结构: {schema_text} 请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 需求是:{prompt} """ llm_data = call_llm_api(user_prompt) # 判断是否生成的是sql 如果不是直接返回message generated_sql = strip_sql_markdown(llm_data) if generated_sql: try: result = execute_sql(conn, generated_sql) return Response({"result": result}) except Exception as e: print("\n第一次执行SQL报错了,错误信息:", str(e)) # 如果第一次执行SQL报错,则重新生成SQL fix_prompt = f"""刚才你生成的SQL出现了错误,错误信息是:{str(e)} 请根据这个错误修正你的SQL,返回新的正确的SQL,直接给出SQL,不要解释。 数据库结构如下: {schema_text} 用户需求是:{prompt} """ fixed_sql = call_llm_api(fix_prompt) fixed_sql = strip_sql_markdown(fixed_sql) try: results = execute_sql(conn, fixed_sql) print("\n修正后的查询结果:") print(results) return Response({"result": results}) except Exception as e2: print("\n修正后的SQL仍然报错,错误信息:", str(e2)) return Response({"error": "SQL执行失败", "detail": str(e2)}, status=400) finally: conn.close() else: return Response({"result": llm_data}) # 先新建对话 生成对话session_id class ConversationView(APIView): def get(self, request): conversation = Conversation.objects.all() serializer = ConversationSerializer(conversation, many=True) return Response(serializer.data) def post(self, request): serializer = ConversationSerializer(data=request.data) serializer.is_valid(raise_exception=True) serializer.save() return Response(serializer.data) def put(self, request, pk): conversation = get_object_or_404(Conversation, pk=pk) serializer = ConversationSerializer(conversation, data=request.data) serializer.is_valid(raise_exception=True) serializer.save() return Response(serializer.data)