refactor: 统一使用 OpenAI 兼容 API,支持自定义 base_url/key/model
- 移除 Gemini 和 Ollama 独立适配,统一使用 ChatOpenAI + base_url - config.ini 简化为 BASE_URL / API_KEY / MODEL / TEMPERATURE / MAX_RETRIES - 新增 config.example.ini 示例配置 - 移除 langchain-google-genai / langchain-ollama / pymupdf 依赖 - main.py 新增断点续跑:跳过已有 index.md / index_refined.md - LLM 请求支持 max_retries 自动重试(默认 3 次) - 优化 README
This commit is contained in:
+24
-62
@@ -1,6 +1,5 @@
|
||||
import re
|
||||
import base64
|
||||
import os
|
||||
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.datamodel.pipeline_options import (
|
||||
@@ -11,14 +10,11 @@ from docling.datamodel.settings import settings
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling_core.types.doc.base import ImageRefMode
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
from llm import set_api_key, get_model_name, get_temperature
|
||||
from llm import set_api_key, get_model_name, get_temperature, get_base_url, get_max_retries
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
import configparser
|
||||
import fitz
|
||||
import base64
|
||||
|
||||
|
||||
def save_md_images(
|
||||
@@ -117,35 +113,21 @@ def convert_pdf_to_markdown(pdf: bytes) -> tuple[str, dict[str, bytes]]:
|
||||
def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str:
|
||||
"""Refines the Markdown content using an LLM."""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
provider = config.get("llm", "PROVIDER", fallback="gemini")
|
||||
|
||||
set_api_key()
|
||||
|
||||
try:
|
||||
if provider == "gemini":
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model=get_model_name(), temperature=get_temperature()
|
||||
)
|
||||
elif provider == "ollama":
|
||||
llm = ChatOllama(
|
||||
model=get_model_name(),
|
||||
temperature=get_temperature(),
|
||||
base_url=os.environ["OLLAMA_BASE_URL"],
|
||||
num_ctx=256000,
|
||||
num_predict=-1,
|
||||
)
|
||||
elif provider == "openai":
|
||||
llm = ChatOpenAI(
|
||||
model=get_model_name(),
|
||||
temperature=get_temperature(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
kwargs = {
|
||||
"model": get_model_name(),
|
||||
"temperature": get_temperature(),
|
||||
}
|
||||
base_url = get_base_url()
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
kwargs["max_retries"] = get_max_retries()
|
||||
llm = ChatOpenAI(**kwargs)
|
||||
except Exception as e:
|
||||
raise BaseException(
|
||||
f"Error initializing LLM. Make sure your LLM configuration is correct. Error: {e}"
|
||||
f"Error initializing LLM. Make sure your configuration is correct. Error: {e}"
|
||||
)
|
||||
|
||||
with open("pdf_convertor_prompt.md", "r") as f:
|
||||
@@ -204,36 +186,16 @@ def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str:
|
||||
}
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "openai":
|
||||
human_message_parts.extend(
|
||||
[
|
||||
{
|
||||
"type": "file",
|
||||
"base64": base64.b64encode(pdf).decode("utf-8"),
|
||||
"mime_type": "application/pdf",
|
||||
"filename": "origin.pdf",
|
||||
},
|
||||
]
|
||||
)
|
||||
if provider == "ollama":
|
||||
doc = fitz.open(stream=pdf, filetype="pdf")
|
||||
for page_num in range(doc.page_count):
|
||||
page = doc.load_page(page_num)
|
||||
pix = page.get_pixmap()
|
||||
img_bytes = pix.tobytes("png")
|
||||
human_message_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"This is page {page_num + 1} of the original PDF file:\n",
|
||||
}
|
||||
)
|
||||
human_message_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}",
|
||||
}
|
||||
)
|
||||
doc.close()
|
||||
human_message_parts.extend(
|
||||
[
|
||||
{
|
||||
"type": "file",
|
||||
"base64": base64.b64encode(pdf).decode("utf-8"),
|
||||
"mime_type": "application/pdf",
|
||||
"filename": "origin.pdf",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
message_content = [
|
||||
SystemMessage(content=prompt),
|
||||
@@ -241,7 +203,7 @@ def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str:
|
||||
]
|
||||
|
||||
print(
|
||||
f"Sending request to {provider} with the PDF, Markdown and referenced images... This may take a moment."
|
||||
"Sending request to LLM with the PDF, Markdown and referenced images... This may take a moment."
|
||||
)
|
||||
try:
|
||||
response = llm.invoke(message_content)
|
||||
@@ -250,7 +212,7 @@ def refine_content(md: str, images: dict[str, bytes], pdf: bytes) -> str:
|
||||
raise BaseException(f"An error occurred while invoking the LLM: {e}")
|
||||
|
||||
if str(refined_content) == "":
|
||||
raise BaseException(f"Response of {provider} is empty")
|
||||
raise BaseException("Response of LLM is empty")
|
||||
|
||||
return fix_output(str(refined_content))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user