refactor: remove unused abstract method and improve retrievers

This commit is contained in:
Siddhant Rai
2025-08-20 22:25:31 +05:30
parent 6f47aa802b
commit bd73fa9ae7
3 changed files with 121 additions and 40 deletions

View File

@@ -5,10 +5,6 @@ class BaseRetriever(ABC):
def __init__(self):
pass
@abstractmethod
def gen(self, *args, **kwargs):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass

View File

@@ -1,4 +1,5 @@
import logging
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
@@ -20,6 +21,7 @@ class ClassicRAG(BaseRetriever):
api_key=settings.API_KEY,
decoded_token=None,
):
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
self.original_question = source.get("question", "")
self.chat_history = chat_history if chat_history is not None else []
self.prompt = prompt
@@ -47,25 +49,46 @@ class ClassicRAG(BaseRetriever):
if "active_docs" in source:
if isinstance(source["active_docs"], list):
self.vectorstores = source["active_docs"]
elif isinstance(source["active_docs"], str) and "," in source["active_docs"]:
# ✅ split multiple IDs from comma string
self.vectorstores = [doc_id.strip() for doc_id in source["active_docs"].split(",") if doc_id.strip()]
elif (
isinstance(source["active_docs"], str) and "," in source["active_docs"]
):
self.vectorstores = [
doc_id.strip()
for doc_id in source["active_docs"].split(",")
if doc_id.strip()
]
else:
self.vectorstores = [source["active_docs"]]
else:
self.vectorstores = []
self.vectorstore = None
self.question = self._rephrase_query()
self.decoded_token = decoded_token
self._validate_vectorstore_config()
def _validate_vectorstore_config(self):
"""Validate vectorstore IDs and remove any empty/invalid entries"""
if not self.vectorstores:
logging.warning("No vectorstores configured for retrieval")
return
invalid_ids = [
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
]
if invalid_ids:
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
self.vectorstores = [
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
]
def _rephrase_query(self):
"""Rephrase user query with chat history context for better retrieval"""
if (
not self.original_question
or not self.chat_history
or self.chat_history == []
or self.chunks == 0
or self.vectorstore is None
or not self.vectorstores
):
return self.original_question
@@ -90,41 +113,62 @@ class ClassicRAG(BaseRetriever):
return self.original_question
def _get_data(self):
"""Retrieve relevant documents from configured vectorstores"""
if self.chunks == 0 or not self.vectorstores:
return []
all_docs = []
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
for vectorstore in self.vectorstores:
if vectorstore:
for vectorstore_id in self.vectorstores:
if vectorstore_id:
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=chunks_per_source)
for i in docs_temp:
all_docs.append({
"title": i.metadata.get("title", i.metadata.get("post_title", i.page_content)).split("/")[-1],
"text": i.page_content,
"source": i.metadata.get("source") or vectorstore,
})
for doc in docs_temp:
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
title = metadata.get(
"title", metadata.get("post_title", page_content)
)
if isinstance(title, str):
title = title.split("/")[-1]
else:
title = str(title).split("/")[-1]
all_docs.append(
{
"title": title,
"text": page_content,
"source": metadata.get("source") or vectorstore_id,
}
)
except Exception as e:
logging.error(f"Error searching vectorstore {vectorstore}: {e}")
logging.error(
f"Error searching vectorstore {vectorstore_id}: {e}",
exc_info=True,
)
continue
return all_docs
def gen():
pass
def search(self, query: str = ""):
"""Search for documents using optional query override"""
if query:
self.original_question = query
self.question = self._rephrase_query()
return self._get_data()
def get_params(self):
"""Return current retriever configuration parameters"""
return {
"question": self.original_question,
"rephrased_question": self.question,

View File

@@ -1,20 +1,28 @@
from abc import ABC, abstractmethod
import os
from sentence_transformers import SentenceTransformer
from abc import ABC, abstractmethod
from langchain_openai import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs
)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_query(self, query: str):
return self.model.encode(query).tolist()
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
@@ -24,15 +32,14 @@ class EmbeddingsWrapper:
raise ValueError("Input must be a string or a list of strings")
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
embeddings_name, *args, **kwargs
EmbeddingsSingleton._instances[embeddings_name] = (
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
)
return EmbeddingsSingleton._instances[embeddings_name]
@@ -40,9 +47,15 @@ class EmbeddingsSingleton:
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
"hkunlp/instructor-large"
),
}
if embeddings_name in embeddings_factory:
@@ -50,34 +63,63 @@ class EmbeddingsSingleton:
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
"""Search for similar documents/chunks in the vectorstore"""
pass
@abstractmethod
def add_texts(self, texts, metadatas=None, *args, **kwargs):
"""Add texts with their embeddings to the vectorstore"""
pass
def delete_index(self, *args, **kwargs):
"""Delete the entire index/collection"""
pass
def save_local(self, *args, **kwargs):
"""Save vectorstore to local storage"""
pass
def get_chunks(self, *args, **kwargs):
"""Get all chunks from the vectorstore"""
pass
def add_chunk(self, text, metadata=None, *args, **kwargs):
"""Add a single chunk to the vectorstore"""
pass
def delete_chunk(self, chunk_id, *args, **kwargs):
"""Delete a specific chunk from the vectorstore"""
pass
def is_azure_configured(self):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def _get_embeddings(self, embeddings_name, embeddings_key=None):
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
openai_api_key=embeddings_key
embeddings_name, openai_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./models/all-mpnet-base-v2"):
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name = "./models/all-mpnet-base-v2",
embeddings_name="./models/all-mpnet-base-v2",
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
@@ -87,4 +129,3 @@ class BaseVectorStore(ABC):
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance