zcbot/core/session.py

94 lines
2.9 KiB
Python

"""会话: 内存中的消息列表 + meta(cwd / model / created_at) + 落盘 json。
文件格式:
{
"meta": {"id": "...", "created_at": "...", "cwd": "...", "model": "..."},
"messages": [...]
}
兼容老格式: 如果文件根是 list,就当 messages 处理,meta 为空。
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
def _to_dict(msg: Any) -> Any:
if isinstance(msg, dict):
return msg
if hasattr(msg, "model_dump"):
return msg.model_dump(exclude_none=True)
if hasattr(msg, "dict"):
return msg.dict(exclude_none=True)
return msg
def atomic_write_text(path: Path, text: str, encoding: str = "utf-8") -> None:
"""原子写: 先写到 path.tmp 再 os.replace 到 path。
防止写中途异常(磁盘满 / surrogate 编码错 / 进程被杀)留下 0 字节或半文件。
单 REPL 单 task 假设下 .tmp 名固定;若上次写崩留下孤儿,本次写会覆盖它。
`_cleanup_if_empty` 已配合放过 `*.tmp` 文件。
"""
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
with open(tmp, "w", encoding=encoding, newline="\n") as f:
f.write(text)
f.flush()
os.fsync(f.fileno())
os.replace(tmp, path)
class Session:
def __init__(
self,
system_prompt: str = "",
path: Optional[Path] = None,
meta: Optional[dict] = None,
) -> None:
self.messages: List[dict] = []
self.path = path
self.meta: Dict[str, Any] = dict(meta or {})
if system_prompt:
self.messages.append({"role": "system", "content": system_prompt})
def append(self, msg: Any) -> None:
self.messages.append(_to_dict(msg))
if self.path is not None:
self.save()
def reset(self, keep_system: bool = True) -> None:
if keep_system and self.messages and self.messages[0].get("role") == "system":
self.messages = [self.messages[0]]
else:
self.messages = []
if self.path is not None:
self.save()
def save(self) -> None:
if self.path is None:
return
payload = {"meta": self.meta, "messages": self.messages}
atomic_write_text(
self.path,
json.dumps(payload, ensure_ascii=False, indent=2),
)
@classmethod
def load(cls, path: Path) -> "Session":
s = cls(path=path)
if not path.exists():
return s
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
# 老格式: 纯消息列表
s.messages = data
s.meta = {}
elif isinstance(data, dict):
s.messages = data.get("messages", []) or []
s.meta = data.get("meta", {}) or {}
return s