import re import base64 import os from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.base_models import InputFormat from docling.datamodel.pipeline_options import ( PdfPipelineOptions, ) from docling_core.types.io import DocumentStream from docling.datamodel.settings import settings from docling.document_converter import DocumentConverter, PdfFormatOption from docling_core.types.doc.base import ImageRefMode from langchain_core.messages import HumanMessage, SystemMessage from langchain_google_genai import ChatGoogleGenerativeAI from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI from llm import set_api_key, get_model_name, get_temperature from io import BytesIO from pathlib import Path import configparser import fitz def save_md_images( output: str | Path, md_content: str, images: dict[str, bytes], md_name: str = "index.md", images_dirname: str = "images", ): output = Path(output) md_path = output.joinpath(md_name) md_path.parent.mkdir(exist_ok=True, parents=True) images_dir = output.joinpath(images_dirname) images_dir.mkdir(exist_ok=True, parents=True) for image_name in images.keys(): image_path = images_dir.joinpath(Path(image_name).name) with open(image_path, "wb") as image_file: image_file.write(images[image_name]) md_content = md_content.replace( f"]({image_name})", f"]({image_path.relative_to(md_path.parent, walk_up=True)})", ) with open(md_path, "w") as md_file: md_file.write(md_content) def load_md_file(md_path: str | Path) -> tuple[str, dict[str, bytes]]: md_path = Path(md_path) with open(md_path, "r") as md_file: md = md_file.read() images: list[str] = re.findall(r"!\[.*?\]\((.*?)\)", md) image_dict: dict[str, bytes] = dict() for i in range(len(images)): image_path = images[i] if image_path.startswith("data:image/png;base64,"): image_dict[f"{i}.png"] = image_path.removeprefix( "data:image/png;base64," ).encode("UTF-8") else: with open( Path(md_path.parent).joinpath(image_path), "rb" ) as image_file: image_dict[image_path] = image_file.read() return (md, image_dict) def convert_pdf_to_markdown(pdf: bytes) -> tuple[str, dict[str, bytes]]: """Converts a PDF document to Markdown format.""" accelerator_options = AcceleratorOptions( num_threads=16, device=AcceleratorDevice.CUDA ) pipeline_options = PdfPipelineOptions() pipeline_options.accelerator_options = accelerator_options pipeline_options.do_ocr = True pipeline_options.do_table_structure = True pipeline_options.table_structure_options.do_cell_matching = True pipeline_options.generate_page_images = True pipeline_options.generate_picture_images = True converter = DocumentConverter( format_options={ InputFormat.PDF: PdfFormatOption( pipeline_options=pipeline_options, ) } ) # Enable the profiling to measure the time spent settings.debug.profile_pipeline_timings = True # Convert the document conversion_result = converter.convert( source=DocumentStream(name="", stream=BytesIO(pdf)) ) doc = conversion_result.document doc.pictures md = doc.export_to_markdown( image_mode=ImageRefMode.EMBEDDED, ) images: list[str] = re.findall(r"!\[Image\]\((.*?)\)", md) image_dict: dict[str, bytes] = dict() for i in range(len(images)): data = images[i].removeprefix("data:image/png;base64,") img_data = base64.b64decode(data) image_dict[f"{i}.png"] = img_data md = md.replace(images[i], f"{i}.png") return (md, image_dict) def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str: """Refines the Markdown content using an LLM.""" config = configparser.ConfigParser() config.read("config.ini") provider = config.get("llm", "PROVIDER", fallback="gemini") set_api_key() try: if provider == "gemini": llm = ChatGoogleGenerativeAI( model=get_model_name(), temperature=get_temperature() ) elif provider == "ollama": llm = ChatOllama( model=get_model_name(), temperature=get_temperature(), base_url=os.environ["OLLAMA_BASE_URL"], num_ctx=256000, num_predict=-1, ) elif provider == "openai": llm = ChatOpenAI( model=get_model_name(), temperature=get_temperature(), ) else: raise ValueError(f"Unsupported LLM provider: {provider}") except Exception as e: raise BaseException( f"Error initializing LLM. Make sure your LLM configuration is correct. Error: {e}" ) with open("pdf_convertor_prompt.md", "r") as f: prompt = f.read() # 添加 Markdown human_message_parts = [] human_message_parts.append( { "type": "text", "text": md, } ) # 添加图片 for image_name in images.keys(): human_message_parts.append( { "type": "text", "text": f"This is image: '{image_name}':\n", } ) human_message_parts.append( { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{base64.b64encode(images[image_name]).decode('utf-8')}" }, } ) # 添加 PDF if provider == "gemini": human_message_parts.extend( [ { "type": "text", "text": "This is original PDF file:\n", }, { "type": "media", "mime_type": "application/pdf", "data": base64.b64encode(pdf).decode("utf-8"), }, ] ) if provider == "openai": human_message_parts.extend( [ { "type": "file", "file": { "filename": "origin.pdf", "file_data": f"data:application/pdf;base64,{base64.b64encode(pdf).decode('utf-8')}", }, }, ] ) if provider == "ollama": doc = fitz.open(stream=pdf, filetype="pdf") for page_num in range(doc.page_count): page = doc.load_page(page_num) pix = page.get_pixmap() img_bytes = pix.tobytes("png") human_message_parts.append( { "type": "text", "text": f"This is page {page_num + 1} of the original PDF file:\n", } ) human_message_parts.append( { "type": "image_url", "image_url": f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}", } ) doc.close() message_content = [ SystemMessage(content=prompt), HumanMessage(content=human_message_parts), # type: ignore ] print( f"Sending request to {provider} with the PDF, Markdown and referenced images... This may take a moment." ) try: response = llm.invoke(message_content) refined_content = response.content except Exception as e: raise BaseException(f"An error occurred while invoking the LLM: {e}") if str(refined_content) == "": raise BaseException(f"Response of {provider} is empty") return fix_output(str(refined_content)) def fix_output(md: str) -> str: if not md.startswith("['") or not md.endswith("']"): return md md = md.removeprefix("['") md = md.removesuffix("']") md = md.replace("\\n", "\n") md = md.replace("\", '", "\n\n") md = md.replace("', \"", "\n\n") md = md.replace("', '", "\n\n") return md