Phase 7: Hoàn thiện Modular RAG Backend với FastAPI và Đa LLM Provider
This commit is contained in:
99
extraction/dce.py
Normal file
99
extraction/dce.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import httpx
|
||||
import logging
|
||||
from core.models import IngestedDocument, DocumentClassificationResult, DocumentType, ProcessingPolicy, PdfType
|
||||
from extraction.magic_numbers import MagicNumberValidator
|
||||
from extraction.pdf_inspector import PDFInspector
|
||||
|
||||
logger = logging.getLogger("DCE")
|
||||
|
||||
class DocumentClassificationEngine:
|
||||
"""
|
||||
Document Classification Engine (DCE).
|
||||
"""
|
||||
def __init__(self):
|
||||
self.pdf_inspector = PDFInspector()
|
||||
|
||||
def classify(self, document: IngestedDocument) -> DocumentClassificationResult:
|
||||
logger.info(f"Classifying document: {document.name} (ID: {document.item_id})")
|
||||
|
||||
ext = os.path.splitext(document.name)[1].lower()
|
||||
|
||||
doc_type = DocumentType.UNKNOWN
|
||||
policy = ProcessingPolicy.UNSUPPORTED
|
||||
reason = "Initial state"
|
||||
|
||||
# 1. Magic Number Validation
|
||||
if document.download_url:
|
||||
header_bytes = MagicNumberValidator.fetch_header_bytes(document.download_url)
|
||||
is_valid, detected_type, sig_desc = MagicNumberValidator.validate_from_bytes(header_bytes)
|
||||
if is_valid:
|
||||
logger.info(f"Magic Number match: {sig_desc}")
|
||||
else:
|
||||
logger.warning(f"Could not verify magic number for {document.name}. Trusting extension fallback.")
|
||||
|
||||
# 2. Routing Rules
|
||||
if ext == ".pdf":
|
||||
pdf_type = PdfType.SCAN_PDF # Simulated default
|
||||
if document.download_url:
|
||||
logger.info("Downloading PDF into memory for PyMuPDF inspection...")
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
resp = client.get(document.download_url)
|
||||
resp.raise_for_status()
|
||||
pdf_bytes = resp.content
|
||||
pdf_type = self.pdf_inspector.inspect_pdf_from_bytes(pdf_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download/inspect PDF: {e}")
|
||||
pdf_type = PdfType.SCAN_PDF
|
||||
else:
|
||||
logger.warning("No download_url available for PDF. Defaulting to SCAN_PDF.")
|
||||
|
||||
if pdf_type == PdfType.TEXT_PDF:
|
||||
doc_type = DocumentType.TEXTUAL_DOCUMENT
|
||||
policy = ProcessingPolicy.SKIP_OCR
|
||||
reason = "PDF has text layer (TEXT_PDF)"
|
||||
elif pdf_type == PdfType.DRAWING_PDF:
|
||||
doc_type = DocumentType.DRAWING
|
||||
policy = ProcessingPolicy.METADATA_ONLY
|
||||
reason = "PDF has large vector dimensions (DRAWING_PDF)"
|
||||
elif pdf_type == PdfType.AMBIGUOUS_PDF:
|
||||
doc_type = DocumentType.UNKNOWN
|
||||
policy = ProcessingPolicy.REQUIRES_REVIEW
|
||||
reason = "Kích thước PDF lớn bất thường (khổ A3/A2 hoặc DPI cao), cần con người xác nhận là bản Scan hay Bản vẽ"
|
||||
else:
|
||||
doc_type = DocumentType.TEXTUAL_DOCUMENT
|
||||
policy = ProcessingPolicy.REQUIRES_OCR
|
||||
reason = "PDF has no text layer (SCAN_PDF)"
|
||||
|
||||
elif ext in [".docx", ".doc", ".txt", ".md"]:
|
||||
doc_type = DocumentType.TEXTUAL_DOCUMENT
|
||||
policy = ProcessingPolicy.SKIP_OCR
|
||||
reason = "Standard textual document format"
|
||||
|
||||
elif ext in [".xlsx", ".xls", ".csv"]:
|
||||
doc_type = DocumentType.SPREADSHEET
|
||||
policy = ProcessingPolicy.SKIP_OCR
|
||||
reason = "Spreadsheet document format"
|
||||
|
||||
elif ext in [".dwg", ".dxf", ".cad"]:
|
||||
doc_type = DocumentType.DRAWING
|
||||
policy = ProcessingPolicy.METADATA_ONLY
|
||||
reason = "Native CAD drawing format"
|
||||
|
||||
else:
|
||||
doc_type = DocumentType.BINARY
|
||||
policy = ProcessingPolicy.UNSUPPORTED
|
||||
reason = f"Unsupported or binary extension: {ext}"
|
||||
|
||||
result = DocumentClassificationResult(
|
||||
item_id=document.item_id,
|
||||
doc_type=doc_type,
|
||||
processing_policy=policy,
|
||||
file_extension=ext,
|
||||
is_supported=policy != ProcessingPolicy.UNSUPPORTED,
|
||||
reason=reason
|
||||
)
|
||||
|
||||
logger.info(f"Result -> Type: {doc_type.value}, Policy: {policy.value}, Reason: {reason}")
|
||||
return result
|
||||
39
extraction/magic_numbers.py
Normal file
39
extraction/magic_numbers.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Dict, Any, Tuple
|
||||
import httpx
|
||||
import logging
|
||||
from core.models import IngestedDocument, DocumentClassificationResult, DocumentType, ProcessingPolicy
|
||||
|
||||
logger = logging.getLogger("DCE")
|
||||
|
||||
class MagicNumberValidator:
|
||||
"""Validates file types using magic numbers (file signatures)."""
|
||||
|
||||
SIGNATURES = {
|
||||
b"%PDF-": (DocumentType.TEXTUAL_DOCUMENT, "PDF Document"),
|
||||
b"PK\x03\x04": (DocumentType.UNKNOWN, "ZIP Archive / Office Open XML"), # Needs further check
|
||||
b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1": (DocumentType.UNKNOWN, "Legacy Office Document"),
|
||||
# Add CAD magic numbers here if needed (e.g., AutoCAD DWG: b"AC10")
|
||||
b"AC10": (DocumentType.DRAWING, "AutoCAD Drawing (DWG)")
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_from_bytes(cls, header_bytes: bytes) -> Tuple[bool, DocumentType, str]:
|
||||
"""Checks if the bytes match any known signature."""
|
||||
for sig, (doc_type, desc) in cls.SIGNATURES.items():
|
||||
if header_bytes.startswith(sig):
|
||||
return True, doc_type, desc
|
||||
return False, DocumentType.UNKNOWN, "Unknown Signature"
|
||||
|
||||
@classmethod
|
||||
def fetch_header_bytes(cls, download_url: str, num_bytes: int = 256) -> bytes:
|
||||
"""Fetches only the first N bytes of a file using HTTP Range request."""
|
||||
try:
|
||||
# Idea: HTTP Range request prevents downloading huge files just to check headers
|
||||
headers = {"Range": f"bytes=0-{num_bytes - 1}"}
|
||||
with httpx.Client() as client:
|
||||
response = client.get(download_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch header bytes: {e}")
|
||||
return b""
|
||||
111
extraction/ocr_service.py
Normal file
111
extraction/ocr_service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import httpx
|
||||
import fitz
|
||||
from PIL import Image
|
||||
from typing import List, Tuple
|
||||
from core.models import OCRPageResult
|
||||
from core.config import settings
|
||||
|
||||
logger = logging.getLogger("OCRService")
|
||||
|
||||
class OCRService:
|
||||
"""
|
||||
OCR Service implementation acting as a VLM client.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.vlm_url = settings.VLM_ENDPOINT
|
||||
logger.info(f"Initialized VLM OCR Service connecting to {self.vlm_url}")
|
||||
|
||||
def _image_to_base64(self, img: Image.Image) -> str:
|
||||
"""Chuyển đổi PIL Image sang chuẩn Base64 JPEG"""
|
||||
buffered = io.BytesIO()
|
||||
# Chuyển sang RGB nếu ảnh có kênh Alpha
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
img.save(buffered, format="JPEG", quality=85)
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return f"data:image/jpeg;base64,{img_str}"
|
||||
|
||||
def process_pdf_bytes(self, pdf_bytes: bytes) -> List[OCRPageResult]:
|
||||
"""Process a PDF from memory using Vintern-3B VLM via LAN"""
|
||||
if not pdf_bytes:
|
||||
logger.warning("Empty PDF bytes received.")
|
||||
return []
|
||||
|
||||
results = []
|
||||
try:
|
||||
import gc
|
||||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||
for page_num in range(len(doc)):
|
||||
logger.info(f"VLM Processing page {page_num + 1}/{len(doc)} via LAN...")
|
||||
|
||||
# Render trang PDF thành ảnh. Hạ độ phân giải xuống 1.2 để giảm thiểu số lượng token
|
||||
# Tránh lỗi 500 do vượt quá Context Window của Llama.cpp
|
||||
matrix = fitz.Matrix(1.2, 1.2)
|
||||
pix = doc[page_num].get_pixmap(matrix=matrix)
|
||||
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
del pix
|
||||
gc.collect()
|
||||
|
||||
# Chuyển ảnh sang Base64
|
||||
b64_image = self._image_to_base64(img)
|
||||
|
||||
# Gọi API Llama.cpp Server
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": b64_image
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hãy trích xuất chính xác toàn bộ văn bản có trong hình ảnh này. Giữ nguyên định dạng và các dấu câu tiếng Việt."
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"temperature": settings.VLM_TEMPERATURE,
|
||||
"max_tokens": settings.VLM_MAX_TOKENS
|
||||
}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=settings.VLM_TIMEOUT) as client:
|
||||
response = client.post(self.vlm_url, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
vlm_text = data['choices'][0]['message']['content'].strip()
|
||||
|
||||
results.append(OCRPageResult(
|
||||
page=page_num + 1,
|
||||
text=vlm_text,
|
||||
confidence=0.99, # VLM thường không trả về độ tự tin từng chữ, set cứng 0.99
|
||||
paddle_text="", # Bỏ qua cột so sánh cũ
|
||||
paddle_confidence=0.0
|
||||
))
|
||||
logger.info(f"VLM extraction successful for page {page_num + 1}")
|
||||
|
||||
except Exception as api_err:
|
||||
logger.error(f"VLM API Error: {api_err}")
|
||||
# Ghi nhận trang lỗi nhưng vẫn tiếp tục các trang sau
|
||||
results.append(OCRPageResult(
|
||||
page=page_num + 1,
|
||||
text=f"[LỖI KẾT NỐI VLM: {api_err}]",
|
||||
confidence=0.0,
|
||||
paddle_text="",
|
||||
paddle_confidence=0.0
|
||||
))
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Failed to process PDF: {e}\n{traceback.format_exc()}")
|
||||
return []
|
||||
61
extraction/pdf_inspector.py
Normal file
61
extraction/pdf_inspector.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from core.models import DocumentType, ProcessingPolicy, PdfType
|
||||
|
||||
logger = logging.getLogger("PDFInspector")
|
||||
|
||||
class PDFInspector:
|
||||
"""
|
||||
Inspects PDF files to determine if they are TEXT, SCAN, DRAWING or AMBIGUOUS.
|
||||
"""
|
||||
|
||||
def __init__(self, text_density_threshold: int = 100):
|
||||
self.text_density_threshold = text_density_threshold
|
||||
|
||||
def inspect_pdf_from_bytes(self, pdf_bytes: bytes) -> PdfType:
|
||||
"""
|
||||
Deep inspects a PDF file from a byte stream.
|
||||
"""
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
except ImportError:
|
||||
logger.error("PyMuPDF (fitz) is not installed. Returning default SCAN_PDF.")
|
||||
return PdfType.SCAN_PDF
|
||||
|
||||
try:
|
||||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||
num_pages = len(doc)
|
||||
|
||||
pages_to_check = min(3, num_pages)
|
||||
total_text_length = 0
|
||||
is_huge = False
|
||||
is_ambiguous_size = False
|
||||
|
||||
for i in range(pages_to_check):
|
||||
page = doc[i]
|
||||
rect = page.rect
|
||||
max_dim = max(rect.width, rect.height)
|
||||
|
||||
if max_dim > 3000:
|
||||
is_huge = True
|
||||
elif max_dim > 1000:
|
||||
is_ambiguous_size = True
|
||||
|
||||
text = page.get_text()
|
||||
total_text_length += len(text.strip())
|
||||
|
||||
avg_text = total_text_length / pages_to_check
|
||||
|
||||
if avg_text >= self.text_density_threshold:
|
||||
return PdfType.TEXT_PDF
|
||||
|
||||
if is_huge:
|
||||
return PdfType.DRAWING_PDF
|
||||
|
||||
if is_ambiguous_size:
|
||||
return PdfType.AMBIGUOUS_PDF
|
||||
|
||||
return PdfType.SCAN_PDF
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error inspecting PDF stream: {e}")
|
||||
return PdfType.SCAN_PDF
|
||||
Reference in New Issue
Block a user