Files
poc_system/search/retriever.py
2026-05-09 10:31:28 +00:00

103 lines
3.5 KiB
Python

import logging
import os
from typing import List, Optional
from opensearchpy import OpenSearch, RequestsHttpConnection
from core.config import settings
from core.models import DocumentChunk
logger = logging.getLogger("Retriever")
class SearchRetriever:
def __init__(self, index_name: str = "poc_sharepoint_docs"):
self.index_name = index_name
host = settings.opensearch_host
if host == "opensearch" and os.environ.get("ENV") != "docker":
host = "localhost"
self.client = OpenSearch(
hosts=[{'host': host, 'port': settings.opensearch_port}],
http_auth=(settings.opensearch_user, settings.opensearch_pass),
use_ssl=False,
verify_certs=False,
connection_class=RequestsHttpConnection
)
logger.info("Loading Embedding Model for Retriever...")
from sentence_transformers import SentenceTransformer
self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert')
def retrieve(self, query: str, top_k: int = 5, user_email: Optional[str] = None, is_admin: bool = False) -> List[DocumentChunk]:
"""
Tìm kiếm ngữ nghĩa với ACL filtering.
Args:
query: Câu hỏi của user
top_k: Số kết quả tối đa
user_email: Email user để filter quyền.
is_admin: True = bypass ACL, thấy tất cả.
"""
logger.info(f"Search: '{query[:80]}' (user={user_email or 'none'}, admin={is_admin})")
query_vector = self.embedder.encode(query).tolist()
# Admin hoặc không có user_email → không filter
if is_admin or not user_email:
search_query = {
"size": top_k,
"query": {
"knn": {
"embedding": {
"vector": query_vector,
"k": top_k
}
}
}
}
else:
# User thường → filter theo permissions
search_query = {
"size": top_k,
"query": {
"bool": {
"must": [
{
"knn": {
"embedding": {
"vector": query_vector,
"k": top_k * 2
}
}
}
],
"should": [
{"term": {"permissions": "*"}},
{"term": {"permissions": user_email.lower()}}
],
"minimum_should_match": 1
}
}
}
try:
response = self.client.search(
index=self.index_name,
body=search_query
)
hits = response.get("hits", {}).get("hits", [])
results = []
for hit in hits:
source = hit["_source"]
chunk = DocumentChunk(**source)
results.append(chunk)
logger.info(f"Found {len(results)} chunks")
return results
except Exception as e:
logger.error(f"OpenSearch query error: {e}")
return []