diff --git a/apps/bi/migrations/0002_dataset_cache_seconds.py b/apps/bi/migrations/0002_dataset_cache_seconds.py new file mode 100644 index 00000000..058f1b4d --- /dev/null +++ b/apps/bi/migrations/0002_dataset_cache_seconds.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.12 on 2023-05-31 01:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bi', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='dataset', + name='cache_seconds', + field=models.PositiveIntegerField(blank=True, default=10, verbose_name='缓存秒数'), + ), + ] diff --git a/apps/bi/models.py b/apps/bi/models.py index b54a4b8c..693478d7 100644 --- a/apps/bi/models.py +++ b/apps/bi/models.py @@ -8,6 +8,7 @@ class Dataset(CommonBDModel): description = models.TextField('描述说明', default='', blank=True) sql_query = models.TextField('sql查询语句', default='', blank=True) echart_options = models.TextField(default='', blank=True) + cache_seconds = models.PositiveIntegerField('缓存秒数', default=10, blank=True) # class Report(CommonBDModel): diff --git a/apps/bi/serializers.py b/apps/bi/serializers.py index b70de58c..22a88f6c 100644 --- a/apps/bi/serializers.py +++ b/apps/bi/serializers.py @@ -30,5 +30,4 @@ class DatasetSerializer(CustomModelSerializer): class DataExecSerializer(serializers.Serializer): - query = serializers.JSONField(label="查询字典参数", required=False, allow_null=True) - return_type = serializers.IntegerField(label="返回格式", required=False, default=2) \ No newline at end of file + query = serializers.JSONField(label="查询字典参数", required=False, allow_null=True) \ No newline at end of file diff --git a/apps/bi/views.py b/apps/bi/views.py index 8e6b76ab..b9eb0816 100644 --- a/apps/bi/views.py +++ b/apps/bi/views.py @@ -27,39 +27,42 @@ class DatasetViewSet(CustomModelViewSet): dt = self.get_object() rdata = DatasetSerializer(instance=dt).data query = request.data.get('query', {}) - return_type = request.data.get('return_type', 2) query['r_user'] = request.user.id query['r_dept'] = request.user.belong_dept.id if request.user.belong_dept else '' results = {} - seconds = 10 + results2 = {} + can_cache = True + if dt.sql_query: sql_f_ = check_sql_safe(dt.sql_query.format(**query)) sql_f_l = sql_f_.strip(';').split(';') + hash_k = hash(sql_f_.strip(';')) + hash_v = cache.get(hash_k, None) + if hash_v: + return Response(hash_v) with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor: # 多线程运行并返回字典结果 fun_ps = [] for ind, val in enumerate(sql_f_l): - res = cache.get(val, None) - if isinstance(res, tuple): - results[f'ds{ind}'] = format_sqldata(res[0], res[1], return_type) - else: - fun_ps.append((f'ds{ind}', execute_raw_sql, val)) + fun_ps.append((f'ds{ind}', execute_raw_sql, val)) # 生成执行函数 futures = {executor.submit(i[1], i[2]): i for i in fun_ps} for future in concurrent.futures.as_completed(futures): name, *_, sql_f = futures[future] # 获取对应的键 try: res = future.result() - results[name] = format_sqldata(res[0], res[1], return_type) - if seconds: - cache.set(sql_f, res, seconds) + results[name], results2[name]= format_sqldata(res[0], res[1]) except Exception as e: results[name] = 'error: ' + str(e) + can_cache = False rdata['data'] = results + rdata['data2'] = results2 if rdata['echart_options']: for key in results: if isinstance(results[key], str): raise ParseError(results[key]) rdata['echart_options'] = format_json_with_placeholders(rdata['echart_options'], **results) + if results and can_cache: + cache.set(hash_k, rdata, dt.cache_seconds) return Response(rdata) @action(methods=['get'], detail=False, perms_map={'get': '*'}) diff --git a/apps/utils/sql.py b/apps/utils/sql.py index f90de8a5..06567427 100644 --- a/apps/utils/sql.py +++ b/apps/utils/sql.py @@ -17,11 +17,8 @@ def execute_raw_sql(sql: str, params=None): rows = cursor.fetchall() return columns, rows -def format_sqldata(columns, rows, return_type=2): - if return_type == 2: - return [columns] + rows - elif return_type == 1: - return [dict(zip(columns, row)) for row in rows] +def format_sqldata(columns, rows): + return [columns] + rows, [dict(zip(columns, row)) for row in rows] def query_all_dict(sql, params=None):