factory/apps/utils/mixins.py

514 lines
18 KiB
Python
Executable File

import uuid
from rest_framework.mixins import CreateModelMixin, UpdateModelMixin, DestroyModelMixin, ListModelMixin, RetrieveModelMixin
import ast
import ipaddress
import traceback
from apps.monitor.models import DrfRequestLog
from django.db import connection
from django.utils.timezone import now
from user_agents import parse
import logging
from rest_framework.response import Response
from rest_framework.exceptions import ParseError, ValidationError
from apps.utils.errors import PKS_ERROR
from rest_framework.generics import get_object_or_404
from drf_yasg.utils import swagger_auto_schema
from drf_yasg import openapi
from apps.utils.serializers import PkSerializer
from rest_framework.decorators import action
from apps.utils.serializers import ComplexSerializer
from django.db.models import F
from django.db import transaction
# 实例化myLogger
myLogger = logging.getLogger('log')
class CreateUpdateModelAMixin:
"""
业务用基本表A用
"""
def perform_create(self, serializer):
serializer.save(create_by=self.request.user)
def perform_update(self, serializer):
serializer.save(update_by=self.request.user)
class CreateUpdateModelBMixin:
"""
业务用基本表B用
"""
def perform_create(self, serializer):
serializer.save(create_by=self.request.user, belong_dept=self.request.user.dept)
def perform_update(self, serializer):
serializer.save(update_by=self.request.user)
class CreateUpdateCustomMixin:
"""
整合
"""
def perform_create(self, serializer):
if hasattr(self.queryset.model, 'belong_dept'):
serializer.save(create_by=self.request.user, belong_dept=self.request.user.dept)
else:
serializer.save(create_by=self.request.user)
def perform_update(self, serializer):
serializer.save(update_by=self.request.user)
class CustomCreateModelMixin(CreateModelMixin):
def perform_create(self, serializer):
if hasattr(self.queryset.model, 'belong_dept'):
serializer.save(create_by=self.request.user, belong_dept=self.request.user.dept)
else:
serializer.save(create_by=self.request.user)
class CustomUpdateModelMixin(UpdateModelMixin):
def perform_update(self, serializer):
serializer.save(update_by=self.request.user)
class BulkCreateModelMixin(CreateModelMixin):
def after_bulk_create(self, objs):
pass
@transaction.atomic
def create(self, request, *args, **kwargs):
"""创建(支持批量)
创建(支持批量)
"""
rdata = request.data
many = False
if isinstance(rdata, list):
many = True
sr = self.get_serializer(data=rdata, many=many)
sr.is_valid(raise_exception=True)
self.perform_create(sr)
if many:
self.after_bulk_create(sr.data)
return Response(sr.data, status=201)
class BulkUpdateModelMixin(UpdateModelMixin):
def after_bulk_update(self, objs):
pass
@transaction.atomic
def partial_update(self, request, *args, **kwargs):
"""部分更新(支持批量)
部分更新(支持批量)
"""
kwargs['partial'] = True
return self.update(request, *args, **kwargs)
@transaction.atomic
def update(self, request, *args, **kwargs):
"""更新(支持批量)
更新(支持批量)
"""
partial = kwargs.pop('partial', False)
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
if kwargs[lookup_url_kwarg] == 'bulk': # 如果是批量操作
queryset = self.filter_queryset(self.get_queryset())
objs = []
if isinstance(request.data, list):
for ind, item in enumerate(request.data):
obj = get_object_or_404(queryset, id=item['id'])
sr = self.get_serializer(obj, data=item, partial=partial)
if not sr.is_valid():
err_dict = { f'{ind+1}': sr.errors}
raise ValidationError(err_dict)
self.perform_update(sr) # 用自带的更新,可能需要做其他操作
objs.append(sr.data)
self.after_bulk_update(objs)
else:
raise ParseError('提交数据非列表')
return Response(objs)
else:
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return Response(serializer.data)
class BulkDestroyModelMixin(DestroyModelMixin):
@swagger_auto_schema(request_body=PkSerializer)
@transaction.atomic
def destroy(self, request, *args, **kwargs):
"""删除(支持批量)
删除(支持批量和硬删除(需管理员权限))
"""
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
if kwargs[lookup_url_kwarg] == 'bulk': # 如果是批量操作
queryset = self.filter_queryset(self.get_queryset())
ids = request.data.get('ids', None)
soft = request.data.get('soft', True)
if not soft and not request.user.is_superuser:
raise ParseError('非管理员不支持物理删除')
if ids:
if soft is True:
queryset.filter(id__in=ids).delete()
elif soft is False:
try:
queryset.model.objects.get_queryset(
all=True).filter(id__in=ids).delete(soft=False)
except Exception:
queryset.filter(id__in=ids).delete()
return Response(status=204)
else:
raise ValidationError(**PKS_ERROR)
else:
instance = self.get_object()
if hasattr(instance, "parent") and instance.__class__.objects.filter(parent=instance).exists():
raise ParseError('存在子级,无法删除')
self.perform_destroy(instance)
return Response(status=204)
class CustomRetrieveModelMixin(RetrieveModelMixin):
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
data = self.add_info_for_item(serializer.data)
return Response(data)
def add_info_for_item(self, data):
"""给dict返回数据添加额外信息
给dict返回数据添加额外信息
"""
return data
class CustomListModelMixin(ListModelMixin):
@swagger_auto_schema(manual_parameters=[
openapi.Parameter(name="query", in_=openapi.IN_QUERY, description="定制返回数据",
type=openapi.TYPE_STRING, required=False),
openapi.Parameter(name="with_children", in_=openapi.IN_QUERY, description="带有children(yes/no/count)",
type=openapi.TYPE_STRING, required=False),
])
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
data = self.add_info_for_list(serializer.data)
return self.get_paginated_response(data)
serializer = self.get_serializer(queryset, many=True)
data = self.add_info_for_list(serializer.data)
return Response(data)
def add_info_for_list(self, data):
"""给list返回数据添加额外信息
给list返回数据添加额外信息
"""
return data
class ComplexQueryMixin:
"""复杂查询
"""
@swagger_auto_schema(request_body=ComplexSerializer, responses={200: {}})
@action(methods=['post'], detail=False, perms_map={'post': '*'})
def cquery(self, request):
"""复杂查询
复杂查询
"""
sr = ComplexSerializer(data=request.data)
sr.is_valid(raise_exception=True)
vdata = sr.validated_data
queryset = self.get_queryset()
querys = vdata.get('querys', [])
annotate_field_list = vdata.get('annotate_field_list', [])
if not querys:
new_qs = queryset
else:
new_qs = queryset.none()
try:
for m in querys:
one_qs = queryset
for n in m:
st = {}
if n['compare'] == '!': # 如果是排除比较式
st[n['field']] = n['value']
one_qs = one_qs.exclude(**st)
elif n['compare'] == '':
st[n['field']] = n['value']
one_qs = one_qs.filter(**st)
else:
st[n['field'] + '__' + n['compare']] = n['value']
one_qs = one_qs.filter(**st)
new_qs = new_qs | one_qs
except Exception as e:
raise ParseError(str(e))
if annotate_field_list:
annotate_dict = getattr(self, "annotate_dict", {})
if annotate_dict:
filtered_annotate_dict = { key: annotate_dict[key] for key in annotate_field_list if key in annotate_dict}
new_qs = new_qs.annotate(**filtered_annotate_dict)
ordering = vdata.get('ordering', None)
if not ordering:
ordering = getattr(self, 'ordering', None)
if isinstance(ordering, str):
ordering = ordering.replace('\n', '').replace(' ', '')
ordering = ordering.split(',')
order_fields = []
if ordering:
for item in ordering:
if item.startswith('-'):
# JSONField 排序只能用字符串,不要 F
order_fields.append(F(item[1:]).desc(nulls_last=True) if '__' not in item else item)
else:
order_fields.append(F(item).asc(nulls_last=True) if '__' not in item else item)
new_qs = new_qs.order_by(*order_fields)
page = self.paginate_queryset(new_qs)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(new_qs, many=True)
return Response(serializer.data)
class MyLoggingMixin(object):
"""Mixin to log requests"""
CLEANED_SUBSTITUTE = "********************"
# logging_methods = "__all__"
logging_methods = '__all__'
sensitive_fields = {}
def __init__(self, *args, **kwargs):
assert isinstance(
self.CLEANED_SUBSTITUTE, str
), "CLEANED_SUBSTITUTE must be a string."
super().__init__(*args, **kwargs)
def initial(self, request, *args, **kwargs):
request_id = uuid.uuid4()
self.log = {"requested_at": now(), "id": request_id}
setattr(request, 'request_id', request_id)
if not getattr(self, "decode_request_body", False):
self.log["data"] = ""
else:
self.log["data"] = self._clean_data(request.body)
super().initial(request, *args, **kwargs)
try:
# Accessing request.data *for the first time* parses the request body, which may raise
# ParseError and UnsupportedMediaType exceptions. It's important not to swallow these,
# as (depending on implementation details) they may only get raised this once, and
# DRF logic needs them to be raised by the view for error handling to work correctly.
data = self.request.data.dict()
except AttributeError:
data = self.request.data
self.log["data"] = self._clean_data(data)
def handle_exception(self, exc):
response = super().handle_exception(exc)
self.log["errors"] = traceback.format_exc()
return response
def finalize_response(self, request, response, *args, **kwargs):
response = super().finalize_response(
request, response, *args, **kwargs
)
self.log["response_ms"] = self._get_response_ms()
# Ensure backward compatibility for those using _should_log hook
should_log = (
self._should_log if hasattr(self, "_should_log") else self.should_log
)
if should_log(request, response):
if (connection.settings_dict.get("ATOMIC_REQUESTS") and
getattr(response, "exception", None) and connection.in_atomic_block):
# response with exception (HTTP status like: 401, 404, etc)
# pointwise disable atomic block for handle log (TransactionManagementError)
connection.set_rollback(True)
connection.set_rollback(False)
if response.streaming:
rendered_content = None
elif hasattr(response, "rendered_content"):
rendered_content = response.rendered_content
else:
rendered_content = response.getvalue()
self.log.update(
{
"remote_addr": self._get_ip_address(request),
"view": self._get_view_name(request),
"view_method": self._get_view_method(request),
"path": self._get_path(request),
"host": request.get_host(),
"method": request.method,
"query_params": self._clean_data(request.query_params.dict()),
"user": self._get_user(request),
# "response_ms": self._get_response_ms(),
"response": self._clean_data(rendered_content),
"status_code": response.status_code,
"agent": self._get_agent(request),
}
)
try:
self.handle_log()
except Exception:
# ensure that all exceptions raised by handle_log
# doesn't prevent API call to continue as expected
myLogger.exception("Logging API call raise exception!")
return response
def handle_log(self):
"""
Hook to define what happens with the log.
Defaults on saving the data on the db.
"""
DrfRequestLog(**self.log).save()
def _get_path(self, request):
"""Get the request path and truncate it"""
return request.path
def _get_ip_address(self, request):
"""Get the remote ip address the request was generated from."""
ipaddr = request.META.get("HTTP_X_FORWARDED_FOR", None)
if ipaddr:
ipaddr = ipaddr.split(",")[0]
else:
ipaddr = request.META.get("REMOTE_ADDR", "")
# Account for IPv4 and IPv6 addresses, each possibly with port appended. Possibilities are:
# <ipv4 address>
# <ipv6 address>
# <ipv4 address>:port
# [<ipv6 address>]:port
# Note that ipv6 addresses are colon separated hex numbers
possibles = (ipaddr.lstrip("[").split("]")[0], ipaddr.split(":")[0])
for addr in possibles:
try:
return str(ipaddress.ip_address(addr))
except ValueError:
pass
return ipaddr
def _get_view_name(self, request):
"""Get view name."""
method = request.method.lower()
try:
attributes = getattr(self, method)
return (
type(attributes.__self__).__module__ + "." + type(attributes.__self__).__name__
)
except AttributeError:
return None
def _get_view_method(self, request):
"""Get view method."""
if hasattr(self, "action"):
return self.action or None
return request.method.lower()
def _get_user(self, request):
"""Get user."""
user = request.user
if user.is_anonymous:
return None
return user
def _get_agent(self, request):
"""Get os string"""
return str(parse(request.META['HTTP_USER_AGENT']))
def _get_response_ms(self):
"""
Get the duration of the request response cycle is milliseconds.
In case of negative duration 0 is returned.
"""
response_timedelta = now() - self.log["requested_at"]
response_ms = int(response_timedelta.total_seconds() * 1000)
return max(response_ms, 0)
def should_log(self, request, response):
"""
Method that should return a value that evaluated to True if the request should be logged.
By default, check if the request method is in logging_methods.
"""
return self.logging_methods == "__all__" or response.status_code > 404 or response.status_code == 400 \
or (request.method in self.logging_methods and response.status_code not in [401, 403, 404])\
or (self.log.get("response_ms", 0) > 2000)
def _clean_data(self, data):
"""
Clean a dictionary of data of potentially sensitive info before
sending to the database.
Function based on the "_clean_credentials" function of django
(https://github.com/django/django/blob/stable/1.11.x/django/contrib/auth/__init__.py#L50)
Fields defined by django are by default cleaned with this function
You can define your own sensitive fields in your view by defining a set
eg: sensitive_fields = {'field1', 'field2'}
"""
if isinstance(data, bytes):
data = data.decode(errors="replace")
if isinstance(data, list):
return [self._clean_data(d) for d in data]
if isinstance(data, dict):
SENSITIVE_FIELDS = {
"api",
"token",
"key",
"secret",
"password",
"signature",
}
data = dict(data)
if self.sensitive_fields:
SENSITIVE_FIELDS = SENSITIVE_FIELDS | {
field.lower() for field in self.sensitive_fields
}
for key, value in data.items():
try:
value = ast.literal_eval(value)
except (ValueError, SyntaxError):
pass
if isinstance(value, (list, dict)):
data[key] = self._clean_data(value)
if key.lower() in SENSITIVE_FIELDS:
data[key] = self.CLEANED_SUBSTITUTE
return data