diff --git a/apps/utils/viewsets.py b/apps/utils/viewsets.py index 3c5ed2c7..8b9b6944 100755 --- a/apps/utils/viewsets.py +++ b/apps/utils/viewsets.py @@ -17,6 +17,7 @@ from apps.utils.queryset import get_child_queryset2, get_child_queryset_u from apps.utils.serializers import ComplexSerializer from rest_framework.throttling import UserRateThrottle from drf_yasg.utils import swagger_auto_schema +import hashlib import json from django.db import connection from django.core.exceptions import ObjectDoesNotExist @@ -43,6 +44,7 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): data_filter = False # 数据权限过滤是否开启(需要RbacPermission) data_filter_field = 'belong_dept' hash_k = None + cache_processing_flag = '__processing__' cache_seconds = 5 # 接口缓存时间默认5秒 filterset_fields = select_related_fields @@ -67,26 +69,64 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): timeout=self.cache_seconds) # 将结果存入缓存,设置超时时间 return super().finalize_response(request, response, *args, **kwargs) + def dispatch(self, request, *args, **kwargs): + self.args = args + self.kwargs = kwargs + request = self.initialize_request(request, *args, **kwargs) + self.request = request + self.headers = self.default_response_headers + + try: + self.initial(request, *args, **kwargs) + + if hasattr(self, '_cached_response'): + response = self._cached_response + else: + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + response = handler(request, *args, **kwargs) + except Exception as exc: + response = self.handle_exception(exc) + + self.response = self.finalize_response(request, response, *args, **kwargs) + return self.response + + def _normalize_cache_value(self, value): + if hasattr(value, 'lists'): + return {k: v if len(v) > 1 else v[0] for k, v in value.lists()} + if isinstance(value, dict): + return {k: self._normalize_cache_value(v) for k, v in value.items()} + if isinstance(value, list): + return [self._normalize_cache_value(item) for item in value] + return value + + def _build_cache_key(self, request): + payload = { + 'request_method': request.method, + 'request_path': request.path, + 'request_data': self._normalize_cache_value(getattr(request, 'data', {})), + 'request_query': self._normalize_cache_value(request.query_params), + 'request_userid': request.user.id, + } + payload_str = json.dumps(payload, sort_keys=True, ensure_ascii=False, default=str) + return hashlib.sha256(payload_str.encode('utf-8')).hexdigest() + def initial(self, request, *args, **kwargs): super().initial(request, *args, **kwargs) cache_seconds = getattr( self, f"{self.action}_cache_seconds", getattr(self, 'cache_seconds', 0)) if cache_seconds: self.cache_seconds = cache_seconds - rdata = {} - rdata['request_method'] = request.method - rdata['request_path'] = request.path - rdata['request_data'] = request.data - rdata['request_query'] = request.query_params.dict() - rdata['request_userid'] = request.user.id - self.hash_k = hash(json.dumps(rdata)) + self.hash_k = self._build_cache_key(request) hash_v_e = cache.get(self.hash_k, None) if hash_v_e is None: - cache.set(self.hash_k, 'o', self.cache_seconds) - elif hash_v_e == 'o': # 说明请求正在处理 + cache.set(self.hash_k, self.cache_processing_flag, self.cache_seconds) + elif hash_v_e == self.cache_processing_flag: # 说明请求正在处理 raise ParseError(f'请求忽略,请{self.cache_seconds}秒后重试') elif hash_v_e: - return Response(hash_v_e) + self._cached_response = Response(hash_v_e) def get_object(self, force_lock=False): """ @@ -229,4 +269,4 @@ class EuModelViewSet(BulkCreateModelMixin, CustomListModelMixin, CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet): """ 不支持更新的增强ModelViewSet - """ \ No newline at end of file + """