feat:恢复ichat 功能和 defaut 下的文件

This commit is contained in:
TianyangZhang 2026-03-13 16:59:12 +08:00
parent d657c9fd26
commit e3dcb492d7
26 changed files with 1180 additions and 1 deletions

0
apps/ichat/__init__.py Normal file
View File

3
apps/ichat/admin.py Normal file
View File

@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

6
apps/ichat/apps.py Normal file
View File

@ -0,0 +1,6 @@
from django.apps import AppConfig
class ChatConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'apps.ichat'

View File

@ -0,0 +1,48 @@
# Generated by Django 3.2.12 on 2025-05-21 05:59
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='Conversation',
fields=[
('id', models.CharField(editable=False, help_text='主键ID', max_length=20, primary_key=True, serialize=False, verbose_name='主键ID')),
('create_time', models.DateTimeField(default=django.utils.timezone.now, help_text='创建时间', verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, help_text='修改时间', verbose_name='修改时间')),
('is_deleted', models.BooleanField(default=False, help_text='删除标记', verbose_name='删除标记')),
('title', models.CharField(default='新对话', max_length=200, verbose_name='对话标题')),
('create_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='conversation_create_by', to=settings.AUTH_USER_MODEL, verbose_name='创建人')),
('update_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='conversation_update_by', to=settings.AUTH_USER_MODEL, verbose_name='最后编辑人')),
],
options={
'abstract': False,
},
),
migrations.CreateModel(
name='Message',
fields=[
('id', models.CharField(editable=False, help_text='主键ID', max_length=20, primary_key=True, serialize=False, verbose_name='主键ID')),
('create_time', models.DateTimeField(default=django.utils.timezone.now, help_text='创建时间', verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, help_text='修改时间', verbose_name='修改时间')),
('is_deleted', models.BooleanField(default=False, help_text='删除标记', verbose_name='删除标记')),
('content', models.TextField(verbose_name='消息内容')),
('role', models.CharField(default='user', help_text='system/user', max_length=10, verbose_name='角色')),
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='ichat.conversation', verbose_name='对话')),
],
options={
'abstract': False,
},
),
]

View File

17
apps/ichat/models.py Normal file
View File

@ -0,0 +1,17 @@
from django.db import models
from apps.system.models import CommonADModel, BaseModel
# Create your models here.
class Conversation(CommonADModel):
"""
TN: 对话
"""
title = models.CharField(max_length=200, default='新对话',verbose_name='对话标题')
class Message(BaseModel):
"""
TN: 消息
"""
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages', verbose_name='对话')
content = models.TextField(verbose_name='消息内容')
role = models.CharField("角色", max_length=10, default='user', help_text="system/user")

View File

@ -0,0 +1,14 @@
# 角色
你是一位数据分析专家和前端程序员,具备深厚的专业知识和丰富的实践经验。你能够精准理解用户的文本描述, 并形成报告。
# 技能
1. 仔细分析用户提供的JSON格式数据分析用户需求。
2. 依据得到的需求, 分别获取JSON数据中的关键信息。
3. 根据2中的关键信息最优化选择表格/饼图/柱状图/折线图等格式绘制报告。
# 回答要求
1. 仅生成完整的HTML代码所有功能都需要实现支持响应式不要输出任何解释或说明。
2. 代码中如需要Echarts等js库请直接使用中国大陆的CDN链接例如bootcdn的链接。
3. 标题为 数据分析报告。
3. 在开始部分请以表格形式简略展示获取的JSON数据。
4. 之后选择最合适的图表方式生成相应的图。
5. 在最后提供可下载该报告的完整PDF的按钮和功能。
6. 在最后提供可下载含有JSON数据的EXCEL文件的按钮和功能。

View File

@ -0,0 +1,53 @@
# 角色
你是一位资深的Postgresql数据库SQL专家具备深厚的专业知识和丰富的实践经验。你能够精准理解用户的文本描述并生成准确可执行的SQL语句。
# 技能
1. 仔细分析用户提供的文本描述,明确用户需求。
2. 根据对用户需求的理解生成符合Postgresql数据库语法的准确可执行的SQL语句。
# 回答要求
1. 如果用户的询问未以 查询 开头,请直接回复 "请以 查询 开头,重新描述你的需求"。
2. 生成的SQL语句必须符合Postgresql数据库的语法规范。
3. 不要使用 Markerdown 和 SQL 语法格式输出,禁止添加语法标准、备注、说明等信息。
4. 直接输出符合Postgresql标准的SQL语句用txt纯文本格式展示即可。
5. 如果无法生成符合要求的SQL语句请直接回复 "无法生成"。
# 示例
1. 问:查询 外协白片抛 工段在2025年6月1日到2025年6月15日之间的生产合格数以及合格率等
select
sum(mlog.count_use) as 领用数,
sum(mlog.count_real) as 生产数,
sum(mlog.count_ok) as 合格数,
sum(mlog.count_notok) as 不合格数,
CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率
from wpm_mlog mlog
left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id
where mlog.submit_time is not null
and mgroup.name = '外协白片抛'
and mlog.handle_date >= '2025-06-01'
and mlog.handle_date <= '2025-06-15'
2. 问:查询 黑化 工段在2025年6月的生产合格数以及合格率等
答: select
sum(mlog.count_use) as 领用数,
sum(mlog.count_real) as 生产数,
sum(mlog.count_ok) as 合格数,
sum(mlog.count_notok) as 不合格数,
CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率
from wpm_mlog mlog
left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id
where mlog.submit_time is not null
and mgroup.name = '黑化'
and mlog.handle_date >= '2025-06-01'
and mlog.handle_date <= '2025-06-30'
3. 问:查询 各工段 在2025年6月的生产合格数以及合格率等
答: select
mgroup.name as 工段,
sum(mlog.count_use) as 领用数,
sum(mlog.count_real) as 生产数,
sum(mlog.count_ok) as 合格数,
sum(mlog.count_notok) as 不合格数,
CAST ( SUM ( mlog.count_ok ) AS FLOAT ) / NULLIF ( SUM ( mlog.count_real ), 0 ) * 100 AS 合格率
from wpm_mlog mlog
left join mtm_mgroup mgroup on mgroup.id = mlog.mgroup_id
where mlog.submit_time is not null
and mlog.handle_date >= '2025-06-01'
and mlog.handle_date <= '2025-06-30'
group by mgroup.id
order by mgroup.sort

22
apps/ichat/script.py Normal file
View File

@ -0,0 +1,22 @@
import json
from .models import Message
from django.http import StreamingHttpResponse
def stream_generator(stream_response: bytes, conversation_id: str):
full_content = ''
for chunk in stream_response.iter_content(chunk_size=1024):
if chunk:
full_content += chunk.decode('utf-8')
try:
data = json.loads(full_content)
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
Message.objects.create(
conversation_id=conversation_id,
content=content
)
yield f" data:{content}\n\n"
full_content = ''
except json.JSONDecodeError:
continue
return StreamingHttpResponse(stream_generator(stream_response, conversation_id), content_type='text/event-stream')

18
apps/ichat/serializers.py Normal file
View File

@ -0,0 +1,18 @@
from rest_framework import serializers
from .models import Conversation, Message
from apps.utils.constants import EXCLUDE_FIELDS
class MessageSerializer(serializers.ModelSerializer):
class Meta:
model = Message
fields = ['id', 'conversation', 'content', 'role']
read_only_fields = EXCLUDE_FIELDS
class ConversationSerializer(serializers.ModelSerializer):
messages = MessageSerializer(many=True, read_only=True)
class Meta:
model = Conversation
fields = ['id', 'title', 'messages']
read_only_fields = EXCLUDE_FIELDS

3
apps/ichat/tests.py Normal file
View File

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

16
apps/ichat/urls.py Normal file
View File

@ -0,0 +1,16 @@
from django.urls import path, include
from rest_framework.routers import DefaultRouter
from apps.ichat.views import QueryLLMviewSet, ConversationViewSet
from apps.ichat.views2 import WorkChain
API_BASE_URL = 'api/ichat/'
router = DefaultRouter()
router.register('conversation', ConversationViewSet, basename='conversation')
router.register('message', QueryLLMviewSet, basename='message')
urlpatterns = [
path(API_BASE_URL, include(router.urls)),
path(API_BASE_URL + 'workchain/ask/', WorkChain.as_view(), name='workchain')
]

195
apps/ichat/utils.py Normal file
View File

@ -0,0 +1,195 @@
import re
import psycopg2
import threading
from django.db import transaction
from .models import Message
# 数据库连接
def connect_db():
from server.conf import DATABASES
db_conf = DATABASES['default']
conn = psycopg2.connect(
host=db_conf['HOST'],
port=db_conf['PORT'],
user=db_conf['USER'],
password=db_conf['PASSWORD'],
database=db_conf['NAME']
)
return conn
def extract_sql_code(text):
# 优先尝试 ```sql 包裹的语句
match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
# fallback: 寻找首个 select 语句
match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip()
return None
# def get_schema_text(conn, table_names:list):
# cur = conn.cursor()
# query = """
# SELECT
# table_name, column_name, data_type
# FROM
# information_schema.columns
# WHERE
# table_schema = 'public'
# and table_name in %s;
# """
# cur.execute(query, (tuple(table_names), ))
# schema = {}
# for table_name, column_name, data_type in cur.fetchall():
# if table_name not in schema:
# schema[table_name] = []
# schema[table_name].append(f"{column_name} ({data_type})")
# cur.close()
# schema_text = ""
# for table_name, columns in schema.items():
# schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n"
# return schema_text
def get_schema_text(conn, table_names: list):
cur = conn.cursor()
query = """
SELECT
c.relname AS table_name,
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
d.description AS column_comment
FROM
pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_attribute a ON a.attrelid = c.oid
LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum
WHERE
n.nspname = 'public'
AND c.relname = ANY(%s)
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY
c.relname, a.attnum;
"""
cur.execute(query, (table_names,))
schema = {}
for table_name, column_name, data_type, comment in cur.fetchall():
if comment and "备注" in comment:
comment = comment.split("备注")[0].strip()
schema.setdefault(table_name, []).append(
f"{column_name}-{comment}"
)
cur.close()
return [
{"table": table, "text": f"{table} 包含列:\n" + "\n".join(columns)}
for table, columns in schema.items()
]
# def get_schema_text(conn, table_names: list):
# cur = conn.cursor()
# # 获取字段、类型、注释
# column_query = """
# SELECT
# c.relname AS table_name,
# a.attname AS column_name,
# pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
# d.description AS column_comment
# FROM
# pg_class c
# JOIN pg_namespace n ON n.oid = c.relnamespace
# JOIN pg_attribute a ON a.attrelid = c.oid
# LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum
# WHERE
# n.nspname = 'public'
# AND c.relname = ANY(%s)
# AND a.attnum > 0
# AND NOT a.attisdropped
# ORDER BY
# c.relname, a.attnum;
# """
# # 获取外键信息
# fk_query = """
# SELECT
# conrelid::regclass::text AS table_name,
# a.attname AS column_name,
# confrelid::regclass::text AS foreign_table,
# af.attname AS foreign_column
# FROM
# pg_constraint
# JOIN pg_class ON conrelid = pg_class.oid
# JOIN pg_namespace n ON pg_class.relnamespace = n.oid
# JOIN pg_attribute a ON a.attrelid = conrelid AND a.attnum = ANY(conkey)
# JOIN pg_attribute af ON af.attrelid = confrelid AND af.attnum = ANY(confkey)
# WHERE
# contype = 'f'
# AND n.nspname = 'public'
# AND conrelid::regclass::text = ANY(%s);
# """
# cur.execute(column_query, (table_names,))
# columns = cur.fetchall()
# cur.execute(fk_query, (table_names,))
# fks = cur.fetchall()
# # 构建外键字典
# fk_map = {} # {(table, column): "foreign_table(foreign_column)"}
# for table, column, f_table, f_column in fks:
# fk_map[(table, column)] = f"{f_table}({f_column})"
# # 组织输出结构
# schema = {}
# for table, column, dtype, comment in columns:
# fk_note = f" -> {fk_map[(table, column)]}" if (table, column) in fk_map else ""
# comment_note = f" -- {comment}" if comment else ""
# schema.setdefault(table, []).append(f"{column} ({dtype}{fk_note}{comment_note})")
# cur.close()
# # 生成文本
# schema_text = ""
# for table, cols in schema.items():
# schema_text += f"表 {table} 包含列:\n - " + "\n - ".join(cols) + "\n"
# return schema_text
def is_safe_sql(sql:str) -> bool:
sql = sql.strip().lower()
return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql)
def execute_sql(conn, sql_query):
cur = conn.cursor()
cur.execute(sql_query)
try:
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description]
result = [dict(zip(columns, row)) for row in rows]
except psycopg2.ProgrammingError:
result = cur.statusmessage
cur.close()
return result
def strip_sql_markdown(content: str) -> str:
# 去掉包裹在 ```sql 或 ``` 中的内容
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
else:
return None
# ORM 写入包装函数
def save_message_thread_safe(**kwargs):
def _save():
with transaction.atomic():
Message.objects.create(**kwargs)
threading.Thread(target=_save).start()

0
apps/ichat/view Normal file
View File

155
apps/ichat/view_bak.py Normal file
View File

@ -0,0 +1,155 @@
import requests
import json
from rest_framework.views import APIView
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
from rest_framework.response import Response
from apps.ichat.models import Conversation, Message
from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe
from django.http import StreamingHttpResponse, JsonResponse
from rest_framework.decorators import action
from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# MODEL = "qwen-plus"
# #本地部署的模式
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
API_BASE = "http://106.0.4.200:9000/v1"
MODEL = "qwen14b"
# google gemini
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="google/gemini-2.0-flash-exp:free"
# deepseek v3
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="deepseek/deepseek-chat-v3-0324:free"
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
class QueryLLMviewSet(CustomModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}
@action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer)
def completion(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
prompt = serializer.validated_data['content']
conversation = serializer.validated_data['conversation']
if not prompt or not conversation:
return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400)
save_message_thread_safe(content=prompt, conversation=conversation, role="user")
url = f"{API_BASE}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
user_prompt = f"""
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关如果是回答"database"如果不是回答"general"
注意
只需回答"database""general"即可不要有其他内容
"""
_payload = {
"model": MODEL,
"messages": [{"role": "user", "content": user_prompt}],
"temperature": 0,
"max_tokens": 10
}
try:
class_response = requests.post(url, headers=headers, json=_payload)
class_response.raise_for_status()
class_result = class_response.json()
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
print("question_type", question_type)
if question_type == "database":
conn = connect_db()
schema_text = get_schema_text(conn, TABLES)
print("schema_text----------------------", schema_text)
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
{schema_text}
请根据我的需求生成一条标准的PostgreSQL SQL语句直接返回SQL不要额外解释
需求是{prompt}
"""
else:
user_prompt = f"""
回答以下问题不需要涉及数据库查询
问题: {prompt}
请直接回答问题不要提及数据库或SQL
"""
# TODO 是否应该拿到conservastion的id然后根据id去数据库查询所以的messages, 然后赋值给messages
# history = Message.objects.filter(conversation=conversation).order_by('create_time')
# chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
# chat_history.append({"role": "user", "content": prompt})
chat_history = [{"role":"user", "content":user_prompt}]
print("chat_history", chat_history)
payload = {
"model": MODEL,
"messages": chat_history,
"temperature": 0,
"stream": True
}
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
except requests.exceptions.RequestException as e:
return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500)
def stream_generator():
accumulated_content = ""
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data:'):
if decoded_line.strip() == "data: [DONE]":
break # OpenAI-style标志结束
try:
data = json.loads(decoded_line[6:])
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
if content:
accumulated_content += content
yield f"data: {content}\n\n"
except Exception as e:
yield f"data: [解析失败]: {str(e)}\n\n"
print("accumulated_content", accumulated_content)
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
if question_type == "database":
sql = extract_sql_code(accumulated_content)
if sql:
try:
conn = connect_db()
if is_safe_sql(sql):
result = execute_sql(conn, sql)
save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system")
yield f"data: SQL执行结果: {result}\n\n"
else:
yield f"data: 拒绝执行非查询类 SQL{sql}\n\n"
except Exception as e:
yield f"data: SQL执行失败: {str(e)}\n\n"
finally:
if conn:
conn.close()
else:
yield "data: \\n[文本结束]\n\n"
return StreamingHttpResponse(stream_generator(), content_type='text/event-stream')
# 先新建对话 生成对话session_id
class ConversationViewSet(CustomModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}

286
apps/ichat/view_bak2.py Normal file
View File

@ -0,0 +1,286 @@
import requests
import json
import faiss
import numpy as np
from rest_framework.views import APIView
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
from rest_framework.response import Response
from apps.ichat.models import Conversation, Message
from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, is_safe_sql, save_message_thread_safe, get_table_structures
from django.http import StreamingHttpResponse, JsonResponse
from rest_framework.decorators import action
from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# MODEL = "qwen-plus"
#本地部署的模式
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
API_BASE = "http://106.0.4.200:9000/v1"
MODEL = "qwen14b"
# 文本向量化模型
EM_MODEL = "m3e-base"
API_BASE_EM = "http://106.0.4.200:9997/v1"
# google gemini
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="google/gemini-2.0-flash-exp:free"
# deepseek v3
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="deepseek/deepseek-chat-v3-0324:free"
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
HEADERS = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
def get_table_names(conn):
sql = """
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public';
"""
cur = conn.cursor()
cur.execute(sql)
data = cur.fetchall()
cur.close()
return [row[0] for row in data]
# def get_relation_table(query):
# conn = connect_db()
# # table_names = TABLES
# table_names = get_table_names(conn)
# schemas = get_table_structures(conn, table_names)
# texts = [
# f"这是一个数据库表结构,表名为 {s['table']},其结构如下:{s['text']}"
# for s in schemas
# ]
# table_names = [s["table"] for s in schemas]
# embeddings = embed_text(texts)
# index, index_table_map = create_index(embeddings, texts, table_names)
# results = search_similar_tables(query, index, index_table_map, top_k=3)
# if not results:
# return "没有找到相关表结构"
# return results
def get_relation_table(query: str):
conn = connect_db()
table_names = get_table_names(conn) # 只获取用户表
schemas = get_table_structures(conn, table_names)
texts = [s["text"] for s in schemas]
table_names = [s["table"] for s in schemas]
embeddings = embed_text(texts)
# 存储向量
store_embeddings_pg(conn, embeddings, texts, table_names)
# 查询相似表
results = search_similar_tables_pg(conn, query, top_k=5)
if len(results) == 0:
return "没有找到相关表结构"
# 只取相关表的结构
schemas = get_table_structures(conn, results)
llm_results = format_schema_for_llm(schemas)
return llm_results
def store_embeddings_pg(conn, embeddings: list[list[float]], texts: list[str], table_names: list[str]):
cur = conn.cursor()
for embedding, text, table_name in zip(embeddings, texts, table_names):
cur.execute("""
INSERT INTO table_embeddings (table_name, schema_text, embedding)
VALUES (%s, %s, %s)
ON CONFLICT (table_name) DO UPDATE
SET schema_text = EXCLUDED.schema_text,
embedding = EXCLUDED.embedding
""", (table_name, text, embedding))
conn.commit()
cur.close()
def search_similar_tables_pg(conn, query: str, top_k: int = 5):
# 第一步:将 query 转为 embedding
query_embedding = embed_text([query])[0]
# 第二步embedding 转成 '[x, y, z]' 格式字符串
embedding_str = ",".join(map(str, query_embedding))
cur = conn.cursor()
query = f"""
SELECT table_name
FROM table_embeddings
ORDER BY embedding <-> '[{embedding_str}]'::vector
LIMIT {top_k};
"""
cur.execute(query)
results = [row[0] for row in cur.fetchall()]
cur.close()
return results
def format_schema_for_llm(schemas: list[dict]) -> str:
lines = []
for schema in schemas:
lines.append(f"【表名】:{schema['table']}")
lines.append("【字段】:")
for col in schema["text"].split("结构如下:")[1].split("\n"):
if col.strip():
lines.append(f" - {col.strip()}")
lines.append("") # 空行分隔表
return "\n".join(lines)
def embed_text(texts: list[str]) -> list[list[float]]:
paylaod = {
"input":texts,
"model":EM_MODEL
}
url = f"{API_BASE_EM}/embeddings"
response = requests.post(url, headers=HEADERS, json=paylaod)
json_data = response.json()
return [e['embedding'] for e in json_data['data']]
# def search_similar_tables(query: str, index, index_table_map, top_k:int=3):
# query_embedding = embed_text([query])[0]
# distances, indices = index.search(np.array([query_embedding]).astype("float32"), int(top_k))
# results = []
# for i in indices[0]:
# if i != -1 and i in index_table_map:
# results.append(index_table_map[i])
# return results
# def create_index(embeddings: list[list[float]], texts: list[str], table_names: list[str]):
# print(len(embeddings), '-----------')
# dim = len(embeddings[0])
# index = faiss.IndexFlatL2(dim)
# embeddings_np = np.array(embeddings).astype('float32')
# index.add(embeddings_np)
# # 构建索引到表名的映射字典
# index_table_map = {i: table_names[i] for i in range(len(table_names))}
# return index, index_table_map
class QueryLLMviewSet(CustomModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}
@action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer)
def completion(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
prompt = serializer.validated_data['content']
conversation = serializer.validated_data['conversation']
if not prompt or not conversation:
return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400)
save_message_thread_safe(content=prompt, conversation=conversation, role="user")
url = f"{API_BASE}/chat/completions"
user_prompt = f"""
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关如果是回答"database"如果不是回答"general"
注意
只需回答"database""general"即可不要有其他内容
"""
_payload = {
"model": MODEL,
"messages": [{"role": "user", "content": user_prompt}],
"temperature": 0,
"max_tokens": 10
}
try:
class_response = requests.post(url, headers=HEADERS, json=_payload)
class_response.raise_for_status()
class_result = class_response.json()
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
print("question_type", question_type)
if question_type == "database":
schema_text = get_relation_table(prompt)
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
{schema_text}
请根据我的需求生成一条标准的PostgreSQL SQL语句直接返回SQL不要额外解释
需求是{prompt}
"""
else:
user_prompt = f"""
回答以下问题不需要涉及数据库查询
问题: {prompt}
请直接回答问题不要提及数据库或SQL
"""
# TODO 是否应该拿到conservastion的id然后根据id去数据库查询所以的messages, 然后赋值给messages
# history = Message.objects.filter(conversation=conversation).order_by('create_time')
# chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
# chat_history.append({"role": "user", "content": prompt})
chat_history = [{"role":"user", "content":user_prompt}]
print("user_prompt", user_prompt)
payload = {
"model": MODEL,
"messages": chat_history,
"temperature": 0,
"stream": True
}
response = requests.post(url, headers=HEADERS, json=payload)
response.raise_for_status()
except requests.exceptions.RequestException as e:
return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500)
def stream_generator():
accumulated_content = ""
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data:'):
if decoded_line.strip() == "data: [DONE]":
break # OpenAI-style标志结束
try:
data = json.loads(decoded_line[6:])
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
if content:
accumulated_content += content
yield f"data: {content}\n\n"
except Exception as e:
yield f"data: [解析失败]: {str(e)}\n\n"
print("accumulated_content", accumulated_content)
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
if question_type == "database":
sql = extract_sql_code(accumulated_content)
if sql:
try:
conn = connect_db()
if is_safe_sql(sql):
result = execute_sql(conn, sql)
save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system")
yield f"data: SQL执行结果: {result}\n\n"
else:
yield f"data: 拒绝执行非查询类 SQL{sql}\n\n"
except Exception as e:
yield f"data: SQL执行失败: {str(e)}\n\n"
finally:
if conn:
conn.close()
else:
yield "data: \\n[文本结束]\n\n"
return StreamingHttpResponse(stream_generator(), content_type='text/event-stream')
# 先新建对话 生成对话session_id
class ConversationViewSet(CustomModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}

214
apps/ichat/views.py Normal file
View File

@ -0,0 +1,214 @@
import requests
import json
import faiss
import numpy as np
from rest_framework.views import APIView
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
from rest_framework.response import Response
from apps.ichat.models import Conversation, Message
from apps.ichat.utils import connect_db, extract_sql_code, execute_sql, get_schema_text, is_safe_sql, save_message_thread_safe
from django.http import StreamingHttpResponse, JsonResponse
from rest_framework.decorators import action
from apps.utils.viewsets import CustomGenericViewSet, CustomModelViewSet
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# MODEL = "qwen-plus"
#本地部署的模式
API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
API_BASE = "http://106.0.4.200:9000/v1"
MODEL = "qwen14b"
# 文本向量化模型
EM_MODEL = "m3e-base"
# google gemini
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="google/gemini-2.0-flash-exp:free"
# deepseek v3
# API_KEY = "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"
# API_BASE = "https://openrouter.ai/api/v1"
# MODEL="deepseek/deepseek-chat-v3-0324:free"
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
HEADERS = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
# 表结构向量化
def embed_text(texts: list[str]) -> list[list[float]]:
url = f"{API_BASE}/embeddings"
_payload = {
"model": EM_MODEL,
"input": texts
}
try:
response = requests.post(url, headers=HEADERS, json=_payload)
except requests.exceptions.RequestException as e:
return JsonResponse({"error":f"Embedding API调用失败: {e}"}, status=500)
print("embeddings", response["data"])
return [e['embedding'] for e in response['data']]
# 创建Faiss索引
def create_index(embeddings: list[list[float]], texts: list[str], table_names: list[str]):
index = faiss.IndexFlatL2(len(embeddings[0]))
index.add(np.array(embeddings)).astype("float32")
index_table_map = {i: {"table": table_names[i], "text": texts[i]} for i in range(len(table_names))}
return index, index_table_map
# 查询
def search_similar_tables(query:str, index, index_table_map, k:int=5):
query_embedding = embed_text([query])[0]
distances, indices = index.search(np.array([query_embedding]).astype("float32"), k)
return [index_table_map[i] for i in indices[0]]
def get_tables(conn) -> list[str]:
with conn.cursor() as cur:
cur.execute("""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
AND tableowner = 'postgres';
""")
return [row[0] for row in cur.fetchall()]
# 主函数:提取表结构、嵌入向量并存储到 FAISS
def get_relation_table(query):
conn = connect_db()
table_names = get_tables(conn)
schemas = get_schema_text(conn, table_names)
texts = [s["text"] for s in schemas]
# table_names = [s["table"] for s in schemas]
embeddings = embed_text(texts)
index, index_table_map = create_index(embeddings, texts, table_names)
results = search_similar_tables(query, index, index_table_map)
for result in results:
print(f"表名: {result['table']}\n结构: {result['text']}")
if len(results) == 0:
return "没有找到相关表结构"
return results
class QueryLLMviewSet(CustomModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}
@action(methods=['post'], detail=False, perms_map={'post':'*'} ,serializer_class=MessageSerializer)
def completion(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
prompt = serializer.validated_data['content']
conversation = serializer.validated_data['conversation']
if not prompt or not conversation:
return JsonResponse({"error": "缺少 prompt 或 conversation"}, status=400)
save_message_thread_safe(content=prompt, conversation=conversation, role="user")
url = f"{API_BASE}/chat/completions"
user_prompt = f"""
我提问的问题是:{prompt}请判断我的问题是否与数据库查询或操作相关如果是回答"database"如果不是回答"general"
注意
只需回答"database""general"即可不要有其他内容
"""
_payload = {
"model": MODEL,
"messages": [{"role": "user", "content": user_prompt}],
"temperature": 0,
"max_tokens": 10
}
try:
class_response = requests.post(url, headers=HEADERS, json=_payload)
class_response.raise_for_status()
class_result = class_response.json()
question_type = class_result.get('choices', [{}])[0].get('message', {}).get('content', '').strip().lower()
print("question_type", question_type)
if question_type == "database":
schema_text = get_relation_table(prompt)
print("schema_text----------------------", schema_text)
user_prompt = f"""你是一个专业的数据库工程师,根据以下数据库结构:
{schema_text}
请根据我的需求生成一条标准的PostgreSQL SQL语句直接返回SQL不要额外解释
需求是{prompt}
"""
else:
user_prompt = f"""
回答以下问题不需要涉及数据库查询
问题: {prompt}
请直接回答问题不要提及数据库或SQL
"""
# TODO 是否应该拿到conservastion的id然后根据id去数据库查询所以的messages, 然后赋值给messages
# history = Message.objects.filter(conversation=conversation).order_by('create_time')
# chat_history = [{"role": msg.role, "content": msg.content} for msg in history]
# chat_history.append({"role": "user", "content": prompt})
chat_history = [{"role":"user", "content":user_prompt}]
print("chat_history", chat_history)
payload = {
"model": MODEL,
"messages": chat_history,
"temperature": 0,
"stream": True
}
response = requests.post(url, headers=HEADERS, json=payload)
response.raise_for_status()
except requests.exceptions.RequestException as e:
return JsonResponse({"error":f"LLM API调用失败: {e}"}, status=500)
def stream_generator():
accumulated_content = ""
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data:'):
if decoded_line.strip() == "data: [DONE]":
break # OpenAI-style标志结束
try:
data = json.loads(decoded_line[6:])
content = data.get('choices', [{}])[0].get('delta', {}).get('content', '')
if content:
accumulated_content += content
yield f"data: {content}\n\n"
except Exception as e:
yield f"data: [解析失败]: {str(e)}\n\n"
print("accumulated_content", accumulated_content)
save_message_thread_safe(content=accumulated_content, conversation=conversation, role="system")
if question_type == "database":
sql = extract_sql_code(accumulated_content)
if sql:
try:
conn = connect_db()
if is_safe_sql(sql):
result = execute_sql(conn, sql)
save_message_thread_safe(content=f"SQL结果: {result}", conversation=conversation, role="system")
yield f"data: SQL执行结果: {result}\n\n"
else:
yield f"data: 拒绝执行非查询类 SQL{sql}\n\n"
except Exception as e:
yield f"data: SQL执行失败: {str(e)}\n\n"
finally:
if conn:
conn.close()
else:
yield "data: \\n[文本结束]\n\n"
return StreamingHttpResponse(stream_generator(), content_type='text/event-stream')
# 先新建对话 生成对话session_id
class ConversationViewSet(CustomModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
ordering = ['create_time']
perms_map = {'get':'*', 'post':'*', 'put':'*'}

129
apps/ichat/views2.py Normal file
View File

@ -0,0 +1,129 @@
import requests
import os
from apps.utils.sql import execute_raw_sql
import json
from apps.utils.tools import MyJSONEncoder
from .utils import is_safe_sql
from rest_framework.views import APIView
from drf_yasg.utils import swagger_auto_schema
from rest_framework import serializers
from rest_framework.exceptions import ParseError
from rest_framework.response import Response
from django.conf import settings
from apps.utils.mixins import MyLoggingMixin
from django.core.cache import cache
import uuid
from apps.utils.thread import MyThread
LLM_URL = getattr(settings, "LLM_URL", "")
API_KEY = getattr(settings, "LLM_API_KEY", "")
MODEL = "qwen14b"
HEADERS = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
def load_promot(name):
with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f:
return f.read()
def ask(input:str, p_name:str, stream=False):
his = [{"role":"system", "content": load_promot(p_name)}]
his.append({"role":"user", "content": input})
payload = {
"model": MODEL,
"messages": his,
"temperature": 0,
"stream": stream
}
response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream)
if not stream:
return response.json()["choices"][0]["message"]["content"]
else:
# 处理流式响应
full_content = ""
for chunk in response.iter_lines():
if chunk:
# 通常流式响应是SSE格式data: {...}
decoded_chunk = chunk.decode('utf-8')
if decoded_chunk.startswith("data:"):
json_str = decoded_chunk[5:].strip()
if json_str == "[DONE]":
break
try:
chunk_data = json.loads(json_str)
if "choices" in chunk_data and chunk_data["choices"]:
delta = chunk_data["choices"][0].get("delta", {})
if "content" in delta:
print(delta["content"])
full_content += delta["content"]
except json.JSONDecodeError:
continue
return full_content
def work_chain(input:str, t_key:str):
pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]}
cache.set(t_key, pdict)
res_text = ask(input, 'w_sql')
if res_text == '请以 查询 开头,重新描述你的需求':
pdict["state"] = "error"
pdict["steps"].append({"state":"error", "msg":res_text})
cache.set(t_key, pdict)
return
else:
pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text})
cache.set(t_key, pdict)
if not is_safe_sql(res_text):
pdict["state"] = "error"
pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"})
cache.set(t_key, pdict)
return
pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"})
cache.set(t_key, pdict)
res = execute_raw_sql(res_text)
pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res})
cache.set(t_key, pdict)
pdict["steps"].append({"state":"ok", "msg":"正在生成报告"})
cache.set(t_key, pdict)
res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana')
content = res2.lstrip('```html ').rstrip('```')
pdict["state"] = "done"
pdict["content"] = content
pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content})
cache.set(t_key, pdict)
return
class InputSerializer(serializers.Serializer):
input = serializers.CharField(label="查询需求")
class WorkChain(MyLoggingMixin, APIView):
@swagger_auto_schema(
operation_summary="提交查询需求",
request_body=InputSerializer)
def post(self, request):
llm_enabled = getattr(settings, "LLM_ENABLED", False)
if not llm_enabled:
raise ParseError('LLM功能未启用')
input = request.data.get('input')
t_key = f'ichat_{uuid.uuid4()}'
MyThread(target=work_chain, args=(input, t_key)).start()
return Response({'ichat_tid': t_key})
@swagger_auto_schema(
operation_summary="获取查询进度")
def get(self, request):
llm_enabled = getattr(settings, "LLM_ENABLED", False)
if not llm_enabled:
raise ParseError('LLM功能未启用')
ichat_tid = request.GET.get('ichat_tid')
if ichat_tid:
return Response(cache.get(ichat_tid))
if __name__ == "__main__":
print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告"))
from apps.ichat.views2 import work_chain
print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告'))

BIN
media/default/abc.mp3 Normal file

Binary file not shown.

BIN
media/default/abc2.mp3 Normal file

Binary file not shown.

BIN
media/default/alarm.mp3 Normal file

Binary file not shown.

BIN
media/default/avatar.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.9 KiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -63,7 +63,7 @@ INSTALLED_APPS = [
'apps.wf',
'apps.ecm',
'apps.hrm',
'apps.ichat',
#'apps.ichat',
'apps.am',
'apps.vm',
'apps.rpm',