commit 03c467159c64d261ad677c36c6189ce470637137 Author: caoqianming Date: Tue Sep 24 11:33:54 2024 +0800 初始化 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8cb5a6c --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.venv/ +__pycache__/ +app.log +conf.py \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..0f457b8 --- /dev/null +++ b/main.py @@ -0,0 +1,198 @@ +from fastapi import FastAPI, File, UploadFile, Form +from typing import Literal +from fastapi.responses import JSONResponse +from fastapi.exceptions import HTTPException +from PIL import Image +from paddleocr import PaddleOCR +import numpy as np +import uvicorn +import re +import io +from pdf2image import convert_from_bytes +import uuid +import os +import requests +import json +import paddle +import sys +from uvicorn.config import LOGGING_CONFIG + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, CUR_DIR) +import conf + +LOGGING_CONFIG["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" +LOGGING_CONFIG["handlers"]["file"] = { + "class": "logging.handlers.RotatingFileHandler", + "filename": os.path.join(CUR_DIR, "app.log"), + "maxBytes": 5 * 1024 * 1024, # 5 MB + "backupCount": 5, + "formatter": "default", + "encoding": "utf-8" + } +LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"]=["file"] + + + +gpu_available = paddle.device.is_compiled_with_cuda() +print("GPU available:", gpu_available) + + +ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=True) +app = FastAPI() + +ALLOWED_FILE_TYPES = {"application/pdf", "image/jpeg", "image/png", "image/jpg"} + +def save_imgs(images): + """ + 保存图像到本地 + """ + if not os.path.exists('save_imgs'): + os.makedirs('save_imgs') + paths = [] + for i, image in enumerate(images): + image_path = os.path.join('saved_imgs', f'{uuid.uuid4()}.png') + image.save(image_path) + paths.append(image_path) + return paths + +def check_file_format(file: UploadFile): + if file.content_type not in ALLOWED_FILE_TYPES: + raise HTTPException(400, "File format not supported.") + +def perform_ocr(ocr, images): + """ + 使用 OCR 提取图像文件中的文字. + """ + all_text = [] + for image in images: + # Convert PIL Image to numpy array + image_np = np.array(image) + result = ocr.ocr(image_np, cls=True) + all_text.extend(result) + return all_text + +def extract_standard_info(text, patterns): + """ + 从标准文件文本中提取信息. + """ + info = {} + for label, pattern in patterns: + if label in ("归口单位", "中文名称"): + matches = re.findall(pattern, text, re.MULTILINE) + else: + matches = re.findall(pattern, text, re.DOTALL) + if matches: + if label in ["发布日期", "实施日期"]: + info[label] = re.sub(r'-', '-', matches[0].replace(' ', '')) + elif label in ["起草单位", "起草人"]: + info[label] = matches[0].strip() + elif label == "提出单位": + info[label] = matches[-1].strip() + elif label == "标准号": + info[label] = matches[0].strip() + elif label == "发布部门": + if matches[0][1]: + info[label] = matches[0][0] + '、' + matches[0][1] + else: + info[label] = matches[0][0] + matches[0][1] + info[label] = re.sub(r'\n', '', info[label]) + else: + info[label] = matches[-1] if matches else None + + return info + + +def extract_patent_info(text, patterns): + """ + 从专利文件文本中提取信息. + """ + info = {} + + for label, pattern in patterns: + matches = re.findall(pattern, text, re.DOTALL) + info[label] = matches[-1] if matches else None + + return info + +def clean_info(info: dict): + for key in info: + if info[key]: + info[key] = info[key].replace('\n', '').replace('\r', '').replace('\t', '').strip() + +# 定义标准文件用的正则表达式模式 +standard_patterns = [ + ("标准号", r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}"), # 仅提取了第一个标准号(若有更多标准号可能错误) + ("中文名称", r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称 + r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分 + r"(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布"), # 定位发布日期 + ("英文名称", r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布"), # 匹配发布日期前面的非中文部分 + ("发布部门", r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)"), # 发布部门是通过匹配发布的上下行的中文 + ("发布日期", r"(\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *发布"), + ("实施日期", r"(\d{4}(?:-|-)\d{2}(?:-|-)\d{2}) *实施"), + ("提出单位", r"由(.+?)\s?提出。"), + ("归口单位", r"由(.+?)\s?归口。"), + ("起草单位", r"起草单位[::](.+?)(?:。|起草人)"), + ("起草人", r"起草人[::](.+?)(?:。)") +] + +# 定义专利文件用的正则表达式模式 +patent_patterns = [ + ("专利名称", r"名称[::](.+?)\n"), + ("发明人", r"发[\s]*明[\s]*人[\s]*[::](.+?)\n"), + ("专利号", r"专[\s]*利[\s]*号[\s]*[::](.+?)\n"), + ("专利申请日", r"专利申请日[::](.+?)\n"), + ("专利权人", r"专[\s]*利[\s]*权[\s]*人[\s]*[::](.+?)\n"), + ("地址", r"地[\s]*址[\s]*[::](.+?)\n"), + ("授权公告日", r"授权公告日:(.+?)\n"), + ("授权公告号", r"授权公告号:(.+?)\n") +] + +@app.get("/error") +async def create_error(): + raise ValueError("This is a test error!") + + +@app.post("/extract") +async def extract_info( + file_type: Literal["standard", "patent"] = Form(..., description="Specify the type of file to extract."), + extract_method: Literal["re", "chat"] = Form(..., description="Specify how to extract. 正则或对话模型"), + file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file.") +): + + check_file_format(file) + content = await file.read() # 读取文件内容 + if file.filename.lower().endswith('.pdf'): + images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300) + else: + images = [Image.open(io.BytesIO(content))] + result = perform_ocr(ocr, images) + ocr_text = "\n".join([line[1][0] for res in result if res is not None for line in res]) + + # 提取信息 + if extract_method == "re": + if file_type == "patent": + info = extract_patent_info(ocr_text, patent_patterns) + clean_info(info) + elif file_type == "standard": + info = extract_standard_info(ocr_text, standard_patterns) + clean_info(info) + else: + raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.") + else: + if file_type == 'patent': + prompt = f'我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人,授权公告号,授权公告日为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}' + elif file_type == "standard": + prompt = f'我有以下文本,是一个标准的内容。请按标准号,中文名称,英文名称,发布部门,发布日期,实施日期,提出单位,归口单位,起草单位,起草人为key的json格式返回数据,注意只返回json数据。文本如下:{ocr_text}' + else: + raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.") + r = requests.post(conf.CHAT_API, json={ + "model": "llama3.1", + "prompt": prompt, + "stream": False + }) + info = json.loads((r.json()['response'])) + return JSONResponse(content=info) + +if __name__ == "__main__": + uvicorn.run(app="main:app", host="0.0.0.0", port=8000, reload=True, log_config=LOGGING_CONFIG) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..48c087c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +fastapi +uvicorn +python-multipart +paddlepaddle +paddleocr +pdf2image