zcbot/core/storage/usage.py

222 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""用量记账(0006 + 0007):一次产生成本的调用 = 一行 usage_events + 双写 messages 列。
chat 类型的入口由 loop.py 在 assistant message 入库后调用;媒体工具(image/video/audio)
在 tool execute 完后由 tool 直接调用对应入口(record_image_usage 等)。
币种(0007):全表统一 CNY(`cost_cny` 列)。chat 路径走 litellm 的 USD cost_map → 内部
×USD_TO_CNY 折算落库;媒体路径价格本身就是 CNY,直接落。units jsonb 里 snapshot 当时
的关键价格参数(chat 没有,media 存 price_cny_per_image 等),便于跨调价对账。
"""
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import Any, Mapping, Optional
from uuid import UUID
from sqlalchemy import func, select, update
from .engine import session_scope
from .models import Message, UsageEvent
# litellm 的 cost map 给的是 USD,落库前折成 CNY。汇率近似(每年看一次,实质偏差不大);
# 真要精算的话应该按调用时刻的汇率,但开发期/个人用接受。
USD_TO_CNY = Decimal("7.2")
def _safe_chat_cost_usd(response: Any) -> Decimal:
"""litellm.completion_cost(response) 包一层:任何异常都吞掉返 0。
未知 model / cost map 没收录 / response 结构变都不影响主流程 —— usage_events
仍写入,只是 cost=0,后续人工补算 OK。返 USD,由 caller 折算。
"""
try:
from litellm import completion_cost # type: ignore[import-not-found]
cost = completion_cost(completion_response=response)
if cost is None:
return Decimal("0")
return Decimal(str(cost))
except Exception:
return Decimal("0")
def record_chat_usage(
*,
task_id: UUID,
user_id: UUID,
message_id: Optional[UUID],
model_profile: str,
prompt_tokens: int,
completion_tokens: int,
response: Any = None,
) -> Decimal:
"""记一次 chat 调用:写 usage_events 行 + 回填 messages.model_profile/tokens_in/out。
`message_id` 来自 `Session.append` 的返回值;若为 None(系统消息 / 旧路径未拿到)
则 usage_events 仍写但 message_id=NULL,messages 列不回填。
`model_profile` 形如 `"deepseek_v4.pro"`(family.variant)。
返回算出的 cost_cny(已落库),调用方可用作 SSE 显示。
"""
cost_usd = _safe_chat_cost_usd(response)
cost_cny = (cost_usd * USD_TO_CNY).quantize(Decimal("0.000001"))
units = {
"tokens_in": int(prompt_tokens),
"tokens_out": int(completion_tokens),
# snapshot 折算系数,便于历史对账(汇率/价格涨跌后仍能还原当时折算逻辑)
"usd_to_cny": float(USD_TO_CNY),
}
with session_scope() as s:
s.add(UsageEvent(
user_id=user_id,
task_id=task_id,
message_id=message_id,
kind="chat",
model_profile=model_profile,
units=units,
cost_cny=cost_cny,
))
if message_id is not None:
s.execute(
update(Message)
.where(Message.message_id == message_id)
.values(
tokens_in=int(prompt_tokens),
tokens_out=int(completion_tokens),
model_profile=model_profile,
)
)
return cost_cny
def record_image_usage(
*,
task_id: UUID,
user_id: UUID,
model_profile: str,
n_images: int,
size: str,
price_cny_per_image: float,
search: bool = False,
extra_units: Optional[Mapping[str, Any]] = None,
) -> Decimal:
"""记一次图像生成:写 usage_events(kind=image)。
单价(CNY/张)由 caller 从配置文件读出后传入,**同步 snapshot 进 units jsonb** ——
将来豆包调价改 YAML 即可,历史记录不动且仍能完整还原当时单价。
`model_profile` 形如 `"doubao.seedream_5"`(family.variant 风格,跟 chat 对齐)。
`extra_units` 给将来扩展(如 quality / style 等额外加价维度)预留。
返回 cost_cny(已落库,可作 SSE / tool 返回串显示用)。
"""
price = Decimal(str(price_cny_per_image))
cost_cny = (price * n_images).quantize(Decimal("0.000001"))
units: dict[str, Any] = {
"n_images": int(n_images),
"size": size,
"search": bool(search),
"price_cny_per_image": float(price_cny_per_image),
}
if extra_units:
units.update(extra_units)
with session_scope() as s:
s.add(UsageEvent(
user_id=user_id,
task_id=task_id,
message_id=None, # image tool 在 tool execute 时调用,message 还未落库
kind="image",
model_profile=model_profile,
units=units,
cost_cny=cost_cny,
))
return cost_cny
def record_video_usage(
*,
task_id: UUID,
user_id: UUID,
model_profile: str,
resolution: str,
ratio: str,
duration_s: int,
fps: int,
width: int,
height: int,
tokens: int,
price_cny_per_mtoken: float,
has_video_input: bool = False,
watermark: bool = False,
extra_units: Optional[Mapping[str, Any]] = None,
) -> Decimal:
"""记一次视频生成:写 usage_events(kind=video)。
成本算法:`cost_cny = tokens / 1_000_000 * price_cny_per_mtoken`。tokens 由 caller
传入(响应里若有官方 usage 字段就用,否则按公式 `(in_dur+out_dur)*W*H*fps/1024`
估算),`price_cny_per_mtoken` 同步 snapshot 进 units jsonb,日后调价不影响历史对账。
`model_profile` 形如 `"doubao.seedance_2_fast"`(family.variant 风格)。
`has_video_input=True` 表示图生/视频编辑路径(单价 22 元/Mtok);False = 文生视频(37 元/Mtok)。
caller 自行根据请求 body 是否含 image_url/video_url 传对应单价 + flag。
**失败任务不要走这里** —— Volcengine 失败不计费,失败的 tool 调用直接返 [Error] 不写 usage。
"""
price = Decimal(str(price_cny_per_mtoken))
tok = Decimal(str(int(tokens)))
cost_cny = (price * tok / Decimal("1000000")).quantize(Decimal("0.000001"))
units: dict[str, Any] = {
"resolution": resolution,
"ratio": ratio,
"duration_s": int(duration_s),
"fps": int(fps),
"width": int(width),
"height": int(height),
"tokens": int(tokens),
"price_cny_per_mtoken": float(price_cny_per_mtoken),
"has_video_input": bool(has_video_input),
"watermark": bool(watermark),
}
if extra_units:
units.update(extra_units)
with session_scope() as s:
s.add(UsageEvent(
user_id=user_id,
task_id=task_id,
message_id=None,
kind="video",
model_profile=model_profile,
units=units,
cost_cny=cost_cny,
))
return cost_cny
def check_daily_quota(*, user_id: UUID, kind: str, limit: int) -> tuple[int, bool]:
"""每账号每日 kind=image/video 调用配额检查。返回 (今日已用次数, 是否超额)。
`limit <= 0` 视为不限,直接返回 (0, False) 不查 DB。
"一天"=服务器本地 00:00 起算(用户直觉日历日;非滑动 24h 窗口)。
失败任务不计 —— `record_*_usage` 只在成功+下载完才落库 → 失败 retry 不烧配额。
跨 task 跨 variant 全口径合计 —— 配额是账号级,与具体调用哪个 variant 无关。
并发 race:同 user 两次调用同时通过 check 会两个都跑成功(off-by-one),
可接受 —— 单 task 单活 run gate 已经挡了同 task 并发;跨 task 罕见,
且这是软上限(非计费 hard cap),日级偶尔多 1 张不影响保护意图。
"""
if limit <= 0:
return 0, False
now_local = datetime.now().astimezone()
today_start = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
with session_scope() as s:
used = s.execute(
select(func.count()).select_from(UsageEvent).where(
UsageEvent.user_id == user_id,
UsageEvent.kind == kind,
UsageEvent.created_at >= today_start,
)
).scalar_one()
used_int = int(used)
return used_int, used_int >= int(limit)