diff --git a/apps/ichat/__init__.py b/apps/ichat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/ichat/admin.py b/apps/ichat/admin.py new file mode 100644 index 00000000..8c38f3f3 --- /dev/null +++ b/apps/ichat/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/apps/ichat/apps.py b/apps/ichat/apps.py new file mode 100644 index 00000000..c7bf0cf9 --- /dev/null +++ b/apps/ichat/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class ChatConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'apps.ichat' diff --git a/apps/ichat/migrations/0001_initial.py b/apps/ichat/migrations/0001_initial.py new file mode 100644 index 00000000..290c1cff --- /dev/null +++ b/apps/ichat/migrations/0001_initial.py @@ -0,0 +1,48 @@ +# Generated by Django 3.2.12 on 2025-05-21 05:59 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='Conversation', + fields=[ + ('id', models.CharField(editable=False, help_text='主键ID', max_length=20, primary_key=True, serialize=False, verbose_name='主键ID')), + ('create_time', models.DateTimeField(default=django.utils.timezone.now, help_text='创建时间', verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, help_text='修改时间', verbose_name='修改时间')), + ('is_deleted', models.BooleanField(default=False, help_text='删除标记', verbose_name='删除标记')), + ('title', models.CharField(default='新对话', max_length=200, verbose_name='对话标题')), + ('create_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='conversation_create_by', to=settings.AUTH_USER_MODEL, verbose_name='创建人')), + ('update_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='conversation_update_by', to=settings.AUTH_USER_MODEL, verbose_name='最后编辑人')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='Message', + fields=[ + ('id', models.CharField(editable=False, help_text='主键ID', max_length=20, primary_key=True, serialize=False, verbose_name='主键ID')), + ('create_time', models.DateTimeField(default=django.utils.timezone.now, help_text='创建时间', verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, help_text='修改时间', verbose_name='修改时间')), + ('is_deleted', models.BooleanField(default=False, help_text='删除标记', verbose_name='删除标记')), + ('content', models.TextField(verbose_name='消息内容')), + ('role', models.CharField(default='user', help_text='system/user', max_length=10, verbose_name='角色')), + ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='ichat.conversation', verbose_name='对话')), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/apps/ichat/migrations/__init__.py b/apps/ichat/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/ichat/models.py b/apps/ichat/models.py new file mode 100644 index 00000000..e8e79fad --- /dev/null +++ b/apps/ichat/models.py @@ -0,0 +1,17 @@ +from django.db import models +from apps.system.models import CommonADModel, BaseModel + +# Create your models here. +class Conversation(CommonADModel): + """ + TN: 对话 + """ + title = models.CharField(max_length=200, default='新对话',verbose_name='对话标题') + +class Message(BaseModel): + """ + TN: 消息 + """ + conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages', verbose_name='对话') + content = models.TextField(verbose_name='消息内容') + role = models.CharField("角色", max_length=10, default='user', help_text="system/user") diff --git a/apps/ichat/promot/w_ana.md b/apps/ichat/promot/w_ana.md new file mode 100644 index 00000000..0751e9a1 --- /dev/null +++ b/apps/ichat/promot/w_ana.md @@ -0,0 +1,14 @@ +# 角色 +你是一位数据分析专家和前端程序员,具备深厚的专业知识和丰富的实践经验。你能够精准理解用户的文本描述, 并形成报告。 +# 技能 +1. 仔细分析用户提供的JSON格式数据,分析用户需求。 +2. 依据得到的需求, 分别获取JSON数据中的关键信息。 +3. 根据2中的关键信息最优化选择表格/饼图/柱状图/折线图等格式绘制报告。 +# 回答要求 +1. 仅生成完整的HTML代码,所有功能都需要实现,支持响应式,不要输出任何解释或说明。 +2. 代码中如需要Echarts等js库,请直接使用中国大陆的CDN链接例如bootcdn的链接。 +3. 标题为 数据分析报告。 +3. 在开始部分,请以表格形式简略展示获取的JSON数据。 +4. 之后选择最合适的图表方式生成相应的图。 +5. 在最后提供可下载该报告的完整PDF的按钮和功能。 +6. 在最后提供可下载含有JSON数据的EXCEL文件的按钮和功能。 \ No newline at end of file diff --git a/apps/ichat/promot/w_sql.md b/apps/ichat/promot/w_sql.md new file mode 100644 index 00000000..c987e161 --- /dev/null +++ b/apps/ichat/promot/w_sql.md @@ -0,0 +1,53 @@ +# 角色 +你是一位资深的Postgresql数据库SQL专家,具备深厚的专业知识和丰富的实践经验。你能够精准理解用户的文本描述,并生成准确可执行的SQL语句。 +# 技能 +1. 仔细分析用户提供的文本描述,明确用户需求。 +2. 根据对用户需求的理解,生成符合Postgresql数据库语法的准确可执行的SQL语句。 +# 回答要求 +1. 如果用户的询问未以 查询 开头,请直接回复 "请以 查询 开头,重新描述你的需求"。 +2. 生成的SQL语句必须符合Postgresql数据库的语法规范。 +3. 不要使用 Markerdown 和 SQL 语法格式输出,禁止添加语法标准、备注、说明等信息。 +4. 直接输出符合Postgresql标准的SQL语句,用txt纯文本格式展示即可。 +5. 如果无法生成符合要求的SQL语句,请直接回复 "无法生成"。 +# 示例 +1. 问:查询 外协白片抛 工段在2025年6月1日到2025年6月15日之间的生产合格数以及合格率等 + 答:select + sum(mlog.count_use) as 领用数, + sum(mlog.count_real) as 生产数, + sum(mlog.count_ok) as 合格数, + sum(mlog.count_notok) as 不合格数, + CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率 + from wpm_mlog mlog + left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id + where mlog.submit_time is not null + and mgroup.name = '外协白片抛' + and mlog.handle_date >= '2025-06-01' + and mlog.handle_date <= '2025-06-15' +2. 问:查询 黑化 工段在2025年6月的生产合格数以及合格率等 + 答: select + sum(mlog.count_use) as 领用数, + sum(mlog.count_real) as 生产数, + sum(mlog.count_ok) as 合格数, + sum(mlog.count_notok) as 不合格数, + CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率 + from wpm_mlog mlog + left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id + where mlog.submit_time is not null + and mgroup.name = '黑化' + and mlog.handle_date >= '2025-06-01' + and mlog.handle_date <= '2025-06-30' +3. 问:查询 各工段 在2025年6月的生产合格数以及合格率等 + 答: select + mgroup.name as 工段, + sum(mlog.count_use) as 领用数, + sum(mlog.count_real) as 生产数, + sum(mlog.count_ok) as 合格数, + sum(mlog.count_notok) as 不合格数, + CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率 + from wpm_mlog mlog + left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id + where mlog.submit_time is not null + and mlog.handle_date >= '2025-06-01' + and mlog.handle_date <= '2025-06-30' + group by mgroup.id + order by mgroup.sort \ No newline at end of file diff --git a/apps/ichat/script.py b/apps/ichat/script.py new file mode 100644 index 00000000..705ce0d0 --- /dev/null +++ b/apps/ichat/script.py @@ -0,0 +1,22 @@ +import json +from .models import Message +from django.http import StreamingHttpResponse + +def stream_generator(stream_response: bytes, conversation_id: str): + full_content = '' + for chunk in stream_response.iter_content(chunk_size=1024): + if chunk: + full_content += chunk.decode('utf-8') + try: + data = json.loads(full_content) + content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") + Message.objects.create( + conversation_id=conversation_id, + content=content + ) + yield f" data:{content}\n\n" + full_content = '' + except json.JSONDecodeError: + continue + return StreamingHttpResponse(stream_generator(stream_response, conversation_id), content_type='text/event-stream') + diff --git a/apps/ichat/serializers.py b/apps/ichat/serializers.py new file mode 100644 index 00000000..43102545 --- /dev/null +++ b/apps/ichat/serializers.py @@ -0,0 +1,18 @@ +from rest_framework import serializers +from .models import Conversation, Message +from apps.utils.constants import EXCLUDE_FIELDS + + +class MessageSerializer(serializers.ModelSerializer): + class Meta: + model = Message + fields = ['id', 'conversation', 'content', 'role'] + read_only_fields = EXCLUDE_FIELDS + + +class ConversationSerializer(serializers.ModelSerializer): + messages = MessageSerializer(many=True, read_only=True) + class Meta: + model = Conversation + fields = ['id', 'title', 'messages'] + read_only_fields = EXCLUDE_FIELDS \ No newline at end of file diff --git a/apps/ichat/tests.py b/apps/ichat/tests.py new file mode 100644 index 00000000..7ce503c2 --- /dev/null +++ b/apps/ichat/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/ichat/urls.py b/apps/ichat/urls.py new file mode 100644 index 00000000..88a41b81 --- /dev/null +++ b/apps/ichat/urls.py @@ -0,0 +1,16 @@ + +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from apps.ichat.views import QueryLLMviewSet, ConversationViewSet +from apps.ichat.views2 import WorkChain + +API_BASE_URL = 'api/ichat/' + +router = DefaultRouter() + +router.register('conversation', ConversationViewSet, basename='conversation') +router.register('message', QueryLLMviewSet, basename='message') +urlpatterns = [ + path(API_BASE_URL, include(router.urls)), + path(API_BASE_URL + 'workchain/ask/', WorkChain.as_view(), name='workchain') +] diff --git a/apps/ichat/utils.py b/apps/ichat/utils.py new file mode 100644 index 00000000..1ce4c822 --- /dev/null +++ b/apps/ichat/utils.py @@ -0,0 +1,195 @@ +import re +import psycopg2 +import threading +from django.db import transaction +from .models import Message + +# 数据库连接 +def connect_db(): + from server.conf import DATABASES + db_conf = DATABASES['default'] + conn = psycopg2.connect( + host=db_conf['HOST'], + port=db_conf['PORT'], + user=db_conf['USER'], + password=db_conf['PASSWORD'], + database=db_conf['NAME'] + ) + return conn + +def extract_sql_code(text): + # 优先尝试 ```sql 包裹的语句 + match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + + # fallback: 寻找首个 select 语句 + match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL) + if match: + return match.group(1).strip() + + return None + + +# def get_schema_text(conn, table_names:list): +# cur = conn.cursor() +# query = """ +# SELECT +# table_name, column_name, data_type +# FROM +# information_schema.columns +# WHERE +# table_schema = 'public' +# and table_name in %s; +# """ +# cur.execute(query, (tuple(table_names), )) + +# schema = {} +# for table_name, column_name, data_type in cur.fetchall(): +# if table_name not in schema: +# schema[table_name] = [] +# schema[table_name].append(f"{column_name} ({data_type})") +# cur.close() +# schema_text = "" +# for table_name, columns in schema.items(): +# schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" +# return schema_text + +def get_schema_text(conn, table_names: list): + cur = conn.cursor() + query = """ + SELECT + c.relname AS table_name, + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + d.description AS column_comment + FROM + pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + JOIN pg_attribute a ON a.attrelid = c.oid + LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum + WHERE + n.nspname = 'public' + AND c.relname = ANY(%s) + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY + c.relname, a.attnum; + """ + cur.execute(query, (table_names,)) + + schema = {} + for table_name, column_name, data_type, comment in cur.fetchall(): + if comment and "备注" in comment: + comment = comment.split("备注")[0].strip() + schema.setdefault(table_name, []).append( + f"{column_name}-{comment}" + ) + + cur.close() + + return [ + {"table": table, "text": f"表 {table} 包含列:\n" + "\n".join(columns)} + for table, columns in schema.items() + ] + + +# def get_schema_text(conn, table_names: list): +# cur = conn.cursor() + +# # 获取字段、类型、注释 +# column_query = """ +# SELECT +# c.relname AS table_name, +# a.attname AS column_name, +# pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, +# d.description AS column_comment +# FROM +# pg_class c +# JOIN pg_namespace n ON n.oid = c.relnamespace +# JOIN pg_attribute a ON a.attrelid = c.oid +# LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum +# WHERE +# n.nspname = 'public' +# AND c.relname = ANY(%s) +# AND a.attnum > 0 +# AND NOT a.attisdropped +# ORDER BY +# c.relname, a.attnum; +# """ + +# # 获取外键信息 +# fk_query = """ +# SELECT +# conrelid::regclass::text AS table_name, +# a.attname AS column_name, +# confrelid::regclass::text AS foreign_table, +# af.attname AS foreign_column +# FROM +# pg_constraint +# JOIN pg_class ON conrelid = pg_class.oid +# JOIN pg_namespace n ON pg_class.relnamespace = n.oid +# JOIN pg_attribute a ON a.attrelid = conrelid AND a.attnum = ANY(conkey) +# JOIN pg_attribute af ON af.attrelid = confrelid AND af.attnum = ANY(confkey) +# WHERE +# contype = 'f' +# AND n.nspname = 'public' +# AND conrelid::regclass::text = ANY(%s); +# """ + +# cur.execute(column_query, (table_names,)) +# columns = cur.fetchall() + +# cur.execute(fk_query, (table_names,)) +# fks = cur.fetchall() + +# # 构建外键字典 +# fk_map = {} # {(table, column): "foreign_table(foreign_column)"} +# for table, column, f_table, f_column in fks: +# fk_map[(table, column)] = f"{f_table}({f_column})" + +# # 组织输出结构 +# schema = {} +# for table, column, dtype, comment in columns: +# fk_note = f" -> {fk_map[(table, column)]}" if (table, column) in fk_map else "" +# comment_note = f" -- {comment}" if comment else "" +# schema.setdefault(table, []).append(f"{column} ({dtype}{fk_note}{comment_note})") + +# cur.close() + +# # 生成文本 +# schema_text = "" +# for table, cols in schema.items(): +# schema_text += f"表 {table} 包含列:\n - " + "\n - ".join(cols) + "\n" +# return schema_text + +def is_safe_sql(sql:str) -> bool: + sql = sql.strip().lower() + return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql) + +def execute_sql(conn, sql_query): + cur = conn.cursor() + cur.execute(sql_query) + try: + rows = cur.fetchall() + columns = [desc[0] for desc in cur.description] + result = [dict(zip(columns, row)) for row in rows] + except psycopg2.ProgrammingError: + result = cur.statusmessage + cur.close() + return result + +def strip_sql_markdown(content: str) -> str: + # 去掉包裹在 ```sql 或 ``` 中的内容 + match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + else: + return None + +# ORM 写入包装函数 +def save_message_thread_safe(**kwargs): + def _save(): + with transaction.atomic(): + Message.objects.create(**kwargs) + threading.Thread(target=_save).start() diff --git a/apps/ichat/view b/apps/ichat/view new file mode 100644 index 00000000..e69de29b diff --git a/apps/ichat/view_bak.py b/apps/ichat/view_bak.py new file mode 100644 index 00000000..8e98ded8 --- /dev/null +++ b/apps/ichat/view_bak.py @@ -0,0 +1,155 @@ +import requests +import json +from rest_framework.views import APIView +from apps.ichat.serializers import MessageSerializer, ConversationSerializer +from rest_framework.response import Response +from apps.ichat.models import Conversation, Message +from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe +from django.http import StreamingHttpResponse, JsonResponse +from rest_framework.decorators import action +from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet + +# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" +# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" +# MODEL = "qwen-plus" + +# #本地部署的模式 +API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" +API_BASE = "http://106.0.4.200:9000/v1" +MODEL = "qwen14b" + +# google gemini +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="google/gemini-2.0-flash-exp:free" + +# deepseek v3 +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="deepseek/deepseek-chat-v3-0324:free" + +TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 + + +class QueryLLMviewSet(CustomModelViewSet): + queryset = Message.objects.all() + serializer_class = MessageSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} + + @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer) + def completion(self, request): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + prompt = serializer.validated_data['content'] + conversation = serializer.validated_data['conversation'] + if not prompt or not conversation: + return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400) + save_message_thread_safe(content=prompt, conversation=conversation, role="user") + url = f"{API_BASE}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {API_KEY}" + } + + user_prompt = f""" +我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。 + +注意: +只需回答"database"或"general"即可,不要有其他内容。 +""" + _payload = { + "model": MODEL, + "messages": [{"role": "user", "content": user_prompt}], + "temperature": 0, + "max_tokens": 10 + } + try: + class_response = requests.post(url, headers=headers, json=_payload) + class_response.raise_for_status() + class_result = class_response.json() + question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower() + print("question_type", question_type) + if question_type == "database": + conn = connect_db() + schema_text = get_schema_text(conn, TABLES) + print("schema_text----------------------", schema_text) + user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构: +{schema_text} +请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 +需求是:{prompt} +""" + else: + user_prompt = f""" +回答以下问题,不需要涉及数据库查询: + +问题: {prompt} + +请直接回答问题,不要提及数据库或SQL。 +""" + # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages + # history = Message.objects.filter(conversation=conversation).order_by('create_time') + # chat_history = [{"role": msg.role, "content": msg.content} for msg in history] + # chat_history.append({"role": "user", "content": prompt}) + chat_history = [{"role":"user", "content":user_prompt}] + print("chat_history", chat_history) + payload = { + "model": MODEL, + "messages": chat_history, + "temperature": 0, + "stream": True + } + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + except requests.exceptions.RequestException as e: + return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500) + def stream_generator(): + accumulated_content = "" + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8') + if decoded_line.startswith('data:'): + if decoded_line.strip() == "data: [DONE]": + break # OpenAI-style标志结束 + try: + data = json.loads(decoded_line[6:]) + content = data.get('choices', [{}])[0].get('delta', {}).get('content', '') + if content: + accumulated_content += content + yield f"data: {content}\n\n" + + except Exception as e: + yield f"data: [解析失败]: {str(e)}\n\n" + print("accumulated_content", accumulated_content) + save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system") + + if question_type == "database": + sql = extract_sql_code(accumulated_content) + if sql: + try: + conn = connect_db() + if is_safe_sql(sql): + result = execute_sql(conn, sql) + save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system") + yield f"data: SQL执行结果: {result}\n\n" + else: + yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n" + except Exception as e: + yield f"data: SQL执行失败: {str(e)}\n\n" + finally: + if conn: + conn.close() + else: + yield "data: \\n[文本结束]\n\n" + return StreamingHttpResponse(stream_generator(), content_type='text/event-stream') + + +# 先新建对话 生成对话session_id +class ConversationViewSet(CustomModelViewSet): + queryset = Conversation.objects.all() + serializer_class = ConversationSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} + + diff --git a/apps/ichat/view_bak2.py b/apps/ichat/view_bak2.py new file mode 100644 index 00000000..46a1779e --- /dev/null +++ b/apps/ichat/view_bak2.py @@ -0,0 +1,286 @@ +import requests +import json +import faiss +import numpy as np +from rest_framework.views import APIView +from apps.ichat.serializers import MessageSerializer, ConversationSerializer +from rest_framework.response import Response +from apps.ichat.models import Conversation, Message +from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, is_safe_sql, save_message_thread_safe, get_table_structures +from django.http import StreamingHttpResponse, JsonResponse +from rest_framework.decorators import action +from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet + +# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" +# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" +# MODEL = "qwen-plus" + +#本地部署的模式 +API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" +API_BASE = "http://106.0.4.200:9000/v1" +MODEL = "qwen14b" + +# 文本向量化模型 +EM_MODEL = "m3e-base" +API_BASE_EM = "http://106.0.4.200:9997/v1" + +# google gemini +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="google/gemini-2.0-flash-exp:free" + +# deepseek v3 +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="deepseek/deepseek-chat-v3-0324:free" + +TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 + +HEADERS = { + "Content-Type": "application/json", + "Authorization": f"Bearer {API_KEY}" + } + +def get_table_names(conn): + sql = """ + SELECT tablename + FROM pg_tables + WHERE schemaname = 'public'; + """ + cur = conn.cursor() + cur.execute(sql) + data = cur.fetchall() + cur.close() + return [row[0] for row in data] + +# def get_relation_table(query): +# conn = connect_db() +# # table_names = TABLES +# table_names = get_table_names(conn) +# schemas = get_table_structures(conn, table_names) + +# texts = [ +# f"这是一个数据库表结构,表名为 {s['table']},其结构如下:{s['text']}" +# for s in schemas +# ] +# table_names = [s["table"] for s in schemas] +# embeddings = embed_text(texts) +# index, index_table_map = create_index(embeddings, texts, table_names) + +# results = search_similar_tables(query, index, index_table_map, top_k=3) + +# if not results: +# return "没有找到相关表结构" +# return results + +def get_relation_table(query: str): + conn = connect_db() + table_names = get_table_names(conn) # 只获取用户表 + schemas = get_table_structures(conn, table_names) + texts = [s["text"] for s in schemas] + table_names = [s["table"] for s in schemas] + embeddings = embed_text(texts) + + # 存储向量 + store_embeddings_pg(conn, embeddings, texts, table_names) + + # 查询相似表 + results = search_similar_tables_pg(conn, query, top_k=5) + + if len(results) == 0: + return "没有找到相关表结构" + # 只取相关表的结构 + schemas = get_table_structures(conn, results) + + llm_results = format_schema_for_llm(schemas) + return llm_results + +def store_embeddings_pg(conn, embeddings: list[list[float]], texts: list[str], table_names: list[str]): + cur = conn.cursor() + for embedding, text, table_name in zip(embeddings, texts, table_names): + cur.execute(""" + INSERT INTO table_embeddings (table_name, schema_text, embedding) + VALUES (%s, %s, %s) + ON CONFLICT (table_name) DO UPDATE + SET schema_text = EXCLUDED.schema_text, + embedding = EXCLUDED.embedding + """, (table_name, text, embedding)) + conn.commit() + cur.close() + +def search_similar_tables_pg(conn, query: str, top_k: int = 5): + # 第一步:将 query 转为 embedding + query_embedding = embed_text([query])[0] + # 第二步:embedding 转成 '[x, y, z]' 格式字符串 + embedding_str = ",".join(map(str, query_embedding)) + cur = conn.cursor() + query = f""" + SELECT table_name + FROM table_embeddings + ORDER BY embedding <-> '[{embedding_str}]'::vector + LIMIT {top_k}; + """ + cur.execute(query) + results = [row[0] for row in cur.fetchall()] + cur.close() + return results + + +def format_schema_for_llm(schemas: list[dict]) -> str: + lines = [] + for schema in schemas: + lines.append(f"【表名】:{schema['table']}") + lines.append("【字段】:") + for col in schema["text"].split("结构如下:")[1].split("\n"): + if col.strip(): + lines.append(f" - {col.strip()}") + lines.append("") # 空行分隔表 + return "\n".join(lines) + + +def embed_text(texts: list[str]) -> list[list[float]]: + paylaod = { + "input":texts, + "model":EM_MODEL + } + url = f"{API_BASE_EM}/embeddings" + response = requests.post(url, headers=HEADERS, json=paylaod) + json_data = response.json() + return [e['embedding'] for e in json_data['data']] + +# def search_similar_tables(query: str, index, index_table_map, top_k:int=3): +# query_embedding = embed_text([query])[0] +# distances, indices = index.search(np.array([query_embedding]).astype("float32"), int(top_k)) +# results = [] +# for i in indices[0]: +# if i != -1 and i in index_table_map: +# results.append(index_table_map[i]) +# return results + +# def create_index(embeddings: list[list[float]], texts: list[str], table_names: list[str]): +# print(len(embeddings), '-----------') +# dim = len(embeddings[0]) +# index = faiss.IndexFlatL2(dim) +# embeddings_np = np.array(embeddings).astype('float32') +# index.add(embeddings_np) + +# # 构建索引到表名的映射字典 +# index_table_map = {i: table_names[i] for i in range(len(table_names))} +# return index, index_table_map + + + +class QueryLLMviewSet(CustomModelViewSet): + queryset = Message.objects.all() + serializer_class = MessageSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} + + @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer) + def completion(self, request): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + prompt = serializer.validated_data['content'] + conversation = serializer.validated_data['conversation'] + if not prompt or not conversation: + return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400) + save_message_thread_safe(content=prompt, conversation=conversation, role="user") + url = f"{API_BASE}/chat/completions" + user_prompt = f""" +我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。 + +注意: +只需回答"database"或"general"即可,不要有其他内容。 +""" + _payload = { + "model": MODEL, + "messages": [{"role": "user", "content": user_prompt}], + "temperature": 0, + "max_tokens": 10 + } + try: + class_response = requests.post(url, headers=HEADERS, json=_payload) + class_response.raise_for_status() + class_result = class_response.json() + question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower() + print("question_type", question_type) + if question_type == "database": + schema_text = get_relation_table(prompt) + user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构: +{schema_text} +请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 +需求是:{prompt} +""" + else: + user_prompt = f""" +回答以下问题,不需要涉及数据库查询: + +问题: {prompt} + +请直接回答问题,不要提及数据库或SQL。 +""" + # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages + # history = Message.objects.filter(conversation=conversation).order_by('create_time') + # chat_history = [{"role": msg.role, "content": msg.content} for msg in history] + # chat_history.append({"role": "user", "content": prompt}) + chat_history = [{"role":"user", "content":user_prompt}] + print("user_prompt", user_prompt) + payload = { + "model": MODEL, + "messages": chat_history, + "temperature": 0, + "stream": True + } + response = requests.post(url, headers=HEADERS, json=payload) + response.raise_for_status() + except requests.exceptions.RequestException as e: + return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500) + def stream_generator(): + accumulated_content = "" + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8') + if decoded_line.startswith('data:'): + if decoded_line.strip() == "data: [DONE]": + break # OpenAI-style标志结束 + try: + data = json.loads(decoded_line[6:]) + content = data.get('choices', [{}])[0].get('delta', {}).get('content', '') + if content: + accumulated_content += content + yield f"data: {content}\n\n" + + except Exception as e: + yield f"data: [解析失败]: {str(e)}\n\n" + print("accumulated_content", accumulated_content) + save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system") + + if question_type == "database": + sql = extract_sql_code(accumulated_content) + if sql: + try: + conn = connect_db() + if is_safe_sql(sql): + result = execute_sql(conn, sql) + save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system") + yield f"data: SQL执行结果: {result}\n\n" + else: + yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n" + except Exception as e: + yield f"data: SQL执行失败: {str(e)}\n\n" + finally: + if conn: + conn.close() + else: + yield "data: \\n[文本结束]\n\n" + return StreamingHttpResponse(stream_generator(), content_type='text/event-stream') + +# 先新建对话 生成对话session_id +class ConversationViewSet(CustomModelViewSet): + queryset = Conversation.objects.all() + serializer_class = ConversationSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} + + diff --git a/apps/ichat/views.py b/apps/ichat/views.py new file mode 100644 index 00000000..482f519c --- /dev/null +++ b/apps/ichat/views.py @@ -0,0 +1,214 @@ +import requests +import json +import faiss +import numpy as np +from rest_framework.views import APIView +from apps.ichat.serializers import MessageSerializer, ConversationSerializer +from rest_framework.response import Response +from apps.ichat.models import Conversation, Message +from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe +from django.http import StreamingHttpResponse, JsonResponse +from rest_framework.decorators import action +from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet + +# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693" +# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" +# MODEL = "qwen-plus" + +#本地部署的模式 +API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw" +API_BASE = "http://106.0.4.200:9000/v1" +MODEL = "qwen14b" + +# 文本向量化模型 +EM_MODEL = "m3e-base" + +# google gemini +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="google/gemini-2.0-flash-exp:free" + +# deepseek v3 +# API_KEY = "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" +# API_BASE = "https://openrouter.ai/api/v1" +# MODEL="deepseek/deepseek-chat-v3-0324:free" + +TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表 + +HEADERS = { + "Content-Type": "application/json", + "Authorization": f"Bearer {API_KEY}" + } + +# 表结构向量化 +def embed_text(texts: list[str]) -> list[list[float]]: + url = f"{API_BASE}/embeddings" + _payload = { + "model": EM_MODEL, + "input": texts + } + try: + response = requests.post(url, headers=HEADERS, json=_payload) + except requests.exceptions.RequestException as e: + return JsonResponse({"error":f"Embedding API调用失败: {e}"}, status=500) + print("embeddings", response["data"]) + return [e['embedding'] for e in response['data']] + + +# 创建Faiss索引 +def create_index(embeddings: list[list[float]], texts: list[str], table_names: list[str]): + index = faiss.IndexFlatL2(len(embeddings[0])) + index.add(np.array(embeddings)).astype("float32") + index_table_map = {i: {"table": table_names[i], "text": texts[i]} for i in range(len(table_names))} + return index, index_table_map + +# 查询 +def search_similar_tables(query:str, index, index_table_map, k:int=5): + query_embedding = embed_text([query])[0] + distances, indices = index.search(np.array([query_embedding]).astype("float32"), k) + return [index_table_map[i] for i in indices[0]] + +def get_tables(conn) -> list[str]: + with conn.cursor() as cur: + cur.execute(""" + SELECT tablename + FROM pg_tables + WHERE schemaname = 'public' + AND tableowner = 'postgres'; + """) + return [row[0] for row in cur.fetchall()] + + +# 主函数:提取表结构、嵌入向量并存储到 FAISS +def get_relation_table(query): + conn = connect_db() + table_names = get_tables(conn) + schemas = get_schema_text(conn, table_names) + texts = [s["text"] for s in schemas] + # table_names = [s["table"] for s in schemas] + embeddings = embed_text(texts) + index, index_table_map = create_index(embeddings, texts, table_names) + results = search_similar_tables(query, index, index_table_map) + + for result in results: + print(f"表名: {result['table']}\n结构: {result['text']}") + if len(results) == 0: + return "没有找到相关表结构" + return results + + +class QueryLLMviewSet(CustomModelViewSet): + queryset = Message.objects.all() + serializer_class = MessageSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} + + @action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer) + def completion(self, request): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + prompt = serializer.validated_data['content'] + conversation = serializer.validated_data['conversation'] + if not prompt or not conversation: + return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400) + save_message_thread_safe(content=prompt, conversation=conversation, role="user") + url = f"{API_BASE}/chat/completions" + user_prompt = f""" +我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关。如果是,回答"database";如果不是,回答"general"。 + +注意: +只需回答"database"或"general"即可,不要有其他内容。 +""" + _payload = { + "model": MODEL, + "messages": [{"role": "user", "content": user_prompt}], + "temperature": 0, + "max_tokens": 10 + } + try: + class_response = requests.post(url, headers=HEADERS, json=_payload) + class_response.raise_for_status() + class_result = class_response.json() + question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower() + print("question_type", question_type) + if question_type == "database": + schema_text = get_relation_table(prompt) + print("schema_text----------------------", schema_text) + user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构: +{schema_text} +请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。 +需求是:{prompt} +""" + else: + user_prompt = f""" +回答以下问题,不需要涉及数据库查询: + +问题: {prompt} + +请直接回答问题,不要提及数据库或SQL。 +""" + # TODO 是否应该拿到conservastion的id,然后根据id去数据库查询所以的messages, 然后赋值给messages + # history = Message.objects.filter(conversation=conversation).order_by('create_time') + # chat_history = [{"role": msg.role, "content": msg.content} for msg in history] + # chat_history.append({"role": "user", "content": prompt}) + chat_history = [{"role":"user", "content":user_prompt}] + print("chat_history", chat_history) + payload = { + "model": MODEL, + "messages": chat_history, + "temperature": 0, + "stream": True + } + response = requests.post(url, headers=HEADERS, json=payload) + response.raise_for_status() + except requests.exceptions.RequestException as e: + return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500) + def stream_generator(): + accumulated_content = "" + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8') + if decoded_line.startswith('data:'): + if decoded_line.strip() == "data: [DONE]": + break # OpenAI-style标志结束 + try: + data = json.loads(decoded_line[6:]) + content = data.get('choices', [{}])[0].get('delta', {}).get('content', '') + if content: + accumulated_content += content + yield f"data: {content}\n\n" + + except Exception as e: + yield f"data: [解析失败]: {str(e)}\n\n" + print("accumulated_content", accumulated_content) + save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system") + + if question_type == "database": + sql = extract_sql_code(accumulated_content) + if sql: + try: + conn = connect_db() + if is_safe_sql(sql): + result = execute_sql(conn, sql) + save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system") + yield f"data: SQL执行结果: {result}\n\n" + else: + yield f"data: 拒绝执行非查询类 SQL:{sql}\n\n" + except Exception as e: + yield f"data: SQL执行失败: {str(e)}\n\n" + finally: + if conn: + conn.close() + else: + yield "data: \\n[文本结束]\n\n" + return StreamingHttpResponse(stream_generator(), content_type='text/event-stream') + +# 先新建对话 生成对话session_id +class ConversationViewSet(CustomModelViewSet): + queryset = Conversation.objects.all() + serializer_class = ConversationSerializer + ordering = ['create_time'] + perms_map = {'get':'*', 'post':'*', 'put':'*'} + + diff --git a/apps/ichat/views2.py b/apps/ichat/views2.py new file mode 100644 index 00000000..9b2766bb --- /dev/null +++ b/apps/ichat/views2.py @@ -0,0 +1,129 @@ +import requests +import os +from apps.utils.sql import execute_raw_sql +import json +from apps.utils.tools import MyJSONEncoder +from .utils import is_safe_sql +from rest_framework.views import APIView +from drf_yasg.utils import swagger_auto_schema +from rest_framework import serializers +from rest_framework.exceptions import ParseError +from rest_framework.response import Response +from django.conf import settings +from apps.utils.mixins import MyLoggingMixin +from django.core.cache import cache +import uuid +from apps.utils.thread import MyThread + +LLM_URL = getattr(settings, "LLM_URL", "") +API_KEY = getattr(settings, "LLM_API_KEY", "") +MODEL = "qwen14b" +HEADERS = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" +} +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +def load_promot(name): + with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f: + return f.read() + + +def ask(input:str, p_name:str, stream=False): + his = [{"role":"system", "content": load_promot(p_name)}] + his.append({"role":"user", "content": input}) + payload = { + "model": MODEL, + "messages": his, + "temperature": 0, + "stream": stream + } + response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream) + if not stream: + return response.json()["choices"][0]["message"]["content"] + else: + # 处理流式响应 + full_content = "" + for chunk in response.iter_lines(): + if chunk: + # 通常流式响应是SSE格式(data: {...}) + decoded_chunk = chunk.decode('utf-8') + if decoded_chunk.startswith("data:"): + json_str = decoded_chunk[5:].strip() + if json_str == "[DONE]": + break + try: + chunk_data = json.loads(json_str) + if "choices" in chunk_data and chunk_data["choices"]: + delta = chunk_data["choices"][0].get("delta", {}) + if "content" in delta: + print(delta["content"]) + full_content += delta["content"] + except json.JSONDecodeError: + continue + return full_content + +def work_chain(input:str, t_key:str): + pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]} + cache.set(t_key, pdict) + res_text = ask(input, 'w_sql') + if res_text == '请以 查询 开头,重新描述你的需求': + pdict["state"] = "error" + pdict["steps"].append({"state":"error", "msg":res_text}) + cache.set(t_key, pdict) + return + else: + pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text}) + cache.set(t_key, pdict) + if not is_safe_sql(res_text): + pdict["state"] = "error" + pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"}) + cache.set(t_key, pdict) + return + pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"}) + cache.set(t_key, pdict) + res = execute_raw_sql(res_text) + pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res}) + cache.set(t_key, pdict) + pdict["steps"].append({"state":"ok", "msg":"正在生成报告"}) + cache.set(t_key, pdict) + res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana') + content = res2.lstrip('```html ').rstrip('```') + pdict["state"] = "done" + pdict["content"] = content + pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content}) + cache.set(t_key, pdict) + return + +class InputSerializer(serializers.Serializer): + input = serializers.CharField(label="查询需求") + +class WorkChain(MyLoggingMixin, APIView): + + @swagger_auto_schema( + operation_summary="提交查询需求", + request_body=InputSerializer) + def post(self, request): + llm_enabled = getattr(settings, "LLM_ENABLED", False) + if not llm_enabled: + raise ParseError('LLM功能未启用') + input = request.data.get('input') + t_key = f'ichat_{uuid.uuid4()}' + MyThread(target=work_chain, args=(input, t_key)).start() + return Response({'ichat_tid': t_key}) + + @swagger_auto_schema( + operation_summary="获取查询进度") + def get(self, request): + llm_enabled = getattr(settings, "LLM_ENABLED", False) + if not llm_enabled: + raise ParseError('LLM功能未启用') + ichat_tid = request.GET.get('ichat_tid') + if ichat_tid: + return Response(cache.get(ichat_tid)) + +if __name__ == "__main__": + print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告")) + + from apps.ichat.views2 import work_chain + print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告')) \ No newline at end of file diff --git a/media/default/abc.mp3 b/media/default/abc.mp3 new file mode 100644 index 00000000..e4283950 Binary files /dev/null and b/media/default/abc.mp3 differ diff --git a/media/default/abc2.mp3 b/media/default/abc2.mp3 new file mode 100644 index 00000000..a1352813 Binary files /dev/null and b/media/default/abc2.mp3 differ diff --git a/media/default/alarm.mp3 b/media/default/alarm.mp3 new file mode 100644 index 00000000..2727033d Binary files /dev/null and b/media/default/alarm.mp3 differ diff --git a/media/default/avatar.png b/media/default/avatar.png new file mode 100644 index 00000000..07e80789 Binary files /dev/null and b/media/default/avatar.png differ diff --git a/media/default/template/material.xlsx b/media/default/template/material.xlsx new file mode 100644 index 00000000..291a3067 Binary files /dev/null and b/media/default/template/material.xlsx differ diff --git a/media/default/template/mioitemw_t.xlsx b/media/default/template/mioitemw_t.xlsx new file mode 100644 index 00000000..ad58070f Binary files /dev/null and b/media/default/template/mioitemw_t.xlsx differ diff --git a/media/default/template/出入库明细.xlsx b/media/default/template/出入库明细.xlsx new file mode 100644 index 00000000..f7e526e6 Binary files /dev/null and b/media/default/template/出入库明细.xlsx differ diff --git a/server/settings.py b/server/settings.py index 0797700d..fc46cb06 100755 --- a/server/settings.py +++ b/server/settings.py @@ -63,7 +63,7 @@ INSTALLED_APPS = [ 'apps.wf', 'apps.ecm', 'apps.hrm', - 'apps.ichat', + #'apps.ichat', 'apps.am', 'apps.vm', 'apps.rpm',