feat:ichat 修改大模型接口
This commit is contained in:
parent
84380931fd
commit
8f6a6eb973
|
@ -0,0 +1,4 @@
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
class CustomLLMrequestSerializer(serializers.Serializer):
|
||||||
|
prompt = serializers.CharField()
|
|
@ -0,0 +1,12 @@
|
||||||
|
from apps.ichat.views import QueryLLMview
|
||||||
|
from django.urls import path, include
|
||||||
|
from rest_framework.routers import DefaultRouter
|
||||||
|
|
||||||
|
API_BASE_URL = 'api/hrm/'
|
||||||
|
HTML_BASE_URL = 'dhtml/hrm/'
|
||||||
|
|
||||||
|
router = DefaultRouter()
|
||||||
|
router.register('llm/query/', QueryLLMview, basename='llm_query')
|
||||||
|
urlpatterns = [
|
||||||
|
path(API_BASE_URL, include(router.urls)),
|
||||||
|
]
|
|
@ -4,11 +4,20 @@ from langchain_core.language_models import LLM
|
||||||
from langchain_core.outputs import LLMResult, Generation
|
from langchain_core.outputs import LLMResult, Generation
|
||||||
from langchain_experimental.sql import SQLDatabaseChain
|
from langchain_experimental.sql import SQLDatabaseChain
|
||||||
from langchain_community.utilities import SQLDatabase
|
from langchain_community.utilities import SQLDatabase
|
||||||
|
from server.conf import DATABASES
|
||||||
|
from serializers import CustomLLMrequestSerializer
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
from urllib.parse import quote_plus
|
||||||
# fastapi
|
# fastapi
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
db = SQLDatabase.from_uri("postgresql+pg8000://postgres:zcDsj%402024@127.0.0.1:5432/factory", include_tables=["enm_mpoint", "enm_mpointstat"])
|
|
||||||
|
db_conf = DATABASES['default']
|
||||||
|
# 密码需要 URL 编码(因为有特殊字符如 @)
|
||||||
|
password_encodeed = quote_plus(db_conf['password'])
|
||||||
|
|
||||||
|
db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_conf['user']}:{password_encodeed}@{db_conf['host']}/{db_conf['name']}", include_tables=["enm_mpoint", "enm_mpointstat"])
|
||||||
# model_url = "http://14.22.88.72:11025/v1/chat/completions"
|
# model_url = "http://14.22.88.72:11025/v1/chat/completions"
|
||||||
model_url = "http://139.159.180.64:11434/v1/chat/completions"
|
model_url = "http://139.159.180.64:11434/v1/chat/completions"
|
||||||
|
|
||||||
|
@ -59,17 +68,12 @@ class CustomLLM(LLM):
|
||||||
return "custom_llm"
|
return "custom_llm"
|
||||||
|
|
||||||
|
|
||||||
# 实例化
|
class QueryLLMview(APIView):
|
||||||
app = FastAPI()
|
def post(self, request):
|
||||||
|
serializer = CustomLLMrequestSerializer(data=request.data)
|
||||||
class CustomLLMRequest(BaseModel):
|
serializer.is_valid(raise_exception=True)
|
||||||
prompt: str
|
prompt = serializer.validated_data['prompt']
|
||||||
|
llm = CustomLLM(model_url=model_url)
|
||||||
@app.post("/llm/query/")
|
chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
|
||||||
def query(custom_llm_request: CustomLLMRequest):
|
result = chain.invoke(prompt)
|
||||||
prompt = custom_llm_request.prompt
|
return result
|
||||||
llm = CustomLLM(model_url=model_url)
|
|
||||||
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
|
|
||||||
result = db_chain.invoke(prompt)
|
|
||||||
print('result--', result, prompt)
|
|
||||||
return {"result": result}
|
|
Loading…
Reference in New Issue