76 lines
3.2 KiB
Python
76 lines
3.2 KiB
Python
import requests
|
||
from langchain_core.language_models import LLM
|
||
from langchain_core.outputs import LLMResult, Generation
|
||
from langchain_experimental.sql import SQLDatabaseChain
|
||
from langchain_community.utilities import SQLDatabase
|
||
from server.conf import DATABASES
|
||
from apps.ichat.serializers import CustomLLMrequestSerializer
|
||
from rest_framework.views import APIView
|
||
from urllib.parse import quote_plus
|
||
from rest_framework.response import Response
|
||
|
||
|
||
db_conf = DATABASES['default']
|
||
# 密码需要 URL 编码(因为有特殊字符如 @)
|
||
password_encodeed = quote_plus(db_conf['PASSWORD'])
|
||
|
||
db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"])
|
||
# model_url = "http://14.22.88.72:11025/v1/chat/completions"
|
||
model_url = "http://139.159.180.64:11434/v1/chat/completions"
|
||
|
||
class CustomLLM(LLM):
|
||
model_url: str
|
||
def _call(self, prompt: str, stop: list = None) -> str:
|
||
data = {
|
||
"model": "glm4",
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个 SQL 助手,严格遵循以下规则:\n"
|
||
"1. 只返回 PostgreSQL 语法 SQL 语句。\n"
|
||
"2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n"
|
||
"3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。"
|
||
"4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。"
|
||
},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
"stream": False
|
||
}
|
||
response = requests.post(self.model_url, json=data, timeout=600)
|
||
response.raise_for_status()
|
||
content = response.json()["choices"][0]["message"]["content"]
|
||
clean_sql = self.strip_sql_markdown(content)
|
||
return clean_sql
|
||
|
||
def _generate(self, prompts: list, stop: list = None) -> LLMResult:
|
||
generations = []
|
||
for prompt in prompts:
|
||
text = self._call(prompt, stop)
|
||
generations.append([Generation(text=text)])
|
||
return LLMResult(generations=generations)
|
||
|
||
def strip_sql_markdown(self, content: str) -> str:
|
||
import re
|
||
# 去掉包裹在 ```sql 或 ``` 中的内容
|
||
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
|
||
if match:
|
||
return match.group(1).strip()
|
||
match = re.search(r"```\s*(.*?)```", content, re.DOTALL)
|
||
if match:
|
||
return match.group(1).strip()
|
||
return content.strip()
|
||
|
||
@property
|
||
def _llm_type(self) -> str:
|
||
return "custom_llm"
|
||
|
||
|
||
class QueryLLMview(APIView):
|
||
def post(self, request):
|
||
serializer = CustomLLMrequestSerializer(data=request.data)
|
||
serializer.is_valid(raise_exception=True)
|
||
prompt = serializer.validated_data['prompt']
|
||
llm = CustomLLM(model_url=model_url)
|
||
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
||
result = chain.invoke(prompt)
|
||
return Response({"result": result}) |