diff --git a/apps/utils/mixins.py b/apps/utils/mixins.py index c3df566c..66aa9b1d 100755 --- a/apps/utils/mixins.py +++ b/apps/utils/mixins.py @@ -19,6 +19,8 @@ from rest_framework.decorators import action from apps.utils.serializers import ComplexSerializer from django.db.models import F from django.db import transaction +from collections import defaultdict +from django.db import models # 实例化myLogger myLogger = logging.getLogger('log') @@ -214,17 +216,44 @@ class CustomListModelMixin(ListModelMixin): type=openapi.TYPE_STRING, required=False), ]) def list(self, request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - + model = queryset.model page = self.paginate_queryset(queryset) if page is not None: - serializer = self.get_serializer(page, many=True) - data = self.add_info_for_list(serializer.data) - return self.get_paginated_response(data) - - serializer = self.get_serializer(queryset, many=True) + objs = page + else: + objs = queryset + # ===== 默认空映射(不支持 parent 的 model 也能正常返回) ===== + children_map = {} + count_map = {} + # ===== 只在 model 有 parent FK 时才构建 ===== + if self.request.query_params.get('with_children', 'no') in ['yes', 'count']: + has_parent = any( + f.name == 'parent' and isinstance(f, models.ForeignKey) + for f in model._meta.get_fields() + ) + if has_parent: + parent_ids = [obj.id for obj in objs] + children_map = defaultdict(list) + count_map = defaultdict(int) + if parent_ids: + children_qs = self.get_queryset().filter(parent_id__in=parent_ids) + for child in children_qs: + children_map[child.parent_id].append(child) + count_map[child.parent_id] += 1 + # ===== 序列化 ===== + serializer = self.get_serializer( + objs, + many=True, + context={ + 'request': request, + 'children_map': children_map, + 'count_map': count_map, + } + ) data = self.add_info_for_list(serializer.data) + if page is not None: + return self.get_paginated_response(data) return Response(data) def add_info_for_list(self, data): diff --git a/apps/utils/serializers.py b/apps/utils/serializers.py index 263ec5f2..d79aae1f 100755 --- a/apps/utils/serializers.py +++ b/apps/utils/serializers.py @@ -25,21 +25,20 @@ class TreeSerializerMixin: super().__init__(*args, **kwargs) request = self.context.get('request', None) self.with_children = request.query_params.get('with_children', 'no') if request else 'no' - if self.with_children in ['yes', 'count']: - if 'children' not in self.fields: - self.fields['children'] = serializers.SerializerMethodField() - if 'children_count' not in self.fields: - self.fields['children_count'] = serializers.SerializerMethodField() + if self.with_children in ('yes', 'count'): + self.fields.setdefault('children_count', serializers.SerializerMethodField()) + if self.with_children == 'yes': + self.fields.setdefault('children', serializers.SerializerMethodField()) def get_children(self, obj): if hasattr(obj, 'parent') and self.with_children == 'yes': - serializer_class = self.__class__ - return serializer_class(obj.__class__.objects.filter(parent=obj), many=True, context=self.context).data + children = self.context.get('children_map', {}).get(obj.id, []) + return self.__class__(children, many=True, context=self.context).data return [] def get_children_count(self, obj): if hasattr(obj, 'parent') and self.with_children in ['yes', 'count']: - return obj.__class__.objects.filter(parent=obj).count() + return self.context.get('count_map', {}).get(obj.id, 0) return 0 class CustomModelSerializer(DynamicFieldsMixin, TreeSerializerMixin, serializers.ModelSerializer):