66 lines
1.8 KiB
Python
66 lines
1.8 KiB
Python
"""PG 连接 + Session factory。
|
|
|
|
`ZCBOT_DB_URL` 必填,标准 SQLAlchemy URL,如:
|
|
postgresql+psycopg://user:pass@host:5432/zcbot
|
|
|
|
未设置时 get_engine() 抛 RuntimeError 并打印指引(不引导 docker)。
|
|
users 行由 web auth 入口按需 INSERT (`web.auth.ensure_user_row`),引擎层不再 bootstrap。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from contextlib import contextmanager
|
|
from typing import Iterator, Optional
|
|
|
|
from sqlalchemy import Engine, create_engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
_engine: Optional[Engine] = None
|
|
_SessionLocal: Optional[sessionmaker[Session]] = None
|
|
|
|
|
|
_DB_URL_HINT = (
|
|
"ZCBOT_DB_URL is not set.\n"
|
|
" export ZCBOT_DB_URL='postgresql+psycopg://user:pass@host:5432/dbname'\n"
|
|
" (local: dev/staging PG; SaaS: production PG)"
|
|
)
|
|
|
|
|
|
def _read_db_url() -> str:
|
|
url = os.environ.get("ZCBOT_DB_URL", "").strip()
|
|
if not url:
|
|
raise RuntimeError(_DB_URL_HINT)
|
|
return url
|
|
|
|
|
|
def get_engine() -> Engine:
|
|
"""单例 engine。线程安全(SQLAlchemy 内置 pool)。"""
|
|
global _engine, _SessionLocal
|
|
if _engine is None:
|
|
url = _read_db_url()
|
|
_engine = create_engine(url, pool_pre_ping=True, future=True)
|
|
_SessionLocal = sessionmaker(bind=_engine, expire_on_commit=False, future=True)
|
|
return _engine
|
|
|
|
|
|
def get_sessionmaker() -> sessionmaker[Session]:
|
|
if _SessionLocal is None:
|
|
get_engine()
|
|
assert _SessionLocal is not None
|
|
return _SessionLocal
|
|
|
|
|
|
@contextmanager
|
|
def session_scope() -> Iterator[Session]:
|
|
"""事务上下文:成功 commit,异常 rollback,总是 close。"""
|
|
sm = get_sessionmaker()
|
|
s = sm()
|
|
try:
|
|
yield s
|
|
s.commit()
|
|
except Exception:
|
|
s.rollback()
|
|
raise
|
|
finally:
|
|
s.close()
|