From f3ab4476a4c2ee120ae73ce4201814208a2fa0a6 Mon Sep 17 00:00:00 2001 From: zty Date: Mon, 12 May 2025 09:54:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20ichat=20/=E5=A2=9E=E5=8A=A0ichat?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=EF=BC=8C=E4=BC=98=E5=8C=96=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E6=B5=81=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/ichat/models.py | 2 +- apps/ichat/script.py | 22 ++++ apps/ichat/serializers.py | 2 +- apps/ichat/urls.py | 16 ++- apps/ichat/utils.py | 88 +++++++++++++ apps/ichat/views.py | 268 ++++++++++++++++++-------------------- apps/utils/viewsets.py | 4 + server/settings.py | 1 + server/urls.py | 2 +- 9 files changed, 253 insertions(+), 152 deletions(-) create mode 100644 apps/ichat/script.py create mode 100644 apps/ichat/utils.py diff --git a/apps/ichat/models.py b/apps/ichat/models.py index d83065dd..e8e79fad 100644 --- a/apps/ichat/models.py +++ b/apps/ichat/models.py @@ -12,6 +12,6 @@ class Message(BaseModel): """ TN: 消息 """ - conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, verbose_name='对话') + conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages', verbose_name='对话') content = models.TextField(verbose_name='消息内容') role = models.CharField("角色", max_length=10, default='user', help_text="system/user") diff --git a/apps/ichat/script.py b/apps/ichat/script.py new file mode 100644 index 00000000..705ce0d0 --- /dev/null +++ b/apps/ichat/script.py @@ -0,0 +1,22 @@ +import json +from .models import Message +from django.http import StreamingHttpResponse + +def stream_generator(stream_response: bytes, conversation_id: str): + full_content = '' + for chunk in stream_response.iter_content(chunk_size=1024): + if chunk: + full_content += chunk.decode('utf-8') + try: + data = json.loads(full_content) + content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") + Message.objects.create( + conversation_id=conversation_id, + content=content + ) + yield f" data:{content}\n\n" + full_content = '' + except json.JSONDecodeError: + continue + return StreamingHttpResponse(stream_generator(stream_response, conversation_id), content_type='text/event-stream') + diff --git a/apps/ichat/serializers.py b/apps/ichat/serializers.py index c5c23831..43102545 100644 --- a/apps/ichat/serializers.py +++ b/apps/ichat/serializers.py @@ -6,7 +6,7 @@ from apps.utils.constants import EXCLUDE_FIELDS class MessageSerializer(serializers.ModelSerializer): class Meta: model = Message - fields = ['id', 'conversation', 'mode', 'content', 'role'] + fields = ['id', 'conversation', 'content', 'role'] read_only_fields = EXCLUDE_FIELDS diff --git a/apps/ichat/urls.py b/apps/ichat/urls.py index 22d7553d..fcc50def 100644 --- a/apps/ichat/urls.py +++ b/apps/ichat/urls.py @@ -1,10 +1,14 @@ -from django.urls import path -from apps.ichat.views import QueryLLMview, ConversationView +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from apps.ichat.views import QueryLLMviewSet, ConversationViewSet API_BASE_URL = 'api/ichat/' + +router = DefaultRouter() + +router.register('conversation', ConversationViewSet, basename='conversation') +router.register('message', QueryLLMviewSet, basename='message') urlpatterns = [ - path(API_BASE_URL + 'query/', QueryLLMview.as_view(), name='llm_query'), - path(API_BASE_URL + 'conversation/', ConversationView.as_view(), name='conversation') - -] \ No newline at end of file + path(API_BASE_URL, include(router.urls)), +] diff --git a/apps/ichat/utils.py b/apps/ichat/utils.py new file mode 100644 index 00000000..e055a672 --- /dev/null +++ b/apps/ichat/utils.py @@ -0,0 +1,88 @@ +import re +import psycopg2 +import threading +from django.db import transaction +from .models import Message + +# 数据库连接 +def connect_db(): + from server.conf import DATABASES + db_conf = DATABASES['default'] + conn = psycopg2.connect( + host=db_conf['HOST'], + port=db_conf['PORT'], + user=db_conf['USER'], + password=db_conf['PASSWORD'], + database=db_conf['NAME'] + ) + return conn + +def extract_sql_code(text): + # 优先尝试 ```sql 包裹的语句 + match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + + # fallback: 寻找首个 select 语句 + match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL) + if match: + return match.group(1).strip() + + return None + + +def get_schema_text(conn, table_names:list): + cur = conn.cursor() + query = """ + SELECT + table_name, column_name, data_type + FROM + information_schema.columns + WHERE + table_schema = 'public' + and table_name in %s; + """ + cur.execute(query, (tuple(table_names), )) + + schema = {} + for table_name, column_name, data_type in cur.fetchall(): + if table_name not in schema: + schema[table_name] = [] + schema[table_name].append(f"{column_name} ({data_type})") + cur.close() + schema_text = "" + for table_name, columns in schema.items(): + schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" + return schema_text + + +def is_safe_sql(sql:str) -> bool: + sql = sql.strip().lower() + return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql) + +def execute_sql(conn, sql_query): + cur = conn.cursor() + cur.execute(sql_query) + try: + rows = cur.fetchall() + columns = [desc[0] for desc in cur.description] + result = [dict(zip(columns, row)) for row in rows] + except psycopg2.ProgrammingError: + result = cur.statusmessage + cur.close() + return result + +def strip_sql_markdown(content: str) -> str: + # 去掉包裹在 ```sql 或 ``` 中的内容 + match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + else: + return None + +# ORM 写入包装函数 +def save_message_thread_safe(**kwargs): + def _save(): + with transaction.atomic(): + Message.objects.create(**kwargs) + threading.Thread(target=_save).start() diff --git a/apps/ichat/views.py b/apps/ichat/views.py index 473d488a..99f97980 100644 --- a/apps/ichat/views.py +++ b/apps/ichat/views.py @@ -1,173 +1,155 @@ import requests -import psycopg2 +import json from rest_framework.views import APIView from apps.ichat.serializers import MessageSerializer, ConversationSerializer from rest_framework.response import Response -from ichat.models import Conversation, Message -from rest_framework.generics import get_object_or_404 -#本地部署模型 +from apps.ichat.models import Conversation, Message +from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe +from django.http import StreamingHttpResponse, JsonResponse +from rest_framework.decorators import action +from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet + # API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" # API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" # MODEL = "qwen-plus" #本地部署的模式 -# API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" -# API_BASE = "http://106.0.4.200:9000/v1" -# MODEL = "Qwen/Qwen2.5-14B-Instruct" +API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" +API_BASE = "http://106.0.4.200:9000/v1" +MODEL = "qwen14b" # google gemini -API_KEY = "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" -API_BASE = "https://openrouter.ai/api/v1" -MODEL="google/gemini-2.0-flash-exp:free" +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="google/gemini-2.0-flash-exp:free" # deepseek v3 # API_KEY = "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" # API_BASE = "https://openrouter.ai/api/v1" # MODEL="deepseek/deepseek-chat-v3-0324:free" - - TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 -# 数据库连接 -def connect_db(): - from server.conf import DATABASES - db_conf = DATABASES['default'] - conn = psycopg2.connect( - host=db_conf['HOST'], - port=db_conf['PORT'], - user=db_conf['USER'], - password=db_conf['PASSWORD'], - database=db_conf['NAME'] - ) - return conn - -def get_schema_text(conn, table_names:list): - cur = conn.cursor() - query = """ - SELECT - table_name, column_name, data_type - FROM - information_schema.columns - WHERE - table_schema = 'public' - and table_name in %s; - """ - cur.execute(query, (tuple(table_names), )) - - schema = {} - for table_name, column_name, data_type in cur.fetchall(): - if table_name not in schema: - schema[table_name] = [] - schema[table_name].append(f"{column_name} ({data_type})") - cur.close() - schema_text = "" - for table_name, columns in schema.items(): - schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" - return schema_text -# 调用大模型生成sql -def call_llm_api(prompt, api_key=API_KEY, api_base=API_BASE, model=MODEL): - url = f"{api_base}/chat/completions" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - payload = { - "model": model, - "messages": [{"role": "user", "content": prompt}], - "temperature": 0, - } - response = requests.post(url, headers=headers, json=payload) - response.raise_for_status() - print("\n大模型返回:\n", response.json()) - return response.json()["choices"][0]["message"]["content"] +class QueryLLMviewSet(CustomModelViewSet): + queryset = Message.objects.all() + serializer_class = MessageSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} - -def execute_sql(conn, sql_query): - cur = conn.cursor() - cur.execute(sql_query) - try: - rows = cur.fetchall() - columns = [desc[0] for desc in cur.description] - result = [dict(zip(columns, row)) for row in rows] - except psycopg2.ProgrammingError: - result = cur.statusmessage - cur.close() - return result - - -def strip_sql_markdown(content: str) -> str: - import re - # 去掉包裹在 ```sql 或 ``` 中的内容 - match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) - if match: - return match.group(1).strip() - else: - return None - - -class QueryLLMview(APIView): - def post(self, request): - serializer = MessageSerializer(data=request.data) + @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer) + def completion(self, request): + serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) serializer.save() - prompt = serializer.validated_data['prompt'] - conn = connect_db() - # 数据库表结构 - schema_text = get_schema_text(conn, TABLES) - user_prompt = f"""你是可能是一个专业的数据库工程师,根据以下数据库结构: + prompt = serializer.validated_data['content'] + conversation = serializer.validated_data['conversation'] + if not prompt or not conversation: + return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400) + save_message_thread_safe(content=prompt, conversation=conversation, role="user") + url = f"{API_BASE}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {API_KEY}" + } + + user_prompt = f""" +请判断以下问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。 + +问题: {prompt} + +只需回答"database"或"general",不要有其他内容。 +""" + _payload = { + "model": MODEL, + "messages": [{"role": "user", "content": user_prompt}], + "temperature": 0, + "max_tokens": 10 + } + try: + class_response = requests.post(url, headers=headers, json=_payload) + class_response.raise_for_status() + class_result = class_response.json() + question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower() + if question_type == "database": + conn = connect_db() + schema_text = get_schema_text(conn, TABLES) + user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构: {schema_text} 请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 需求是:{prompt} -""" - llm_data = call_llm_api(user_prompt) - # 判断是否生成的是sql 如果不是直接返回message - generated_sql = strip_sql_markdown(llm_data) - if generated_sql: - try: - result = execute_sql(conn, generated_sql) - return Response({"result": result}) - except Exception as e: - print("\n第一次执行SQL报错了,错误信息:", str(e)) - # 如果第一次执行SQL报错,则重新生成SQL - fix_prompt = f"""刚才你生成的SQL出现了错误,错误信息是:{str(e)} - 请根据这个错误修正你的SQL,返回新的正确的SQL,直接给出SQL,不要解释。 - 数据库结构如下: - {schema_text} - 用户需求是:{prompt} - """ - fixed_sql = call_llm_api(fix_prompt) - fixed_sql = strip_sql_markdown(fixed_sql) - try: - results = execute_sql(conn, fixed_sql) - print("\n修正后的查询结果:") - print(results) - return Response({"result": results}) - except Exception as e2: - print("\n修正后的SQL仍然报错,错误信息:", str(e2)) - return Response({"error": "SQL执行失败", "detail": str(e2)}, status=400) - finally: - conn.close() - else: - return Response({"result": llm_data}) +""" + else: + user_prompt = f""" +回答以下问题,不需要涉及数据库查询: + +问题: {prompt} + +请直接回答问题,不要提及数据库或SQL。 +""" + # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages + history = Message.objects.filter(conversation=conversation).order_by('create_time') + chat_history = [{"role": msg.role, "content": msg.content} for msg in history] + chat_history.append({"role": "user", "content": prompt}) + print("chat_history", chat_history) + payload = { + "model": MODEL, + "messages": chat_history, + "temperature": 0, + "stream": True + } + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + except requests.exceptions.RequestException as e: + return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500) + def stream_generator(): + accumulated_content = "" + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8') + if decoded_line.startswith('data:'): + if decoded_line.strip() == "data: [DONE]": + break # OpenAI-style标志结束 + try: + data = json.loads(decoded_line[6:]) + content = data.get('choices', [{}])[0].get('delta', {}).get('content', '') + print("content", content) + if content: + accumulated_content += content + yield f"data: {content}\n\n" + + except Exception as e: + yield f"data: [解析失败]: {str(e)}\n\n" + print("accumulated_content", accumulated_content) + print("question_type", question_type) + print("conversation", conversation) + save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system") + + if question_type == "database": + sql = extract_sql_code(accumulated_content) + if sql: + try: + conn = connect_db() + if is_safe_sql(sql): + result = execute_sql(conn, sql) + save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system") + yield f"data: SQL执行结果: {result}\n\n" + else: + yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n" + except Exception as e: + yield f"data: SQL执行失败: {str(e)}\n\n" + finally: + if conn: + conn.close() + else: + yield "data: \\n[文本结束]\n\n" + return StreamingHttpResponse(stream_generator(), content_type='text/event-stream') # 先新建对话 生成对话session_id -class ConversationView(APIView): - def get(self, request): - conversation = Conversation.objects.all() - serializer = ConversationSerializer(conversation, many=True) - return Response(serializer.data) +class ConversationViewSet(CustomModelViewSet): + queryset = Conversation.objects.all() + serializer_class = ConversationSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} - def post(self, request): - serializer = ConversationSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - serializer.save() - return Response(serializer.data) - def put(self, request, pk): - conversation = get_object_or_404(Conversation, pk=pk) - serializer = ConversationSerializer(conversation, data=request.data) - serializer.is_valid(raise_exception=True) - serializer.save() - return Response(serializer.data) \ No newline at end of file diff --git a/apps/utils/viewsets.py b/apps/utils/viewsets.py index 6de001ac..9311ecb8 100755 --- a/apps/utils/viewsets.py +++ b/apps/utils/viewsets.py @@ -1,5 +1,6 @@ from django.core.cache import cache +from django.http import StreamingHttpResponse from rest_framework.decorators import action from rest_framework.exceptions import ParseError from rest_framework.mixins import RetrieveModelMixin @@ -58,6 +59,9 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): return super().__new__(cls) def finalize_response(self, request, response, *args, **kwargs): + # 如果是流式响应,直接返回 + if isinstance(response, StreamingHttpResponse): + return response if self.hash_k and self.cache_seconds: cache.set(self.hash_k, response.data, timeout=self.cache_seconds) # 将结果存入缓存,设置超时时间 diff --git a/server/settings.py b/server/settings.py index dceec682..40d9f3dd 100755 --- a/server/settings.py +++ b/server/settings.py @@ -63,6 +63,7 @@ INSTALLED_APPS = [ 'apps.wf', 'apps.ecm', 'apps.hrm', + 'apps.ichat', 'apps.am', 'apps.vm', 'apps.rpm', diff --git a/server/urls.py b/server/urls.py index 473e362f..0eb2954c 100755 --- a/server/urls.py +++ b/server/urls.py @@ -44,7 +44,7 @@ urlpatterns = [ # api path('', include('apps.auth1.urls')), - # path('', include('apps.ichat.urls')), + path('', include('apps.ichat.urls')), path('', include('apps.system.urls')), path('', include('apps.monitor.urls')), path('', include('apps.wf.urls')),