539 lines
20 KiB
Python
539 lines
20 KiB
Python
"""CLI 入口: 简单 REPL。
|
|
|
|
用法:
|
|
python cli.py chat # 新建一个 task
|
|
python cli.py chat --mode coding --desc "修一处 bug" # 带元数据建任务
|
|
python cli.py chat --resume last # 恢复最近一个 task
|
|
python cli.py chat --resume <uuid-or-prefix> # 显式 task_id(前缀 ≥8 字符)
|
|
python cli.py chat --model deepseek_v4.pro
|
|
python cli.py tasks # 列出 task
|
|
python cli.py probe # 实测对账 yaml 声称的能力
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import shutil
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import click
|
|
from rich.prompt import Prompt
|
|
from rich.table import Table
|
|
|
|
from core.ui import make_console
|
|
from main import (
|
|
ROOT,
|
|
_resolve_uuid_or_prefix,
|
|
build_agent,
|
|
load_config,
|
|
resolve_workspace,
|
|
sync_task_tokens,
|
|
tasks_dir,
|
|
)
|
|
|
|
|
|
@click.group()
|
|
def cli() -> None:
|
|
"""zcbot - 个人任务 agent"""
|
|
|
|
|
|
@cli.group()
|
|
def db() -> None:
|
|
"""数据库管理 (alembic upgrade/downgrade/current)。需先 export ZCBOT_DB_URL。"""
|
|
|
|
|
|
def _alembic_cfg():
|
|
from alembic.config import Config
|
|
return Config(str(ROOT / "alembic.ini"))
|
|
|
|
|
|
def _run_alembic(fn, *args) -> None:
|
|
"""统一包一层友好出错(ZCBOT_DB_URL 未设置 / 连不上 → 简洁报错,不打 traceback)。"""
|
|
try:
|
|
fn(_alembic_cfg(), *args)
|
|
except RuntimeError as e:
|
|
click.echo(f"[err] {e}", err=True)
|
|
sys.exit(2)
|
|
except Exception as e:
|
|
click.echo(f"[err] {type(e).__name__}: {e}", err=True)
|
|
sys.exit(3)
|
|
|
|
|
|
@db.command("upgrade")
|
|
@click.argument("revision", default="head")
|
|
def db_upgrade(revision: str) -> None:
|
|
"""alembic upgrade <revision> (default head)."""
|
|
from alembic import command
|
|
_run_alembic(command.upgrade, revision)
|
|
|
|
|
|
@db.command("downgrade")
|
|
@click.argument("revision")
|
|
def db_downgrade(revision: str) -> None:
|
|
"""alembic downgrade <revision> (use -1 for one step, base for all)."""
|
|
from alembic import command
|
|
_run_alembic(command.downgrade, revision)
|
|
|
|
|
|
@db.command("current")
|
|
def db_current() -> None:
|
|
"""alembic current -- show currently applied revision."""
|
|
from alembic import command
|
|
_run_alembic(command.current)
|
|
|
|
|
|
def _cleanup_if_empty(task_dir, session, workspace_dir, console=None) -> bool:
|
|
"""切走前清理空 task。
|
|
|
|
DB 行无条件删除(若存在且 session 内存无 user 消息)。
|
|
FS rmtree **仅在 task_dir 是 workspace/tasks/<uuid>/ 默认派生路径**且无产物时执行 ——
|
|
用户用 `--task-dir` 指定的项目目录绝不 rmtree(可能含用户已有文件)。
|
|
"""
|
|
from main import is_managed_task_dir
|
|
|
|
if session.n_user_msgs() > 0:
|
|
return False
|
|
|
|
managed = is_managed_task_dir(task_dir, workspace_dir)
|
|
try:
|
|
entries = list(task_dir.iterdir())
|
|
except FileNotFoundError:
|
|
# 目录都没建,只清 DB 占位行
|
|
_delete_task_db_row(session.task_id)
|
|
return False
|
|
meaningful = [
|
|
p for p in entries
|
|
if not (p.is_file() and p.name.endswith(".tmp"))
|
|
]
|
|
if meaningful:
|
|
return False
|
|
if managed:
|
|
shutil.rmtree(task_dir, ignore_errors=True)
|
|
_delete_task_db_row(session.task_id)
|
|
if console is not None:
|
|
tag = "empty" if managed else "empty (kept user dir)"
|
|
console.print(f"[muted]cleaned {tag} task {str(session.task_id)[:8]}[/muted]")
|
|
return True
|
|
|
|
|
|
def _delete_task_db_row(task_id) -> None:
|
|
"""删 PG tasks 行(messages 走 CASCADE)。task_id 可能从未入库,DELETE 0 行无副作用。"""
|
|
from sqlalchemy import delete
|
|
from core.storage import session_scope
|
|
from core.storage.models import Task
|
|
with session_scope() as s:
|
|
s.execute(delete(Task).where(Task.task_id == task_id))
|
|
|
|
|
|
def _task_has_messages(task_id_str: str) -> bool:
|
|
"""PG 里该 task_id 有至少一条 message。task_id 字符串(UUID 完整形式)。"""
|
|
from uuid import UUID
|
|
from sqlalchemy import select
|
|
from core.storage import session_scope
|
|
from core.storage.models import Message
|
|
try:
|
|
tid = UUID(task_id_str)
|
|
except ValueError:
|
|
return False
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(Message.message_id).where(Message.task_id == tid).limit(1)
|
|
).scalar_one_or_none()
|
|
return row is not None
|
|
|
|
|
|
def _list_task_rows(workspace_dir, limit=20, status=None):
|
|
"""返回 [(updated_at, task_id_str, status, mode, model, tokens, n_msgs, desc), ...] 时间降序。
|
|
|
|
Step 3 后:全字段从 PG tasks 表读,messages 数从 PG 数;workspace_dir 仅用于
|
|
保持签名向后兼容(不再读 state.json)。status 过滤走 SQL WHERE。
|
|
"""
|
|
from sqlalchemy import func, select
|
|
from core.storage import session_scope
|
|
from core.storage.models import Message, Task
|
|
|
|
_ = workspace_dir # 签名占位,Step 3 后已不需要
|
|
with session_scope() as s:
|
|
q = select(
|
|
Task.task_id, Task.updated_at, Task.status, Task.mode,
|
|
Task.model, Task.model_profile, Task.tokens_prompt,
|
|
Task.tokens_completion, Task.description,
|
|
).order_by(Task.updated_at.desc())
|
|
if status:
|
|
q = q.where(Task.status == status)
|
|
rows_db = s.execute(q.limit(limit)).all()
|
|
msg_counts = dict(s.execute(
|
|
select(Message.task_id, func.count()).group_by(Message.task_id)
|
|
).all())
|
|
|
|
rows = []
|
|
for tid, updated_at, st_, md, mdl, prof, tp, tc, desc in rows_db:
|
|
n = msg_counts.get(tid, 0)
|
|
rows.append((
|
|
updated_at, str(tid), st_, md,
|
|
prof or mdl, (tp or 0) + (tc or 0), n, desc,
|
|
))
|
|
return rows
|
|
|
|
|
|
@cli.command()
|
|
@click.option("--model", default=None, help="模型档案,如 deepseek_v4.flash 或 deepseek_v4.pro")
|
|
@click.option("--workspace", default=None, help="工作目录(存 tasks/ 和 sessions/)")
|
|
@click.option("--resume", default=None, help="恢复 task: 'last' 或 task_id")
|
|
@click.option("--mode", default="", help="任务模式标签(coding/ppt/proposal/...自由形式)")
|
|
@click.option("--desc", default="", help="一句话任务描述,便于 tasks 列表识别")
|
|
@click.option("--task-dir", "task_dir_arg", default=None,
|
|
help="项目化 task:把产物落到指定目录(绝对或相对当前 cwd);留空走默认派生 workspace/tasks/<uuid>/")
|
|
def chat(model: str, workspace: str, resume: str, mode: str, desc: str,
|
|
task_dir_arg: str) -> None:
|
|
"""启动交互式 REPL。每次启动默认开新 task,用 --resume 接老的。"""
|
|
console = make_console()
|
|
ws_dir = resolve_workspace(workspace)
|
|
try:
|
|
agent, session, sid, task_state, task_dir = build_agent(
|
|
model_name=model,
|
|
workspace=workspace,
|
|
console=console,
|
|
session_id=resume,
|
|
resume=bool(resume),
|
|
mode=mode,
|
|
description=desc,
|
|
task_dir_arg=task_dir_arg,
|
|
)
|
|
except Exception as e:
|
|
console.print(f"[err]启动失败:[/err] {type(e).__name__}: {e}")
|
|
sys.exit(1)
|
|
|
|
if resume:
|
|
console.print(
|
|
f"[ok]恢复 task[/ok] [bold]{sid[:8]}[/bold] ({len(session.messages)} 条消息) "
|
|
f"model: [accent]{agent.caps.model_id}[/accent]"
|
|
)
|
|
else:
|
|
meta_tail = ""
|
|
if task_state.mode or task_state.description:
|
|
meta_tail = f" mode={task_state.mode!r} desc={task_state.description!r}"
|
|
console.print(
|
|
f"[ok]新 task[/ok] [bold]{sid[:8]}[/bold] "
|
|
f"model: [accent]{agent.caps.model_id}[/accent]{meta_tail}"
|
|
)
|
|
console.print(
|
|
"[info]/exit 退出 /reset 清空对话(保留 task) /new 开新 task "
|
|
"/resume [last|<id>] 切到已有 task /id /status 查看 "
|
|
"/done /abandon 改状态 /desc <文本> 设描述 "
|
|
"/export [<id>] 导出对话为 .docx[/info]\n"
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
user_input = Prompt.ask("[user]you[/user]", console=console)
|
|
except (EOFError, KeyboardInterrupt):
|
|
console.print("\n[muted]bye[/muted]")
|
|
_cleanup_if_empty(task_dir, session, ws_dir, console)
|
|
break
|
|
|
|
cmd = user_input.strip()
|
|
if cmd in ("/exit", "/quit"):
|
|
_cleanup_if_empty(task_dir, session, ws_dir, console)
|
|
break
|
|
if cmd == "/reset":
|
|
session.reset(keep_system=True)
|
|
console.print("[info]当前 task 对话已重置(保留 system 和 state)[/info]")
|
|
continue
|
|
if cmd == "/new":
|
|
_cleanup_if_empty(task_dir, session, ws_dir, console)
|
|
try:
|
|
agent, session, sid, task_state, task_dir = build_agent(
|
|
model_name=model, workspace=workspace, console=console,
|
|
mode=mode, description=desc,
|
|
task_dir_arg=task_dir_arg,
|
|
)
|
|
except Exception as e:
|
|
console.print(f"[err]新建失败:[/err] {type(e).__name__}: {e}")
|
|
continue
|
|
console.print(f"[ok]新 task[/ok] [bold]{sid[:8]}[/bold]")
|
|
continue
|
|
if cmd.startswith("/resume"):
|
|
arg = cmd[len("/resume"):].strip()
|
|
target_id = None
|
|
if arg == "last":
|
|
rs = _list_task_rows(ws_dir, limit=1)
|
|
if not rs:
|
|
console.print("[warn]没有可恢复的 task[/warn]")
|
|
continue
|
|
target_id = rs[0][1]
|
|
elif arg:
|
|
target_id = arg
|
|
else:
|
|
rs = _list_task_rows(ws_dir, limit=10)
|
|
if not rs:
|
|
console.print("[warn]没有可恢复的 task[/warn]")
|
|
continue
|
|
tbl = Table(show_lines=False)
|
|
tbl.add_column("#", style="bold")
|
|
tbl.add_column("task id")
|
|
tbl.add_column("status")
|
|
tbl.add_column("mode")
|
|
tbl.add_column("msgs", justify="right")
|
|
tbl.add_column("desc")
|
|
sc = {"active": "status.active", "completed": "status.completed", "abandoned": "status.abandoned"}
|
|
for i, (_, tid, st, md, _mdl, _tok, n, dsc) in enumerate(rs, 1):
|
|
c = sc.get(st, "info")
|
|
d_show = dsc if len(dsc) <= 50 else dsc[:47] + "..."
|
|
tbl.add_row(str(i), tid[:8], f"[{c}]{st}[/{c}]", md, str(n), d_show)
|
|
console.print(tbl)
|
|
try:
|
|
sel = Prompt.ask("[user]选编号或输入 task_id (回车取消)[/user]", console=console, default="")
|
|
except (EOFError, KeyboardInterrupt):
|
|
continue
|
|
sel = sel.strip()
|
|
if not sel:
|
|
continue
|
|
if sel.isdigit():
|
|
idx = int(sel) - 1
|
|
if 0 <= idx < len(rs):
|
|
target_id = rs[idx][1]
|
|
else:
|
|
console.print(f"[err]编号超界: {sel}[/err]")
|
|
continue
|
|
else:
|
|
target_id = sel
|
|
if target_id == sid:
|
|
console.print(f"[info]已是当前 task: {sid}[/info]")
|
|
continue
|
|
_cleanup_if_empty(task_dir, session, ws_dir, console)
|
|
try:
|
|
agent, session, sid, task_state, task_dir = build_agent(
|
|
model_name=model, workspace=workspace, console=console,
|
|
session_id=target_id, resume=True,
|
|
)
|
|
except Exception as e:
|
|
console.print(f"[err]恢复失败:[/err] {type(e).__name__}: {e}")
|
|
continue
|
|
console.print(
|
|
f"[ok]切到 task[/ok] [bold]{sid[:8]}[/bold] ({len(session.messages)} 条消息) "
|
|
f"model: [accent]{agent.caps.model_id}[/accent]"
|
|
)
|
|
continue
|
|
if cmd == "/id":
|
|
cwd_disp = session.meta.get("cwd", "?")
|
|
model_disp = session.meta.get("model", agent.caps.model_id)
|
|
console.print(f"[info]task: {sid} model: {model_disp} cwd: {cwd_disp}[/info]")
|
|
continue
|
|
if cmd == "/status":
|
|
console.print(
|
|
f"[info]task {task_state.task_id} status={task_state.status} "
|
|
f"mode={task_state.mode!r} desc={task_state.description!r}\n"
|
|
f" model={task_state.model} tokens={task_state.tokens_total} "
|
|
f"(p={task_state.tokens_prompt}/c={task_state.tokens_completion}) "
|
|
f"created={task_state.created_at} updated={task_state.updated_at}[/info]"
|
|
)
|
|
continue
|
|
if cmd == "/done":
|
|
task_state.status = "completed"
|
|
task_state.save()
|
|
console.print(f"[ok]task {sid} marked completed[/ok]")
|
|
break
|
|
if cmd == "/abandon":
|
|
task_state.status = "abandoned"
|
|
task_state.save()
|
|
console.print(f"[warn]task {sid} marked abandoned[/warn]")
|
|
break
|
|
if cmd.startswith("/desc"):
|
|
new_desc = cmd[len("/desc"):].strip()
|
|
task_state.description = new_desc
|
|
task_state.save()
|
|
console.print(f"[info]description set: {new_desc!r}[/info]")
|
|
continue
|
|
if cmd.startswith("/export"):
|
|
arg = cmd[len("/export"):].strip()
|
|
from uuid import UUID
|
|
if arg:
|
|
if arg == "last":
|
|
rs = _list_task_rows(ws_dir, limit=1)
|
|
if not rs:
|
|
console.print("[warn]没有 task 可导出[/warn]")
|
|
continue
|
|
arg = rs[0][1]
|
|
try:
|
|
target_tid = _resolve_uuid_or_prefix(arg)
|
|
except Exception as e:
|
|
console.print(f"[err]task_id 解析失败:[/err] {type(e).__name__}: {e}")
|
|
continue
|
|
target_dir = None # 让 export_chat_to_docx 从 PG 读 task_dir
|
|
else:
|
|
target_tid = UUID(sid)
|
|
target_dir = task_dir
|
|
if not _task_has_messages(str(target_tid)):
|
|
console.print(
|
|
f"[warn]无可导出内容: {str(target_tid)[:8]} 还没有消息[/warn]"
|
|
)
|
|
continue
|
|
try:
|
|
from core.export_docx import export_chat_to_docx
|
|
out = export_chat_to_docx(target_tid, target_dir)
|
|
except Exception as e:
|
|
console.print(f"[err]导出失败:[/err] {type(e).__name__}: {e}")
|
|
continue
|
|
console.print(f"[ok]已导出[/ok] -> {out}")
|
|
continue
|
|
if not cmd:
|
|
continue
|
|
|
|
try:
|
|
agent.run(user_input)
|
|
except KeyboardInterrupt:
|
|
console.print("\n[warn]已中断本轮。下一条输入会继续这个 task。[/warn]")
|
|
except Exception as e:
|
|
console.print(f"[err]运行错误:[/err] {type(e).__name__}: {e}")
|
|
finally:
|
|
sync_task_tokens(task_state, agent.llm)
|
|
|
|
|
|
@cli.command()
|
|
@click.option("--workspace", default=None, help="工作目录")
|
|
@click.option("--limit", default=20, help="显示最近 N 个")
|
|
@click.option("--status", default=None, help="只看某状态: active / completed / abandoned")
|
|
def tasks(workspace: str, limit: int, status: str) -> None:
|
|
"""列出已有 task(从 PG tasks 表读,按 updated_at 降序)。"""
|
|
cfg = load_config()
|
|
ws = resolve_workspace(workspace, cfg)
|
|
rows = _list_task_rows(ws, limit=limit, status=status)
|
|
|
|
if not rows:
|
|
click.echo(f"(no tasks in {tasks_dir(ws)})")
|
|
return
|
|
tbl = Table(show_lines=False)
|
|
tbl.add_column("task id", style="bold")
|
|
tbl.add_column("status")
|
|
tbl.add_column("mode")
|
|
tbl.add_column("model")
|
|
tbl.add_column("msgs", justify="right")
|
|
tbl.add_column("tokens", justify="right")
|
|
tbl.add_column("desc")
|
|
sc = {"active": "status.active", "completed": "status.completed", "abandoned": "status.abandoned"}
|
|
for _, tid, st, mode, model, tok, n, desc in rows:
|
|
c = sc.get(st, "info")
|
|
d_show = desc if len(desc) <= 50 else desc[:47] + "..."
|
|
tbl.add_row(tid[:8], f"[{c}]{st}[/{c}]", mode, model, str(n), str(tok), d_show)
|
|
make_console().print(tbl)
|
|
|
|
|
|
@cli.command()
|
|
@click.argument("task_id")
|
|
@click.option("--workspace", default=None, help="工作目录")
|
|
@click.option("-o", "--output", default=None,
|
|
help="输出 .docx 路径,默认 <task_dir>/chat_<task_id>.docx")
|
|
@click.option("--include-system", is_flag=True,
|
|
help="包含 system prompt(默认跳过,信息密度低)")
|
|
@click.option("--no-reasoning", is_flag=True,
|
|
help="不包含 reasoning_content(默认带)")
|
|
@click.option("--tool-head", default=1000, type=int,
|
|
help="tool 结果保留前 N 字符(默认 1000)")
|
|
@click.option("--tool-tail", default=500, type=int,
|
|
help="tool 结果保留后 N 字符(默认 500)")
|
|
def export(task_id: str, workspace: str, output: str, include_system: bool,
|
|
no_reasoning: bool, tool_head: int, tool_tail: int) -> None:
|
|
"""把指定 task 的对话导出为 .docx。task_id 用 'last' 取最近一个。"""
|
|
from core.export_docx import export_chat_to_docx
|
|
|
|
console = make_console()
|
|
cfg = load_config()
|
|
ws = resolve_workspace(workspace, cfg)
|
|
|
|
if task_id == "last":
|
|
rs = _list_task_rows(ws, limit=1)
|
|
if not rs:
|
|
console.print("[err]没有 task 可导出[/err]")
|
|
sys.exit(1)
|
|
task_id = rs[0][1]
|
|
|
|
try:
|
|
tid = _resolve_uuid_or_prefix(task_id)
|
|
except Exception as e:
|
|
console.print(f"[err]task_id 解析失败:[/err] {type(e).__name__}: {e}")
|
|
sys.exit(1)
|
|
if not _task_has_messages(str(tid)):
|
|
console.print(f"[err]task 不存在或无 messages:[/err] {tid}")
|
|
sys.exit(1)
|
|
|
|
out = Path(output).resolve() if output else None
|
|
try:
|
|
path = export_chat_to_docx(
|
|
tid, None, out,
|
|
include_system=include_system,
|
|
include_reasoning=not no_reasoning,
|
|
tool_head=tool_head,
|
|
tool_tail=tool_tail,
|
|
)
|
|
except Exception as e:
|
|
console.print(f"[err]导出失败:[/err] {type(e).__name__}: {e}")
|
|
sys.exit(1)
|
|
console.print(f"[ok]导出完成[/ok] -> {path}")
|
|
|
|
|
|
@cli.command()
|
|
@click.option("--model", default=None, help="模型档案,如 deepseek_v4.flash 或 deepseek_v4.pro")
|
|
@click.option("--long-context", is_flag=True, help="加跑 needle-in-haystack(费 token,默认关)")
|
|
def probe(model: str, long_context: bool) -> None:
|
|
"""实测对账模型 yaml 声称的能力。会调用 LLM,有 API 开销。"""
|
|
from core.capabilities import ModelCapabilities
|
|
from core.llm import LLM
|
|
from core.probe import probe_capabilities
|
|
|
|
cfg = load_config()
|
|
name = model or cfg["default_model"]
|
|
|
|
console = make_console()
|
|
try:
|
|
caps = ModelCapabilities.load(name, ROOT / cfg["models_dir"])
|
|
except Exception as e:
|
|
console.print(f"[err]档案加载失败:[/err] {type(e).__name__}: {e}")
|
|
sys.exit(1)
|
|
|
|
console.print(
|
|
f"[bold]probing[/bold] [accent]{caps.model_id}[/accent] (profile: {name}) "
|
|
f"[muted]long-context={long_context}[/muted]\n"
|
|
)
|
|
|
|
try:
|
|
llm = LLM(caps)
|
|
except Exception as e:
|
|
console.print(f"[err]LLM 构造失败:[/err] {type(e).__name__}: {e}")
|
|
sys.exit(1)
|
|
|
|
with console.status("[muted]running probes...[/muted]", spinner="dots"):
|
|
report = probe_capabilities(caps, llm, include_long_context=long_context)
|
|
|
|
tbl = Table(show_lines=False)
|
|
tbl.add_column("capability", style="bold")
|
|
tbl.add_column("declared")
|
|
tbl.add_column("observed")
|
|
tbl.add_column("status")
|
|
tbl.add_column("detail")
|
|
color = {"ok": "ok", "mismatch": "warn", "error": "err", "skip": "muted"}
|
|
for r in report.results:
|
|
c = color.get(r.status, "info")
|
|
tbl.add_row(
|
|
r.name,
|
|
str(r.declared),
|
|
str(r.observed),
|
|
f"[{c}]{r.status}[/{c}]",
|
|
r.detail,
|
|
)
|
|
console.print(tbl)
|
|
|
|
if report.has_mismatch:
|
|
console.print(
|
|
"\n[warn]存在能力对账差异 —— 看 detail,必要时改 "
|
|
f"config/models/{caps.family}.yaml[/warn]"
|
|
)
|
|
sys.exit(2)
|
|
if any(r.status == "error" for r in report.results):
|
|
console.print("\n[err]部分探测出错(见 detail)[/err]")
|
|
sys.exit(3)
|
|
console.print("\n[ok]全部能力声明与实测一致。[/ok]")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|