zcbot/core/session.py

154 lines
5.5 KiB
Python

"""会话: 内存中的消息列表 + meta + 落 PG `messages` 表。
§7 B Step 2:消息走 ORM(append-only, idx 严格递增,payload jsonb)。
system prompt **不入库** —— 每次 build_agent 重建拼到 messages[0](§3.7
"memory 演化即时生效")。Session 内存里仍维持 [system, user_1, assistant_1, ...]
全列表;DB idx 从 0 开始数第一条非 system 消息。
保留 `atomic_write_text` 给 skill 产物 / 其他 .md 文件写入使用。
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import UUID
from sqlalchemy import delete, select
from .storage import session_scope
from .storage.models import Message, Task
def _to_dict(msg: Any) -> Any:
if isinstance(msg, dict):
return msg
if hasattr(msg, "model_dump"):
return msg.model_dump(exclude_none=True)
if hasattr(msg, "dict"):
return msg.dict(exclude_none=True)
return msg
def atomic_write_text(path: Path, text: str, encoding: str = "utf-8") -> None:
"""原子写: 先写到 path.tmp 再 os.replace 到 path。
防止写中途异常(磁盘满 / surrogate 编码错 / 进程被杀)留下 0 字节或半文件。
skill 产物(spec_lock.md / sections/*.md 等)走这里,messages 已改走 PG。
"""
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
with open(tmp, "w", encoding=encoding, newline="\n") as f:
f.write(text)
f.flush()
os.fsync(f.fileno())
os.replace(tmp, path)
class Session:
"""消息列表 anchored on task_id。
Lazy-persist: 构造时不动 DB,第一条非 system 消息 append 时:
1) 调 ensure_task_row 保证 tasks 行存在(Step 2 用占位值,Step 3 由 TaskState 提供完整值)
2) INSERT 一行 messages
系统 reset 走 DB DELETE 该 task 全部 messages。
"""
def __init__(
self,
task_id: UUID,
system_prompt: str = "",
meta: Optional[dict] = None,
) -> None:
self.task_id: UUID = task_id
self.messages: List[dict] = []
self.meta: Dict[str, Any] = dict(meta or {})
self._db_idx: int = 0 # 下一条要写 DB 的 idx
if system_prompt:
self.messages.append({"role": "system", "content": system_prompt})
def append(self, msg: Any) -> None:
"""追加消息;非 system 落 DB,system 仅内存。"""
msg_dict = _to_dict(msg)
self.messages.append(msg_dict)
if msg_dict.get("role") == "system":
return
# 首次写入前,让 tasks 行就位。`ensure_local_task_row` 在 storage 层 idempotent。
# meta 字段(mode/description/reasoning_effort)走 INSERT 一次性带入,避免
# 首次 append 后 _list_task_rows 看到空 meta;后续 task_state.save() 走 UPSERT 覆盖。
from .storage.utils import ensure_local_task_row
ensure_local_task_row(
task_id=self.task_id,
task_dir=self.meta.get("task_dir", ""),
mode=self.meta.get("mode", ""),
description=self.meta.get("description", ""),
model=self.meta.get("model", ""),
model_profile=self.meta.get("model_profile", ""),
reasoning_effort=self.meta.get("reasoning_effort", ""),
)
with session_scope() as s:
s.add(Message(
task_id=self.task_id,
idx=self._db_idx,
payload=msg_dict,
))
self._db_idx += 1
def reset(self, keep_system: bool = True) -> None:
"""清空消息。keep_system 仅影响内存(system 本来就不在 DB)。"""
if keep_system and self.messages and self.messages[0].get("role") == "system":
self.messages = [self.messages[0]]
else:
self.messages = []
with session_scope() as s:
s.execute(delete(Message).where(Message.task_id == self.task_id))
self._db_idx = 0
@classmethod
def load(
cls,
task_id: UUID,
system_prompt: str = "",
meta: Optional[dict] = None,
) -> "Session":
"""从 DB 读历史 messages。system_prompt 由调用方注入(memory 演化即时生效)。
若 task_id 在 DB 不存在,返回空 Session(messages 只含 system,_db_idx=0);
调用方判断该不该报错。
"""
sess = cls(task_id=task_id, system_prompt=system_prompt, meta=meta)
with session_scope() as s:
rows = s.execute(
select(Message)
.where(Message.task_id == task_id)
.order_by(Message.idx)
).scalars().all()
for row in rows:
sess.messages.append(dict(row.payload))
sess._db_idx = len(rows)
return sess
@classmethod
def task_exists(cls, task_id: UUID) -> bool:
"""tasks 行 + messages 至少 1 条 → 该 task 真存在(不是 lazy 占位)。"""
with session_scope() as s:
row = s.execute(
select(Task.task_id).where(Task.task_id == task_id)
).scalar_one_or_none()
if row is None:
return False
cnt = s.execute(
select(Message.message_id)
.where(Message.task_id == task_id)
.limit(1)
).scalar_one_or_none()
return cnt is not None
def n_user_msgs(self) -> int:
"""内存里 user 消息数,用于 _cleanup_if_empty 守门(避免回 DB)。"""
return sum(1 for m in self.messages if m.get("role") == "user")