diff --git a/apps/ichat/__init__.py b/apps/ichat/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/apps/ichat/admin.py b/apps/ichat/admin.py deleted file mode 100644 index 8c38f3f3..00000000 --- a/apps/ichat/admin.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.contrib import admin - -# Register your models here. diff --git a/apps/ichat/apps.py b/apps/ichat/apps.py deleted file mode 100644 index c7bf0cf9..00000000 --- a/apps/ichat/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class ChatConfig(AppConfig): - default_auto_field = 'django.db.models.BigAutoField' - name = 'apps.ichat' diff --git a/apps/ichat/migrations/0001_initial.py b/apps/ichat/migrations/0001_initial.py deleted file mode 100644 index 290c1cff..00000000 --- a/apps/ichat/migrations/0001_initial.py +++ /dev/null @@ -1,48 +0,0 @@ -# Generated by Django 3.2.12 on 2025-05-21 05:59 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion -import django.utils.timezone - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.CreateModel( - name='Conversation', - fields=[ - ('id', models.CharField(editable=False, help_text='主键ID', max_length=20, primary_key=True, serialize=False, verbose_name='主键ID')), - ('create_time', models.DateTimeField(default=django.utils.timezone.now, help_text='创建时间', verbose_name='创建时间')), - ('update_time', models.DateTimeField(auto_now=True, help_text='修改时间', verbose_name='修改时间')), - ('is_deleted', models.BooleanField(default=False, help_text='删除标记', verbose_name='删除标记')), - ('title', models.CharField(default='新对话', max_length=200, verbose_name='对话标题')), - ('create_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='conversation_create_by', to=settings.AUTH_USER_MODEL, verbose_name='创建人')), - ('update_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='conversation_update_by', to=settings.AUTH_USER_MODEL, verbose_name='最后编辑人')), - ], - options={ - 'abstract': False, - }, - ), - migrations.CreateModel( - name='Message', - fields=[ - ('id', models.CharField(editable=False, help_text='主键ID', max_length=20, primary_key=True, serialize=False, verbose_name='主键ID')), - ('create_time', models.DateTimeField(default=django.utils.timezone.now, help_text='创建时间', verbose_name='创建时间')), - ('update_time', models.DateTimeField(auto_now=True, help_text='修改时间', verbose_name='修改时间')), - ('is_deleted', models.BooleanField(default=False, help_text='删除标记', verbose_name='删除标记')), - ('content', models.TextField(verbose_name='消息内容')), - ('role', models.CharField(default='user', help_text='system/user', max_length=10, verbose_name='角色')), - ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='ichat.conversation', verbose_name='对话')), - ], - options={ - 'abstract': False, - }, - ), - ] diff --git a/apps/ichat/migrations/__init__.py b/apps/ichat/migrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/apps/ichat/models.py b/apps/ichat/models.py deleted file mode 100644 index e8e79fad..00000000 --- a/apps/ichat/models.py +++ /dev/null @@ -1,17 +0,0 @@ -from django.db import models -from apps.system.models import CommonADModel, BaseModel - -# Create your models here. -class Conversation(CommonADModel): - """ - TN: 对话 - """ - title = models.CharField(max_length=200, default='新对话',verbose_name='对话标题') - -class Message(BaseModel): - """ - TN: 消息 - """ - conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages', verbose_name='对话') - content = models.TextField(verbose_name='消息内容') - role = models.CharField("角色", max_length=10, default='user', help_text="system/user") diff --git a/apps/ichat/promot/w_ana.md b/apps/ichat/promot/w_ana.md deleted file mode 100644 index 0751e9a1..00000000 --- a/apps/ichat/promot/w_ana.md +++ /dev/null @@ -1,14 +0,0 @@ -# 角色 -你是一位数据分析专家和前端程序员,具备深厚的专业知识和丰富的实践经验。你能够精准理解用户的文本描述, 并形成报告。 -# 技能 -1. 仔细分析用户提供的JSON格式数据,分析用户需求。 -2. 依据得到的需求, 分别获取JSON数据中的关键信息。 -3. 根据2中的关键信息最优化选择表格/饼图/柱状图/折线图等格式绘制报告。 -# 回答要求 -1. 仅生成完整的HTML代码,所有功能都需要实现,支持响应式,不要输出任何解释或说明。 -2. 代码中如需要Echarts等js库,请直接使用中国大陆的CDN链接例如bootcdn的链接。 -3. 标题为 数据分析报告。 -3. 在开始部分,请以表格形式简略展示获取的JSON数据。 -4. 之后选择最合适的图表方式生成相应的图。 -5. 在最后提供可下载该报告的完整PDF的按钮和功能。 -6. 在最后提供可下载含有JSON数据的EXCEL文件的按钮和功能。 \ No newline at end of file diff --git a/apps/ichat/promot/w_sql.md b/apps/ichat/promot/w_sql.md deleted file mode 100644 index c987e161..00000000 --- a/apps/ichat/promot/w_sql.md +++ /dev/null @@ -1,53 +0,0 @@ -# 角色 -你是一位资深的Postgresql数据库SQL专家,具备深厚的专业知识和丰富的实践经验。你能够精准理解用户的文本描述,并生成准确可执行的SQL语句。 -# 技能 -1. 仔细分析用户提供的文本描述,明确用户需求。 -2. 根据对用户需求的理解,生成符合Postgresql数据库语法的准确可执行的SQL语句。 -# 回答要求 -1. 如果用户的询问未以 查询 开头,请直接回复 "请以 查询 开头,重新描述你的需求"。 -2. 生成的SQL语句必须符合Postgresql数据库的语法规范。 -3. 不要使用 Markerdown 和 SQL 语法格式输出,禁止添加语法标准、备注、说明等信息。 -4. 直接输出符合Postgresql标准的SQL语句,用txt纯文本格式展示即可。 -5. 如果无法生成符合要求的SQL语句,请直接回复 "无法生成"。 -# 示例 -1. 问:查询 外协白片抛 工段在2025年6月1日到2025年6月15日之间的生产合格数以及合格率等 - 答:select - sum(mlog.count_use) as 领用数, - sum(mlog.count_real) as 生产数, - sum(mlog.count_ok) as 合格数, - sum(mlog.count_notok) as 不合格数, - CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率 - from wpm_mlog mlog - left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id - where mlog.submit_time is not null - and mgroup.name = '外协白片抛' - and mlog.handle_date >= '2025-06-01' - and mlog.handle_date <= '2025-06-15' -2. 问:查询 黑化 工段在2025年6月的生产合格数以及合格率等 - 答: select - sum(mlog.count_use) as 领用数, - sum(mlog.count_real) as 生产数, - sum(mlog.count_ok) as 合格数, - sum(mlog.count_notok) as 不合格数, - CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率 - from wpm_mlog mlog - left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id - where mlog.submit_time is not null - and mgroup.name = '黑化' - and mlog.handle_date >= '2025-06-01' - and mlog.handle_date <= '2025-06-30' -3. 问:查询 各工段 在2025年6月的生产合格数以及合格率等 - 答: select - mgroup.name as 工段, - sum(mlog.count_use) as 领用数, - sum(mlog.count_real) as 生产数, - sum(mlog.count_ok) as 合格数, - sum(mlog.count_notok) as 不合格数, - CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率 - from wpm_mlog mlog - left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id - where mlog.submit_time is not null - and mlog.handle_date >= '2025-06-01' - and mlog.handle_date <= '2025-06-30' - group by mgroup.id - order by mgroup.sort \ No newline at end of file diff --git a/apps/ichat/script.py b/apps/ichat/script.py deleted file mode 100644 index 705ce0d0..00000000 --- a/apps/ichat/script.py +++ /dev/null @@ -1,22 +0,0 @@ -import json -from .models import Message -from django.http import StreamingHttpResponse - -def stream_generator(stream_response: bytes, conversation_id: str): - full_content = '' - for chunk in stream_response.iter_content(chunk_size=1024): - if chunk: - full_content += chunk.decode('utf-8') - try: - data = json.loads(full_content) - content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") - Message.objects.create( - conversation_id=conversation_id, - content=content - ) - yield f" data:{content}\n\n" - full_content = '' - except json.JSONDecodeError: - continue - return StreamingHttpResponse(stream_generator(stream_response, conversation_id), content_type='text/event-stream') - diff --git a/apps/ichat/serializers.py b/apps/ichat/serializers.py deleted file mode 100644 index 43102545..00000000 --- a/apps/ichat/serializers.py +++ /dev/null @@ -1,18 +0,0 @@ -from rest_framework import serializers -from .models import Conversation, Message -from apps.utils.constants import EXCLUDE_FIELDS - - -class MessageSerializer(serializers.ModelSerializer): - class Meta: - model = Message - fields = ['id', 'conversation', 'content', 'role'] - read_only_fields = EXCLUDE_FIELDS - - -class ConversationSerializer(serializers.ModelSerializer): - messages = MessageSerializer(many=True, read_only=True) - class Meta: - model = Conversation - fields = ['id', 'title', 'messages'] - read_only_fields = EXCLUDE_FIELDS \ No newline at end of file diff --git a/apps/ichat/tests.py b/apps/ichat/tests.py deleted file mode 100644 index 7ce503c2..00000000 --- a/apps/ichat/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/apps/ichat/urls.py b/apps/ichat/urls.py deleted file mode 100644 index 88a41b81..00000000 --- a/apps/ichat/urls.py +++ /dev/null @@ -1,16 +0,0 @@ - -from django.urls import path, include -from rest_framework.routers import DefaultRouter -from apps.ichat.views import QueryLLMviewSet, ConversationViewSet -from apps.ichat.views2 import WorkChain - -API_BASE_URL = 'api/ichat/' - -router = DefaultRouter() - -router.register('conversation', ConversationViewSet, basename='conversation') -router.register('message', QueryLLMviewSet, basename='message') -urlpatterns = [ - path(API_BASE_URL, include(router.urls)), - path(API_BASE_URL + 'workchain/ask/', WorkChain.as_view(), name='workchain') -] diff --git a/apps/ichat/utils.py b/apps/ichat/utils.py deleted file mode 100644 index e055a672..00000000 --- a/apps/ichat/utils.py +++ /dev/null @@ -1,88 +0,0 @@ -import re -import psycopg2 -import threading -from django.db import transaction -from .models import Message - -# 数据库连接 -def connect_db(): - from server.conf import DATABASES - db_conf = DATABASES['default'] - conn = psycopg2.connect( - host=db_conf['HOST'], - port=db_conf['PORT'], - user=db_conf['USER'], - password=db_conf['PASSWORD'], - database=db_conf['NAME'] - ) - return conn - -def extract_sql_code(text): - # 优先尝试 ```sql 包裹的语句 - match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE) - if match: - return match.group(1).strip() - - # fallback: 寻找首个 select 语句 - match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL) - if match: - return match.group(1).strip() - - return None - - -def get_schema_text(conn, table_names:list): - cur = conn.cursor() - query = """ - SELECT - table_name, column_name, data_type - FROM - information_schema.columns - WHERE - table_schema = 'public' - and table_name in %s; - """ - cur.execute(query, (tuple(table_names), )) - - schema = {} - for table_name, column_name, data_type in cur.fetchall(): - if table_name not in schema: - schema[table_name] = [] - schema[table_name].append(f"{column_name} ({data_type})") - cur.close() - schema_text = "" - for table_name, columns in schema.items(): - schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" - return schema_text - - -def is_safe_sql(sql:str) -> bool: - sql = sql.strip().lower() - return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql) - -def execute_sql(conn, sql_query): - cur = conn.cursor() - cur.execute(sql_query) - try: - rows = cur.fetchall() - columns = [desc[0] for desc in cur.description] - result = [dict(zip(columns, row)) for row in rows] - except psycopg2.ProgrammingError: - result = cur.statusmessage - cur.close() - return result - -def strip_sql_markdown(content: str) -> str: - # 去掉包裹在 ```sql 或 ``` 中的内容 - match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) - if match: - return match.group(1).strip() - else: - return None - -# ORM 写入包装函数 -def save_message_thread_safe(**kwargs): - def _save(): - with transaction.atomic(): - Message.objects.create(**kwargs) - threading.Thread(target=_save).start() diff --git a/apps/ichat/view_bak.py b/apps/ichat/view_bak.py deleted file mode 100644 index 4a5570c7..00000000 --- a/apps/ichat/view_bak.py +++ /dev/null @@ -1,87 +0,0 @@ -import requests -from langchain_core.language_models import LLM -from langchain_core.outputs import LLMResult, Generation -from langchain_experimental.sql import SQLDatabaseChain -from langchain_community.utilities import SQLDatabase -from server.conf import DATABASES -from apps.ichat.serializers import CustomLLMrequestSerializer -from rest_framework.views import APIView -from urllib.parse import quote_plus -from rest_framework.response import Response - - -db_conf = DATABASES['default'] -# 密码需要 URL 编码(因为有特殊字符如 @) -password_encodeed = quote_plus(db_conf['PASSWORD']) - -db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"]) -# model_url = "http://14.22.88.72:11025/v1/chat/completions" -model_url = "http://139.159.180.64:11434/v1/chat/completions" - -class CustomLLM(LLM): - model_url: str - mode: str = 'chat' - def _call(self, prompt: str, stop: list = None) -> str: - data = { - "model":"glm4", - "messages": self.build_message(prompt), - "stream": False, - } - response = requests.post(self.model_url, json=data, timeout=600) - response.raise_for_status() - content = response.json()["choices"][0]["message"]["content"] - print('content---', content) - clean_sql = self.strip_sql_markdown(content) if self.mode == 'sql' else content.strip() - return clean_sql - - def _generate(self, prompts: list, stop: list = None) -> LLMResult: - generations = [] - for prompt in prompts: - text = self._call(prompt, stop) - generations.append([Generation(text=text)]) - return LLMResult(generations=generations) - - def strip_sql_markdown(self, content: str) -> str: - import re - # 去掉包裹在 ```sql 或 ``` 中的内容 - match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) - if match: - return match.group(1).strip() - else: - return content.strip() - - def build_message(self, prompt: str) -> list: - if self.mode == 'sql': - system_prompt = ( - "你是一个 SQL 助手,严格遵循以下规则:\n" - "1. 只返回 PostgreSQL 语法 SQL 语句。\n" - "2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n" - "3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。\n" - "4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。" - ) - else: - system_prompt = "你是一个聊天助手,请根据用户的问题,提供简洁明了的答案。" - return [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ] - - @property - def _llm_type(self) -> str: - return "custom_llm" - - -class QueryLLMview(APIView): - def post(self, request): - serializer = CustomLLMrequestSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - prompt = serializer.validated_data['prompt'] - mode = serializer.validated_data.get('mode', 'chat') - llm = CustomLLM(model_url=model_url, mode=mode) - print('prompt---', prompt, mode) - if mode == 'sql': - chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) - result = chain.invoke(prompt) - else: - result = llm._call(prompt) - return Response({"result": result}) \ No newline at end of file diff --git a/apps/ichat/views.py b/apps/ichat/views.py deleted file mode 100644 index 922d0fee..00000000 --- a/apps/ichat/views.py +++ /dev/null @@ -1,155 +0,0 @@ -import requests -import json -from rest_framework.views import APIView -from apps.ichat.serializers import MessageSerializer, ConversationSerializer -from rest_framework.response import Response -from apps.ichat.models import Conversation, Message -from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe -from django.http import StreamingHttpResponse, JsonResponse -from rest_framework.decorators import action -from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet - -# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" -# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" -# MODEL = "qwen-plus" - -# #本地部署的模式 -API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" -API_BASE = "http://106.0.4.200:9000/v1" -MODEL = "qwen14b" - -# google gemini -# API_KEY = "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" -# API_BASE = "https://openrouter.ai/api/v1" -# MODEL="google/gemini-2.0-flash-exp:free" - -# deepseek v3 -# API_KEY = "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" -# API_BASE = "https://openrouter.ai/api/v1" -# MODEL="deepseek/deepseek-chat-v3-0324:free" - -TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 - - -class QueryLLMviewSet(CustomModelViewSet): - queryset = Message.objects.all() - serializer_class = MessageSerializer - ordering = ['create_time'] - perms_map = {'get':'*', 'post':'*', 'put':'*'} - - @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer) - def completion(self, request): - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - serializer.save() - prompt = serializer.validated_data['content'] - conversation = serializer.validated_data['conversation'] - if not prompt or not conversation: - return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400) - save_message_thread_safe(content=prompt, conversation=conversation, role="user") - url = f"{API_BASE}/chat/completions" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {API_KEY}" - } - - user_prompt = f""" -我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。 - -注意: -只需回答"database"或"general"即可,不要有其他内容。 -""" - _payload = { - "model": MODEL, - "messages": [{"role": "user", "content": user_prompt}, {"role":"system" , "content": "只返回一个结果'database'或'general'"}], - "temperature": 0, - "max_tokens": 10 - } - try: - class_response = requests.post(url, headers=headers, json=_payload) - class_response.raise_for_status() - class_result = class_response.json() - question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower() - print("question_type", question_type) - if question_type == "database": - conn = connect_db() - schema_text = get_schema_text(conn, TABLES) - print("schema_text----------------------", schema_text) - user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构: -{schema_text} -请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 -需求是:{prompt} -""" - else: - user_prompt = f""" -回答以下问题,不需要涉及数据库查询: - -问题: {prompt} - -请直接回答问题,不要提及数据库或SQL。 -""" - # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages - history = Message.objects.filter(conversation=conversation).order_by('create_time') - # chat_history = [{"role": msg.role, "content": msg.content} for msg in history] - # chat_history.append({"role": "user", "content": prompt}) - chat_history = [{"role":"user", "content":prompt}] - print("chat_history", chat_history) - payload = { - "model": MODEL, - "messages": chat_history, - "temperature": 0, - "stream": True - } - response = requests.post(url, headers=headers, json=payload) - response.raise_for_status() - except requests.exceptions.RequestException as e: - return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500) - def stream_generator(): - accumulated_content = "" - for line in response.iter_lines(): - if line: - decoded_line = line.decode('utf-8') - if decoded_line.startswith('data:'): - if decoded_line.strip() == "data: [DONE]": - break # OpenAI-style标志结束 - try: - data = json.loads(decoded_line[6:]) - content = data.get('choices', [{}])[0].get('delta', {}).get('content', '') - if content: - accumulated_content += content - yield f"data: {content}\n\n" - - except Exception as e: - yield f"data: [解析失败]: {str(e)}\n\n" - print("accumulated_content", accumulated_content) - save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system") - - if question_type == "database": - sql = extract_sql_code(accumulated_content) - if sql: - try: - conn = connect_db() - if is_safe_sql(sql): - result = execute_sql(conn, sql) - save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system") - yield f"data: SQL执行结果: {result}\n\n" - else: - yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n" - except Exception as e: - yield f"data: SQL执行失败: {str(e)}\n\n" - finally: - if conn: - conn.close() - else: - yield "data: \\n[文本结束]\n\n" - return StreamingHttpResponse(stream_generator(), content_type='text/event-stream') - - -# 先新建对话 生成对话session_id -class ConversationViewSet(CustomModelViewSet): - queryset = Conversation.objects.all() - serializer_class = ConversationSerializer - ordering = ['create_time'] - perms_map = {'get':'*', 'post':'*', 'put':'*'} - - diff --git a/apps/ichat/views2.py b/apps/ichat/views2.py deleted file mode 100644 index 9b2766bb..00000000 --- a/apps/ichat/views2.py +++ /dev/null @@ -1,129 +0,0 @@ -import requests -import os -from apps.utils.sql import execute_raw_sql -import json -from apps.utils.tools import MyJSONEncoder -from .utils import is_safe_sql -from rest_framework.views import APIView -from drf_yasg.utils import swagger_auto_schema -from rest_framework import serializers -from rest_framework.exceptions import ParseError -from rest_framework.response import Response -from django.conf import settings -from apps.utils.mixins import MyLoggingMixin -from django.core.cache import cache -import uuid -from apps.utils.thread import MyThread - -LLM_URL = getattr(settings, "LLM_URL", "") -API_KEY = getattr(settings, "LLM_API_KEY", "") -MODEL = "qwen14b" -HEADERS = { - "Authorization": f"Bearer {API_KEY}", - "Content-Type": "application/json" -} -CUR_DIR = os.path.dirname(os.path.abspath(__file__)) - -def load_promot(name): - with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f: - return f.read() - - -def ask(input:str, p_name:str, stream=False): - his = [{"role":"system", "content": load_promot(p_name)}] - his.append({"role":"user", "content": input}) - payload = { - "model": MODEL, - "messages": his, - "temperature": 0, - "stream": stream - } - response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream) - if not stream: - return response.json()["choices"][0]["message"]["content"] - else: - # 处理流式响应 - full_content = "" - for chunk in response.iter_lines(): - if chunk: - # 通常流式响应是SSE格式(data: {...}) - decoded_chunk = chunk.decode('utf-8') - if decoded_chunk.startswith("data:"): - json_str = decoded_chunk[5:].strip() - if json_str == "[DONE]": - break - try: - chunk_data = json.loads(json_str) - if "choices" in chunk_data and chunk_data["choices"]: - delta = chunk_data["choices"][0].get("delta", {}) - if "content" in delta: - print(delta["content"]) - full_content += delta["content"] - except json.JSONDecodeError: - continue - return full_content - -def work_chain(input:str, t_key:str): - pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]} - cache.set(t_key, pdict) - res_text = ask(input, 'w_sql') - if res_text == '请以 查询 开头,重新描述你的需求': - pdict["state"] = "error" - pdict["steps"].append({"state":"error", "msg":res_text}) - cache.set(t_key, pdict) - return - else: - pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text}) - cache.set(t_key, pdict) - if not is_safe_sql(res_text): - pdict["state"] = "error" - pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"}) - cache.set(t_key, pdict) - return - pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"}) - cache.set(t_key, pdict) - res = execute_raw_sql(res_text) - pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res}) - cache.set(t_key, pdict) - pdict["steps"].append({"state":"ok", "msg":"正在生成报告"}) - cache.set(t_key, pdict) - res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana') - content = res2.lstrip('```html ').rstrip('```') - pdict["state"] = "done" - pdict["content"] = content - pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content}) - cache.set(t_key, pdict) - return - -class InputSerializer(serializers.Serializer): - input = serializers.CharField(label="查询需求") - -class WorkChain(MyLoggingMixin, APIView): - - @swagger_auto_schema( - operation_summary="提交查询需求", - request_body=InputSerializer) - def post(self, request): - llm_enabled = getattr(settings, "LLM_ENABLED", False) - if not llm_enabled: - raise ParseError('LLM功能未启用') - input = request.data.get('input') - t_key = f'ichat_{uuid.uuid4()}' - MyThread(target=work_chain, args=(input, t_key)).start() - return Response({'ichat_tid': t_key}) - - @swagger_auto_schema( - operation_summary="获取查询进度") - def get(self, request): - llm_enabled = getattr(settings, "LLM_ENABLED", False) - if not llm_enabled: - raise ParseError('LLM功能未启用') - ichat_tid = request.GET.get('ichat_tid') - if ichat_tid: - return Response(cache.get(ichat_tid)) - -if __name__ == "__main__": - print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告")) - - from apps.ichat.views2 import work_chain - print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告')) \ No newline at end of file diff --git a/server/urls.py b/server/urls.py index da981a42..ac46779f 100755 --- a/server/urls.py +++ b/server/urls.py @@ -45,7 +45,6 @@ urlpatterns = [ # api path('', include('apps.auth1.urls')), - path('', include('apps.ichat.urls')), path('', include('apps.system.urls')), path('', include('apps.monitor.urls')), path('', include('apps.wf.urls')),