Xu ly SSO
This commit is contained in:
263
api/main.py
263
api/main.py
@@ -1,17 +1,29 @@
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import secrets
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, status
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import uvicorn
|
||||
import msal
|
||||
|
||||
# Đảm bảo đường dẫn module
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from chat.rag_engine import RAGEngine
|
||||
from core.config import settings
|
||||
from core.models import IngestedDocument, ProcessingPolicy
|
||||
from ingestion.providers.sharepoint_provider import SharePointProvider
|
||||
from ingestion.sync import SyncEngine
|
||||
from extraction.dce import DocumentClassificationEngine
|
||||
from extraction.ocr_service import OCRService
|
||||
from extraction.text_extractor import TextExtractor
|
||||
from chunking.markdown_chunker import MarkdownChunker
|
||||
from indexing.vector_store import VectorStore
|
||||
|
||||
# --- Cấu hình Logging chuyên nghiệp ---
|
||||
logging.basicConfig(
|
||||
@@ -29,8 +41,30 @@ app = FastAPI(
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Thêm cấu hình CORS để Frontend có thể gọi API
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Cho phép tất cả nguồn (hợp lý cho bản PoC)
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# --- Singleton Engine Instance ---
|
||||
rag_engine = None
|
||||
sync_status = {"running": False, "last_run": None, "processed": 0, "skipped": 0, "errors": []}
|
||||
|
||||
# --- Azure AD SSO Config ---
|
||||
REDIRECT_URI = "http://localhost:8000/auth/callback"
|
||||
AUTHORITY = f"https://login.microsoftonline.com/{settings.tenant_id}"
|
||||
SCOPE = ["User.Read"]
|
||||
|
||||
def _build_msal_app():
|
||||
return msal.ConfidentialClientApplication(
|
||||
settings.client_id,
|
||||
authority=AUTHORITY,
|
||||
client_credential=settings.client_secret,
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
@@ -89,6 +123,18 @@ class ChatResponse(BaseModel):
|
||||
sources: List[SourceCitation] = Field(default_factory=list, description="Danh sách các nguồn trích dẫn từ tài liệu")
|
||||
context_used: Optional[str] = Field(None, description="Ngữ cảnh thực tế đã được trích xuất từ VectorDB (Dùng cho Debug/UI)")
|
||||
|
||||
class SyncResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: str = Field(..., description="Email người dùng")
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
email: str
|
||||
display_name: str
|
||||
role: str
|
||||
|
||||
# --- ENDPOINTS ---
|
||||
|
||||
@app.get("/health", tags=["System"])
|
||||
@@ -103,11 +149,82 @@ async def health_check():
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/auth/login", tags=["Auth"])
|
||||
async def sso_login():
|
||||
"""
|
||||
Redirect sang Azure AD login page.
|
||||
Dùng chung App Registration với SharePoint ingestion.
|
||||
"""
|
||||
msal_app = _build_msal_app()
|
||||
auth_url = msal_app.get_authorization_request_url(
|
||||
SCOPE,
|
||||
redirect_uri=REDIRECT_URI,
|
||||
state=secrets.token_hex(16)
|
||||
)
|
||||
return RedirectResponse(url=auth_url)
|
||||
|
||||
@app.get("/auth/callback", tags=["Auth"])
|
||||
async def sso_callback(request: Request):
|
||||
"""
|
||||
Azure AD redirect về đây với authorization code.
|
||||
Đổi code lấy token, lấy thông tin user, redirect về frontend.
|
||||
"""
|
||||
code = request.query_params.get("code")
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing authorization code")
|
||||
|
||||
msal_app = _build_msal_app()
|
||||
result = msal_app.acquire_token_by_authorization_code(
|
||||
code,
|
||||
scopes=SCOPE,
|
||||
redirect_uri=REDIRECT_URI
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
logger.error(f"SSO error: {result.get('error_description', result.get('error'))}")
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
# Lấy thông tin user từ token
|
||||
id_token_claims = result.get("id_token_claims", {})
|
||||
email = id_token_claims.get("preferred_username", id_token_claims.get("email", ""))
|
||||
name = id_token_claims.get("name", email.split("@")[0])
|
||||
oid = id_token_claims.get("oid", "")
|
||||
|
||||
# Xác định role
|
||||
role = "admin" if "admin" in email.lower() else "user"
|
||||
|
||||
logger.info(f"SSO login: {email} (role={role})")
|
||||
|
||||
# Redirect về frontend với user info
|
||||
import json
|
||||
import urllib.parse
|
||||
user_data = json.dumps({"email": email, "display_name": name, "role": role})
|
||||
encoded = urllib.parse.quote(user_data)
|
||||
return RedirectResponse(url=f"http://localhost:8000?user={encoded}")
|
||||
|
||||
@app.post("/auth/login-email", response_model=LoginResponse, tags=["Auth"])
|
||||
async def login_email_endpoint(request: LoginRequest):
|
||||
"""
|
||||
Đăng nhập bằng email (fallback khi không dùng SSO).
|
||||
"""
|
||||
email = request.email.strip().lower()
|
||||
if not email or "@" not in email:
|
||||
raise HTTPException(status_code=400, detail="Email không hợp lệ.")
|
||||
|
||||
local_part = email.split("@")[0]
|
||||
display_name = local_part.replace(".", " ").title()
|
||||
role = "admin" if "admin" in email else "user"
|
||||
|
||||
logger.info(f"Email login: {email} (role={role})")
|
||||
|
||||
return LoginResponse(email=email, display_name=display_name, role=role)
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse, tags=["RAG"], status_code=status.HTTP_200_OK)
|
||||
async def chat_endpoint(request: ChatRequest):
|
||||
async def chat_endpoint(request: ChatRequest, http_request: Request):
|
||||
"""
|
||||
Điểm cuối xử lý hội thoại RAG.
|
||||
Hệ thống sẽ tự động trích xuất ngữ cảnh từ OpenSearch và sử dụng Provider đã cấu hình để trả lời.
|
||||
Header 'X-User-Email' (optional): Email user để filter quyền.
|
||||
Header 'X-User-Role' (optional): "admin" = bypass ACL.
|
||||
"""
|
||||
if not rag_engine:
|
||||
raise HTTPException(
|
||||
@@ -116,11 +233,14 @@ async def chat_endpoint(request: ChatRequest):
|
||||
)
|
||||
|
||||
try:
|
||||
# Chuyển đổi ChatHistoryItem sang format dict cho RAGEngine
|
||||
user_email = http_request.headers.get("X-User-Email")
|
||||
user_role = http_request.headers.get("X-User-Role", "user")
|
||||
is_admin = user_role == "admin" or not user_email
|
||||
|
||||
history_data = [item.dict() for item in request.history]
|
||||
|
||||
logger.info(f"Xử lý truy vấn: {request.query[:50]}...")
|
||||
result = rag_engine.chat(request.query, history=history_data)
|
||||
logger.info(f"Chat query: {request.query[:50]} (user={user_email or 'none'}, role={user_role})")
|
||||
result = rag_engine.chat(request.query, history=history_data, user_email=user_email, is_admin=is_admin)
|
||||
|
||||
return ChatResponse(
|
||||
answer=result["answer"],
|
||||
@@ -134,5 +254,136 @@ async def chat_endpoint(request: ChatRequest):
|
||||
detail="Đã xảy ra lỗi nội bộ trong quá trình xử lý ngôn ngữ."
|
||||
)
|
||||
|
||||
|
||||
def extract_text_from_pdf_bytes(pdf_bytes: bytes) -> str:
|
||||
"""Trích xuất text trực tiếp từ PDF có text layer."""
|
||||
try:
|
||||
import fitz
|
||||
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||
return "\n\n".join(page.get_text() for page in doc)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def run_sync_background():
|
||||
"""Chạy đồng bộ SharePoint → DCE → OCR/Extract → Chunk → Index."""
|
||||
global sync_status
|
||||
sync_status = {"running": True, "last_run": None, "processed": 0, "skipped": 0, "errors": []}
|
||||
|
||||
try:
|
||||
provider = SharePointProvider()
|
||||
dce = DocumentClassificationEngine(provider=provider)
|
||||
ocr = OCRService()
|
||||
chunker = MarkdownChunker(max_chunk_size=1000, overlap=100)
|
||||
vector_db = VectorStore(index_name="poc_sharepoint_docs")
|
||||
|
||||
items, _ = provider.fetch_changes({})
|
||||
logger.info(f"Sync: Found {len(items)} items from SharePoint")
|
||||
|
||||
for item in items:
|
||||
if item.get("is_folder") or item.get("is_deleted"):
|
||||
continue
|
||||
|
||||
name = item.get("name", "")
|
||||
item_id = item.get("id", "")
|
||||
|
||||
item_details = provider.get_item_details(item_id)
|
||||
permissions = provider.get_item_permissions(item_id)
|
||||
doc = IngestedDocument(
|
||||
site_id=settings.sharepoint_site_id,
|
||||
drive_id="",
|
||||
item_id=item_id,
|
||||
name=name,
|
||||
web_url=item_details.get("web_url", ""),
|
||||
download_url=item_details.get("download_url"),
|
||||
is_folder=False,
|
||||
size=item.get("size", 0),
|
||||
)
|
||||
|
||||
classification = dce.classify(doc, target_item=item)
|
||||
|
||||
if classification.processing_policy in (ProcessingPolicy.UNSUPPORTED, ProcessingPolicy.METADATA_ONLY, ProcessingPolicy.REQUIRES_REVIEW):
|
||||
sync_status["skipped"] += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
file_bytes = provider.download_file(item)
|
||||
except Exception as e:
|
||||
sync_status["errors"].append(f"{name}: download failed")
|
||||
continue
|
||||
|
||||
if not file_bytes:
|
||||
sync_status["errors"].append(f"{name}: empty file")
|
||||
continue
|
||||
|
||||
pages = []
|
||||
ext = name.lower().rsplit(".", 1)[-1] if "." in name else ""
|
||||
|
||||
if classification.processing_policy == ProcessingPolicy.SKIP_OCR:
|
||||
if ext == "pdf":
|
||||
text = extract_text_from_pdf_bytes(file_bytes)
|
||||
if text.strip():
|
||||
from core.models import OCRPageResult
|
||||
pages = [OCRPageResult(page=1, text=text, confidence=1.0)]
|
||||
elif ext in ("docx", "doc"):
|
||||
pages = TextExtractor.extract_from_docx(file_bytes)
|
||||
elif ext in ("xlsx", "xls"):
|
||||
pages = TextExtractor.extract_from_xlsx(file_bytes)
|
||||
elif ext in ("txt", "md", "csv"):
|
||||
pages = TextExtractor.extract_from_text(file_bytes)
|
||||
elif classification.processing_policy == ProcessingPolicy.REQUIRES_OCR:
|
||||
pages = ocr.process_pdf_bytes(file_bytes)
|
||||
|
||||
if not pages:
|
||||
sync_status["skipped"] += 1
|
||||
continue
|
||||
|
||||
metadata = {
|
||||
"item_id": item_id,
|
||||
"name": name,
|
||||
"web_url": item_details.get("web_url"),
|
||||
"download_url": item_details.get("download_url"),
|
||||
"site_id": settings.sharepoint_site_id,
|
||||
"permissions": permissions
|
||||
}
|
||||
chunks = chunker.chunk_document(pages, metadata)
|
||||
|
||||
if chunks:
|
||||
vector_db.delete_by_file_id(item_id)
|
||||
vector_db.embed_and_index(chunks)
|
||||
sync_status["processed"] += 1
|
||||
logger.info(f"Sync: Indexed {name} → {len(chunks)} chunks")
|
||||
else:
|
||||
sync_status["skipped"] += 1
|
||||
|
||||
sync_status["last_run"] = "completed"
|
||||
logger.info(f"Sync completed: {sync_status['processed']} processed, {sync_status['skipped']} skipped")
|
||||
|
||||
except Exception as e:
|
||||
sync_status["last_run"] = "failed"
|
||||
sync_status["errors"].append(str(e))
|
||||
logger.error(f"Sync failed: {e}")
|
||||
finally:
|
||||
sync_status["running"] = False
|
||||
|
||||
|
||||
@app.post("/sync", response_model=SyncResponse, tags=["Ingestion"])
|
||||
async def sync_endpoint(background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Trigger đồng bộ dữ liệu từ SharePoint.
|
||||
Chạy trong background, trả về trạng thái ngay lập tức.
|
||||
"""
|
||||
if sync_status["running"]:
|
||||
return SyncResponse(status="already_running", message="Đồng bộ đang chạy, vui lòng đợi.")
|
||||
|
||||
background_tasks.add_task(run_sync_background)
|
||||
return SyncResponse(status="started", message="Đồng bộ đã bắt đầu trong background.")
|
||||
|
||||
|
||||
@app.get("/sync/status", tags=["Ingestion"])
|
||||
async def sync_status_endpoint():
|
||||
"""Kiểm tra trạng thái đồng bộ."""
|
||||
return sync_status
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
|
||||
Reference in New Issue
Block a user