103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
import logging
|
|
import os
|
|
from typing import List, Optional
|
|
from opensearchpy import OpenSearch, RequestsHttpConnection
|
|
from core.config import settings
|
|
from core.models import DocumentChunk
|
|
|
|
logger = logging.getLogger("Retriever")
|
|
|
|
|
|
class SearchRetriever:
|
|
def __init__(self, index_name: str = "poc_sharepoint_docs"):
|
|
self.index_name = index_name
|
|
|
|
host = settings.opensearch_host
|
|
if host == "opensearch" and os.environ.get("ENV") != "docker":
|
|
host = "localhost"
|
|
|
|
self.client = OpenSearch(
|
|
hosts=[{'host': host, 'port': settings.opensearch_port}],
|
|
http_auth=(settings.opensearch_user, settings.opensearch_pass),
|
|
use_ssl=False,
|
|
verify_certs=False,
|
|
connection_class=RequestsHttpConnection
|
|
)
|
|
|
|
logger.info("Loading Embedding Model for Retriever...")
|
|
from sentence_transformers import SentenceTransformer
|
|
self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert')
|
|
|
|
def retrieve(self, query: str, top_k: int = 5, user_email: Optional[str] = None, is_admin: bool = False) -> List[DocumentChunk]:
|
|
"""
|
|
Tìm kiếm ngữ nghĩa với ACL filtering.
|
|
|
|
Args:
|
|
query: Câu hỏi của user
|
|
top_k: Số kết quả tối đa
|
|
user_email: Email user để filter quyền.
|
|
is_admin: True = bypass ACL, thấy tất cả.
|
|
"""
|
|
logger.info(f"Search: '{query[:80]}' (user={user_email or 'none'}, admin={is_admin})")
|
|
|
|
query_vector = self.embedder.encode(query).tolist()
|
|
|
|
# Admin hoặc không có user_email → không filter
|
|
if is_admin or not user_email:
|
|
search_query = {
|
|
"size": top_k,
|
|
"query": {
|
|
"knn": {
|
|
"embedding": {
|
|
"vector": query_vector,
|
|
"k": top_k
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else:
|
|
# User thường → filter theo permissions
|
|
search_query = {
|
|
"size": top_k,
|
|
"query": {
|
|
"bool": {
|
|
"must": [
|
|
{
|
|
"knn": {
|
|
"embedding": {
|
|
"vector": query_vector,
|
|
"k": top_k * 2
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"should": [
|
|
{"term": {"permissions": "*"}},
|
|
{"term": {"permissions": user_email.lower()}}
|
|
],
|
|
"minimum_should_match": 1
|
|
}
|
|
}
|
|
}
|
|
|
|
try:
|
|
response = self.client.search(
|
|
index=self.index_name,
|
|
body=search_query
|
|
)
|
|
|
|
hits = response.get("hits", {}).get("hits", [])
|
|
results = []
|
|
|
|
for hit in hits:
|
|
source = hit["_source"]
|
|
chunk = DocumentChunk(**source)
|
|
results.append(chunk)
|
|
|
|
logger.info(f"Found {len(results)} chunks")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"OpenSearch query error: {e}")
|
|
return []
|