78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
"""会话: 内存中的消息列表 + meta(cwd / model / created_at) + 落盘 json。
|
|
|
|
文件格式:
|
|
{
|
|
"meta": {"id": "...", "created_at": "...", "cwd": "...", "model": "..."},
|
|
"messages": [...]
|
|
}
|
|
|
|
兼容老格式: 如果文件根是 list,就当 messages 处理,meta 为空。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
def _to_dict(msg: Any) -> Any:
|
|
if isinstance(msg, dict):
|
|
return msg
|
|
if hasattr(msg, "model_dump"):
|
|
return msg.model_dump(exclude_none=True)
|
|
if hasattr(msg, "dict"):
|
|
return msg.dict(exclude_none=True)
|
|
return msg
|
|
|
|
|
|
class Session:
|
|
def __init__(
|
|
self,
|
|
system_prompt: str = "",
|
|
path: Optional[Path] = None,
|
|
meta: Optional[dict] = None,
|
|
) -> None:
|
|
self.messages: List[dict] = []
|
|
self.path = path
|
|
self.meta: Dict[str, Any] = dict(meta or {})
|
|
if system_prompt:
|
|
self.messages.append({"role": "system", "content": system_prompt})
|
|
|
|
def append(self, msg: Any) -> None:
|
|
self.messages.append(_to_dict(msg))
|
|
if self.path is not None:
|
|
self.save()
|
|
|
|
def reset(self, keep_system: bool = True) -> None:
|
|
if keep_system and self.messages and self.messages[0].get("role") == "system":
|
|
self.messages = [self.messages[0]]
|
|
else:
|
|
self.messages = []
|
|
if self.path is not None:
|
|
self.save()
|
|
|
|
def save(self) -> None:
|
|
if self.path is None:
|
|
return
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
payload = {"meta": self.meta, "messages": self.messages}
|
|
self.path.write_text(
|
|
json.dumps(payload, ensure_ascii=False, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
@classmethod
|
|
def load(cls, path: Path) -> "Session":
|
|
s = cls(path=path)
|
|
if not path.exists():
|
|
return s
|
|
data = json.loads(path.read_text(encoding="utf-8"))
|
|
if isinstance(data, list):
|
|
# 老格式: 纯消息列表
|
|
s.messages = data
|
|
s.meta = {}
|
|
elif isinstance(data, dict):
|
|
s.messages = data.get("messages", []) or []
|
|
s.meta = data.get("meta", {}) or {}
|
|
return s
|