feat: workchain通过线程执行
This commit is contained in:
parent
4955ad1f11
commit
e774992d3a
|
@ -10,6 +10,10 @@ from rest_framework import serializers
|
||||||
from rest_framework.exceptions import ParseError
|
from rest_framework.exceptions import ParseError
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from apps.utils.mixins import MyLoggingMixin
|
||||||
|
from django.core.cache import cache
|
||||||
|
import uuid
|
||||||
|
from apps.utils.thread import MyThread
|
||||||
|
|
||||||
LLM_URL = getattr(settings, "LLM_URL", "")
|
LLM_URL = getattr(settings, "LLM_URL", "")
|
||||||
API_KEY = getattr(settings, "LLM_API_KEY", "")
|
API_KEY = getattr(settings, "LLM_API_KEY", "")
|
||||||
|
@ -25,47 +29,101 @@ def load_promot(name):
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
def ask(input:str, p_name:str):
|
def ask(input:str, p_name:str, stream=False):
|
||||||
his = [{"role":"system", "content": load_promot(p_name)}]
|
his = [{"role":"system", "content": load_promot(p_name)}]
|
||||||
his.append({"role":"user", "content": input})
|
his.append({"role":"user", "content": input})
|
||||||
payload = {
|
payload = {
|
||||||
"model": MODEL,
|
"model": MODEL,
|
||||||
"messages": his,
|
"messages": his,
|
||||||
"temperature": 0
|
"temperature": 0,
|
||||||
|
"stream": stream
|
||||||
}
|
}
|
||||||
response = requests.post(LLM_URL, headers=HEADERS, json=payload)
|
response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream)
|
||||||
return response.json()["choices"][0]["message"]["content"]
|
if not stream:
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
else:
|
||||||
|
# 处理流式响应
|
||||||
|
full_content = ""
|
||||||
|
for chunk in response.iter_lines():
|
||||||
|
if chunk:
|
||||||
|
# 通常流式响应是SSE格式(data: {...})
|
||||||
|
decoded_chunk = chunk.decode('utf-8')
|
||||||
|
if decoded_chunk.startswith("data:"):
|
||||||
|
json_str = decoded_chunk[5:].strip()
|
||||||
|
if json_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk_data = json.loads(json_str)
|
||||||
|
if "choices" in chunk_data and chunk_data["choices"]:
|
||||||
|
delta = chunk_data["choices"][0].get("delta", {})
|
||||||
|
if "content" in delta:
|
||||||
|
print(delta["content"])
|
||||||
|
full_content += delta["content"]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return full_content
|
||||||
|
|
||||||
def work_chain(input:str):
|
def work_chain(input:str, t_key:str):
|
||||||
|
pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]}
|
||||||
|
cache.set(t_key, pdict)
|
||||||
res_text = ask(input, 'w_sql')
|
res_text = ask(input, 'w_sql')
|
||||||
if res_text == '请以 查询 开头,重新描述你的需求':
|
if res_text == '请以 查询 开头,重新描述你的需求':
|
||||||
return '请以 查询 开头,重新描述你的需求'
|
pdict["state"] = "error"
|
||||||
|
pdict["steps"].append({"state":"error", "msg":res_text})
|
||||||
|
cache.set(t_key, pdict)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
|
pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text})
|
||||||
|
cache.set(t_key, pdict)
|
||||||
if not is_safe_sql(res_text):
|
if not is_safe_sql(res_text):
|
||||||
return '当前查询存在风险,请重新描述你的需求'
|
pdict["state"] = "error"
|
||||||
|
pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"})
|
||||||
|
cache.set(t_key, pdict)
|
||||||
|
return
|
||||||
|
pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"})
|
||||||
|
cache.set(t_key, pdict)
|
||||||
res = execute_raw_sql(res_text)
|
res = execute_raw_sql(res_text)
|
||||||
|
pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res})
|
||||||
|
cache.set(t_key, pdict)
|
||||||
|
pdict["steps"].append({"state":"ok", "msg":"正在生成报告"})
|
||||||
|
cache.set(t_key, pdict)
|
||||||
res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana')
|
res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana')
|
||||||
return res2
|
content = res2.lstrip('```html ').rstrip('```')
|
||||||
|
pdict["state"] = "done"
|
||||||
|
pdict["content"] = content
|
||||||
|
pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content})
|
||||||
|
cache.set(t_key, pdict)
|
||||||
|
return
|
||||||
|
|
||||||
class InputSerializer(serializers.Serializer):
|
class InputSerializer(serializers.Serializer):
|
||||||
input = serializers.CharField(label="查询需求")
|
input = serializers.CharField(label="查询需求")
|
||||||
|
|
||||||
class WorkChain(APIView):
|
class WorkChain(MyLoggingMixin, APIView):
|
||||||
|
|
||||||
@swagger_auto_schema(
|
@swagger_auto_schema(
|
||||||
operation_summary="查询工作",
|
operation_summary="提交查询需求",
|
||||||
request_body=InputSerializer)
|
request_body=InputSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
llm_enabled = getattr(settings, "LLM_ENABLED", False)
|
llm_enabled = getattr(settings, "LLM_ENABLED", False)
|
||||||
if not llm_enabled:
|
if not llm_enabled:
|
||||||
raise ParseError('LLM功能未启用')
|
raise ParseError('LLM功能未启用')
|
||||||
input = request.data.get('input')
|
input = request.data.get('input')
|
||||||
res_text = work_chain(input)
|
t_key = f'ichat_{uuid.uuid4()}'
|
||||||
res_text = res_text.lstrip('```html ').rstrip('```')
|
MyThread(target=work_chain, args=(input, t_key)).start()
|
||||||
return Response({'content': res_text})
|
return Response({'ichat_tid': t_key})
|
||||||
|
|
||||||
|
@swagger_auto_schema(
|
||||||
|
operation_summary="获取查询进度")
|
||||||
|
def get(self, request):
|
||||||
|
llm_enabled = getattr(settings, "LLM_ENABLED", False)
|
||||||
|
if not llm_enabled:
|
||||||
|
raise ParseError('LLM功能未启用')
|
||||||
|
ichat_tid = request.GET.get('ichat_tid')
|
||||||
|
if ichat_tid:
|
||||||
|
return Response(cache.get(ichat_tid))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(work_chain("查询外观检验工段在2025年6月的生产合格数等并形成报告"))
|
print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告"))
|
||||||
|
|
||||||
from apps.ichat.views2 import work_chain
|
from apps.ichat.views2 import work_chain
|
||||||
print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告'))
|
print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告'))
|
Loading…
Reference in New Issue