From 412398d46174851d50de6f87bb754529ae7b54e9 Mon Sep 17 00:00:00 2001 From: caoqianming Date: Fri, 12 Sep 2025 12:14:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20base=20CustomGenericViewSet=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E8=87=AA=E5=8A=A8=E4=BA=8B=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/utils/mixins.py | 27 +++++++++---------- apps/utils/viewsets.py | 60 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/apps/utils/mixins.py b/apps/utils/mixins.py index 9e4a277f..dff0fd27 100755 --- a/apps/utils/mixins.py +++ b/apps/utils/mixins.py @@ -9,7 +9,6 @@ from django.utils.timezone import now from user_agents import parse import logging from rest_framework.response import Response -from django.db import transaction from rest_framework.exceptions import ParseError, ValidationError from apps.utils.errors import PKS_ERROR from rest_framework.generics import get_object_or_404 @@ -91,10 +90,9 @@ class BulkCreateModelMixin(CreateModelMixin): many = False if isinstance(rdata, list): many = True - with transaction.atomic(): - sr = self.get_serializer(data=rdata, many=many) - sr.is_valid(raise_exception=True) - self.perform_create(sr) + sr = self.get_serializer(data=rdata, many=many) + sr.is_valid(raise_exception=True) + self.perform_create(sr) if many: self.after_bulk_create(sr.data) return Response(sr.data, status=201) @@ -124,16 +122,15 @@ class BulkUpdateModelMixin(UpdateModelMixin): queryset = self.filter_queryset(self.get_queryset()) objs = [] if isinstance(request.data, list): - with transaction.atomic(): - for ind, item in enumerate(request.data): - obj = get_object_or_404(queryset, id=item['id']) - sr = self.get_serializer(obj, data=item, partial=partial) - if not sr.is_valid(): - err_dict = { f'第{ind+1}': sr.errors} - raise ValidationError(err_dict) - self.perform_update(sr) # 用自带的更新,可能需要做其他操作 - objs.append(sr.data) - self.after_bulk_update(objs) + for ind, item in enumerate(request.data): + obj = get_object_or_404(queryset, id=item['id']) + sr = self.get_serializer(obj, data=item, partial=partial) + if not sr.is_valid(): + err_dict = { f'第{ind+1}': sr.errors} + raise ValidationError(err_dict) + self.perform_update(sr) # 用自带的更新,可能需要做其他操作 + objs.append(sr.data) + self.after_bulk_update(objs) else: raise ParseError('提交数据非列表') return Response(objs) diff --git a/apps/utils/viewsets.py b/apps/utils/viewsets.py index b33c973d..1e2758d8 100755 --- a/apps/utils/viewsets.py +++ b/apps/utils/viewsets.py @@ -1,6 +1,6 @@ from django.core.cache import cache -from django.http import StreamingHttpResponse +from django.http import StreamingHttpResponse, Http404 from rest_framework.decorators import action from rest_framework.exceptions import ParseError from rest_framework.mixins import RetrieveModelMixin @@ -18,6 +18,9 @@ from apps.utils.serializers import ComplexSerializer from rest_framework.throttling import UserRateThrottle from drf_yasg.utils import swagger_auto_schema import json +from django.db import connection +from django.core.exceptions import ObjectDoesNotExist +from django.db import transaction class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): @@ -59,6 +62,30 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): cls._initialized = True return super().__new__(cls) + def dispatch(self, request, *args, **kwargs): + # 判断是否需要事务 + if self._should_use_transaction(request): + with transaction.atomic(): + return super().dispatch(request, *args, **kwargs) + else: + return super().dispatch(request, *args, **kwargs) + + def _should_use_transaction(self, request): + """判断当前请求是否需要事务""" + # 标准的写操作需要事务 + if request.method in ('POST', 'PUT', 'PATCH', 'DELETE'): + # 但还要看具体是哪个action + action = self.action_map.get(request.method.lower(), {}).get(request.method.lower()) + if action in ['create', 'update', 'partial_update', 'destroy']: + return True + + # 自定义的action:可以通过在action方法上添加装饰器或特殊属性来判断 + action = getattr(self, self.action, None) if self.action else None + if action and getattr(action, 'requires_transaction', False): + return True + + return False + def finalize_response(self, request, response, *args, **kwargs): # 如果是流式响应,直接返回 if isinstance(response, StreamingHttpResponse): @@ -89,6 +116,34 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet): elif hash_v_e: return Response(hash_v_e) + def get_object(self, force_lock=False): + """ + 智能加锁的get_object + - 只读请求:普通查询 + - 非只读请求且在事务中:加锁查询 + - 非只读请求但不在事务中:普通查询(带警告) + """ + # 只读方法列表 + read_only_methods = ['GET', 'HEAD', 'OPTIONS'] + + if self.request.method not in read_only_methods and connection.in_atomic_block: + if force_lock: + raise ParseError("当前操作需要在事务中进行,请使用事务装饰器") + # 非只读请求且在事务中:加锁查询 + queryset = self.filter_queryset(self.get_queryset()) + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} + + try: + obj = queryset.select_for_update().get(**filter_kwargs) + self.check_object_permissions(self.request, obj) + return obj + except ObjectDoesNotExist: + raise Http404 + else: + # 其他情况:普通查询 + return super().get_object() + def get_serializer_class(self): action_serializer_name = f"{self.action}_serializer_class" action_serializer_class = getattr(self, action_serializer_name, None) @@ -193,5 +248,4 @@ class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, CustomListM CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet): """ 增强的ModelViewSet - """ - pass \ No newline at end of file + """ \ No newline at end of file