78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
"""用量记账(0006):一次产生成本的调用 = 一行 usage_events + 双写 messages 列。
|
|
|
|
chat 类型的入口由 loop.py 在 assistant message 入库后调用;未来的媒体工具
|
|
(image/video/audio)在 tool execute 后由 loop 顺手记账。
|
|
|
|
成本计算依赖 litellm 的 cost map(litellm.cost_calculator.completion_cost)。
|
|
未知 model 或 map 缺失时 cost=0(不阻塞主流程),emit warn 给 sink。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from decimal import Decimal
|
|
from typing import Any, Optional
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import update
|
|
|
|
from .engine import session_scope
|
|
from .models import Message, UsageEvent
|
|
|
|
|
|
def _safe_chat_cost(response: Any) -> Decimal:
|
|
"""litellm.completion_cost(response) 包一层:任何异常都吞掉返 0。
|
|
|
|
未知 model / cost map 没收录 / response 结构变都不影响主流程 —— usage_events
|
|
仍写入,只是 cost_usd=0,后续人工补算 OK。
|
|
"""
|
|
try:
|
|
from litellm import completion_cost # type: ignore[import-not-found]
|
|
cost = completion_cost(completion_response=response)
|
|
if cost is None:
|
|
return Decimal("0")
|
|
return Decimal(str(cost))
|
|
except Exception:
|
|
return Decimal("0")
|
|
|
|
|
|
def record_chat_usage(
|
|
*,
|
|
task_id: UUID,
|
|
user_id: UUID,
|
|
message_id: Optional[UUID],
|
|
model_profile: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
response: Any = None,
|
|
) -> Decimal:
|
|
"""记一次 chat 调用:写 usage_events 行 + 回填 messages.model_profile/tokens_in/out。
|
|
|
|
`message_id` 来自 `Session.append` 的返回值;若为 None(系统消息 / 旧路径未拿到)
|
|
则 usage_events 仍写但 message_id=NULL,messages 列不回填。
|
|
`model_profile` 形如 `"deepseek_v4.pro"`(family.variant)。
|
|
返回算出的 cost_usd(已落库),调用方可用作 SSE 显示。
|
|
"""
|
|
cost = _safe_chat_cost(response)
|
|
units = {"tokens_in": int(prompt_tokens), "tokens_out": int(completion_tokens)}
|
|
|
|
with session_scope() as s:
|
|
s.add(UsageEvent(
|
|
user_id=user_id,
|
|
task_id=task_id,
|
|
message_id=message_id,
|
|
kind="chat",
|
|
model_profile=model_profile,
|
|
units=units,
|
|
cost_usd=cost,
|
|
))
|
|
if message_id is not None:
|
|
s.execute(
|
|
update(Message)
|
|
.where(Message.message_id == message_id)
|
|
.values(
|
|
tokens_in=int(prompt_tokens),
|
|
tokens_out=int(completion_tokens),
|
|
model_profile=model_profile,
|
|
)
|
|
)
|
|
return cost
|