初始化

This commit is contained in:
caoqianming 2024-09-24 11:33:54 +08:00
commit 03c467159c
3 changed files with 208 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
.venv/
__pycache__/
app.log
conf.py

198
main.py Normal file
View File

@ -0,0 +1,198 @@
from fastapi import FastAPI, File, UploadFile, Form
from typing import Literal
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException
from PIL import Image
from paddleocr import PaddleOCR
import numpy as np
import uvicorn
import re
import io
from pdf2image import convert_from_bytes
import uuid
import os
import requests
import json
import paddle
import sys
from uvicorn.config import LOGGING_CONFIG
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, CUR_DIR)
import conf
LOGGING_CONFIG["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
LOGGING_CONFIG["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"filename": os.path.join(CUR_DIR, "app.log"),
"maxBytes": 5 * 1024 * 1024, # 5 MB
"backupCount": 5,
"formatter": "default",
"encoding": "utf-8"
}
LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"]=["file"]
gpu_available = paddle.device.is_compiled_with_cuda()
print("GPU available:", gpu_available)
ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=True)
app = FastAPI()
ALLOWED_FILE_TYPES = {"application/pdf", "image/jpeg", "image/png", "image/jpg"}
def save_imgs(images):
"""
保存图像到本地
"""
if not os.path.exists('save_imgs'):
os.makedirs('save_imgs')
paths = []
for i, image in enumerate(images):
image_path = os.path.join('saved_imgs', f'{uuid.uuid4()}.png')
image.save(image_path)
paths.append(image_path)
return paths
def check_file_format(file: UploadFile):
if file.content_type not in ALLOWED_FILE_TYPES:
raise HTTPException(400, "File format not supported.")
def perform_ocr(ocr, images):
"""
使用 OCR 提取图像文件中的文字.
"""
all_text = []
for image in images:
# Convert PIL Image to numpy array
image_np = np.array(image)
result = ocr.ocr(image_np, cls=True)
all_text.extend(result)
return all_text
def extract_standard_info(text, patterns):
"""
从标准文件文本中提取信息.
"""
info = {}
for label, pattern in patterns:
if label in ("归口单位", "中文名称"):
matches = re.findall(pattern, text, re.MULTILINE)
else:
matches = re.findall(pattern, text, re.DOTALL)
if matches:
if label in ["发布日期", "实施日期"]:
info[label] = re.sub(r'', '-', matches[0].replace(' ', ''))
elif label in ["起草单位", "起草人"]:
info[label] = matches[0].strip()
elif label == "提出单位":
info[label] = matches[-1].strip()
elif label == "标准号":
info[label] = matches[0].strip()
elif label == "发布部门":
if matches[0][1]:
info[label] = matches[0][0] + '' + matches[0][1]
else:
info[label] = matches[0][0] + matches[0][1]
info[label] = re.sub(r'\n', '', info[label])
else:
info[label] = matches[-1] if matches else None
return info
def extract_patent_info(text, patterns):
"""
从专利文件文本中提取信息.
"""
info = {}
for label, pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
info[label] = matches[-1] if matches else None
return info
def clean_info(info: dict):
for key in info:
if info[key]:
info[key] = info[key].replace('\n', '').replace('\r', '').replace('\t', '').strip()
# 定义标准文件用的正则表达式模式
standard_patterns = [
("标准号", r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}"), # 仅提取了第一个标准号(若有更多标准号可能错误)
("中文名称", r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称
r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分
r"(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布"), # 定位发布日期
("英文名称", r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布"), # 匹配发布日期前面的非中文部分
("发布部门", r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)"), # 发布部门是通过匹配发布的上下行的中文
("发布日期", r"(\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布"),
("实施日期", r"(\d{4}(?:-|)\d{2}(?:-|)\d{2}) *实施"),
("提出单位", r"由(.+?)\s?提出。"),
("归口单位", r"由(.+?)\s?归口。"),
("起草单位", r"起草单位[:](.+?)(?:。|起草人)"),
("起草人", r"起草人[:](.+?)(?:。)")
]
# 定义专利文件用的正则表达式模式
patent_patterns = [
("专利名称", r"名称[:](.+?)\n"),
("发明人", r"发[\s]*明[\s]*人[\s]*[:](.+?)\n"),
("专利号", r"专[\s]*利[\s]*号[\s]*[:](.+?)\n"),
("专利申请日", r"专利申请日[:](.+?)\n"),
("专利权人", r"专[\s]*利[\s]*权[\s]*人[\s]*[:](.+?)\n"),
("地址", r"地[\s]*址[\s]*[:](.+?)\n"),
("授权公告日", r"授权公告日:(.+?)\n"),
("授权公告号", r"授权公告号:(.+?)\n")
]
@app.get("/error")
async def create_error():
raise ValueError("This is a test error!")
@app.post("/extract")
async def extract_info(
file_type: Literal["standard", "patent"] = Form(..., description="Specify the type of file to extract."),
extract_method: Literal["re", "chat"] = Form(..., description="Specify how to extract. 正则或对话模型"),
file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file.")
):
check_file_format(file)
content = await file.read() # 读取文件内容
if file.filename.lower().endswith('.pdf'):
images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300)
else:
images = [Image.open(io.BytesIO(content))]
result = perform_ocr(ocr, images)
ocr_text = "\n".join([line[1][0] for res in result if res is not None for line in res])
# 提取信息
if extract_method == "re":
if file_type == "patent":
info = extract_patent_info(ocr_text, patent_patterns)
clean_info(info)
elif file_type == "standard":
info = extract_standard_info(ocr_text, standard_patterns)
clean_info(info)
else:
raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.")
else:
if file_type == 'patent':
prompt = f'我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人授权公告号授权公告日为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}'
elif file_type == "standard":
prompt = f'我有以下文本,是一个标准的内容。请按标准号中文名称英文名称发布部门发布日期实施日期提出单位归口单位起草单位起草人为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}'
else:
raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.")
r = requests.post(conf.CHAT_API, json={
"model": "llama3.1",
"prompt": prompt,
"stream": False
})
info = json.loads((r.json()['response']))
return JSONResponse(content=info)
if __name__ == "__main__":
uvicorn.run(app="main:app", host="0.0.0.0", port=8000, reload=True, log_config=LOGGING_CONFIG)

6
requirements.txt Normal file
View File

@ -0,0 +1,6 @@
fastapi
uvicorn
python-multipart
paddlepaddle
paddleocr
pdf2image