factory/apps/utils/viewsets.py

126 lines
5.9 KiB
Python

from rest_framework.viewsets import GenericViewSet
from rest_framework.decorators import action
from apps.system.models import Dept, Post
from apps.utils.errors import PKS_ERROR
from apps.utils.mixins import CustomDestoryModelMixin, MyLoggingMixin
from apps.utils.permission import ALL_PERMS, RbacPermission, get_user_perms_map
from apps.utils.queryset import get_child_queryset2
from apps.utils.serializers import PkSerializer
from rest_framework.response import Response
from rest_framework.mixins import RetrieveModelMixin, ListModelMixin, CreateModelMixin, UpdateModelMixin
from rest_framework.permissions import IsAuthenticated
from rest_framework.exceptions import ValidationError
from django.core.cache import cache
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
"""
增强的GenericViewSet
"""
perms_map = {} # 权限标识
logging_methods = ['POST', 'PUT', 'PATCH', 'DELETE']
ordering_fields = '__all__'
filter_fields = '__all__'
ordering = '-create_time'
filterset_fields = '__all__'
create_serializer_class = None
update_serializer_class = None
partial_update_serializer_class = None
list_serializer_class = None
retrieve_serializer_class = None
select_related_fields = []
prefetch_related_fields = []
permission_classes = [IsAuthenticated & RbacPermission]
data_filter = False # 数据权限过滤是否开启(需要RbacPermission)
def get_serializer_class(self):
action_serializer_name = f"{self.action}_serializer_class"
action_serializer_class = getattr(self, action_serializer_name, None)
if action_serializer_class:
return action_serializer_class
return super().get_serializer_class()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.perms_map:
for k, v in self.perms_map.items():
if v not in ALL_PERMS and v!='*':
ALL_PERMS.append(v)
def get_queryset(self):
queryset = super().get_queryset()
if self.select_related_fields:
queryset = queryset.select_related(*self.select_related_fields)
if self.prefetch_related_fields:
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
if self.data_filter:
if self.request.user.is_superuser:
return queryset
if hasattr(queryset.model, 'belong_dept'):
user = self.request.user
user_perms_map = cache.get('perms_' + user.id, None)
if user_perms_map is None:
user_perms_map = get_user_perms_map(self.request.user)
if isinstance(user_perms_map, dict):
if hasattr(self.view, 'perms_map'):
perms_map = self.view.perms_map
action_str = perms_map.get(self.request._request.method.lower(), None)
if '*' in perms_map:
return queryset
elif action_str == '*':
return queryset
elif action_str in user_perms_map:
new_queryset = queryset.none()
for dept_id, data_range in user_perms_map[action_str].items:
dept = Dept.objects.get(id=dept_id)
if data_range == Post.POST_DATA_ALL:
return queryset
elif data_range == Post.POST_DATA_SAMELEVE_AND_BELOW:
if dept.parent:
belong_depts = get_child_queryset2(dept.parent)
else:
belong_depts = get_child_queryset2(dept)
queryset = queryset.filter(belong_dept__in = belong_depts)
elif data_range == Post.POST_DATA_THISLEVEL_AND_BELOW:
belong_depts = get_child_queryset2(dept)
queryset = queryset.filter(belong_dept__in = belong_depts)
elif data_range == Post.POST_DATA_THISLEVEL:
queryset = queryset.filter(belong_dept = dept)
elif data_range == Post.POST_DATA_THISLEVEL:
queryset = queryset.filter(create_by = user)
new_queryset = new_queryset | queryset
return new_queryset
else:
return queryset.none()
return queryset
class CustomModelViewSet(CreateModelMixin
, UpdateModelMixin, ListModelMixin, RetrieveModelMixin
, CustomDestoryModelMixin, CustomGenericViewSet):
"""
增强的ModelViewSet
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
# 增加默认权限标识
if not self.perms_map:
basename = self.basename
self.perms_map = {'get':'*', 'post':'{}.create'.format(basename)
,'put':'{}.update'.format(basename)
,'patch':'{}.update'.format(basename)
,'delete':'{}.delete'.format(basename)
,'deletes':'{}.delete'.format(basename)}
for k, v in self.perms_map.items():
if v not in ALL_PERMS and v!='*':
ALL_PERMS.append(v)
@action(methods=['post'], detail=False, serializer_class=PkSerializer)
def deletes(self,request,*args,**kwargs):
request_data = request.data
pks = request_data.get('pks',None)
if pks:
self.get_queryset().filter(id__in=pks).delete(update_by=request.user)
return Response()
else:
raise ValidationError(**PKS_ERROR)