364 lines
14 KiB
Python
364 lines
14 KiB
Python
from fastapi import FastAPI, File, UploadFile, Form
|
||
from typing import Literal
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.exceptions import HTTPException
|
||
from PIL import Image
|
||
from paddleocr import PaddleOCR
|
||
import numpy as np
|
||
import uvicorn
|
||
import re
|
||
import io
|
||
from pdf2image import convert_from_bytes
|
||
import uuid
|
||
import os
|
||
import requests
|
||
import json
|
||
import paddle
|
||
import sys
|
||
from uvicorn.config import LOGGING_CONFIG
|
||
|
||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
sys.path.insert(0, CUR_DIR)
|
||
import conf
|
||
|
||
LOGGING_CONFIG["formatters"]["default"]["fmt"] = (
|
||
"%(asctime)s - %(levelname)s - %(message)s"
|
||
)
|
||
LOGGING_CONFIG["handlers"]["file"] = {
|
||
"class": "logging.handlers.RotatingFileHandler",
|
||
"filename": os.path.join(CUR_DIR, "app.log"),
|
||
"maxBytes": 5 * 1024 * 1024, # 5 MB
|
||
"backupCount": 5,
|
||
"formatter": "default",
|
||
"encoding": "utf-8",
|
||
}
|
||
LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"] = ["file"]
|
||
|
||
|
||
gpu_available = paddle.device.is_compiled_with_cuda()
|
||
print("GPU available:", gpu_available)
|
||
|
||
|
||
ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=conf.GPU_ENABLE)
|
||
app = FastAPI()
|
||
|
||
ALLOWED_FILE_TYPES = {"application/pdf", "image/jpeg", "image/png", "image/jpg"}
|
||
|
||
|
||
def save_imgs(images):
|
||
"""
|
||
保存图像到本地
|
||
"""
|
||
if not os.path.exists("save_imgs"):
|
||
os.makedirs("save_imgs")
|
||
paths = []
|
||
for i, image in enumerate(images):
|
||
image_path = os.path.join("saved_imgs", f"{uuid.uuid4()}.png")
|
||
image.save(image_path)
|
||
paths.append(image_path)
|
||
return paths
|
||
|
||
|
||
def check_file_format(file: UploadFile):
|
||
if file.content_type not in ALLOWED_FILE_TYPES:
|
||
raise HTTPException(400, "File format not supported.")
|
||
|
||
|
||
def perform_ocr(ocr, images):
|
||
"""
|
||
使用 OCR 提取图像文件中的文字.
|
||
"""
|
||
all_text = []
|
||
for image in images:
|
||
# Convert PIL Image to numpy array
|
||
image_np = np.array(image)
|
||
result = ocr.ocr(image_np, cls=True)
|
||
all_text.extend(result)
|
||
return all_text
|
||
|
||
|
||
def extract_standard_info(text, patterns):
|
||
"""
|
||
从标准文件文本中提取信息.
|
||
"""
|
||
info = {}
|
||
for label, pattern in patterns:
|
||
if label in ("归口单位", "中文名称"):
|
||
matches = re.findall(pattern, text, re.MULTILINE)
|
||
else:
|
||
matches = re.findall(pattern, text, re.DOTALL)
|
||
if matches:
|
||
if label in ["发布日期", "实施日期"]:
|
||
info[label] = re.sub(r"-", "-", matches[0].replace(" ", ""))
|
||
elif label in ["起草单位", "起草人"]:
|
||
info[label] = matches[0].strip()
|
||
elif label == "提出单位":
|
||
info[label] = matches[-1].strip()
|
||
elif label == "标准号":
|
||
info[label] = matches[0].strip()
|
||
elif label == "发布部门":
|
||
if matches[0][1]:
|
||
info[label] = matches[0][0] + "、" + matches[0][1]
|
||
else:
|
||
info[label] = matches[0][0] + matches[0][1]
|
||
info[label] = re.sub(r"\n", "", info[label])
|
||
else:
|
||
info[label] = matches[-1] if matches else None
|
||
|
||
return info
|
||
|
||
|
||
def extract_patent_info(text, patterns):
|
||
"""
|
||
从专利文件文本中提取信息.
|
||
"""
|
||
info = {}
|
||
|
||
for label, pattern in patterns:
|
||
matches = re.findall(pattern, text, re.DOTALL)
|
||
info[label] = matches[-1] if matches else None
|
||
|
||
return info
|
||
|
||
def extract_social_info(text, patterns):
|
||
"""
|
||
从专利文件文本中提取信息.
|
||
"""
|
||
info = {}
|
||
for label, pattern in patterns:
|
||
matches = re.findall(pattern, text, re.DOTALL)
|
||
if matches and isinstance(matches[-1], tuple):
|
||
info[label] = matches[-1][-1]
|
||
else:
|
||
info[label] = matches[-1] if matches else None
|
||
|
||
return info
|
||
|
||
|
||
def clean_info(info: dict):
|
||
for key in info:
|
||
if info[key]:
|
||
info[key] = (
|
||
info[key].replace("\n", "").replace("\r", "").replace("\t", "").strip()
|
||
)
|
||
|
||
|
||
# 定义标准文件用的正则表达式模式
|
||
standard_patterns = [
|
||
(
|
||
"标准号",
|
||
r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}",
|
||
), # 仅提取了第一个标准号(若有更多标准号可能错误)
|
||
(
|
||
"中文名称",
|
||
r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称
|
||
r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分
|
||
r"(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布",
|
||
), # 定位发布日期
|
||
(
|
||
"英文名称",
|
||
r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布",
|
||
), # 匹配发布日期前面的非中文部分
|
||
(
|
||
"发布部门",
|
||
r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)",
|
||
), # 发布部门是通过匹配发布的上下行的中文
|
||
("发布日期", r"(\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布"),
|
||
("实施日期", r"(\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *实施"),
|
||
("提出单位", r"由(.+?)\s?提出。"),
|
||
("归口单位", r"由(.+?)\s?归口。"),
|
||
("起草单位", r"起草单位[::](.+?)(?:。|起草人)"),
|
||
("起草人", r"起草人[::](.+?)(?:。)"),
|
||
]
|
||
|
||
# 定义专利文件用的正则表达式模式
|
||
patent_patterns = [
|
||
("专利名称", r"名称[::](.+?)\n"),
|
||
("发明人", r"发[\s]*明[\s]*人[\s]*[::](.+?)\n"),
|
||
("专利号", r"专[\s]*利[\s]*号[\s]*[::](.+?)\n"),
|
||
("专利申请日", r"专利申请日[::](.+?)\n"),
|
||
("专利权人", r"专[\s]*利[\s]*权[\s]*人[\s]*[::](.+?)\n"),
|
||
("地址", r"地[\s]*址[\s]*[::](.+?)\n"),
|
||
("授权公告日", r"授权公告日:(.+?)\n"),
|
||
("授权公告号", r"授权公告号:(.+?)\n"),
|
||
]
|
||
|
||
# 定义国家奖正则表达式
|
||
country_patterns = [
|
||
("项目名称", r"项目名称[::](.*?)[^\u4e00-\u9fa5]+奖励等级"),
|
||
("奖励等级", r"奖励等级[::](.+?)\n"),
|
||
("获奖者", r"(获奖(?:者|单位)[::](.+?))(?:$|\n)"),
|
||
("证书号", r"证书号[::]([\w-]+)"),
|
||
("颁发日期", r"\d{4}\d{1,2}月\d{1,2}日"),
|
||
]
|
||
# 定义建材行业奖正则表达式
|
||
build_patterns = [
|
||
("项目名称", r"项目名称[::](.*?)[^\u4e00-\u9fa5]+奖励等级"),
|
||
("奖励等级", r"奖励等级[::](.+?)\n"),
|
||
("获奖单位", r"获奖单位[::](.+?)[^\u4e00-\u9fa5]+证书编号"),
|
||
("证书编号", r"证书编号[::]([\w-]+)"),
|
||
("颁发日期", r"二[〇一二三四五六七八九]{1,2}年[一至十二月一二三四五六七八九十]{2,3}"),
|
||
]
|
||
# 定义社会力量奖正则表达式、
|
||
social_patterns = [
|
||
# ("获奖项目", r"获奖项目[::](.*?)[^\u4e00-\u9fa5]"),
|
||
("项目名称", r"[项目名称|获奖项目][::](.*?)\n"), # 此处两个获奖单位是因为证书样式不同。
|
||
# ("获奖单位", r"获奖单位[::]([\u4e00-\u9fa5A-Za-z0-9、\(\)\s]+)"),
|
||
("获奖单位", r"获奖单位[::](.+?)[^\u4e00-\u9fa5]"),
|
||
("奖励等级", r"[奖励等级|获奖等级][::](\s*)([^\s].+?)\n"),
|
||
("获奖人", r"获奖人[::]([\u4e00-\u9fa5、\n]+)"),
|
||
|
||
("奖励年度", r"奖励年度[::]([\d]{4})"),
|
||
("证书编号", r"证书[ _]*[编号号][::]([\w-]+)"),
|
||
("颁发日期", r"二[〇一二三四五六七八九]{1,2}年[一至十二月一二三四五六七八九十]{2,3}"),
|
||
]
|
||
# 省部级科技进步奖
|
||
province_patterns = [
|
||
("项目名称", r"项目名称[::](.*?)[^\u4e00-\u9fa5]+奖励等级"),
|
||
("奖励等级", r"奖励等级[::](.+?)\n",),
|
||
("获奖者", r"(获奖(?:者|单位)[::](.+?))(?:$|\n)"),
|
||
("证书编号", r"([\w-]+)"),
|
||
("颁发日期", r"\d{4}年\d{1,2}月\d{1,2}日"),
|
||
]
|
||
|
||
# 软件著作权奖
|
||
software_patterns = [
|
||
("证书号", r"证书号[::]([\w-]+)"),
|
||
("软件名称", r"软件名称[::](.*?)[^\u4e00-\u9fa5]\n"),
|
||
("著作权人", r"著[\s]*作[\s]*权[\s]*人[\s]*[::](.+?)\n",),
|
||
("开发完成日期", r"开发完成日期[::](.+?)\n"),
|
||
("首次发表日期", r"首次发表日期[::](.+?)\n"),
|
||
("权利取得方式", r"权利取得方式[::](.+?)\n"),
|
||
("权利范围", r"权利范围[::](.+?)\n"),
|
||
("登记号", r"登[\s]*记[\s]*号[::](.+?)\n"),
|
||
("颁发日期", r"\d{4}年\d{1,2}月\d{1,2}日")
|
||
]
|
||
|
||
@app.post(
|
||
"/extract",
|
||
summary="提取专利/标准等文件里的信息",
|
||
responses={
|
||
200: {
|
||
"description": "成功示例",
|
||
"content": {
|
||
"application/json": {
|
||
"example": {
|
||
"标准号": "GB/T 1234-2020",
|
||
"中文名称": "低热矿渣硅酸盐水泥",
|
||
"英文名称": "Low-heat portland slag cement",
|
||
"发布部门": "国家质量监督检验检疫总局",
|
||
"发布日期": "2020-01-01",
|
||
"实施日期": "2020-07-01",
|
||
"提出单位": "中国建筑材料联合会",
|
||
"归口单位": "国家建筑材料工业技术监督研究中心",
|
||
"起草单位": "中国建筑材料科学研究总院",
|
||
"起草人": "张三、李四",
|
||
}
|
||
},
|
||
}
|
||
|
||
},
|
||
400: {
|
||
'description': '错误示例',
|
||
"content": {
|
||
"application/json": {
|
||
"example": {
|
||
"detail": "Invalid file type. Please choose 'standard' or 'patent'.",
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
)
|
||
async def extract_info(
|
||
file_type: Literal["standard", "patent", "country", "building", "social", "province", "software"] = Form(
|
||
..., description="Specify the type of file to extract."
|
||
),
|
||
extract_method: Literal["re", "chat"] = Form(
|
||
..., description="Specify how to extract. 正则或对话模型"
|
||
),
|
||
file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file."),
|
||
):
|
||
check_file_format(file)
|
||
content = await file.read() # 读取文件内容
|
||
if file.filename.lower().endswith(".pdf"):
|
||
images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300)
|
||
else:
|
||
# 如果是建筑行业奖再进行图片截取提高识别率
|
||
if file_type == "building":
|
||
img= Image.open(io.BytesIO(content))
|
||
# 定义截取的区域,格式为 (左, 上, 右, 下)
|
||
box = (730, 0, 1500, 1000) # 根据实际图片大小调整
|
||
cropped_img = img.crop(box)
|
||
images = [cropped_img]
|
||
else:
|
||
images = [Image.open(io.BytesIO(content))]
|
||
result = perform_ocr(ocr, images)
|
||
ocr_text = "\n".join(
|
||
[line[1][0] for res in result if res is not None for line in res]
|
||
)
|
||
# 提取信息
|
||
if extract_method == "re":
|
||
if file_type == "patent":
|
||
info = extract_patent_info(ocr_text, patent_patterns)
|
||
clean_info(info)
|
||
elif file_type == "standard":
|
||
info = extract_standard_info(ocr_text, standard_patterns)
|
||
clean_info(info)
|
||
elif file_type == "country":
|
||
info = extract_social_info(ocr_text, country_patterns)
|
||
clean_info(info)
|
||
elif file_type == "building":
|
||
info = extract_patent_info(ocr_text, build_patterns)
|
||
clean_info(info)
|
||
elif file_type == "social":
|
||
info = extract_social_info(ocr_text, social_patterns)
|
||
clean_info(info)
|
||
elif file_type == "province":
|
||
info = extract_social_info(ocr_text, province_patterns)
|
||
clean_info(info)
|
||
elif file_type == "software":
|
||
info = extract_social_info(ocr_text, software_patterns)
|
||
clean_info(info)
|
||
else:
|
||
raise HTTPException(
|
||
400, detail="Invalid file type. Please choose 'standard' or 'patent'."
|
||
)
|
||
else:
|
||
if file_type == "patent":
|
||
prompt = f"我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人,授权公告号,授权公告日为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}"
|
||
elif file_type == "standard":
|
||
prompt = f"我有以下文本,是一个标准的内容。请按标准号,中文名称,英文名称,发布部门,发布日期,实施日期,提出单位,归口单位,起草单位,起草人为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}"
|
||
elif file_type == "country":
|
||
prompt = f"我有以下文本,是一个国家标准的内容。请按项目名称,奖励等级,获奖者,证书号,颁发日期为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}"
|
||
elif file_type == "province":
|
||
prompt = f"我有以下文本,是一个省级标准的内容。请按项目名称,奖励等级,获奖者,证书编号,颁发日期为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}"
|
||
elif file_type == "social":
|
||
prompt = f"我有以下文本,是一个社会标准的内容。请按项目名称,获奖单位,奖励等级,获奖人,奖励年度,证书编号,颁发日期为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}"
|
||
elif file_type=="building":
|
||
prompt = f"我有以下文本,是一个建筑标准的内容。请按项目名称,奖励等级,获奖单位,证书编号,颁发日期为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}"
|
||
else:
|
||
raise HTTPException(
|
||
400, detail="Invalid file type. Please choose 'standard' or 'patent'."
|
||
)
|
||
r = requests.post(
|
||
conf.CHAT_API,
|
||
json={"model": conf.CHAT_MODEL, "prompt": prompt, "stream": False},
|
||
)
|
||
try:
|
||
info = json.loads(r.json().get("response", {}))
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSONDecodeError: {e}")
|
||
return JSONResponse(content=info)
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
app="main:app",
|
||
host="0.0.0.0",
|
||
port=2260,
|
||
reload=conf.APP_RELOAD,
|
||
log_config=LOGGING_CONFIG,
|
||
)
|