factory/apps/ichat/utils.py

89 lines
2.5 KiB
Python

import re
import psycopg2
import threading
from django.db import transaction
from .models import Message
# 数据库连接
def connect_db():
from server.conf import DATABASES
db_conf = DATABASES['default']
conn = psycopg2.connect(
host=db_conf['HOST'],
port=db_conf['PORT'],
user=db_conf['USER'],
password=db_conf['PASSWORD'],
database=db_conf['NAME']
)
return conn
def extract_sql_code(text):
# 优先尝试 ```sql 包裹的语句
match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
# fallback: 寻找首个 select 语句
match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip()
return None
def get_schema_text(conn, table_names:list):
cur = conn.cursor()
query = """
SELECT
table_name, column_name, data_type
FROM
information_schema.columns
WHERE
table_schema = 'public'
and table_name in %s;
"""
cur.execute(query, (tuple(table_names), ))
schema = {}
for table_name, column_name, data_type in cur.fetchall():
if table_name not in schema:
schema[table_name] = []
schema[table_name].append(f"{column_name} ({data_type})")
cur.close()
schema_text = ""
for table_name, columns in schema.items():
schema_text += f"{table_name} 包含列:{', '.join(columns)}\n"
return schema_text
def is_safe_sql(sql:str) -> bool:
sql = sql.strip().lower()
return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql)
def execute_sql(conn, sql_query):
cur = conn.cursor()
cur.execute(sql_query)
try:
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description]
result = [dict(zip(columns, row)) for row in rows]
except psycopg2.ProgrammingError:
result = cur.statusmessage
cur.close()
return result
def strip_sql_markdown(content: str) -> str:
# 去掉包裹在 ```sql 或 ``` 中的内容
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
else:
return None
# ORM 写入包装函数
def save_message_thread_safe(**kwargs):
def _save():
with transaction.atomic():
Message.objects.create(**kwargs)
threading.Thread(target=_save).start()