add: 新增社会力量将 和 建筑奖

This commit is contained in:
zty 2024-10-12 15:49:40 +08:00
parent a70ac00bd9
commit 3462cac98f
1 changed files with 59 additions and 20 deletions

53
main.py
View File

@ -169,7 +169,31 @@ patent_patterns = [
("授权公告号", r"授权公告号:(.+?)\n"),
]
# 定义国家奖正则表达式
country_patterns = [
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"),
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖者", r"获奖者[:](.+?)\n"),
("证书号", r"证书号[:]([\w-]+)"),
]
# 定义建材行业奖正则表达式
build_patterns = [
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"),
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]+证书编号"),
("证书编号", r"证书编号[:]([\w-]+)"),
]
# 定义社会力量奖正则表达式、
social_patterns = [
("获奖项目", r"获奖项目[:](.*?)[^\u4e00-\u9fa5]+获奖"),
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"), # 此处两个获奖单位是因为证书样式不同。
("获奖单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]+奖励等级"),
("获奖_单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]+获奖"),
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖人", r"获[\s]*奖[\s]*人[:](.+?)"),
("奖励年度", r"奖励年度[:]([\d]{4})年"),
("证书编号", r"证书编号[:]([\w-]+)"),
]
@app.post(
"/extract",
summary="提取专利/标准等文件里的信息",
@ -191,22 +215,23 @@ patent_patterns = [
"起草人": "张三、李四",
}
},
},
}
},
400: {
"description": "错误示例",
'description': '错误示例',
"content": {
"application/json": {
"example": {
"detail": "Invalid file type. Please choose 'standard' or 'patent'.",
}
}
},
},
}
}
},
)
async def extract_info(
file_type: Literal["standard", "patent"] = Form(
file_type: Literal["standard", "patent", "country", "building", "social"] = Form(
..., description="Specify the type of file to extract."
),
extract_method: Literal["re", "chat"] = Form(
@ -218,13 +243,21 @@ async def extract_info(
content = await file.read() # 读取文件内容
if file.filename.lower().endswith(".pdf"):
images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300)
else:
# 如果是建筑行业奖再进行图片截取提高识别率
if file_type == "building":
img= Image.open(io.BytesIO(content))
# 定义截取的区域,格式为 (左, 上, 右, 下)
box = (730, 0, 1500, 1000) # 根据实际图片大小调整
cropped_img = img.crop(box)
images = [cropped_img]
else:
images = [Image.open(io.BytesIO(content))]
result = perform_ocr(ocr, images)
ocr_text = "\n".join(
[line[1][0] for res in result if res is not None for line in res]
)
print("ocr_text", ocr_text)
# 提取信息
if extract_method == "re":
if file_type == "patent":
@ -233,6 +266,12 @@ async def extract_info(
elif file_type == "standard":
info = extract_standard_info(ocr_text, standard_patterns)
clean_info(info)
elif file_type == "country":
info = extract_patent_info(ocr_text, country_patterns)
clean_info(info)
elif file_type == "building":
info = extract_patent_info(ocr_text, build_patterns)
clean_info(info)
else:
raise HTTPException(
400, detail="Invalid file type. Please choose 'standard' or 'patent'."