From bd73fa9ae716e9c0a31604aea171591752ffaaef Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 20 Aug 2025 22:25:31 +0530 Subject: [PATCH] refactor: remove unused abstract method and improve retrievers --- application/retriever/base.py | 4 -- application/retriever/classic_rag.py | 80 +++++++++++++++++++++------- application/vectorstore/base.py | 77 +++++++++++++++++++------- 3 files changed, 121 insertions(+), 40 deletions(-) diff --git a/application/retriever/base.py b/application/retriever/base.py index fd99dbdd..36ac2e93 100644 --- a/application/retriever/base.py +++ b/application/retriever/base.py @@ -5,10 +5,6 @@ class BaseRetriever(ABC): def __init__(self): pass - @abstractmethod - def gen(self, *args, **kwargs): - pass - @abstractmethod def search(self, *args, **kwargs): pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index a9e3dee7..b558c8f0 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,5 @@ import logging + from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever @@ -20,6 +21,7 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): + """Initialize ClassicRAG retriever with vectorstore sources and LLM configuration""" self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt @@ -47,25 +49,46 @@ class ClassicRAG(BaseRetriever): if "active_docs" in source: if isinstance(source["active_docs"], list): self.vectorstores = source["active_docs"] - elif isinstance(source["active_docs"], str) and "," in source["active_docs"]: - # ✅ split multiple IDs from comma string - self.vectorstores = [doc_id.strip() for doc_id in source["active_docs"].split(",") if doc_id.strip()] + elif ( + isinstance(source["active_docs"], str) and "," in source["active_docs"] + ): + self.vectorstores = [ + doc_id.strip() + for doc_id in source["active_docs"].split(",") + if doc_id.strip() + ] else: self.vectorstores = [source["active_docs"]] else: self.vectorstores = [] - self.vectorstore = None self.question = self._rephrase_query() self.decoded_token = decoded_token + self._validate_vectorstore_config() + + def _validate_vectorstore_config(self): + """Validate vectorstore IDs and remove any empty/invalid entries""" + if not self.vectorstores: + logging.warning("No vectorstores configured for retrieval") + return + + invalid_ids = [ + vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip() + ] + if invalid_ids: + logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}") + self.vectorstores = [ + vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip() + ] def _rephrase_query(self): + """Rephrase user query with chat history context for better retrieval""" if ( not self.original_question or not self.chat_history or self.chat_history == [] or self.chunks == 0 - or self.vectorstore is None + or not self.vectorstores ): return self.original_question @@ -90,41 +113,62 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): + """Retrieve relevant documents from configured vectorstores""" if self.chunks == 0 or not self.vectorstores: return [] all_docs = [] chunks_per_source = max(1, self.chunks // len(self.vectorstores)) - for vectorstore in self.vectorstores: - if vectorstore: + for vectorstore_id in self.vectorstores: + if vectorstore_id: try: docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY + settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY ) docs_temp = docsearch.search(self.question, k=chunks_per_source) - for i in docs_temp: - all_docs.append({ - "title": i.metadata.get("title", i.metadata.get("post_title", i.page_content)).split("/")[-1], - "text": i.page_content, - "source": i.metadata.get("source") or vectorstore, - }) + + for doc in docs_temp: + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + + title = metadata.get( + "title", metadata.get("post_title", page_content) + ) + if isinstance(title, str): + title = title.split("/")[-1] + else: + title = str(title).split("/")[-1] + + all_docs.append( + { + "title": title, + "text": page_content, + "source": metadata.get("source") or vectorstore_id, + } + ) except Exception as e: - logging.error(f"Error searching vectorstore {vectorstore}: {e}") + logging.error( + f"Error searching vectorstore {vectorstore_id}: {e}", + exc_info=True, + ) continue return all_docs - def gen(): - pass - def search(self, query: str = ""): + """Search for documents using optional query override""" if query: self.original_question = query self.question = self._rephrase_query() return self._get_data() def get_params(self): + """Return current retriever configuration parameters""" return { "question": self.original_question, "rephrased_question": self.question, diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index a6b206c9..ea4885cd 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -1,20 +1,28 @@ -from abc import ABC, abstractmethod import os -from sentence_transformers import SentenceTransformer +from abc import ABC, abstractmethod + from langchain_openai import OpenAIEmbeddings +from sentence_transformers import SentenceTransformer + from application.core.settings import settings + class EmbeddingsWrapper: def __init__(self, model_name, *args, **kwargs): - self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs) + self.model = SentenceTransformer( + model_name, + config_kwargs={"allow_dangerous_deserialization": True}, + *args, + **kwargs + ) self.dimension = self.model.get_sentence_embedding_dimension() def embed_query(self, query: str): return self.model.encode(query).tolist() - + def embed_documents(self, documents: list): return self.model.encode(documents).tolist() - + def __call__(self, text): if isinstance(text, str): return self.embed_query(text) @@ -24,15 +32,14 @@ class EmbeddingsWrapper: raise ValueError("Input must be a string or a list of strings") - class EmbeddingsSingleton: _instances = {} @staticmethod def get_instance(embeddings_name, *args, **kwargs): if embeddings_name not in EmbeddingsSingleton._instances: - EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance( - embeddings_name, *args, **kwargs + EmbeddingsSingleton._instances[embeddings_name] = ( + EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs) ) return EmbeddingsSingleton._instances[embeddings_name] @@ -40,9 +47,15 @@ class EmbeddingsSingleton: def _create_instance(embeddings_name, *args, **kwargs): embeddings_factory = { "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"), + "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper( + "hkunlp/instructor-large" + ), } if embeddings_name in embeddings_factory: @@ -50,34 +63,63 @@ class EmbeddingsSingleton: else: return EmbeddingsWrapper(embeddings_name, *args, **kwargs) + class BaseVectorStore(ABC): def __init__(self): pass @abstractmethod def search(self, *args, **kwargs): + """Search for similar documents/chunks in the vectorstore""" + pass + + @abstractmethod + def add_texts(self, texts, metadatas=None, *args, **kwargs): + """Add texts with their embeddings to the vectorstore""" + pass + + def delete_index(self, *args, **kwargs): + """Delete the entire index/collection""" + pass + + def save_local(self, *args, **kwargs): + """Save vectorstore to local storage""" + pass + + def get_chunks(self, *args, **kwargs): + """Get all chunks from the vectorstore""" + pass + + def add_chunk(self, text, metadata=None, *args, **kwargs): + """Add a single chunk to the vectorstore""" + pass + + def delete_chunk(self, chunk_id, *args, **kwargs): + """Delete a specific chunk from the vectorstore""" pass def is_azure_configured(self): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) def _get_embeddings(self, embeddings_name, embeddings_key=None): if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME + embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME ) else: embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - openai_api_key=embeddings_key + embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./models/all-mpnet-base-v2"): embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name = "./models/all-mpnet-base-v2", + embeddings_name="./models/all-mpnet-base-v2", ) else: embedding_instance = EmbeddingsSingleton.get_instance( @@ -87,4 +129,3 @@ class BaseVectorStore(ABC): embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) return embedding_instance -