from django.core.cache import cache from django.http import StreamingHttpResponse, Http404 from rest_framework.decorators import action from rest_framework.exceptions import ParseError from rest_framework.mixins import RetrieveModelMixin from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.viewsets import GenericViewSet from apps.system.models import DataFilter, Dept from apps.utils.mixins import (MyLoggingMixin, BulkCreateModelMixin, BulkUpdateModelMixin, BulkDestroyModelMixin, CustomListModelMixin, CustomRetrieveModelMixin, ComplexQueryMixin) from apps.utils.permission import ALL_PERMS, RbacPermission, get_user_perms_map from apps.utils.queryset import get_child_queryset2, get_child_queryset_u from apps.utils.serializers import ComplexSerializer from rest_framework.throttling import UserRateThrottle from drf_yasg.utils import swagger_auto_schema import hashlib import json from django.db import connection from django.core.exceptions import ObjectDoesNotExist from django.db.utils import NotSupportedError class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): """ 增强的GenericViewSet """ _initialized = False perms_map = None # 权限标识 throttle_classes = [UserRateThrottle] logging_methods = ['POST', 'PUT', 'PATCH', 'DELETE'] ordering_fields = '__all__' ordering = '-create_time' create_serializer_class = None update_serializer_class = None partial_update_serializer_class = None list_serializer_class = None retrieve_serializer_class = None select_related_fields = [] prefetch_related_fields = [] permission_classes = [IsAuthenticated & RbacPermission] data_filter = False # 数据权限过滤是否开启(需要RbacPermission) data_filter_field = 'belong_dept' hash_k = None cache_processing_flag = '__processing__' cache_seconds = 5 # 接口缓存时间默认5秒 filterset_fields = select_related_fields def __new__(cls, *args, **kwargs): """ 第一次实例化时,将权限标识添加到全局权限标识列表中 """ if not cls._initialized: if cls.perms_map is None: basename = kwargs["basename"] cls.perms_map = {'get': '*', 'post': '{}.create'.format(basename), 'put': '{}.update'.format( basename), 'patch': '{}.update'.format(basename), 'delete': '{}.delete'.format(basename)} for _, v in cls.perms_map.items(): if v not in ALL_PERMS and v != '*': ALL_PERMS.append(v) cls._initialized = True return super().__new__(cls) def finalize_response(self, request, response, *args, **kwargs): if self.hash_k and self.cache_seconds: cache.set(self.hash_k, response.data, timeout=self.cache_seconds) # 将结果存入缓存,设置超时时间 return super().finalize_response(request, response, *args, **kwargs) def dispatch(self, request, *args, **kwargs): self.args = args self.kwargs = kwargs request = self.initialize_request(request, *args, **kwargs) self.request = request self.headers = self.default_response_headers try: self.initial(request, *args, **kwargs) if hasattr(self, '_cached_response'): response = self._cached_response else: if request.method.lower() in self.http_method_names: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed response = handler(request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) self.response = self.finalize_response(request, response, *args, **kwargs) return self.response def _normalize_cache_value(self, value): if hasattr(value, 'lists'): return {k: v if len(v) > 1 else v[0] for k, v in value.lists()} if isinstance(value, dict): return {k: self._normalize_cache_value(v) for k, v in value.items()} if isinstance(value, list): return [self._normalize_cache_value(item) for item in value] return value def _build_cache_key(self, request): payload = { 'request_method': request.method, 'request_path': request.path, 'request_data': self._normalize_cache_value(getattr(request, 'data', {})), 'request_query': self._normalize_cache_value(request.query_params), 'request_userid': request.user.id, } payload_str = json.dumps(payload, sort_keys=True, ensure_ascii=False, default=str) return hashlib.sha256(payload_str.encode('utf-8')).hexdigest() def initial(self, request, *args, **kwargs): super().initial(request, *args, **kwargs) cache_seconds = getattr( self, f"{self.action}_cache_seconds", getattr(self, 'cache_seconds', 0)) if cache_seconds: self.cache_seconds = cache_seconds self.hash_k = self._build_cache_key(request) hash_v_e = cache.get(self.hash_k, None) if hash_v_e is None: cache.set(self.hash_k, self.cache_processing_flag, self.cache_seconds) elif hash_v_e == self.cache_processing_flag: # 说明请求正在处理 raise ParseError(f'请求忽略,请{self.cache_seconds}秒后重试') elif hash_v_e: self._cached_response = Response(hash_v_e) def get_object(self, force_lock=False): """ 智能加锁的get_object - 只读请求:普通查询 - 非只读请求且在事务中:加锁查询 - 非只读请求但不在事务中:普通查询(带警告) """ # 只读方法列表 read_only_methods = ['GET', 'HEAD', 'OPTIONS'] if self.request.method not in read_only_methods and connection.in_atomic_block: if force_lock: raise ParseError("当前操作需要在事务中进行,请使用事务装饰器") # 非只读请求且在事务中:加锁查询 queryset = self.filter_queryset(self.get_queryset()) lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} try: obj = queryset.get(**filter_kwargs) l_obj = queryset.model._base_manager.select_for_update().get(pk=obj.pk) self.check_object_permissions(self.request, l_obj) return l_obj except queryset.model.DoesNotExist: raise Http404 else: # 其他情况:普通查询 return super().get_object() def get_serializer_class(self): action_serializer_name = f"{self.action}_serializer_class" action_serializer_class = getattr(self, action_serializer_name, None) if action_serializer_class: return action_serializer_class return super().get_serializer_class() def get_queryset_custom(self, queryset): """ 自定义过滤方法可复写 """ if self.action in ["list", "retrieve", "create", "update", "partial_update", "destroy"]: return queryset elif hasattr(self, f'get_queryset_{self.action}'): return getattr(self, f'get_queryset_{self.action}')(queryset) return queryset def filter_queryset(self, queryset): # 用于性能优化 if self.select_related_fields: queryset = queryset.select_related(*self.select_related_fields) if self.prefetch_related_fields: queryset = queryset.prefetch_related(*self.prefetch_related_fields) queryset = super().filter_queryset(queryset) # 如果带有with_children查询, 出于优化需要应自动过滤掉一些内容 # if (self.request.query_params.get("with_children", "no") in ["yes", "count"] # and self.request.query_params.get("parent", None) is None): # queryset = queryset.filter(parent=None) return queryset def get_queryset(self): queryset = super().get_queryset() queryset = self.get_queryset_custom(queryset) if self.data_filter: user = self.request.user if user.is_superuser: return queryset user_perms_map = get_user_perms_map(self.request.user) if isinstance(user_perms_map, dict): if hasattr(self, 'perms_map'): perms_map = self.perms_map action_str = perms_map.get( self.request._request.method.lower(), None) if '*' in perms_map: return queryset elif action_str == '*': return queryset elif action_str in user_perms_map: new_queryset = queryset.none() for dept_id, data_range in user_perms_map[action_str].items(): dept = Dept.objects.get(id=dept_id) if data_range == DataFilter.ALL: return queryset elif data_range == DataFilter.SAMELEVE_AND_BELOW: queryset = self.filter_s_a_b(queryset, dept) elif data_range == DataFilter.THISLEVEL_AND_BELOW: queryset = self.filter_t_a_b(queryset, dept) elif data_range == DataFilter.THISLEVEL: queryset = self.filter_t(queryset, dept) elif data_range == DataFilter.MYSELF: queryset = queryset.filter(create_by=user) new_queryset = new_queryset | queryset return new_queryset else: return queryset.none() return queryset def filter_s_a_b(self, queryset, dept): """过滤同级及以下, 可重写 """ if hasattr(queryset.model, 'belong_dept'): if dept.parent: belong_depts = get_child_queryset2(dept.parent) else: belong_depts = get_child_queryset2(dept) whereis = {self.data_filter_field + '__in': belong_depts} queryset = queryset.filter(**whereis) return queryset return queryset.filter(create_by=self.request.user) def filter_t_a_b(self, queryset, dept): """过滤本级及以下, 可重写 """ if hasattr(queryset.model, 'belong_dept'): belong_depts = get_child_queryset2(dept) whereis = {self.data_filter_field + '__in': belong_depts} queryset = queryset.filter(**whereis) return queryset return queryset.filter(create_by=self.request.user) def filter_t(self, queryset, dept): """过滤本级, 可重写 """ if hasattr(queryset.model, 'belong_dept'): whereis = {self.data_filter_field: dept} queryset = queryset.filter(whereis) return queryset return queryset.filter(create_by=self.request.user) class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, CustomListModelMixin, CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet): """ 增强的ModelViewSet """ class EuModelViewSet(BulkCreateModelMixin, CustomListModelMixin, CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet): """ 不支持更新的增强ModelViewSet """