zcbot/scripts/backfill_chat_cost_cache_di...

169 lines
7.0 KiB
Python

"""Backfill 历史 chat usage_events 的 cost_cny —— 给前缀缓存命中折价 + 修 ¥0 旧账。
背景:`usage.py::_fallback_chat_cost_cny` 早期(a)对未知模型 litellm 返 0 又无兜底 →
大量 chat 事件 cost_cny 记成 ¥0;(b)后来加了兜底但把缓存命中段也按 input 全价算 →
命中率高的 task 虚高 2-3x。本脚本按每条事件 units 里已存的 token 数 + 模型档案价
**重算 cost_cny**,只改成本列,**不动任何 token 数 / units**。
价格来源:`ModelCapabilities.load(model_profile)`(config 当前价,含 cache_hit 折价)。
config 里没有该 profile 或无 input/output 价(如 glm 未配价)→ 跳过,保留原值
(不臆造价格)。缓存命中数取 units.cache_hit_tokens(缺 → 0,即按全价,绝不少记)。
跑法: .venv/Scripts/python.exe scripts/backfill_chat_cost_cache_discount.py
默认 dry-run 只打印汇总,加 --apply 真写。
幂等:重算是确定性的;再跑一遍 0 改动。前端任务成本是现算 SUM(usage_events.cost_cny),
改完即时反映,无需动 tasks 表。
"""
from __future__ import annotations
import argparse
import os
import sys
from collections import defaultdict
from decimal import Decimal
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
env_file = ROOT / ".env"
if env_file.exists():
for line in env_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, _, v = line.partition("=")
os.environ.setdefault(k.strip(), v.strip())
from sqlalchemy import select
from core.agent_builder import ROOT as AB_ROOT, load_config
from core.capabilities import ModelCapabilities
from core.storage import session_scope
from core.storage.models import UsageEvent
from core.storage.usage import _fallback_chat_cost_cny
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--apply", action="store_true", help="真写;默认 dry-run 只打印")
ap.add_argument(
"--assume-cache-hit-rate",
type=float,
default=None,
metavar="RATE",
help="对 units 没记 cache_hit_tokens 的老事件,假定命中率 RATE(0~1)折价。"
"DeepSeek 当时其实缓存了前缀只是没记 → 全价偏高;给个保守估算(如 0.7)更贴近真实。"
"已记 cache_hit_tokens 的事件用真实值,不受影响。缺省=不假定(无字段按 0 命中/全价)。",
)
args = ap.parse_args()
assume_rate = args.assume_cache_hit_rate
if assume_rate is not None and not (0.0 <= assume_rate <= 1.0):
ap.error("--assume-cache-hit-rate 必须在 0~1 之间")
cfg = load_config()
models_dir = AB_ROOT / cfg["models_dir"]
# ModelCapabilities.load 按 profile 缓存(避免每行重读 yaml);None = 无法定价
caps_cache: dict[str, ModelCapabilities | None] = {}
def get_caps(profile: str) -> ModelCapabilities | None:
if profile not in caps_cache:
try:
caps_cache[profile] = ModelCapabilities.load(profile, models_dir)
except Exception:
caps_cache[profile] = None
return caps_cache[profile]
# per-profile 统计:事件数 / 改动数 / 跳过数 / 假定命中数 / 旧总额 / 新总额
stat_n: dict[str, int] = defaultdict(int)
stat_changed: dict[str, int] = defaultdict(int)
stat_skipped: dict[str, int] = defaultdict(int)
stat_assumed: dict[str, int] = defaultdict(int)
old_sum: dict[str, Decimal] = defaultdict(lambda: Decimal("0"))
new_sum: dict[str, Decimal] = defaultdict(lambda: Decimal("0"))
with session_scope() as s:
rows = s.execute(
select(UsageEvent).where(UsageEvent.kind == "chat")
).scalars().all()
for e in rows:
profile = e.model_profile or "?"
stat_n[profile] += 1
u = e.units or {}
caps = get_caps(profile)
if caps and (caps.input_cny_per_mtoken or caps.output_cny_per_mtoken):
inp = caps.input_cny_per_mtoken
outp = caps.output_cny_per_mtoken
chp = caps.cache_hit_cny_per_mtoken
else:
# config 无价 → 退 units 价格快照(老事件多半也没有);仍无 → 跳过
inp = float(u.get("input_cny_per_mtoken") or 0)
outp = float(u.get("output_cny_per_mtoken") or 0)
chp = float(u.get("cache_hit_cny_per_mtoken") or 0)
if not (inp or outp):
stat_skipped[profile] += 1
old_sum[profile] += Decimal(str(e.cost_cny))
new_sum[profile] += Decimal(str(e.cost_cny)) # 无价不变
continue
tin = int(u.get("tokens_in") or 0)
# cache_hit:优先用真实记录;没记(key 缺失)且开了 --assume-cache-hit-rate
# 时按估算命中率折(DeepSeek 当时缓存了只是没记)。key 在(含 =0)= 真实值,不假定。
if "cache_hit_tokens" in u:
cache_hit = int(u.get("cache_hit_tokens") or 0)
elif assume_rate is not None:
cache_hit = int(round(tin * assume_rate))
stat_assumed[profile] += 1
else:
cache_hit = 0
new_cost = _fallback_chat_cost_cny(
prompt_tokens=tin,
completion_tokens=int(u.get("tokens_out") or 0),
input_cny_per_mtoken=inp,
output_cny_per_mtoken=outp,
cache_hit_tokens=cache_hit,
cache_hit_cny_per_mtoken=chp,
)
old_cost = Decimal(str(e.cost_cny))
old_sum[profile] += old_cost
new_sum[profile] += new_cost
if new_cost != old_cost:
e.cost_cny = new_cost
stat_changed[profile] += 1
if args.apply:
s.commit()
else:
s.rollback()
print()
if assume_rate is not None:
print(f"[assume] 无 cache_hit 字段的老事件按命中率 {assume_rate:.0%} 估算折价")
print(f"{'model_profile':<22}{'events':>8}{'changed':>9}{'skipped':>9}"
f"{'assumed':>9}{'old_¥':>12}{'new_¥':>12}")
tot_old = Decimal("0")
tot_new = Decimal("0")
for profile in sorted(stat_n):
o, n = old_sum[profile], new_sum[profile]
tot_old += o
tot_new += n
print(f"{profile:<22}{stat_n[profile]:>8}{stat_changed[profile]:>9}"
f"{stat_skipped[profile]:>9}{stat_assumed[profile]:>9}"
f"{float(o):>12.4f}{float(n):>12.4f}")
print(f"{'TOTAL':<22}{sum(stat_n.values()):>8}"
f"{sum(stat_changed.values()):>9}{sum(stat_skipped.values()):>9}"
f"{sum(stat_assumed.values()):>9}{float(tot_old):>12.4f}{float(tot_new):>12.4f}")
print()
print(f"[mode] {'APPLIED (committed)' if args.apply else 'DRY-RUN (no commit, rerun with --apply)'}")
return 0
if __name__ == "__main__":
sys.exit(main())