From 69ff37e19aa898a3ec6466b1947607f46ec590bd Mon Sep 17 00:00:00 2001 From: caoqianming Date: Tue, 24 Sep 2024 12:01:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A0=BC=E5=BC=8F=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 160 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 111 insertions(+), 49 deletions(-) diff --git a/main.py b/main.py index 1fb3a2e..a16345a 100644 --- a/main.py +++ b/main.py @@ -21,20 +21,21 @@ CUR_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, CUR_DIR) import conf -LOGGING_CONFIG["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" +LOGGING_CONFIG["formatters"]["default"]["fmt"] = ( + "%(asctime)s - %(levelname)s - %(message)s" +) LOGGING_CONFIG["handlers"]["file"] = { - "class": "logging.handlers.RotatingFileHandler", - "filename": os.path.join(CUR_DIR, "app.log"), - "maxBytes": 5 * 1024 * 1024, # 5 MB - "backupCount": 5, - "formatter": "default", - "encoding": "utf-8" - } -LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"]=["file"] + "class": "logging.handlers.RotatingFileHandler", + "filename": os.path.join(CUR_DIR, "app.log"), + "maxBytes": 5 * 1024 * 1024, # 5 MB + "backupCount": 5, + "formatter": "default", + "encoding": "utf-8", +} +LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"] = ["file"] - -gpu_available = paddle.device.is_compiled_with_cuda() +gpu_available = paddle.device.is_compiled_with_cuda() print("GPU available:", gpu_available) @@ -43,23 +44,26 @@ app = FastAPI() ALLOWED_FILE_TYPES = {"application/pdf", "image/jpeg", "image/png", "image/jpg"} + def save_imgs(images): """ 保存图像到本地 """ - if not os.path.exists('save_imgs'): - os.makedirs('save_imgs') + if not os.path.exists("save_imgs"): + os.makedirs("save_imgs") paths = [] for i, image in enumerate(images): - image_path = os.path.join('saved_imgs', f'{uuid.uuid4()}.png') + image_path = os.path.join("saved_imgs", f"{uuid.uuid4()}.png") image.save(image_path) paths.append(image_path) return paths + def check_file_format(file: UploadFile): if file.content_type not in ALLOWED_FILE_TYPES: raise HTTPException(400, "File format not supported.") + def perform_ocr(ocr, images): """ 使用 OCR 提取图像文件中的文字. @@ -72,6 +76,7 @@ def perform_ocr(ocr, images): all_text.extend(result) return all_text + def extract_standard_info(text, patterns): """ 从标准文件文本中提取信息. @@ -84,7 +89,7 @@ def extract_standard_info(text, patterns): matches = re.findall(pattern, text, re.DOTALL) if matches: if label in ["发布日期", "实施日期"]: - info[label] = re.sub(r'-', '-', matches[0].replace(' ', '')) + info[label] = re.sub(r"-", "-", matches[0].replace(" ", "")) elif label in ["起草单位", "起草人"]: info[label] = matches[0].strip() elif label == "提出单位": @@ -93,10 +98,10 @@ def extract_standard_info(text, patterns): info[label] = matches[0].strip() elif label == "发布部门": if matches[0][1]: - info[label] = matches[0][0] + '、' + matches[0][1] + info[label] = matches[0][0] + "、" + matches[0][1] else: info[label] = matches[0][0] + matches[0][1] - info[label] = re.sub(r'\n', '', info[label]) + info[label] = re.sub(r"\n", "", info[label]) else: info[label] = matches[-1] if matches else None @@ -115,25 +120,41 @@ def extract_patent_info(text, patterns): return info + def clean_info(info: dict): for key in info: if info[key]: - info[key] = info[key].replace('\n', '').replace('\r', '').replace('\t', '').strip() + info[key] = ( + info[key].replace("\n", "").replace("\r", "").replace("\t", "").strip() + ) + # 定义标准文件用的正则表达式模式 standard_patterns = [ - ("标准号", r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}"), # 仅提取了第一个标准号(若有更多标准号可能错误) - ("中文名称", r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称 - r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分 - r"(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布"), # 定位发布日期 - ("英文名称", r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布"), # 匹配发布日期前面的非中文部分 - ("发布部门", r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)"), # 发布部门是通过匹配发布的上下行的中文 + ( + "标准号", + r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}", + ), # 仅提取了第一个标准号(若有更多标准号可能错误) + ( + "中文名称", + r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称 + r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分 + r"(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布", + ), # 定位发布日期 + ( + "英文名称", + r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布", + ), # 匹配发布日期前面的非中文部分 + ( + "发布部门", + r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)", + ), # 发布部门是通过匹配发布的上下行的中文 ("发布日期", r"(\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布"), ("实施日期", r"(\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *实施"), ("提出单位", r"由(.+?)\s?提出。"), ("归口单位", r"由(.+?)\s?归口。"), ("起草单位", r"起草单位[::](.+?)(?:。|起草人)"), - ("起草人", r"起草人[::](.+?)(?:。)") + ("起草人", r"起草人[::](.+?)(?:。)"), ] # 定义专利文件用的正则表达式模式 @@ -145,29 +166,60 @@ patent_patterns = [ ("专利权人", r"专[\s]*利[\s]*权[\s]*人[\s]*[::](.+?)\n"), ("地址", r"地[\s]*址[\s]*[::](.+?)\n"), ("授权公告日", r"授权公告日:(.+?)\n"), - ("授权公告号", r"授权公告号:(.+?)\n") + ("授权公告号", r"授权公告号:(.+?)\n"), ] - -@app.get("/error") -async def create_error(): - raise ValueError("This is a test error!") -@app.post("/extract") +@app.post( + "/extract", + summary="提取专利/标准等文件里的信息", + responses={ + 200: { + "description": "成功示例", + "examples": { + "application/json": { + "标准号": "GB/T 1234-2020", + "中文名称": "低热矿渣硅酸盐水泥", + "英文名称": "Low-heat portland slag cement", + "发布部门": "国家质量监督检验检疫总局", + "发布日期": "2020-01-01", + "实施日期": "2020-07-01", + "提出单位": "中国建筑材料联合会", + "归口单位": "国家建筑材料工业技术监督研究中心", + "起草单位": "中国建筑材料科学研究总院", + "起草人": "张三、李四", + } + }, + }, + 400: { + 'description': '错误示例', + "examples": { + "application/json": { + "detail": "Invalid file type. Please choose 'standard' or 'patent'.", + } + }, + } + }, +) async def extract_info( - file_type: Literal["standard", "patent"] = Form(..., description="Specify the type of file to extract."), - extract_method: Literal["re", "chat"] = Form(..., description="Specify how to extract. 正则或对话模型"), - file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file.") + file_type: Literal["standard", "patent"] = Form( + ..., description="Specify the type of file to extract." + ), + extract_method: Literal["re", "chat"] = Form( + ..., description="Specify how to extract. 正则或对话模型" + ), + file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file."), ): - check_file_format(file) content = await file.read() # 读取文件内容 - if file.filename.lower().endswith('.pdf'): + if file.filename.lower().endswith(".pdf"): images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300) else: images = [Image.open(io.BytesIO(content))] result = perform_ocr(ocr, images) - ocr_text = "\n".join([line[1][0] for res in result if res is not None for line in res]) + ocr_text = "\n".join( + [line[1][0] for res in result if res is not None for line in res] + ) # 提取信息 if extract_method == "re": @@ -178,21 +230,31 @@ async def extract_info( info = extract_standard_info(ocr_text, standard_patterns) clean_info(info) else: - raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.") + raise HTTPException( + 400, detail="Invalid file type. Please choose 'standard' or 'patent'." + ) else: - if file_type == 'patent': - prompt = f'我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人,授权公告号,授权公告日为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}' + if file_type == "patent": + prompt = f"我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人,授权公告号,授权公告日为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}" elif file_type == "standard": - prompt = f'我有以下文本,是一个标准的内容。请按标准号,中文名称,英文名称,发布部门,发布日期,实施日期,提出单位,归口单位,起草单位,起草人为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}' + prompt = f"我有以下文本,是一个标准的内容。请按标准号,中文名称,英文名称,发布部门,发布日期,实施日期,提出单位,归口单位,起草单位,起草人为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}" else: - raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.") - r = requests.post(conf.CHAT_API, json={ - "model": conf.CHAT_MODEL, - "prompt": prompt, - "stream": False - }) - info = json.loads((r.json()['response'])) + raise HTTPException( + 400, detail="Invalid file type. Please choose 'standard' or 'patent'." + ) + r = requests.post( + conf.CHAT_API, + json={"model": conf.CHAT_MODEL, "prompt": prompt, "stream": False}, + ) + info = json.loads((r.json()["response"])) return JSONResponse(content=info) + if __name__ == "__main__": - uvicorn.run(app="main:app", host="0.0.0.0", port=2260, reload=conf.APP_RELOAD, log_config=LOGGING_CONFIG) \ No newline at end of file + uvicorn.run( + app="main:app", + host="0.0.0.0", + port=2260, + reload=conf.APP_RELOAD, + log_config=LOGGING_CONFIG, + )