feat: 格式调整

This commit is contained in:
caoqianming 2024-09-24 12:01:28 +08:00
parent f8c3c56248
commit 69ff37e19a
1 changed files with 111 additions and 49 deletions

160
main.py
View File

@ -21,20 +21,21 @@ CUR_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, CUR_DIR)
import conf
LOGGING_CONFIG["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
LOGGING_CONFIG["formatters"]["default"]["fmt"] = (
"%(asctime)s - %(levelname)s - %(message)s"
)
LOGGING_CONFIG["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"filename": os.path.join(CUR_DIR, "app.log"),
"maxBytes": 5 * 1024 * 1024, # 5 MB
"backupCount": 5,
"formatter": "default",
"encoding": "utf-8"
}
LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"]=["file"]
"class": "logging.handlers.RotatingFileHandler",
"filename": os.path.join(CUR_DIR, "app.log"),
"maxBytes": 5 * 1024 * 1024, # 5 MB
"backupCount": 5,
"formatter": "default",
"encoding": "utf-8",
}
LOGGING_CONFIG["loggers"]["uvicorn.error"]["handlers"] = ["file"]
gpu_available = paddle.device.is_compiled_with_cuda()
gpu_available = paddle.device.is_compiled_with_cuda()
print("GPU available:", gpu_available)
@ -43,23 +44,26 @@ app = FastAPI()
ALLOWED_FILE_TYPES = {"application/pdf", "image/jpeg", "image/png", "image/jpg"}
def save_imgs(images):
"""
保存图像到本地
"""
if not os.path.exists('save_imgs'):
os.makedirs('save_imgs')
if not os.path.exists("save_imgs"):
os.makedirs("save_imgs")
paths = []
for i, image in enumerate(images):
image_path = os.path.join('saved_imgs', f'{uuid.uuid4()}.png')
image_path = os.path.join("saved_imgs", f"{uuid.uuid4()}.png")
image.save(image_path)
paths.append(image_path)
return paths
def check_file_format(file: UploadFile):
if file.content_type not in ALLOWED_FILE_TYPES:
raise HTTPException(400, "File format not supported.")
def perform_ocr(ocr, images):
"""
使用 OCR 提取图像文件中的文字.
@ -72,6 +76,7 @@ def perform_ocr(ocr, images):
all_text.extend(result)
return all_text
def extract_standard_info(text, patterns):
"""
从标准文件文本中提取信息.
@ -84,7 +89,7 @@ def extract_standard_info(text, patterns):
matches = re.findall(pattern, text, re.DOTALL)
if matches:
if label in ["发布日期", "实施日期"]:
info[label] = re.sub(r'', '-', matches[0].replace(' ', ''))
info[label] = re.sub(r"", "-", matches[0].replace(" ", ""))
elif label in ["起草单位", "起草人"]:
info[label] = matches[0].strip()
elif label == "提出单位":
@ -93,10 +98,10 @@ def extract_standard_info(text, patterns):
info[label] = matches[0].strip()
elif label == "发布部门":
if matches[0][1]:
info[label] = matches[0][0] + '' + matches[0][1]
info[label] = matches[0][0] + "" + matches[0][1]
else:
info[label] = matches[0][0] + matches[0][1]
info[label] = re.sub(r'\n', '', info[label])
info[label] = re.sub(r"\n", "", info[label])
else:
info[label] = matches[-1] if matches else None
@ -115,25 +120,41 @@ def extract_patent_info(text, patterns):
return info
def clean_info(info: dict):
for key in info:
if info[key]:
info[key] = info[key].replace('\n', '').replace('\r', '').replace('\t', '').strip()
info[key] = (
info[key].replace("\n", "").replace("\r", "").replace("\t", "").strip()
)
# 定义标准文件用的正则表达式模式
standard_patterns = [
("标准号", r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}"), # 仅提取了第一个标准号(若有更多标准号可能错误)
("中文名称", r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称
r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分
r"(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布"), # 定位发布日期
("英文名称", r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布"), # 匹配发布日期前面的非中文部分
("发布部门", r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)"), # 发布部门是通过匹配发布的上下行的中文
(
"标准号",
r"[A-Z][A-Z/\ 0-9]{1,}[0-9\.]+[\-—-]\d{3,4}",
), # 仅提取了第一个标准号(若有更多标准号可能错误)
(
"中文名称",
r"(^[\u4e00-\u9fa5][\w::、√—\-\ \nn]+)\n" # 定位英文名称前的中文名称
r"(?:[^\u4e00-\u9fa5]+)\n" # 定位发布日期前的英文名称部分
r"(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布",
), # 定位发布日期
(
"英文名称",
r"\n([^\u4e00-\u9fa5]+)\n(?:\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布",
), # 匹配发布日期前面的非中文部分
(
"发布部门",
r"([\u4e00-\u9fa5]+)\n发布\n([\u4e00-\u9fa5\n]*)",
), # 发布部门是通过匹配发布的上下行的中文
("发布日期", r"(\d{4}(?:-|)\d{2}(?:-|)\d{2}) *发布"),
("实施日期", r"(\d{4}(?:-|)\d{2}(?:-|)\d{2}) *实施"),
("提出单位", r"由(.+?)\s?提出。"),
("归口单位", r"由(.+?)\s?归口。"),
("起草单位", r"起草单位[:](.+?)(?:。|起草人)"),
("起草人", r"起草人[:](.+?)(?:。)")
("起草人", r"起草人[:](.+?)(?:。)"),
]
# 定义专利文件用的正则表达式模式
@ -145,29 +166,60 @@ patent_patterns = [
("专利权人", r"专[\s]*利[\s]*权[\s]*人[\s]*[:](.+?)\n"),
("地址", r"地[\s]*址[\s]*[:](.+?)\n"),
("授权公告日", r"授权公告日:(.+?)\n"),
("授权公告号", r"授权公告号:(.+?)\n")
("授权公告号", r"授权公告号:(.+?)\n"),
]
@app.get("/error")
async def create_error():
raise ValueError("This is a test error!")
@app.post("/extract")
@app.post(
"/extract",
summary="提取专利/标准等文件里的信息",
responses={
200: {
"description": "成功示例",
"examples": {
"application/json": {
"标准号": "GB/T 1234-2020",
"中文名称": "低热矿渣硅酸盐水泥",
"英文名称": "Low-heat portland slag cement",
"发布部门": "国家质量监督检验检疫总局",
"发布日期": "2020-01-01",
"实施日期": "2020-07-01",
"提出单位": "中国建筑材料联合会",
"归口单位": "国家建筑材料工业技术监督研究中心",
"起草单位": "中国建筑材料科学研究总院",
"起草人": "张三、李四",
}
},
},
400: {
'description': '错误示例',
"examples": {
"application/json": {
"detail": "Invalid file type. Please choose 'standard' or 'patent'.",
}
},
}
},
)
async def extract_info(
file_type: Literal["standard", "patent"] = Form(..., description="Specify the type of file to extract."),
extract_method: Literal["re", "chat"] = Form(..., description="Specify how to extract. 正则或对话模型"),
file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file.")
file_type: Literal["standard", "patent"] = Form(
..., description="Specify the type of file to extract."
),
extract_method: Literal["re", "chat"] = Form(
..., description="Specify how to extract. 正则或对话模型"
),
file: UploadFile = File(..., description="Upload a PDF, JPEG, or PNG file."),
):
check_file_format(file)
content = await file.read() # 读取文件内容
if file.filename.lower().endswith('.pdf'):
if file.filename.lower().endswith(".pdf"):
images = convert_from_bytes(content, first_page=1, last_page=6, dpi=300)
else:
images = [Image.open(io.BytesIO(content))]
result = perform_ocr(ocr, images)
ocr_text = "\n".join([line[1][0] for res in result if res is not None for line in res])
ocr_text = "\n".join(
[line[1][0] for res in result if res is not None for line in res]
)
# 提取信息
if extract_method == "re":
@ -178,21 +230,31 @@ async def extract_info(
info = extract_standard_info(ocr_text, standard_patterns)
clean_info(info)
else:
raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.")
raise HTTPException(
400, detail="Invalid file type. Please choose 'standard' or 'patent'."
)
else:
if file_type == 'patent':
prompt = f'我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人授权公告号授权公告日为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}'
if file_type == "patent":
prompt = f"我有以下文本,是一个专利的内容。请按专利名称,发明人,专利号,专利申请日, 专利权人授权公告号授权公告日为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
elif file_type == "standard":
prompt = f'我有以下文本,是一个标准的内容。请按标准号中文名称英文名称发布部门发布日期实施日期提出单位归口单位起草单位起草人为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}'
prompt = f"我有以下文本,是一个标准的内容。请按标准号中文名称英文名称发布部门发布日期实施日期提出单位归口单位起草单位起草人为key的json格式返回数据,注意只返回json数据。文本如下{ocr_text}"
else:
raise HTTPException(400, detail="Invalid file type. Please choose 'standard' or 'patent'.")
r = requests.post(conf.CHAT_API, json={
"model": conf.CHAT_MODEL,
"prompt": prompt,
"stream": False
})
info = json.loads((r.json()['response']))
raise HTTPException(
400, detail="Invalid file type. Please choose 'standard' or 'patent'."
)
r = requests.post(
conf.CHAT_API,
json={"model": conf.CHAT_MODEL, "prompt": prompt, "stream": False},
)
info = json.loads((r.json()["response"]))
return JSONResponse(content=info)
if __name__ == "__main__":
uvicorn.run(app="main:app", host="0.0.0.0", port=2260, reload=conf.APP_RELOAD, log_config=LOGGING_CONFIG)
uvicorn.run(
app="main:app",
host="0.0.0.0",
port=2260,
reload=conf.APP_RELOAD,
log_config=LOGGING_CONFIG,
)