cert_recog/main.py

364 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, File, UploadFile, Form
from typing import Literal
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException
from PIL import Image
from paddleocr import PaddleOCR
import numpy as np
import uvicorn
import re
import io
from pdf2image import convert_from_bytes
import uuid
import os
import requests
import json
import paddle
import sys
from uvicorn.config import LOGGING_CONFIG
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, CUR_DIR)
import conf
LOGGING_CONFIG["formatters"]["default"]["fmt"] = (
"%(asctime)s - %(levelname)s - %(message)s"
)
LOGGING_CONFIG["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"filename": os.path.join(CUR_DIR, "app.log"),
"maxBytes": 5 * 1024 * 1024, # 5 MB
"backupCount": 5,
"formatter": "default",
"encoding": "utf-8",
}
LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"] = ["file"]
gpu_available = paddle.device.is_compiled_with_cuda()
print("GPU available:", gpu_available)
ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=conf.GPU_ENABLE)
app = FastAPI()
ALLOWED_FILE_TYPES = {"application/pdf", "image/jpeg", "image/png", "image/jpg"}
def save_imgs(images):
"""
保存图像到本地
"""
if not os.path.exists("save_imgs"):
os.makedirs("save_imgs")
paths = []
for i, image in enumerate(images):
image_path = os.path.join("saved_imgs", f"{uuid.uuid4()}.png")
image.save(image_path)
paths.append(image_path)
return paths
def check_file_format(file: UploadFile):
if file.content_type not in ALLOWED_FILE_TYPES:
raise HTTPException(400, "File format not supported.")
def perform_ocr(ocr, images):
"""
使用 OCR 提取图像文件中的文字.
"""
all_text = []
for image in images:
# Convert PIL Image to numpy array
image_np = np.array(image)
result = ocr.ocr(image_np, cls=True)
all_text.extend(result)
return all_text
def extract_standard_info(text, patterns):
"""
从标准文件文本中提取信息.
"""
info = {}
for label, pattern in patterns:
if label in ("归口单位", "中文名称"):
matches = re.findall(pattern, text, re.MULTILINE)
else:
matches = re.findall(pattern, text, re.DOTALL)
if matches:
if label in ["发布日期", "实施日期"]:
info[label] = re.sub(r"", "-", matches[0].replace(" ", ""))
elif label in ["起草单位", "起草人"]:
info[label] = matches[0].strip()
elif label == "提出单位":
info[label] = matches[-1].strip()
elif label == "标准号":
info[label] = matches[0].strip()
elif label == "发布部门":
if matches[0][1]:
info[label] = matches[0][0] + "" + matches[0][1]
else:
info[label] = matches[0][0] + matches[0][1]
info[label] = re.sub(r"\n", "", info[label])
else:
info[label] = matches[-1] if matches else None
return info
def extract_patent_info(text, patterns):
"""
从专利文件文本中提取信息.
"""
info = {}
for label, pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
info[label] = matches[-1] if matches else None
return info
def extract_social_info(text, patterns):
"""
从专利文件文本中提取信息.
"""
info = {}
for label, pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
if matches and isinstance(matches[-1], tuple):
info[label] = matches[-1][-1]
else:
info[label] = matches[-1] if matches else None
return info
def clean_info(info: dict):
for key in info:
if info[key]:
info[key] = (
info[key].replace("\n", "").replace("\r", "").replace("\t", "").strip()
)
# 定义标准文件用的正则表达式模式
standard_patterns = [
(
"标准号",
r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}",
), # 仅提取了第一个标准号(若有更多标准号可能错误)
(
"中文名称",
r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称
r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分
r"(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布",
), # 定位发布日期
(
"英文名称",
r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布",
), # 匹配发布日期前面的非中文部分
(
"发布部门",
r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)",
), # 发布部门是通过匹配发布的上下行的中文
("发布日期", r"(\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布"),
("实施日期", r"(\d{4}(?:-|)\d{2}(?:-|)\d{2}) *实施"),
("提出单位", r"由(.+?)\s?提出。"),
("归口单位", r"由(.+?)\s?归口。"),
("起草单位", r"起草单位[:](.+?)(?:。|起草人)"),
("起草人", r"起草人[:](.+?)(?:。)"),
]
# 定义专利文件用的正则表达式模式
patent_patterns = [
("专利名称", r"名称[:](.+?)\n"),
("发明人", r"发[\s]*明[\s]*人[\s]*[:](.+?)\n"),
("专利号", r"专[\s]*利[\s]*号[\s]*[:](.+?)\n"),
("专利申请日", r"专利申请日[:](.+?)\n"),
("专利权人", r"专[\s]*利[\s]*权[\s]*人[\s]*[:](.+?)\n"),
("地址", r"地[\s]*址[\s]*[:](.+?)\n"),
("授权公告日", r"授权公告日:(.+?)\n"),
("授权公告号", r"授权公告号:(.+?)\n"),
]
# 定义国家奖正则表达式
country_patterns = [
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"),
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖者", r"(获奖(?:者|单位)[:](.+?))(?:$|\n)"),
("证书号", r"证书号[:]([\w-]+)"),
("颁发日期", r"\d{4}\d{1,2}月\d{1,2}日"),
]
# 定义建材行业奖正则表达式
build_patterns = [
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"),
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]+证书编号"),
("证书编号", r"证书编号[:]([\w-]+)"),
("颁发日期", r"二[〇一二三四五六七八九]{1,2}年[一至十二月一二三四五六七八九十]{2,3}"),
]
# 定义社会力量奖正则表达式、
social_patterns = [
# ("获奖项目", r"获奖项目[:](.*?)[^\u4e00-\u9fa5]"),
("项目名称", r"[项目名称|获奖项目][:](.*?)\n"), # 此处两个获奖单位是因为证书样式不同。
# ("获奖单位", r"获奖单位[:]([\u4e00-\u9fa5A-Za-z0-9、\(\)\s]+)"),
("获奖单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]"),
("奖励等级", r"[奖励等级|获奖等级][:](\s*)([^\s].+?)\n"),
("获奖人", r"获奖人[:]([\u4e00-\u9fa5、\n]+)"),
("奖励年度", r"奖励年度[:]([\d]{4})"),
("证书编号", r"证书[ _]*[编号号][:]([\w-]+)"),
("颁发日期", r"二[〇一二三四五六七八九]{1,2}年[一至十二月一二三四五六七八九十]{2,3}"),
]
# 省部级科技进步奖
province_patterns = [
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"),
("奖励等级", r"奖励等级[:](.+?)\n",),
("获奖者", r"(获奖(?:者|单位)[:](.+?))(?:$|\n)"),
("证书编号", r"([\w-]+)"),
("颁发日期", r"\d{4}\d{1,2}月\d{1,2}日"),
]
# 软件著作权奖
software_patterns = [
("证书号", r"证书号[:]([\w-]+)"),
("软件名称", r"软件名称[:](.*?)[^\u4e00-\u9fa5]\n"),
("著作权人", r"著[\s]*作[\s]*权[\s]*人[\s]*[:](.+?)\n",),
("开发完成日期", r"开发完成日期[:](.+?)\n"),
("首次发表日期", r"首次发表日期[:](.+?)\n"),
("权利取得方式", r"权利取得方式[:](.+?)\n"),
("权利范围", r"权利范围[:](.+?)\n"),
("登记号", r"登[\s]*记[\s]*号[:](.+?)\n"),
("颁发日期", r"\d{4}\d{1,2}月\d{1,2}日")
]
@app.post(
"/extract",
summary="提取专利/标准等文件里的信息",
responses={
200: {
"description": "成功示例",
"content": {
"application/json": {
"example": {
"标准号": "GB/T 1234-2020",
"中文名称": "低热矿渣硅酸盐水泥",
"英文名称": "Low-heat portland slag cement",
"发布部门": "国家质量监督检验检疫总局",
"发布日期": "2020-01-01",
"实施日期": "2020-07-01",
"提出单位": "中国建筑材料联合会",
"归口单位": "国家建筑材料工业技术监督研究中心",
"起草单位": "中国建筑材料科学研究总院",
"起草人": "张三、李四",
}
},
}
},
400: {
'description': '错误示例',
"content": {
"application/json": {
"example": {
"detail": "Invalid file type. Please choose 'standard' or 'patent'.",
}
}
}
}
},
)
async def extract_info(
file_type: Literal["standard", "patent", "country", "building", "social", "province", "software"] = Form(
..., description="Specify the type of file to extract."
),
extract_method: Literal["re", "chat"] = Form(
..., description="Specify how to extract. 正则或对话模型"
),
file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file."),
):
check_file_format(file)
content = await file.read() # 读取文件内容
if file.filename.lower().endswith(".pdf"):
images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300)
else:
# 如果是建筑行业奖再进行图片截取提高识别率
if file_type == "building":
img= Image.open(io.BytesIO(content))
# 定义截取的区域,格式为 (左, 上, 右, 下)
box = (730, 0, 1500, 1000) # 根据实际图片大小调整
cropped_img = img.crop(box)
images = [cropped_img]
else:
images = [Image.open(io.BytesIO(content))]
result = perform_ocr(ocr, images)
ocr_text = "\n".join(
[line[1][0] for res in result if res is not None for line in res]
)
# 提取信息
if extract_method == "re":
if file_type == "patent":
info = extract_patent_info(ocr_text, patent_patterns)
clean_info(info)
elif file_type == "standard":
info = extract_standard_info(ocr_text, standard_patterns)
clean_info(info)
elif file_type == "country":
info = extract_social_info(ocr_text, country_patterns)
clean_info(info)
elif file_type == "building":
info = extract_patent_info(ocr_text, build_patterns)
clean_info(info)
elif file_type == "social":
info = extract_social_info(ocr_text, social_patterns)
clean_info(info)
elif file_type == "province":
info = extract_social_info(ocr_text, province_patterns)
clean_info(info)
elif file_type == "software":
info = extract_social_info(ocr_text, software_patterns)
clean_info(info)
else:
raise HTTPException(
400, detail="Invalid file type. Please choose 'standard' or 'patent'."
)
else:
if file_type == "patent":
prompt = f"我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人授权公告号授权公告日为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
elif file_type == "standard":
prompt = f"我有以下文本,是一个标准的内容。请按标准号中文名称英文名称发布部门发布日期实施日期提出单位归口单位起草单位起草人为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
elif file_type == "country":
prompt = f"我有以下文本,是一个国家标准的内容。请按项目名称奖励等级获奖者证书号颁发日期为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
elif file_type == "province":
prompt = f"我有以下文本,是一个省级标准的内容。请按项目名称奖励等级获奖者证书编号颁发日期为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
elif file_type == "social":
prompt = f"我有以下文本,是一个社会标准的内容。请按项目名称,获奖单位,奖励等级,获奖人,奖励年度,证书编号,颁发日期为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
elif file_type=="building":
prompt = f"我有以下文本,是一个建筑标准的内容。请按项目名称奖励等级获奖单位证书编号颁发日期为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
else:
raise HTTPException(
400, detail="Invalid file type. Please choose 'standard' or 'patent'."
)
r = requests.post(
conf.CHAT_API,
json={"model": conf.CHAT_MODEL, "prompt": prompt, "stream": False},
)
try:
info = json.loads(r.json().get("response", {}))
except json.JSONDecodeError as e:
print(f"JSONDecodeError: {e}")
return JSONResponse(content=info)
if __name__ == "__main__":
uvicorn.run(
app="main:app",
host="0.0.0.0",
port=2260,
reload=conf.APP_RELOAD,
log_config=LOGGING_CONFIG,
)