feat: base CustomGenericViewSet 添加自动事务

This commit is contained in:
caoqianming 2025-09-12 12:14:29 +08:00
parent c9a2daaa48
commit 412398d461
2 changed files with 69 additions and 18 deletions

View File

@ -9,7 +9,6 @@ from django.utils.timezone import now
from user_agents import parse
import logging
from rest_framework.response import Response
from django.db import transaction
from rest_framework.exceptions import ParseError, ValidationError
from apps.utils.errors import PKS_ERROR
from rest_framework.generics import get_object_or_404
@ -91,10 +90,9 @@ class BulkCreateModelMixin(CreateModelMixin):
many = False
if isinstance(rdata, list):
many = True
with transaction.atomic():
sr = self.get_serializer(data=rdata, many=many)
sr.is_valid(raise_exception=True)
self.perform_create(sr)
sr = self.get_serializer(data=rdata, many=many)
sr.is_valid(raise_exception=True)
self.perform_create(sr)
if many:
self.after_bulk_create(sr.data)
return Response(sr.data, status=201)
@ -124,16 +122,15 @@ class BulkUpdateModelMixin(UpdateModelMixin):
queryset = self.filter_queryset(self.get_queryset())
objs = []
if isinstance(request.data, list):
with transaction.atomic():
for ind, item in enumerate(request.data):
obj = get_object_or_404(queryset, id=item['id'])
sr = self.get_serializer(obj, data=item, partial=partial)
if not sr.is_valid():
err_dict = { f'{ind+1}': sr.errors}
raise ValidationError(err_dict)
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
objs.append(sr.data)
self.after_bulk_update(objs)
for ind, item in enumerate(request.data):
obj = get_object_or_404(queryset, id=item['id'])
sr = self.get_serializer(obj, data=item, partial=partial)
if not sr.is_valid():
err_dict = { f'{ind+1}': sr.errors}
raise ValidationError(err_dict)
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
objs.append(sr.data)
self.after_bulk_update(objs)
else:
raise ParseError('提交数据非列表')
return Response(objs)

View File

@ -1,6 +1,6 @@
from django.core.cache import cache
from django.http import StreamingHttpResponse
from django.http import StreamingHttpResponse, Http404
from rest_framework.decorators import action
from rest_framework.exceptions import ParseError
from rest_framework.mixins import RetrieveModelMixin
@ -18,6 +18,9 @@ from apps.utils.serializers import ComplexSerializer
from rest_framework.throttling import UserRateThrottle
from drf_yasg.utils import swagger_auto_schema
import json
from django.db import connection
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
@ -59,6 +62,30 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
cls._initialized = True
return super().__new__(cls)
def dispatch(self, request, *args, **kwargs):
# 判断是否需要事务
if self._should_use_transaction(request):
with transaction.atomic():
return super().dispatch(request, *args, **kwargs)
else:
return super().dispatch(request, *args, **kwargs)
def _should_use_transaction(self, request):
"""判断当前请求是否需要事务"""
# 标准的写操作需要事务
if request.method in ('POST', 'PUT', 'PATCH', 'DELETE'):
# 但还要看具体是哪个action
action = self.action_map.get(request.method.lower(), {}).get(request.method.lower())
if action in ['create', 'update', 'partial_update', 'destroy']:
return True
# 自定义的action可以通过在action方法上添加装饰器或特殊属性来判断
action = getattr(self, self.action, None) if self.action else None
if action and getattr(action, 'requires_transaction', False):
return True
return False
def finalize_response(self, request, response, *args, **kwargs):
# 如果是流式响应,直接返回
if isinstance(response, StreamingHttpResponse):
@ -89,6 +116,34 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
elif hash_v_e:
return Response(hash_v_e)
def get_object(self, force_lock=False):
"""
智能加锁的get_object
- 只读请求普通查询
- 非只读请求且在事务中加锁查询
- 非只读请求但不在事务中普通查询带警告
"""
# 只读方法列表
read_only_methods = ['GET', 'HEAD', 'OPTIONS']
if self.request.method not in read_only_methods and connection.in_atomic_block:
if force_lock:
raise ParseError("当前操作需要在事务中进行,请使用事务装饰器")
# 非只读请求且在事务中:加锁查询
queryset = self.filter_queryset(self.get_queryset())
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
try:
obj = queryset.select_for_update().get(**filter_kwargs)
self.check_object_permissions(self.request, obj)
return obj
except ObjectDoesNotExist:
raise Http404
else:
# 其他情况:普通查询
return super().get_object()
def get_serializer_class(self):
action_serializer_name = f"{self.action}_serializer_class"
action_serializer_class = getattr(self, action_serializer_name, None)
@ -193,5 +248,4 @@ class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, CustomListM
CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet):
"""
增强的ModelViewSet
"""
pass
"""