import requests import os from apps.utils.sql import execute_raw_sql import json from apps.utils.tools import MyJSONEncoder from .utils import is_safe_sql from rest_framework.views import APIView from drf_yasg.utils import swagger_auto_schema from rest_framework import serializers from rest_framework.exceptions import ParseError from rest_framework.response import Response from django.conf import settings LLM_URL = getattr(settings, "LLM_URL", "") API_KEY = getattr(settings, "LLM_API_KEY", "") MODEL = "qwen14b" HEADERS = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json" } CUR_DIR = os.path.dirname(os.path.abspath(__file__)) def load_promot(name): with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f: return f.read() def ask(input:str, p_name:str): his = [{"role":"system", "content": load_promot(p_name)}] his.append({"role":"user", "content": input}) payload = { "model": MODEL, "messages": his, "temperature": 0 } response = requests.post(LLM_URL, headers=HEADERS, json=payload) return response.json()["choices"][0]["message"]["content"] def work_chain(input:str): res_text = ask(input, 'w_sql') if res_text == '请以 查询 开头,重新描述你的需求': return '请以 查询 开头,重新描述你的需求' else: if not is_safe_sql(res_text): return '当前查询存在风险,请重新描述你的需求' res = execute_raw_sql(res_text) res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana') return res2 class InputSerializer(serializers.Serializer): input = serializers.CharField(label="查询需求") class WorkChain(APIView): @swagger_auto_schema( operation_summary="查询工作", request_body=InputSerializer) def post(self, request): llm_enabled = getattr(settings, "LLM_ENABLED", False) if not llm_enabled: raise ParseError('LLM功能未启用') input = request.data.get('input') res_text = work_chain(input) res_text = res_text.lstrip('```html ').rstrip('```') return Response({'content': res_text}) if __name__ == "__main__": print(work_chain("查询外观检验工段在2025年6月的生产合格数等并形成报告")) from apps.ichat.views2 import work_chain print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告'))