129 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			129 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
| import requests
 | ||
| import os
 | ||
| from apps.utils.sql import execute_raw_sql
 | ||
| import json
 | ||
| from apps.utils.tools import MyJSONEncoder
 | ||
| from .utils import is_safe_sql
 | ||
| from rest_framework.views import APIView
 | ||
| from drf_yasg.utils import swagger_auto_schema
 | ||
| from rest_framework import serializers
 | ||
| from rest_framework.exceptions import ParseError
 | ||
| from rest_framework.response import Response
 | ||
| from django.conf import settings
 | ||
| from apps.utils.mixins import MyLoggingMixin
 | ||
| from django.core.cache import cache
 | ||
| import uuid
 | ||
| from apps.utils.thread import MyThread
 | ||
| 
 | ||
| LLM_URL = getattr(settings, "LLM_URL", "")
 | ||
| API_KEY = getattr(settings, "LLM_API_KEY", "")
 | ||
| MODEL = "qwen14b"
 | ||
| HEADERS = {
 | ||
|     "Authorization": f"Bearer {API_KEY}",
 | ||
|     "Content-Type": "application/json"
 | ||
| }
 | ||
| CUR_DIR = os.path.dirname(os.path.abspath(__file__))
 | ||
| 
 | ||
| def load_promot(name):
 | ||
|     with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f:
 | ||
|         return f.read()
 | ||
| 
 | ||
| 
 | ||
| def ask(input:str, p_name:str, stream=False):
 | ||
|     his = [{"role":"system", "content": load_promot(p_name)}]
 | ||
|     his.append({"role":"user", "content": input})
 | ||
|     payload = {
 | ||
|                 "model": MODEL,
 | ||
|                 "messages": his,
 | ||
|                 "temperature": 0,
 | ||
|                 "stream": stream
 | ||
|                 }
 | ||
|     response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream)
 | ||
|     if not stream:
 | ||
|         return response.json()["choices"][0]["message"]["content"]
 | ||
|     else:
 | ||
|         # 处理流式响应
 | ||
|         full_content = ""
 | ||
|         for chunk in response.iter_lines():
 | ||
|             if chunk:
 | ||
|                 # 通常流式响应是SSE格式(data: {...})
 | ||
|                 decoded_chunk = chunk.decode('utf-8')
 | ||
|                 if decoded_chunk.startswith("data:"):
 | ||
|                     json_str = decoded_chunk[5:].strip()
 | ||
|                     if json_str == "[DONE]":
 | ||
|                         break
 | ||
|                     try:
 | ||
|                         chunk_data = json.loads(json_str)
 | ||
|                         if "choices" in chunk_data and chunk_data["choices"]:
 | ||
|                             delta = chunk_data["choices"][0].get("delta", {})
 | ||
|                             if "content" in delta:
 | ||
|                                 print(delta["content"])
 | ||
|                                 full_content += delta["content"]
 | ||
|                     except json.JSONDecodeError:
 | ||
|                         continue
 | ||
|         return full_content
 | ||
| 
 | ||
| def work_chain(input:str, t_key:str):
 | ||
|     pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]}
 | ||
|     cache.set(t_key, pdict)
 | ||
|     res_text = ask(input, 'w_sql')
 | ||
|     if res_text == '请以 查询 开头,重新描述你的需求':
 | ||
|         pdict["state"] = "error"
 | ||
|         pdict["steps"].append({"state":"error", "msg":res_text})
 | ||
|         cache.set(t_key, pdict)
 | ||
|         return
 | ||
|     else:
 | ||
|         pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text})
 | ||
|         cache.set(t_key, pdict)
 | ||
|         if not is_safe_sql(res_text):
 | ||
|             pdict["state"] = "error"
 | ||
|             pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"})
 | ||
|             cache.set(t_key, pdict)
 | ||
|             return
 | ||
|         pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"})
 | ||
|         cache.set(t_key, pdict)
 | ||
|         res = execute_raw_sql(res_text)
 | ||
|         pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res})
 | ||
|         cache.set(t_key, pdict)
 | ||
|         pdict["steps"].append({"state":"ok", "msg":"正在生成报告"})
 | ||
|         cache.set(t_key, pdict)
 | ||
|         res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana')
 | ||
|         content = res2.lstrip('```html ').rstrip('```')
 | ||
|         pdict["state"] = "done"
 | ||
|         pdict["content"] = content
 | ||
|         pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content})
 | ||
|         cache.set(t_key, pdict)
 | ||
|         return
 | ||
|     
 | ||
| class InputSerializer(serializers.Serializer):
 | ||
|     input = serializers.CharField(label="查询需求")
 | ||
| 
 | ||
| class WorkChain(MyLoggingMixin, APIView):
 | ||
| 
 | ||
|     @swagger_auto_schema(
 | ||
|         operation_summary="提交查询需求",
 | ||
|         request_body=InputSerializer)
 | ||
|     def post(self, request):
 | ||
|         llm_enabled = getattr(settings, "LLM_ENABLED", False)
 | ||
|         if not llm_enabled:
 | ||
|             raise ParseError('LLM功能未启用')
 | ||
|         input = request.data.get('input')
 | ||
|         t_key = f'ichat_{uuid.uuid4()}'
 | ||
|         MyThread(target=work_chain, args=(input, t_key)).start()
 | ||
|         return Response({'ichat_tid': t_key})
 | ||
|     
 | ||
|     @swagger_auto_schema(
 | ||
|         operation_summary="获取查询进度")
 | ||
|     def get(self, request):
 | ||
|         llm_enabled = getattr(settings, "LLM_ENABLED", False)
 | ||
|         if not llm_enabled:
 | ||
|             raise ParseError('LLM功能未启用')
 | ||
|         ichat_tid = request.GET.get('ichat_tid')
 | ||
|         if ichat_tid:
 | ||
|             return Response(cache.get(ichat_tid))
 | ||
| 
 | ||
| if __name__ == "__main__":
 | ||
|     print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告"))
 | ||
| 
 | ||
|     from apps.ichat.views2 import work_chain
 | ||
|     print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告')) |