203 lines
6.7 KiB
Python
203 lines
6.7 KiB
Python
"""Per-user 工作目录配额(§7.5 #4 软配额,应用层 gate)。
|
|
|
|
调用入口:
|
|
- `scan_user_dir(user_root) -> (bytes, count)` ── os.walk 累加,跳 dotfile / 损坏 stat
|
|
- `upsert_user_usage(user_id, bytes, count)` ── 落 user_disk_usage 表
|
|
- `check_disk_quota(user_id, limit_bytes) -> Optional[str]` ── 写前查,返 None=放行 /
|
|
str=拒绝原因。`limit_bytes <= 0` 短路放行(不限)
|
|
- `scan_all_users(user_root_base, limit_bytes)` ── lifespan 后台 task 周期跑,
|
|
per user 跑完后下一个,避免 IO 风暴
|
|
|
|
字节单位解析(yaml `disk_bytes_per_user`):
|
|
- 整数字节 / "5gb" / "500mb" / "1.5g" 等 case-insensitive 后缀
|
|
- 失败返 None,caller 视为不限
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Iterable, List, Optional, Tuple
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
|
|
from .engine import session_scope
|
|
from .models import UserDiskUsage
|
|
|
|
|
|
# yaml 字节解析:5gb / 500mb / 1024 / 1.5g
|
|
_SIZE_RE = re.compile(r"^\s*([\d.]+)\s*([kmgt]?b?)?\s*$", re.IGNORECASE)
|
|
_UNIT_FACTORS = {
|
|
"": 1, "b": 1,
|
|
"k": 1024, "kb": 1024,
|
|
"m": 1024 ** 2, "mb": 1024 ** 2,
|
|
"g": 1024 ** 3, "gb": 1024 ** 3,
|
|
"t": 1024 ** 4, "tb": 1024 ** 4,
|
|
}
|
|
|
|
|
|
def parse_bytes(value) -> Optional[int]:
|
|
"""yaml 字节值 → int;无法解析返 None。"""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, int):
|
|
return value
|
|
if not isinstance(value, str):
|
|
return None
|
|
m = _SIZE_RE.match(value)
|
|
if not m:
|
|
return None
|
|
num_s, unit_s = m.group(1), (m.group(2) or "").lower()
|
|
factor = _UNIT_FACTORS.get(unit_s)
|
|
if factor is None:
|
|
return None
|
|
try:
|
|
return int(float(num_s) * factor)
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
# 扫描跳过的 dotfile 顶层名(节省 IO,且 /v1/files API 也隐藏)
|
|
_SKIP_TOPLEVEL = frozenset({".zcbot_tmp", ".memory"})
|
|
|
|
|
|
def scan_user_dir(user_root: Path) -> Tuple[int, int]:
|
|
"""os.walk 累加 user_root 下所有文件大小,返 (bytes, count)。
|
|
|
|
跳过顶层 .zcbot_tmp / .memory(开发期临时 + 用户记忆 dotfile,不算入产品配额);
|
|
follow_symlinks=False 防 symlink 循环爆。
|
|
"""
|
|
if not user_root.exists() or not user_root.is_dir():
|
|
return 0, 0
|
|
|
|
total_bytes = 0
|
|
total_count = 0
|
|
try:
|
|
for entry in os.scandir(user_root):
|
|
if entry.name in _SKIP_TOPLEVEL:
|
|
continue
|
|
try:
|
|
if entry.is_file(follow_symlinks=False):
|
|
try:
|
|
total_bytes += entry.stat(follow_symlinks=False).st_size
|
|
total_count += 1
|
|
except OSError:
|
|
pass
|
|
elif entry.is_dir(follow_symlinks=False):
|
|
sub_b, sub_c = _walk_dir(Path(entry.path))
|
|
total_bytes += sub_b
|
|
total_count += sub_c
|
|
except OSError:
|
|
pass
|
|
except OSError:
|
|
pass
|
|
return total_bytes, total_count
|
|
|
|
|
|
def _walk_dir(d: Path) -> Tuple[int, int]:
|
|
total_b, total_c = 0, 0
|
|
for root, dirs, files in os.walk(d, followlinks=False, onerror=lambda _e: None):
|
|
for f in files:
|
|
try:
|
|
st = os.stat(os.path.join(root, f), follow_symlinks=False)
|
|
total_b += st.st_size
|
|
total_c += 1
|
|
except OSError:
|
|
pass
|
|
return total_b, total_c
|
|
|
|
|
|
def upsert_user_usage(user_id: UUID, bytes_used: int, file_count: int) -> None:
|
|
"""落 user_disk_usage 单行;首次 INSERT,后续 UPDATE。"""
|
|
from sqlalchemy import func
|
|
with session_scope() as s:
|
|
stmt = pg_insert(UserDiskUsage).values(
|
|
user_id=user_id,
|
|
bytes_used=bytes_used,
|
|
file_count=file_count,
|
|
).on_conflict_do_update(
|
|
index_elements=["user_id"],
|
|
set_={
|
|
"bytes_used": bytes_used,
|
|
"file_count": file_count,
|
|
"scanned_at": func.now(),
|
|
},
|
|
)
|
|
s.execute(stmt)
|
|
|
|
|
|
def get_user_usage(user_id: UUID) -> Optional[Tuple[int, int]]:
|
|
"""读最近一次扫描结果 (bytes, count);无记录返 None。"""
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(UserDiskUsage.bytes_used, UserDiskUsage.file_count)
|
|
.where(UserDiskUsage.user_id == user_id)
|
|
).first()
|
|
if row is None:
|
|
return None
|
|
return int(row[0]), int(row[1])
|
|
|
|
|
|
def check_disk_quota(user_id: UUID, limit_bytes: int) -> Optional[str]:
|
|
"""写前 gate:超额返 错误 msg(给 LLM 直读);放行返 None。
|
|
|
|
`limit_bytes <= 0` 短路放行(不限)。无扫描记录(首次,首次扫描前)放行 ──
|
|
避免冷启动期间所有写入卡死。15min 后周期扫到就生效。
|
|
"""
|
|
if limit_bytes <= 0:
|
|
return None
|
|
usage = get_user_usage(user_id)
|
|
if usage is None:
|
|
return None # 首次,放行,首次扫描后下次 gate 才生效
|
|
used, _ = usage
|
|
if used >= limit_bytes:
|
|
used_mb = used / (1024 ** 2)
|
|
limit_mb = limit_bytes / (1024 ** 2)
|
|
return (
|
|
f"[Error] 已达磁盘配额上限({used_mb:.1f} MB / {limit_mb:.1f} MB);"
|
|
f"清理旧产物或联系管理员升配后重试"
|
|
)
|
|
return None
|
|
|
|
|
|
def list_user_ids_with_root(user_root_base: Path) -> List[UUID]:
|
|
"""扫 user_root_base 子目录,返合法 UUID 列表(=有 workspace 子目录的 user)。
|
|
|
|
不去 DB 查 users 全表 —— 有些 user 可能从未发消息(无 workspace 目录),无 disk 占用,
|
|
无需 upsert 占位行。
|
|
"""
|
|
if not user_root_base.is_dir():
|
|
return []
|
|
out: List[UUID] = []
|
|
try:
|
|
for entry in os.scandir(user_root_base):
|
|
if not entry.is_dir(follow_symlinks=False):
|
|
continue
|
|
try:
|
|
out.append(UUID(entry.name))
|
|
except ValueError:
|
|
continue
|
|
except OSError:
|
|
pass
|
|
return out
|
|
|
|
|
|
def scan_all_users(user_root_base: Path) -> int:
|
|
"""扫所有 user 落库,返扫描的 user 数。lifespan 后台 task 调。
|
|
|
|
串行(per user 跑完下一个)避免 IO 风暴;单 user 几秒(几百 MB 量级),N user 总耗时
|
|
线性。失败的 user 静默跳过,下次周期再试。
|
|
"""
|
|
count = 0
|
|
for uid in list_user_ids_with_root(user_root_base):
|
|
try:
|
|
b, c = scan_user_dir(user_root_base / str(uid))
|
|
upsert_user_usage(uid, b, c)
|
|
count += 1
|
|
except Exception:
|
|
# 单 user 扫挂不阻塞其他 user;下次周期重试。日志靠 caller 注入。
|
|
pass
|
|
return count
|