zcbot/tools/image_ref.py

93 lines
3.2 KiB
Python

"""共享:把模型给的图片路径解析 + 读成 base64 data URL。
seedream(改图参考)与 look_at_image(看图)共用同一套路径解析 + 校验:
- 三种路径形态都吃:`figures/x.png`(working_dir 相对)/ `<taskname>/figures/x.png`
(user_root 相对,= tool 上次 saved: 行形态)/ 绝对路径
- **强制最终落在 user_root 子树内**(防模型借参考图越界读任意文件)
- 校验存在 + 图片扩展名 + 大小上限,再 base64 编码成 data URL
"""
from __future__ import annotations
import base64
from pathlib import Path
from typing import Callable, Optional
# 支持的扩展名 → MIME(其余拒绝,避免把非图当 base64 喂进模型)
REF_MIME = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".webp": "image/webp",
".gif": "image/gif",
}
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 单图 10MB(ARK 约束)
def resolve_in_root(
rel: str, working_dir: Path, user_root: Optional[Path]
) -> Optional[Path]:
"""三形态解析 + user_root 边界校验。命中返回解析后的绝对 Path,否则 None。"""
p = Path(rel)
candidates: list[Path] = []
if p.is_absolute():
candidates.append(p)
else:
candidates.append(working_dir / rel)
if user_root is not None:
candidates.append(user_root / rel)
root = (user_root or working_dir).resolve()
for c in candidates:
try:
rc = c.resolve()
except OSError:
continue
try:
rc.relative_to(root) # 越界(.. 逃逸 / 软链外指)直接跳过
except ValueError:
continue
if rc.is_file():
return rc
return None
def load_image_as_data_url(
rel: str,
*,
working_dir: Path,
user_root: Optional[Path],
display_fn: Callable[[Path], str],
max_bytes: int = MAX_IMAGE_BYTES,
) -> tuple[str, str, str]:
"""返回 (data_url, display_path, error)。
error 非空时前两者无意义,caller 直接把 error 当 tool 结果返回(已是 `[Error] ...` 形态)。
display_fn 传 Tool._display,把解析路径渲成对外相对串(不泄漏部署绝对路径)。
"""
resolved = resolve_in_root(rel, working_dir, user_root)
if resolved is None:
return "", "", (
f"[Error] 图片找不到或越界: {rel!r}。请传 task_dir 内已存在图片的相对路径"
f"(如 'figures/xxx.png',或工具上次返回的 saved 路径)。"
)
ext = resolved.suffix.lower()
mime = REF_MIME.get(ext)
if mime is None:
return "", "", (
f"[Error] 图片扩展名不支持: {ext or '(无)'}"
f"仅支持 {'/'.join(sorted(REF_MIME))}"
)
try:
raw = resolved.read_bytes()
except OSError as e:
return "", "", f"[Error] 读取图片失败: {type(e).__name__}: {e}"
if len(raw) > max_bytes:
mb = len(raw) / 1024 / 1024
return "", "", (
f"[Error] 图片 {mb:.1f}MB 超过 {max_bytes // 1024 // 1024}MB 上限。先压缩 / 缩小再传。"
)
b64 = base64.b64encode(raw).decode("ascii")
return f"data:{mime};base64,{b64}", display_fn(resolved), ""