import requests from langchain_core.language_models import LLM from langchain_core.outputs import LLMResult, Generation from langchain_experimental.sql import SQLDatabaseChain from langchain_community.utilities import SQLDatabase from server.conf import DATABASES from apps.ichat.serializers import CustomLLMrequestSerializer from rest_framework.views import APIView from urllib.parse import quote_plus from rest_framework.response import Response db_conf = DATABASES['default'] # 密码需要 URL 编码(因为有特殊字符如 @) password_encodeed = quote_plus(db_conf['PASSWORD']) db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"]) # model_url = "http://14.22.88.72:11025/v1/chat/completions" model_url = "http://139.159.180.64:11434/v1/chat/completions" class CustomLLM(LLM): model_url: str mode: str = 'chat' def _call(self, prompt: str, stop: list = None) -> str: data = { "model":"glm4", "messages": self.build_message(prompt), "stream": False, } response = requests.post(self.model_url, json=data, timeout=600) response.raise_for_status() content = response.json()["choices"][0]["message"]["content"] print('content---', content) clean_sql = self.strip_sql_markdown(content) if self.mode == 'sql' else content.strip() return clean_sql def _generate(self, prompts: list, stop: list = None) -> LLMResult: generations = [] for prompt in prompts: text = self._call(prompt, stop) generations.append([Generation(text=text)]) return LLMResult(generations=generations) def strip_sql_markdown(self, content: str) -> str: import re # 去掉包裹在 ```sql 或 ``` 中的内容 match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() else: return content.strip() def build_message(self, prompt: str) -> list: if self.mode == 'sql': system_prompt = ( "你是一个 SQL 助手,严格遵循以下规则:\n" "1. 只返回 PostgreSQL 语法 SQL 语句。\n" "2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n" "3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。\n" "4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。" ) else: system_prompt = "你是一个聊天助手,请根据用户的问题,提供简洁明了的答案。" return [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] @property def _llm_type(self) -> str: return "custom_llm" class QueryLLMview(APIView): def post(self, request): serializer = CustomLLMrequestSerializer(data=request.data) serializer.is_valid(raise_exception=True) prompt = serializer.validated_data['prompt'] mode = serializer.validated_data.get('mode', 'chat') llm = CustomLLM(model_url=model_url, mode=mode) print('prompt---', prompt, mode) if mode == 'sql': chain = SQLDatabaseChain.from_llm(llm, db, verbose=True) result = chain.invoke(prompt) else: result = llm._call(prompt) return Response({"result": result})