Refactor image data passing in `pdf_convertor.py` to use a direct base64 and mime_type format, aligning with updated API requirements for vision models. Additionally, the `pdf_convertor_prompt.md` has been significantly refined to improve the clarity and specificity of instructions for the AI model, particularly concerning: - **Image Content Explanation:** Added detailed rules to ensure the AI only processes existing image references, preserves paths, and focuses on descriptive text. - **Mathematical Formulas:** Clarified conversion to LaTeX notation. - **Heading Structure:** Enhanced rules and examples for adjusting heading levels and merging adjacent or duplicate headings to ensure logical document flow.
256 lines
8.1 KiB
Python
Executable File
256 lines
8.1 KiB
Python
Executable File
import re
|
|
import base64
|
|
import os
|
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
|
from docling.datamodel.base_models import InputFormat
|
|
from docling.datamodel.pipeline_options import (
|
|
PdfPipelineOptions,
|
|
)
|
|
from docling_core.types.io import DocumentStream
|
|
from docling.datamodel.settings import settings
|
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
|
from docling_core.types.doc.base import ImageRefMode
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_openai import ChatOpenAI
|
|
from llm import set_api_key, get_model_name, get_temperature
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
import configparser
|
|
import fitz
|
|
|
|
|
|
def save_md_images(
|
|
output: str | Path,
|
|
md_content: str,
|
|
images: dict[str, bytes],
|
|
md_name: str = "index.md",
|
|
images_dirname: str = "images",
|
|
):
|
|
output = Path(output)
|
|
md_path = output.joinpath(md_name)
|
|
md_path.parent.mkdir(exist_ok=True, parents=True)
|
|
images_dir = output.joinpath(images_dirname)
|
|
images_dir.mkdir(exist_ok=True, parents=True)
|
|
for image_name in images.keys():
|
|
image_path = images_dir.joinpath(Path(image_name).name)
|
|
with open(image_path, "wb") as image_file:
|
|
image_file.write(images[image_name])
|
|
md_content = md_content.replace(
|
|
f"]({image_name})",
|
|
f"]({image_path.relative_to(md_path.parent, walk_up=True)})",
|
|
)
|
|
with open(md_path, "w") as md_file:
|
|
md_file.write(md_content)
|
|
|
|
|
|
def load_md_file(md_path: str | Path) -> tuple[str, dict[str, bytes]]:
|
|
md_path = Path(md_path)
|
|
with open(md_path, "r") as md_file:
|
|
md = md_file.read()
|
|
images: list[str] = re.findall(r"!\[.*?\]\((.*?)\)", md)
|
|
image_dict: dict[str, bytes] = dict()
|
|
for i in range(len(images)):
|
|
image_path = images[i]
|
|
if image_path.startswith("data:image/png;base64,"):
|
|
image_dict[f"{i}.png"] = image_path.removeprefix(
|
|
"data:image/png;base64,"
|
|
).encode("UTF-8")
|
|
else:
|
|
with open(
|
|
Path(md_path.parent).joinpath(image_path), "rb"
|
|
) as image_file:
|
|
image_dict[image_path] = image_file.read()
|
|
return (md, image_dict)
|
|
|
|
|
|
def convert_pdf_to_markdown(pdf: bytes) -> tuple[str, dict[str, bytes]]:
|
|
"""Converts a PDF document to Markdown format."""
|
|
|
|
accelerator_options = AcceleratorOptions(
|
|
num_threads=16, device=AcceleratorDevice.CUDA
|
|
)
|
|
|
|
pipeline_options = PdfPipelineOptions()
|
|
pipeline_options.accelerator_options = accelerator_options
|
|
pipeline_options.do_ocr = True
|
|
pipeline_options.do_table_structure = True
|
|
pipeline_options.table_structure_options.do_cell_matching = True
|
|
pipeline_options.generate_page_images = True
|
|
pipeline_options.generate_picture_images = True
|
|
|
|
converter = DocumentConverter(
|
|
format_options={
|
|
InputFormat.PDF: PdfFormatOption(
|
|
pipeline_options=pipeline_options,
|
|
)
|
|
}
|
|
)
|
|
|
|
# Enable the profiling to measure the time spent
|
|
settings.debug.profile_pipeline_timings = True
|
|
|
|
# Convert the document
|
|
conversion_result = converter.convert(
|
|
source=DocumentStream(name="", stream=BytesIO(pdf))
|
|
)
|
|
doc = conversion_result.document
|
|
|
|
doc.pictures
|
|
|
|
md = doc.export_to_markdown(
|
|
image_mode=ImageRefMode.EMBEDDED,
|
|
)
|
|
|
|
images: list[str] = re.findall(r"!\[Image\]\((.*?)\)", md)
|
|
image_dict: dict[str, bytes] = dict()
|
|
for i in range(len(images)):
|
|
data = images[i].removeprefix("data:image/png;base64,")
|
|
img_data = base64.b64decode(data)
|
|
image_dict[f"{i}.png"] = img_data
|
|
md = md.replace(images[i], f"{i}.png")
|
|
|
|
return (md, image_dict)
|
|
|
|
|
|
def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str:
|
|
"""Refines the Markdown content using an LLM."""
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read("config.ini")
|
|
provider = config.get("llm", "PROVIDER", fallback="gemini")
|
|
|
|
set_api_key()
|
|
|
|
try:
|
|
if provider == "gemini":
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=get_model_name(), temperature=get_temperature()
|
|
)
|
|
elif provider == "ollama":
|
|
llm = ChatOllama(
|
|
model=get_model_name(),
|
|
temperature=get_temperature(),
|
|
base_url=os.environ["OLLAMA_BASE_URL"],
|
|
num_ctx=256000,
|
|
num_predict=-1,
|
|
)
|
|
elif provider == "openai":
|
|
llm = ChatOpenAI(
|
|
model=get_model_name(),
|
|
temperature=get_temperature(),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
except Exception as e:
|
|
raise BaseException(
|
|
f"Error initializing LLM. Make sure your LLM configuration is correct. Error: {e}"
|
|
)
|
|
|
|
with open("pdf_convertor_prompt.md", "r") as f:
|
|
prompt = f.read()
|
|
|
|
# 添加 Markdown
|
|
human_message_parts = []
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": md,
|
|
}
|
|
)
|
|
|
|
# 添加图片
|
|
for image_name in images.keys():
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": f"This is image: '{image_name}':\n",
|
|
}
|
|
)
|
|
human_message_parts.append(
|
|
{
|
|
"type": "image",
|
|
"base64": base64.b64encode(images[image_name]).decode("utf-8"),
|
|
"mime_type": "image/png",
|
|
}
|
|
)
|
|
|
|
# 添加 PDF
|
|
if provider == "gemini":
|
|
human_message_parts.extend(
|
|
[
|
|
{
|
|
"type": "text",
|
|
"text": "This is original PDF file:\n",
|
|
},
|
|
{
|
|
"type": "media",
|
|
"mime_type": "application/pdf",
|
|
"data": base64.b64encode(pdf).decode("utf-8"),
|
|
},
|
|
]
|
|
)
|
|
if provider == "openai":
|
|
human_message_parts.extend(
|
|
[
|
|
{
|
|
"type": "file",
|
|
"file": {
|
|
"filename": "origin.pdf",
|
|
"file_data": f"data:application/pdf;base64,{base64.b64encode(pdf).decode('utf-8')}",
|
|
},
|
|
},
|
|
]
|
|
)
|
|
if provider == "ollama":
|
|
doc = fitz.open(stream=pdf, filetype="pdf")
|
|
for page_num in range(doc.page_count):
|
|
page = doc.load_page(page_num)
|
|
pix = page.get_pixmap()
|
|
img_bytes = pix.tobytes("png")
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": f"This is page {page_num + 1} of the original PDF file:\n",
|
|
}
|
|
)
|
|
human_message_parts.append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}",
|
|
}
|
|
)
|
|
doc.close()
|
|
|
|
message_content = [
|
|
SystemMessage(content=prompt),
|
|
HumanMessage(content=human_message_parts), # type: ignore
|
|
]
|
|
|
|
print(
|
|
f"Sending request to {provider} with the PDF, Markdown and referenced images... This may take a moment."
|
|
)
|
|
try:
|
|
response = llm.invoke(message_content)
|
|
refined_content = response.content
|
|
except Exception as e:
|
|
raise BaseException(f"An error occurred while invoking the LLM: {e}")
|
|
|
|
if str(refined_content) == "":
|
|
raise BaseException(f"Response of {provider} is empty")
|
|
|
|
return fix_output(str(refined_content))
|
|
|
|
|
|
def fix_output(md: str) -> str:
|
|
if not md.startswith("['") or not md.endswith("']"):
|
|
return md
|
|
md = md.removeprefix("['")
|
|
md = md.removesuffix("']")
|
|
md = md.replace("\\n", "\n")
|
|
md = md.replace("\", '", "\n\n")
|
|
md = md.replace("', \"", "\n\n")
|
|
md = md.replace("', '", "\n\n")
|
|
return md
|