factory/apps/ichat/views2.py

71 lines
2.5 KiB
Python

import requests
import os
from apps.utils.sql import execute_raw_sql
import json
from apps.utils.tools import MyJSONEncoder
from .utils import is_safe_sql
from rest_framework.views import APIView
from drf_yasg.utils import swagger_auto_schema
from rest_framework import serializers
from rest_framework.exceptions import ParseError
from rest_framework.response import Response
from django.conf import settings
LLM_URL = getattr(settings, "LLM_URL", "")
API_KEY = getattr(settings, "LLM_API_KEY", "")
MODEL = "qwen14b"
HEADERS = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
def load_promot(name):
with open(os.path.join(CUR_DIR, f'promot/{name}.md'), 'r') as f:
return f.read()
def ask(input:str, p_name:str):
his = [{"role":"system", "content": load_promot(p_name)}]
his.append({"role":"user", "content": input})
payload = {
"model": MODEL,
"messages": his,
"temperature": 0
}
response = requests.post(LLM_URL, headers=HEADERS, json=payload)
return response.json()["choices"][0]["message"]["content"]
def work_chain(input:str):
res_text = ask(input, 'w_sql')
if res_text == '请以 查询 开头,重新描述你的需求':
return '请以 查询 开头,重新描述你的需求'
else:
if not is_safe_sql(res_text):
return '当前查询存在风险,请重新描述你的需求'
res = execute_raw_sql(res_text)
res2 = ask(json.dumps(res, cls=MyJSONEncoder, ensure_ascii=False), 'w_ana')
return res2
class InputSerializer(serializers.Serializer):
input = serializers.CharField(label="查询需求")
class WorkChain(APIView):
@swagger_auto_schema(
operation_summary="查询工作",
request_body=InputSerializer)
def post(self, request):
llm_enabled = getattr(settings, "LLM_ENABLED", False)
if not llm_enabled:
raise ParseError('LLM功能未启用')
input = request.data.get('input')
res_text = work_chain(input)
res_text = res_text.lstrip('```html ').rstrip('```')
return Response({'content': res_text})
if __name__ == "__main__":
print(work_chain("查询外观检验工段在2025年6月的生产合格数等并形成报告"))
from apps.ichat.views2 import work_chain
print(work_chain('查询外观检验工段在2025年6月的生产合格数等并形成报告'))