feat:ichat 修改接口 去掉langchain
This commit is contained in:
parent
89c8cac7c1
commit
a5a862f7fb
|
@ -14,4 +14,4 @@ class Message(BaseModel):
|
||||||
"""
|
"""
|
||||||
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, verbose_name='对话')
|
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, verbose_name='对话')
|
||||||
content = models.TextField(verbose_name='消息内容')
|
content = models.TextField(verbose_name='消息内容')
|
||||||
role = models.CharField("角色", max_length=10, help_text="system/user")
|
role = models.CharField("角色", max_length=10, default='user', help_text="system/user")
|
||||||
|
|
|
@ -1,4 +1,18 @@
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
from .models import Conversation, Message
|
||||||
|
from apps.utils.constants import EXCLUDE_FIELDS
|
||||||
|
|
||||||
class CustomLLMrequestSerializer(serializers.Serializer):
|
|
||||||
prompt = serializers.CharField()
|
class MessageSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = Message
|
||||||
|
fields = ['id', 'conversation', 'mode', 'content', 'role']
|
||||||
|
read_only_fields = EXCLUDE_FIELDS
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationSerializer(serializers.ModelSerializer):
|
||||||
|
messages = MessageSerializer(many=True, read_only=True)
|
||||||
|
class Meta:
|
||||||
|
model = Conversation
|
||||||
|
fields = ['id', 'title', 'messages']
|
||||||
|
read_only_fields = EXCLUDE_FIELDS
|
|
@ -1,8 +1,10 @@
|
||||||
|
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
from apps.ichat.views import QueryLLMview
|
from apps.ichat.views import QueryLLMview, ConversationView
|
||||||
|
|
||||||
API_BASE_URL = 'api/llm/ichat/'
|
API_BASE_URL = 'api/ichat/'
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path(API_BASE_URL + 'query/', QueryLLMview.as_view(), name='llm_query'),
|
path(API_BASE_URL + 'query/', QueryLLMview.as_view(), name='llm_query'),
|
||||||
|
path(API_BASE_URL + 'conversation/', ConversationView.as_view(), name='conversation')
|
||||||
|
|
||||||
]
|
]
|
|
@ -0,0 +1,87 @@
|
||||||
|
import requests
|
||||||
|
from langchain_core.language_models import LLM
|
||||||
|
from langchain_core.outputs import LLMResult, Generation
|
||||||
|
from langchain_experimental.sql import SQLDatabaseChain
|
||||||
|
from langchain_community.utilities import SQLDatabase
|
||||||
|
from server.conf import DATABASES
|
||||||
|
from apps.ichat.serializers import CustomLLMrequestSerializer
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
from rest_framework.response import Response
|
||||||
|
|
||||||
|
|
||||||
|
db_conf = DATABASES['default']
|
||||||
|
# 密码需要 URL 编码(因为有特殊字符如 @)
|
||||||
|
password_encodeed = quote_plus(db_conf['PASSWORD'])
|
||||||
|
|
||||||
|
db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"])
|
||||||
|
# model_url = "http://14.22.88.72:11025/v1/chat/completions"
|
||||||
|
model_url = "http://139.159.180.64:11434/v1/chat/completions"
|
||||||
|
|
||||||
|
class CustomLLM(LLM):
|
||||||
|
model_url: str
|
||||||
|
mode: str = 'chat'
|
||||||
|
def _call(self, prompt: str, stop: list = None) -> str:
|
||||||
|
data = {
|
||||||
|
"model":"glm4",
|
||||||
|
"messages": self.build_message(prompt),
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
response = requests.post(self.model_url, json=data, timeout=600)
|
||||||
|
response.raise_for_status()
|
||||||
|
content = response.json()["choices"][0]["message"]["content"]
|
||||||
|
print('content---', content)
|
||||||
|
clean_sql = self.strip_sql_markdown(content) if self.mode == 'sql' else content.strip()
|
||||||
|
return clean_sql
|
||||||
|
|
||||||
|
def _generate(self, prompts: list, stop: list = None) -> LLMResult:
|
||||||
|
generations = []
|
||||||
|
for prompt in prompts:
|
||||||
|
text = self._call(prompt, stop)
|
||||||
|
generations.append([Generation(text=text)])
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
def strip_sql_markdown(self, content: str) -> str:
|
||||||
|
import re
|
||||||
|
# 去掉包裹在 ```sql 或 ``` 中的内容
|
||||||
|
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
else:
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
def build_message(self, prompt: str) -> list:
|
||||||
|
if self.mode == 'sql':
|
||||||
|
system_prompt = (
|
||||||
|
"你是一个 SQL 助手,严格遵循以下规则:\n"
|
||||||
|
"1. 只返回 PostgreSQL 语法 SQL 语句。\n"
|
||||||
|
"2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n"
|
||||||
|
"3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。\n"
|
||||||
|
"4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_prompt = "你是一个聊天助手,请根据用户的问题,提供简洁明了的答案。"
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "custom_llm"
|
||||||
|
|
||||||
|
|
||||||
|
class QueryLLMview(APIView):
|
||||||
|
def post(self, request):
|
||||||
|
serializer = CustomLLMrequestSerializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
prompt = serializer.validated_data['prompt']
|
||||||
|
mode = serializer.validated_data.get('mode', 'chat')
|
||||||
|
llm = CustomLLM(model_url=model_url, mode=mode)
|
||||||
|
print('prompt---', prompt, mode)
|
||||||
|
if mode == 'sql':
|
||||||
|
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
||||||
|
result = chain.invoke(prompt)
|
||||||
|
else:
|
||||||
|
result = llm._call(prompt)
|
||||||
|
return Response({"result": result})
|
|
@ -1,76 +1,173 @@
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.language_models import LLM
|
import psycopg2
|
||||||
from langchain_core.outputs import LLMResult, Generation
|
|
||||||
from langchain_experimental.sql import SQLDatabaseChain
|
|
||||||
from langchain_community.utilities import SQLDatabase
|
|
||||||
from server.conf import DATABASES
|
|
||||||
from apps.ichat.serializers import CustomLLMrequestSerializer
|
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
from urllib.parse import quote_plus
|
from apps.ichat.serializers import MessageSerializer, ConversationSerializer
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
from ichat.models import Conversation, Message
|
||||||
|
from rest_framework.generics import get_object_or_404
|
||||||
|
#本地部署模型
|
||||||
|
# API_KEY = "sk-5644e2d6077b46b9a04a8a2b12d6b693"
|
||||||
|
# API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
# MODEL = "qwen-plus"
|
||||||
|
|
||||||
|
#本地部署的模式
|
||||||
|
# API_KEY = "JJVAide0hw3eaugGmxecyYYFw45FX2LfhnYJtC+W2rw"
|
||||||
|
# API_BASE = "http://106.0.4.200:9000/v1"
|
||||||
|
# MODEL = "Qwen/Qwen2.5-14B-Instruct"
|
||||||
|
|
||||||
|
# google gemini
|
||||||
|
API_KEY = "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"
|
||||||
|
API_BASE = "https://openrouter.ai/api/v1"
|
||||||
|
MODEL="google/gemini-2.0-flash-exp:free"
|
||||||
|
|
||||||
|
# deepseek v3
|
||||||
|
# API_KEY = "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"
|
||||||
|
# API_BASE = "https://openrouter.ai/api/v1"
|
||||||
|
# MODEL="deepseek/deepseek-chat-v3-0324:free"
|
||||||
|
|
||||||
|
|
||||||
db_conf = DATABASES['default']
|
TABLES = ["enm_mpoint", "enm_mpointstat", "enm_mplogx"] # 如果整个数据库全都给模型,准确率下降,所以只给模型部分表
|
||||||
# 密码需要 URL 编码(因为有特殊字符如 @)
|
# 数据库连接
|
||||||
password_encodeed = quote_plus(db_conf['PASSWORD'])
|
def connect_db():
|
||||||
|
from server.conf import DATABASES
|
||||||
|
db_conf = DATABASES['default']
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=db_conf['HOST'],
|
||||||
|
port=db_conf['PORT'],
|
||||||
|
user=db_conf['USER'],
|
||||||
|
password=db_conf['PASSWORD'],
|
||||||
|
database=db_conf['NAME']
|
||||||
|
)
|
||||||
|
return conn
|
||||||
|
|
||||||
db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['USER']}:{password_encodeed}@{db_conf['HOST']}/{db_conf['NAME']}", include_tables=["enm_mpoint", "enm_mpointstat"])
|
def get_schema_text(conn, table_names:list):
|
||||||
# model_url = "http://14.22.88.72:11025/v1/chat/completions"
|
cur = conn.cursor()
|
||||||
model_url = "http://139.159.180.64:11434/v1/chat/completions"
|
query = """
|
||||||
|
SELECT
|
||||||
|
table_name, column_name, data_type
|
||||||
|
FROM
|
||||||
|
information_schema.columns
|
||||||
|
WHERE
|
||||||
|
table_schema = 'public'
|
||||||
|
and table_name in %s;
|
||||||
|
"""
|
||||||
|
cur.execute(query, (tuple(table_names), ))
|
||||||
|
|
||||||
class CustomLLM(LLM):
|
schema = {}
|
||||||
model_url: str
|
for table_name, column_name, data_type in cur.fetchall():
|
||||||
def _call(self, prompt: str, stop: list = None) -> str:
|
if table_name not in schema:
|
||||||
data = {
|
schema[table_name] = []
|
||||||
"model": "glm4",
|
schema[table_name].append(f"{column_name} ({data_type})")
|
||||||
"messages": [
|
cur.close()
|
||||||
{
|
schema_text = ""
|
||||||
"role": "system",
|
for table_name, columns in schema.items():
|
||||||
"content": "你是一个 SQL 助手,严格遵循以下规则:\n"
|
schema_text += f"表{table_name} 包含列:{', '.join(columns)}\n"
|
||||||
"1. 只返回 PostgreSQL 语法 SQL 语句。\n"
|
return schema_text
|
||||||
"2. 严格禁止添加任何解释、注释、Markdown 代码块标记(包括 ```sql 和 ```)。\n"
|
|
||||||
"3. 输出必须是纯 SQL,且可直接执行,无需任何额外处理。"
|
|
||||||
"4. 在 SQL 中如有多个表,请始终使用表名前缀引用字段,避免字段歧义。"
|
|
||||||
},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
],
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
response = requests.post(self.model_url, json=data, timeout=600)
|
|
||||||
response.raise_for_status()
|
|
||||||
content = response.json()["choices"][0]["message"]["content"]
|
|
||||||
clean_sql = self.strip_sql_markdown(content)
|
|
||||||
return clean_sql
|
|
||||||
|
|
||||||
def _generate(self, prompts: list, stop: list = None) -> LLMResult:
|
|
||||||
generations = []
|
# 调用大模型生成sql
|
||||||
for prompt in prompts:
|
def call_llm_api(prompt, api_key=API_KEY, api_base=API_BASE, model=MODEL):
|
||||||
text = self._call(prompt, stop)
|
url = f"{api_base}/chat/completions"
|
||||||
generations.append([Generation(text=text)])
|
headers = {
|
||||||
return LLMResult(generations=generations)
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
def strip_sql_markdown(self, content: str) -> str:
|
}
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0,
|
||||||
|
}
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
print("\n大模型返回:\n", response.json())
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def execute_sql(conn, sql_query):
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(sql_query)
|
||||||
|
try:
|
||||||
|
rows = cur.fetchall()
|
||||||
|
columns = [desc[0] for desc in cur.description]
|
||||||
|
result = [dict(zip(columns, row)) for row in rows]
|
||||||
|
except psycopg2.ProgrammingError:
|
||||||
|
result = cur.statusmessage
|
||||||
|
cur.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def strip_sql_markdown(content: str) -> str:
|
||||||
import re
|
import re
|
||||||
# 去掉包裹在 ```sql 或 ``` 中的内容
|
# 去掉包裹在 ```sql 或 ``` 中的内容
|
||||||
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
|
match = re.search(r"```sql\s*(.*?)```", content, re.DOTALL | re.IGNORECASE)
|
||||||
if match:
|
if match:
|
||||||
return match.group(1).strip()
|
return match.group(1).strip()
|
||||||
match = re.search(r"```\s*(.*?)```", content, re.DOTALL)
|
else:
|
||||||
if match:
|
return None
|
||||||
return match.group(1).strip()
|
|
||||||
return content.strip()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
return "custom_llm"
|
|
||||||
|
|
||||||
|
|
||||||
class QueryLLMview(APIView):
|
class QueryLLMview(APIView):
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
serializer = CustomLLMrequestSerializer(data=request.data)
|
serializer = MessageSerializer(data=request.data)
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
|
serializer.save()
|
||||||
prompt = serializer.validated_data['prompt']
|
prompt = serializer.validated_data['prompt']
|
||||||
llm = CustomLLM(model_url=model_url)
|
conn = connect_db()
|
||||||
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
# 数据库表结构
|
||||||
result = chain.invoke(prompt)
|
schema_text = get_schema_text(conn, TABLES)
|
||||||
return Response({"result": result})
|
user_prompt = f"""你是可能是一个专业的数据库工程师,根据以下数据库结构:
|
||||||
|
{schema_text}
|
||||||
|
请根据我的需求生成一条标准的PostgreSQL SQL语句,直接返回SQL,不要额外解释。
|
||||||
|
需求是:{prompt}
|
||||||
|
"""
|
||||||
|
llm_data = call_llm_api(user_prompt)
|
||||||
|
# 判断是否生成的是sql 如果不是直接返回message
|
||||||
|
generated_sql = strip_sql_markdown(llm_data)
|
||||||
|
if generated_sql:
|
||||||
|
try:
|
||||||
|
result = execute_sql(conn, generated_sql)
|
||||||
|
return Response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
print("\n第一次执行SQL报错了,错误信息:", str(e))
|
||||||
|
# 如果第一次执行SQL报错,则重新生成SQL
|
||||||
|
fix_prompt = f"""刚才你生成的SQL出现了错误,错误信息是:{str(e)}
|
||||||
|
请根据这个错误修正你的SQL,返回新的正确的SQL,直接给出SQL,不要解释。
|
||||||
|
数据库结构如下:
|
||||||
|
{schema_text}
|
||||||
|
用户需求是:{prompt}
|
||||||
|
"""
|
||||||
|
fixed_sql = call_llm_api(fix_prompt)
|
||||||
|
fixed_sql = strip_sql_markdown(fixed_sql)
|
||||||
|
try:
|
||||||
|
results = execute_sql(conn, fixed_sql)
|
||||||
|
print("\n修正后的查询结果:")
|
||||||
|
print(results)
|
||||||
|
return Response({"result": results})
|
||||||
|
except Exception as e2:
|
||||||
|
print("\n修正后的SQL仍然报错,错误信息:", str(e2))
|
||||||
|
return Response({"error": "SQL执行失败", "detail": str(e2)}, status=400)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
else:
|
||||||
|
return Response({"result": llm_data})
|
||||||
|
|
||||||
|
|
||||||
|
# 先新建对话 生成对话session_id
|
||||||
|
class ConversationView(APIView):
|
||||||
|
def get(self, request):
|
||||||
|
conversation = Conversation.objects.all()
|
||||||
|
serializer = ConversationSerializer(conversation, many=True)
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
def post(self, request):
|
||||||
|
serializer = ConversationSerializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
serializer.save()
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
def put(self, request, pk):
|
||||||
|
conversation = get_object_or_404(Conversation, pk=pk)
|
||||||
|
serializer = ConversationSerializer(conversation, data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
serializer.save()
|
||||||
|
return Response(serializer.data)
|
Loading…
Reference in New Issue