factory/apps/ichat/utils.py

196 lines
6.2 KiB
Python

import re
import psycopg2
import threading
from django.db import transaction
from .models import Message
# 数据库连接
def connect_db():
from server.conf import DATABASES
db_conf = DATABASES['default']
conn = psycopg2.connect(
host=db_conf['HOST'],
port=db_conf['PORT'],
user=db_conf['USER'],
password=db_conf['PASSWORD'],
database=db_conf['NAME']
)
return conn
def extract_sql_code(text):
# 优先尝试 ```sql 包裹的语句
match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
# fallback: 寻找首个 select 语句
match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip()
return None
# def get_schema_text(conn, table_names:list):
# cur = conn.cursor()
# query = """
# SELECT
# table_name, column_name, data_type
# FROM
# information_schema.columns
# WHERE
# table_schema = 'public'
# and table_name in %s;
# """
# cur.execute(query, (tuple(table_names), ))
# schema = {}
# for table_name, column_name, data_type in cur.fetchall():
# if table_name not in schema:
# schema[table_name] = []
# schema[table_name].append(f"{column_name} ({data_type})")
# cur.close()
# schema_text = ""
# for table_name, columns in schema.items():
# schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n"
# return schema_text
def get_schema_text(conn, table_names: list):
cur = conn.cursor()
query = """
SELECT
c.relname AS table_name,
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
d.description AS column_comment
FROM
pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
JOIN pg_attribute a ON a.attrelid = c.oid
LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum
WHERE
n.nspname = 'public'
AND c.relname = ANY(%s)
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY
c.relname, a.attnum;
"""
cur.execute(query, (table_names,))
schema = {}
for table_name, column_name, data_type, comment in cur.fetchall():
if comment and "备注" in comment:
comment = comment.split("备注")[0].strip()
schema.setdefault(table_name, []).append(
f"{column_name}-{comment}"
)
cur.close()
return [
{"table": table, "text": f"{table} 包含列:\n" + "\n".join(columns)}
for table, columns in schema.items()
]
# def get_schema_text(conn, table_names: list):
# cur = conn.cursor()
# # 获取字段、类型、注释
# column_query = """
# SELECT
# c.relname AS table_name,
# a.attname AS column_name,
# pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
# d.description AS column_comment
# FROM
# pg_class c
# JOIN pg_namespace n ON n.oid = c.relnamespace
# JOIN pg_attribute a ON a.attrelid = c.oid
# LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum
# WHERE
# n.nspname = 'public'
# AND c.relname = ANY(%s)
# AND a.attnum > 0
# AND NOT a.attisdropped
# ORDER BY
# c.relname, a.attnum;
# """
# # 获取外键信息
# fk_query = """
# SELECT
# conrelid::regclass::text AS table_name,
# a.attname AS column_name,
# confrelid::regclass::text AS foreign_table,
# af.attname AS foreign_column
# FROM
# pg_constraint
# JOIN pg_class ON conrelid = pg_class.oid
# JOIN pg_namespace n ON pg_class.relnamespace = n.oid
# JOIN pg_attribute a ON a.attrelid = conrelid AND a.attnum = ANY(conkey)
# JOIN pg_attribute af ON af.attrelid = confrelid AND af.attnum = ANY(confkey)
# WHERE
# contype = 'f'
# AND n.nspname = 'public'
# AND conrelid::regclass::text = ANY(%s);
# """
# cur.execute(column_query, (table_names,))
# columns = cur.fetchall()
# cur.execute(fk_query, (table_names,))
# fks = cur.fetchall()
# # 构建外键字典
# fk_map = {} # {(table, column): "foreign_table(foreign_column)"}
# for table, column, f_table, f_column in fks:
# fk_map[(table, column)] = f"{f_table}({f_column})"
# # 组织输出结构
# schema = {}
# for table, column, dtype, comment in columns:
# fk_note = f" -> {fk_map[(table, column)]}" if (table, column) in fk_map else ""
# comment_note = f" -- {comment}" if comment else ""
# schema.setdefault(table, []).append(f"{column} ({dtype}{fk_note}{comment_note})")
# cur.close()
# # 生成文本
# schema_text = ""
# for table, cols in schema.items():
# schema_text += f"表 {table} 包含列:\n - " + "\n - ".join(cols) + "\n"
# return schema_text
def is_safe_sql(sql:str) -> bool:
sql = sql.strip().lower()
return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql)
def execute_sql(conn, sql_query):
cur = conn.cursor()
cur.execute(sql_query)
try:
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description]
result = [dict(zip(columns, row)) for row in rows]
except psycopg2.ProgrammingError:
result = cur.statusmessage
cur.close()
return result
def strip_sql_markdown(content: str) -> str:
# 去掉包裹在 ```sql 或 ``` 中的内容
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
else:
return None
# ORM 写入包装函数
def save_message_thread_safe(**kwargs):
def _save():
with transaction.atomic():
Message.objects.create(**kwargs)
threading.Thread(target=_save).start()