This commit refactors the content refinement process to leverage `SystemMessage` for the primary prompt, enhancing clarity and adherence to LLM best practices. The `pdf_convertor.py` file was updated to: - Import `SystemMessage` from `langchain_core.messages`. - Modify the `refine_content` function to use `SystemMessage` for the main prompt, moving the prompt content from `human_message_parts`. - Adjust `human_message_parts` to only contain the Markdown and image data for the `HumanMessage`. The `pdf_convertor_prompt.md` file was updated to: - Reformat the prompt with clearer headings and instructions for each task. - Improve the clarity and conciseness of the instructions for cleaning up characters, explaining image content, and correcting list formatting. Additionally, `.gitignore` was updated to include `.vscode/` to prevent IDE-specific files from being committed. These changes improve the structure of the LLM interaction and make the prompt more readable and maintainable.
255 lines
8.2 KiB
Python
Executable File
255 lines
8.2 KiB
Python
Executable File
import re
|
|
import base64
|
|
import os
|
|
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
|
from docling.datamodel.base_models import InputFormat
|
|
from docling.datamodel.pipeline_options import (
|
|
PdfPipelineOptions,
|
|
)
|
|
from docling_core.types.io import DocumentStream
|
|
from docling.datamodel.settings import settings
|
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
|
from docling_core.types.doc.base import ImageRefMode
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_ollama import ChatOllama
|
|
from llm import set_api_key, get_model_name, get_temperature
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
import configparser
|
|
import fitz
|
|
|
|
|
|
def save_md_images(
|
|
output: str | Path,
|
|
md_content: str,
|
|
images: dict[str, bytes],
|
|
md_name: str = "index.md",
|
|
images_dirname: str = "images",
|
|
):
|
|
output = Path(output)
|
|
md_path = output.joinpath(md_name)
|
|
md_path.parent.mkdir(exist_ok=True, parents=True)
|
|
images_dir = output.joinpath(images_dirname)
|
|
images_dir.mkdir(exist_ok=True, parents=True)
|
|
for image_name in images.keys():
|
|
image_path = images_dir.joinpath(Path(image_name).name)
|
|
with open(image_path, "wb") as image_file:
|
|
image_file.write(images[image_name])
|
|
md_content = md_content.replace(
|
|
f"]({image_name})",
|
|
f"]({image_path.relative_to(md_path.parent, walk_up=True)})",
|
|
)
|
|
with open(md_path, "w") as md_file:
|
|
md_file.write(md_content)
|
|
|
|
|
|
def load_md_file(md_path: str | Path) -> tuple[str, dict[str, bytes]]:
|
|
md_path = Path(md_path)
|
|
with open(md_path, "r") as md_file:
|
|
md = md_file.read()
|
|
images: list[str] = re.findall(r"!\[.*?\]\((.*?)\)", md)
|
|
image_dict: dict[str, bytes] = dict()
|
|
for i in range(len(images)):
|
|
image_path = images[i]
|
|
if image_path.startswith("data:image/png;base64,"):
|
|
image_dict[f"{i}.png"] = image_path.removeprefix(
|
|
"data:image/png;base64,"
|
|
).encode("UTF-8")
|
|
else:
|
|
with open(
|
|
Path(md_path.parent).joinpath(image_path), "rb"
|
|
) as image_file:
|
|
image_dict[image_path] = image_file.read()
|
|
return (md, image_dict)
|
|
|
|
|
|
def convert_pdf_to_markdown(pdf: bytes) -> tuple[str, dict[str, bytes]]:
|
|
"""Converts a PDF document to Markdown format."""
|
|
|
|
accelerator_options = AcceleratorOptions(
|
|
num_threads=16, device=AcceleratorDevice.CUDA
|
|
)
|
|
|
|
pipeline_options = PdfPipelineOptions()
|
|
pipeline_options.accelerator_options = accelerator_options
|
|
pipeline_options.do_ocr = True
|
|
pipeline_options.do_table_structure = True
|
|
pipeline_options.table_structure_options.do_cell_matching = True
|
|
pipeline_options.generate_page_images = True
|
|
pipeline_options.generate_picture_images = True
|
|
|
|
converter = DocumentConverter(
|
|
format_options={
|
|
InputFormat.PDF: PdfFormatOption(
|
|
pipeline_options=pipeline_options,
|
|
)
|
|
}
|
|
)
|
|
|
|
# Enable the profiling to measure the time spent
|
|
settings.debug.profile_pipeline_timings = True
|
|
|
|
# Convert the document
|
|
conversion_result = converter.convert(
|
|
source=DocumentStream(name="", stream=BytesIO(pdf))
|
|
)
|
|
doc = conversion_result.document
|
|
|
|
doc.pictures
|
|
|
|
md = doc.export_to_markdown(
|
|
image_mode=ImageRefMode.EMBEDDED,
|
|
)
|
|
|
|
images: list[str] = re.findall(r"!\[Image\]\((.*?)\)", md)
|
|
image_dict: dict[str, bytes] = dict()
|
|
for i in range(len(images)):
|
|
data = images[i].removeprefix("data:image/png;base64,")
|
|
img_data = base64.b64decode(data)
|
|
image_dict[f"{i}.png"] = img_data
|
|
md = md.replace(images[i], f"{i}.png")
|
|
|
|
return (md, image_dict)
|
|
|
|
|
|
def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str:
|
|
"""Refines the Markdown content using an LLM."""
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read("config.ini")
|
|
provider = config.get("llm", "PROVIDER", fallback="gemini")
|
|
|
|
set_api_key(provider)
|
|
|
|
try:
|
|
if provider == "gemini":
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=get_model_name(), temperature=get_temperature()
|
|
)
|
|
elif provider == "ollama":
|
|
llm = ChatOllama(
|
|
model=get_model_name(),
|
|
temperature=get_temperature(),
|
|
base_url=os.environ["OLLAMA_BASE_URL"],
|
|
num_ctx=256000,
|
|
num_predict=-1,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
except Exception as e:
|
|
raise BaseException(
|
|
f"Error initializing LLM. Make sure your LLM configuration is correct. Error: {e}"
|
|
)
|
|
|
|
with open("pdf_convertor_prompt.md", "r") as f:
|
|
prompt = f.read()
|
|
|
|
# 添加 Markdown
|
|
human_message_parts = []
|
|
if provider == "gemini":
|
|
human_message_parts.append(
|
|
{
|
|
"type": "media",
|
|
"mime_type": "text/markdown",
|
|
"data": base64.b64encode(md.encode("UTF-8")).decode("utf-8"),
|
|
}
|
|
)
|
|
elif provider == "ollama":
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": md,
|
|
}
|
|
)
|
|
|
|
# 添加图片
|
|
for image_name in images.keys():
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": f"This is image: '{image_name}':\n",
|
|
}
|
|
)
|
|
if provider == "gemini":
|
|
human_message_parts.append(
|
|
{
|
|
"type": "media",
|
|
"mime_type": "image/png",
|
|
"data": base64.b64encode(images[image_name]).decode("utf-8"),
|
|
}
|
|
)
|
|
if provider == "ollama":
|
|
human_message_parts.append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": f"data:image/png;base64,{base64.b64encode(images[image_name]).decode('utf-8')}",
|
|
}
|
|
)
|
|
|
|
# 添加 PDF
|
|
if provider == "gemini":
|
|
human_message_parts.extend(
|
|
[
|
|
{
|
|
"type": "text",
|
|
"text": "This is original PDF file:\n",
|
|
},
|
|
{
|
|
"type": "media",
|
|
"mime_type": "application/pdf",
|
|
"data": base64.b64encode(pdf).decode("utf-8"),
|
|
},
|
|
]
|
|
)
|
|
if provider == "ollama":
|
|
doc = fitz.open(stream=pdf, filetype="pdf")
|
|
for page_num in range(doc.page_count):
|
|
page = doc.load_page(page_num)
|
|
pix = page.get_pixmap()
|
|
img_bytes = pix.tobytes("png")
|
|
human_message_parts.append(
|
|
{
|
|
"type": "text",
|
|
"text": f"This is page {page_num + 1} of the original PDF file:\n",
|
|
}
|
|
)
|
|
human_message_parts.append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}",
|
|
}
|
|
)
|
|
doc.close()
|
|
|
|
message_content = [
|
|
SystemMessage(content=prompt),
|
|
HumanMessage(content=human_message_parts), # type: ignore
|
|
]
|
|
|
|
print(
|
|
f"Sending request to {provider} with the PDF, Markdown and referenced images... This may take a moment."
|
|
)
|
|
try:
|
|
response = llm.invoke(message_content)
|
|
refined_content = response.content
|
|
except Exception as e:
|
|
raise BaseException(f"An error occurred while invoking the LLM: {e}")
|
|
|
|
if str(refined_content) == "":
|
|
raise BaseException(f"Response of {provider} is empty")
|
|
|
|
return fix_output(str(refined_content))
|
|
|
|
|
|
def fix_output(md: str) -> str:
|
|
if not md.startswith("['") or not md.endswith("']"):
|
|
return md
|
|
md = md.removeprefix("['")
|
|
md = md.removesuffix("']")
|
|
md = md.replace("\\n", "\n")
|
|
md = md.replace("\", '", "\n\n")
|
|
md = md.replace("', \"", "\n\n")
|
|
md = md.replace("', '", "\n\n")
|
|
return md
|