import requests import os from apps.utils.sql import execute_raw_sql import json from apps.utils.tools import MyJSONEncoder from .utils import is_safe_sql from rest_framework.views import APIView from drf_yasg.utils import swagger_auto_schema from rest_framework import serializers from rest_framework.exceptions import ParseError from rest_framework.response import Response from django.conf import settings from apps.utils.mixins import MyLoggingMixin from django.core.cache import cache import uuid from apps.utils.thread import MyThread LLM_URL = getattr(settings, "LLM_URL", "") API_KEY = getattr(settings, "LLM_API_KEY", "") MODEL = "qwen14b" HEADERS = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json" } CUR_DIR = os.path.dirname(os.path.abspath(__file__)) def load_promot(name): with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f: return f.read() def ask(input:str, p_name:str, stream=False): his = [{"role":"system", "content": load_promot(p_name)}] his.append({"role":"user", "content": input}) payload = { "model": MODEL, "messages": his, "temperature": 0, "stream": stream } response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream) if not stream: return response.json()["choices"][0]["message"]["content"] else: # 处理流式响应 full_content = "" for chunk in response.iter_lines(): if chunk: # 通常流式响应是SSE格式(data: {...}) decoded_chunk = chunk.decode('utf-8') if decoded_chunk.startswith("data:"): json_str = decoded_chunk[5:].strip() if json_str == "[DONE]": break try: chunk_data = json.loads(json_str) if "choices" in chunk_data and chunk_data["choices"]: delta = chunk_data["choices"][0].get("delta", {}) if "content" in delta: print(delta["content"]) full_content += delta["content"] except json.JSONDecodeError: continue return full_content def work_chain(input:str, t_key:str): pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]} cache.set(t_key, pdict) res_text = ask(input, 'w_sql') if res_text == '请以 查询 开头,重新描述你的需求': pdict["state"] = "error" pdict["steps"].append({"state":"error", "msg":res_text}) cache.set(t_key, pdict) return else: pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text}) cache.set(t_key, pdict) if not is_safe_sql(res_text): pdict["state"] = "error" pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"}) cache.set(t_key, pdict) return pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"}) cache.set(t_key, pdict) res = execute_raw_sql(res_text) pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res}) cache.set(t_key, pdict) pdict["steps"].append({"state":"ok", "msg":"正在生成报告"}) cache.set(t_key, pdict) res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana') content = res2.lstrip('```html ').rstrip('```') pdict["state"] = "done" pdict["content"] = content pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content}) cache.set(t_key, pdict) return class InputSerializer(serializers.Serializer): input = serializers.CharField(label="查询需求") class WorkChain(MyLoggingMixin, APIView): @swagger_auto_schema( operation_summary="提交查询需求", request_body=InputSerializer) def post(self, request): llm_enabled = getattr(settings, "LLM_ENABLED", False) if not llm_enabled: raise ParseError('LLM功能未启用') input = request.data.get('input') t_key = f'ichat_{uuid.uuid4()}' MyThread(target=work_chain, args=(input, t_key)).start() return Response({'ichat_tid': t_key}) @swagger_auto_schema( operation_summary="获取查询进度") def get(self, request): llm_enabled = getattr(settings, "LLM_ENABLED", False) if not llm_enabled: raise ParseError('LLM功能未启用') ichat_tid = request.GET.get('ichat_tid') if ichat_tid: return Response(cache.get(ichat_tid)) if __name__ == "__main__": print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告")) from apps.ichat.views2 import work_chain print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告'))