148 lines
5.4 KiB
Python
148 lines
5.4 KiB
Python
"""Storage 辅助:tasks 表的 idempotent 创建 / UPSERT / UPDATE / no-subtask 校验。"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Iterable, Optional
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
from .engine import session_scope
|
|
from .models import SENTINEL_USER_ID, Task
|
|
|
|
|
|
class NoSubtaskError(ValueError):
|
|
"""working_dir 与同 user 已有 task 形成前缀嵌套(§7.4 no-subtask 策略)。"""
|
|
|
|
|
|
def ensure_local_task_row(
|
|
task_id: UUID,
|
|
name: str,
|
|
working_dir: str = "",
|
|
skill: str = "",
|
|
description: str = "",
|
|
model: str = "",
|
|
model_profile: str = "",
|
|
reasoning_effort: str = "",
|
|
user_id: UUID = SENTINEL_USER_ID,
|
|
) -> None:
|
|
"""占位 INSERT(ON CONFLICT DO NOTHING)—— 不覆盖已有字段。
|
|
|
|
用于 `Session.append` 在首条非 system 消息前打底 tasks 行,避免 messages
|
|
FK 违反。字段是 build_agent 阶段已知的最小集;TaskState.save 之后会通过
|
|
`upsert_task` 把真实字段(desc/status/tokens 等)写进去。`name` 必填(列 NOT NULL),
|
|
调用方应已 validate。
|
|
"""
|
|
stmt = (
|
|
insert(Task)
|
|
.values(
|
|
task_id=task_id,
|
|
user_id=user_id,
|
|
name=name,
|
|
working_dir=working_dir,
|
|
skill=skill,
|
|
description=description,
|
|
model=model,
|
|
model_profile=model_profile,
|
|
reasoning_effort=reasoning_effort,
|
|
)
|
|
.on_conflict_do_nothing(index_elements=["task_id"])
|
|
)
|
|
with session_scope() as s:
|
|
s.execute(stmt)
|
|
|
|
|
|
def upsert_task(
|
|
task_id: UUID,
|
|
*,
|
|
user_id: UUID = SENTINEL_USER_ID,
|
|
**fields: Any,
|
|
) -> None:
|
|
"""INSERT ... ON CONFLICT DO UPDATE —— TaskState.save 的落地点。
|
|
|
|
fields 可包含 tasks 表任意可写列(name/working_dir/skill/description/status/model/
|
|
model_profile/reasoning_effort/tokens_prompt/tokens_completion/cost_usd)。
|
|
不传的字段在 INSERT 时走 ORM 默认值,UPDATE 时不动。
|
|
INSERT 路径需要 name(NOT NULL)+ working_dir;纯 UPDATE 路径(行已存在)不强制。
|
|
"""
|
|
values = {"task_id": task_id, "user_id": user_id, **fields}
|
|
stmt = insert(Task).values(**values)
|
|
update_cols = {k: stmt.excluded[k] for k in fields}
|
|
if update_cols:
|
|
# ORM 的 onupdate=func.now() 只在 ORM-level UPDATE 触发,DO UPDATE 是 raw DML
|
|
# 不会自动刷 updated_at —— 这里显式追加。
|
|
update_cols["updated_at"] = func.now()
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=["task_id"], set_=update_cols
|
|
)
|
|
else:
|
|
stmt = stmt.on_conflict_do_nothing(index_elements=["task_id"])
|
|
with session_scope() as s:
|
|
s.execute(stmt)
|
|
|
|
|
|
def update_task(task_id: UUID, **fields: Any) -> int:
|
|
"""UPDATE 已有 tasks 行;不存在则 no-op(返回 0)。
|
|
|
|
ORM-level update 会带 onupdate=func.now() 自动刷 updated_at,无需显式传。
|
|
"""
|
|
if not fields:
|
|
return 0
|
|
with session_scope() as s:
|
|
result = s.execute(
|
|
update(Task).where(Task.task_id == task_id).values(**fields)
|
|
)
|
|
return result.rowcount or 0
|
|
|
|
|
|
def get_task(task_id: UUID) -> Optional[Task]:
|
|
"""读 tasks 行,不存在返回 None。"""
|
|
with session_scope() as s:
|
|
return s.execute(
|
|
select(Task).where(Task.task_id == task_id)
|
|
).scalar_one_or_none()
|
|
|
|
|
|
def check_no_subtask(
|
|
working_dir: str,
|
|
user_id: UUID = SENTINEL_USER_ID,
|
|
exclude_task_ids: Optional[Iterable[UUID]] = None,
|
|
) -> None:
|
|
"""§7.4 no-subtask:同 user 下校验 working_dir 不能与已有 working_dir 形成前缀嵌套。
|
|
|
|
允许:同 working_dir(同项目多对话)、完全无关路径(平级或不相关)。
|
|
拒绝:new 是 existing 的子目录、existing 是 new 的子目录。
|
|
空 working_dir / 仅 whitespace 跳过(legacy / 未绑项目)。
|
|
|
|
`working_dir` 入参既可以是 db 形态(相对 ROOT)也可以是 absolute str,内部统一用
|
|
`from_db_path` 归一到 absolute posix 后再比前缀;DB 里行的两种形态同样归一。
|
|
数量小(per user 几十量级),全量拉到 Python 端比对,不在 SQL 里拼分隔符 / 前缀。
|
|
|
|
`exclude_task_ids` 用于 rename 场景:正在被一起改名的 task 是平移过去的,内部不冲突,
|
|
需要从比对集合里排掉,否则它们会和"自己未来的 working_dir"误判嵌套。
|
|
"""
|
|
if not working_dir or not working_dir.strip():
|
|
return
|
|
from core.paths import from_db_path
|
|
|
|
new_abs = from_db_path(working_dir).as_posix()
|
|
if not new_abs:
|
|
return
|
|
exclude = set(exclude_task_ids or ())
|
|
with session_scope() as s:
|
|
rows = s.execute(
|
|
select(Task.task_id, Task.working_dir)
|
|
.where(Task.user_id == user_id, Task.working_dir != "")
|
|
).all()
|
|
for existing_id, existing_dir in rows:
|
|
if existing_id in exclude:
|
|
continue
|
|
existing_abs = from_db_path(existing_dir).as_posix()
|
|
if not existing_abs or existing_abs == new_abs:
|
|
continue
|
|
if new_abs.startswith(existing_abs + "/") or existing_abs.startswith(new_abs + "/"):
|
|
raise NoSubtaskError(
|
|
f"working_dir {working_dir!r} 与已有 task {str(existing_id)[:8]} 的 "
|
|
f"working_dir {existing_dir!r} 前缀嵌套 — 同项目多对话请用相同 working_dir"
|
|
)
|