zcbot/web/auth.py

188 lines
7.2 KiB
Python

"""Auth: 两条 login 路径,签同款 JWT(§7 D' 过渡形态)。
模型:
- `PLATFORM_KEY` env(必填):platform 服务端 / zcbot 间机器对机器共享密钥
- `JWT_SECRET` env(必填):HS256 签 token;泄漏 = 任意伪造,与 PLATFORM_KEY 同级保护
- `POST /v1/auth/login {user_id, platform_key}` → JWT(platform 服务端用,自带 user_id 注入)
- `POST /v1/auth/login_password {email, password}` → JWT
(dev SPA 用,users.email UNIQUE + users.password_hash bcrypt 校验;0005 加 UNIQUE)
- 后续 `/v1/*`(除 /healthz、/docs、/openapi.json、/、/v1/auth/login*)走 `Authorization: Bearer <jwt>`
- Token TTL: `ZCBOT_JWT_TTL_SECONDS` env 覆盖,默 7 天
发用户:`.venv/Scripts/python.exe main.py user add --email X --password Y`,后台直接
bcrypt + INSERT users;撤用户 `DELETE FROM users WHERE email=...`(messages CASCADE,
tasks 通过 FK 拦,要先 DELETE 该 user 的 tasks)。
OIDC(D')替换:只动 `/v1/auth/login` 实现(校验 ID token 代替 key);password 路径
真发布时下线。
"""
from __future__ import annotations
import os
import time
from typing import Optional
from uuid import UUID
import bcrypt
import jwt
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy import select
from core.storage import session_scope
from core.storage.models import User
_DEFAULT_TTL_SECONDS = 7 * 24 * 3600 # 7d
class AuthConfig:
"""App 启动时一次性读 env + 校验存在性;create_app 调 `AuthConfig.from_env()` 拿到。"""
def __init__(self, platform_key: str, jwt_secret: str, ttl_seconds: int):
self.platform_key = platform_key
self.jwt_secret = jwt_secret
self.ttl_seconds = ttl_seconds
@classmethod
def from_env(cls) -> "AuthConfig":
key = os.environ.get("PLATFORM_KEY", "").strip()
secret = os.environ.get("JWT_SECRET", "").strip()
missing = []
if not key:
missing.append("PLATFORM_KEY")
if not secret:
missing.append("JWT_SECRET")
if missing:
raise RuntimeError(
f"{', '.join(missing)} env not set. zcbot web requires both:\n"
" PLATFORM_KEY=<shared secret between platform and zcbot>\n"
" JWT_SECRET=<HMAC secret used to sign session tokens>"
)
ttl_raw = os.environ.get("ZCBOT_JWT_TTL_SECONDS", "").strip()
try:
ttl = int(ttl_raw) if ttl_raw else _DEFAULT_TTL_SECONDS
except ValueError:
raise RuntimeError(
f"ZCBOT_JWT_TTL_SECONDS must be int seconds, got {ttl_raw!r}"
)
if ttl <= 0:
raise RuntimeError(f"ZCBOT_JWT_TTL_SECONDS must be > 0, got {ttl}")
return cls(platform_key=key, jwt_secret=secret, ttl_seconds=ttl)
def hash_password(password: str) -> str:
"""bcrypt 哈希(默认 cost=12)。返 ASCII str(bcrypt 标准格式 `$2b$12$...`),直接落 DB。"""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("ascii")
def verify_password(password: str, stored_hash: str) -> bool:
"""常数时间比对。stored_hash 是 DB 里 users.password_hash 列。"""
try:
return bcrypt.checkpw(password.encode("utf-8"), stored_hash.encode("ascii"))
except (ValueError, TypeError):
# stored_hash 格式坏(手工 INSERT 乱写)/ 不是 ASCII → 视作不匹配,别 500
return False
def resolve_user_by_email(email: str, password: str) -> Optional[tuple[UUID, str]]:
"""email + password → `(user_id, email)`;不匹配返 None(空表 / 邮箱不存在 / 密码错都走这条)。
单次 SELECT + bcrypt verify;不缓存,改密码 / 删账号下次 login 立即生效。
bcrypt.checkpw 本身是 constant-time;查不到也要跑一次 dummy hash 防 timing oracle
(5 人级别用户无所谓,但顺手做)。
"""
e = (email or "").strip().lower()
if not e or not password:
return None
with session_scope() as s:
row = s.execute(
select(User.user_id, User.email, User.password_hash).where(User.email == e)
).first()
if row is None:
# 避免 timing oracle:用户不存在时也跑一次同等开销的 verify
bcrypt.checkpw(b"x", b"$2b$12$" + b"." * 53)
return None
if not row.password_hash:
return None # 用户存在但没设密码(platform_key 入口建的)
if not verify_password(password, row.password_hash):
return None
return row.user_id, row.email
def mint_token(cfg: AuthConfig, user_id: UUID) -> tuple[str, int]:
"""签 JWT。返回 `(token, exp_unix_seconds)`。"""
now = int(time.time())
exp = now + cfg.ttl_seconds
payload = {"sub": str(user_id), "iat": now, "exp": exp}
token = jwt.encode(payload, cfg.jwt_secret, algorithm="HS256")
return token, exp
def verify_token(cfg: AuthConfig, token: str) -> UUID:
"""验签 + 取 sub。失败抛 HTTPException 401。"""
try:
payload = jwt.decode(token, cfg.jwt_secret, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise HTTPException(401, "token expired")
except jwt.InvalidTokenError as e:
raise HTTPException(401, f"invalid token: {e}")
sub = payload.get("sub", "")
try:
return UUID(sub)
except (ValueError, TypeError):
raise HTTPException(401, f"invalid sub in token: {sub!r}")
def ensure_user_row(user_id: UUID) -> None:
"""幂等 INSERT 一行 users 占位(`ON CONFLICT DO NOTHING`)。
platform_key 登录入口用 — 平台直传的 user_id 可能是 zcbot 没见过的,首次登录建行
避免下游 FK 失败。邮箱密码登录走 `main.py user add` 已经写好 users 行,不走这条。
"""
from sqlalchemy.dialects.postgresql import insert
stmt = insert(User).values(user_id=user_id).on_conflict_do_nothing(
index_elements=["user_id"]
)
with session_scope() as s:
s.execute(stmt)
# ──────────────── FastAPI Depends ────────────────
# auto_error=False 让我们自己出 401 文案,而不是 FastAPI 默认 "Not authenticated"
_bearer = HTTPBearer(auto_error=False)
def make_require_user(cfg: AuthConfig):
"""工厂:返回一个 Depends 函数,闭包持有 cfg(避免 app 启动后改 env)。
用法:
require_user = make_require_user(cfg)
@app.get("/v1/...", dependencies=[Depends(require_user)])
def route(user_id: UUID = Depends(require_user)):
...
实际使用建议直接 `user_id: UUID = Depends(require_user)`,既验签又拿到 user_id。
"""
async def require_user(
creds: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
) -> UUID:
if creds is None or not creds.credentials:
raise HTTPException(401, "missing Authorization: Bearer <token>")
if creds.scheme.lower() != "bearer":
raise HTTPException(401, f"unsupported auth scheme: {creds.scheme!r}")
return verify_token(cfg, creds.credentials)
return require_user
__all__ = [
"AuthConfig",
"ensure_user_row",
"hash_password",
"make_require_user",
"mint_token",
"resolve_user_by_email",
"verify_password",
"verify_token",
]