From 6f47aa802b245a1bf0955234d7b0791f8939f59d Mon Sep 17 00:00:00 2001 From: Ankit Matth Date: Sat, 16 Aug 2025 15:19:19 +0530 Subject: [PATCH] added support for multi select sources --- application/retriever/classic_rag.py | 62 +++++++++++-------- frontend/src/components/MessageInput.tsx | 10 +-- frontend/src/components/SourcesPopup.tsx | 23 ++++--- .../src/conversation/conversationHandlers.ts | 46 ++++++++++---- frontend/src/hooks/useDefaultDocument.ts | 4 +- frontend/src/preferences/preferenceApi.ts | 33 +++++----- frontend/src/preferences/preferenceSlice.ts | 6 +- 7 files changed, 115 insertions(+), 69 deletions(-) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 9416b4f7..a9e3dee7 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -20,7 +20,7 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): - self.original_question = "" + self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt self.chunks = chunks @@ -44,7 +44,18 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - self.vectorstore = source["active_docs"] if "active_docs" in source else None + if "active_docs" in source: + if isinstance(source["active_docs"], list): + self.vectorstores = source["active_docs"] + elif isinstance(source["active_docs"], str) and "," in source["active_docs"]: + # ✅ split multiple IDs from comma string + self.vectorstores = [doc_id.strip() for doc_id in source["active_docs"].split(",") if doc_id.strip()] + else: + self.vectorstores = [source["active_docs"]] + else: + self.vectorstores = [] + + self.vectorstore = None self.question = self._rephrase_query() self.decoded_token = decoded_token @@ -79,29 +90,30 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): - if self.chunks == 0 or self.vectorstore is None: - docs = [] - else: - docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY - ) - docs_temp = docsearch.search(self.question, k=self.chunks) - docs = [ - { - "title": i.metadata.get( - "title", i.metadata.get("post_title", i.page_content) - ).split("/")[-1], - "text": i.page_content, - "source": ( - i.metadata.get("source") - if i.metadata.get("source") - else "local" - ), - } - for i in docs_temp - ] + if self.chunks == 0 or not self.vectorstores: + return [] - return docs + all_docs = [] + chunks_per_source = max(1, self.chunks // len(self.vectorstores)) + + for vectorstore in self.vectorstores: + if vectorstore: + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY + ) + docs_temp = docsearch.search(self.question, k=chunks_per_source) + for i in docs_temp: + all_docs.append({ + "title": i.metadata.get("title", i.metadata.get("post_title", i.page_content)).split("/")[-1], + "text": i.page_content, + "source": i.metadata.get("source") or vectorstore, + }) + except Exception as e: + logging.error(f"Error searching vectorstore {vectorstore}: {e}") + continue + + return all_docs def gen(): pass @@ -116,7 +128,7 @@ class ClassicRAG(BaseRetriever): return { "question": self.original_question, "rephrased_question": self.question, - "source": self.vectorstore, + "sources": self.vectorstores, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/frontend/src/components/MessageInput.tsx b/frontend/src/components/MessageInput.tsx index d9bcea3e..6ae5678e 100644 --- a/frontend/src/components/MessageInput.tsx +++ b/frontend/src/components/MessageInput.tsx @@ -368,8 +368,8 @@ export default function MessageInput({ className="xs:px-3 xs:py-1.5 dark:border-purple-taupe flex max-w-[130px] items-center rounded-[32px] border border-[#AAAAAA] px-2 py-1 transition-colors hover:bg-gray-100 sm:max-w-[150px] dark:hover:bg-[#2C2E3C]" onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)} title={ - selectedDocs - ? selectedDocs.name + selectedDocs && selectedDocs.length > 0 + ? selectedDocs.map(doc => doc.name).join(', ') : t('conversation.sources.title') } > @@ -379,8 +379,10 @@ export default function MessageInput({ className="mr-1 h-3.5 w-3.5 shrink-0 sm:mr-1.5 sm:h-4" /> - {selectedDocs - ? selectedDocs.name + {selectedDocs && selectedDocs.length > 0 + ? selectedDocs.length === 1 + ? selectedDocs[0].name + : `${selectedDocs.length} sources selected` : t('conversation.sources.title')} {!isTouch && ( diff --git a/frontend/src/components/SourcesPopup.tsx b/frontend/src/components/SourcesPopup.tsx index 906f75cd..c3a61e01 100644 --- a/frontend/src/components/SourcesPopup.tsx +++ b/frontend/src/components/SourcesPopup.tsx @@ -149,9 +149,10 @@ export default function SourcesPopup({ if (option.model === embeddingsName) { const isSelected = selectedDocs && - (option.id - ? selectedDocs.id === option.id - : selectedDocs.date === option.date); + Array.isArray(selectedDocs) && selectedDocs.length > 0 && + selectedDocs.some(doc => + option.id ? doc.id === option.id : doc.date === option.date + ); return (
{ if (isSelected) { - dispatch(setSelectedDocs(null)); - handlePostDocumentSelect(null); + const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) + ? selectedDocs.filter(doc => + option.id ? doc.id !== option.id : doc.date !== option.date + ) + : []; + dispatch(setSelectedDocs(updatedDocs.length > 0 ? updatedDocs : null)); + handlePostDocumentSelect(updatedDocs.length > 0 ? updatedDocs : null); } else { - dispatch(setSelectedDocs(option)); - handlePostDocumentSelect(option); + const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) + ? [...selectedDocs, option] + : [option]; + dispatch(setSelectedDocs(updatedDocs)); + handlePostDocumentSelect(updatedDocs); } }} > diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index fb6e1b59..ae60b070 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -7,7 +7,7 @@ export function handleFetchAnswer( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -52,10 +52,17 @@ export function handleFetchAnswer( payload.attachments = attachments; } - if (selectedDocs && 'id' in selectedDocs) { - payload.active_docs = selectedDocs.id as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } } - payload.retriever = selectedDocs?.retriever as string; return conversationService .answer(payload, token, signal) .then((response) => { @@ -84,7 +91,7 @@ export function handleFetchAnswerSteaming( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -112,10 +119,17 @@ export function handleFetchAnswerSteaming( payload.attachments = attachments; } - if (selectedDocs && 'id' in selectedDocs) { - payload.active_docs = selectedDocs.id as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } } - payload.retriever = selectedDocs?.retriever as string; return new Promise((resolve, reject) => { conversationService @@ -171,7 +185,7 @@ export function handleFetchAnswerSteaming( export function handleSearch( question: string, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversation_id: string | null, chunks: string, token_limit: number, @@ -183,9 +197,17 @@ export function handleSearch( token_limit: token_limit, isNoneDoc: selectedDocs === null, }; - if (selectedDocs && 'id' in selectedDocs) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs?.retriever as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } + } return conversationService .search(payload, token) .then((response) => response.json()) diff --git a/frontend/src/hooks/useDefaultDocument.ts b/frontend/src/hooks/useDefaultDocument.ts index a2642dc5..004e4bb1 100644 --- a/frontend/src/hooks/useDefaultDocument.ts +++ b/frontend/src/hooks/useDefaultDocument.ts @@ -18,11 +18,11 @@ export default function useDefaultDocument() { const fetchDocs = () => { getDocs(token).then((data) => { dispatch(setSourceDocs(data)); - if (!selectedDoc) + if (!selectedDoc || (Array.isArray(selectedDoc) && selectedDoc.length === 0)) Array.isArray(data) && data?.forEach((doc: Doc) => { if (doc.model && doc.name === 'default') { - dispatch(setSelectedDocs(doc)); + dispatch(setSelectedDocs([doc])); } }); }); diff --git a/frontend/src/preferences/preferenceApi.ts b/frontend/src/preferences/preferenceApi.ts index 7fb907b3..40dc4bcc 100644 --- a/frontend/src/preferences/preferenceApi.ts +++ b/frontend/src/preferences/preferenceApi.ts @@ -90,9 +90,9 @@ export function getLocalApiKey(): string | null { return key; } -export function getLocalRecentDocs(): string | null { - const doc = localStorage.getItem('DocsGPTRecentDocs'); - return doc; +export function getLocalRecentDocs(): Doc[] | null { + const docs = localStorage.getItem('DocsGPTRecentDocs'); + return docs ? JSON.parse(docs) as Doc[] : null; } export function getLocalPrompt(): string | null { @@ -108,19 +108,20 @@ export function setLocalPrompt(prompt: string): void { localStorage.setItem('DocsGPTPrompt', prompt); } -export function setLocalRecentDocs(doc: Doc | null): void { - localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(doc)); +export function setLocalRecentDocs(docs: Doc[] | null): void { + if (docs && docs.length > 0) { + localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(docs)); - let docPath = 'default'; - if (doc?.type === 'local') { - docPath = 'local' + '/' + doc.name + '/'; + docs.forEach((doc) => { + let docPath = 'default'; + if (doc.type === 'local') { + docPath = 'local' + '/' + doc.name + '/'; + } + userService + .checkDocs({ docs: docPath }, null) + .then((response) => response.json()); + }); + } else { + localStorage.removeItem('DocsGPTRecentDocs'); } - userService - .checkDocs( - { - docs: docPath, - }, - null, - ) - .then((response) => response.json()); } diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts index a0825039..6abbef4d 100644 --- a/frontend/src/preferences/preferenceSlice.ts +++ b/frontend/src/preferences/preferenceSlice.ts @@ -15,7 +15,7 @@ export interface Preference { prompt: { name: string; id: string; type: string }; chunks: string; token_limit: number; - selectedDocs: Doc | null; + selectedDocs: Doc[] | null; sourceDocs: Doc[] | null; conversations: { data: { name: string; id: string }[] | null; @@ -34,7 +34,7 @@ const initialState: Preference = { prompt: { name: 'default', id: 'default', type: 'public' }, chunks: '2', token_limit: 2000, - selectedDocs: { + selectedDocs: [{ id: 'default', name: 'default', type: 'remote', @@ -42,7 +42,7 @@ const initialState: Preference = { docLink: 'default', model: 'openai_text-embedding-ada-002', retriever: 'classic', - } as Doc, + }] as Doc[], sourceDocs: null, conversations: { data: null,