79 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			79 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| import requests
 | ||
| from pydantic import Field
 | ||
| from langchain_core.language_models import LLM
 | ||
| from langchain_core.outputs import LLMResult, Generation
 | ||
| from langchain_experimental.sql import SQLDatabaseChain
 | ||
| from langchain_community.utilities import SQLDatabase
 | ||
| from server.conf import DATABASES
 | ||
| from serializers import CustomLLMrequestSerializer
 | ||
| from rest_framework.views import APIView
 | ||
| from urllib.parse import quote_plus
 | ||
| # fastapi
 | ||
| from fastapi import FastAPI
 | ||
| from pydantic import BaseModel
 | ||
| 
 | ||
| 
 | ||
| db_conf = DATABASES['default']
 | ||
| # 密码需要 URL 编码(因为有特殊字符如 @)
 | ||
| password_encodeed = quote_plus(db_conf['password'])
 | ||
| 
 | ||
| db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['user']}:{password_encodeed}@{db_conf['host']}/{db_conf['name']}", include_tables=["enm_mpoint", "enm_mpointstat"])
 | ||
| # model_url = "http://14.22.88.72:11025/v1/chat/completions"
 | ||
| model_url = "http://139.159.180.64:11434/v1/chat/completions"
 | ||
| 
 | ||
| class CustomLLM(LLM):
 | ||
|     model_url: str
 | ||
|     def _call(self, prompt: str, stop: list = None) -> str:
 | ||
|         data = {
 | ||
|             "model": "glm4",
 | ||
|             "messages": [
 | ||
|                 {
 | ||
|                     "role": "system",
 | ||
|                     "content": "你是一个 SQL 助手,严格遵循以下规则:\n"
 | ||
|                                 "1. 只返回 PostgreSQL 语法 SQL 语句。\n"
 | ||
|                                 "2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n"
 | ||
|                                 "3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。"
 | ||
|                                 "4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。"
 | ||
|                 },
 | ||
|                 {"role": "user", "content": prompt}
 | ||
|             ],
 | ||
|             "stream": False
 | ||
|         }
 | ||
|         response = requests.post(self.model_url, json=data, timeout=600)
 | ||
|         response.raise_for_status()
 | ||
|         content = response.json()["choices"][0]["message"]["content"]
 | ||
|         clean_sql = self.strip_sql_markdown(content)
 | ||
|         return clean_sql
 | ||
| 
 | ||
|     def _generate(self, prompts: list, stop: list = None) -> LLMResult:
 | ||
|         generations = []
 | ||
|         for prompt in prompts:
 | ||
|             text = self._call(prompt, stop)
 | ||
|             generations.append([Generation(text=text)])
 | ||
|         return LLMResult(generations=generations)
 | ||
|     
 | ||
|     def strip_sql_markdown(self, content: str) -> str:
 | ||
|         import re
 | ||
|         # 去掉包裹在 ```sql 或 ``` 中的内容
 | ||
|         match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
 | ||
|         if match:
 | ||
|             return match.group(1).strip()
 | ||
|         match = re.search(r"```\s*(.*?)```", content, re.DOTALL)
 | ||
|         if match:
 | ||
|             return match.group(1).strip()
 | ||
|         return content.strip()
 | ||
|     
 | ||
|     @property
 | ||
|     def _llm_type(self) -> str:
 | ||
|         return "custom_llm"
 | ||
|     
 | ||
| 
 | ||
| class QueryLLMview(APIView):
 | ||
|     def post(self, request):
 | ||
|         serializer = CustomLLMrequestSerializer(data=request.data)
 | ||
|         serializer.is_valid(raise_exception=True)
 | ||
|         prompt = serializer.validated_data['prompt']
 | ||
|         llm = CustomLLM(model_url=model_url)
 | ||
|         chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
 | ||
|         result = chain.invoke(prompt)
 | ||
|         return result |