from rest_framework import generics from rest_framework.decorators import api_view, permission_classes from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response from rest_framework_simplejwt.views import TokenObtainPairView from rest_framework.exceptions import PermissionDenied from .models import User from .serializers import UserSerializer, UserCreateSerializer, CustomTokenObtainPairSerializer class CustomTokenObtainPairView(TokenObtainPairView): """ 自定义JWT令牌获取视图 """ serializer_class = CustomTokenObtainPairSerializer permission_classes = [AllowAny] class UserListView(generics.ListCreateAPIView): """ 用户列表和创建视图 """ serializer_class = UserSerializer permission_classes = [IsAuthenticated] def get_queryset(self): if self.request.user.role == 'admin': return User.objects.all() return User.objects.filter(id=self.request.user.id) def get_serializer_class(self): if self.request.method == 'POST': return UserCreateSerializer return UserSerializer def perform_create(self, serializer): # 只有管理员可以创建用户 if self.request.user.role != 'admin': raise PermissionDenied("只有管理员可以创建用户") serializer.save() class UserDetailView(generics.RetrieveUpdateDestroyAPIView): """ 用户详情视图 """ queryset = User.objects.all() serializer_class = UserSerializer permission_classes = [IsAuthenticated] def perform_update(self, serializer): # 普通用户只能修改自己的信息 if self.request.user.role != 'admin' and self.request.user.id != self.get_object().id: raise PermissionDenied("无权修改其他用户信息") if self.request.user.role != 'admin': allowed_fields = {'first_name', 'last_name', 'email', 'phone'} for field in list(serializer.validated_data.keys()): if field not in allowed_fields: serializer.validated_data.pop(field) serializer.save() def perform_destroy(self, instance): # 只有管理员可以删除用户 if self.request.user.role != 'admin': raise PermissionDenied("只有管理员可以删除用户") instance.delete() @api_view(['GET']) @permission_classes([IsAuthenticated]) def current_user(request): """ 获取当前用户信息 """ serializer = UserSerializer(request.user) return Response(serializer.data)