zcbot/tests/test_context_compaction.py

204 lines
7.9 KiB
Python

import unittest
import json
from core.context import prepare_messages_for_llm, prepare_messages_with_stats
class ContextCompactionTests(unittest.TestCase):
def test_preserves_system_and_recent_messages(self) -> None:
messages = [
{"role": "system", "content": "rules"},
{"role": "user", "content": "old"},
{"role": "tool", "tool_call_id": "old-tool", "name": "shell", "content": "A" * 200},
{"role": "user", "content": "latest"},
{"role": "tool", "tool_call_id": "new-tool", "name": "shell", "content": "B" * 200},
]
prepared = prepare_messages_for_llm(
messages,
keep_recent=2,
old_tool_chars=40,
)
self.assertEqual(prepared[0], messages[0])
self.assertEqual(prepared[-2], messages[-2])
self.assertEqual(prepared[-1], messages[-1])
def test_compacts_old_tool_content_without_breaking_protocol_fields(self) -> None:
messages = [
{"role": "system", "content": "rules"},
{"role": "assistant", "tool_calls": [{"id": "tc1"}], "content": None},
{"role": "tool", "tool_call_id": "tc1", "name": "run_python", "content": "A" * 200},
{"role": "user", "content": "continue"},
]
prepared = prepare_messages_for_llm(
messages,
keep_recent=1,
old_tool_chars=40,
)
tool_msg = prepared[2]
self.assertEqual(tool_msg["role"], "tool")
self.assertEqual(tool_msg["tool_call_id"], "tc1")
self.assertEqual(tool_msg["name"], "run_python")
self.assertIn("[compacted old tool result", tool_msg["content"])
self.assertLess(len(tool_msg["content"]), 120)
def test_short_old_tool_content_is_left_unchanged(self) -> None:
messages = [
{"role": "system", "content": "rules"},
{"role": "tool", "tool_call_id": "tc1", "name": "grep", "content": "short"},
{"role": "user", "content": "next"},
]
prepared = prepare_messages_for_llm(
messages,
keep_recent=1,
old_tool_chars=40,
)
self.assertEqual(prepared[1]["content"], "short")
def test_compacts_old_load_skill_result_to_marker(self) -> None:
messages = [
{"role": "system", "content": "rules"},
{
"role": "tool",
"tool_call_id": "tc1",
"name": "load_skill",
"content": "[skill=proposal, dir=/sandbox/skills/proposal]\n" + "A" * 5000,
},
{"role": "user", "content": "next"},
]
prepared = prepare_messages_for_llm(messages, keep_recent=1)
self.assertIn("loaded skill: proposal", prepared[1]["content"])
self.assertIn("dir=/sandbox/skills/proposal", prepared[1]["content"])
self.assertNotIn("A" * 100, prepared[1]["content"])
def test_prepare_messages_reports_compaction_stats(self) -> None:
messages = [
{"role": "system", "content": "rules"},
{"role": "tool", "tool_call_id": "tc1", "name": "shell", "content": "A" * 200},
{"role": "user", "content": "next"},
]
prepared, stats = prepare_messages_with_stats(
messages,
keep_recent=1,
old_tool_chars=40,
)
self.assertLess(stats["sent_chars"], stats["original_chars"])
self.assertEqual(stats["compacted_tool_messages"], 1)
self.assertGreater(stats["saved_chars"], 0)
self.assertEqual(len(prepared), len(messages))
def test_defaults_compact_medium_sized_old_write_arguments(self) -> None:
args = json.dumps({"path": "slides/p01.py", "content": "A" * 1000})
messages = [
{"role": "system", "content": "rules"},
{
"role": "assistant",
"tool_calls": [{
"id": "tc1",
"type": "function",
"function": {"name": "write", "arguments": args},
}],
},
] + [{"role": "user", "content": f"recent {i}"} for i in range(12)]
prepared, stats = prepare_messages_with_stats(messages)
compacted_args = json.loads(prepared[1]["tool_calls"][0]["function"]["arguments"])
self.assertTrue(compacted_args["_compacted"])
self.assertEqual(compacted_args["path"], "slides/p01.py")
self.assertEqual(stats["compacted_tool_call_arguments"], 1)
def test_compacts_old_assistant_tool_call_arguments(self) -> None:
args = json.dumps({"path": "slides/p01.py", "content": "A" * 5000})
messages = [
{"role": "system", "content": "rules"},
{
"role": "assistant",
"content": "writing slide",
"tool_calls": [{
"id": "tc1",
"type": "function",
"function": {"name": "write", "arguments": args},
}],
},
{"role": "tool", "tool_call_id": "tc1", "name": "write", "content": "[wrote file]"},
{"role": "user", "content": "next"},
]
prepared, stats = prepare_messages_with_stats(
messages,
keep_recent=1,
old_tool_arg_chars=200,
)
tc = prepared[1]["tool_calls"][0]
compacted_args = json.loads(tc["function"]["arguments"])
self.assertEqual(tc["id"], "tc1")
self.assertEqual(tc["type"], "function")
self.assertEqual(tc["function"]["name"], "write")
self.assertTrue(compacted_args["_compacted"])
self.assertEqual(compacted_args["path"], "slides/p01.py")
self.assertNotIn("A" * 100, tc["function"]["arguments"])
self.assertEqual(stats["compacted_tool_call_arguments"], 1)
def test_keeps_old_task_progress_arguments_intact(self) -> None:
# task_progress 参数本就很小,且压成 `{"_compacted":...,"step_id":...}` 这种"像合法调用"
# 的标记会毒化模型 + 毁掉前端进度还原。故旧的 task_progress 调用参数必须原样保留。
args = json.dumps({
"action": "set_plan",
"steps": [
{"id": "s1", "title": "理解需求", "status": "completed"},
{"id": "s2", "title": "实现功能", "status": "in_progress"},
],
}, ensure_ascii=False)
messages = [
{"role": "system", "content": "rules"},
{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "tc1",
"type": "function",
"function": {"name": "task_progress", "arguments": args},
}],
},
] + [{"role": "user", "content": f"recent {i}"} for i in range(12)]
prepared, stats = prepare_messages_with_stats(messages)
kept_args = json.loads(prepared[1]["tool_calls"][0]["function"]["arguments"])
self.assertNotIn("_compacted", kept_args)
self.assertEqual(kept_args["action"], "set_plan")
self.assertEqual(kept_args["steps"][0]["title"], "理解需求")
self.assertEqual(stats["compacted_tool_call_arguments"], 0)
def test_old_task_progress_tool_result_uses_tiny_marker(self) -> None:
messages = [
{"role": "system", "content": "rules"},
{
"role": "tool",
"tool_call_id": "tc1",
"name": "task_progress",
"content": json.dumps({"ok": True, "steps": [{"title": "A" * 2000}]}),
},
{"role": "user", "content": "next"},
]
prepared, stats = prepare_messages_with_stats(messages, keep_recent=1)
self.assertEqual(prepared[1]["content"], "[task_progress updated; UI-only details omitted from context]")
self.assertEqual(stats["compacted_tool_messages"], 1)
if __name__ == "__main__":
unittest.main()