156 lines
6.5 KiB
Python
156 lines
6.5 KiB
Python
"""入站长轮询管理器(DESIGN §8.7):收用户消息 → 跑 agent → 回复发回。
|
|
|
|
- 每个 active 绑定一条 `getupdates` 长轮询(ilink 同步,放 to_thread);收到消息:
|
|
① `service.refresh_context_token` 刷新 24h 推送窗口;② 调注入的 `handle_message`
|
|
(app.py 提供:解析/建该用户常驻「微信」task → 抢 run 锁 → `_run_agent_bg` → 取回复);
|
|
③ 用本轮新鲜 `context_token` 分块发回。
|
|
- 每绑定 loop **串行**处理(收→跑→回→再收):天然避免同用户并发 run 锁冲突;不同用户并发。
|
|
- 管理器周期性对账 active 绑定:新增起 loop、撤销/revoke 停 loop。
|
|
|
|
`handle_message` 注入解耦 app.py 内部(broker / run 锁 / _run_agent_bg);本模块只管协议循环
|
|
与回复提取(`extract_last_assistant_text` 纯函数可测)。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any, Awaitable, Callable, Optional
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
|
|
from core.storage import session_scope
|
|
from core.storage.models import Message
|
|
from core.wechat import service
|
|
from core.wechat.ilink import ILinkClient, InboundAttachment
|
|
from core.wechat.service import BindingSnapshot
|
|
|
|
# app.py 注入:跑该用户的微信对话 task,返回 assistant 回复文本(可空)。
|
|
# 第三参 attachments:已下载解密(att.data 已回填)的入站附件,app.py 负责落盘 + 拼提示行。
|
|
HandleMessage = Callable[[UUID, str, list[InboundAttachment]], Awaitable[str]]
|
|
|
|
|
|
def _content_to_text(content: Any) -> str:
|
|
"""OpenAI 风格 content → 纯文本(str 直返;content blocks 拼 text 段)。"""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for b in content:
|
|
if isinstance(b, dict) and b.get("type") in (None, "text"):
|
|
parts.append(b.get("text", ""))
|
|
return "".join(parts)
|
|
return ""
|
|
|
|
|
|
def extract_last_assistant_text(task_id: UUID, *, scan: int = 20) -> str:
|
|
"""取该 task 最后一条**有正文**的 assistant 消息文本(跳过纯 tool_calls 行)。"""
|
|
with session_scope() as s:
|
|
rows = s.execute(
|
|
select(Message.payload)
|
|
.where(Message.task_id == task_id)
|
|
.order_by(Message.idx.desc())
|
|
.limit(scan)
|
|
).all()
|
|
for (payload,) in rows:
|
|
if not isinstance(payload, dict) or payload.get("role") != "assistant":
|
|
continue
|
|
text = _content_to_text(payload.get("content"))
|
|
if text.strip():
|
|
return text
|
|
return ""
|
|
|
|
|
|
async def _poll_binding(
|
|
snap: BindingSnapshot, handle_message: HandleMessage, stop: asyncio.Event
|
|
) -> None:
|
|
"""单个绑定的长轮询循环。异常退避重试,直到 stop。"""
|
|
client = ILinkClient(snap.bot_token, snap.base_url)
|
|
cursor = ""
|
|
backoff = 2
|
|
while not stop.is_set():
|
|
try:
|
|
msgs, cursor = await asyncio.to_thread(client.get_updates, cursor)
|
|
backoff = 2
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"[wechat-inbound] {str(snap.user_id)[:8]} getupdates err: "
|
|
f"{type(e).__name__}: {e}; retry in {backoff}s")
|
|
await asyncio.sleep(backoff)
|
|
backoff = min(backoff * 2, 60)
|
|
continue
|
|
for m in msgs:
|
|
if stop.is_set():
|
|
break
|
|
# 下载入站附件(图片/文件):CDN 取密文 → AES 解密 → 回填 att.data
|
|
atts: list[InboundAttachment] = []
|
|
for att in m.attachments:
|
|
try:
|
|
att.data = await asyncio.to_thread(client.download_media, att)
|
|
atts.append(att)
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"[wechat-inbound] {str(snap.user_id)[:8]} download "
|
|
f"{att.kind} err: {type(e).__name__}: {e}")
|
|
# 文本和附件都没有(纯文本为空 / 附件全下载失败)→ 跳过整条
|
|
if not m.text.strip() and not atts:
|
|
continue
|
|
# ① 刷新该用户推送窗口(主动推靠它续命)
|
|
await asyncio.to_thread(
|
|
service.refresh_context_token, snap.user_id, m.from_user_id, m.context_token
|
|
)
|
|
# ② 跑 agent 取回复(附件由 handle_message 落盘 + 拼 [用户上传的...] 行)
|
|
try:
|
|
reply = await handle_message(snap.user_id, m.text, atts)
|
|
except Exception as e: # noqa: BLE001
|
|
reply = f"[出错] {type(e).__name__}: {e}"
|
|
# ③ 用本轮新鲜 token 分块回
|
|
if reply and reply.strip():
|
|
try:
|
|
await asyncio.to_thread(
|
|
client.send_text, m.from_user_id, m.context_token, reply
|
|
)
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"[wechat-inbound] {str(snap.user_id)[:8]} reply send err: "
|
|
f"{type(e).__name__}: {e}")
|
|
|
|
|
|
async def run_inbound_manager(
|
|
handle_message: HandleMessage,
|
|
stop: asyncio.Event,
|
|
*,
|
|
reconcile_seconds: int = 60,
|
|
) -> None:
|
|
"""常驻管理器:周期对账 active 绑定,起/停 per-binding 长轮询循环。"""
|
|
loops: dict[UUID, asyncio.Task] = {}
|
|
try:
|
|
while not stop.is_set():
|
|
try:
|
|
active = await asyncio.to_thread(service.list_active_bindings)
|
|
except Exception as e: # noqa: BLE001
|
|
print(f"[wechat-inbound] list bindings err: {type(e).__name__}: {e}")
|
|
active = []
|
|
active_ids = {s.user_id for s in active}
|
|
# 起新增
|
|
for snap in active:
|
|
t = loops.get(snap.user_id)
|
|
if t is None or t.done():
|
|
loops[snap.user_id] = asyncio.create_task(
|
|
_poll_binding(snap, handle_message, stop),
|
|
name=f"wechat-poll-{str(snap.user_id)[:8]}",
|
|
)
|
|
# 清撤销 / 已结束
|
|
for uid in list(loops):
|
|
if uid not in active_ids:
|
|
loops.pop(uid).cancel()
|
|
elif loops[uid].done():
|
|
loops.pop(uid)
|
|
await _wait_stop(stop, reconcile_seconds) # 等 stop 或到下次对账
|
|
finally:
|
|
for t in loops.values():
|
|
t.cancel()
|
|
|
|
|
|
async def _wait_stop(stop: asyncio.Event, timeout: float) -> None:
|
|
try:
|
|
await asyncio.wait_for(stop.wait(), timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
pass
|