feat:修改正则匹配 国家奖 社会奖

This commit is contained in:
zty 2024-10-18 11:17:55 +08:00
parent 3462cac98f
commit d97a4d436f
1 changed files with 45 additions and 12 deletions

57
main.py
View File

@ -120,6 +120,20 @@ def extract_patent_info(text, patterns):
return info
def extract_social_info(text, patterns):
"""
从专利文件文本中提取信息.
"""
info = {}
for label, pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
if matches and isinstance(matches[-1], tuple):
info[label] = matches[-1][-1]
else:
info[label] = matches[-1] if matches else None
return info
def clean_info(info: dict):
for key in info:
@ -173,8 +187,9 @@ patent_patterns = [
country_patterns = [
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"),
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖者", r"获奖者[:](.+?)\n"),
("获奖者", r"(获奖(?:|单位)[:](.+?))(?:$|\n)"),
("证书号", r"证书号[:]([\w-]+)"),
("颁发日期", r"\d{4}\d{1,2}月\d{1,2}日"),
]
# 定义建材行业奖正则表达式
build_patterns = [
@ -182,18 +197,30 @@ build_patterns = [
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]+证书编号"),
("证书编号", r"证书编号[:]([\w-]+)"),
("颁发日期", r"二[〇一二三四五六七八九]{1,2}年[一至十二月一二三四五六七八九十]{2,3}"),
]
# 定义社会力量奖正则表达式、
social_patterns = [
("获奖项目", r"获奖项目[:](.*?)[^\u4e00-\u9fa5]+获奖"),
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"), # 此处两个获奖单位是因为证书样式不同。
("获奖单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]+奖励等级"),
("获奖_单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]+获奖"),
("奖励等级", r"奖励等级[:](.+?)\n"),
("获奖人", r"获[\s]*奖[\s]*人[:](.+?)"),
("奖励年度", r"奖励年度[:]([\d]{4})年"),
("证书编号", r"证书编号[:]([\w-]+)"),
# ("获奖项目", r"获奖项目[:](.*?)[^\u4e00-\u9fa5]"),
("项目名称", r"[项目名称|获奖项目][:](.*?)\n"), # 此处两个获奖单位是因为证书样式不同。
# ("获奖单位", r"获奖单位[:]([\u4e00-\u9fa5A-Za-z0-9、\(\)\s]+)"),
("获奖单位", r"获奖单位[:](.+?)[^\u4e00-\u9fa5]"),
("奖励等级", r"[奖励等级|获奖等级][:](\s*)([^\s].+?)\n"),
("获奖人", r"获奖人[:]([\u4e00-\u9fa5、\n]+)"),
("奖励年度", r"奖励年度[:]([\d]{4})"),
("证书编号", r"证书[ _]*[编号号][:]([\w-]+)"),
("颁发日期", r"二[〇一二三四五六七八九]{1,2}年[一至十二月一二三四五六七八九十]{2,3}"),
]
# 省部级科技进步奖
province_patterns = [
("项目名称", r"项目名称[:](.*?)[^\u4e00-\u9fa5]+奖励等级"),
("奖励等级", r"奖励等级[:](.+?)\n",),
("获奖者", r"(获奖(?:者|单位)[:](.+?))(?:$|\n)"),
("证书编号", r"([\w-]+)"),
("颁发日期", r"\d{4}\d{1,2}月\d{1,2}日"),
]
@app.post(
"/extract",
summary="提取专利/标准等文件里的信息",
@ -231,7 +258,7 @@ social_patterns = [
},
)
async def extract_info(
file_type: Literal["standard", "patent", "country", "building", "social"] = Form(
file_type: Literal["standard", "patent", "country", "building", "social", "province"] = Form(
..., description="Specify the type of file to extract."
),
extract_method: Literal["re", "chat"] = Form(
@ -257,7 +284,7 @@ async def extract_info(
ocr_text = "\n".join(
[line[1][0] for res in result if res is not None for line in res]
)
print("ocr_text", ocr_text)
# print("ocr_text", ocr_text)
# 提取信息
if extract_method == "re":
if file_type == "patent":
@ -267,11 +294,17 @@ async def extract_info(
info = extract_standard_info(ocr_text, standard_patterns)
clean_info(info)
elif file_type == "country":
info = extract_patent_info(ocr_text, country_patterns)
info = extract_social_info(ocr_text, country_patterns)
clean_info(info)
elif file_type == "building":
info = extract_patent_info(ocr_text, build_patterns)
clean_info(info)
elif file_type == "social":
info = extract_social_info(ocr_text, social_patterns)
clean_info(info)
elif file_type == "province":
info = extract_social_info(ocr_text, province_patterns)
clean_info(info)
else:
raise HTTPException(
400, detail="Invalid file type. Please choose 'standard' or 'patent'."