fix(ops): validate log and backup paths

This commit is contained in:
caoqianming 2026-03-20 09:00:15 +08:00
parent e153ee6d54
commit 77616b37d2
1 changed files with 32 additions and 9 deletions

View File

@ -5,6 +5,7 @@ from rest_framework.views import APIView
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from django.conf import settings from django.conf import settings
import os import os
from pathlib import Path
from apps.ops.serializers import DbbackupDeleteSerializer, MemDiskSerializer, CpuSerializer, DrfRequestLogSerializer, TlogSerializer, TextListSerializer from apps.ops.serializers import DbbackupDeleteSerializer, MemDiskSerializer, CpuSerializer, DrfRequestLogSerializer, TlogSerializer, TextListSerializer
from rest_framework.exceptions import NotFound from rest_framework.exceptions import NotFound
from rest_framework.mixins import ListModelMixin from rest_framework.mixins import ListModelMixin
@ -22,6 +23,21 @@ from server.settings import BACKUP_PATH
# Create your views here. # Create your views here.
def _resolve_path_in_dir(base_dir, user_input):
base_path = Path(base_dir).resolve()
input_path = Path(user_input)
target_path = input_path.resolve() if input_path.is_absolute() else (base_path / input_path).resolve()
try:
target_path.relative_to(base_path)
except ValueError as exc:
raise NotFound(**LOG_NOT_FONED) from exc
return target_path
def _get_backup_database_dir():
return Path(BACKUP_PATH, 'database').resolve()
def index(request): def index(request):
return render(request, 'ops/index.html') return render(request, 'ops/index.html')
@ -122,7 +138,7 @@ class DiskView(APIView):
def get_file_list(file_path): def get_file_list(file_path):
dir_list = os.listdir(file_path) dir_list = os.listdir(file_path)
if not dir_list: if not dir_list:
return return []
else: else:
# 注意这里使用lambda表达式将文件按照最后修改时间顺序升序排列 # 注意这里使用lambda表达式将文件按照最后修改时间顺序升序排列
# os.path.getmtime() 函数是获取文件最后修改时间 # os.path.getmtime() 函数是获取文件最后修改时间
@ -147,7 +163,7 @@ class LogView(APIView):
break break
filepath = os.path.join(settings.LOG_PATH, file) filepath = os.path.join(settings.LOG_PATH, file)
if name: if name:
if name in filepath: if name in file:
fsize = os.path.getsize(filepath) fsize = os.path.getsize(filepath)
if fsize: if fsize:
logs.append({ logs.append({
@ -172,7 +188,10 @@ class LogDetailView(APIView):
@swagger_auto_schema(operation_summary="查看日志详情", responses=None) @swagger_auto_schema(operation_summary="查看日志详情", responses=None)
def get(self, request, name): def get(self, request, name):
try: try:
with open(os.path.join(settings.LOG_PATH, name)) as f: filepath = _resolve_path_in_dir(settings.LOG_PATH, name)
if not filepath.is_file():
raise NotFound(**LOG_NOT_FONED)
with filepath.open(encoding='utf-8') as f:
data = f.read() data = f.read()
return Response(data) return Response(data)
except Exception: except Exception:
@ -184,8 +203,10 @@ class DbBackupDeleteView(APIView):
@swagger_auto_schema(operation_summary="删除备份", responses={204: None}) @swagger_auto_schema(operation_summary="删除备份", responses={204: None})
def delete(self, request, filepath): def delete(self, request, filepath):
if BACKUP_PATH in filepath: target_path = _resolve_path_in_dir(_get_backup_database_dir(), filepath)
os.remove(filepath) if not target_path.is_file():
raise NotFound(**LOG_NOT_FONED)
target_path.unlink()
return Response() return Response()
@ -196,21 +217,23 @@ class DbBackupView(APIView):
def post(self, request): def post(self, request):
filepaths = request.data.get('filepaths', []) filepaths = request.data.get('filepaths', [])
for i in filepaths: for i in filepaths:
if BACKUP_PATH in i: target_path = _resolve_path_in_dir(_get_backup_database_dir(), i)
os.remove(i) if not target_path.is_file():
raise NotFound(**LOG_NOT_FONED)
target_path.unlink()
return Response() return Response()
@swagger_auto_schema(operation_summary="查看最近的备份列表", responses={200: TextListSerializer(many=True)}, request_body=None) @swagger_auto_schema(operation_summary="查看最近的备份列表", responses={200: TextListSerializer(many=True)}, request_body=None)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
items = [] items = []
name = request.GET.get('name', None) name = request.GET.get('name', None)
backpath = settings.BACKUP_PATH + '/database' backpath = str(_get_backup_database_dir())
for file in get_file_list(backpath): for file in get_file_list(backpath):
if len(items) > 50: if len(items) > 50:
break break
filepath = os.path.join(backpath, file) filepath = os.path.join(backpath, file)
if name: if name:
if name in filepath: if name in file:
fsize = os.path.getsize(filepath) fsize = os.path.getsize(filepath)
if fsize: if fsize:
items.append({ items.append({