196 lines
6.2 KiB
Python
196 lines
6.2 KiB
Python
import re
|
|
import psycopg2
|
|
import threading
|
|
from django.db import transaction
|
|
from .models import Message
|
|
|
|
# 数据库连接
|
|
def connect_db():
|
|
from server.conf import DATABASES
|
|
db_conf = DATABASES['default']
|
|
conn = psycopg2.connect(
|
|
host=db_conf['HOST'],
|
|
port=db_conf['PORT'],
|
|
user=db_conf['USER'],
|
|
password=db_conf['PASSWORD'],
|
|
database=db_conf['NAME']
|
|
)
|
|
return conn
|
|
|
|
def extract_sql_code(text):
|
|
# 优先尝试 ```sql 包裹的语句
|
|
match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
|
|
if match:
|
|
return match.group(1).strip()
|
|
|
|
# fallback: 寻找首个 select 语句
|
|
match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL)
|
|
if match:
|
|
return match.group(1).strip()
|
|
|
|
return None
|
|
|
|
|
|
# def get_schema_text(conn, table_names:list):
|
|
# cur = conn.cursor()
|
|
# query = """
|
|
# SELECT
|
|
# table_name, column_name, data_type
|
|
# FROM
|
|
# information_schema.columns
|
|
# WHERE
|
|
# table_schema = 'public'
|
|
# and table_name in %s;
|
|
# """
|
|
# cur.execute(query, (tuple(table_names), ))
|
|
|
|
# schema = {}
|
|
# for table_name, column_name, data_type in cur.fetchall():
|
|
# if table_name not in schema:
|
|
# schema[table_name] = []
|
|
# schema[table_name].append(f"{column_name} ({data_type})")
|
|
# cur.close()
|
|
# schema_text = ""
|
|
# for table_name, columns in schema.items():
|
|
# schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n"
|
|
# return schema_text
|
|
|
|
def get_schema_text(conn, table_names: list):
|
|
cur = conn.cursor()
|
|
query = """
|
|
SELECT
|
|
c.relname AS table_name,
|
|
a.attname AS column_name,
|
|
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
|
d.description AS column_comment
|
|
FROM
|
|
pg_class c
|
|
JOIN pg_namespace n ON n.oid = c.relnamespace
|
|
JOIN pg_attribute a ON a.attrelid = c.oid
|
|
LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum
|
|
WHERE
|
|
n.nspname = 'public'
|
|
AND c.relname = ANY(%s)
|
|
AND a.attnum > 0
|
|
AND NOT a.attisdropped
|
|
ORDER BY
|
|
c.relname, a.attnum;
|
|
"""
|
|
cur.execute(query, (table_names,))
|
|
|
|
schema = {}
|
|
for table_name, column_name, data_type, comment in cur.fetchall():
|
|
if comment and "备注" in comment:
|
|
comment = comment.split("备注")[0].strip()
|
|
schema.setdefault(table_name, []).append(
|
|
f"{column_name}-{comment}"
|
|
)
|
|
|
|
cur.close()
|
|
|
|
return [
|
|
{"table": table, "text": f"表 {table} 包含列:\n" + "\n".join(columns)}
|
|
for table, columns in schema.items()
|
|
]
|
|
|
|
|
|
# def get_schema_text(conn, table_names: list):
|
|
# cur = conn.cursor()
|
|
|
|
# # 获取字段、类型、注释
|
|
# column_query = """
|
|
# SELECT
|
|
# c.relname AS table_name,
|
|
# a.attname AS column_name,
|
|
# pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
|
# d.description AS column_comment
|
|
# FROM
|
|
# pg_class c
|
|
# JOIN pg_namespace n ON n.oid = c.relnamespace
|
|
# JOIN pg_attribute a ON a.attrelid = c.oid
|
|
# LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum
|
|
# WHERE
|
|
# n.nspname = 'public'
|
|
# AND c.relname = ANY(%s)
|
|
# AND a.attnum > 0
|
|
# AND NOT a.attisdropped
|
|
# ORDER BY
|
|
# c.relname, a.attnum;
|
|
# """
|
|
|
|
# # 获取外键信息
|
|
# fk_query = """
|
|
# SELECT
|
|
# conrelid::regclass::text AS table_name,
|
|
# a.attname AS column_name,
|
|
# confrelid::regclass::text AS foreign_table,
|
|
# af.attname AS foreign_column
|
|
# FROM
|
|
# pg_constraint
|
|
# JOIN pg_class ON conrelid = pg_class.oid
|
|
# JOIN pg_namespace n ON pg_class.relnamespace = n.oid
|
|
# JOIN pg_attribute a ON a.attrelid = conrelid AND a.attnum = ANY(conkey)
|
|
# JOIN pg_attribute af ON af.attrelid = confrelid AND af.attnum = ANY(confkey)
|
|
# WHERE
|
|
# contype = 'f'
|
|
# AND n.nspname = 'public'
|
|
# AND conrelid::regclass::text = ANY(%s);
|
|
# """
|
|
|
|
# cur.execute(column_query, (table_names,))
|
|
# columns = cur.fetchall()
|
|
|
|
# cur.execute(fk_query, (table_names,))
|
|
# fks = cur.fetchall()
|
|
|
|
# # 构建外键字典
|
|
# fk_map = {} # {(table, column): "foreign_table(foreign_column)"}
|
|
# for table, column, f_table, f_column in fks:
|
|
# fk_map[(table, column)] = f"{f_table}({f_column})"
|
|
|
|
# # 组织输出结构
|
|
# schema = {}
|
|
# for table, column, dtype, comment in columns:
|
|
# fk_note = f" -> {fk_map[(table, column)]}" if (table, column) in fk_map else ""
|
|
# comment_note = f" -- {comment}" if comment else ""
|
|
# schema.setdefault(table, []).append(f"{column} ({dtype}{fk_note}{comment_note})")
|
|
|
|
# cur.close()
|
|
|
|
# # 生成文本
|
|
# schema_text = ""
|
|
# for table, cols in schema.items():
|
|
# schema_text += f"表 {table} 包含列:\n - " + "\n - ".join(cols) + "\n"
|
|
# return schema_text
|
|
|
|
def is_safe_sql(sql:str) -> bool:
|
|
sql = sql.strip().lower()
|
|
return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql)
|
|
|
|
def execute_sql(conn, sql_query):
|
|
cur = conn.cursor()
|
|
cur.execute(sql_query)
|
|
try:
|
|
rows = cur.fetchall()
|
|
columns = [desc[0] for desc in cur.description]
|
|
result = [dict(zip(columns, row)) for row in rows]
|
|
except psycopg2.ProgrammingError:
|
|
result = cur.statusmessage
|
|
cur.close()
|
|
return result
|
|
|
|
def strip_sql_markdown(content: str) -> str:
|
|
# 去掉包裹在 ```sql 或 ``` 中的内容
|
|
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
|
|
if match:
|
|
return match.group(1).strip()
|
|
else:
|
|
return None
|
|
|
|
# ORM 写入包装函数
|
|
def save_message_thread_safe(**kwargs):
|
|
def _save():
|
|
with transaction.atomic():
|
|
Message.objects.create(**kwargs)
|
|
threading.Thread(target=_save).start()
|