156 lines
5.7 KiB
Python
156 lines
5.7 KiB
Python
"""会话: 内存中的消息列表 + meta + 落 PG `messages` 表。
|
|
|
|
§7 B Step 2:消息走 ORM(append-only, idx 严格递增,payload jsonb)。
|
|
|
|
system prompt **不入库** —— 每次 build_agent 重建拼到 messages[0](§3.7
|
|
"memory 演化即时生效")。Session 内存里仍维持 [system, user_1, assistant_1, ...]
|
|
全列表;DB idx 从 0 开始数第一条非 system 消息。
|
|
|
|
保留 `atomic_write_text` 给 skill 产物 / 其他 .md 文件写入使用。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import delete, select
|
|
|
|
from .storage import session_scope
|
|
from .storage.models import Message, Task
|
|
|
|
|
|
def _to_dict(msg: Any) -> Any:
|
|
if isinstance(msg, dict):
|
|
return msg
|
|
if hasattr(msg, "model_dump"):
|
|
return msg.model_dump(exclude_none=True)
|
|
if hasattr(msg, "dict"):
|
|
return msg.dict(exclude_none=True)
|
|
return msg
|
|
|
|
|
|
def atomic_write_text(path: Path, text: str, encoding: str = "utf-8") -> None:
|
|
"""原子写: 先写到 path.tmp 再 os.replace 到 path。
|
|
|
|
防止写中途异常(磁盘满 / surrogate 编码错 / 进程被杀)留下 0 字节或半文件。
|
|
skill 产物(spec_lock.md / sections/*.md 等)走这里,messages 已改走 PG。
|
|
"""
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp = path.with_suffix(path.suffix + ".tmp")
|
|
with open(tmp, "w", encoding=encoding, newline="\n") as f:
|
|
f.write(text)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.replace(tmp, path)
|
|
|
|
|
|
class Session:
|
|
"""消息列表 anchored on task_id。
|
|
|
|
Lazy-persist: 构造时不动 DB,第一条非 system 消息 append 时:
|
|
1) 调 ensure_task_row 保证 tasks 行存在(Step 2 用占位值,Step 3 由 TaskState 提供完整值)
|
|
2) INSERT 一行 messages
|
|
|
|
系统 reset 走 DB DELETE 该 task 全部 messages。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
task_id: UUID,
|
|
system_prompt: str = "",
|
|
meta: Optional[dict] = None,
|
|
) -> None:
|
|
self.task_id: UUID = task_id
|
|
self.messages: List[dict] = []
|
|
self.meta: Dict[str, Any] = dict(meta or {})
|
|
self._db_idx: int = 0 # 下一条要写 DB 的 idx
|
|
if system_prompt:
|
|
self.messages.append({"role": "system", "content": system_prompt})
|
|
|
|
def append(self, msg: Any) -> None:
|
|
"""追加消息;非 system 落 DB,system 仅内存。"""
|
|
msg_dict = _to_dict(msg)
|
|
self.messages.append(msg_dict)
|
|
if msg_dict.get("role") == "system":
|
|
return
|
|
|
|
# 首次写入前,让 tasks 行就位。`ensure_local_task_row` 在 storage 层 idempotent。
|
|
# meta 字段(name/working_dir/skill/description/reasoning_effort)走 INSERT 一次性带入,
|
|
# 避免首次 append 后 _list_task_rows 看到空 meta;后续 task_state.save() 走 UPSERT 覆盖。
|
|
# name 是 NOT NULL,build_agent 必须放进 meta(新建 / resume 都已就位)。
|
|
from .storage.utils import ensure_local_task_row
|
|
ensure_local_task_row(
|
|
task_id=self.task_id,
|
|
name=self.meta.get("name", ""),
|
|
working_dir=self.meta.get("working_dir", ""),
|
|
skill=self.meta.get("skill", ""),
|
|
description=self.meta.get("description", ""),
|
|
model=self.meta.get("model", ""),
|
|
model_profile=self.meta.get("model_profile", ""),
|
|
reasoning_effort=self.meta.get("reasoning_effort", ""),
|
|
)
|
|
|
|
with session_scope() as s:
|
|
s.add(Message(
|
|
task_id=self.task_id,
|
|
idx=self._db_idx,
|
|
payload=msg_dict,
|
|
))
|
|
self._db_idx += 1
|
|
|
|
def reset(self, keep_system: bool = True) -> None:
|
|
"""清空消息。keep_system 仅影响内存(system 本来就不在 DB)。"""
|
|
if keep_system and self.messages and self.messages[0].get("role") == "system":
|
|
self.messages = [self.messages[0]]
|
|
else:
|
|
self.messages = []
|
|
with session_scope() as s:
|
|
s.execute(delete(Message).where(Message.task_id == self.task_id))
|
|
self._db_idx = 0
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
task_id: UUID,
|
|
system_prompt: str = "",
|
|
meta: Optional[dict] = None,
|
|
) -> "Session":
|
|
"""从 DB 读历史 messages。system_prompt 由调用方注入(memory 演化即时生效)。
|
|
|
|
若 task_id 在 DB 不存在,返回空 Session(messages 只含 system,_db_idx=0);
|
|
调用方判断该不该报错。
|
|
"""
|
|
sess = cls(task_id=task_id, system_prompt=system_prompt, meta=meta)
|
|
with session_scope() as s:
|
|
rows = s.execute(
|
|
select(Message)
|
|
.where(Message.task_id == task_id)
|
|
.order_by(Message.idx)
|
|
).scalars().all()
|
|
for row in rows:
|
|
sess.messages.append(dict(row.payload))
|
|
sess._db_idx = len(rows)
|
|
return sess
|
|
|
|
@classmethod
|
|
def task_exists(cls, task_id: UUID) -> bool:
|
|
"""tasks 行 + messages 至少 1 条 → 该 task 真存在(不是 lazy 占位)。"""
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(Task.task_id).where(Task.task_id == task_id)
|
|
).scalar_one_or_none()
|
|
if row is None:
|
|
return False
|
|
cnt = s.execute(
|
|
select(Message.message_id)
|
|
.where(Message.task_id == task_id)
|
|
.limit(1)
|
|
).scalar_one_or_none()
|
|
return cnt is not None
|
|
|
|
def n_user_msgs(self) -> int:
|
|
"""内存里 user 消息数,用于 _cleanup_if_empty 守门(避免回 DB)。"""
|
|
return sum(1 for m in self.messages if m.get("role") == "user")
|