89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
import re
|
|
import psycopg2
|
|
import threading
|
|
from django.db import transaction
|
|
from .models import Message
|
|
|
|
# 数据库连接
|
|
def connect_db():
|
|
from server.conf import DATABASES
|
|
db_conf = DATABASES['default']
|
|
conn = psycopg2.connect(
|
|
host=db_conf['HOST'],
|
|
port=db_conf['PORT'],
|
|
user=db_conf['USER'],
|
|
password=db_conf['PASSWORD'],
|
|
database=db_conf['NAME']
|
|
)
|
|
return conn
|
|
|
|
def extract_sql_code(text):
|
|
# 优先尝试 ```sql 包裹的语句
|
|
match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
|
|
if match:
|
|
return match.group(1).strip()
|
|
|
|
# fallback: 寻找首个 select 语句
|
|
match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL)
|
|
if match:
|
|
return match.group(1).strip()
|
|
|
|
return None
|
|
|
|
|
|
def get_schema_text(conn, table_names:list):
|
|
cur = conn.cursor()
|
|
query = """
|
|
SELECT
|
|
table_name, column_name, data_type
|
|
FROM
|
|
information_schema.columns
|
|
WHERE
|
|
table_schema = 'public'
|
|
and table_name in %s;
|
|
"""
|
|
cur.execute(query, (tuple(table_names), ))
|
|
|
|
schema = {}
|
|
for table_name, column_name, data_type in cur.fetchall():
|
|
if table_name not in schema:
|
|
schema[table_name] = []
|
|
schema[table_name].append(f"{column_name} ({data_type})")
|
|
cur.close()
|
|
schema_text = ""
|
|
for table_name, columns in schema.items():
|
|
schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n"
|
|
return schema_text
|
|
|
|
|
|
def is_safe_sql(sql:str) -> bool:
|
|
sql = sql.strip().lower()
|
|
return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql)
|
|
|
|
def execute_sql(conn, sql_query):
|
|
cur = conn.cursor()
|
|
cur.execute(sql_query)
|
|
try:
|
|
rows = cur.fetchall()
|
|
columns = [desc[0] for desc in cur.description]
|
|
result = [dict(zip(columns, row)) for row in rows]
|
|
except psycopg2.ProgrammingError:
|
|
result = cur.statusmessage
|
|
cur.close()
|
|
return result
|
|
|
|
def strip_sql_markdown(content: str) -> str:
|
|
# 去掉包裹在 ```sql 或 ``` 中的内容
|
|
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
|
|
if match:
|
|
return match.group(1).strip()
|
|
else:
|
|
return None
|
|
|
|
# ORM 写入包装函数
|
|
def save_message_thread_safe(**kwargs):
|
|
def _save():
|
|
with transaction.atomic():
|
|
Message.objects.create(**kwargs)
|
|
threading.Thread(target=_save).start()
|