89 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			89 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
import re
 | 
						|
import psycopg2
 | 
						|
import threading
 | 
						|
from django.db import transaction
 | 
						|
from .models import Message
 | 
						|
 | 
						|
# 数据库连接
 | 
						|
def connect_db():
 | 
						|
    from server.conf import DATABASES
 | 
						|
    db_conf = DATABASES['default']
 | 
						|
    conn = psycopg2.connect(
 | 
						|
        host=db_conf['HOST'],
 | 
						|
        port=db_conf['PORT'],
 | 
						|
        user=db_conf['USER'],
 | 
						|
        password=db_conf['PASSWORD'],
 | 
						|
        database=db_conf['NAME']
 | 
						|
    )
 | 
						|
    return conn
 | 
						|
 | 
						|
def extract_sql_code(text):
 | 
						|
    # 优先尝试 ```sql 包裹的语句
 | 
						|
    match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE)
 | 
						|
    if match:
 | 
						|
        return match.group(1).strip()
 | 
						|
 | 
						|
    # fallback: 寻找首个 select 语句
 | 
						|
    match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL)
 | 
						|
    if match:
 | 
						|
        return match.group(1).strip()
 | 
						|
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
def get_schema_text(conn, table_names:list):
 | 
						|
    cur = conn.cursor()
 | 
						|
    query = """
 | 
						|
        SELECT 
 | 
						|
        table_name, column_name, data_type
 | 
						|
        FROM 
 | 
						|
            information_schema.columns
 | 
						|
        WHERE 
 | 
						|
            table_schema = 'public'
 | 
						|
            and table_name in %s;
 | 
						|
    """
 | 
						|
    cur.execute(query, (tuple(table_names), ))
 | 
						|
 | 
						|
    schema = {} 
 | 
						|
    for table_name, column_name, data_type in cur.fetchall():
 | 
						|
        if table_name not in schema:
 | 
						|
            schema[table_name] = []
 | 
						|
        schema[table_name].append(f"{column_name} ({data_type})")
 | 
						|
    cur.close()
 | 
						|
    schema_text = ""
 | 
						|
    for table_name, columns in schema.items():
 | 
						|
        schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n"
 | 
						|
    return schema_text
 | 
						|
 | 
						|
 | 
						|
def  is_safe_sql(sql:str) -> bool:
 | 
						|
    sql = sql.strip().lower()
 | 
						|
    return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql)
 | 
						|
 | 
						|
def execute_sql(conn, sql_query):
 | 
						|
    cur = conn.cursor()
 | 
						|
    cur.execute(sql_query)
 | 
						|
    try:
 | 
						|
        rows = cur.fetchall()
 | 
						|
        columns = [desc[0] for desc in cur.description]
 | 
						|
        result = [dict(zip(columns, row)) for row in rows]
 | 
						|
    except psycopg2.ProgrammingError:
 | 
						|
        result = cur.statusmessage
 | 
						|
    cur.close()
 | 
						|
    return result
 | 
						|
 | 
						|
def strip_sql_markdown(content: str) -> str:
 | 
						|
        # 去掉包裹在 ```sql 或 ``` 中的内容
 | 
						|
        match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
 | 
						|
        if match:
 | 
						|
            return match.group(1).strip()
 | 
						|
        else:
 | 
						|
            return None
 | 
						|
        
 | 
						|
# ORM 写入包装函数
 | 
						|
def save_message_thread_safe(**kwargs):
 | 
						|
    def _save():
 | 
						|
        with transaction.atomic():
 | 
						|
            Message.objects.create(**kwargs)
 | 
						|
    threading.Thread(target=_save).start()
 |