zcbot/core/storage/utils.py

178 lines
6.7 KiB
Python

"""Storage 辅助:tasks 表的 idempotent 创建 / UPSERT / UPDATE / no-subtask 校验。"""
from __future__ import annotations
from typing import Any, Iterable, Optional
from uuid import UUID
from sqlalchemy import func, select, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError
from .engine import session_scope
from .models import Message, Task
class NoSubtaskError(ValueError):
"""working_dir 与同 user 已有 task 形成前缀嵌套(§7.4 no-subtask 策略)。"""
def ensure_local_task_row(
task_id: UUID,
name: str,
user_id: UUID,
working_dir: str = "",
skill: str = "",
description: str = "",
model: str = "",
model_profile: str = "",
reasoning_effort: str = "",
channel: str = "web",
scheduled_job_id: Optional[UUID] = None,
) -> None:
"""占位 INSERT(ON CONFLICT DO NOTHING)—— 不覆盖已有字段。
用于 `Session.append` 在首条非 system 消息前打底 tasks 行,避免 messages
FK 违反。字段是 build_agent 阶段已知的最小集;TaskState.save 之后会通过
`upsert_task` 把真实字段(desc/status/tokens 等)写进去。`name` 必填(列 NOT NULL),
调用方应已 validate。
"""
stmt = (
insert(Task)
.values(
task_id=task_id,
user_id=user_id,
name=name,
working_dir=working_dir,
skill=skill,
description=description,
model=model,
model_profile=model_profile,
reasoning_effort=reasoning_effort,
channel=channel,
scheduled_job_id=scheduled_job_id,
)
.on_conflict_do_nothing(index_elements=["task_id"])
)
with session_scope() as s:
s.execute(stmt)
def append_channel_message(
task_id: UUID, content: str, *, role: str = "assistant", kind: Optional[str] = None
) -> None:
"""往 task 追加一条非 agent-run 产生的消息(push 出站记录等)。原子算 idx
(SELECT max(idx)+1)+INSERT;撞 uq_messages_task_idx(与入站 agent run 并发
append)→ 重试。payload 形态同 Session.append 的 {role, content};不设
model_profile / tokens_*(非模型产出,usage 不计)。kind 写 messages.kind 列
(独立列,不进 payload):"push" 标记 push 记录,extract_last_assistant_text 据此跳过。"""
payload = {"role": role, "content": content}
last_err: Optional[Exception] = None
for _ in range(3):
try:
with session_scope() as s:
max_idx = s.execute(
select(func.max(Message.idx)).where(Message.task_id == task_id)
).scalar()
next_idx = (max_idx if max_idx is not None else -1) + 1
s.add(Message(task_id=task_id, idx=next_idx, payload=payload, kind=kind))
return
except IntegrityError as e:
last_err = e
continue
raise RuntimeError(f"append_channel_message: idx 冲突重试耗尽: {last_err}")
def upsert_task(
task_id: UUID,
*,
user_id: UUID,
**fields: Any,
) -> None:
"""INSERT ... ON CONFLICT DO UPDATE —— TaskState.save 的落地点。
fields 可包含 tasks 表任意可写列(name/working_dir/skill/description/status/model/
model_profile/reasoning_effort/tokens_prompt/tokens_completion/cost_cny)。
不传的字段在 INSERT 时走 ORM 默认值,UPDATE 时不动。
INSERT 路径需要 name(NOT NULL)+ working_dir;纯 UPDATE 路径(行已存在)不强制。
"""
values = {"task_id": task_id, "user_id": user_id, **fields}
stmt = insert(Task).values(**values)
update_cols = {k: stmt.excluded[k] for k in fields}
if update_cols:
# ORM 的 onupdate=func.now() 只在 ORM-level UPDATE 触发,DO UPDATE 是 raw DML
# 不会自动刷 updated_at —— 这里显式追加。
update_cols["updated_at"] = func.now()
stmt = stmt.on_conflict_do_update(
index_elements=["task_id"], set_=update_cols
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=["task_id"])
with session_scope() as s:
s.execute(stmt)
def update_task(task_id: UUID, **fields: Any) -> int:
"""UPDATE 已有 tasks 行;不存在则 no-op(返回 0)。
ORM-level update 会带 onupdate=func.now() 自动刷 updated_at,无需显式传。
"""
if not fields:
return 0
with session_scope() as s:
result = s.execute(
update(Task).where(Task.task_id == task_id).values(**fields)
)
return result.rowcount or 0
def get_task(task_id: UUID) -> Optional[Task]:
"""读 tasks 行,不存在返回 None。"""
with session_scope() as s:
return s.execute(
select(Task).where(Task.task_id == task_id)
).scalar_one_or_none()
def check_no_subtask(
working_dir: str,
user_id: UUID,
exclude_task_ids: Optional[Iterable[UUID]] = None,
) -> None:
"""§7.4 no-subtask:同 user 下校验 working_dir 不能与已有 working_dir 形成前缀嵌套。
允许:同 working_dir(同项目多对话)、完全无关路径(平级或不相关)。
拒绝:new 是 existing 的子目录、existing 是 new 的子目录。
空 working_dir / 仅 whitespace 跳过(legacy / 未绑项目)。
`working_dir` 入参既可以是 db 形态(相对 ROOT)也可以是 absolute str,内部统一用
`from_db_path` 归一到 absolute posix 后再比前缀;DB 里行的两种形态同样归一。
数量小(per user 几十量级),全量拉到 Python 端比对,不在 SQL 里拼分隔符 / 前缀。
`exclude_task_ids` 用于 rename 场景:正在被一起改名的 task 是平移过去的,内部不冲突,
需要从比对集合里排掉,否则它们会和"自己未来的 working_dir"误判嵌套。
"""
if not working_dir or not working_dir.strip():
return
from core.paths import from_db_path
new_abs = from_db_path(working_dir).as_posix()
if not new_abs:
return
exclude = set(exclude_task_ids or ())
with session_scope() as s:
rows = s.execute(
select(Task.task_id, Task.working_dir)
.where(Task.user_id == user_id, Task.working_dir != "")
).all()
for existing_id, existing_dir in rows:
if existing_id in exclude:
continue
existing_abs = from_db_path(existing_dir).as_posix()
if not existing_abs or existing_abs == new_abs:
continue
if new_abs.startswith(existing_abs + "/") or existing_abs.startswith(new_abs + "/"):
raise NoSubtaskError(
f"working_dir {working_dir!r} 与已有 task {str(existing_id)[:8]}"
f"working_dir {existing_dir!r} 前缀嵌套 — 同项目多对话请用相同 working_dir"
)