Compare commits
	
		
			No commits in common. "90b7e2087b71c3eea7e78718830bbcff5768036f" and "ebd125ca1dda01ec2bd423b455355394e6782219" have entirely different histories.
		
	
	
		
			90b7e2087b
			...
			ebd125ca1d
		
	
		| 
						 | 
					@ -9,6 +9,7 @@ from django.utils.timezone import now
 | 
				
			||||||
from user_agents import parse
 | 
					from user_agents import parse
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
 | 
					from django.db import transaction
 | 
				
			||||||
from rest_framework.exceptions import ParseError, ValidationError
 | 
					from rest_framework.exceptions import ParseError, ValidationError
 | 
				
			||||||
from apps.utils.errors import PKS_ERROR
 | 
					from apps.utils.errors import PKS_ERROR
 | 
				
			||||||
from rest_framework.generics import get_object_or_404
 | 
					from rest_framework.generics import get_object_or_404
 | 
				
			||||||
| 
						 | 
					@ -90,6 +91,7 @@ class BulkCreateModelMixin(CreateModelMixin):
 | 
				
			||||||
        many = False
 | 
					        many = False
 | 
				
			||||||
        if isinstance(rdata, list):
 | 
					        if isinstance(rdata, list):
 | 
				
			||||||
            many = True
 | 
					            many = True
 | 
				
			||||||
 | 
					        with transaction.atomic():
 | 
				
			||||||
            sr = self.get_serializer(data=rdata, many=many)
 | 
					            sr = self.get_serializer(data=rdata, many=many)
 | 
				
			||||||
            sr.is_valid(raise_exception=True)
 | 
					            sr.is_valid(raise_exception=True)
 | 
				
			||||||
            self.perform_create(sr)
 | 
					            self.perform_create(sr)
 | 
				
			||||||
| 
						 | 
					@ -122,6 +124,7 @@ class BulkUpdateModelMixin(UpdateModelMixin):
 | 
				
			||||||
            queryset = self.filter_queryset(self.get_queryset())
 | 
					            queryset = self.filter_queryset(self.get_queryset())
 | 
				
			||||||
            objs = []
 | 
					            objs = []
 | 
				
			||||||
            if isinstance(request.data, list):
 | 
					            if isinstance(request.data, list):
 | 
				
			||||||
 | 
					                with transaction.atomic():
 | 
				
			||||||
                    for ind, item in enumerate(request.data):
 | 
					                    for ind, item in enumerate(request.data):
 | 
				
			||||||
                        obj = get_object_or_404(queryset, id=item['id'])
 | 
					                        obj = get_object_or_404(queryset, id=item['id'])
 | 
				
			||||||
                        sr = self.get_serializer(obj, data=item, partial=partial)
 | 
					                        sr = self.get_serializer(obj, data=item, partial=partial)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
from django.http import StreamingHttpResponse, Http404
 | 
					from django.http import StreamingHttpResponse
 | 
				
			||||||
from rest_framework.decorators import action
 | 
					from rest_framework.decorators import action
 | 
				
			||||||
from rest_framework.exceptions import ParseError
 | 
					from rest_framework.exceptions import ParseError
 | 
				
			||||||
from rest_framework.mixins import RetrieveModelMixin
 | 
					from rest_framework.mixins import RetrieveModelMixin
 | 
				
			||||||
| 
						 | 
					@ -18,9 +18,6 @@ from apps.utils.serializers import ComplexSerializer
 | 
				
			||||||
from rest_framework.throttling import UserRateThrottle
 | 
					from rest_framework.throttling import UserRateThrottle
 | 
				
			||||||
from drf_yasg.utils import swagger_auto_schema
 | 
					from drf_yasg.utils import swagger_auto_schema
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from django.db import connection
 | 
					 | 
				
			||||||
from django.core.exceptions import ObjectDoesNotExist
 | 
					 | 
				
			||||||
from django.db import transaction
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
 | 
					class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
 | 
				
			||||||
| 
						 | 
					@ -62,30 +59,6 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
 | 
				
			||||||
            cls._initialized = True
 | 
					            cls._initialized = True
 | 
				
			||||||
        return super().__new__(cls)
 | 
					        return super().__new__(cls)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def dispatch(self, request, *args, **kwargs):
 | 
					 | 
				
			||||||
        # 判断是否需要事务
 | 
					 | 
				
			||||||
        if self._should_use_transaction(request):
 | 
					 | 
				
			||||||
            with transaction.atomic():
 | 
					 | 
				
			||||||
                return super().dispatch(request, *args, **kwargs)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return super().dispatch(request, *args, **kwargs)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    def _should_use_transaction(self, request):
 | 
					 | 
				
			||||||
        """判断当前请求是否需要事务"""
 | 
					 | 
				
			||||||
        # 标准的写操作需要事务
 | 
					 | 
				
			||||||
        if request.method in ('POST', 'PUT', 'PATCH', 'DELETE'):
 | 
					 | 
				
			||||||
            # 但还要看具体是哪个action
 | 
					 | 
				
			||||||
            action = self.action_map.get(request.method.lower(), {}).get(request.method.lower())
 | 
					 | 
				
			||||||
            if action in ['create', 'update', 'partial_update', 'destroy']:
 | 
					 | 
				
			||||||
                return True
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        # 自定义的action:可以通过在action方法上添加装饰器或特殊属性来判断
 | 
					 | 
				
			||||||
        action = getattr(self, self.action, None) if self.action else None
 | 
					 | 
				
			||||||
        if action and getattr(action, 'requires_transaction', False):
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
    def finalize_response(self, request, response, *args, **kwargs):
 | 
					    def finalize_response(self, request, response, *args, **kwargs):
 | 
				
			||||||
        # 如果是流式响应,直接返回
 | 
					        # 如果是流式响应,直接返回
 | 
				
			||||||
        if isinstance(response, StreamingHttpResponse):
 | 
					        if isinstance(response, StreamingHttpResponse):
 | 
				
			||||||
| 
						 | 
					@ -116,34 +89,6 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
 | 
				
			||||||
            elif hash_v_e:
 | 
					            elif hash_v_e:
 | 
				
			||||||
                return Response(hash_v_e)
 | 
					                return Response(hash_v_e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_object(self, force_lock=False):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        智能加锁的get_object
 | 
					 | 
				
			||||||
        - 只读请求:普通查询
 | 
					 | 
				
			||||||
        - 非只读请求且在事务中:加锁查询
 | 
					 | 
				
			||||||
        - 非只读请求但不在事务中:普通查询(带警告)
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # 只读方法列表
 | 
					 | 
				
			||||||
        read_only_methods = ['GET', 'HEAD', 'OPTIONS']
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        if self.request.method not in read_only_methods and connection.in_atomic_block:
 | 
					 | 
				
			||||||
            if force_lock:
 | 
					 | 
				
			||||||
                raise ParseError("当前操作需要在事务中进行,请使用事务装饰器")
 | 
					 | 
				
			||||||
            # 非只读请求且在事务中:加锁查询
 | 
					 | 
				
			||||||
            queryset = self.filter_queryset(self.get_queryset())
 | 
					 | 
				
			||||||
            lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
 | 
					 | 
				
			||||||
            filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                obj = queryset.select_for_update().get(**filter_kwargs)
 | 
					 | 
				
			||||||
                self.check_object_permissions(self.request, obj)
 | 
					 | 
				
			||||||
                return obj
 | 
					 | 
				
			||||||
            except ObjectDoesNotExist:
 | 
					 | 
				
			||||||
                raise Http404
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # 其他情况:普通查询
 | 
					 | 
				
			||||||
            return super().get_object()
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    def get_serializer_class(self):
 | 
					    def get_serializer_class(self):
 | 
				
			||||||
        action_serializer_name = f"{self.action}_serializer_class"
 | 
					        action_serializer_name = f"{self.action}_serializer_class"
 | 
				
			||||||
        action_serializer_class = getattr(self, action_serializer_name, None)
 | 
					        action_serializer_class = getattr(self, action_serializer_name, None)
 | 
				
			||||||
| 
						 | 
					@ -249,3 +194,4 @@ class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, CustomListM
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    增强的ModelViewSet
 | 
					    增强的ModelViewSet
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
		Loading…
	
		Reference in New Issue