feat: base CustomGenericViewSet 添加自动事务
This commit is contained in:
parent
c9a2daaa48
commit
412398d461
|
@ -9,7 +9,6 @@ from django.utils.timezone import now
|
||||||
from user_agents import parse
|
from user_agents import parse
|
||||||
import logging
|
import logging
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from django.db import transaction
|
|
||||||
from rest_framework.exceptions import ParseError, ValidationError
|
from rest_framework.exceptions import ParseError, ValidationError
|
||||||
from apps.utils.errors import PKS_ERROR
|
from apps.utils.errors import PKS_ERROR
|
||||||
from rest_framework.generics import get_object_or_404
|
from rest_framework.generics import get_object_or_404
|
||||||
|
@ -91,10 +90,9 @@ class BulkCreateModelMixin(CreateModelMixin):
|
||||||
many = False
|
many = False
|
||||||
if isinstance(rdata, list):
|
if isinstance(rdata, list):
|
||||||
many = True
|
many = True
|
||||||
with transaction.atomic():
|
sr = self.get_serializer(data=rdata, many=many)
|
||||||
sr = self.get_serializer(data=rdata, many=many)
|
sr.is_valid(raise_exception=True)
|
||||||
sr.is_valid(raise_exception=True)
|
self.perform_create(sr)
|
||||||
self.perform_create(sr)
|
|
||||||
if many:
|
if many:
|
||||||
self.after_bulk_create(sr.data)
|
self.after_bulk_create(sr.data)
|
||||||
return Response(sr.data, status=201)
|
return Response(sr.data, status=201)
|
||||||
|
@ -124,16 +122,15 @@ class BulkUpdateModelMixin(UpdateModelMixin):
|
||||||
queryset = self.filter_queryset(self.get_queryset())
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
objs = []
|
objs = []
|
||||||
if isinstance(request.data, list):
|
if isinstance(request.data, list):
|
||||||
with transaction.atomic():
|
for ind, item in enumerate(request.data):
|
||||||
for ind, item in enumerate(request.data):
|
obj = get_object_or_404(queryset, id=item['id'])
|
||||||
obj = get_object_or_404(queryset, id=item['id'])
|
sr = self.get_serializer(obj, data=item, partial=partial)
|
||||||
sr = self.get_serializer(obj, data=item, partial=partial)
|
if not sr.is_valid():
|
||||||
if not sr.is_valid():
|
err_dict = { f'第{ind+1}': sr.errors}
|
||||||
err_dict = { f'第{ind+1}': sr.errors}
|
raise ValidationError(err_dict)
|
||||||
raise ValidationError(err_dict)
|
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
|
||||||
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
|
objs.append(sr.data)
|
||||||
objs.append(sr.data)
|
self.after_bulk_update(objs)
|
||||||
self.after_bulk_update(objs)
|
|
||||||
else:
|
else:
|
||||||
raise ParseError('提交数据非列表')
|
raise ParseError('提交数据非列表')
|
||||||
return Response(objs)
|
return Response(objs)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.http import StreamingHttpResponse
|
from django.http import StreamingHttpResponse, Http404
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.exceptions import ParseError
|
from rest_framework.exceptions import ParseError
|
||||||
from rest_framework.mixins import RetrieveModelMixin
|
from rest_framework.mixins import RetrieveModelMixin
|
||||||
|
@ -18,6 +18,9 @@ from apps.utils.serializers import ComplexSerializer
|
||||||
from rest_framework.throttling import UserRateThrottle
|
from rest_framework.throttling import UserRateThrottle
|
||||||
from drf_yasg.utils import swagger_auto_schema
|
from drf_yasg.utils import swagger_auto_schema
|
||||||
import json
|
import json
|
||||||
|
from django.db import connection
|
||||||
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
|
from django.db import transaction
|
||||||
|
|
||||||
|
|
||||||
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
|
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
|
||||||
|
@ -59,6 +62,30 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
|
||||||
cls._initialized = True
|
cls._initialized = True
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def dispatch(self, request, *args, **kwargs):
|
||||||
|
# 判断是否需要事务
|
||||||
|
if self._should_use_transaction(request):
|
||||||
|
with transaction.atomic():
|
||||||
|
return super().dispatch(request, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
|
def _should_use_transaction(self, request):
|
||||||
|
"""判断当前请求是否需要事务"""
|
||||||
|
# 标准的写操作需要事务
|
||||||
|
if request.method in ('POST', 'PUT', 'PATCH', 'DELETE'):
|
||||||
|
# 但还要看具体是哪个action
|
||||||
|
action = self.action_map.get(request.method.lower(), {}).get(request.method.lower())
|
||||||
|
if action in ['create', 'update', 'partial_update', 'destroy']:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 自定义的action:可以通过在action方法上添加装饰器或特殊属性来判断
|
||||||
|
action = getattr(self, self.action, None) if self.action else None
|
||||||
|
if action and getattr(action, 'requires_transaction', False):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def finalize_response(self, request, response, *args, **kwargs):
|
def finalize_response(self, request, response, *args, **kwargs):
|
||||||
# 如果是流式响应,直接返回
|
# 如果是流式响应,直接返回
|
||||||
if isinstance(response, StreamingHttpResponse):
|
if isinstance(response, StreamingHttpResponse):
|
||||||
|
@ -89,6 +116,34 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
|
||||||
elif hash_v_e:
|
elif hash_v_e:
|
||||||
return Response(hash_v_e)
|
return Response(hash_v_e)
|
||||||
|
|
||||||
|
def get_object(self, force_lock=False):
|
||||||
|
"""
|
||||||
|
智能加锁的get_object
|
||||||
|
- 只读请求:普通查询
|
||||||
|
- 非只读请求且在事务中:加锁查询
|
||||||
|
- 非只读请求但不在事务中:普通查询(带警告)
|
||||||
|
"""
|
||||||
|
# 只读方法列表
|
||||||
|
read_only_methods = ['GET', 'HEAD', 'OPTIONS']
|
||||||
|
|
||||||
|
if self.request.method not in read_only_methods and connection.in_atomic_block:
|
||||||
|
if force_lock:
|
||||||
|
raise ParseError("当前操作需要在事务中进行,请使用事务装饰器")
|
||||||
|
# 非只读请求且在事务中:加锁查询
|
||||||
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
|
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
|
||||||
|
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj = queryset.select_for_update().get(**filter_kwargs)
|
||||||
|
self.check_object_permissions(self.request, obj)
|
||||||
|
return obj
|
||||||
|
except ObjectDoesNotExist:
|
||||||
|
raise Http404
|
||||||
|
else:
|
||||||
|
# 其他情况:普通查询
|
||||||
|
return super().get_object()
|
||||||
|
|
||||||
def get_serializer_class(self):
|
def get_serializer_class(self):
|
||||||
action_serializer_name = f"{self.action}_serializer_class"
|
action_serializer_name = f"{self.action}_serializer_class"
|
||||||
action_serializer_class = getattr(self, action_serializer_name, None)
|
action_serializer_class = getattr(self, action_serializer_name, None)
|
||||||
|
@ -193,5 +248,4 @@ class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, CustomListM
|
||||||
CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet):
|
CustomRetrieveModelMixin, BulkDestroyModelMixin, ComplexQueryMixin, CustomGenericViewSet):
|
||||||
"""
|
"""
|
||||||
增强的ModelViewSet
|
增强的ModelViewSet
|
||||||
"""
|
"""
|
||||||
pass
|
|
Loading…
Reference in New Issue