From 6f47aa802b245a1bf0955234d7b0791f8939f59d Mon Sep 17 00:00:00 2001 From: Ankit Matth Date: Sat, 16 Aug 2025 15:19:19 +0530 Subject: [PATCH 1/6] added support for multi select sources --- application/retriever/classic_rag.py | 62 +++++++++++-------- frontend/src/components/MessageInput.tsx | 10 +-- frontend/src/components/SourcesPopup.tsx | 23 ++++--- .../src/conversation/conversationHandlers.ts | 46 ++++++++++---- frontend/src/hooks/useDefaultDocument.ts | 4 +- frontend/src/preferences/preferenceApi.ts | 33 +++++----- frontend/src/preferences/preferenceSlice.ts | 6 +- 7 files changed, 115 insertions(+), 69 deletions(-) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 9416b4f7..a9e3dee7 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -20,7 +20,7 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): - self.original_question = "" + self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt self.chunks = chunks @@ -44,7 +44,18 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - self.vectorstore = source["active_docs"] if "active_docs" in source else None + if "active_docs" in source: + if isinstance(source["active_docs"], list): + self.vectorstores = source["active_docs"] + elif isinstance(source["active_docs"], str) and "," in source["active_docs"]: + # ✅ split multiple IDs from comma string + self.vectorstores = [doc_id.strip() for doc_id in source["active_docs"].split(",") if doc_id.strip()] + else: + self.vectorstores = [source["active_docs"]] + else: + self.vectorstores = [] + + self.vectorstore = None self.question = self._rephrase_query() self.decoded_token = decoded_token @@ -79,29 +90,30 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): - if self.chunks == 0 or self.vectorstore is None: - docs = [] - else: - docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY - ) - docs_temp = docsearch.search(self.question, k=self.chunks) - docs = [ - { - "title": i.metadata.get( - "title", i.metadata.get("post_title", i.page_content) - ).split("/")[-1], - "text": i.page_content, - "source": ( - i.metadata.get("source") - if i.metadata.get("source") - else "local" - ), - } - for i in docs_temp - ] + if self.chunks == 0 or not self.vectorstores: + return [] - return docs + all_docs = [] + chunks_per_source = max(1, self.chunks // len(self.vectorstores)) + + for vectorstore in self.vectorstores: + if vectorstore: + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY + ) + docs_temp = docsearch.search(self.question, k=chunks_per_source) + for i in docs_temp: + all_docs.append({ + "title": i.metadata.get("title", i.metadata.get("post_title", i.page_content)).split("/")[-1], + "text": i.page_content, + "source": i.metadata.get("source") or vectorstore, + }) + except Exception as e: + logging.error(f"Error searching vectorstore {vectorstore}: {e}") + continue + + return all_docs def gen(): pass @@ -116,7 +128,7 @@ class ClassicRAG(BaseRetriever): return { "question": self.original_question, "rephrased_question": self.question, - "source": self.vectorstore, + "sources": self.vectorstores, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/frontend/src/components/MessageInput.tsx b/frontend/src/components/MessageInput.tsx index d9bcea3e..6ae5678e 100644 --- a/frontend/src/components/MessageInput.tsx +++ b/frontend/src/components/MessageInput.tsx @@ -368,8 +368,8 @@ export default function MessageInput({ className="xs:px-3 xs:py-1.5 dark:border-purple-taupe flex max-w-[130px] items-center rounded-[32px] border border-[#AAAAAA] px-2 py-1 transition-colors hover:bg-gray-100 sm:max-w-[150px] dark:hover:bg-[#2C2E3C]" onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)} title={ - selectedDocs - ? selectedDocs.name + selectedDocs && selectedDocs.length > 0 + ? selectedDocs.map(doc => doc.name).join(', ') : t('conversation.sources.title') } > @@ -379,8 +379,10 @@ export default function MessageInput({ className="mr-1 h-3.5 w-3.5 shrink-0 sm:mr-1.5 sm:h-4" /> - {selectedDocs - ? selectedDocs.name + {selectedDocs && selectedDocs.length > 0 + ? selectedDocs.length === 1 + ? selectedDocs[0].name + : `${selectedDocs.length} sources selected` : t('conversation.sources.title')} {!isTouch && ( diff --git a/frontend/src/components/SourcesPopup.tsx b/frontend/src/components/SourcesPopup.tsx index 906f75cd..c3a61e01 100644 --- a/frontend/src/components/SourcesPopup.tsx +++ b/frontend/src/components/SourcesPopup.tsx @@ -149,9 +149,10 @@ export default function SourcesPopup({ if (option.model === embeddingsName) { const isSelected = selectedDocs && - (option.id - ? selectedDocs.id === option.id - : selectedDocs.date === option.date); + Array.isArray(selectedDocs) && selectedDocs.length > 0 && + selectedDocs.some(doc => + option.id ? doc.id === option.id : doc.date === option.date + ); return (
{ if (isSelected) { - dispatch(setSelectedDocs(null)); - handlePostDocumentSelect(null); + const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) + ? selectedDocs.filter(doc => + option.id ? doc.id !== option.id : doc.date !== option.date + ) + : []; + dispatch(setSelectedDocs(updatedDocs.length > 0 ? updatedDocs : null)); + handlePostDocumentSelect(updatedDocs.length > 0 ? updatedDocs : null); } else { - dispatch(setSelectedDocs(option)); - handlePostDocumentSelect(option); + const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) + ? [...selectedDocs, option] + : [option]; + dispatch(setSelectedDocs(updatedDocs)); + handlePostDocumentSelect(updatedDocs); } }} > diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index fb6e1b59..ae60b070 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -7,7 +7,7 @@ export function handleFetchAnswer( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -52,10 +52,17 @@ export function handleFetchAnswer( payload.attachments = attachments; } - if (selectedDocs && 'id' in selectedDocs) { - payload.active_docs = selectedDocs.id as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } } - payload.retriever = selectedDocs?.retriever as string; return conversationService .answer(payload, token, signal) .then((response) => { @@ -84,7 +91,7 @@ export function handleFetchAnswerSteaming( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -112,10 +119,17 @@ export function handleFetchAnswerSteaming( payload.attachments = attachments; } - if (selectedDocs && 'id' in selectedDocs) { - payload.active_docs = selectedDocs.id as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } } - payload.retriever = selectedDocs?.retriever as string; return new Promise((resolve, reject) => { conversationService @@ -171,7 +185,7 @@ export function handleFetchAnswerSteaming( export function handleSearch( question: string, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversation_id: string | null, chunks: string, token_limit: number, @@ -183,9 +197,17 @@ export function handleSearch( token_limit: token_limit, isNoneDoc: selectedDocs === null, }; - if (selectedDocs && 'id' in selectedDocs) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs?.retriever as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } + } return conversationService .search(payload, token) .then((response) => response.json()) diff --git a/frontend/src/hooks/useDefaultDocument.ts b/frontend/src/hooks/useDefaultDocument.ts index a2642dc5..004e4bb1 100644 --- a/frontend/src/hooks/useDefaultDocument.ts +++ b/frontend/src/hooks/useDefaultDocument.ts @@ -18,11 +18,11 @@ export default function useDefaultDocument() { const fetchDocs = () => { getDocs(token).then((data) => { dispatch(setSourceDocs(data)); - if (!selectedDoc) + if (!selectedDoc || (Array.isArray(selectedDoc) && selectedDoc.length === 0)) Array.isArray(data) && data?.forEach((doc: Doc) => { if (doc.model && doc.name === 'default') { - dispatch(setSelectedDocs(doc)); + dispatch(setSelectedDocs([doc])); } }); }); diff --git a/frontend/src/preferences/preferenceApi.ts b/frontend/src/preferences/preferenceApi.ts index 7fb907b3..40dc4bcc 100644 --- a/frontend/src/preferences/preferenceApi.ts +++ b/frontend/src/preferences/preferenceApi.ts @@ -90,9 +90,9 @@ export function getLocalApiKey(): string | null { return key; } -export function getLocalRecentDocs(): string | null { - const doc = localStorage.getItem('DocsGPTRecentDocs'); - return doc; +export function getLocalRecentDocs(): Doc[] | null { + const docs = localStorage.getItem('DocsGPTRecentDocs'); + return docs ? JSON.parse(docs) as Doc[] : null; } export function getLocalPrompt(): string | null { @@ -108,19 +108,20 @@ export function setLocalPrompt(prompt: string): void { localStorage.setItem('DocsGPTPrompt', prompt); } -export function setLocalRecentDocs(doc: Doc | null): void { - localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(doc)); +export function setLocalRecentDocs(docs: Doc[] | null): void { + if (docs && docs.length > 0) { + localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(docs)); - let docPath = 'default'; - if (doc?.type === 'local') { - docPath = 'local' + '/' + doc.name + '/'; + docs.forEach((doc) => { + let docPath = 'default'; + if (doc.type === 'local') { + docPath = 'local' + '/' + doc.name + '/'; + } + userService + .checkDocs({ docs: docPath }, null) + .then((response) => response.json()); + }); + } else { + localStorage.removeItem('DocsGPTRecentDocs'); } - userService - .checkDocs( - { - docs: docPath, - }, - null, - ) - .then((response) => response.json()); } diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts index a0825039..6abbef4d 100644 --- a/frontend/src/preferences/preferenceSlice.ts +++ b/frontend/src/preferences/preferenceSlice.ts @@ -15,7 +15,7 @@ export interface Preference { prompt: { name: string; id: string; type: string }; chunks: string; token_limit: number; - selectedDocs: Doc | null; + selectedDocs: Doc[] | null; sourceDocs: Doc[] | null; conversations: { data: { name: string; id: string }[] | null; @@ -34,7 +34,7 @@ const initialState: Preference = { prompt: { name: 'default', id: 'default', type: 'public' }, chunks: '2', token_limit: 2000, - selectedDocs: { + selectedDocs: [{ id: 'default', name: 'default', type: 'remote', @@ -42,7 +42,7 @@ const initialState: Preference = { docLink: 'default', model: 'openai_text-embedding-ada-002', retriever: 'classic', - } as Doc, + }] as Doc[], sourceDocs: null, conversations: { data: null, From bd73fa9ae716e9c0a31604aea171591752ffaaef Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 20 Aug 2025 22:25:31 +0530 Subject: [PATCH 2/6] refactor: remove unused abstract method and improve retrievers --- application/retriever/base.py | 4 -- application/retriever/classic_rag.py | 80 +++++++++++++++++++++------- application/vectorstore/base.py | 77 +++++++++++++++++++------- 3 files changed, 121 insertions(+), 40 deletions(-) diff --git a/application/retriever/base.py b/application/retriever/base.py index fd99dbdd..36ac2e93 100644 --- a/application/retriever/base.py +++ b/application/retriever/base.py @@ -5,10 +5,6 @@ class BaseRetriever(ABC): def __init__(self): pass - @abstractmethod - def gen(self, *args, **kwargs): - pass - @abstractmethod def search(self, *args, **kwargs): pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index a9e3dee7..b558c8f0 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,5 @@ import logging + from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever @@ -20,6 +21,7 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): + """Initialize ClassicRAG retriever with vectorstore sources and LLM configuration""" self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt @@ -47,25 +49,46 @@ class ClassicRAG(BaseRetriever): if "active_docs" in source: if isinstance(source["active_docs"], list): self.vectorstores = source["active_docs"] - elif isinstance(source["active_docs"], str) and "," in source["active_docs"]: - # ✅ split multiple IDs from comma string - self.vectorstores = [doc_id.strip() for doc_id in source["active_docs"].split(",") if doc_id.strip()] + elif ( + isinstance(source["active_docs"], str) and "," in source["active_docs"] + ): + self.vectorstores = [ + doc_id.strip() + for doc_id in source["active_docs"].split(",") + if doc_id.strip() + ] else: self.vectorstores = [source["active_docs"]] else: self.vectorstores = [] - self.vectorstore = None self.question = self._rephrase_query() self.decoded_token = decoded_token + self._validate_vectorstore_config() + + def _validate_vectorstore_config(self): + """Validate vectorstore IDs and remove any empty/invalid entries""" + if not self.vectorstores: + logging.warning("No vectorstores configured for retrieval") + return + + invalid_ids = [ + vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip() + ] + if invalid_ids: + logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}") + self.vectorstores = [ + vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip() + ] def _rephrase_query(self): + """Rephrase user query with chat history context for better retrieval""" if ( not self.original_question or not self.chat_history or self.chat_history == [] or self.chunks == 0 - or self.vectorstore is None + or not self.vectorstores ): return self.original_question @@ -90,41 +113,62 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): + """Retrieve relevant documents from configured vectorstores""" if self.chunks == 0 or not self.vectorstores: return [] all_docs = [] chunks_per_source = max(1, self.chunks // len(self.vectorstores)) - for vectorstore in self.vectorstores: - if vectorstore: + for vectorstore_id in self.vectorstores: + if vectorstore_id: try: docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY + settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY ) docs_temp = docsearch.search(self.question, k=chunks_per_source) - for i in docs_temp: - all_docs.append({ - "title": i.metadata.get("title", i.metadata.get("post_title", i.page_content)).split("/")[-1], - "text": i.page_content, - "source": i.metadata.get("source") or vectorstore, - }) + + for doc in docs_temp: + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + + title = metadata.get( + "title", metadata.get("post_title", page_content) + ) + if isinstance(title, str): + title = title.split("/")[-1] + else: + title = str(title).split("/")[-1] + + all_docs.append( + { + "title": title, + "text": page_content, + "source": metadata.get("source") or vectorstore_id, + } + ) except Exception as e: - logging.error(f"Error searching vectorstore {vectorstore}: {e}") + logging.error( + f"Error searching vectorstore {vectorstore_id}: {e}", + exc_info=True, + ) continue return all_docs - def gen(): - pass - def search(self, query: str = ""): + """Search for documents using optional query override""" if query: self.original_question = query self.question = self._rephrase_query() return self._get_data() def get_params(self): + """Return current retriever configuration parameters""" return { "question": self.original_question, "rephrased_question": self.question, diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index a6b206c9..ea4885cd 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -1,20 +1,28 @@ -from abc import ABC, abstractmethod import os -from sentence_transformers import SentenceTransformer +from abc import ABC, abstractmethod + from langchain_openai import OpenAIEmbeddings +from sentence_transformers import SentenceTransformer + from application.core.settings import settings + class EmbeddingsWrapper: def __init__(self, model_name, *args, **kwargs): - self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs) + self.model = SentenceTransformer( + model_name, + config_kwargs={"allow_dangerous_deserialization": True}, + *args, + **kwargs + ) self.dimension = self.model.get_sentence_embedding_dimension() def embed_query(self, query: str): return self.model.encode(query).tolist() - + def embed_documents(self, documents: list): return self.model.encode(documents).tolist() - + def __call__(self, text): if isinstance(text, str): return self.embed_query(text) @@ -24,15 +32,14 @@ class EmbeddingsWrapper: raise ValueError("Input must be a string or a list of strings") - class EmbeddingsSingleton: _instances = {} @staticmethod def get_instance(embeddings_name, *args, **kwargs): if embeddings_name not in EmbeddingsSingleton._instances: - EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance( - embeddings_name, *args, **kwargs + EmbeddingsSingleton._instances[embeddings_name] = ( + EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs) ) return EmbeddingsSingleton._instances[embeddings_name] @@ -40,9 +47,15 @@ class EmbeddingsSingleton: def _create_instance(embeddings_name, *args, **kwargs): embeddings_factory = { "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"), + "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper( + "hkunlp/instructor-large" + ), } if embeddings_name in embeddings_factory: @@ -50,34 +63,63 @@ class EmbeddingsSingleton: else: return EmbeddingsWrapper(embeddings_name, *args, **kwargs) + class BaseVectorStore(ABC): def __init__(self): pass @abstractmethod def search(self, *args, **kwargs): + """Search for similar documents/chunks in the vectorstore""" + pass + + @abstractmethod + def add_texts(self, texts, metadatas=None, *args, **kwargs): + """Add texts with their embeddings to the vectorstore""" + pass + + def delete_index(self, *args, **kwargs): + """Delete the entire index/collection""" + pass + + def save_local(self, *args, **kwargs): + """Save vectorstore to local storage""" + pass + + def get_chunks(self, *args, **kwargs): + """Get all chunks from the vectorstore""" + pass + + def add_chunk(self, text, metadata=None, *args, **kwargs): + """Add a single chunk to the vectorstore""" + pass + + def delete_chunk(self, chunk_id, *args, **kwargs): + """Delete a specific chunk from the vectorstore""" pass def is_azure_configured(self): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) def _get_embeddings(self, embeddings_name, embeddings_key=None): if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME + embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME ) else: embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - openai_api_key=embeddings_key + embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./models/all-mpnet-base-v2"): embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name = "./models/all-mpnet-base-v2", + embeddings_name="./models/all-mpnet-base-v2", ) else: embedding_instance = EmbeddingsSingleton.get_instance( @@ -87,4 +129,3 @@ class BaseVectorStore(ABC): embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) return embedding_instance - From 07d59b66406e4b5bfb193056f956abbbdb85c322 Mon Sep 17 00:00:00 2001 From: Ankit Matth Date: Sat, 23 Aug 2025 20:25:29 +0530 Subject: [PATCH 3/6] refactor: use list instead of string parsing --- application/retriever/classic_rag.py | 11 +---- frontend/src/components/SourcesPopup.tsx | 41 +++++++++++------- .../src/conversation/conversationHandlers.ts | 42 +++++++++---------- .../src/conversation/conversationModels.ts | 2 +- .../src/modals/ShareConversationModal.tsx | 14 +++---- frontend/src/preferences/preferenceSlice.ts | 19 +++++---- 6 files changed, 68 insertions(+), 61 deletions(-) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index b558c8f0..82423bb5 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -46,17 +46,10 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - if "active_docs" in source: + + if "active_docs" in source and source["active_docs"] is not None: if isinstance(source["active_docs"], list): self.vectorstores = source["active_docs"] - elif ( - isinstance(source["active_docs"], str) and "," in source["active_docs"] - ): - self.vectorstores = [ - doc_id.strip() - for doc_id in source["active_docs"].split(",") - if doc_id.strip() - ] else: self.vectorstores = [source["active_docs"]] else: diff --git a/frontend/src/components/SourcesPopup.tsx b/frontend/src/components/SourcesPopup.tsx index c3a61e01..f13ee25a 100644 --- a/frontend/src/components/SourcesPopup.tsx +++ b/frontend/src/components/SourcesPopup.tsx @@ -17,7 +17,7 @@ type SourcesPopupProps = { isOpen: boolean; onClose: () => void; anchorRef: React.RefObject; - handlePostDocumentSelect: (doc: Doc | null) => void; + handlePostDocumentSelect: (doc: Doc[] | null) => void; setUploadModalState: React.Dispatch>; }; @@ -149,9 +149,12 @@ export default function SourcesPopup({ if (option.model === embeddingsName) { const isSelected = selectedDocs && - Array.isArray(selectedDocs) && selectedDocs.length > 0 && - selectedDocs.some(doc => - option.id ? doc.id === option.id : doc.date === option.date + Array.isArray(selectedDocs) && + selectedDocs.length > 0 && + selectedDocs.some((doc) => + option.id + ? doc.id === option.id + : doc.date === option.date, ); return ( @@ -160,17 +163,27 @@ export default function SourcesPopup({ className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]" onClick={() => { if (isSelected) { - const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) - ? selectedDocs.filter(doc => - option.id ? doc.id !== option.id : doc.date !== option.date - ) - : []; - dispatch(setSelectedDocs(updatedDocs.length > 0 ? updatedDocs : null)); - handlePostDocumentSelect(updatedDocs.length > 0 ? updatedDocs : null); + const updatedDocs = + selectedDocs && Array.isArray(selectedDocs) + ? selectedDocs.filter((doc) => + option.id + ? doc.id !== option.id + : doc.date !== option.date, + ) + : []; + dispatch( + setSelectedDocs( + updatedDocs.length > 0 ? updatedDocs : null, + ), + ); + handlePostDocumentSelect( + updatedDocs.length > 0 ? updatedDocs : null, + ); } else { - const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) - ? [...selectedDocs, option] - : [option]; + const updatedDocs = + selectedDocs && Array.isArray(selectedDocs) + ? [...selectedDocs, option] + : [option]; dispatch(setSelectedDocs(updatedDocs)); handlePostDocumentSelect(updatedDocs); } diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index ae60b070..63557924 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -7,7 +7,7 @@ export function handleFetchAnswer( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | Doc[] | null, + selectedDocs: Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -52,15 +52,15 @@ export function handleFetchAnswer( payload.attachments = attachments; } - if (selectedDocs) { - if (Array.isArray(selectedDocs)) { + if (selectedDocs && Array.isArray(selectedDocs)) { + if (selectedDocs.length > 1) { // Handle multiple documents - payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.active_docs = selectedDocs.map((doc) => doc.id!); payload.retriever = selectedDocs[0]?.retriever as string; - } else if ('id' in selectedDocs) { + } else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) { // Handle single document (backward compatibility) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs.retriever as string; + payload.active_docs = selectedDocs[0].id as string; + payload.retriever = selectedDocs[0].retriever as string; } } return conversationService @@ -91,7 +91,7 @@ export function handleFetchAnswerSteaming( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | Doc[] | null, + selectedDocs: Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -119,15 +119,15 @@ export function handleFetchAnswerSteaming( payload.attachments = attachments; } - if (selectedDocs) { - if (Array.isArray(selectedDocs)) { + if (selectedDocs && Array.isArray(selectedDocs)) { + if (selectedDocs.length > 1) { // Handle multiple documents - payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.active_docs = selectedDocs.map((doc) => doc.id!); payload.retriever = selectedDocs[0]?.retriever as string; - } else if ('id' in selectedDocs) { + } else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) { // Handle single document (backward compatibility) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs.retriever as string; + payload.active_docs = selectedDocs[0].id as string; + payload.retriever = selectedDocs[0].retriever as string; } } @@ -185,7 +185,7 @@ export function handleFetchAnswerSteaming( export function handleSearch( question: string, token: string | null, - selectedDocs: Doc | Doc[] | null, + selectedDocs: Doc[] | null, conversation_id: string | null, chunks: string, token_limit: number, @@ -197,15 +197,15 @@ export function handleSearch( token_limit: token_limit, isNoneDoc: selectedDocs === null, }; - if (selectedDocs) { - if (Array.isArray(selectedDocs)) { + if (selectedDocs && Array.isArray(selectedDocs)) { + if (selectedDocs.length > 1) { // Handle multiple documents - payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.active_docs = selectedDocs.map((doc) => doc.id!); payload.retriever = selectedDocs[0]?.retriever as string; - } else if ('id' in selectedDocs) { + } else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) { // Handle single document (backward compatibility) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs.retriever as string; + payload.active_docs = selectedDocs[0].id as string; + payload.retriever = selectedDocs[0].retriever as string; } } return conversationService diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index 08743e73..2b9f6ee3 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -54,7 +54,7 @@ export interface Query { export interface RetrievalPayload { question: string; - active_docs?: string; + active_docs?: string | string[]; retriever?: string; conversation_id: string | null; prompt_id?: string | null; diff --git a/frontend/src/modals/ShareConversationModal.tsx b/frontend/src/modals/ShareConversationModal.tsx index 99262f01..624d64f5 100644 --- a/frontend/src/modals/ShareConversationModal.tsx +++ b/frontend/src/modals/ShareConversationModal.tsx @@ -60,7 +60,7 @@ export const ShareConversationModal = ({ const [sourcePath, setSourcePath] = useState<{ label: string; value: string; - } | null>(preSelectedDoc ? extractDocPaths([preSelectedDoc])[0] : null); + } | null>(preSelectedDoc ? extractDocPaths(preSelectedDoc)[0] : null); const handleCopyKey = (url: string) => { navigator.clipboard.writeText(url); @@ -105,14 +105,14 @@ export const ShareConversationModal = ({ return (
-

+

{t('modals.shareConv.label')}

-

+

{t('modals.shareConv.note')}

- + {t('modals.shareConv.option')} )}
- + {`${domain}/share/${identifier ?? '....'}`} {status === 'fetched' ? ( ) : ( ) => { setSelectedSourceIds(newSelectedIds); - setIsSourcePopupOpen(false); }} - title="Select Source" + title="Select Sources" searchPlaceholder="Search sources..." - noOptionsMessage="No source available" - singleSelect={true} + noOptionsMessage="No sources available" />
diff --git a/frontend/src/agents/types/index.ts b/frontend/src/agents/types/index.ts index e841cb0a..442097a1 100644 --- a/frontend/src/agents/types/index.ts +++ b/frontend/src/agents/types/index.ts @@ -10,6 +10,7 @@ export type Agent = { description: string; image: string; source: string; + sources?: string[]; chunks: string; retriever: string; prompt_id: string; From adcdce8d764ca4de31009af52487fe40581156fd Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 10 Sep 2025 22:10:11 +0530 Subject: [PATCH 5/6] fix: handle invalid chunks value in StreamProcessor and ClassicRAG --- .../api/answer/services/stream_processor.py | 16 ++++++++++++++-- application/retriever/classic_rag.py | 11 ++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index a04020cb..f6e639ef 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -266,7 +266,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 elif self.agent_key: data_key = self._get_data_from_api_key(self.agent_key) self.agent_config.update( @@ -287,7 +293,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 else: self.agent_config.update( { diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index ce1b937b..2ce863c2 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -25,7 +25,16 @@ class ClassicRAG(BaseRetriever): self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt - self.chunks = chunks + if isinstance(chunks, str): + try: + self.chunks = int(chunks) + except ValueError: + logging.warning( + f"Invalid chunks value '{chunks}', using default value 2" + ) + self.chunks = 2 + else: + self.chunks = chunks self.gpt_model = gpt_model self.token_limit = ( token_limit From 188d118fc0c689870b0c89ff3d32d5721f24c757 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 10 Sep 2025 22:14:31 +0530 Subject: [PATCH 6/6] refactor: remove unused logging import from routes.py --- application/api/connector/routes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index f203a703..1647aa78 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -1,6 +1,5 @@ import datetime import json -import logging from bson.objectid import ObjectId