129 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			129 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
import requests
 | 
						||
import os
 | 
						||
from apps.utils.sql import execute_raw_sql
 | 
						||
import json
 | 
						||
from apps.utils.tools import MyJSONEncoder
 | 
						||
from .utils import is_safe_sql
 | 
						||
from rest_framework.views import APIView
 | 
						||
from drf_yasg.utils import swagger_auto_schema
 | 
						||
from rest_framework import serializers
 | 
						||
from rest_framework.exceptions import ParseError
 | 
						||
from rest_framework.response import Response
 | 
						||
from django.conf import settings
 | 
						||
from apps.utils.mixins import MyLoggingMixin
 | 
						||
from django.core.cache import cache
 | 
						||
import uuid
 | 
						||
from apps.utils.thread import MyThread
 | 
						||
 | 
						||
LLM_URL = getattr(settings, "LLM_URL", "")
 | 
						||
API_KEY = getattr(settings, "LLM_API_KEY", "")
 | 
						||
MODEL = "qwen14b"
 | 
						||
HEADERS = {
 | 
						||
    "Authorization": f"Bearer {API_KEY}",
 | 
						||
    "Content-Type": "application/json"
 | 
						||
}
 | 
						||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
 | 
						||
 | 
						||
def load_promot(name):
 | 
						||
    with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f:
 | 
						||
        return f.read()
 | 
						||
 | 
						||
 | 
						||
def ask(input:str, p_name:str, stream=False):
 | 
						||
    his = [{"role":"system", "content": load_promot(p_name)}]
 | 
						||
    his.append({"role":"user", "content": input})
 | 
						||
    payload = {
 | 
						||
                "model": MODEL,
 | 
						||
                "messages": his,
 | 
						||
                "temperature": 0,
 | 
						||
                "stream": stream
 | 
						||
                }
 | 
						||
    response = requests.post(LLM_URL, headers=HEADERS, json=payload, stream=stream)
 | 
						||
    if not stream:
 | 
						||
        return response.json()["choices"][0]["message"]["content"]
 | 
						||
    else:
 | 
						||
        # 处理流式响应
 | 
						||
        full_content = ""
 | 
						||
        for chunk in response.iter_lines():
 | 
						||
            if chunk:
 | 
						||
                # 通常流式响应是SSE格式(data: {...})
 | 
						||
                decoded_chunk = chunk.decode('utf-8')
 | 
						||
                if decoded_chunk.startswith("data:"):
 | 
						||
                    json_str = decoded_chunk[5:].strip()
 | 
						||
                    if json_str == "[DONE]":
 | 
						||
                        break
 | 
						||
                    try:
 | 
						||
                        chunk_data = json.loads(json_str)
 | 
						||
                        if "choices" in chunk_data and chunk_data["choices"]:
 | 
						||
                            delta = chunk_data["choices"][0].get("delta", {})
 | 
						||
                            if "content" in delta:
 | 
						||
                                print(delta["content"])
 | 
						||
                                full_content += delta["content"]
 | 
						||
                    except json.JSONDecodeError:
 | 
						||
                        continue
 | 
						||
        return full_content
 | 
						||
 | 
						||
def work_chain(input:str, t_key:str):
 | 
						||
    pdict = {"state": "progress", "steps": [{"state":"ok", "msg":"正在生成查询语句"}]}
 | 
						||
    cache.set(t_key, pdict)
 | 
						||
    res_text = ask(input, 'w_sql')
 | 
						||
    if res_text == '请以 查询 开头,重新描述你的需求':
 | 
						||
        pdict["state"] = "error"
 | 
						||
        pdict["steps"].append({"state":"error", "msg":res_text})
 | 
						||
        cache.set(t_key, pdict)
 | 
						||
        return
 | 
						||
    else:
 | 
						||
        pdict["steps"].append({"state":"ok", "msg":"查询语句生成成功", "content":res_text})
 | 
						||
        cache.set(t_key, pdict)
 | 
						||
        if not is_safe_sql(res_text):
 | 
						||
            pdict["state"] = "error"
 | 
						||
            pdict["steps"].append({"state":"error", "msg":"当前查询存在风险,请重新描述你的需求"})
 | 
						||
            cache.set(t_key, pdict)
 | 
						||
            return
 | 
						||
        pdict["steps"].append({"state":"ok", "msg":"正在执行查询语句"})
 | 
						||
        cache.set(t_key, pdict)
 | 
						||
        res = execute_raw_sql(res_text)
 | 
						||
        pdict["steps"].append({"state":"ok", "msg":"查询语句执行成功", "content":res})
 | 
						||
        cache.set(t_key, pdict)
 | 
						||
        pdict["steps"].append({"state":"ok", "msg":"正在生成报告"})
 | 
						||
        cache.set(t_key, pdict)
 | 
						||
        res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana')
 | 
						||
        content = res2.lstrip('```html ').rstrip('```')
 | 
						||
        pdict["state"] = "done"
 | 
						||
        pdict["content"] = content
 | 
						||
        pdict["steps"].append({"state":"ok", "msg":"报告生成成功", "content": content})
 | 
						||
        cache.set(t_key, pdict)
 | 
						||
        return
 | 
						||
    
 | 
						||
class InputSerializer(serializers.Serializer):
 | 
						||
    input = serializers.CharField(label="查询需求")
 | 
						||
 | 
						||
class WorkChain(MyLoggingMixin, APIView):
 | 
						||
 | 
						||
    @swagger_auto_schema(
 | 
						||
        operation_summary="提交查询需求",
 | 
						||
        request_body=InputSerializer)
 | 
						||
    def post(self, request):
 | 
						||
        llm_enabled = getattr(settings, "LLM_ENABLED", False)
 | 
						||
        if not llm_enabled:
 | 
						||
            raise ParseError('LLM功能未启用')
 | 
						||
        input = request.data.get('input')
 | 
						||
        t_key = f'ichat_{uuid.uuid4()}'
 | 
						||
        MyThread(target=work_chain, args=(input, t_key)).start()
 | 
						||
        return Response({'ichat_tid': t_key})
 | 
						||
    
 | 
						||
    @swagger_auto_schema(
 | 
						||
        operation_summary="获取查询进度")
 | 
						||
    def get(self, request):
 | 
						||
        llm_enabled = getattr(settings, "LLM_ENABLED", False)
 | 
						||
        if not llm_enabled:
 | 
						||
            raise ParseError('LLM功能未启用')
 | 
						||
        ichat_tid = request.GET.get('ichat_tid')
 | 
						||
        if ichat_tid:
 | 
						||
            return Response(cache.get(ichat_tid))
 | 
						||
 | 
						||
if __name__ == "__main__":
 | 
						||
    print(work_chain("查询 一次超洗 工段在2025年6月的生产合格数等并形成报告"))
 | 
						||
 | 
						||
    from apps.ichat.views2 import work_chain
 | 
						||
    print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告')) |