factory/apps/ichat/view_bak.py

87 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
from langchain_core.language_models import LLM
from langchain_core.outputs import LLMResult, Generation
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.utilities import SQLDatabase
from server.conf import DATABASES
from apps.ichat.serializers import CustomLLMrequestSerializer
from rest_framework.views import APIView
from urllib.parse import quote_plus
from rest_framework.response import Response
db_conf = DATABASES['default']
# 密码需要 URL 编码(因为有特殊字符如 @
password_encodeed = quote_plus(db_conf['PASSWORD'])
db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"])
# model_url = "http://14.22.88.72:11025/v1/chat/completions"
model_url = "http://139.159.180.64:11434/v1/chat/completions"
class CustomLLM(LLM):
model_url: str
mode: str = 'chat'
def _call(self, prompt: str, stop: list = None) -> str:
data = {
"model":"glm4",
"messages": self.build_message(prompt),
"stream": False,
}
response = requests.post(self.model_url, json=data, timeout=600)
response.raise_for_status()
content = response.json()["choices"][0]["message"]["content"]
print('content---', content)
clean_sql = self.strip_sql_markdown(content) if self.mode == 'sql' else content.strip()
return clean_sql
def _generate(self, prompts: list, stop: list = None) -> LLMResult:
generations = []
for prompt in prompts:
text = self._call(prompt, stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
def strip_sql_markdown(self, content: str) -> str:
import re
# 去掉包裹在 ```sql 或 ``` 中的内容
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
else:
return content.strip()
def build_message(self, prompt: str) -> list:
if self.mode == 'sql':
system_prompt = (
"你是一个 SQL 助手,严格遵循以下规则:\n"
"1. 只返回 PostgreSQL 语法 SQL 语句。\n"
"2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n"
"3. 输出必须是纯 SQL且可直接执行无需任何额外处理。\n"
"4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。"
)
else:
system_prompt = "你是一个聊天助手,请根据用户的问题,提供简洁明了的答案。"
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
@property
def _llm_type(self) -> str:
return "custom_llm"
class QueryLLMview(APIView):
def post(self, request):
serializer = CustomLLMrequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
prompt = serializer.validated_data['prompt']
mode = serializer.validated_data.get('mode', 'chat')
llm = CustomLLM(model_url=model_url, mode=mode)
print('prompt---', prompt, mode)
if mode == 'sql':
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
result = chain.invoke(prompt)
else:
result = llm._call(prompt)
return Response({"result": result})