diff --git a/apps/ops/views.py b/apps/ops/views.py index 629e1868..2c8c91e0 100644 --- a/apps/ops/views.py +++ b/apps/ops/views.py @@ -5,6 +5,7 @@ from rest_framework.views import APIView from rest_framework.permissions import IsAuthenticated from django.conf import settings import os +from pathlib import Path from apps.ops.serializers import DbbackupDeleteSerializer, MemDiskSerializer, CpuSerializer, DrfRequestLogSerializer, TlogSerializer, TextListSerializer from rest_framework.exceptions import NotFound from rest_framework.mixins import ListModelMixin @@ -22,6 +23,21 @@ from server.settings import BACKUP_PATH # Create your views here. +def _resolve_path_in_dir(base_dir, user_input): + base_path = Path(base_dir).resolve() + input_path = Path(user_input) + target_path = input_path.resolve() if input_path.is_absolute() else (base_path / input_path).resolve() + try: + target_path.relative_to(base_path) + except ValueError as exc: + raise NotFound(**LOG_NOT_FONED) from exc + return target_path + + +def _get_backup_database_dir(): + return Path(BACKUP_PATH, 'database').resolve() + + def index(request): return render(request, 'ops/index.html') @@ -122,7 +138,7 @@ class DiskView(APIView): def get_file_list(file_path): dir_list = os.listdir(file_path) if not dir_list: - return + return [] else: # 注意,这里使用lambda表达式,将文件按照最后修改时间顺序升序排列 # os.path.getmtime() 函数是获取文件最后修改时间 @@ -147,7 +163,7 @@ class LogView(APIView): break filepath = os.path.join(settings.LOG_PATH, file) if name: - if name in filepath: + if name in file: fsize = os.path.getsize(filepath) if fsize: logs.append({ @@ -172,7 +188,10 @@ class LogDetailView(APIView): @swagger_auto_schema(operation_summary="查看日志详情", responses=None) def get(self, request, name): try: - with open(os.path.join(settings.LOG_PATH, name)) as f: + filepath = _resolve_path_in_dir(settings.LOG_PATH, name) + if not filepath.is_file(): + raise NotFound(**LOG_NOT_FONED) + with filepath.open(encoding='utf-8') as f: data = f.read() return Response(data) except Exception: @@ -184,8 +203,10 @@ class DbBackupDeleteView(APIView): @swagger_auto_schema(operation_summary="删除备份", responses={204: None}) def delete(self, request, filepath): - if BACKUP_PATH in filepath: - os.remove(filepath) + target_path = _resolve_path_in_dir(_get_backup_database_dir(), filepath) + if not target_path.is_file(): + raise NotFound(**LOG_NOT_FONED) + target_path.unlink() return Response() @@ -196,21 +217,23 @@ class DbBackupView(APIView): def post(self, request): filepaths = request.data.get('filepaths', []) for i in filepaths: - if BACKUP_PATH in i: - os.remove(i) + target_path = _resolve_path_in_dir(_get_backup_database_dir(), i) + if not target_path.is_file(): + raise NotFound(**LOG_NOT_FONED) + target_path.unlink() return Response() @swagger_auto_schema(operation_summary="查看最近的备份列表", responses={200: TextListSerializer(many=True)}, request_body=None) def get(self, request, *args, **kwargs): items = [] name = request.GET.get('name', None) - backpath = settings.BACKUP_PATH + '/database' + backpath = str(_get_backup_database_dir()) for file in get_file_list(backpath): if len(items) > 50: break filepath = os.path.join(backpath, file) if name: - if name in filepath: + if name in file: fsize = os.path.getsize(filepath) if fsize: items.append({