This commit improves the structure and clarity of the prompt sent to the LLM (Gemini/OpenAI) in the `refine_content` function. Changes include: * Adding explicit introductory text for the Markdown, individual images, and PDF sections to guide the LLM on the purpose of each input. * Introducing clear "START OF IMAGE" and "END OF IMAGE" delimiters for each image to better define their boundaries. * Unifying the PDF attachment mechanism for both Gemini and OpenAI providers, simplifying the code and ensuring consistent handling of PDF input. These changes aim to improve the LLM's understanding of the provided content, leading to more accurate and relevant refinements.
268 lines
8.5 KiB
Python
Executable File
268 lines
8.5 KiB
Python
Executable File
import re
|
|
import base64
|
|
import os
|
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
|
from docling.datamodel.base_models import InputFormat
|
|
from docling.datamodel.pipeline_options import (
|
|
PdfPipelineOptions,
|
|
)
|
|
from docling_core.types.io import DocumentStream
|
|
from docling.datamodel.settings import settings
|
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
|
from docling_core.types.doc.base import ImageRefMode
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_openai import ChatOpenAI
|
|
from llm import set_api_key, get_model_name, get_temperature
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
import configparser
|
|
import fitz
|
|
|
|
|
|
def save_md_images(
|
|
output: str | Path,
|
|
md_content: str,
|
|
images: dict[str, bytes],
|
|
md_name: str = "index.md",
|
|
images_dirname: str = "images",
|
|
):
|
|
output = Path(output)
|
|
md_path = output.joinpath(md_name)
|
|
md_path.parent.mkdir(exist_ok=True, parents=True)
|
|
images_dir = output.joinpath(images_dirname)
|
|
images_dir.mkdir(exist_ok=True, parents=True)
|
|
for image_name in images.keys():
|
|
image_path = images_dir.joinpath(Path(image_name).name)
|
|
with open(image_path, "wb") as image_file:
|
|
image_file.write(images[image_name])
|
|
md_content = md_content.replace(
|
|
f"]({image_name})",
|
|
f"]({image_path.relative_to(md_path.parent, walk_up=True)})",
|
|
)
|
|
with open(md_path, "w") as md_file:
|
|
md_file.write(md_content)
|
|
|
|
|
|
def load_md_file(md_path: str | Path) -> tuple[str, dict[str, bytes]]:
|
|
md_path = Path(md_path)
|
|
with open(md_path, "r") as md_file:
|
|
md = md_file.read()
|
|
images: list[str] = re.findall(r"!\[.*?\]\((.*?)\)", md)
|
|
image_dict: dict[str, bytes] = dict()
|
|
for i in range(len(images)):
|
|
image_path = images[i]
|
|
if image_path.startswith("data:image/png;base64,"):
|
|
image_dict[f"{i}.png"] = image_path.removeprefix(
|
|
"data:image/png;base64,"
|
|
).encode("UTF-8")
|
|
else:
|
|
with open(
|
|
Path(md_path.parent).joinpath(image_path), "rb"
|
|
) as image_file:
|
|
image_dict[image_path] = image_file.read()
|
|
return (md, image_dict)
|
|
|
|
|
|
def convert_pdf_to_markdown(pdf: bytes) -> tuple[str, dict[str, bytes]]:
|
|
"""Converts a PDF document to Markdown format."""
|
|
|
|
accelerator_options = AcceleratorOptions(
|
|
num_threads=16, device=AcceleratorDevice.CUDA
|
|
)
|
|
|
|
pipeline_options = PdfPipelineOptions()
|
|
pipeline_options.accelerator_options = accelerator_options
|
|
pipeline_options.do_ocr = True
|
|
pipeline_options.do_table_structure = True
|
|
pipeline_options.table_structure_options.do_cell_matching = True
|
|
pipeline_options.generate_page_images = True
|
|
pipeline_options.generate_picture_images = True
|
|
|
|
converter = DocumentConverter(
|
|
format_options={
|
|
InputFormat.PDF: PdfFormatOption(
|
|
pipeline_options=pipeline_options,
|
|
)
|
|
}
|
|
)
|
|
|
|
# Enable the profiling to measure the time spent
|
|
settings.debug.profile_pipeline_timings = True
|
|
|
|
# Convert the document
|
|
conversion_result = converter.convert(
|
|
source=DocumentStream(name="", stream=BytesIO(pdf))
|
|
)
|
|
doc = conversion_result.document
|
|
|
|
doc.pictures
|
|
|
|
md = doc.export_to_markdown(
|
|
image_mode=ImageRefMode.EMBEDDED,
|
|
)
|
|
|
|
images: list[str] = re.findall(r"!\[Image\]\((.*?)\)", md)
|
|
image_dict: dict[str, bytes] = dict()
|
|
for i in range(len(images)):
|
|
data = images[i].removeprefix("data:image/png;base64,")
|
|
img_data = base64.b64decode(data)
|
|
image_dict[f"{i}.png"] = img_data
|
|
md = md.replace(images[i], f"{i}.png")
|
|
|
|
return (md, image_dict)
|
|
|
|
|
|
def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str:
|
|
"""Refines the Markdown content using an LLM."""
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read("config.ini")
|
|
provider = config.get("llm", "PROVIDER", fallback="gemini")
|
|
|
|
set_api_key()
|
|
|
|
try:
|
|
if provider == "gemini":
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=get_model_name(), temperature=get_temperature()
|
|
)
|
|
elif provider == "ollama":
|
|
llm = ChatOllama(
|
|
model=get_model_name(),
|
|
temperature=get_temperature(),
|
|
base_url=os.environ["OLLAMA_BASE_URL"],
|
|
num_ctx=256000,
|
|
num_predict=-1,
|
|
)
|
|
elif provider == "openai":
|
|
llm = ChatOpenAI(
|
|
model=get_model_name(),
|
|
temperature=get_temperature(),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
except Exception as e:
|
|
raise BaseException(
|
|
f"Error initializing LLM. Make sure your LLM configuration is correct. Error: {e}"
|
|
)
|
|
|
|
with open("pdf_convertor_prompt.md", "r") as f:
|
|
prompt = f.read()
|
|
|
|
# 添加 Markdown
|
|
human_message_parts = []
|
|
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": "I will provide you with: \n1. Original Markdown text (with image references)\n2. Individual images with their exact filenames\n3. Original PDF for reference\n---\nOriginal Markdown text:\n",
|
|
}
|
|
)
|
|
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": md,
|
|
}
|
|
)
|
|
|
|
# 添加图片
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": "\n\n---\n\nIndividual images (use ONLY these to describe the corresponding Markdown image references):\n\n",
|
|
}
|
|
)
|
|
for image_name in images.keys():
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": f"=== START OF IMAGE: '{image_name}' ===\n",
|
|
}
|
|
)
|
|
human_message_parts.append(
|
|
{
|
|
"type": "image",
|
|
"base64": base64.b64encode(images[image_name]).decode("utf-8"),
|
|
"mime_type": "image/png",
|
|
}
|
|
)
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": f"=== END OF IMAGE: '{image_name}' ===\n\n",
|
|
}
|
|
)
|
|
|
|
# 添加 PDF
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": "\n---\n\nOriginal PDF (for overall layout and context reference only):\n\n",
|
|
}
|
|
)
|
|
|
|
if provider == "gemini" or provider == "openai":
|
|
human_message_parts.extend(
|
|
[
|
|
{
|
|
"type": "file",
|
|
"base64": base64.b64encode(pdf).decode("utf-8"),
|
|
"mime_type": "application/pdf",
|
|
"filename": "origin.pdf",
|
|
},
|
|
]
|
|
)
|
|
if provider == "ollama":
|
|
doc = fitz.open(stream=pdf, filetype="pdf")
|
|
for page_num in range(doc.page_count):
|
|
page = doc.load_page(page_num)
|
|
pix = page.get_pixmap()
|
|
img_bytes = pix.tobytes("png")
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": f"This is page {page_num + 1} of the original PDF file:\n",
|
|
}
|
|
)
|
|
human_message_parts.append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}",
|
|
}
|
|
)
|
|
doc.close()
|
|
|
|
message_content = [
|
|
SystemMessage(content=prompt),
|
|
HumanMessage(content=human_message_parts), # type: ignore
|
|
]
|
|
|
|
print(
|
|
f"Sending request to {provider} with the PDF, Markdown and referenced images... This may take a moment."
|
|
)
|
|
try:
|
|
response = llm.invoke(message_content)
|
|
refined_content = response.content
|
|
except Exception as e:
|
|
raise BaseException(f"An error occurred while invoking the LLM: {e}")
|
|
|
|
if str(refined_content) == "":
|
|
raise BaseException(f"Response of {provider} is empty")
|
|
|
|
return fix_output(str(refined_content))
|
|
|
|
|
|
def fix_output(md: str) -> str:
|
|
if not md.startswith("['") or not md.endswith("']"):
|
|
return md
|
|
md = md.removeprefix("['")
|
|
md = md.removesuffix("']")
|
|
md = md.replace("\\n", "\n")
|
|
md = md.replace("\", '", "\n\n")
|
|
md = md.replace("', \"", "\n\n")
|
|
md = md.replace("', '", "\n\n")
|
|
return md
|