93 lines
3.2 KiB
Python
93 lines
3.2 KiB
Python
"""共享:把模型给的图片路径解析 + 读成 base64 data URL。
|
|
|
|
seedream(改图参考)与 look_at_image(看图)共用同一套路径解析 + 校验:
|
|
- 三种路径形态都吃:`figures/x.png`(working_dir 相对)/ `<taskname>/figures/x.png`
|
|
(user_root 相对,= tool 上次 saved: 行形态)/ 绝对路径
|
|
- **强制最终落在 user_root 子树内**(防模型借参考图越界读任意文件)
|
|
- 校验存在 + 图片扩展名 + 大小上限,再 base64 编码成 data URL
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
from pathlib import Path
|
|
from typing import Callable, Optional
|
|
|
|
# 支持的扩展名 → MIME(其余拒绝,避免把非图当 base64 喂进模型)
|
|
REF_MIME = {
|
|
".png": "image/png",
|
|
".jpg": "image/jpeg",
|
|
".jpeg": "image/jpeg",
|
|
".webp": "image/webp",
|
|
".gif": "image/gif",
|
|
}
|
|
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 单图 10MB(ARK 约束)
|
|
|
|
|
|
def resolve_in_root(
|
|
rel: str, working_dir: Path, user_root: Optional[Path]
|
|
) -> Optional[Path]:
|
|
"""三形态解析 + user_root 边界校验。命中返回解析后的绝对 Path,否则 None。"""
|
|
p = Path(rel)
|
|
candidates: list[Path] = []
|
|
if p.is_absolute():
|
|
candidates.append(p)
|
|
else:
|
|
candidates.append(working_dir / rel)
|
|
if user_root is not None:
|
|
candidates.append(user_root / rel)
|
|
root = (user_root or working_dir).resolve()
|
|
for c in candidates:
|
|
try:
|
|
rc = c.resolve()
|
|
except OSError:
|
|
continue
|
|
try:
|
|
rc.relative_to(root) # 越界(.. 逃逸 / 软链外指)直接跳过
|
|
except ValueError:
|
|
continue
|
|
if rc.is_file():
|
|
return rc
|
|
return None
|
|
|
|
|
|
def load_image_as_data_url(
|
|
rel: str,
|
|
*,
|
|
working_dir: Path,
|
|
user_root: Optional[Path],
|
|
display_fn: Callable[[Path], str],
|
|
max_bytes: int = MAX_IMAGE_BYTES,
|
|
) -> tuple[str, str, str]:
|
|
"""返回 (data_url, display_path, error)。
|
|
|
|
error 非空时前两者无意义,caller 直接把 error 当 tool 结果返回(已是 `[Error] ...` 形态)。
|
|
display_fn 传 Tool._display,把解析路径渲成对外相对串(不泄漏部署绝对路径)。
|
|
"""
|
|
resolved = resolve_in_root(rel, working_dir, user_root)
|
|
if resolved is None:
|
|
return "", "", (
|
|
f"[Error] 图片找不到或越界: {rel!r}。请传 task_dir 内已存在图片的相对路径"
|
|
f"(如 'figures/xxx.png',或工具上次返回的 saved 路径)。"
|
|
)
|
|
|
|
ext = resolved.suffix.lower()
|
|
mime = REF_MIME.get(ext)
|
|
if mime is None:
|
|
return "", "", (
|
|
f"[Error] 图片扩展名不支持: {ext or '(无)'}。"
|
|
f"仅支持 {'/'.join(sorted(REF_MIME))}。"
|
|
)
|
|
|
|
try:
|
|
raw = resolved.read_bytes()
|
|
except OSError as e:
|
|
return "", "", f"[Error] 读取图片失败: {type(e).__name__}: {e}"
|
|
if len(raw) > max_bytes:
|
|
mb = len(raw) / 1024 / 1024
|
|
return "", "", (
|
|
f"[Error] 图片 {mb:.1f}MB 超过 {max_bytes // 1024 // 1024}MB 上限。先压缩 / 缩小再传。"
|
|
)
|
|
|
|
b64 = base64.b64encode(raw).decode("ascii")
|
|
return f"data:{mime};base64,{b64}", display_fn(resolved), ""
|