factory/apps/utils/viewsets.py

231 lines
10 KiB
Python
Executable File

from django.core.cache import cache
from django.http import StreamingHttpResponse
from rest_framework.decorators import action
from rest_framework.exceptions import ParseError
from rest_framework.mixins import RetrieveModelMixin
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from apps.system.models import DataFilter, Dept
from apps.utils.mixins import (MyLoggingMixin, BulkCreateModelMixin, BulkUpdateModelMixin,
BulkDestroyModelMixin, CustomListModelMixin, CustomRetrieveModelMixin)
from apps.utils.permission import ALL_PERMS, RbacPermission, get_user_perms_map
from apps.utils.queryset import get_child_queryset2, get_child_queryset_u
from apps.utils.serializers import ComplexSerializer
from rest_framework.throttling import UserRateThrottle
from drf_yasg.utils import swagger_auto_schema
import json
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
"""
增强的GenericViewSet
"""
_initialized = False
perms_map = None # 权限标识
throttle_classes = [UserRateThrottle]
logging_methods = ['POST', 'PUT', 'PATCH', 'DELETE']
ordering_fields = '__all__'
ordering = '-create_time'
create_serializer_class = None
update_serializer_class = None
partial_update_serializer_class = None
list_serializer_class = None
retrieve_serializer_class = None
select_related_fields = []
prefetch_related_fields = []
permission_classes = [IsAuthenticated & RbacPermission]
data_filter = False # 数据权限过滤是否开启(需要RbacPermission)
data_filter_field = 'belong_dept'
hash_k = None
cache_seconds = 5 # 接口缓存时间默认5秒
filterset_fields = select_related_fields
def __new__(cls, *args, **kwargs):
"""
第一次实例化时,将权限标识添加到全局权限标识列表中
"""
if not cls._initialized:
if cls.perms_map is None:
basename = kwargs["basename"]
cls.perms_map = {'get': '*', 'post': '{}.create'.format(basename), 'put': '{}.update'.format(
basename), 'patch': '{}.update'.format(basename), 'delete': '{}.delete'.format(basename)}
for _, v in cls.perms_map.items():
if v not in ALL_PERMS and v != '*':
ALL_PERMS.append(v)
cls._initialized = True
return super().__new__(cls)
def finalize_response(self, request, response, *args, **kwargs):
# 如果是流式响应,直接返回
if isinstance(response, StreamingHttpResponse):
return response
if self.hash_k and self.cache_seconds:
cache.set(self.hash_k, response.data,
timeout=self.cache_seconds) # 将结果存入缓存,设置超时时间
return super().finalize_response(request, response, *args, **kwargs)
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
cache_seconds = getattr(
self, f"{self.action}_cache_seconds", getattr(self, 'cache_seconds', 0))
if cache_seconds:
self.cache_seconds = cache_seconds
rdata = {}
rdata['request_method'] = request.method
rdata['request_path'] = request.path
rdata['request_data'] = request.data
rdata['request_query'] = request.query_params.dict()
rdata['request_userid'] = request.user.id
self.hash_k = hash(json.dumps(rdata))
hash_v_e = cache.get(self.hash_k, None)
if hash_v_e is None:
cache.set(self.hash_k, 'o', self.cache_seconds)
elif hash_v_e == 'o': # 说明请求正在处理
raise ParseError(f'请求忽略,请{self.cache_seconds}秒后重试')
elif hash_v_e:
return Response(hash_v_e)
def get_serializer_class(self):
action_serializer_name = f"{self.action}_serializer_class"
action_serializer_class = getattr(self, action_serializer_name, None)
if action_serializer_class:
return action_serializer_class
return super().get_serializer_class()
def get_queryset_custom(self, queryset):
"""
自定义过滤方法可复写
"""
if self.action in ["list", "retrieve", "create", "update", "partial_update", "destroy"]:
return queryset
elif hasattr(self, f'get_queryset_{self.action}'):
return getattr(self, f'get_queryset_{self.action}')(queryset)
return queryset
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
# 如果带有with_children查询, 出于优化需要应自动过滤掉一些内容
if (self.request.query_params.get("with_children", "no") in ["yes", "count"]
and self.request.query_params.get("parent", None) is None):
queryset = queryset.filter(parent=None)
# 用于性能优化
if self.select_related_fields:
queryset = queryset.select_related(*self.select_related_fields)
if self.prefetch_related_fields:
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
return queryset
def get_queryset(self):
queryset = super().get_queryset()
queryset = self.get_queryset_custom(queryset)
if self.data_filter:
user = self.request.user
if user.is_superuser:
return queryset
user_perms_map = get_user_perms_map(self.request.user)
if isinstance(user_perms_map, dict):
if hasattr(self, 'perms_map'):
perms_map = self.perms_map
action_str = perms_map.get(
self.request._request.method.lower(), None)
if '*' in perms_map:
return queryset
elif action_str == '*':
return queryset
elif action_str in user_perms_map:
new_queryset = queryset.none()
for dept_id, data_range in user_perms_map[action_str].items():
dept = Dept.objects.get(id=dept_id)
if data_range == DataFilter.ALL:
return queryset
elif data_range == DataFilter.SAMELEVE_AND_BELOW:
queryset = self.filter_s_a_b(queryset, dept)
elif data_range == DataFilter.THISLEVEL_AND_BELOW:
queryset = self.filter_t_a_b(queryset, dept)
elif data_range == DataFilter.THISLEVEL:
queryset = self.filter_t(queryset, dept)
elif data_range == DataFilter.MYSELF:
queryset = queryset.filter(create_by=user)
new_queryset = new_queryset | queryset
return new_queryset
else:
return queryset.none()
return queryset
def filter_s_a_b(self, queryset, dept):
"""过滤同级及以下, 可重写
"""
if hasattr(queryset.model, 'belong_dept'):
if dept.parent:
belong_depts = get_child_queryset2(dept.parent)
else:
belong_depts = get_child_queryset2(dept)
whereis = {self.data_filter_field + '__in': belong_depts}
queryset = queryset.filter(**whereis)
return queryset
return queryset.filter(create_by=self.request.user)
def filter_t_a_b(self, queryset, dept):
"""过滤本级及以下, 可重写
"""
if hasattr(queryset.model, 'belong_dept'):
belong_depts = get_child_queryset2(dept)
whereis = {self.data_filter_field + '__in': belong_depts}
queryset = queryset.filter(**whereis)
return queryset
return queryset.filter(create_by=self.request.user)
def filter_t(self, queryset, dept):
"""过滤本级, 可重写
"""
if hasattr(queryset.model, 'belong_dept'):
whereis = {self.data_filter_field: dept}
queryset = queryset.filter(whereis)
return queryset
return queryset.filter(create_by=self.request.user)
class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, CustomListModelMixin,
CustomRetrieveModelMixin, BulkDestroyModelMixin, CustomGenericViewSet):
"""
增强的ModelViewSet
"""
@swagger_auto_schema(request_body=ComplexSerializer, responses={200: {}})
@action(methods=['post'], detail=False, perms_map={'post': '*'})
def cquery(self, request):
"""复杂查询
复杂查询
"""
sr = ComplexSerializer(data=request.data)
sr.is_valid(raise_exception=True)
vdata = sr.validated_data
queryset = self.filter_queryset(self.get_queryset())
new_qs = queryset.none()
try:
for m in vdata.get('querys', []):
one_qs = queryset
for n in m:
st = {}
if n['compare'] == '!': # 如果是排除比较式
st[n['field']] = n['value']
one_qs = one_qs.exclude(**st)
elif n['compare'] == '':
st[n['field']] = n['value']
one_qs = one_qs.filter(**st)
else:
st[n['field'] + '__' + n['compare']] = n['value']
one_qs = one_qs.filter(**st)
new_qs = new_qs | one_qs
except Exception as e:
raise ParseError(str(e))
page = self.paginate_queryset(new_qs)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(new_qs, many=True)
return Response(serializer.data)