From e774992d3a20fa897f953a14b136e691c3038900 Mon Sep 17 00:00:00 2001 From: caoqianming Date: Fri, 13 Jun 2025 16:23:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20workchain=E9=80=9A=E8=BF=87=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E6=89=A7=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/ichat/views2.py | 86 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 72 insertions(+), 14 deletions(-) diff --git a/apps/ichat/views2.py b/apps/ichat/views2.py index 1f0b2397..9b2766bb 100644 --- a/apps/ichat/views2.py +++ b/apps/ichat/views2.py @@ -10,6 +10,10 @@ from rest_framework import serializers from rest_framework.exceptions import ParseError from rest_framework.response import Response from django.conf import settings +from apps.utils.mixins import MyLoggingMixin +from django.core.cache import cache +import uuid +from apps.utils.thread import MyThread LLM_URL = getattr(settings, "LLM_URL", "") API_KEY = getattr(settings, "LLM_API_KEY", "") @@ -25,47 +29,101 @@ def load_promot(name): return f.read() -def ask(input:str, p_name:str): +def ask(input:str, p_name:str, stream=False): his = [{"role":"system", "content": load_promot(p_name)}] his.append({"role":"user", "content": input}) payload = { "model": MODEL, "messages": his, - "temperature": 0 + "temperature": 0, + "stream": stream } - response = requests.post(LLM_URL, headers=HEADERS, json=payload) - return response.json()["choices"][0]["message"]["content"] + response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream) + if not stream: + return response.json()["choices"][0]["message"]["content"] + else: + # 处理流式响应 + full_content = "" + for chunk in response.iter_lines(): + if chunk: + # 通常流式响应是SSE格式(data: {...}) + decoded_chunk = chunk.decode('utf-8') + if decoded_chunk.startswith("data:"): + json_str = decoded_chunk[5:].strip() + if json_str == "[DONE]": + break + try: + chunk_data = json.loads(json_str) + if "choices" in chunk_data and chunk_data["choices"]: + delta = chunk_data["choices"][0].get("delta", {}) + if "content" in delta: + print(delta["content"]) + full_content += delta["content"] + except json.JSONDecodeError: + continue + return full_content -def work_chain(input:str): +def work_chain(input:str, t_key:str): + pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]} + cache.set(t_key, pdict) res_text = ask(input, 'w_sql') if res_text == '请以 查询 开头,重新描述你的需求': - return '请以 查询 开头,重新描述你的需求' + pdict["state"] = "error" + pdict["steps"].append({"state":"error", "msg":res_text}) + cache.set(t_key, pdict) + return else: + pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text}) + cache.set(t_key, pdict) if not is_safe_sql(res_text): - return '当前查询存在风险,请重新描述你的需求' + pdict["state"] = "error" + pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"}) + cache.set(t_key, pdict) + return + pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"}) + cache.set(t_key, pdict) res = execute_raw_sql(res_text) + pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res}) + cache.set(t_key, pdict) + pdict["steps"].append({"state":"ok", "msg":"正在生成报告"}) + cache.set(t_key, pdict) res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana') - return res2 + content = res2.lstrip('```html ').rstrip('```') + pdict["state"] = "done" + pdict["content"] = content + pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content}) + cache.set(t_key, pdict) + return class InputSerializer(serializers.Serializer): input = serializers.CharField(label="查询需求") -class WorkChain(APIView): +class WorkChain(MyLoggingMixin, APIView): @swagger_auto_schema( - operation_summary="查询工作", + operation_summary="提交查询需求", request_body=InputSerializer) def post(self, request): llm_enabled = getattr(settings, "LLM_ENABLED", False) if not llm_enabled: raise ParseError('LLM功能未启用') input = request.data.get('input') - res_text = work_chain(input) - res_text = res_text.lstrip('```html ').rstrip('```') - return Response({'content': res_text}) + t_key = f'ichat_{uuid.uuid4()}' + MyThread(target=work_chain, args=(input, t_key)).start() + return Response({'ichat_tid': t_key}) + + @swagger_auto_schema( + operation_summary="获取查询进度") + def get(self, request): + llm_enabled = getattr(settings, "LLM_ENABLED", False) + if not llm_enabled: + raise ParseError('LLM功能未启用') + ichat_tid = request.GET.get('ichat_tid') + if ichat_tid: + return Response(cache.get(ichat_tid)) if __name__ == "__main__": - print(work_chain("查询外观检验工段在2025年6月的生产合格数等并形成报告")) + print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告")) from apps.ichat.views2 import work_chain print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告')) \ No newline at end of file