import re import psycopg2 import threading from django.db import transaction from .models import Message # 数据库连接 def connect_db(): from server.conf import DATABASES db_conf = DATABASES['default'] conn = psycopg2.connect( host=db_conf['HOST'], port=db_conf['PORT'], user=db_conf['USER'], password=db_conf['PASSWORD'], database=db_conf['NAME'] ) return conn def extract_sql_code(text): # 优先尝试 ```sql 包裹的语句 match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() # fallback: 寻找首个 select 语句 match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL) if match: return match.group(1).strip() return None def get_schema_text(conn, table_names:list): cur = conn.cursor() query = """ SELECT table_name, column_name, data_type FROM information_schema.columns WHERE table_schema = 'public' and table_name in %s; """ cur.execute(query, (tuple(table_names), )) schema = {} for table_name, column_name, data_type in cur.fetchall(): if table_name not in schema: schema[table_name] = [] schema[table_name].append(f"{column_name} ({data_type})") cur.close() schema_text = "" for table_name, columns in schema.items(): schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" return schema_text def is_safe_sql(sql:str) -> bool: sql = sql.strip().lower() return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql) def execute_sql(conn, sql_query): cur = conn.cursor() cur.execute(sql_query) try: rows = cur.fetchall() columns = [desc[0] for desc in cur.description] result = [dict(zip(columns, row)) for row in rows] except psycopg2.ProgrammingError: result = cur.statusmessage cur.close() return result def strip_sql_markdown(content: str) -> str: # 去掉包裹在 ```sql 或 ``` 中的内容 match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() else: return None # ORM 写入包装函数 def save_message_thread_safe(**kwargs): def _save(): with transaction.atomic(): Message.objects.create(**kwargs) threading.Thread(target=_save).start()