zcbot/core/storage/disk_quota.py

207 lines
6.8 KiB
Python

"""Per-user 工作目录配额(§7.5 #4 软配额,应用层 gate)。
调用入口:
- `scan_user_dir(user_root) -> (bytes, count)` ── os.walk 累加,跳 dotfile / 损坏 stat
- `upsert_user_usage(user_id, bytes, count)` ── 落 user_disk_usage 表
- `check_disk_quota(user_id, limit_bytes) -> Optional[str]` ── 写前查,返 None=放行 /
str=拒绝原因。`limit_bytes <= 0` 短路放行(不限)
- `scan_all_users(user_root_base, limit_bytes)` ── lifespan 后台 task 周期跑,
per user 跑完后下一个,避免 IO 风暴
字节单位解析(yaml `disk_bytes_per_user`):
- 整数字节 / "5gb" / "500mb" / "1.5g" 等 case-insensitive 后缀
- 失败返 None,caller 视为不限
"""
from __future__ import annotations
import os
import re
from datetime import datetime
from pathlib import Path
from typing import Iterable, List, Optional, Tuple
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from .engine import session_scope
from .models import UserDiskUsage
# yaml 字节解析:5gb / 500mb / 1024 / 1.5g
_SIZE_RE = re.compile(r"^\s*([\d.]+)\s*([kmgt]?b?)?\s*$", re.IGNORECASE)
_UNIT_FACTORS = {
"": 1, "b": 1,
"k": 1024, "kb": 1024,
"m": 1024 ** 2, "mb": 1024 ** 2,
"g": 1024 ** 3, "gb": 1024 ** 3,
"t": 1024 ** 4, "tb": 1024 ** 4,
}
def parse_bytes(value) -> Optional[int]:
"""yaml 字节值 → int;无法解析返 None。"""
if value is None:
return None
if isinstance(value, int):
return value
if not isinstance(value, str):
return None
m = _SIZE_RE.match(value)
if not m:
return None
num_s, unit_s = m.group(1), (m.group(2) or "").lower()
factor = _UNIT_FACTORS.get(unit_s)
if factor is None:
return None
try:
return int(float(num_s) * factor)
except ValueError:
return None
# 扫描跳过的 dotfile 顶层名(节省 IO,且 /v1/files API 也隐藏)
_SKIP_TOPLEVEL = frozenset({".zcbot_tmp", ".memory"})
def scan_user_dir(user_root: Path) -> Tuple[int, int]:
"""os.walk 累加 user_root 下所有文件大小,返 (bytes, count)。
跳过顶层 .zcbot_tmp / .memory(开发期临时 + 用户记忆 dotfile,不算入产品配额);
follow_symlinks=False 防 symlink 循环爆。
"""
if not user_root.exists() or not user_root.is_dir():
return 0, 0
total_bytes = 0
total_count = 0
try:
for entry in os.scandir(user_root):
if entry.name in _SKIP_TOPLEVEL:
continue
try:
if entry.is_file(follow_symlinks=False):
try:
total_bytes += entry.stat(follow_symlinks=False).st_size
total_count += 1
except OSError:
pass
elif entry.is_dir(follow_symlinks=False):
sub_b, sub_c = _walk_dir(Path(entry.path))
total_bytes += sub_b
total_count += sub_c
except OSError:
pass
except OSError:
pass
return total_bytes, total_count
def _walk_dir(d: Path) -> Tuple[int, int]:
total_b, total_c = 0, 0
for root, dirs, files in os.walk(d, followlinks=False, onerror=lambda _e: None):
for f in files:
try:
st = os.stat(os.path.join(root, f), follow_symlinks=False)
total_b += st.st_size
total_c += 1
except OSError:
pass
return total_b, total_c
def upsert_user_usage(user_id: UUID, bytes_used: int, file_count: int) -> None:
"""落 user_disk_usage 单行;首次 INSERT,后续 UPDATE。"""
from sqlalchemy import func
with session_scope() as s:
stmt = pg_insert(UserDiskUsage).values(
user_id=user_id,
bytes_used=bytes_used,
file_count=file_count,
).on_conflict_do_update(
index_elements=["user_id"],
set_={
"bytes_used": bytes_used,
"file_count": file_count,
"scanned_at": func.now(),
},
)
s.execute(stmt)
def get_user_usage(user_id: UUID) -> Optional[Tuple[int, int, Optional[datetime]]]:
"""读最近一次扫描结果 (bytes, count, scanned_at);无记录返 None。"""
with session_scope() as s:
row = s.execute(
select(
UserDiskUsage.bytes_used,
UserDiskUsage.file_count,
UserDiskUsage.scanned_at,
).where(UserDiskUsage.user_id == user_id)
).first()
if row is None:
return None
return int(row[0]), int(row[1]), row[2]
def check_disk_quota(user_id: UUID, limit_bytes: int) -> Optional[str]:
"""写前 gate:超额返 错误 msg(给 LLM 直读);放行返 None。
`limit_bytes <= 0` 短路放行(不限)。无扫描记录(首次,首次扫描前)放行 ──
避免冷启动期间所有写入卡死。15min 后周期扫到就生效。
"""
if limit_bytes <= 0:
return None
usage = get_user_usage(user_id)
if usage is None:
return None # 首次,放行,首次扫描后下次 gate 才生效
used, _, _ = usage
if used >= limit_bytes:
used_mb = used / (1024 ** 2)
limit_mb = limit_bytes / (1024 ** 2)
return (
f"[Error] 已达磁盘配额上限({used_mb:.1f} MB / {limit_mb:.1f} MB);"
f"清理旧产物或联系管理员升配后重试"
)
return None
def list_user_ids_with_root(user_root_base: Path) -> List[UUID]:
"""扫 user_root_base 子目录,返合法 UUID 列表(=有 workspace 子目录的 user)。
不去 DB 查 users 全表 —— 有些 user 可能从未发消息(无 workspace 目录),无 disk 占用,
无需 upsert 占位行。
"""
if not user_root_base.is_dir():
return []
out: List[UUID] = []
try:
for entry in os.scandir(user_root_base):
if not entry.is_dir(follow_symlinks=False):
continue
try:
out.append(UUID(entry.name))
except ValueError:
continue
except OSError:
pass
return out
def scan_all_users(user_root_base: Path) -> int:
"""扫所有 user 落库,返扫描的 user 数。lifespan 后台 task 调。
串行(per user 跑完下一个)避免 IO 风暴;单 user 几秒(几百 MB 量级),N user 总耗时
线性。失败的 user 静默跳过,下次周期再试。
"""
count = 0
for uid in list_user_ids_with_root(user_root_base):
try:
b, c = scan_user_dir(user_root_base / str(uid))
upsert_user_usage(uid, b, c)
count += 1
except Exception:
# 单 user 扫挂不阻塞其他 user;下次周期重试。日志靠 caller 注入。
pass
return count