diff --git a/apps/ichat/models.py b/apps/ichat/models.py index daf6944c..d83065dd 100644 --- a/apps/ichat/models.py +++ b/apps/ichat/models.py @@ -14,4 +14,4 @@ class Message(BaseModel): """ conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, verbose_name='对话') content = models.TextField(verbose_name='消息内容') - role = models.CharField("角色", max_length=10, help_text="system/user") + role = models.CharField("角色", max_length=10, default='user', help_text="system/user") diff --git a/apps/ichat/serializers.py b/apps/ichat/serializers.py index 1e7a6b33..c5c23831 100644 --- a/apps/ichat/serializers.py +++ b/apps/ichat/serializers.py @@ -1,4 +1,18 @@ from rest_framework import serializers +from .models import Conversation, Message +from apps.utils.constants import EXCLUDE_FIELDS -class CustomLLMrequestSerializer(serializers.Serializer): - prompt = serializers.CharField() \ No newline at end of file + +class MessageSerializer(serializers.ModelSerializer): + class Meta: + model = Message + fields = ['id', 'conversation', 'mode', 'content', 'role'] + read_only_fields = EXCLUDE_FIELDS + + +class ConversationSerializer(serializers.ModelSerializer): + messages = MessageSerializer(many=True, read_only=True) + class Meta: + model = Conversation + fields = ['id', 'title', 'messages'] + read_only_fields = EXCLUDE_FIELDS \ No newline at end of file diff --git a/apps/ichat/urls.py b/apps/ichat/urls.py index 3fdc93c1..22d7553d 100644 --- a/apps/ichat/urls.py +++ b/apps/ichat/urls.py @@ -1,8 +1,10 @@ from django.urls import path -from apps.ichat.views import QueryLLMview +from apps.ichat.views import QueryLLMview, ConversationView -API_BASE_URL = 'api/llm/ichat/' +API_BASE_URL = 'api/ichat/' urlpatterns = [ path(API_BASE_URL + 'query/', QueryLLMview.as_view(), name='llm_query'), + path(API_BASE_URL + 'conversation/', ConversationView.as_view(), name='conversation') + ] \ No newline at end of file diff --git a/apps/ichat/view_bak.py b/apps/ichat/view_bak.py new file mode 100644 index 00000000..4a5570c7 --- /dev/null +++ b/apps/ichat/view_bak.py @@ -0,0 +1,87 @@ +import requests +from langchain_core.language_models import LLM +from langchain_core.outputs import LLMResult, Generation +from langchain_experimental.sql import SQLDatabaseChain +from langchain_community.utilities import SQLDatabase +from server.conf import DATABASES +from apps.ichat.serializers import CustomLLMrequestSerializer +from rest_framework.views import APIView +from urllib.parse import quote_plus +from rest_framework.response import Response + + +db_conf = DATABASES['default'] +# 密码需要 URL 编码(因为有特殊字符如 @) +password_encodeed = quote_plus(db_conf['PASSWORD']) + +db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"]) +# model_url = "http://14.22.88.72:11025/v1/chat/completions" +model_url = "http://139.159.180.64:11434/v1/chat/completions" + +class CustomLLM(LLM): + model_url: str + mode: str = 'chat' + def _call(self, prompt: str, stop: list = None) -> str: + data = { + "model":"glm4", + "messages": self.build_message(prompt), + "stream": False, + } + response = requests.post(self.model_url, json=data, timeout=600) + response.raise_for_status() + content = response.json()["choices"][0]["message"]["content"] + print('content---', content) + clean_sql = self.strip_sql_markdown(content) if self.mode == 'sql' else content.strip() + return clean_sql + + def _generate(self, prompts: list, stop: list = None) -> LLMResult: + generations = [] + for prompt in prompts: + text = self._call(prompt, stop) + generations.append([Generation(text=text)]) + return LLMResult(generations=generations) + + def strip_sql_markdown(self, content: str) -> str: + import re + # 去掉包裹在 ```sql 或 ``` 中的内容 + match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + else: + return content.strip() + + def build_message(self, prompt: str) -> list: + if self.mode == 'sql': + system_prompt = ( + "你是一个 SQL 助手,严格遵循以下规则:\n" + "1. 只返回 PostgreSQL 语法 SQL 语句。\n" + "2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n" + "3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。\n" + "4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。" + ) + else: + system_prompt = "你是一个聊天助手,请根据用户的问题,提供简洁明了的答案。" + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] + + @property + def _llm_type(self) -> str: + return "custom_llm" + + +class QueryLLMview(APIView): + def post(self, request): + serializer = CustomLLMrequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + prompt = serializer.validated_data['prompt'] + mode = serializer.validated_data.get('mode', 'chat') + llm = CustomLLM(model_url=model_url, mode=mode) + print('prompt---', prompt, mode) + if mode == 'sql': + chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) + result = chain.invoke(prompt) + else: + result = llm._call(prompt) + return Response({"result": result}) \ No newline at end of file diff --git a/apps/ichat/views.py b/apps/ichat/views.py index 61ed583e..473d488a 100644 --- a/apps/ichat/views.py +++ b/apps/ichat/views.py @@ -1,76 +1,173 @@ import requests -from langchain_core.language_models import LLM -from langchain_core.outputs import LLMResult, Generation -from langchain_experimental.sql import SQLDatabaseChain -from langchain_community.utilities import SQLDatabase -from server.conf import DATABASES -from apps.ichat.serializers import CustomLLMrequestSerializer +import psycopg2 from rest_framework.views import APIView -from urllib.parse import quote_plus +from apps.ichat.serializers import MessageSerializer, ConversationSerializer from rest_framework.response import Response +from ichat.models import Conversation, Message +from rest_framework.generics import get_object_or_404 +#本地部署模型 +# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" +# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" +# MODEL = "qwen-plus" + +#本地部署的模式 +# API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" +# API_BASE = "http://106.0.4.200:9000/v1" +# MODEL = "Qwen/Qwen2.5-14B-Instruct" + +# google gemini +API_KEY = "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" +API_BASE = "https://openrouter.ai/api/v1" +MODEL="google/gemini-2.0-flash-exp:free" + +# deepseek v3 +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="deepseek/deepseek-chat-v3-0324:free" -db_conf = DATABASES['default'] -# 密码需要 URL 编码(因为有特殊字符如 @) -password_encodeed = quote_plus(db_conf['PASSWORD']) +TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 +# 数据库连接 +def connect_db(): + from server.conf import DATABASES + db_conf = DATABASES['default'] + conn = psycopg2.connect( + host=db_conf['HOST'], + port=db_conf['PORT'], + user=db_conf['USER'], + password=db_conf['PASSWORD'], + database=db_conf['NAME'] + ) + return conn -db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"]) -# model_url = "http://14.22.88.72:11025/v1/chat/completions" -model_url = "http://139.159.180.64:11434/v1/chat/completions" +def get_schema_text(conn, table_names:list): + cur = conn.cursor() + query = """ + SELECT + table_name, column_name, data_type + FROM + information_schema.columns + WHERE + table_schema = 'public' + and table_name in %s; + """ + cur.execute(query, (tuple(table_names), )) -class CustomLLM(LLM): - model_url: str - def _call(self, prompt: str, stop: list = None) -> str: - data = { - "model": "glm4", - "messages": [ - { - "role": "system", - "content": "你是一个 SQL 助手,严格遵循以下规则:\n" - "1. 只返回 PostgreSQL 语法 SQL 语句。\n" - "2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n" - "3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。" - "4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。" - }, - {"role": "user", "content": prompt} - ], - "stream": False - } - response = requests.post(self.model_url, json=data, timeout=600) - response.raise_for_status() - content = response.json()["choices"][0]["message"]["content"] - clean_sql = self.strip_sql_markdown(content) - return clean_sql + schema = {} + for table_name, column_name, data_type in cur.fetchall(): + if table_name not in schema: + schema[table_name] = [] + schema[table_name].append(f"{column_name} ({data_type})") + cur.close() + schema_text = "" + for table_name, columns in schema.items(): + schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" + return schema_text - def _generate(self, prompts: list, stop: list = None) -> LLMResult: - generations = [] - for prompt in prompts: - text = self._call(prompt, stop) - generations.append([Generation(text=text)]) - return LLMResult(generations=generations) - - def strip_sql_markdown(self, content: str) -> str: + +# 调用大模型生成sql +def call_llm_api(prompt, api_key=API_KEY, api_base=API_BASE, model=MODEL): + url = f"{api_base}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + payload = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0, + } + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + print("\n大模型返回:\n", response.json()) + return response.json()["choices"][0]["message"]["content"] + + +def execute_sql(conn, sql_query): + cur = conn.cursor() + cur.execute(sql_query) + try: + rows = cur.fetchall() + columns = [desc[0] for desc in cur.description] + result = [dict(zip(columns, row)) for row in rows] + except psycopg2.ProgrammingError: + result = cur.statusmessage + cur.close() + return result + + +def strip_sql_markdown(content: str) -> str: import re # 去掉包裹在 ```sql 或 ``` 中的内容 match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() - match = re.search(r"```\s*(.*?)```", content, re.DOTALL) - if match: - return match.group(1).strip() - return content.strip() - - @property - def _llm_type(self) -> str: - return "custom_llm" - + else: + return None + class QueryLLMview(APIView): def post(self, request): - serializer = CustomLLMrequestSerializer(data=request.data) + serializer = MessageSerializer(data=request.data) serializer.is_valid(raise_exception=True) + serializer.save() prompt = serializer.validated_data['prompt'] - llm = CustomLLM(model_url=model_url) - chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) - result = chain.invoke(prompt) - return Response({"result": result}) \ No newline at end of file + conn = connect_db() + # 数据库表结构 + schema_text = get_schema_text(conn, TABLES) + user_prompt = f"""你是可能是一个专业的数据库工程师,根据以下数据库结构: +{schema_text} +请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 +需求是:{prompt} +""" + llm_data = call_llm_api(user_prompt) + # 判断是否生成的是sql 如果不是直接返回message + generated_sql = strip_sql_markdown(llm_data) + if generated_sql: + try: + result = execute_sql(conn, generated_sql) + return Response({"result": result}) + except Exception as e: + print("\n第一次执行SQL报错了,错误信息:", str(e)) + # 如果第一次执行SQL报错,则重新生成SQL + fix_prompt = f"""刚才你生成的SQL出现了错误,错误信息是:{str(e)} + 请根据这个错误修正你的SQL,返回新的正确的SQL,直接给出SQL,不要解释。 + 数据库结构如下: + {schema_text} + 用户需求是:{prompt} + """ + fixed_sql = call_llm_api(fix_prompt) + fixed_sql = strip_sql_markdown(fixed_sql) + try: + results = execute_sql(conn, fixed_sql) + print("\n修正后的查询结果:") + print(results) + return Response({"result": results}) + except Exception as e2: + print("\n修正后的SQL仍然报错,错误信息:", str(e2)) + return Response({"error": "SQL执行失败", "detail": str(e2)}, status=400) + finally: + conn.close() + else: + return Response({"result": llm_data}) + + +# 先新建对话 生成对话session_id +class ConversationView(APIView): + def get(self, request): + conversation = Conversation.objects.all() + serializer = ConversationSerializer(conversation, many=True) + return Response(serializer.data) + + def post(self, request): + serializer = ConversationSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data) + + def put(self, request, pk): + conversation = get_object_or_404(Conversation, pk=pk) + serializer = ConversationSerializer(conversation, data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data) \ No newline at end of file