import re import psycopg2 import threading from django.db import transaction from .models import Message # 数据库连接 def connect_db(): from server.conf import DATABASES db_conf = DATABASES['default'] conn = psycopg2.connect( host=db_conf['HOST'], port=db_conf['PORT'], user=db_conf['USER'], password=db_conf['PASSWORD'], database=db_conf['NAME'] ) return conn def extract_sql_code(text): # 优先尝试 ```sql 包裹的语句 match = re.search(r"```sql\s*(.+?)```", text, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() # fallback: 寻找首个 select 语句 match = re.search(r"(SELECT\s.+?;)", text, re.IGNORECASE | re.DOTALL) if match: return match.group(1).strip() return None # def get_schema_text(conn, table_names:list): # cur = conn.cursor() # query = """ # SELECT # table_name, column_name, data_type # FROM # information_schema.columns # WHERE # table_schema = 'public' # and table_name in %s; # """ # cur.execute(query, (tuple(table_names), )) # schema = {} # for table_name, column_name, data_type in cur.fetchall(): # if table_name not in schema: # schema[table_name] = [] # schema[table_name].append(f"{column_name} ({data_type})") # cur.close() # schema_text = "" # for table_name, columns in schema.items(): # schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n" # return schema_text def get_schema_text(conn, table_names: list): cur = conn.cursor() query = """ SELECT c.relname AS table_name, a.attname AS column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, d.description AS column_comment FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace JOIN pg_attribute a ON a.attrelid = c.oid LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum WHERE n.nspname = 'public' AND c.relname = ANY(%s) AND a.attnum > 0 AND NOT a.attisdropped ORDER BY c.relname, a.attnum; """ cur.execute(query, (table_names,)) schema = {} for table_name, column_name, data_type, comment in cur.fetchall(): if comment and "备注" in comment: comment = comment.split("备注")[0].strip() schema.setdefault(table_name, []).append( f"{column_name}-{comment}" ) cur.close() return [ {"table": table, "text": f"表 {table} 包含列:\n" + "\n".join(columns)} for table, columns in schema.items() ] # def get_schema_text(conn, table_names: list): # cur = conn.cursor() # # 获取字段、类型、注释 # column_query = """ # SELECT # c.relname AS table_name, # a.attname AS column_name, # pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, # d.description AS column_comment # FROM # pg_class c # JOIN pg_namespace n ON n.oid = c.relnamespace # JOIN pg_attribute a ON a.attrelid = c.oid # LEFT JOIN pg_description d ON d.objoid = a.attrelid AND d.objsubid = a.attnum # WHERE # n.nspname = 'public' # AND c.relname = ANY(%s) # AND a.attnum > 0 # AND NOT a.attisdropped # ORDER BY # c.relname, a.attnum; # """ # # 获取外键信息 # fk_query = """ # SELECT # conrelid::regclass::text AS table_name, # a.attname AS column_name, # confrelid::regclass::text AS foreign_table, # af.attname AS foreign_column # FROM # pg_constraint # JOIN pg_class ON conrelid = pg_class.oid # JOIN pg_namespace n ON pg_class.relnamespace = n.oid # JOIN pg_attribute a ON a.attrelid = conrelid AND a.attnum = ANY(conkey) # JOIN pg_attribute af ON af.attrelid = confrelid AND af.attnum = ANY(confkey) # WHERE # contype = 'f' # AND n.nspname = 'public' # AND conrelid::regclass::text = ANY(%s); # """ # cur.execute(column_query, (table_names,)) # columns = cur.fetchall() # cur.execute(fk_query, (table_names,)) # fks = cur.fetchall() # # 构建外键字典 # fk_map = {} # {(table, column): "foreign_table(foreign_column)"} # for table, column, f_table, f_column in fks: # fk_map[(table, column)] = f"{f_table}({f_column})" # # 组织输出结构 # schema = {} # for table, column, dtype, comment in columns: # fk_note = f" -> {fk_map[(table, column)]}" if (table, column) in fk_map else "" # comment_note = f" -- {comment}" if comment else "" # schema.setdefault(table, []).append(f"{column} ({dtype}{fk_note}{comment_note})") # cur.close() # # 生成文本 # schema_text = "" # for table, cols in schema.items(): # schema_text += f"表 {table} 包含列:\n - " + "\n - ".join(cols) + "\n" # return schema_text def is_safe_sql(sql:str) -> bool: sql = sql.strip().lower() return sql.startswith("select") or sql.startswith("show") and not re.search(r"delete|update|insert|drop|create|alter", sql) def execute_sql(conn, sql_query): cur = conn.cursor() cur.execute(sql_query) try: rows = cur.fetchall() columns = [desc[0] for desc in cur.description] result = [dict(zip(columns, row)) for row in rows] except psycopg2.ProgrammingError: result = cur.statusmessage cur.close() return result def strip_sql_markdown(content: str) -> str: # 去掉包裹在 ```sql 或 ``` 中的内容 match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() else: return None # ORM 写入包装函数 def save_message_thread_safe(**kwargs): def _save(): with transaction.atomic(): Message.objects.create(**kwargs) threading.Thread(target=_save).start()