examtest/test_server/utils/custom.py

109 lines
3.5 KiB
Python

import logging
from rest_framework import serializers
from rest_framework.generics import ListAPIView
from rest_framework.pagination import PageNumberPagination,_positive_int,InvalidPage
from rest_framework.response import Response
from rest_framework.views import exception_handler
from django.core.paginator import InvalidPage
from rest_framework.exceptions import NotFound
error_logger = logging.getLogger('error')
info_logger = logging.getLogger('info')
class CommonPagination(PageNumberPagination):
"""
分页设置
"""
page_size = 10
page_size_query_param = 'limit'
def paginate_queryset(self, queryset, request, view=None):
"""
重写该方法,确保post请求也可分页
"""
page_size = self.get_page_size(request)
if not page_size:
return None
paginator = self.django_paginator_class(queryset, page_size)
if self.page_query_param in request.query_params:
page_number = request.query_params.get(self.page_query_param)
elif self.page_query_param in request.data:
page_number = request.data.get(self.page_query_param)
else:
page_number = 1
if page_number in self.last_page_strings:
page_number = paginator.num_pages
try:
self.page = paginator.page(page_number)
except InvalidPage as exc:
msg = self.invalid_page_message.format(
page_number=page_number, message=str(exc)
)
raise NotFound(msg)
if paginator.num_pages > 1 and self.template is not None:
# The browsable API should display pagination controls.
self.display_page_controls = True
self.request = request
return list(self.page)
def get_page_size(self, request):
if self.page_size_query_param:
try:
if self.page_size_query_param in request.query_params:
page_size = request.query_params.get(self.page_size_query_param)
else:
page_size = request.data.get(self.page_size_query_param)
return _positive_int(
page_size,
strict=True,
cutoff=self.max_page_size
)
except (KeyError, ValueError):
pass
return self.page_size
class TreeSerializer(serializers.Serializer):
id = serializers.IntegerField()
label = serializers.CharField(max_length=20, source='name')
pid = serializers.PrimaryKeyRelatedField(read_only=True)
class TreeAPIView(ListAPIView):
"""
自定义树结构View
"""
serializer_class = TreeSerializer
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(queryset, many=True)
tree_dict = {}
tree_data = []
try:
for item in serializer.data:
tree_dict[item['id']] = item
for i in tree_dict:
if tree_dict[i]['pid']:
pid = tree_dict[i]['pid']
parent = tree_dict[pid]
parent.setdefault('children', []).append(tree_dict[i])
else:
tree_data.append(tree_dict[i])
results = tree_data
except KeyError:
results = serializer.data
if page is not None:
return self.get_paginated_response(results)
return Response(results)