from django.core.cache import cache from rest_framework.decorators import action from rest_framework.exceptions import ValidationError from rest_framework.mixins import (CreateModelMixin, ListModelMixin, RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin) from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.viewsets import GenericViewSet from apps.system.models import DataFilter, Dept from apps.utils.errors import PKS_ERROR from apps.utils.mixins import MyLoggingMixin from apps.utils.permission import ALL_PERMS, RbacPermission, get_user_perms_map from apps.utils.queryset import get_child_queryset2 from apps.utils.serializers import PkSerializer from rest_framework.throttling import UserRateThrottle class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): """ 增强的GenericViewSet """ perms_map = {} # 权限标识 throttle_classes = [UserRateThrottle] logging_methods = ['POST', 'PUT', 'PATCH', 'DELETE'] ordering_fields = '__all__' ordering = '-create_time' filterset_fields = [] create_serializer_class = None update_serializer_class = None partial_update_serializer_class = None list_serializer_class = None retrieve_serializer_class = None select_related_fields = [] prefetch_related_fields = [] permission_classes = [IsAuthenticated & RbacPermission] data_filter = False # 数据权限过滤是否开启(需要RbacPermission) def get_serializer_class(self): action_serializer_name = f"{self.action}_serializer_class" action_serializer_class = getattr(self, action_serializer_name, None) if action_serializer_class: return action_serializer_class return super().get_serializer_class() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.perms_map: for k, v in self.perms_map.items(): if v not in ALL_PERMS and v != '*': ALL_PERMS.append(v) def get_queryset(self): queryset = super().get_queryset() if self.select_related_fields: queryset = queryset.select_related(*self.select_related_fields) if self.prefetch_related_fields: queryset = queryset.prefetch_related(*self.prefetch_related_fields) if self.data_filter: if self.request.user.is_superuser: return queryset if hasattr(queryset.model, 'belong_dept'): user = self.request.user user_perms_map = cache.get('perms_' + user.id, None) if user_perms_map is None: user_perms_map = get_user_perms_map(self.request.user) if isinstance(user_perms_map, dict): if hasattr(self.view, 'perms_map'): perms_map = self.view.perms_map action_str = perms_map.get(self.request._request.method.lower(), None) if '*' in perms_map: return queryset elif action_str == '*': return queryset elif action_str in user_perms_map: new_queryset = queryset.none() for dept_id, data_range in user_perms_map[action_str].items: dept = Dept.objects.get(id=dept_id) if data_range == DataFilter.ALL: return queryset elif data_range == DataFilter.SAMELEVE_AND_BELOW: if dept.parent: belong_depts = get_child_queryset2(dept.parent) else: belong_depts = get_child_queryset2(dept) queryset = queryset.filter(belong_dept__in=belong_depts) elif data_range == DataFilter.THISLEVEL_AND_BELOW: belong_depts = get_child_queryset2(dept) queryset = queryset.filter(belong_dept__in=belong_depts) elif data_range == DataFilter.THISLEVEL: queryset = queryset.filter(belong_dept=dept) elif data_range == DataFilter.MYSELF: queryset = queryset.filter(create_by=user) new_queryset = new_queryset | queryset return new_queryset else: return queryset.none() return queryset class CustomModelViewSet(CreateModelMixin, UpdateModelMixin, ListModelMixin, RetrieveModelMixin, DestroyModelMixin, CustomGenericViewSet): """ 增强的ModelViewSet """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) # 增加默认权限标识 if not self.perms_map: basename = self.basename self.perms_map = {'get': '*', 'post': '{}:create'.format(basename), 'put': '{}:update'.format( basename), 'patch': '{}:update'.format(basename), 'delete': '{}:delete'.format(basename)} for k, v in self.perms_map.items(): if v not in ALL_PERMS and v != '*': ALL_PERMS.append(v) @action(methods=['post'], detail=False, serializer_class=PkSerializer) def deletes(self, request, *args, **kwargs): """ 批量删除 """ request_data = request.data pks = request_data.get('pks', None) if pks: self.get_queryset().filter(id__in=pks).delete(update_by=request.user) return Response(status=204) else: raise ValidationError(**PKS_ERROR) def perform_destroy(self, instance): instance.delete(update_by=self.request.user)