zcbot/core/storage/utils.py

96 lines
3.0 KiB
Python

"""Storage 辅助:tasks 表的 idempotent 创建 / UPSERT / UPDATE。"""
from __future__ import annotations
from typing import Any, Optional
from uuid import UUID
from sqlalchemy import func, select, update
from sqlalchemy.dialects.postgresql import insert
from .engine import session_scope
from .models import SENTINEL_USER_ID, Task
def ensure_local_task_row(
task_id: UUID,
task_dir: str = "",
mode: str = "",
description: str = "",
model: str = "",
model_profile: str = "",
reasoning_effort: str = "",
user_id: UUID = SENTINEL_USER_ID,
) -> None:
"""占位 INSERT(ON CONFLICT DO NOTHING)—— 不覆盖已有字段。
用于 `Session.append` 在首条非 system 消息前打底 tasks 行,避免 messages
FK 违反。字段是 build_agent 阶段已知的最小集;TaskState.save 之后会通过
`upsert_task` 把真实字段(desc/status/tokens 等)写进去。
"""
stmt = (
insert(Task)
.values(
task_id=task_id,
user_id=user_id,
task_dir=task_dir,
mode=mode,
description=description,
model=model,
model_profile=model_profile,
reasoning_effort=reasoning_effort,
)
.on_conflict_do_nothing(index_elements=["task_id"])
)
with session_scope() as s:
s.execute(stmt)
def upsert_task(
task_id: UUID,
*,
user_id: UUID = SENTINEL_USER_ID,
**fields: Any,
) -> None:
"""INSERT ... ON CONFLICT DO UPDATE —— TaskState.save 的落地点。
fields 可包含 tasks 表任意可写列(task_dir/mode/description/status/model/
model_profile/reasoning_effort/tokens_prompt/tokens_completion/cost_usd)。
不传的字段在 INSERT 时走 ORM 默认值,UPDATE 时不动。
"""
values = {"task_id": task_id, "user_id": user_id, **fields}
stmt = insert(Task).values(**values)
update_cols = {k: stmt.excluded[k] for k in fields}
if update_cols:
# ORM 的 onupdate=func.now() 只在 ORM-level UPDATE 触发,DO UPDATE 是 raw DML
# 不会自动刷 updated_at —— 这里显式追加。
update_cols["updated_at"] = func.now()
stmt = stmt.on_conflict_do_update(
index_elements=["task_id"], set_=update_cols
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=["task_id"])
with session_scope() as s:
s.execute(stmt)
def update_task(task_id: UUID, **fields: Any) -> int:
"""UPDATE 已有 tasks 行;不存在则 no-op(返回 0)。
ORM-level update 会带 onupdate=func.now() 自动刷 updated_at,无需显式传。
"""
if not fields:
return 0
with session_scope() as s:
result = s.execute(
update(Task).where(Task.task_id == task_id).values(**fields)
)
return result.rowcount or 0
def get_task(task_id: UUID) -> Optional[Task]:
"""读 tasks 行,不存在返回 None。"""
with session_scope() as s:
return s.execute(
select(Task).where(Task.task_id == task_id)
).scalar_one_or_none()