zcbot/core/context.py

172 lines
5.6 KiB
Python

"""LLM 上下文准备。
不改 Session 持久化历史,只在发给模型前做低风险压缩。第一阶段只压旧 tool
消息内容,保留 tool_call 协议字段,避免历史命令输出 / 检索结果反复占满 prompt。
"""
from __future__ import annotations
from copy import deepcopy
from typing import Any, List
import json
import re
def _compact_old_tool_content(content: str, max_chars: int) -> str:
if len(content) <= max_chars:
return content
head = max_chars // 2
tail = max_chars - head
omitted = len(content) - head - tail
return (
content[:head]
+ f"\n[compacted old tool result, {omitted} chars omitted]\n"
+ content[-tail:]
)
_LOAD_SKILL_HEADER_RE = re.compile(r"\[skill=([^,\]]+)(?:,\s*dir=([^\]]+))?\]")
def _compact_load_skill_content(content: str) -> str:
first_line = content.splitlines()[0] if content else ""
match = _LOAD_SKILL_HEADER_RE.search(first_line)
if match:
skill = match.group(1)
skill_dir = match.group(2) or ""
suffix = f", dir={skill_dir}" if skill_dir else ""
return f"[loaded skill: {skill}{suffix}; full SKILL.md omitted from old context]"
return "[loaded skill; full SKILL.md omitted from old context]"
def _message_chars(msg: dict[str, Any]) -> int:
try:
return len(json.dumps(msg, ensure_ascii=False))
except TypeError:
return len(str(msg))
def _compact_tool_call_arguments(raw: Any, max_chars: int) -> tuple[Any, bool]:
if not isinstance(raw, str) or len(raw) <= max_chars:
return raw, False
marker: dict[str, Any] = {
"_compacted": True,
"original_chars": len(raw),
"note": "old assistant tool_call arguments omitted from context",
}
try:
parsed = json.loads(raw)
except Exception:
parsed = None
if isinstance(parsed, dict):
for key in ("path", "script_path", "file_path", "name"):
value = parsed.get(key)
if isinstance(value, str) and value:
marker[key] = value
content = parsed.get("content")
if isinstance(content, str):
marker["content_chars"] = len(content)
return json.dumps(marker, ensure_ascii=False), True
def _compact_assistant_tool_calls(
msg: dict[str, Any],
*,
max_arg_chars: int,
) -> tuple[int, int]:
tool_calls = msg.get("tool_calls")
if not isinstance(tool_calls, list):
return 0, 0
compacted = 0
saved = 0
for tc in tool_calls:
if not isinstance(tc, dict):
continue
fn = tc.get("function")
if not isinstance(fn, dict):
continue
before = fn.get("arguments")
after, did_compact = _compact_tool_call_arguments(
before,
max_chars=max(0, max_arg_chars),
)
if did_compact:
fn["arguments"] = after
compacted += 1
saved += len(before) - len(after)
return compacted, max(0, saved)
def prepare_messages_for_llm(
messages: List[dict[str, Any]],
*,
keep_recent: int = 12,
old_tool_chars: int = 2_000,
old_tool_arg_chars: int = 800,
) -> List[dict[str, Any]]:
"""返回发给 LLM 的 messages 副本。
- system 和最近 keep_recent 条消息原样保留。
- 较旧且过长的 tool content 压缩为头尾摘要。
- role/tool_call_id/name 等协议字段不变。
"""
prepared, _ = prepare_messages_with_stats(
messages,
keep_recent=keep_recent,
old_tool_chars=old_tool_chars,
old_tool_arg_chars=old_tool_arg_chars,
)
return prepared
def prepare_messages_with_stats(
messages: List[dict[str, Any]],
*,
keep_recent: int = 12,
old_tool_chars: int = 2_000,
old_tool_arg_chars: int = 800,
) -> tuple[List[dict[str, Any]], dict[str, int]]:
"""返回发给 LLM 的 messages 副本和压缩统计。"""
if keep_recent < 0:
keep_recent = 0
original_chars = sum(_message_chars(m) for m in messages)
recent_start = max(0, len(messages) - keep_recent)
prepared: List[dict[str, Any]] = []
compacted_tool_messages = 0
compacted_skill_messages = 0
compacted_tool_call_arguments = 0
for idx, msg in enumerate(messages):
new_msg = deepcopy(msg)
is_recent = idx >= recent_start
if not is_recent and new_msg.get("role") == "assistant":
n_args, _ = _compact_assistant_tool_calls(
new_msg,
max_arg_chars=old_tool_arg_chars,
)
compacted_tool_call_arguments += n_args
if (
not is_recent
and new_msg.get("role") == "tool"
and isinstance(new_msg.get("content"), str)
):
before = new_msg["content"]
if new_msg.get("name") == "load_skill":
new_msg["content"] = _compact_load_skill_content(before)
compacted_skill_messages += int(new_msg["content"] != before)
else:
new_msg["content"] = _compact_old_tool_content(
before,
max_chars=max(0, old_tool_chars),
)
compacted_tool_messages += int(new_msg["content"] != before)
prepared.append(new_msg)
sent_chars = sum(_message_chars(m) for m in prepared)
stats = {
"original_chars": original_chars,
"sent_chars": sent_chars,
"saved_chars": max(0, original_chars - sent_chars),
"compacted_tool_messages": compacted_tool_messages,
"compacted_skill_messages": compacted_skill_messages,
"compacted_tool_call_arguments": compacted_tool_call_arguments,
}
return prepared, stats