87 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
import requests
 | 
						||
from langchain_core.language_models import LLM
 | 
						||
from langchain_core.outputs import LLMResult, Generation
 | 
						||
from langchain_experimental.sql import SQLDatabaseChain
 | 
						||
from langchain_community.utilities import SQLDatabase
 | 
						||
from server.conf import DATABASES
 | 
						||
from apps.ichat.serializers import CustomLLMrequestSerializer
 | 
						||
from rest_framework.views import APIView
 | 
						||
from urllib.parse import quote_plus
 | 
						||
from rest_framework.response import Response
 | 
						||
 | 
						||
 | 
						||
db_conf = DATABASES['default']
 | 
						||
# 密码需要 URL 编码(因为有特殊字符如 @)
 | 
						||
password_encodeed = quote_plus(db_conf['PASSWORD'])
 | 
						||
 | 
						||
db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"])
 | 
						||
# model_url = "http://14.22.88.72:11025/v1/chat/completions"
 | 
						||
model_url = "http://139.159.180.64:11434/v1/chat/completions"
 | 
						||
 | 
						||
class CustomLLM(LLM):
 | 
						||
    model_url: str
 | 
						||
    mode: str = 'chat'
 | 
						||
    def _call(self, prompt: str, stop: list = None) -> str:
 | 
						||
        data = {
 | 
						||
            "model":"glm4",
 | 
						||
            "messages": self.build_message(prompt),
 | 
						||
            "stream": False,
 | 
						||
        }
 | 
						||
        response = requests.post(self.model_url, json=data, timeout=600)
 | 
						||
        response.raise_for_status()
 | 
						||
        content = response.json()["choices"][0]["message"]["content"]
 | 
						||
        print('content---', content)
 | 
						||
        clean_sql = self.strip_sql_markdown(content) if self.mode == 'sql' else content.strip()
 | 
						||
        return clean_sql
 | 
						||
 | 
						||
    def _generate(self, prompts: list, stop: list = None) -> LLMResult:
 | 
						||
        generations = []
 | 
						||
        for prompt in prompts:
 | 
						||
            text = self._call(prompt, stop)
 | 
						||
            generations.append([Generation(text=text)])
 | 
						||
        return LLMResult(generations=generations)
 | 
						||
    
 | 
						||
    def strip_sql_markdown(self, content: str) -> str:
 | 
						||
        import re
 | 
						||
        # 去掉包裹在 ```sql 或 ``` 中的内容
 | 
						||
        match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
 | 
						||
        if match:
 | 
						||
            return match.group(1).strip()
 | 
						||
        else:
 | 
						||
            return content.strip()
 | 
						||
        
 | 
						||
    def build_message(self, prompt: str) -> list:
 | 
						||
        if self.mode == 'sql':
 | 
						||
            system_prompt =  (
 | 
						||
                "你是一个 SQL 助手,严格遵循以下规则:\n"
 | 
						||
                "1. 只返回 PostgreSQL 语法 SQL 语句。\n"
 | 
						||
                "2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n"
 | 
						||
                "3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。\n"
 | 
						||
                "4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。"
 | 
						||
            )
 | 
						||
        else:
 | 
						||
            system_prompt = "你是一个聊天助手,请根据用户的问题,提供简洁明了的答案。"
 | 
						||
        return [
 | 
						||
            {"role": "system", "content": system_prompt},
 | 
						||
            {"role": "user", "content": prompt},
 | 
						||
        ]
 | 
						||
    
 | 
						||
    @property
 | 
						||
    def _llm_type(self) -> str:
 | 
						||
        return "custom_llm"
 | 
						||
    
 | 
						||
 | 
						||
class QueryLLMview(APIView):
 | 
						||
    def post(self, request):
 | 
						||
        serializer = CustomLLMrequestSerializer(data=request.data)
 | 
						||
        serializer.is_valid(raise_exception=True)
 | 
						||
        prompt = serializer.validated_data['prompt']
 | 
						||
        mode = serializer.validated_data.get('mode', 'chat')
 | 
						||
        llm = CustomLLM(model_url=model_url, mode=mode)
 | 
						||
        print('prompt---', prompt, mode)
 | 
						||
        if mode == 'sql':
 | 
						||
            chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
 | 
						||
            result = chain.invoke(prompt)
 | 
						||
        else:
 | 
						||
            result = llm._call(prompt)
 | 
						||
        return Response({"result": result}) |