185 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			185 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
from django.shortcuts import render
 | 
						|
from apps.utils.viewsets import CustomModelViewSet
 | 
						|
from rest_framework.decorators import action
 | 
						|
from rest_framework.response import Response
 | 
						|
from apps.bi.models import Dataset
 | 
						|
from apps.bi.serializers import DatasetSerializer, DatasetCreateUpdateSerializer, DataExecSerializer
 | 
						|
from django.apps import apps
 | 
						|
from rest_framework import serializers
 | 
						|
import concurrent.futures
 | 
						|
from django.core.cache import cache
 | 
						|
from apps.utils.sql import execute_raw_sql, format_sqldata
 | 
						|
from apps.bi.services import check_sql_safe, format_json_with_placeholders
 | 
						|
from rest_framework.exceptions import ParseError
 | 
						|
from rest_framework.generics import get_object_or_404
 | 
						|
# Create your views here.
 | 
						|
 | 
						|
 | 
						|
class DatasetViewSet(CustomModelViewSet):
 | 
						|
    queryset = Dataset.objects.all()
 | 
						|
    serializer_class = DatasetSerializer
 | 
						|
    create_serializer_class = DatasetCreateUpdateSerializer
 | 
						|
    update_serializer_class = DatasetCreateUpdateSerializer
 | 
						|
    search_fields = ['name', 'code']
 | 
						|
    ordering = ['name', 'code', 'id']
 | 
						|
 | 
						|
    def get_object(self):
 | 
						|
        """
 | 
						|
        Returns the object the view is displaying.
 | 
						|
 | 
						|
        You may want to override this if you need to provide non-standard
 | 
						|
        queryset lookups.  Eg if objects are referenced using multiple
 | 
						|
        keyword arguments in the url conf.
 | 
						|
        """
 | 
						|
        queryset = self.filter_queryset(self.get_queryset())
 | 
						|
 | 
						|
        # Perform the lookup filtering.
 | 
						|
        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
 | 
						|
 | 
						|
        assert lookup_url_kwarg in self.kwargs, (
 | 
						|
            'Expected view %s to be called with a URL keyword argument '
 | 
						|
            'named "%s". Fix your URL conf, or set the `.lookup_field` '
 | 
						|
            'attribute on the view correctly.' %
 | 
						|
            (self.__class__.__name__, lookup_url_kwarg)
 | 
						|
        )
 | 
						|
 | 
						|
        filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
 | 
						|
        try:
 | 
						|
            obj = get_object_or_404(queryset, **filter_kwargs)
 | 
						|
        except:
 | 
						|
            filter_kwargs = {'code': self.kwargs[lookup_url_kwarg]}
 | 
						|
            obj = get_object_or_404(queryset, **filter_kwargs)
 | 
						|
 | 
						|
        # May raise a permission denied
 | 
						|
        self.check_object_permissions(self.request, obj)
 | 
						|
 | 
						|
        return obj
 | 
						|
 | 
						|
    @action(methods=['post'], detail=True, perms_map={'post': 'dataset.exec'}, serializer_class=DataExecSerializer, cache_seconds=0, logging_methods=[])
 | 
						|
    def exec(self, request, pk=None):
 | 
						|
        """执行sql查询
 | 
						|
 | 
						|
        执行sql查询支持code
 | 
						|
        """
 | 
						|
        dt: Dataset = self.get_object()
 | 
						|
        rdata = DatasetSerializer(instance=dt).data
 | 
						|
        xquery = request.data.get('query', {})
 | 
						|
        is_test = request.data.get('is_test', False)
 | 
						|
        raise_exception = request.data.get('raise_exception', True)
 | 
						|
        xquery['r_user'] = request.user.id
 | 
						|
        xquery['r_dept'] = request.user.belong_dept.id if request.user.belong_dept else ''
 | 
						|
        results = {}
 | 
						|
        results2 = {}
 | 
						|
        can_cache = True
 | 
						|
        query = dt.default_param
 | 
						|
        if dt.sql_query:
 | 
						|
            if is_test:
 | 
						|
                query.update(dt.test_param)
 | 
						|
            else:
 | 
						|
                query.update(xquery)
 | 
						|
            try:
 | 
						|
                sql_f_ = check_sql_safe(dt.sql_query.format(**query))
 | 
						|
            except KeyError as e:
 | 
						|
                raise ParseError(f'需指定查询参数_{str(e)}')
 | 
						|
            sql_f_strip = sql_f_.strip(';')
 | 
						|
            sql_f_l = sql_f_strip.split(';')
 | 
						|
            hash_k = hash(sql_f_strip)
 | 
						|
            hash_v = cache.get(hash_k, None)
 | 
						|
            if hash_v:
 | 
						|
                return Response(hash_v)
 | 
						|
            # 多线程运行并返回字典结果
 | 
						|
            with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
 | 
						|
                fun_ps = []
 | 
						|
                for ind, val in enumerate(sql_f_l):
 | 
						|
                    fun_ps.append((f'ds{ind}', execute_raw_sql, val))
 | 
						|
                # 生成执行函数
 | 
						|
                futures = {executor.submit(i[1], i[2]): i for i in fun_ps}
 | 
						|
                for future in concurrent.futures.as_completed(futures):
 | 
						|
                    name, *_, sql_f = futures[future]  # 获取对应的键
 | 
						|
                    try:
 | 
						|
                        res = future.result()
 | 
						|
                        results[name],  results2[name] = format_sqldata(
 | 
						|
                            res[0], res[1])
 | 
						|
                    except Exception as e:
 | 
						|
                        if raise_exception:
 | 
						|
                            raise ParseError(f'查询异常:{str(e)}')
 | 
						|
                        else:
 | 
						|
                            results[name] = 'error: ' + str(e)
 | 
						|
                            can_cache = False
 | 
						|
        rdata['data'] = results
 | 
						|
        rdata['data2'] = results2
 | 
						|
        if rdata['echart_options'] and not rdata['echart_options'].startswith('function'):
 | 
						|
            for key in results:
 | 
						|
                if isinstance(results[key], str):
 | 
						|
                    raise ParseError(results[key])
 | 
						|
            rdata['echart_options'] = format_json_with_placeholders(
 | 
						|
                rdata['echart_options'], **results)
 | 
						|
        if results and can_cache:
 | 
						|
            cache.set(hash_k, rdata, dt.cache_seconds)
 | 
						|
        return Response(rdata)
 | 
						|
 | 
						|
    @action(methods=['get'], detail=False, perms_map={'get': '*'})
 | 
						|
    def base(self, request, pk=None):
 | 
						|
        all_models = apps.get_models()
 | 
						|
        rdict = {}
 | 
						|
        # 遍历所有模型
 | 
						|
        for model in all_models:
 | 
						|
            # 获取表名称
 | 
						|
            table_name = model._meta.db_table
 | 
						|
            rdict[table_name] = []
 | 
						|
 | 
						|
            # 获取字段信息
 | 
						|
            fields = model._meta.get_fields()
 | 
						|
            for field in fields:
 | 
						|
                rdict[table_name].append(
 | 
						|
                    {'name': field.name, 'type': field.get_internal_type()})
 | 
						|
        return Response(rdict)
 | 
						|
 | 
						|
 | 
						|
# class ReportViewSet(CustomModelViewSet):   # 暂时不用了
 | 
						|
#     queryset = Report.objects.all()
 | 
						|
#     serializer_class = ReportSerializer
 | 
						|
#     search_fields = ['name', 'code']
 | 
						|
 | 
						|
#     @action(methods=['post'], detail=True, perms_map={'post': 'report.exec'}, serializer_class=DataExecSerializer, cache_seconds=0)
 | 
						|
#     def exec(self, request, pk=None):
 | 
						|
#         """执行报表查询
 | 
						|
 | 
						|
#         执行报表查询并用于返回前端渲染
 | 
						|
#         """
 | 
						|
#         report = self.get_object()
 | 
						|
#         rdata = ReportSerializer(instance=report).data
 | 
						|
#         query = request.data.get('query', {})
 | 
						|
#         return_type = request.data.get('return_type', 2)
 | 
						|
#         query['r_user'] = request.user.id
 | 
						|
#         query['r_dept'] = request.user.belong_dept.id if request.user.belong_dept else ''
 | 
						|
#         datasets = report.datasets.all()
 | 
						|
#         results = {}
 | 
						|
#         seconds = 10   # 缓存秒数
 | 
						|
 | 
						|
#         with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:   # 多线程运行并返回字典结果
 | 
						|
#             fun_ps = []
 | 
						|
#             for ds in datasets:
 | 
						|
#                 sql_query = ds.sql_query
 | 
						|
#                 if sql_query:
 | 
						|
#                     sql_f = check_sql_safe(sql_query.format(**query))  # 有风险先这样处理一下
 | 
						|
#                     res = cache.get(sql_f, None)
 | 
						|
#                     if isinstance(res, tuple):
 | 
						|
#                         results[ds.name] = format_sqldata(res[0], res[1], return_type)
 | 
						|
#                     else:
 | 
						|
#                         fun_ps.append((ds.name, execute_raw_sql, sql_f))
 | 
						|
#             # 生成执行函数
 | 
						|
#             futures = {executor.submit(i[1], i[2]): i for i in fun_ps}
 | 
						|
#             for future in concurrent.futures.as_completed(futures):
 | 
						|
#                 name, *_, sql_f = futures[future]  # 获取对应的键
 | 
						|
#                 try:
 | 
						|
#                     res = future.result()
 | 
						|
#                     results[name] = format_sqldata(res[0], res[1], return_type)
 | 
						|
#                     if seconds:
 | 
						|
#                         cache.set(sql_f, res, seconds)
 | 
						|
#                 except Exception as e:
 | 
						|
#                     results[name] = 'error: ' + str(e)
 | 
						|
 | 
						|
#         rdata['data'] = results
 | 
						|
#         return Response(rdata)
 |