diff --git a/main.py b/main.py index 01f6ce0..f146978 100644 --- a/main.py +++ b/main.py @@ -169,7 +169,31 @@ patent_patterns = [ ("授权公告号", r"授权公告号:(.+?)\n"), ] - +# 定义国家奖正则表达式 +country_patterns = [ + ("项目名称", r"项目名称[::](.*?)[^\u4e00-\u9fa5]+奖励等级"), + ("奖励等级", r"奖励等级[::](.+?)\n"), + ("获奖者", r"获奖者[::](.+?)\n"), + ("证书号", r"证书号[::]([\w-]+)"), +] +# 定义建材行业奖正则表达式 +build_patterns = [ + ("项目名称", r"项目名称[::](.*?)[^\u4e00-\u9fa5]+奖励等级"), + ("奖励等级", r"奖励等级[::](.+?)\n"), + ("获奖单位", r"获奖单位[::](.+?)[^\u4e00-\u9fa5]+证书编号"), + ("证书编号", r"证书编号[::]([\w-]+)"), +] +# 定义社会力量奖正则表达式、 +social_patterns = [ + ("获奖项目", r"获奖项目[::](.*?)[^\u4e00-\u9fa5]+获奖"), + ("项目名称", r"项目名称[::](.*?)[^\u4e00-\u9fa5]+奖励等级"), # 此处两个获奖单位是因为证书样式不同。 + ("获奖单位", r"获奖单位[::](.+?)[^\u4e00-\u9fa5]+奖励等级"), + ("获奖_单位", r"获奖单位[::](.+?)[^\u4e00-\u9fa5]+获奖"), + ("奖励等级", r"奖励等级[::](.+?)\n"), + ("获奖人", r"获[\s]*奖[\s]*人[::](.+?)"), + ("奖励年度", r"奖励年度[::]([\d]{4})年"), + ("证书编号", r"证书编号[::]([\w-]+)"), +] @app.post( "/extract", summary="提取专利/标准等文件里的信息", @@ -178,35 +202,36 @@ patent_patterns = [ "description": "成功示例", "content": { "application/json": { - "example": { - "标准号": "GB/T 1234-2020", - "中文名称": "低热矿渣硅酸盐水泥", - "英文名称": "Low-heat portland slag cement", - "发布部门": "国家质量监督检验检疫总局", - "发布日期": "2020-01-01", - "实施日期": "2020-07-01", - "提出单位": "中国建筑材料联合会", - "归口单位": "国家建筑材料工业技术监督研究中心", - "起草单位": "中国建筑材料科学研究总院", - "起草人": "张三、李四", - } - }, + "example": { + "标准号": "GB/T 1234-2020", + "中文名称": "低热矿渣硅酸盐水泥", + "英文名称": "Low-heat portland slag cement", + "发布部门": "国家质量监督检验检疫总局", + "发布日期": "2020-01-01", + "实施日期": "2020-07-01", + "提出单位": "中国建筑材料联合会", + "归口单位": "国家建筑材料工业技术监督研究中心", + "起草单位": "中国建筑材料科学研究总院", + "起草人": "张三、李四", + } }, + } + }, 400: { - "description": "错误示例", + 'description': '错误示例', "content": { "application/json": { "example": { "detail": "Invalid file type. Please choose 'standard' or 'patent'.", } } - }, - }, + } + } }, ) async def extract_info( - file_type: Literal["standard", "patent"] = Form( + file_type: Literal["standard", "patent", "country", "building", "social"] = Form( ..., description="Specify the type of file to extract." ), extract_method: Literal["re", "chat"] = Form( @@ -219,12 +244,20 @@ async def extract_info( if file.filename.lower().endswith(".pdf"): images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300) else: - images = [Image.open(io.BytesIO(content))] + # 如果是建筑行业奖再进行图片截取提高识别率 + if file_type == "building": + img= Image.open(io.BytesIO(content)) + # 定义截取的区域,格式为 (左, 上, 右, 下) + box = (730, 0, 1500, 1000) # 根据实际图片大小调整 + cropped_img = img.crop(box) + images = [cropped_img] + else: + images = [Image.open(io.BytesIO(content))] result = perform_ocr(ocr, images) ocr_text = "\n".join( [line[1][0] for res in result if res is not None for line in res] ) - + print("ocr_text", ocr_text) # 提取信息 if extract_method == "re": if file_type == "patent": @@ -233,6 +266,12 @@ async def extract_info( elif file_type == "standard": info = extract_standard_info(ocr_text, standard_patterns) clean_info(info) + elif file_type == "country": + info = extract_patent_info(ocr_text, country_patterns) + clean_info(info) + elif file_type == "building": + info = extract_patent_info(ocr_text, build_patterns) + clean_info(info) else: raise HTTPException( 400, detail="Invalid file type. Please choose 'standard' or 'patent'."