94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
"""会话: 内存中的消息列表 + meta(cwd / model / created_at) + 落盘 json。
|
|
|
|
文件格式:
|
|
{
|
|
"meta": {"id": "...", "created_at": "...", "cwd": "...", "model": "..."},
|
|
"messages": [...]
|
|
}
|
|
|
|
兼容老格式: 如果文件根是 list,就当 messages 处理,meta 为空。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
def _to_dict(msg: Any) -> Any:
|
|
if isinstance(msg, dict):
|
|
return msg
|
|
if hasattr(msg, "model_dump"):
|
|
return msg.model_dump(exclude_none=True)
|
|
if hasattr(msg, "dict"):
|
|
return msg.dict(exclude_none=True)
|
|
return msg
|
|
|
|
|
|
def atomic_write_text(path: Path, text: str, encoding: str = "utf-8") -> None:
|
|
"""原子写: 先写到 path.tmp 再 os.replace 到 path。
|
|
|
|
防止写中途异常(磁盘满 / surrogate 编码错 / 进程被杀)留下 0 字节或半文件。
|
|
单 REPL 单 task 假设下 .tmp 名固定;若上次写崩留下孤儿,本次写会覆盖它。
|
|
`_cleanup_if_empty` 已配合放过 `*.tmp` 文件。
|
|
"""
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp = path.with_suffix(path.suffix + ".tmp")
|
|
with open(tmp, "w", encoding=encoding, newline="\n") as f:
|
|
f.write(text)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.replace(tmp, path)
|
|
|
|
|
|
class Session:
|
|
def __init__(
|
|
self,
|
|
system_prompt: str = "",
|
|
path: Optional[Path] = None,
|
|
meta: Optional[dict] = None,
|
|
) -> None:
|
|
self.messages: List[dict] = []
|
|
self.path = path
|
|
self.meta: Dict[str, Any] = dict(meta or {})
|
|
if system_prompt:
|
|
self.messages.append({"role": "system", "content": system_prompt})
|
|
|
|
def append(self, msg: Any) -> None:
|
|
self.messages.append(_to_dict(msg))
|
|
if self.path is not None:
|
|
self.save()
|
|
|
|
def reset(self, keep_system: bool = True) -> None:
|
|
if keep_system and self.messages and self.messages[0].get("role") == "system":
|
|
self.messages = [self.messages[0]]
|
|
else:
|
|
self.messages = []
|
|
if self.path is not None:
|
|
self.save()
|
|
|
|
def save(self) -> None:
|
|
if self.path is None:
|
|
return
|
|
payload = {"meta": self.meta, "messages": self.messages}
|
|
atomic_write_text(
|
|
self.path,
|
|
json.dumps(payload, ensure_ascii=False, indent=2),
|
|
)
|
|
|
|
@classmethod
|
|
def load(cls, path: Path) -> "Session":
|
|
s = cls(path=path)
|
|
if not path.exists():
|
|
return s
|
|
data = json.loads(path.read_text(encoding="utf-8"))
|
|
if isinstance(data, list):
|
|
# 老格式: 纯消息列表
|
|
s.messages = data
|
|
s.meta = {}
|
|
elif isinstance(data, dict):
|
|
s.messages = data.get("messages", []) or []
|
|
s.meta = data.get("meta", {}) or {}
|
|
return s
|