feat: base CustomGenericViewSet 添加自动事务
This commit is contained in:
parent
c9a2daaa48
commit
412398d461
|
@ -9,7 +9,6 @@ from django.utils.timezone import now
|
|||
from user_agents import parse
|
||||
import logging
|
||||
from rest_framework.response import Response
|
||||
from django.db import transaction
|
||||
from rest_framework.exceptions import ParseError, ValidationError
|
||||
from apps.utils.errors import PKS_ERROR
|
||||
from rest_framework.generics import get_object_or_404
|
||||
|
@ -91,10 +90,9 @@ class BulkCreateModelMixin(CreateModelMixin):
|
|||
many = False
|
||||
if isinstance(rdata, list):
|
||||
many = True
|
||||
with transaction.atomic():
|
||||
sr = self.get_serializer(data=rdata, many=many)
|
||||
sr.is_valid(raise_exception=True)
|
||||
self.perform_create(sr)
|
||||
sr = self.get_serializer(data=rdata, many=many)
|
||||
sr.is_valid(raise_exception=True)
|
||||
self.perform_create(sr)
|
||||
if many:
|
||||
self.after_bulk_create(sr.data)
|
||||
return Response(sr.data, status=201)
|
||||
|
@ -124,16 +122,15 @@ class BulkUpdateModelMixin(UpdateModelMixin):
|
|||
queryset = self.filter_queryset(self.get_queryset())
|
||||
objs = []
|
||||
if isinstance(request.data, list):
|
||||
with transaction.atomic():
|
||||
for ind, item in enumerate(request.data):
|
||||
obj = get_object_or_404(queryset, id=item['id'])
|
||||
sr = self.get_serializer(obj, data=item, partial=partial)
|
||||
if not sr.is_valid():
|
||||
err_dict = { f'第{ind+1}': sr.errors}
|
||||
raise ValidationError(err_dict)
|
||||
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
|
||||
objs.append(sr.data)
|
||||
self.after_bulk_update(objs)
|
||||
for ind, item in enumerate(request.data):
|
||||
obj = get_object_or_404(queryset, id=item['id'])
|
||||
sr = self.get_serializer(obj, data=item, partial=partial)
|
||||
if not sr.is_valid():
|
||||
err_dict = { f'第{ind+1}': sr.errors}
|
||||
raise ValidationError(err_dict)
|
||||
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
|
||||
objs.append(sr.data)
|
||||
self.after_bulk_update(objs)
|
||||
else:
|
||||
raise ParseError('提交数据非列表')
|
||||
return Response(objs)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
from django.core.cache import cache
|
||||
from django.http import StreamingHttpResponse
|
||||
from django.http import StreamingHttpResponse, Http404
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ParseError
|
||||
from rest_framework.mixins import RetrieveModelMixin
|
||||
|
@ -18,6 +18,9 @@ from apps.utils.serializers import ComplexSerializer
|
|||
from rest_framework.throttling import UserRateThrottle
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
import json
|
||||
from django.db import connection
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import transaction
|
||||
|
||||
|
||||
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
|
||||
|
@ -59,6 +62,30 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
|
|||
cls._initialized = True
|
||||
return super().__new__(cls)
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
# 判断是否需要事务
|
||||
if self._should_use_transaction(request):
|
||||
with transaction.atomic():
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
else:
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def _should_use_transaction(self, request):
|
||||
"""判断当前请求是否需要事务"""
|
||||
# 标准的写操作需要事务
|
||||
if request.method in ('POST', 'PUT', 'PATCH', 'DELETE'):
|
||||
# 但还要看具体是哪个action
|
||||
action = self.action_map.get(request.method.lower(), {}).get(request.method.lower())
|
||||
if action in ['create', 'update', 'partial_update', 'destroy']:
|
||||
return True
|
||||
|
||||
# 自定义的action:可以通过在action方法上添加装饰器或特殊属性来判断
|
||||
action = getattr(self, self.action, None) if self.action else None
|
||||
if action and getattr(action, 'requires_transaction', False):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def finalize_response(self, request, response, *args, **kwargs):
|
||||
# 如果是流式响应,直接返回
|
||||
if isinstance(response, StreamingHttpResponse):
|
||||
|
@ -89,6 +116,34 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
|
|||
elif hash_v_e:
|
||||
return Response(hash_v_e)
|
||||
|
||||
def get_object(self, force_lock=False):
|
||||
"""
|
||||
智能加锁的get_object
|
||||
- 只读请求:普通查询
|
||||
- 非只读请求且在事务中:加锁查询
|
||||
- 非只读请求但不在事务中:普通查询(带警告)
|
||||
"""
|
||||
# 只读方法列表
|
||||
read_only_methods = ['GET', 'HEAD', 'OPTIONS']
|
||||
|
||||
if self.request.method not in read_only_methods and connection.in_atomic_block:
|
||||
if force_lock:
|
||||
raise ParseError("当前操作需要在事务中进行,请使用事务装饰器")
|
||||
# 非只读请求且在事务中:加锁查询
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
|
||||
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
|
||||
|
||||
try:
|
||||
obj = queryset.select_for_update().get(**filter_kwargs)
|
||||
self.check_object_permissions(self.request, obj)
|
||||
return obj
|
||||
except ObjectDoesNotExist:
|
||||
raise Http404
|
||||
else:
|
||||
# 其他情况:普通查询
|
||||
return super().get_object()
|
||||
|
||||
def get_serializer_class(self):
|
||||
action_serializer_name = f"{self.action}_serializer_class"
|
||||
action_serializer_class = getattr(self, action_serializer_name, None)
|
||||
|
@ -193,5 +248,4 @@ class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, CustomListM
|
|||
CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet):
|
||||
"""
|
||||
增强的ModelViewSet
|
||||
"""
|
||||
pass
|
||||
"""
|
Loading…
Reference in New Issue