feat: utils增加批量接口mixin

This commit is contained in:
caoqianming 2023-07-05 13:43:30 +08:00
parent 388d9213fd
commit 0f1bb7fd72
3 changed files with 113 additions and 39 deletions

View File

@ -8,6 +8,13 @@ from django.db import connection
from django.utils.timezone import now
from user_agents import parse
import logging
from rest_framework.response import Response
from django.db import transaction
from rest_framework.exceptions import ParseError, ValidationError
from apps.utils.errors import PKS_ERROR
from rest_framework.generics import get_object_or_404
from drf_yasg.utils import swagger_auto_schema
from apps.utils.serializers import PkSerializer
# 实例化myLogger
myLogger = logging.getLogger('log')
@ -66,6 +73,105 @@ class CustomUpdateModelMixin(UpdateModelMixin):
serializer.save(update_by=self.request.user)
class BulkCreateModelMixin(CreateModelMixin):
def after_bulk_create(self, objs):
pass
def create(self, request, *args, **kwargs):
"""创建(支持批量)
创建(支持批量)
"""
rdata = request.data
many = False
if isinstance(rdata, list):
many = True
with transaction.atomic():
sr = self.get_serializer(rdata, many=many)
sr.is_valid(raise_exception=True)
self.perform_create(sr)
if many:
self.after_bulk_create(sr.data)
return Response(sr.data, status=201)
class BulkUpdateModelMixin(UpdateModelMixin):
def after_bulk_update(self, objs):
pass
def partial_update(self, request, *args, **kwargs):
"""部分更新(支持批量)
部分更新(支持批量)
"""
kwargs['partial'] = True
return self.update(request, *args, **kwargs)
def update(self, request, *args, **kwargs):
"""更新(支持批量)
更新(支持批量)
"""
partial = kwargs.pop('partial', False)
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
if kwargs[lookup_url_kwarg] == 'bulk': # 如果是批量操作
queryset = self.filter_queryset(self.get_queryset())
objs = []
if isinstance(request.data, list):
with transaction.atomic():
for ind, item in enumerate(request.data):
obj = get_object_or_404(queryset, id=item['id'])
sr = self.get_serializer(obj, data=item, partial=partial)
if not sr.is_valid():
err_dict = { f'{ind+1}': sr.errors}
raise ValidationError(err_dict)
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
objs.append(sr.data)
self.after_bulk_update(objs)
else:
raise ParseError('提交数据非列表')
return Response(objs)
else:
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return Response(serializer.data)
class BulkDestroyModelMixin(DestroyModelMixin):
@swagger_auto_schema(request_body=PkSerializer)
def destroy(self, request, *args, **kwargs):
"""删除(支持批量)
删除(支持批量和硬删除(需管理员权限))
"""
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
if kwargs[lookup_url_kwarg] == 'bulk': # 如果是批量操作
queryset = self.filter_queryset(self.get_queryset())
ids = request.data.get('ids', None)
soft = request.data.get('soft', True)
if ids:
if soft is True:
queryset.filter(id__in=ids).delete()
elif soft is False:
try:
queryset.model.objects.get_queryset(
all=True).filter(id__in=ids).delete(soft=False)
except Exception:
queryset.filter(id__in=ids).delete()
return Response(status=204)
else:
raise ValidationError(**PKS_ERROR)
else:
instance = self.get_object()
self.perform_destroy(instance)
return Response(status=204)
class MyLoggingMixin(object):
"""Mixin to log requests"""

View File

@ -6,7 +6,8 @@ from rest_framework.request import Request
class PkSerializer(serializers.Serializer):
pks = serializers.ListField(child=serializers.CharField(max_length=20), label="主键ID列表")
ids = serializers.ListField(child=serializers.CharField(max_length=20), label="主键ID列表")
soft = serializers.BooleanField(label="是否软删除", default=True, required=False)
class GenSignatureSerializer(serializers.Serializer):

View File

@ -10,14 +10,16 @@ from rest_framework.viewsets import GenericViewSet
from apps.system.models import DataFilter, Dept, User
from apps.utils.errors import PKS_ERROR
from apps.utils.mixins import MyLoggingMixin
from apps.utils.mixins import MyLoggingMixin, BulkCreateModelMixin, BulkUpdateModelMixin, BulkDestroyModelMixin
from apps.utils.permission import ALL_PERMS, RbacPermission, get_user_perms_map
from apps.utils.queryset import get_child_queryset2
from apps.utils.serializers import PkSerializer, ComplexSerializer
from rest_framework.throttling import UserRateThrottle
from drf_yasg.utils import swagger_auto_schema
from apps.utils.decorators import idempotent
from django.db import transaction
import json
from rest_framework.generics import get_object_or_404
class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
@ -157,8 +159,8 @@ class CustomGenericViewSet(MyLoggingMixin, GenericViewSet):
return queryset.filter(create_by=self.request.user)
class CustomModelViewSet(CreateModelMixin, UpdateModelMixin, ListModelMixin,
RetrieveModelMixin, DestroyModelMixin, CustomGenericViewSet):
class CustomModelViewSet(BulkCreateModelMixin, BulkUpdateModelMixin, ListModelMixin,
RetrieveModelMixin, BulkDestroyModelMixin, CustomGenericViewSet):
"""
增强的ModelViewSet
"""
@ -173,41 +175,6 @@ class CustomModelViewSet(CreateModelMixin, UpdateModelMixin, ListModelMixin,
if v not in ALL_PERMS and v != '*':
ALL_PERMS.append(v)
# @idempotent()
# def create(self, request, *args, **kwargs):
# return super().create(request, *args, **kwargs)
@action(methods=['post'], detail=False, serializer_class=PkSerializer)
def deletes(self, request, *args, **kwargs):
"""
批量删除
"""
request_data = request.data
pks = request_data.get('pks', None)
if pks:
self.get_queryset().filter(id__in=pks).delete()
return Response(status=204)
else:
raise ValidationError(**PKS_ERROR)
@action(methods=['post'], detail=False, serializer_class=PkSerializer, permission_classes=[IsAdminUser])
def deletes_hard(self, request, *args, **kwargs):
"""批量物理删除
批量物理删除
"""
request_data = request.data
pks = request_data.get('pks', None)
if pks:
try:
self.get_queryset().model.objects.get_queryset(
all=True).filter(id__in=pks).delete(soft=False)
except Exception:
self.get_queryset().filter(id__in=pks).delete()
return Response(status=204)
else:
raise ValidationError(**PKS_ERROR)
@swagger_auto_schema(request_body=ComplexSerializer, responses={200: {}})
@action(methods=['post'], detail=False, perms_map={'post': '*'})
def cquery(self, request):