71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import requests
|
|
import os
|
|
from apps.utils.sql import execute_raw_sql
|
|
import json
|
|
from apps.utils.tools import MyJSONEncoder
|
|
from .utils import is_safe_sql
|
|
from rest_framework.views import APIView
|
|
from drf_yasg.utils import swagger_auto_schema
|
|
from rest_framework import serializers
|
|
from rest_framework.exceptions import ParseError
|
|
from rest_framework.response import Response
|
|
from django.conf import settings
|
|
|
|
LLM_URL = getattr(settings, "LLM_URL", "")
|
|
API_KEY = getattr(settings, "LLM_API_KEY", "")
|
|
MODEL = "qwen14b"
|
|
HEADERS = {
|
|
"Authorization": f"Bearer {API_KEY}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
def load_promot(name):
|
|
with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f:
|
|
return f.read()
|
|
|
|
|
|
def ask(input:str, p_name:str):
|
|
his = [{"role":"system", "content": load_promot(p_name)}]
|
|
his.append({"role":"user", "content": input})
|
|
payload = {
|
|
"model": MODEL,
|
|
"messages": his,
|
|
"temperature": 0
|
|
}
|
|
response = requests.post(LLM_URL, headers=HEADERS, json=payload)
|
|
return response.json()["choices"][0]["message"]["content"]
|
|
|
|
def work_chain(input:str):
|
|
res_text = ask(input, 'w_sql')
|
|
if res_text == '请以 查询 开头,重新描述你的需求':
|
|
return '请以 查询 开头,重新描述你的需求'
|
|
else:
|
|
if not is_safe_sql(res_text):
|
|
return '当前查询存在风险,请重新描述你的需求'
|
|
res = execute_raw_sql(res_text)
|
|
res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana')
|
|
return res2
|
|
|
|
class InputSerializer(serializers.Serializer):
|
|
input = serializers.CharField(label="查询需求")
|
|
|
|
class WorkChain(APIView):
|
|
|
|
@swagger_auto_schema(
|
|
operation_summary="查询工作",
|
|
request_body=InputSerializer)
|
|
def post(self, request):
|
|
llm_enabled = getattr(settings, "LLM_ENABLED", False)
|
|
if not llm_enabled:
|
|
raise ParseError('LLM功能未启用')
|
|
input = request.data.get('input')
|
|
res_text = work_chain(input)
|
|
res_text = res_text.lstrip('```html ').rstrip('```')
|
|
return Response({'content': res_text})
|
|
|
|
if __name__ == "__main__":
|
|
print(work_chain("查询外观检验工段在2025年6月的生产合格数等并形成报告"))
|
|
|
|
from apps.ichat.views2 import work_chain
|
|
print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告')) |