71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import logging
|
|
from typing import List
|
|
from opensearchpy import OpenSearch, RequestsHttpConnection
|
|
from core.config import settings
|
|
from core.models import DocumentChunk
|
|
|
|
logger = logging.getLogger("Retriever")
|
|
|
|
class SearchRetriever:
|
|
def __init__(self, index_name: str = "poc_sharepoint_docs"):
|
|
self.index_name = index_name
|
|
|
|
# Kết nối OpenSearch
|
|
self.client = OpenSearch(
|
|
hosts=[{'host': settings.opensearch_host, 'port': settings.opensearch_port}],
|
|
http_auth=(settings.opensearch_user, settings.opensearch_pass),
|
|
use_ssl=False,
|
|
verify_certs=False,
|
|
connection_class=RequestsHttpConnection
|
|
)
|
|
|
|
# Load Local Embedding Model (để biến câu hỏi thành vector cùng không gian với dữ liệu)
|
|
logger.info("Đang nạp Embedding Model cho Retriever...")
|
|
from sentence_transformers import SentenceTransformer
|
|
self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert')
|
|
|
|
def retrieve(self, query: str, top_k: int = 5) -> List[DocumentChunk]:
|
|
"""
|
|
Tìm kiếm ngữ nghĩa (Semantic Search) dựa trên Vector k-NN
|
|
"""
|
|
logger.info(f"Đang tìm kiếm ngữ nghĩa cho câu hỏi: '{query}'")
|
|
|
|
# 1. Chuyển câu hỏi thành Vector
|
|
query_vector = self.embedder.encode(query).tolist()
|
|
|
|
# 2. Xây dựng k-NN Query cho OpenSearch
|
|
# Ta có thể kết hợp Hybrid Search (Vector + Text) ở đây nếu muốn
|
|
search_query = {
|
|
"size": top_k,
|
|
"query": {
|
|
"knn": {
|
|
"embedding": {
|
|
"vector": query_vector,
|
|
"k": top_k
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
try:
|
|
response = self.client.search(
|
|
index=self.index_name,
|
|
body=search_query
|
|
)
|
|
|
|
hits = response.get("hits", {}).get("hits", [])
|
|
results = []
|
|
|
|
for hit in hits:
|
|
source = hit["_source"]
|
|
# Chuyển từ JSON sang DocumentChunk model
|
|
chunk = DocumentChunk(**source)
|
|
results.append(chunk)
|
|
|
|
logger.info(f"Tìm thấy {len(results)} đoạn văn phù hợp nhất.")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Lỗi khi truy vấn OpenSearch: {e}")
|
|
return []
|