131 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
			
		
		
	
	
			131 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
| 
 | |
| from django.core.cache import cache
 | |
| from rest_framework.decorators import action
 | |
| from rest_framework.exceptions import ValidationError
 | |
| from rest_framework.mixins import (CreateModelMixin, ListModelMixin,
 | |
|                                    RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin)
 | |
| from rest_framework.permissions import IsAuthenticated
 | |
| from rest_framework.response import Response
 | |
| from rest_framework.viewsets import GenericViewSet
 | |
| 
 | |
| from apps.system.models import DataFilter, Dept
 | |
| from apps.utils.errors import PKS_ERROR
 | |
| from apps.utils.mixins import MyLoggingMixin
 | |
| from apps.utils.permission import ALL_PERMS, RbacPermission, get_user_perms_map
 | |
| from apps.utils.queryset import get_child_queryset2
 | |
| from apps.utils.serializers import PkSerializer
 | |
| from rest_framework.throttling import UserRateThrottle
 | |
| 
 | |
| 
 | |
| class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
 | |
|     """
 | |
|     增强的GenericViewSet
 | |
|     """
 | |
|     perms_map = {}  # 权限标识
 | |
|     throttle_classes = [UserRateThrottle]
 | |
|     logging_methods = ['POST', 'PUT', 'PATCH', 'DELETE']
 | |
|     ordering_fields = '__all__'
 | |
|     ordering = '-create_time'
 | |
|     filterset_fields = []
 | |
|     create_serializer_class = None
 | |
|     update_serializer_class = None
 | |
|     partial_update_serializer_class = None
 | |
|     list_serializer_class = None
 | |
|     retrieve_serializer_class = None
 | |
|     select_related_fields = []
 | |
|     prefetch_related_fields = []
 | |
|     permission_classes = [IsAuthenticated & RbacPermission]
 | |
|     data_filter = False  # 数据权限过滤是否开启(需要RbacPermission)
 | |
| 
 | |
|     def get_serializer_class(self):
 | |
|         action_serializer_name = f"{self.action}_serializer_class"
 | |
|         action_serializer_class = getattr(self, action_serializer_name, None)
 | |
|         if action_serializer_class:
 | |
|             return action_serializer_class
 | |
|         return super().get_serializer_class()
 | |
| 
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         if self.perms_map:
 | |
|             for k, v in self.perms_map.items():
 | |
|                 if v not in ALL_PERMS and v != '*':
 | |
|                     ALL_PERMS.append(v)
 | |
| 
 | |
|     def get_queryset(self):
 | |
|         queryset = super().get_queryset()
 | |
|         if self.select_related_fields:
 | |
|             queryset = queryset.select_related(*self.select_related_fields)
 | |
|         if self.prefetch_related_fields:
 | |
|             queryset = queryset.prefetch_related(*self.prefetch_related_fields)
 | |
|         if self.data_filter:
 | |
|             if self.request.user.is_superuser:
 | |
|                 return queryset
 | |
|             if hasattr(queryset.model, 'belong_dept'):
 | |
|                 user = self.request.user
 | |
|                 user_perms_map = cache.get('perms_' + user.id, None)
 | |
|                 if user_perms_map is None:
 | |
|                     user_perms_map = get_user_perms_map(self.request.user)
 | |
|                 if isinstance(user_perms_map, dict):
 | |
|                     if hasattr(self.view, 'perms_map'):
 | |
|                         perms_map = self.view.perms_map
 | |
|                         action_str = perms_map.get(self.request._request.method.lower(), None)
 | |
|                         if '*' in perms_map:
 | |
|                             return queryset
 | |
|                         elif action_str == '*':
 | |
|                             return queryset
 | |
|                         elif action_str in user_perms_map:
 | |
|                             new_queryset = queryset.none()
 | |
|                             for dept_id, data_range in user_perms_map[action_str].items:
 | |
|                                 dept = Dept.objects.get(id=dept_id)
 | |
|                                 if data_range == DataFilter.ALL:
 | |
|                                     return queryset
 | |
|                                 elif data_range == DataFilter.SAMELEVE_AND_BELOW:
 | |
|                                     if dept.parent:
 | |
|                                         belong_depts = get_child_queryset2(dept.parent)
 | |
|                                     else:
 | |
|                                         belong_depts = get_child_queryset2(dept)
 | |
|                                     queryset = queryset.filter(belong_dept__in=belong_depts)
 | |
|                                 elif data_range == DataFilter.THISLEVEL_AND_BELOW:
 | |
|                                     belong_depts = get_child_queryset2(dept)
 | |
|                                     queryset = queryset.filter(belong_dept__in=belong_depts)
 | |
|                                 elif data_range == DataFilter.THISLEVEL:
 | |
|                                     queryset = queryset.filter(belong_dept=dept)
 | |
|                                 elif data_range == DataFilter.MYSELF:
 | |
|                                     queryset = queryset.filter(create_by=user)
 | |
|                                 new_queryset = new_queryset | queryset
 | |
|                             return new_queryset
 | |
|                         else:
 | |
|                             return queryset.none()
 | |
|         return queryset
 | |
| 
 | |
| 
 | |
| class CustomModelViewSet(CreateModelMixin, UpdateModelMixin, ListModelMixin,
 | |
|                          RetrieveModelMixin, DestroyModelMixin, CustomGenericViewSet):
 | |
|     """
 | |
|     增强的ModelViewSet
 | |
|     """
 | |
| 
 | |
|     def __init__(self, **kwargs) -> None:
 | |
|         super().__init__(**kwargs)
 | |
|         # 增加默认权限标识
 | |
|         if not self.perms_map:
 | |
|             basename = self.basename
 | |
|             self.perms_map = {'get': '*', 'post': '{}.create'.format(basename), 'put': '{}.update'.format(
 | |
|                 basename), 'patch': '{}.update'.format(basename), 'delete': '{}.delete'.format(basename)}
 | |
|         for k, v in self.perms_map.items():
 | |
|             if v not in ALL_PERMS and v != '*':
 | |
|                 ALL_PERMS.append(v)
 | |
| 
 | |
|     @action(methods=['post'], detail=False, serializer_class=PkSerializer)
 | |
|     def deletes(self, request, *args, **kwargs):
 | |
|         """
 | |
|         批量删除
 | |
|         """
 | |
|         request_data = request.data
 | |
|         pks = request_data.get('pks', None)
 | |
|         if pks:
 | |
|             self.get_queryset().filter(id__in=pks).delete()
 | |
|             return Response(status=204)
 | |
|         else:
 | |
|             raise ValidationError(**PKS_ERROR)
 |