diff --git a/apps/utils/viewsets.py b/apps/utils/viewsets.py index 20ac4b6b..8d3c50a6 100755 --- a/apps/utils/viewsets.py +++ b/apps/utils/viewsets.py @@ -23,6 +23,13 @@ from django.db import connection from django.core.exceptions import ObjectDoesNotExist from django.db.utils import NotSupportedError + +class CachedResponseHit(Exception): + def __init__(self, data): + super().__init__('cached response hit') + self.data = data + + class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): """ 增强的GenericViewSet @@ -64,35 +71,11 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): return super().__new__(cls) def finalize_response(self, request, response, *args, **kwargs): - if self.hash_k and self.cache_seconds: + if self.hash_k and self.cache_seconds and not getattr(self, '_from_cache', False): cache.set(self.hash_k, response.data, timeout=self.cache_seconds) # 将结果存入缓存,设置超时时间 return super().finalize_response(request, response, *args, **kwargs) - def dispatch(self, request, *args, **kwargs): - self.args = args - self.kwargs = kwargs - request = self.initialize_request(request, *args, **kwargs) - self.request = request - self.headers = self.default_response_headers - - try: - self.initial(request, *args, **kwargs) - - if hasattr(self, '_cached_response'): - response = self._cached_response - else: - if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), self.http_method_not_allowed) - else: - handler = self.http_method_not_allowed - response = handler(request, *args, **kwargs) - except Exception as exc: - response = self.handle_exception(exc) - - self.response = self.finalize_response(request, response, *args, **kwargs) - return self.response - def _normalize_cache_value(self, value): if hasattr(value, 'lists'): return {k: v if len(v) > 1 else v[0] for k, v in value.lists()} @@ -114,6 +97,7 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): return hashlib.sha256(payload_str.encode('utf-8')).hexdigest() def initial(self, request, *args, **kwargs): + self._from_cache = False super().initial(request, *args, **kwargs) cache_seconds = getattr( self, f"{self.action}_cache_seconds", getattr(self, 'cache_seconds', 0)) @@ -126,7 +110,13 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): elif hash_v_e == self.cache_processing_flag: # 说明请求正在处理 raise ParseError(f'请求忽略,请{self.cache_seconds}秒后重试') elif hash_v_e: - self._cached_response = Response(hash_v_e) + raise CachedResponseHit(hash_v_e) + + def handle_exception(self, exc): + if isinstance(exc, CachedResponseHit): + self._from_cache = True + return Response(exc.data) + return super().handle_exception(exc) def get_object(self, force_lock=False): """