From 6f47aa802b245a1bf0955234d7b0791f8939f59d Mon Sep 17 00:00:00 2001 From: Ankit Matth Date: Sat, 16 Aug 2025 15:19:19 +0530 Subject: [PATCH 01/13] added support for multi select sources --- application/retriever/classic_rag.py | 62 +++++++++++-------- frontend/src/components/MessageInput.tsx | 10 +-- frontend/src/components/SourcesPopup.tsx | 23 ++++--- .../src/conversation/conversationHandlers.ts | 46 ++++++++++---- frontend/src/hooks/useDefaultDocument.ts | 4 +- frontend/src/preferences/preferenceApi.ts | 33 +++++----- frontend/src/preferences/preferenceSlice.ts | 6 +- 7 files changed, 115 insertions(+), 69 deletions(-) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 9416b4f7..a9e3dee7 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -20,7 +20,7 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): - self.original_question = "" + self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt self.chunks = chunks @@ -44,7 +44,18 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - self.vectorstore = source["active_docs"] if "active_docs" in source else None + if "active_docs" in source: + if isinstance(source["active_docs"], list): + self.vectorstores = source["active_docs"] + elif isinstance(source["active_docs"], str) and "," in source["active_docs"]: + # ✅ split multiple IDs from comma string + self.vectorstores = [doc_id.strip() for doc_id in source["active_docs"].split(",") if doc_id.strip()] + else: + self.vectorstores = [source["active_docs"]] + else: + self.vectorstores = [] + + self.vectorstore = None self.question = self._rephrase_query() self.decoded_token = decoded_token @@ -79,29 +90,30 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): - if self.chunks == 0 or self.vectorstore is None: - docs = [] - else: - docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY - ) - docs_temp = docsearch.search(self.question, k=self.chunks) - docs = [ - { - "title": i.metadata.get( - "title", i.metadata.get("post_title", i.page_content) - ).split("/")[-1], - "text": i.page_content, - "source": ( - i.metadata.get("source") - if i.metadata.get("source") - else "local" - ), - } - for i in docs_temp - ] + if self.chunks == 0 or not self.vectorstores: + return [] - return docs + all_docs = [] + chunks_per_source = max(1, self.chunks // len(self.vectorstores)) + + for vectorstore in self.vectorstores: + if vectorstore: + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY + ) + docs_temp = docsearch.search(self.question, k=chunks_per_source) + for i in docs_temp: + all_docs.append({ + "title": i.metadata.get("title", i.metadata.get("post_title", i.page_content)).split("/")[-1], + "text": i.page_content, + "source": i.metadata.get("source") or vectorstore, + }) + except Exception as e: + logging.error(f"Error searching vectorstore {vectorstore}: {e}") + continue + + return all_docs def gen(): pass @@ -116,7 +128,7 @@ class ClassicRAG(BaseRetriever): return { "question": self.original_question, "rephrased_question": self.question, - "source": self.vectorstore, + "sources": self.vectorstores, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/frontend/src/components/MessageInput.tsx b/frontend/src/components/MessageInput.tsx index d9bcea3e..6ae5678e 100644 --- a/frontend/src/components/MessageInput.tsx +++ b/frontend/src/components/MessageInput.tsx @@ -368,8 +368,8 @@ export default function MessageInput({ className="xs:px-3 xs:py-1.5 dark:border-purple-taupe flex max-w-[130px] items-center rounded-[32px] border border-[#AAAAAA] px-2 py-1 transition-colors hover:bg-gray-100 sm:max-w-[150px] dark:hover:bg-[#2C2E3C]" onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)} title={ - selectedDocs - ? selectedDocs.name + selectedDocs && selectedDocs.length > 0 + ? selectedDocs.map(doc => doc.name).join(', ') : t('conversation.sources.title') } > @@ -379,8 +379,10 @@ export default function MessageInput({ className="mr-1 h-3.5 w-3.5 shrink-0 sm:mr-1.5 sm:h-4" /> - {selectedDocs - ? selectedDocs.name + {selectedDocs && selectedDocs.length > 0 + ? selectedDocs.length === 1 + ? selectedDocs[0].name + : `${selectedDocs.length} sources selected` : t('conversation.sources.title')} {!isTouch && ( diff --git a/frontend/src/components/SourcesPopup.tsx b/frontend/src/components/SourcesPopup.tsx index 906f75cd..c3a61e01 100644 --- a/frontend/src/components/SourcesPopup.tsx +++ b/frontend/src/components/SourcesPopup.tsx @@ -149,9 +149,10 @@ export default function SourcesPopup({ if (option.model === embeddingsName) { const isSelected = selectedDocs && - (option.id - ? selectedDocs.id === option.id - : selectedDocs.date === option.date); + Array.isArray(selectedDocs) && selectedDocs.length > 0 && + selectedDocs.some(doc => + option.id ? doc.id === option.id : doc.date === option.date + ); return (
{ if (isSelected) { - dispatch(setSelectedDocs(null)); - handlePostDocumentSelect(null); + const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) + ? selectedDocs.filter(doc => + option.id ? doc.id !== option.id : doc.date !== option.date + ) + : []; + dispatch(setSelectedDocs(updatedDocs.length > 0 ? updatedDocs : null)); + handlePostDocumentSelect(updatedDocs.length > 0 ? updatedDocs : null); } else { - dispatch(setSelectedDocs(option)); - handlePostDocumentSelect(option); + const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) + ? [...selectedDocs, option] + : [option]; + dispatch(setSelectedDocs(updatedDocs)); + handlePostDocumentSelect(updatedDocs); } }} > diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index fb6e1b59..ae60b070 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -7,7 +7,7 @@ export function handleFetchAnswer( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -52,10 +52,17 @@ export function handleFetchAnswer( payload.attachments = attachments; } - if (selectedDocs && 'id' in selectedDocs) { - payload.active_docs = selectedDocs.id as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } } - payload.retriever = selectedDocs?.retriever as string; return conversationService .answer(payload, token, signal) .then((response) => { @@ -84,7 +91,7 @@ export function handleFetchAnswerSteaming( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -112,10 +119,17 @@ export function handleFetchAnswerSteaming( payload.attachments = attachments; } - if (selectedDocs && 'id' in selectedDocs) { - payload.active_docs = selectedDocs.id as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } } - payload.retriever = selectedDocs?.retriever as string; return new Promise((resolve, reject) => { conversationService @@ -171,7 +185,7 @@ export function handleFetchAnswerSteaming( export function handleSearch( question: string, token: string | null, - selectedDocs: Doc | null, + selectedDocs: Doc | Doc[] | null, conversation_id: string | null, chunks: string, token_limit: number, @@ -183,9 +197,17 @@ export function handleSearch( token_limit: token_limit, isNoneDoc: selectedDocs === null, }; - if (selectedDocs && 'id' in selectedDocs) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs?.retriever as string; + if (selectedDocs) { + if (Array.isArray(selectedDocs)) { + // Handle multiple documents + payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.retriever = selectedDocs[0]?.retriever as string; + } else if ('id' in selectedDocs) { + // Handle single document (backward compatibility) + payload.active_docs = selectedDocs.id as string; + payload.retriever = selectedDocs.retriever as string; + } + } return conversationService .search(payload, token) .then((response) => response.json()) diff --git a/frontend/src/hooks/useDefaultDocument.ts b/frontend/src/hooks/useDefaultDocument.ts index a2642dc5..004e4bb1 100644 --- a/frontend/src/hooks/useDefaultDocument.ts +++ b/frontend/src/hooks/useDefaultDocument.ts @@ -18,11 +18,11 @@ export default function useDefaultDocument() { const fetchDocs = () => { getDocs(token).then((data) => { dispatch(setSourceDocs(data)); - if (!selectedDoc) + if (!selectedDoc || (Array.isArray(selectedDoc) && selectedDoc.length === 0)) Array.isArray(data) && data?.forEach((doc: Doc) => { if (doc.model && doc.name === 'default') { - dispatch(setSelectedDocs(doc)); + dispatch(setSelectedDocs([doc])); } }); }); diff --git a/frontend/src/preferences/preferenceApi.ts b/frontend/src/preferences/preferenceApi.ts index 7fb907b3..40dc4bcc 100644 --- a/frontend/src/preferences/preferenceApi.ts +++ b/frontend/src/preferences/preferenceApi.ts @@ -90,9 +90,9 @@ export function getLocalApiKey(): string | null { return key; } -export function getLocalRecentDocs(): string | null { - const doc = localStorage.getItem('DocsGPTRecentDocs'); - return doc; +export function getLocalRecentDocs(): Doc[] | null { + const docs = localStorage.getItem('DocsGPTRecentDocs'); + return docs ? JSON.parse(docs) as Doc[] : null; } export function getLocalPrompt(): string | null { @@ -108,19 +108,20 @@ export function setLocalPrompt(prompt: string): void { localStorage.setItem('DocsGPTPrompt', prompt); } -export function setLocalRecentDocs(doc: Doc | null): void { - localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(doc)); +export function setLocalRecentDocs(docs: Doc[] | null): void { + if (docs && docs.length > 0) { + localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(docs)); - let docPath = 'default'; - if (doc?.type === 'local') { - docPath = 'local' + '/' + doc.name + '/'; + docs.forEach((doc) => { + let docPath = 'default'; + if (doc.type === 'local') { + docPath = 'local' + '/' + doc.name + '/'; + } + userService + .checkDocs({ docs: docPath }, null) + .then((response) => response.json()); + }); + } else { + localStorage.removeItem('DocsGPTRecentDocs'); } - userService - .checkDocs( - { - docs: docPath, - }, - null, - ) - .then((response) => response.json()); } diff --git a/frontend/src/preferences/preferenceSlice.ts b/frontend/src/preferences/preferenceSlice.ts index a0825039..6abbef4d 100644 --- a/frontend/src/preferences/preferenceSlice.ts +++ b/frontend/src/preferences/preferenceSlice.ts @@ -15,7 +15,7 @@ export interface Preference { prompt: { name: string; id: string; type: string }; chunks: string; token_limit: number; - selectedDocs: Doc | null; + selectedDocs: Doc[] | null; sourceDocs: Doc[] | null; conversations: { data: { name: string; id: string }[] | null; @@ -34,7 +34,7 @@ const initialState: Preference = { prompt: { name: 'default', id: 'default', type: 'public' }, chunks: '2', token_limit: 2000, - selectedDocs: { + selectedDocs: [{ id: 'default', name: 'default', type: 'remote', @@ -42,7 +42,7 @@ const initialState: Preference = { docLink: 'default', model: 'openai_text-embedding-ada-002', retriever: 'classic', - } as Doc, + }] as Doc[], sourceDocs: null, conversations: { data: null, From bd73fa9ae716e9c0a31604aea171591752ffaaef Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 20 Aug 2025 22:25:31 +0530 Subject: [PATCH 02/13] refactor: remove unused abstract method and improve retrievers --- application/retriever/base.py | 4 -- application/retriever/classic_rag.py | 80 +++++++++++++++++++++------- application/vectorstore/base.py | 77 +++++++++++++++++++------- 3 files changed, 121 insertions(+), 40 deletions(-) diff --git a/application/retriever/base.py b/application/retriever/base.py index fd99dbdd..36ac2e93 100644 --- a/application/retriever/base.py +++ b/application/retriever/base.py @@ -5,10 +5,6 @@ class BaseRetriever(ABC): def __init__(self): pass - @abstractmethod - def gen(self, *args, **kwargs): - pass - @abstractmethod def search(self, *args, **kwargs): pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index a9e3dee7..b558c8f0 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,4 +1,5 @@ import logging + from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever @@ -20,6 +21,7 @@ class ClassicRAG(BaseRetriever): api_key=settings.API_KEY, decoded_token=None, ): + """Initialize ClassicRAG retriever with vectorstore sources and LLM configuration""" self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt @@ -47,25 +49,46 @@ class ClassicRAG(BaseRetriever): if "active_docs" in source: if isinstance(source["active_docs"], list): self.vectorstores = source["active_docs"] - elif isinstance(source["active_docs"], str) and "," in source["active_docs"]: - # ✅ split multiple IDs from comma string - self.vectorstores = [doc_id.strip() for doc_id in source["active_docs"].split(",") if doc_id.strip()] + elif ( + isinstance(source["active_docs"], str) and "," in source["active_docs"] + ): + self.vectorstores = [ + doc_id.strip() + for doc_id in source["active_docs"].split(",") + if doc_id.strip() + ] else: self.vectorstores = [source["active_docs"]] else: self.vectorstores = [] - self.vectorstore = None self.question = self._rephrase_query() self.decoded_token = decoded_token + self._validate_vectorstore_config() + + def _validate_vectorstore_config(self): + """Validate vectorstore IDs and remove any empty/invalid entries""" + if not self.vectorstores: + logging.warning("No vectorstores configured for retrieval") + return + + invalid_ids = [ + vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip() + ] + if invalid_ids: + logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}") + self.vectorstores = [ + vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip() + ] def _rephrase_query(self): + """Rephrase user query with chat history context for better retrieval""" if ( not self.original_question or not self.chat_history or self.chat_history == [] or self.chunks == 0 - or self.vectorstore is None + or not self.vectorstores ): return self.original_question @@ -90,41 +113,62 @@ class ClassicRAG(BaseRetriever): return self.original_question def _get_data(self): + """Retrieve relevant documents from configured vectorstores""" if self.chunks == 0 or not self.vectorstores: return [] all_docs = [] chunks_per_source = max(1, self.chunks // len(self.vectorstores)) - for vectorstore in self.vectorstores: - if vectorstore: + for vectorstore_id in self.vectorstores: + if vectorstore_id: try: docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY + settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY ) docs_temp = docsearch.search(self.question, k=chunks_per_source) - for i in docs_temp: - all_docs.append({ - "title": i.metadata.get("title", i.metadata.get("post_title", i.page_content)).split("/")[-1], - "text": i.page_content, - "source": i.metadata.get("source") or vectorstore, - }) + + for doc in docs_temp: + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + + title = metadata.get( + "title", metadata.get("post_title", page_content) + ) + if isinstance(title, str): + title = title.split("/")[-1] + else: + title = str(title).split("/")[-1] + + all_docs.append( + { + "title": title, + "text": page_content, + "source": metadata.get("source") or vectorstore_id, + } + ) except Exception as e: - logging.error(f"Error searching vectorstore {vectorstore}: {e}") + logging.error( + f"Error searching vectorstore {vectorstore_id}: {e}", + exc_info=True, + ) continue return all_docs - def gen(): - pass - def search(self, query: str = ""): + """Search for documents using optional query override""" if query: self.original_question = query self.question = self._rephrase_query() return self._get_data() def get_params(self): + """Return current retriever configuration parameters""" return { "question": self.original_question, "rephrased_question": self.question, diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index a6b206c9..ea4885cd 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -1,20 +1,28 @@ -from abc import ABC, abstractmethod import os -from sentence_transformers import SentenceTransformer +from abc import ABC, abstractmethod + from langchain_openai import OpenAIEmbeddings +from sentence_transformers import SentenceTransformer + from application.core.settings import settings + class EmbeddingsWrapper: def __init__(self, model_name, *args, **kwargs): - self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs) + self.model = SentenceTransformer( + model_name, + config_kwargs={"allow_dangerous_deserialization": True}, + *args, + **kwargs + ) self.dimension = self.model.get_sentence_embedding_dimension() def embed_query(self, query: str): return self.model.encode(query).tolist() - + def embed_documents(self, documents: list): return self.model.encode(documents).tolist() - + def __call__(self, text): if isinstance(text, str): return self.embed_query(text) @@ -24,15 +32,14 @@ class EmbeddingsWrapper: raise ValueError("Input must be a string or a list of strings") - class EmbeddingsSingleton: _instances = {} @staticmethod def get_instance(embeddings_name, *args, **kwargs): if embeddings_name not in EmbeddingsSingleton._instances: - EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance( - embeddings_name, *args, **kwargs + EmbeddingsSingleton._instances[embeddings_name] = ( + EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs) ) return EmbeddingsSingleton._instances[embeddings_name] @@ -40,9 +47,15 @@ class EmbeddingsSingleton: def _create_instance(embeddings_name, *args, **kwargs): embeddings_factory = { "openai_text-embedding-ada-002": OpenAIEmbeddings, - "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"), - "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"), + "huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper( + "sentence-transformers/all-mpnet-base-v2" + ), + "huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper( + "hkunlp/instructor-large" + ), } if embeddings_name in embeddings_factory: @@ -50,34 +63,63 @@ class EmbeddingsSingleton: else: return EmbeddingsWrapper(embeddings_name, *args, **kwargs) + class BaseVectorStore(ABC): def __init__(self): pass @abstractmethod def search(self, *args, **kwargs): + """Search for similar documents/chunks in the vectorstore""" + pass + + @abstractmethod + def add_texts(self, texts, metadatas=None, *args, **kwargs): + """Add texts with their embeddings to the vectorstore""" + pass + + def delete_index(self, *args, **kwargs): + """Delete the entire index/collection""" + pass + + def save_local(self, *args, **kwargs): + """Save vectorstore to local storage""" + pass + + def get_chunks(self, *args, **kwargs): + """Get all chunks from the vectorstore""" + pass + + def add_chunk(self, text, metadata=None, *args, **kwargs): + """Add a single chunk to the vectorstore""" + pass + + def delete_chunk(self, chunk_id, *args, **kwargs): + """Delete a specific chunk from the vectorstore""" pass def is_azure_configured(self): - return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME + return ( + settings.OPENAI_API_BASE + and settings.OPENAI_API_VERSION + and settings.AZURE_DEPLOYMENT_NAME + ) def _get_embeddings(self, embeddings_name, embeddings_key=None): if embeddings_name == "openai_text-embedding-ada-002": if self.is_azure_configured(): os.environ["OPENAI_API_TYPE"] = "azure" embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME + embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME ) else: embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name, - openai_api_key=embeddings_key + embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": if os.path.exists("./models/all-mpnet-base-v2"): embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name = "./models/all-mpnet-base-v2", + embeddings_name="./models/all-mpnet-base-v2", ) else: embedding_instance = EmbeddingsSingleton.get_instance( @@ -87,4 +129,3 @@ class BaseVectorStore(ABC): embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) return embedding_instance - From 07d59b66406e4b5bfb193056f956abbbdb85c322 Mon Sep 17 00:00:00 2001 From: Ankit Matth Date: Sat, 23 Aug 2025 20:25:29 +0530 Subject: [PATCH 03/13] refactor: use list instead of string parsing --- application/retriever/classic_rag.py | 11 +---- frontend/src/components/SourcesPopup.tsx | 41 +++++++++++------- .../src/conversation/conversationHandlers.ts | 42 +++++++++---------- .../src/conversation/conversationModels.ts | 2 +- .../src/modals/ShareConversationModal.tsx | 14 +++---- frontend/src/preferences/preferenceSlice.ts | 19 +++++---- 6 files changed, 68 insertions(+), 61 deletions(-) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index b558c8f0..82423bb5 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -46,17 +46,10 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - if "active_docs" in source: + + if "active_docs" in source and source["active_docs"] is not None: if isinstance(source["active_docs"], list): self.vectorstores = source["active_docs"] - elif ( - isinstance(source["active_docs"], str) and "," in source["active_docs"] - ): - self.vectorstores = [ - doc_id.strip() - for doc_id in source["active_docs"].split(",") - if doc_id.strip() - ] else: self.vectorstores = [source["active_docs"]] else: diff --git a/frontend/src/components/SourcesPopup.tsx b/frontend/src/components/SourcesPopup.tsx index c3a61e01..f13ee25a 100644 --- a/frontend/src/components/SourcesPopup.tsx +++ b/frontend/src/components/SourcesPopup.tsx @@ -17,7 +17,7 @@ type SourcesPopupProps = { isOpen: boolean; onClose: () => void; anchorRef: React.RefObject; - handlePostDocumentSelect: (doc: Doc | null) => void; + handlePostDocumentSelect: (doc: Doc[] | null) => void; setUploadModalState: React.Dispatch>; }; @@ -149,9 +149,12 @@ export default function SourcesPopup({ if (option.model === embeddingsName) { const isSelected = selectedDocs && - Array.isArray(selectedDocs) && selectedDocs.length > 0 && - selectedDocs.some(doc => - option.id ? doc.id === option.id : doc.date === option.date + Array.isArray(selectedDocs) && + selectedDocs.length > 0 && + selectedDocs.some((doc) => + option.id + ? doc.id === option.id + : doc.date === option.date, ); return ( @@ -160,17 +163,27 @@ export default function SourcesPopup({ className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]" onClick={() => { if (isSelected) { - const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) - ? selectedDocs.filter(doc => - option.id ? doc.id !== option.id : doc.date !== option.date - ) - : []; - dispatch(setSelectedDocs(updatedDocs.length > 0 ? updatedDocs : null)); - handlePostDocumentSelect(updatedDocs.length > 0 ? updatedDocs : null); + const updatedDocs = + selectedDocs && Array.isArray(selectedDocs) + ? selectedDocs.filter((doc) => + option.id + ? doc.id !== option.id + : doc.date !== option.date, + ) + : []; + dispatch( + setSelectedDocs( + updatedDocs.length > 0 ? updatedDocs : null, + ), + ); + handlePostDocumentSelect( + updatedDocs.length > 0 ? updatedDocs : null, + ); } else { - const updatedDocs = (selectedDocs && Array.isArray(selectedDocs)) - ? [...selectedDocs, option] - : [option]; + const updatedDocs = + selectedDocs && Array.isArray(selectedDocs) + ? [...selectedDocs, option] + : [option]; dispatch(setSelectedDocs(updatedDocs)); handlePostDocumentSelect(updatedDocs); } diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index ae60b070..63557924 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -7,7 +7,7 @@ export function handleFetchAnswer( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | Doc[] | null, + selectedDocs: Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -52,15 +52,15 @@ export function handleFetchAnswer( payload.attachments = attachments; } - if (selectedDocs) { - if (Array.isArray(selectedDocs)) { + if (selectedDocs && Array.isArray(selectedDocs)) { + if (selectedDocs.length > 1) { // Handle multiple documents - payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.active_docs = selectedDocs.map((doc) => doc.id!); payload.retriever = selectedDocs[0]?.retriever as string; - } else if ('id' in selectedDocs) { + } else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) { // Handle single document (backward compatibility) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs.retriever as string; + payload.active_docs = selectedDocs[0].id as string; + payload.retriever = selectedDocs[0].retriever as string; } } return conversationService @@ -91,7 +91,7 @@ export function handleFetchAnswerSteaming( question: string, signal: AbortSignal, token: string | null, - selectedDocs: Doc | Doc[] | null, + selectedDocs: Doc[] | null, conversationId: string | null, promptId: string | null, chunks: string, @@ -119,15 +119,15 @@ export function handleFetchAnswerSteaming( payload.attachments = attachments; } - if (selectedDocs) { - if (Array.isArray(selectedDocs)) { + if (selectedDocs && Array.isArray(selectedDocs)) { + if (selectedDocs.length > 1) { // Handle multiple documents - payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.active_docs = selectedDocs.map((doc) => doc.id!); payload.retriever = selectedDocs[0]?.retriever as string; - } else if ('id' in selectedDocs) { + } else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) { // Handle single document (backward compatibility) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs.retriever as string; + payload.active_docs = selectedDocs[0].id as string; + payload.retriever = selectedDocs[0].retriever as string; } } @@ -185,7 +185,7 @@ export function handleFetchAnswerSteaming( export function handleSearch( question: string, token: string | null, - selectedDocs: Doc | Doc[] | null, + selectedDocs: Doc[] | null, conversation_id: string | null, chunks: string, token_limit: number, @@ -197,15 +197,15 @@ export function handleSearch( token_limit: token_limit, isNoneDoc: selectedDocs === null, }; - if (selectedDocs) { - if (Array.isArray(selectedDocs)) { + if (selectedDocs && Array.isArray(selectedDocs)) { + if (selectedDocs.length > 1) { // Handle multiple documents - payload.active_docs = selectedDocs.map(doc => doc.id).join(','); + payload.active_docs = selectedDocs.map((doc) => doc.id!); payload.retriever = selectedDocs[0]?.retriever as string; - } else if ('id' in selectedDocs) { + } else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) { // Handle single document (backward compatibility) - payload.active_docs = selectedDocs.id as string; - payload.retriever = selectedDocs.retriever as string; + payload.active_docs = selectedDocs[0].id as string; + payload.retriever = selectedDocs[0].retriever as string; } } return conversationService diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index 08743e73..2b9f6ee3 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -54,7 +54,7 @@ export interface Query { export interface RetrievalPayload { question: string; - active_docs?: string; + active_docs?: string | string[]; retriever?: string; conversation_id: string | null; prompt_id?: string | null; diff --git a/frontend/src/modals/ShareConversationModal.tsx b/frontend/src/modals/ShareConversationModal.tsx index 99262f01..624d64f5 100644 --- a/frontend/src/modals/ShareConversationModal.tsx +++ b/frontend/src/modals/ShareConversationModal.tsx @@ -60,7 +60,7 @@ export const ShareConversationModal = ({ const [sourcePath, setSourcePath] = useState<{ label: string; value: string; - } | null>(preSelectedDoc ? extractDocPaths([preSelectedDoc])[0] : null); + } | null>(preSelectedDoc ? extractDocPaths(preSelectedDoc)[0] : null); const handleCopyKey = (url: string) => { navigator.clipboard.writeText(url); @@ -105,14 +105,14 @@ export const ShareConversationModal = ({ return (
-

+

{t('modals.shareConv.label')}

-

+

{t('modals.shareConv.note')}

- + {t('modals.shareConv.option')} )}
- + {`${domain}/share/${identifier ?? '....'}`} {status === 'fetched' ? ( ) : ( + +
+ + +
+
+
+ + ) + ); +} From 1bf6af6eebce3fde514eea9bde38060f56281946 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 4 Sep 2025 15:10:12 +0530 Subject: [PATCH 05/13] feat: finalize remote mcp --- application/agents/tools/mcp_tool.py | 156 +++----- application/agents/tools/tool_manager.py | 8 - application/api/user/routes.py | 370 ++++++++---------- application/core/settings.py | 3 + application/requirements.txt | 1 + application/security/encryption.py | 120 +++--- frontend/public/toolIcons/tool_mcp_tool.svg | 4 + frontend/src/api/endpoints.ts | 2 + frontend/src/api/services/userService.ts | 9 +- frontend/src/locale/en.json | 31 +- frontend/src/modals/MCPServerModal.tsx | 395 +++++++++----------- 11 files changed, 453 insertions(+), 646 deletions(-) diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index fb47d0ed..c9133b96 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional import requests from application.agents.tools.base import Tool +from application.security.encryption import decrypt_credentials _mcp_session_cache = {} @@ -33,18 +34,12 @@ class MCPTool(Tool): self.auth_type = config.get("auth_type", "none") self.timeout = config.get("timeout", 30) - # Decrypt credentials if they are encrypted - self.auth_credentials = {} if config.get("encrypted_credentials") and user_id: - from application.security.encryption import decrypt_credentials - self.auth_credentials = decrypt_credentials( config["encrypted_credentials"], user_id ) else: - # Fallback to unencrypted credentials (for backward compatibility) - self.auth_credentials = config.get("auth_credentials", {}) self.available_tools = [] self._session = requests.Session() @@ -52,10 +47,25 @@ class MCPTool(Tool): self._setup_authentication() self._cache_key = self._generate_cache_key() + def _setup_authentication(self): + """Setup authentication for the MCP server connection.""" + if self.auth_type == "api_key": + api_key = self.auth_credentials.get("api_key", "") + header_name = self.auth_credentials.get("api_key_header", "X-API-Key") + if api_key: + self._session.headers.update({header_name: api_key}) + elif self.auth_type == "bearer": + token = self.auth_credentials.get("bearer_token", "") + if token: + self._session.headers.update({"Authorization": f"Bearer {token}"}) + elif self.auth_type == "basic": + username = self.auth_credentials.get("username", "") + password = self.auth_credentials.get("password", "") + if username and password: + self._session.auth = (username, password) + def _generate_cache_key(self) -> str: """Generate a unique cache key for this MCP server configuration.""" - # Use server URL + auth info to create unique key - auth_key = "" if self.auth_type == "bearer": token = self.auth_credentials.get("bearer_token", "") @@ -76,13 +86,9 @@ class MCPTool(Tool): if self._cache_key in _mcp_session_cache: session_data = _mcp_session_cache[self._cache_key] - # Check if session is less than 30 minutes old - - if time.time() - session_data["created_at"] < 1800: # 30 minutes + if time.time() - session_data["created_at"] < 1800: return session_data["session_id"] else: - # Remove expired session - del _mcp_session_cache[self._cache_key] return None @@ -94,23 +100,6 @@ class MCPTool(Tool): "created_at": time.time(), } - def _setup_authentication(self): - """Setup authentication for the MCP server connection.""" - if self.auth_type == "api_key": - api_key = self.auth_credentials.get("api_key", "") - header_name = self.auth_credentials.get("api_key_header", "X-API-Key") - if api_key: - self._session.headers.update({header_name: api_key}) - elif self.auth_type == "bearer": - token = self.auth_credentials.get("bearer_token", "") - if token: - self._session.headers.update({"Authorization": f"Bearer {token}"}) - elif self.auth_type == "basic": - username = self.auth_credentials.get("username", "") - password = self.auth_credentials.get("password", "") - if username and password: - self._session.auth = (username, password) - def _initialize_mcp_connection(self) -> Dict: """ Initialize MCP connection with the server, using cached session if available. @@ -264,10 +253,7 @@ class MCPTool(Tool): """ self._ensure_valid_session() - # Prepare call parameters for MCP protocol - call_params = {"name": action_name, "arguments": kwargs} - try: result = self._make_mcp_request("tools/call", call_params) return result @@ -283,9 +269,6 @@ class MCPTool(Tool): """ actions = [] for tool in self.available_tools: - # Parse MCP tool schema according to MCP specification - # Check multiple possible schema locations for compatibility - input_schema = ( tool.get("inputSchema") or tool.get("input_schema") @@ -293,20 +276,14 @@ class MCPTool(Tool): or tool.get("parameters") ) - # Default empty schema if no inputSchema provided - parameters_schema = { "type": "object", "properties": {}, "required": [], } - # Parse the inputSchema if it exists - if input_schema: if isinstance(input_schema, dict): - # Handle standard JSON Schema format - if "properties" in input_schema: parameters_schema = { "type": input_schema.get("type", "object"), @@ -314,14 +291,10 @@ class MCPTool(Tool): "required": input_schema.get("required", []), } - # Add additional schema properties if they exist - for key in ["additionalProperties", "description"]: if key in input_schema: parameters_schema[key] = input_schema[key] else: - # Might be properties directly at root level - parameters_schema["properties"] = input_schema action = { "name": tool.get("name", ""), @@ -331,64 +304,6 @@ class MCPTool(Tool): actions.append(action) return actions - def get_config_requirements(self) -> Dict: - """ - Get configuration requirements for the MCP tool. - - Returns: - Dictionary describing required configuration - """ - return { - "server_url": { - "type": "string", - "description": "URL of the remote MCP server (e.g., https://api.example.com)", - "required": True, - }, - "auth_type": { - "type": "string", - "description": "Authentication type", - "enum": ["none", "api_key", "bearer", "basic"], - "default": "none", - "required": True, - }, - "auth_credentials": { - "type": "object", - "description": "Authentication credentials (varies by auth_type)", - "properties": { - "api_key": { - "type": "string", - "description": "API key for api_key auth", - }, - "header_name": { - "type": "string", - "description": "Header name for API key (default: X-API-Key)", - "default": "X-API-Key", - }, - "token": { - "type": "string", - "description": "Bearer token for bearer auth", - }, - "username": { - "type": "string", - "description": "Username for basic auth", - }, - "password": { - "type": "string", - "description": "Password for basic auth", - }, - }, - "required": False, - }, - "timeout": { - "type": "integer", - "description": "Request timeout in seconds", - "default": 30, - "minimum": 1, - "maximum": 300, - "required": False, - }, - } - def test_connection(self) -> Dict: """ Test the connection to the MCP server and validate functionality. @@ -411,9 +326,7 @@ class MCPTool(Tool): "message": message, "tools_count": len(tools), "session_id": self._mcp_session_id, - "tools": [ - tool.get("name", "unknown") for tool in tools[:5] - ], # First 5 tool names + "tools": [tool.get("name", "unknown") for tool in tools[:5]], } except Exception as e: return { @@ -422,3 +335,32 @@ class MCPTool(Tool): "tools_count": 0, "error_type": type(e).__name__, } + + def get_config_requirements(self) -> Dict: + return { + "server_url": { + "type": "string", + "description": "URL of the remote MCP server (e.g., https://api.example.com)", + "required": True, + }, + "auth_type": { + "type": "string", + "description": "Authentication type", + "enum": ["none", "api_key", "bearer", "basic"], + "default": "none", + "required": True, + }, + "auth_credentials": { + "type": "object", + "description": "Authentication credentials (varies by auth_type)", + "required": False, + }, + "timeout": { + "type": "integer", + "description": "Request timeout in seconds", + "default": 30, + "minimum": 1, + "maximum": 300, + "required": False, + }, + } diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py index 890262bc..d602b762 100644 --- a/application/agents/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -28,7 +28,6 @@ class ToolManager: module = importlib.import_module(f"application.agents.tools.{tool_name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: - # For MCP tools, pass the user_id for credential decryption if tool_name == "mcp_tool" and user_id: return obj(tool_config, user_id) else: @@ -36,18 +35,11 @@ class ToolManager: def execute_action(self, tool_name, action_name, user_id=None, **kwargs): if tool_name not in self.tools: - # For MCP tools, they might not be pre-loaded, so load dynamically - if tool_name == "mcp_tool": - raise ValueError(f"Tool '{tool_name}' not loaded and no config provided for dynamic loading") raise ValueError(f"Tool '{tool_name}' not loaded") - - # For MCP tools, if user_id is provided, create a new instance with user context if tool_name == "mcp_tool" and user_id: - # Load tool dynamically with user context for proper credential access tool_config = self.config.get(tool_name, {}) tool = self.load_tool(tool_name, tool_config, user_id) return tool.execute_action(action_name, **kwargs) - return self.tools[tool_name].execute_action(action_name, **kwargs) def get_all_actions_metadata(self): diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 8309a984..2af52521 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -3,11 +3,12 @@ import json import math import os import secrets +import tempfile import uuid +import zipfile from functools import wraps from typing import Optional, Tuple -import tempfile -import zipfile + from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId @@ -24,7 +25,10 @@ from flask_restx import fields, inputs, Namespace, Resource from pymongo import ReturnDocument from werkzeug.utils import secure_filename +from application.agents.tools.mcp_tool import MCPTool + from application.agents.tools.tool_manager import ToolManager +from application.api import api from application.api.user.tasks import ( ingest, @@ -34,17 +38,17 @@ from application.api.user.tasks import ( ) from application.core.mongo_db import MongoDB from application.core.settings import settings -from application.api import api +from application.security.encryption import encrypt_credentials, decrypt_credentials from application.storage.storage_creator import StorageCreator from application.tts.google_tts import GoogleTTS from application.utils import ( check_required_fields, generate_image_url, + num_tokens_from_string, safe_filename, validate_function_name, validate_required_fields, ) -from application.utils import num_tokens_from_string from application.vectorstore.vector_creator import VectorCreator storage = StorageCreator.get_storage() @@ -3435,31 +3439,6 @@ class CreateTool(Resource): param_details["value"] = "" transformed_actions.append(action) try: - # Process config to encrypt credentials for MCP tools - config = data["config"] - if data["name"] == "mcp_tool": - from application.security.encryption import encrypt_credentials - - # Extract credentials from config - credentials = {} - if config.get("auth_type") == "bearer": - credentials["bearer_token"] = config.get("bearer_token", "") - elif config.get("auth_type") == "api_key": - credentials["api_key"] = config.get("api_key", "") - credentials["api_key_header"] = config.get("api_key_header", "") - elif config.get("auth_type") == "basic": - credentials["username"] = config.get("username", "") - credentials["password"] = config.get("password", "") - - # Encrypt credentials if any exist - if credentials: - config["encrypted_credentials"] = encrypt_credentials( - credentials, user - ) - # Remove plaintext credentials from config - for key in credentials.keys(): - config.pop(key, None) - new_tool = { "user": user, "name": data["name"], @@ -3467,7 +3446,7 @@ class CreateTool(Resource): "description": data["description"], "customName": data.get("customName", ""), "actions": transformed_actions, - "config": config, + "config": data["config"], "status": data["status"], } resp = user_tools_collection.insert_one(new_tool) @@ -3534,41 +3513,7 @@ class UpdateTool(Resource): ), 400, ) - - # Handle MCP tool credential encryption - config = data["config"] - tool_name = data.get("name") - if not tool_name: - # Get the tool name from the database - existing_tool = user_tools_collection.find_one( - {"_id": ObjectId(data["id"]), "user": user} - ) - tool_name = existing_tool.get("name") if existing_tool else None - - if tool_name == "mcp_tool": - from application.security.encryption import encrypt_credentials - - # Extract credentials from config - credentials = {} - if config.get("auth_type") == "bearer": - credentials["bearer_token"] = config.get("bearer_token", "") - elif config.get("auth_type") == "api_key": - credentials["api_key"] = config.get("api_key", "") - credentials["api_key_header"] = config.get("api_key_header", "") - elif config.get("auth_type") == "basic": - credentials["username"] = config.get("username", "") - credentials["password"] = config.get("password", "") - - # Encrypt credentials if any exist - if credentials: - config["encrypted_credentials"] = encrypt_credentials( - credentials, user - ) - # Remove plaintext credentials from config - for key in credentials.keys(): - config.pop(key, None) - - update_data["config"] = config + update_data["config"] = data["config"] if "status" in data: update_data["status"] = data["status"] user_tools_collection.update_one( @@ -4142,74 +4087,55 @@ class DirectoryStructure(Resource): return make_response(jsonify({"success": False, "error": str(e)}), 500) -@user_ns.route("/api/mcp_servers") -class MCPServers(Resource): - @api.doc(description="Get all MCP servers configured by the user") - def get(self): +@user_ns.route("/api/mcp_server/test") +class TestMCPServerConfig(Resource): + @api.expect( + api.model( + "MCPServerTestModel", + { + "config": fields.Raw( + required=True, description="MCP server configuration to test" + ), + }, + ) + ) + @api.doc(description="Test MCP server connection with provided configuration") + def post(self): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - user = decoded_token.get("sub") + data = request.get_json() + + required_fields = ["config"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields try: - # Find all MCP tools for this user - mcp_tools = user_tools_collection.find({"user": user, "name": "mcp_tool"}) + config = data["config"] - servers = [] - for tool in mcp_tools: - config = tool.get("config", {}) - servers.append( - { - "id": str(tool["_id"]), - "name": tool.get("displayName", "MCP Server"), - "server_url": config.get("server_url", ""), - "auth_type": config.get("auth_type", "none"), - "status": tool.get("status", False), - "created_at": ( - tool.get("_id").generation_time.isoformat() - if tool.get("_id") - else None - ), - } - ) + auth_credentials = {} + auth_type = config.get("auth_type", "none") - return make_response(jsonify({"success": True, "servers": servers}), 200) + if auth_type == "api_key" and "api_key" in config: + auth_credentials["api_key"] = config["api_key"] + if "api_key_header" in config: + auth_credentials["api_key_header"] = config["api_key_header"] + elif auth_type == "bearer" and "bearer_token" in config: + auth_credentials["bearer_token"] = config["bearer_token"] + elif auth_type == "basic": + if "username" in config: + auth_credentials["username"] = config["username"] + if "password" in config: + auth_credentials["password"] = config["password"] - except Exception as e: - current_app.logger.error( - f"Error retrieving MCP servers: {e}", exc_info=True - ) - return make_response(jsonify({"success": False, "error": str(e)}), 500) + test_config = config.copy() + test_config["auth_credentials"] = auth_credentials - -@user_ns.route("/api/mcp_server//test") -class TestMCPServer(Resource): - @api.doc(description="Test connection to an MCP server") - def post(self, server_id): - decoded_token = request.decoded_token - if not decoded_token: - return make_response(jsonify({"success": False}), 401) - - user = decoded_token.get("sub") - try: - # Find the MCP tool - mcp_tool_doc = user_tools_collection.find_one( - {"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"} - ) - - if not mcp_tool_doc: - return make_response( - jsonify({"success": False, "error": "MCP server not found"}), 404 - ) - - # Load the tool and test connection - from application.agents.tools.mcp_tool import MCPTool - - mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user) + mcp_tool = MCPTool(test_config, user) result = mcp_tool.test_connection() return make_response(jsonify(result), 200) - except Exception as e: current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True) return make_response( @@ -4220,38 +4146,86 @@ class TestMCPServer(Resource): ) -@user_ns.route("/api/mcp_server//tools") -class MCPServerTools(Resource): - @api.doc(description="Discover and get tools from an MCP server") - def get(self, server_id): +@user_ns.route("/api/mcp_server/save") +class MCPServerSave(Resource): + @api.expect( + api.model( + "MCPServerSaveModel", + { + "id": fields.String( + required=False, description="Tool ID for updates (optional)" + ), + "displayName": fields.String( + required=True, description="Display name for the MCP server" + ), + "config": fields.Raw( + required=True, description="MCP server configuration" + ), + "status": fields.Boolean( + required=False, default=True, description="Tool status" + ), + }, + ) + ) + @api.doc(description="Create or update MCP server with automatic tool discovery") + def post(self): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - user = decoded_token.get("sub") - try: - # Find the MCP tool - mcp_tool_doc = user_tools_collection.find_one( - {"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"} - ) + data = request.get_json() - if not mcp_tool_doc: - return make_response( - jsonify({"success": False, "error": "MCP server not found"}), 404 + required_fields = ["displayName", "config"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + try: + config = data["config"] + + auth_credentials = {} + auth_type = config.get("auth_type", "none") + if auth_type == "api_key": + if "api_key" in config and config["api_key"]: + auth_credentials["api_key"] = config["api_key"] + if "api_key_header" in config: + auth_credentials["api_key_header"] = config["api_key_header"] + elif auth_type == "bearer": + if "bearer_token" in config and config["bearer_token"]: + auth_credentials["bearer_token"] = config["bearer_token"] + elif auth_type == "basic": + if "username" in config and config["username"]: + auth_credentials["username"] = config["username"] + if "password" in config and config["password"]: + auth_credentials["password"] = config["password"] + mcp_config = config.copy() + mcp_config["auth_credentials"] = auth_credentials + + if auth_type == "none" or auth_credentials: + mcp_tool = MCPTool(mcp_config, user) + mcp_tool.discover_tools() + actions_metadata = mcp_tool.get_actions_metadata() + else: + raise Exception( + "No valid credentials provided for the selected authentication type" ) - # Load the tool and discover tools - from application.agents.tools.mcp_tool import MCPTool + storage_config = config.copy() + if auth_credentials: + encrypted_credentials_string = encrypt_credentials( + auth_credentials, user + ) + storage_config["encrypted_credentials"] = encrypted_credentials_string - mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user) - tools = mcp_tool.discover_tools() - - # Get actions metadata and transform to match other tools format - actions_metadata = mcp_tool.get_actions_metadata() + for field in [ + "api_key", + "bearer_token", + "username", + "password", + "api_key_header", + ]: + storage_config.pop(field, None) transformed_actions = [] - for action in actions_metadata: - # Add active flag and transform parameters action["active"] = True if "parameters" in action: if "properties" in action["parameters"]: @@ -4261,77 +4235,53 @@ class MCPServerTools(Resource): param_details["filled_by_llm"] = True param_details["value"] = "" transformed_actions.append(action) + tool_data = { + "name": "mcp_tool", + "displayName": data["displayName"], + "description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}", + "config": storage_config, + "actions": transformed_actions, + "status": data.get("status", True), + "user": user, + } - # Update the stored actions in the database - user_tools_collection.update_one( - {"_id": ObjectId(server_id)}, {"$set": {"actions": transformed_actions}} - ) - - return make_response( - jsonify( - {"success": True, "tools": tools, "actions": transformed_actions} - ), - 200, - ) - + tool_id = data.get("id") + if tool_id: + result = user_tools_collection.update_one( + {"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}, + {"$set": {k: v for k, v in tool_data.items() if k != "user"}}, + ) + if result.matched_count == 0: + return make_response( + jsonify( + { + "success": False, + "error": "Tool not found or access denied", + } + ), + 404, + ) + response_data = { + "success": True, + "id": tool_id, + "message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.", + "tools_count": len(transformed_actions), + } + else: + result = user_tools_collection.insert_one(tool_data) + tool_id = str(result.inserted_id) + response_data = { + "success": True, + "id": tool_id, + "message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.", + "tools_count": len(transformed_actions), + } + return make_response(jsonify(response_data), 200) except Exception as e: - current_app.logger.error(f"Error discovering MCP tools: {e}", exc_info=True) + current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True) return make_response( jsonify( - {"success": False, "error": f"Tool discovery failed: {str(e)}"} - ), - 500, - ) - - -@user_ns.route("/api/mcp_server//tools/") -class MCPServerToolAction(Resource): - @api.expect( - api.model( - "MCPToolActionModel", - { - "parameters": fields.Raw( - required=False, description="Parameters for the tool action" - ) - }, - ) - ) - @api.doc(description="Execute a specific tool action on an MCP server") - def post(self, server_id, action_name): - decoded_token = request.decoded_token - if not decoded_token: - return make_response(jsonify({"success": False}), 401) - - user = decoded_token.get("sub") - data = request.get_json() or {} - parameters = data.get("parameters", {}) - - try: - # Find the MCP tool - mcp_tool_doc = user_tools_collection.find_one( - {"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"} - ) - - if not mcp_tool_doc: - return make_response( - jsonify({"success": False, "error": "MCP server not found"}), 404 - ) - - # Load the tool and execute action - from application.agents.tools.mcp_tool import MCPTool - - mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user) - result = mcp_tool.execute_action(action_name, **parameters) - - return make_response(jsonify({"success": True, "result": result}), 200) - - except Exception as e: - current_app.logger.error( - f"Error executing MCP tool action: {e}", exc_info=True - ) - return make_response( - jsonify( - {"success": False, "error": f"Action execution failed: {str(e)}"} + {"success": False, "error": f"Failed to save MCP server: {str(e)}"} ), 500, ) diff --git a/application/core/settings.py b/application/core/settings.py index 7c25084e..a8c6bfa3 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -109,6 +109,9 @@ class Settings(BaseSettings): JWT_SECRET_KEY: str = "" + # Encryption settings + ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key" + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/requirements.txt b/application/requirements.txt index 3778d941..f922a2cb 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -2,6 +2,7 @@ anthropic==0.49.0 boto3==1.38.18 beautifulsoup4==4.13.4 celery==5.4.0 +cryptography==42.0.8 dataclasses-json==0.6.7 docx2txt==0.8 duckduckgo-search==7.5.2 diff --git a/application/security/encryption.py b/application/security/encryption.py index 5cc891f6..4cb3a4d5 100644 --- a/application/security/encryption.py +++ b/application/security/encryption.py @@ -1,97 +1,85 @@ -""" -Simple encryption utility for securely storing sensitive credentials. -Uses XOR encryption with a key derived from app secret and user ID. -Note: This is basic obfuscation. For production, consider using cryptography library. -""" - import base64 -import hashlib -import os import json +import os + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from application.core.settings import settings -def _get_encryption_key(user_id: str) -> bytes: - """ - Generate a consistent encryption key for a specific user. - Uses app secret + user ID to create a unique key per user. - """ - # Get app secret from environment or use a default (in production, always use env) - app_secret = os.environ.get( - "APP_SECRET_KEY", "default-docsgpt-secret-key-change-in-production" +def _derive_key(user_id: str, salt: bytes) -> bytes: + app_secret = settings.ENCRYPTION_SECRET_KEY + + password = f"{app_secret}#{user_id}".encode() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + backend=default_backend(), ) - # Combine app secret with user ID for user-specific encryption - combined = f"{app_secret}#{user_id}" - - # Create a 32-byte key - key_material = hashlib.sha256(combined.encode()).digest() - - return key_material - - -def _xor_encrypt_decrypt(data: bytes, key: bytes) -> bytes: - """Simple XOR encryption/decryption.""" - result = bytearray() - for i, byte in enumerate(data): - result.append(byte ^ key[i % len(key)]) - return bytes(result) + return kdf.derive(password) def encrypt_credentials(credentials: dict, user_id: str) -> str: - """ - Encrypt credentials dictionary for secure storage. - - Args: - credentials: Dictionary containing sensitive data - user_id: User ID for creating user-specific encryption key - - Returns: - Base64 encoded encrypted string - """ if not credentials: return "" - try: - key = _get_encryption_key(user_id) + salt = os.urandom(16) + iv = os.urandom(16) + key = _derive_key(user_id, salt) - # Convert dict to JSON string and encrypt json_str = json.dumps(credentials) - encrypted_data = _xor_encrypt_decrypt(json_str.encode(), key) - # Return base64 encoded for storage - return base64.b64encode(encrypted_data).decode() + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + padded_data = _pad_data(json_str.encode()) + encrypted_data = encryptor.update(padded_data) + encryptor.finalize() + + result = salt + iv + encrypted_data + return base64.b64encode(result).decode() except Exception as e: - # If encryption fails, store empty string (will require re-auth) print(f"Warning: Failed to encrypt credentials: {e}") return "" def decrypt_credentials(encrypted_data: str, user_id: str) -> dict: - """ - Decrypt credentials from storage. - - Args: - encrypted_data: Base64 encoded encrypted string - user_id: User ID for creating user-specific encryption key - - Returns: - Dictionary containing decrypted credentials - """ if not encrypted_data: return {} - try: - key = _get_encryption_key(user_id) + data = base64.b64decode(encrypted_data.encode()) - # Decode and decrypt - encrypted_bytes = base64.b64decode(encrypted_data.encode()) - decrypted_data = _xor_encrypt_decrypt(encrypted_bytes, key) + salt = data[:16] + iv = data[16:32] + encrypted_content = data[32:] + + key = _derive_key(user_id, salt) + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + decryptor = cipher.decryptor() + + decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize() + decrypted_data = _unpad_data(decrypted_padded) - # Parse JSON back to dict return json.loads(decrypted_data.decode()) - except Exception as e: - # If decryption fails, return empty dict (will require re-auth) print(f"Warning: Failed to decrypt credentials: {e}") return {} + + +def _pad_data(data: bytes) -> bytes: + block_size = 16 + padding_len = block_size - (len(data) % block_size) + padding = bytes([padding_len]) * padding_len + return data + padding + + +def _unpad_data(data: bytes) -> bytes: + padding_len = data[-1] + return data[:-padding_len] diff --git a/frontend/public/toolIcons/tool_mcp_tool.svg b/frontend/public/toolIcons/tool_mcp_tool.svg index e69de29b..22c980e3 100644 --- a/frontend/public/toolIcons/tool_mcp_tool.svg +++ b/frontend/public/toolIcons/tool_mcp_tool.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 81d19c87..62f8ba92 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -56,6 +56,8 @@ const endpoints = { DIRECTORY_STRUCTURE: (docId: string) => `/api/directory_structure?id=${docId}`, MANAGE_SOURCE_FILES: '/api/manage_source_files', + MCP_TEST_CONNECTION: '/api/mcp_server/test', + MCP_SAVE_SERVER: '/api/mcp_server/save', }, CONVERSATION: { ANSWER: '/api/answer', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index af5e4f22..3f69f719 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -89,7 +89,10 @@ const userService = { path?: string, search?: string, ): Promise => - apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search), token), + apiClient.get( + endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search), + token, + ), addChunk: (data: any, token: string | null): Promise => apiClient.post(endpoints.USER.ADD_CHUNK, data, token), deleteChunk: ( @@ -104,6 +107,10 @@ const userService = { apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token), manageSourceFiles: (data: FormData, token: string | null): Promise => apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token), + testMCPConnection: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token), + saveMCPServer: (data: any, token: string | null): Promise => + apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token), }; export default userService; diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 1b3067c9..52c53254 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -187,47 +187,24 @@ "regularTools": "Regular Tools", "mcpTools": "MCP Tools", "mcp": { - "title": "MCP (Model Context Protocol) Servers", - "description": "Connect to remote MCP servers to access their tools and capabilities. Only remote servers are supported.", "addServer": "Add MCP Server", "editServer": "Edit Server", - "deleteServer": "Delete Server", - "delete": "Delete", "serverName": "Server Name", "serverUrl": "Server URL", - "authType": "Authentication Type", - "apiKey": "API Key", "headerName": "Header Name", - "bearerToken": "Bearer Token", - "username": "Username", - "password": "Password", "timeout": "Timeout (seconds)", "testConnection": "Test Connection", "testing": "Testing...", "saving": "Saving...", "save": "Save", "cancel": "Cancel", - "backToServers": "← Back to Servers", - "availableTools": "Available Tools", - "refreshTools": "Refresh Tools", - "refreshing": "Refreshing...", - "serverDisabled": "Server is disabled. Enable it to view available tools.", - "noToolsFound": "No tools found on this server.", - "noServersFound": "No MCP servers configured.", - "addFirstServer": "Add your first MCP server to get started.", - "parameters": "Parameters", - "active": "Active", - "inactive": "Inactive", "noAuth": "No Authentication", - "toggleServer": "Toggle {{serverName}}", - "deleteWarning": "Are you sure you want to delete the MCP server \"{{serverName}}\"? This action cannot be undone.", "placeholders": { - "serverName": "My MCP Server", "serverUrl": "https://api.example.com", - "apiKey": "Enter your API key", - "bearerToken": "Enter your bearer token", - "username": "Enter username", - "password": "Enter password" + "apiKey": "Your secret API key", + "bearerToken": "Your secret token", + "username": "Your username", + "password": "Your password" }, "errors": { "nameRequired": "Server name is required", diff --git a/frontend/src/modals/MCPServerModal.tsx b/frontend/src/modals/MCPServerModal.tsx index 32710712..5e916210 100644 --- a/frontend/src/modals/MCPServerModal.tsx +++ b/frontend/src/modals/MCPServerModal.tsx @@ -2,8 +2,8 @@ import { useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useSelector } from 'react-redux'; -import apiClient from '../api/client'; import userService from '../api/services/userService'; +import Dropdown from '../components/Dropdown'; import Input from '../components/Input'; import Spinner from '../components/Spinner'; import { useOutsideAlerter } from '../hooks'; @@ -19,10 +19,10 @@ interface MCPServerModalProps { } const authTypes = [ - { value: 'none', label: 'No Authentication' }, - { value: 'api_key', label: 'API Key' }, - { value: 'bearer', label: 'Bearer Token' }, - { value: 'basic', label: 'Basic Authentication' }, + { label: 'No Authentication', value: 'none' }, + { label: 'API Key', value: 'api_key' }, + { label: 'Bearer Token', value: 'bearer' }, + // { label: 'Basic Authentication', value: 'basic' }, ]; export default function MCPServerModal({ @@ -36,7 +36,7 @@ export default function MCPServerModal({ const modalRef = useRef(null); const [formData, setFormData] = useState({ - name: server?.name || 'My MCP Server', + name: server?.displayName || 'My MCP Server', server_url: server?.server_url || '', auth_type: server?.auth_type || 'none', api_key: '', @@ -44,7 +44,7 @@ export default function MCPServerModal({ bearer_token: '', username: '', password: '', - timeout: 30, + timeout: server?.timeout || 30, }); const [loading, setLoading] = useState(false); @@ -79,15 +79,37 @@ export default function MCPServerModal({ }; const validateForm = () => { + const requiredFields: { [key: string]: boolean } = { + name: !formData.name.trim(), + server_url: !formData.server_url.trim(), + }; + + const authFieldChecks: { [key: string]: () => void } = { + api_key: () => { + if (!formData.api_key.trim()) + newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired'); + }, + bearer: () => { + if (!formData.bearer_token.trim()) + newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired'); + }, + basic: () => { + if (!formData.username.trim()) + newErrors.username = t('settings.tools.mcp.errors.usernameRequired'); + if (!formData.password.trim()) + newErrors.password = t('settings.tools.mcp.errors.passwordRequired'); + }, + }; + const newErrors: { [key: string]: string } = {}; + Object.entries(requiredFields).forEach(([field, isEmpty]) => { + if (isEmpty) + newErrors[field] = t( + `settings.tools.mcp.errors.${field === 'name' ? 'nameRequired' : 'urlRequired'}`, + ); + }); - if (!formData.name.trim()) { - newErrors.name = t('settings.tools.mcp.errors.nameRequired'); - } - - if (!formData.server_url.trim()) { - newErrors.server_url = t('settings.tools.mcp.errors.urlRequired'); - } else { + if (formData.server_url.trim()) { try { new URL(formData.server_url); } catch { @@ -95,22 +117,15 @@ export default function MCPServerModal({ } } - if (formData.auth_type === 'api_key' && !formData.api_key.trim()) { - newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired'); - } + const timeoutValue = formData.timeout === '' ? 30 : formData.timeout; + if ( + typeof timeoutValue === 'number' && + (timeoutValue < 1 || timeoutValue > 300) + ) + newErrors.timeout = 'Timeout must be between 1 and 300 seconds'; - if (formData.auth_type === 'bearer' && !formData.bearer_token.trim()) { - newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired'); - } - - if (formData.auth_type === 'basic') { - if (!formData.username.trim()) { - newErrors.username = t('settings.tools.mcp.errors.usernameRequired'); - } - if (!formData.password.trim()) { - newErrors.password = t('settings.tools.mcp.errors.passwordRequired'); - } - } + if (authFieldChecks[formData.auth_type]) + authFieldChecks[formData.auth_type](); setErrors(newErrors); return Object.keys(newErrors).length === 0; @@ -128,10 +143,9 @@ export default function MCPServerModal({ const config: any = { server_url: formData.server_url.trim(), auth_type: formData.auth_type, - timeout: formData.timeout, + timeout: formData.timeout === '' ? 30 : formData.timeout, }; - // Add credentials directly to config for encryption if (formData.auth_type === 'api_key') { config.api_key = formData.api_key.trim(); config.api_key_header = formData.header_name.trim() || 'X-API-Key'; @@ -141,59 +155,19 @@ export default function MCPServerModal({ config.username = formData.username.trim(); config.password = formData.password.trim(); } - return config; }; const testConnection = async () => { if (!validateForm()) return; - setTesting(true); setTestResult(null); - try { - // Create a temporary tool to test const config = buildToolConfig(); - - const testData = { - name: 'mcp_tool', - displayName: formData.name, - description: 'MCP Server Connection', - config, - actions: [], - status: false, - }; - - const response = await userService.createTool(testData, token); + const response = await userService.testMCPConnection({ config }, token); const result = await response.json(); - if (response.ok && result.id) { - // Test the connection - try { - const testResponse = await apiClient.post( - `/api/mcp_server/${result.id}/test`, - {}, - token, - ); - const testData = await testResponse.json(); - setTestResult(testData); - - // Clean up the temporary tool - await userService.deleteTool({ id: result.id }, token); - } catch (error) { - setTestResult({ - success: false, - message: t('settings.tools.mcp.errors.testFailed'), - }); - // Clean up the temporary tool - await userService.deleteTool({ id: result.id }, token); - } - } else { - setTestResult({ - success: false, - message: t('settings.tools.mcp.errors.testFailed'), - }); - } + setTestResult(result); } catch (error) { setTestResult({ success: false, @@ -206,73 +180,32 @@ export default function MCPServerModal({ const handleSave = async () => { if (!validateForm()) return; - setLoading(true); - try { const config = buildToolConfig(); - - const toolData = { - name: 'mcp_tool', + const serverData = { displayName: formData.name, - description: `MCP Server: ${formData.server_url}`, config, - actions: [], // Will be populated after tool creation status: true, + ...(server?.id && { id: server.id }), }; - let toolId: string; + const response = await userService.saveMCPServer(serverData, token); + const result = await response.json(); - if (server) { - // Update existing server - await userService.updateTool({ id: server.id, ...toolData }, token); - toolId = server.id; + if (response.ok && result.success) { + setTestResult({ + success: true, + message: result.message, + }); + onServerSaved(); + setModalState('INACTIVE'); + resetForm(); } else { - // Create new server - const response = await userService.createTool(toolData, token); - const result = await response.json(); - toolId = result.id; + setErrors({ + general: result.error || t('settings.tools.mcp.errors.saveFailed'), + }); } - - // Now fetch the MCP tools and update the actions - try { - const toolsResponse = await apiClient.get( - `/api/mcp_server/${toolId}/tools`, - token, - ); - - if (toolsResponse.success && toolsResponse.actions) { - // Update the tool with discovered actions (already formatted by backend) - await userService.updateTool( - { - id: toolId, - ...toolData, - actions: toolsResponse.actions, - }, - token, - ); - - console.log( - `Successfully discovered and saved ${toolsResponse.actions.length} MCP tools`, - ); - - // Show success message with tool count - setTestResult({ - success: true, - message: `MCP server saved successfully! Discovered ${toolsResponse.actions.length} tools.`, - }); - } - } catch (error) { - console.warn( - 'Warning: Could not fetch MCP tools immediately after creation:', - error, - ); - // Don't fail the save operation if tool discovery fails - } - - onServerSaved(); - setModalState('INACTIVE'); - resetForm(); } catch (error) { console.error('Error saving MCP server:', error); setErrors({ general: t('settings.tools.mcp.errors.saveFailed') }); @@ -285,52 +218,52 @@ export default function MCPServerModal({ switch (formData.auth_type) { case 'api_key': return ( -
-
- +
+
handleInputChange('api_key', e.target.value)} placeholder={t('settings.tools.mcp.placeholders.apiKey')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" /> {errors.api_key && (

{errors.api_key}

)}
-
- +
handleInputChange('header_name', e.target.value) } - placeholder="X-API-Key" + placeholder={t('settings.tools.mcp.headerName')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" />
); case 'bearer': return ( -
- +
handleInputChange('bearer_token', e.target.value) } placeholder={t('settings.tools.mcp.placeholders.bearerToken')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" /> {errors.bearer_token && (

{errors.bearer_token}

@@ -339,32 +272,32 @@ export default function MCPServerModal({ ); case 'basic': return ( -
-
- +
+
handleInputChange('username', e.target.value)} - placeholder={t('settings.tools.mcp.placeholders.username')} + placeholder={t('settings.tools.mcp.username')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" /> {errors.username && (

{errors.username}

)}
-
- +
handleInputChange('password', e.target.value)} - placeholder={t('settings.tools.mcp.placeholders.password')} + placeholder={t('settings.tools.mcp.password')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" /> {errors.password && (

{errors.password}

@@ -394,17 +327,17 @@ export default function MCPServerModal({ : t('settings.tools.mcp.addServer')}
- -
-
+
+
handleInputChange('name', e.target.value)} borderVariant="thin" - placeholder={t('settings.tools.mcp.placeholders.serverName')} + placeholder={t('settings.tools.mcp.serverName')} labelBgClassName="bg-white dark:bg-charleston-green-2" /> {errors.name && ( @@ -413,17 +346,17 @@ export default function MCPServerModal({
- handleInputChange('server_url', e.target.value) } - placeholder={t('settings.tools.mcp.placeholders.serverUrl')} + placeholder={t('settings.tools.mcp.serverUrl')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" /> {errors.server_url && (

@@ -432,106 +365,114 @@ export default function MCPServerModal({ )}

-
- - -
+ type.value === formData.auth_type) + ?.label || null + } + onSelect={(selection: { label: string; value: string }) => { + handleInputChange('auth_type', selection.value); + }} + options={authTypes} + size="w-full" + rounded="3xl" + border="border" + /> {renderAuthFields()}
- - handleInputChange('timeout', parseInt(e.target.value) || 30) - } - placeholder="30" + onChange={(e) => { + const value = e.target.value; + if (value === '') { + handleInputChange('timeout', ''); + } else { + const numValue = parseInt(value); + if (!isNaN(numValue) && numValue >= 1) { + handleInputChange('timeout', numValue); + } + } + }} + placeholder={t('settings.tools.mcp.timeout')} + borderVariant="thin" + labelBgClassName="bg-white dark:bg-charleston-green-2" /> + {errors.timeout && ( +

{errors.timeout}

+ )}
{testResult && (
{testResult.message}
)} - {errors.general && ( -
+
{errors.general}
)}
-
- - -
+
+
- + +
+ + +
From 0e4196f036354eaedac70f9b73b7aeb1df58a18e Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 4 Sep 2025 15:19:58 +0530 Subject: [PATCH 06/13] fix: remove unused tool labels from localization file --- frontend/src/locale/en.json | 2 -- 1 file changed, 2 deletions(-) diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 52c53254..6230229d 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -184,8 +184,6 @@ "addNew": "Add New", "name": "Name", "type": "Type", - "regularTools": "Regular Tools", - "mcpTools": "MCP Tools", "mcp": { "addServer": "Add MCP Server", "editServer": "Edit Server", From 2f88890c9404af7e89d7c3c194f9e4ffae33a45b Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 8 Sep 2025 22:10:08 +0530 Subject: [PATCH 07/13] feat: add support for multiple sources in agent configuration and update related components --- .../api/answer/services/stream_processor.py | 82 ++- application/api/user/routes.py | 482 +++++++++++++----- application/retriever/classic_rag.py | 14 +- frontend/src/agents/NewAgent.tsx | 171 ++++++- frontend/src/agents/types/index.ts | 1 + 5 files changed, 592 insertions(+), 158 deletions(-) diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index dfcfcdd2..6f57c2fc 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -69,11 +69,8 @@ class StreamProcessor: self.decoded_token.get("sub") if self.decoded_token is not None else None ) self.conversation_id = self.data.get("conversation_id") - self.source = ( - {"active_docs": self.data["active_docs"]} - if "active_docs" in self.data - else {} - ) + self.source = {} + self.all_sources = [] self.attachments = [] self.history = [] self.agent_config = {} @@ -86,6 +83,7 @@ class StreamProcessor: def initialize(self): """Initialize all required components for processing""" self._configure_agent() + self._configure_source() self._configure_retriever() self._load_conversation_history() self._process_attachments() @@ -171,12 +169,77 @@ class StreamProcessor: source = data.get("source") if isinstance(source, DBRef): source_doc = self.db.dereference(source) - data["source"] = str(source_doc["_id"]) - data["retriever"] = source_doc.get("retriever", data.get("retriever")) + if source_doc: + data["source"] = str(source_doc["_id"]) + data["retriever"] = source_doc.get("retriever", data.get("retriever")) + data["chunks"] = source_doc.get("chunks", data.get("chunks")) + else: + data["source"] = None + elif source == "default": + data["source"] = "default" else: - data["source"] = {} + data["source"] = None + # Handle multiple sources + + sources = data.get("sources", []) + if sources and isinstance(sources, list): + sources_list = [] + for i, source_ref in enumerate(sources): + if source_ref == "default": + processed_source = { + "id": "default", + "retriever": "classic", + "chunks": data.get("chunks", "2"), + } + sources_list.append(processed_source) + elif isinstance(source_ref, DBRef): + source_doc = self.db.dereference(source_ref) + if source_doc: + processed_source = { + "id": str(source_doc["_id"]), + "retriever": source_doc.get("retriever", "classic"), + "chunks": source_doc.get("chunks", data.get("chunks", "2")), + } + sources_list.append(processed_source) + data["sources"] = sources_list + else: + data["sources"] = [] return data + def _configure_source(self): + """Configure the source based on agent data""" + api_key = self.data.get("api_key") or self.agent_key + + if api_key: + agent_data = self._get_data_from_api_key(api_key) + + if agent_data.get("sources") and len(agent_data["sources"]) > 0: + source_ids = [ + source["id"] for source in agent_data["sources"] if source.get("id") + ] + if source_ids: + self.source = {"active_docs": source_ids} + else: + self.source = {} + self.all_sources = agent_data["sources"] + elif agent_data.get("source"): + self.source = {"active_docs": agent_data["source"]} + self.all_sources = [ + { + "id": agent_data["source"], + "retriever": agent_data.get("retriever", "classic"), + } + ] + else: + self.source = {} + self.all_sources = [] + return + if "active_docs" in self.data: + self.source = {"active_docs": self.data["active_docs"]} + return + self.source = {} + self.all_sources = [] + def _configure_agent(self): """Configure the agent based on request data""" agent_id = self.data.get("agent_id") @@ -230,7 +293,8 @@ class StreamProcessor: "token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY), } - if "isNoneDoc" in self.data and self.data["isNoneDoc"]: + api_key = self.data.get("api_key") or self.agent_key + if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]: self.retriever_config["chunks"] = 0 def create_agent(self): diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 9a2febbc..2e9bae81 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -492,9 +492,9 @@ class DeleteOldIndexes(Resource): ) if not doc: return make_response(jsonify({"status": "not found"}), 404) - + storage = StorageCreator.get_storage() - + try: # Delete vector index if settings.VECTOR_STORE == "faiss": @@ -508,7 +508,7 @@ class DeleteOldIndexes(Resource): settings.VECTOR_STORE, source_id=str(doc["_id"]) ) vectorstore.delete_index() - + if "file_path" in doc and doc["file_path"]: file_path = doc["file_path"] if storage.is_directory(file_path): @@ -517,7 +517,7 @@ class DeleteOldIndexes(Resource): storage.delete_file(f) else: storage.delete_file(file_path) - + except FileNotFoundError: pass except Exception as err: @@ -525,7 +525,7 @@ class DeleteOldIndexes(Resource): f"Error deleting files and indexes: {err}", exc_info=True ) return make_response(jsonify({"success": False}), 400) - + sources_collection.delete_one({"_id": ObjectId(source_id)}) return make_response(jsonify({"success": True}), 200) @@ -573,55 +573,75 @@ class UploadFile(Resource): try: storage = StorageCreator.get_storage() - - + for file in files: original_filename = file.filename safe_file = safe_filename(original_filename) - + with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, safe_file) file.save(temp_file_path) - + if zipfile.is_zipfile(temp_file_path): try: - with zipfile.ZipFile(temp_file_path, 'r') as zip_ref: + with zipfile.ZipFile(temp_file_path, "r") as zip_ref: zip_ref.extractall(path=temp_dir) - + # Walk through extracted files and upload them for root, _, files in os.walk(temp_dir): for extracted_file in files: - if os.path.join(root, extracted_file) == temp_file_path: + if ( + os.path.join(root, extracted_file) + == temp_file_path + ): continue - - rel_path = os.path.relpath(os.path.join(root, extracted_file), temp_dir) + + rel_path = os.path.relpath( + os.path.join(root, extracted_file), temp_dir + ) storage_path = f"{base_path}/{rel_path}" - - with open(os.path.join(root, extracted_file), 'rb') as f: + + with open( + os.path.join(root, extracted_file), "rb" + ) as f: storage.save_file(f, storage_path) except Exception as e: - current_app.logger.error(f"Error extracting zip: {e}", exc_info=True) + current_app.logger.error( + f"Error extracting zip: {e}", exc_info=True + ) # If zip extraction fails, save the original zip file file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: + with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) else: # For non-zip files, save directly file_path = f"{base_path}/{safe_file}" - with open(temp_file_path, 'rb') as f: + with open(temp_file_path, "rb") as f: storage.save_file(f, file_path) - + task = ingest.delay( settings.UPLOAD_FOLDER, [ - ".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", - ".html", ".mdx", ".json", ".xlsx", ".pptx", ".png", - ".jpg", ".jpeg", + ".rst", + ".md", + ".pdf", + ".txt", + ".docx", + ".csv", + ".epub", + ".html", + ".mdx", + ".json", + ".xlsx", + ".pptx", + ".png", + ".jpg", + ".jpeg", ], job_name, user, file_path=base_path, - filename=dir_name + filename=dir_name, ) except Exception as err: current_app.logger.error(f"Error uploading file: {err}", exc_info=True) @@ -635,12 +655,29 @@ class ManageSourceFiles(Resource): api.model( "ManageSourceFilesModel", { - "source_id": fields.String(required=True, description="Source ID to modify"), - "operation": fields.String(required=True, description="Operation: 'add', 'remove', or 'remove_directory'"), - "file_paths": fields.List(fields.String, required=False, description="File paths to remove (for remove operation)"), - "directory_path": fields.String(required=False, description="Directory path to remove (for remove_directory operation)"), - "file": fields.Raw(required=False, description="Files to add (for add operation)"), - "parent_dir": fields.String(required=False, description="Parent directory path relative to source root"), + "source_id": fields.String( + required=True, description="Source ID to modify" + ), + "operation": fields.String( + required=True, + description="Operation: 'add', 'remove', or 'remove_directory'", + ), + "file_paths": fields.List( + fields.String, + required=False, + description="File paths to remove (for remove operation)", + ), + "directory_path": fields.String( + required=False, + description="Directory path to remove (for remove_directory operation)", + ), + "file": fields.Raw( + required=False, description="Files to add (for add operation)" + ), + "parent_dir": fields.String( + required=False, + description="Parent directory path relative to source root", + ), }, ) ) @@ -650,7 +687,9 @@ class ManageSourceFiles(Resource): def post(self): decoded_token = request.decoded_token if not decoded_token: - return make_response(jsonify({"success": False, "message": "Unauthorized"}), 401) + return make_response( + jsonify({"success": False, "message": "Unauthorized"}), 401 + ) user = decoded_token.get("sub") source_id = request.form.get("source_id") @@ -658,12 +697,24 @@ class ManageSourceFiles(Resource): if not source_id or not operation: return make_response( - jsonify({"success": False, "message": "source_id and operation are required"}), 400 + jsonify( + { + "success": False, + "message": "source_id and operation are required", + } + ), + 400, ) if operation not in ["add", "remove", "remove_directory"]: return make_response( - jsonify({"success": False, "message": "operation must be 'add', 'remove', or 'remove_directory'"}), 400 + jsonify( + { + "success": False, + "message": "operation must be 'add', 'remove', or 'remove_directory'", + } + ), + 400, ) try: @@ -674,34 +725,53 @@ class ManageSourceFiles(Resource): ) try: - source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user}) + source = sources_collection.find_one( + {"_id": ObjectId(source_id), "user": user} + ) if not source: return make_response( - jsonify({"success": False, "message": "Source not found or access denied"}), 404 + jsonify( + { + "success": False, + "message": "Source not found or access denied", + } + ), + 404, ) except Exception as err: current_app.logger.error(f"Error finding source: {err}", exc_info=True) - return make_response(jsonify({"success": False, "message": "Database error"}), 500) + return make_response( + jsonify({"success": False, "message": "Database error"}), 500 + ) try: storage = StorageCreator.get_storage() source_file_path = source.get("file_path", "") - parent_dir = request.form.get("parent_dir", "") - + parent_dir = request.form.get("parent_dir", "") + if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir): return make_response( - jsonify({"success": False, "message": "Invalid parent directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid parent directory path"} + ), + 400, ) if operation == "add": files = request.files.getlist("file") if not files or all(file.filename == "" for file in files): return make_response( - jsonify({"success": False, "message": "No files provided for add operation"}), 400 + jsonify( + { + "success": False, + "message": "No files provided for add operation", + } + ), + 400, ) added_files = [] - + target_dir = source_file_path if parent_dir: target_dir = f"{source_file_path}/{parent_dir}" @@ -720,26 +790,44 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Added {len(added_files)} files", - "added_files": added_files, - "parent_dir": parent_dir, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Added {len(added_files)} files", + "added_files": added_files, + "parent_dir": parent_dir, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove": file_paths_str = request.form.get("file_paths") if not file_paths_str: return make_response( - jsonify({"success": False, "message": "file_paths required for remove operation"}), 400 + jsonify( + { + "success": False, + "message": "file_paths required for remove operation", + } + ), + 400, ) try: - file_paths = json.loads(file_paths_str) if isinstance(file_paths_str, str) else file_paths_str + file_paths = ( + json.loads(file_paths_str) + if isinstance(file_paths_str, str) + else file_paths_str + ) except Exception: return make_response( - jsonify({"success": False, "message": "Invalid file_paths format"}), 400 + jsonify( + {"success": False, "message": "Invalid file_paths format"} + ), + 400, ) # Remove files from storage and directory structure @@ -757,18 +845,29 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Removed {len(removed_files)} files", - "removed_files": removed_files, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Removed {len(removed_files)} files", + "removed_files": removed_files, + "reingest_task_id": task.id, + } + ), + 200, + ) elif operation == "remove_directory": directory_path = request.form.get("directory_path") if not directory_path: return make_response( - jsonify({"success": False, "message": "directory_path required for remove_directory operation"}), 400 + jsonify( + { + "success": False, + "message": "directory_path required for remove_directory operation", + } + ), + 400, ) # Validate directory path (prevent path traversal) @@ -778,10 +877,17 @@ class ManageSourceFiles(Resource): f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}" ) return make_response( - jsonify({"success": False, "message": "Invalid directory path"}), 400 + jsonify( + {"success": False, "message": "Invalid directory path"} + ), + 400, ) - full_directory_path = f"{source_file_path}/{directory_path}" if directory_path else source_file_path + full_directory_path = ( + f"{source_file_path}/{directory_path}" + if directory_path + else source_file_path + ) if not storage.is_directory(full_directory_path): current_app.logger.warning( @@ -790,7 +896,13 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Directory not found or is not a directory"}), 404 + jsonify( + { + "success": False, + "message": "Directory not found or is not a directory", + } + ), + 404, ) success = storage.remove_directory(full_directory_path) @@ -802,7 +914,10 @@ class ManageSourceFiles(Resource): f"Full path: {full_directory_path}" ) return make_response( - jsonify({"success": False, "message": "Failed to remove directory"}), 500 + jsonify( + {"success": False, "message": "Failed to remove directory"} + ), + 500, ) current_app.logger.info( @@ -816,12 +931,17 @@ class ManageSourceFiles(Resource): task = reingest_source_task.delay(source_id=source_id, user=user) - return make_response(jsonify({ - "success": True, - "message": f"Successfully removed directory: {directory_path}", - "removed_directory": directory_path, - "reingest_task_id": task.id - }), 200) + return make_response( + jsonify( + { + "success": True, + "message": f"Successfully removed directory: {directory_path}", + "removed_directory": directory_path, + "reingest_task_id": task.id, + } + ), + 200, + ) except Exception as err: error_context = f"operation={operation}, user={user}, source_id={source_id}" @@ -835,8 +955,12 @@ class ManageSourceFiles(Resource): parent_dir = request.form.get("parent_dir", "") error_context += f", parent_dir={parent_dir}" - current_app.logger.error(f"Error managing source files: {err} ({error_context})", exc_info=True) - return make_response(jsonify({"success": False, "message": "Operation failed"}), 500) + current_app.logger.error( + f"Error managing source files: {err} ({error_context})", exc_info=True + ) + return make_response( + jsonify({"success": False, "message": "Operation failed"}), 500 + ) @user_ns.route("/api/remote") @@ -984,7 +1108,7 @@ class PaginatedSources(Resource): "tokens": doc.get("tokens", ""), "retriever": doc.get("retriever", "classic"), "syncFrequency": doc.get("sync_frequency", ""), - "isNested": bool(doc.get("directory_structure")) + "isNested": bool(doc.get("directory_structure")), } paginated_docs.append(doc_data) response = { @@ -1032,7 +1156,7 @@ class CombinedJson(Resource): "tokens": index.get("tokens", ""), "retriever": index.get("retriever", "classic"), "syncFrequency": index.get("sync_frequency", ""), - "is_nested": bool(index.get("directory_structure")) + "is_nested": bool(index.get("directory_structure")), } ) except Exception as err: @@ -1272,6 +1396,16 @@ class GetAgent(Resource): and (source_doc := db.dereference(agent.get("source"))) else "" ), + "sources": [ + ( + str(db.dereference(source_ref)["_id"]) + if isinstance(source_ref, DBRef) and db.dereference(source_ref) + else source_ref + ) + for source_ref in agent.get("sources", []) + if (isinstance(source_ref, DBRef) and db.dereference(source_ref)) + or source_ref == "default" + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1325,8 +1459,24 @@ class GetAgents(Resource): str(source_doc["_id"]) if isinstance(agent.get("source"), DBRef) and (source_doc := db.dereference(agent.get("source"))) - else "" + else ( + agent.get("source", "") + if agent.get("source") == "default" + else "" + ) ), + "sources": [ + ( + source_ref + if source_ref == "default" + else str(db.dereference(source_ref)["_id"]) + ) + for source_ref in agent.get("sources", []) + if source_ref == "default" + or ( + isinstance(source_ref, DBRef) and db.dereference(source_ref) + ) + ], "chunks": agent["chunks"], "retriever": agent.get("retriever", ""), "prompt_id": agent.get("prompt_id", ""), @@ -1351,6 +1501,7 @@ class GetAgents(Resource): for agent in agents if "source" in agent or "retriever" in agent ] + except Exception as err: current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True) return make_response(jsonify({"success": False}), 400) @@ -1369,7 +1520,14 @@ class CreateAgent(Resource): "image": fields.Raw( required=False, description="Image file upload", type="file" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1381,7 +1539,8 @@ class CreateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -1401,13 +1560,18 @@ class CreateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) except json.JSONDecodeError: data["json_schema"] = None print(f"Received data: {data}") - + # Validate JSON schema if provided if data.get("json_schema"): try: @@ -1415,20 +1579,32 @@ class CreateAgent(Resource): json_schema = data.get("json_schema") if not isinstance(json_schema, dict): return make_response( - jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must be a valid JSON object", + } + ), + 400, ) - + # Validate that it has either a 'schema' property or is itself a schema if "schema" not in json_schema and "type" not in json_schema: return make_response( - jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}), - 400 + jsonify( + { + "success": False, + "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property", + } + ), + 400, ) except Exception as e: return make_response( - jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}), - 400 + jsonify( + {"success": False, "message": f"Invalid JSON schema: {str(e)}"} + ), + 400, ) if data.get("status") not in ["draft", "published"]: @@ -1446,12 +1622,22 @@ class CreateAgent(Resource): required_fields = [ "name", "description", - "source", "chunks", "retriever", "prompt_id", "agent_type", ] + # Require either source or sources (but not both) + if not data.get("source") and not data.get("sources"): + return make_response( + jsonify( + { + "success": False, + "message": "Either 'source' or 'sources' field is required for published agents", + } + ), + 400, + ) validate_fields = ["name", "description", "prompt_id", "agent_type"] else: required_fields = ["name"] @@ -1471,16 +1657,31 @@ class CreateAgent(Resource): try: key = str(uuid.uuid4()) if data.get("status") == "published" else "" + + sources_list = [] + if data.get("sources") and len(data.get("sources", [])) > 0: + for source_id in data.get("sources", []): + if source_id == "default": + sources_list.append("default") + elif ObjectId.is_valid(source_id): + sources_list.append(DBRef("sources", ObjectId(source_id))) + source_field = "" + else: + source_value = data.get("source", "") + if source_value == "default": + source_field = "default" + elif ObjectId.is_valid(source_value): + source_field = DBRef("sources", ObjectId(source_value)) + else: + source_field = "" + new_agent = { "user": user, "name": data.get("name"), "description": data.get("description", ""), "image": image_url, - "source": ( - DBRef("sources", ObjectId(data.get("source"))) - if ObjectId.is_valid(data.get("source")) - else "" - ), + "source": source_field, + "sources": sources_list, "chunks": data.get("chunks", ""), "retriever": data.get("retriever", ""), "prompt_id": data.get("prompt_id", ""), @@ -1495,7 +1696,11 @@ class CreateAgent(Resource): } if new_agent["chunks"] == "": new_agent["chunks"] = "0" - if new_agent["source"] == "" and new_agent["retriever"] == "": + if ( + new_agent["source"] == "" + and new_agent["retriever"] == "" + and not new_agent["sources"] + ): new_agent["retriever"] = "classic" resp = agents_collection.insert_one(new_agent) new_id = str(resp.inserted_id) @@ -1517,7 +1722,14 @@ class UpdateAgent(Resource): "image": fields.String( required=False, description="New image URL or identifier" ), - "source": fields.String(required=True, description="Source ID"), + "source": fields.String( + required=False, description="Source ID (legacy single source)" + ), + "sources": fields.List( + fields.String, + required=False, + description="List of source identifiers for multiple sources", + ), "chunks": fields.Integer(required=True, description="Chunks count"), "retriever": fields.String(required=True, description="Retriever ID"), "prompt_id": fields.String(required=True, description="Prompt ID"), @@ -1529,7 +1741,8 @@ class UpdateAgent(Resource): required=True, description="Status of the agent (draft or published)" ), "json_schema": fields.Raw( - required=False, description="JSON schema for enforcing structured output format" + required=False, + description="JSON schema for enforcing structured output format", ), }, ) @@ -1549,6 +1762,11 @@ class UpdateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "sources" in data: + try: + data["sources"] = json.loads(data["sources"]) + except json.JSONDecodeError: + data["sources"] = [] if "json_schema" in data: try: data["json_schema"] = json.loads(data["json_schema"]) @@ -1593,6 +1811,7 @@ class UpdateAgent(Resource): "description", "image", "source", + "sources", "chunks", "retriever", "prompt_id", @@ -1616,7 +1835,10 @@ class UpdateAgent(Resource): update_fields[field] = new_status elif field == "source": source_id = data.get("source") - if source_id and ObjectId.is_valid(source_id): + if source_id == "default": + # Handle special "default" source + update_fields[field] = "default" + elif source_id and ObjectId.is_valid(source_id): update_fields[field] = DBRef("sources", ObjectId(source_id)) elif source_id: return make_response( @@ -1630,6 +1852,30 @@ class UpdateAgent(Resource): ) else: update_fields[field] = "" + elif field == "sources": + sources_list = data.get("sources", []) + if sources_list and isinstance(sources_list, list): + valid_sources = [] + for source_id in sources_list: + if source_id == "default": + valid_sources.append("default") + elif ObjectId.is_valid(source_id): + valid_sources.append( + DBRef("sources", ObjectId(source_id)) + ) + else: + return make_response( + jsonify( + { + "success": False, + "message": f"Invalid source ID format: {source_id}", + } + ), + 400, + ) + update_fields[field] = valid_sources + else: + update_fields[field] = [] elif field == "chunks": chunks_value = data.get("chunks") if chunks_value == "": @@ -3532,7 +3778,7 @@ class GetChunks(Resource): "page": "Page number for pagination", "per_page": "Number of chunks per page", "path": "Optional: Filter chunks by relative file path", - "search": "Optional: Search term to filter chunks by title or content" + "search": "Optional: Search term to filter chunks by title or content", }, ) def get(self): @@ -3556,7 +3802,7 @@ class GetChunks(Resource): try: store = get_vector_store(doc_id) chunks = store.get_chunks() - + filtered_chunks = [] for chunk in chunks: metadata = chunk.get("metadata", {}) @@ -3577,9 +3823,9 @@ class GetChunks(Resource): continue filtered_chunks.append(chunk) - + chunks = filtered_chunks - + total_chunks = len(chunks) start = (page - 1) * per_page end = start + per_page @@ -3593,7 +3839,7 @@ class GetChunks(Resource): "total": total_chunks, "chunks": paginated_chunks, "path": path if path else None, - "search": search_term if search_term else None + "search": search_term if search_term else None, } ), 200, @@ -3602,6 +3848,7 @@ class GetChunks(Resource): current_app.logger.error(f"Error getting chunks: {e}", exc_info=True) return make_response(jsonify({"success": False}), 500) + @user_ns.route("/api/add_chunk") class AddChunk(Resource): @api.expect( @@ -3768,7 +4015,9 @@ class UpdateChunk(Resource): deleted = store.delete_chunk(chunk_id) if not deleted: - current_app.logger.warning(f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created") + current_app.logger.warning( + f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created" + ) return make_response( jsonify( @@ -3900,39 +4149,38 @@ class DirectoryStructure(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - + user = decoded_token.get("sub") doc_id = request.args.get("id") - + if not doc_id: - return make_response( - jsonify({"error": "Document ID is required"}), 400 - ) - + return make_response(jsonify({"error": "Document ID is required"}), 400) + if not ObjectId.is_valid(doc_id): return make_response(jsonify({"error": "Invalid document ID"}), 400) - + try: doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user}) if not doc: return make_response( jsonify({"error": "Document not found or access denied"}), 404 ) - + directory_structure = doc.get("directory_structure", {}) - + return make_response( - jsonify({ - "success": True, - "directory_structure": directory_structure, - "base_path": doc.get("file_path", "") - }), 200 + jsonify( + { + "success": True, + "directory_structure": directory_structure, + "base_path": doc.get("file_path", ""), + } + ), + 200, ) - + except Exception as e: current_app.logger.error( f"Error retrieving directory structure: {e}", exc_info=True ) - return make_response( - jsonify({"success": False, "error": str(e)}), 500 - ) + return make_response(jsonify({"success": False, "error": str(e)}), 500) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 82423bb5..ce1b937b 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -46,7 +46,7 @@ class ClassicRAG(BaseRetriever): user_api_key=self.user_api_key, decoded_token=decoded_token, ) - + if "active_docs" in source and source["active_docs"] is not None: if isinstance(source["active_docs"], list): self.vectorstores = source["active_docs"] @@ -54,7 +54,6 @@ class ClassicRAG(BaseRetriever): self.vectorstores = [source["active_docs"]] else: self.vectorstores = [] - self.question = self._rephrase_query() self.decoded_token = decoded_token self._validate_vectorstore_config() @@ -64,7 +63,6 @@ class ClassicRAG(BaseRetriever): if not self.vectorstores: logging.warning("No vectorstores configured for retrieval") return - invalid_ids = [ vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip() ] @@ -84,12 +82,16 @@ class ClassicRAG(BaseRetriever): or not self.vectorstores ): return self.original_question - prompt = f"""Given the following conversation history: + {self.chat_history} + + Rephrase the following user question to be a standalone search query + that captures all relevant context from the conversation: + """ messages = [ @@ -109,7 +111,6 @@ class ClassicRAG(BaseRetriever): """Retrieve relevant documents from configured vectorstores""" if self.chunks == 0 or not self.vectorstores: return [] - all_docs = [] chunks_per_source = max(1, self.chunks // len(self.vectorstores)) @@ -128,7 +129,6 @@ class ClassicRAG(BaseRetriever): else: page_content = doc.get("text", doc.get("page_content", "")) metadata = doc.get("metadata", {}) - title = metadata.get( "title", metadata.get("post_title", page_content) ) @@ -136,7 +136,6 @@ class ClassicRAG(BaseRetriever): title = title.split("/")[-1] else: title = str(title).split("/")[-1] - all_docs.append( { "title": title, @@ -150,7 +149,6 @@ class ClassicRAG(BaseRetriever): exc_info=True, ) continue - return all_docs def search(self, query: str = ""): diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index da8cef5d..f1fc5e50 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -45,6 +45,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { description: '', image: '', source: '', + sources: [], chunks: '', retriever: '', prompt_id: 'default', @@ -150,7 +151,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -196,7 +231,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const formData = new FormData(); formData.append('name', agent.name); formData.append('description', agent.description); - formData.append('source', agent.source); + + if (selectedSourceIds.size > 1) { + const sourcesArray = Array.from(selectedSourceIds) + .map((id) => { + const sourceDoc = sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ); + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) { + return 'default'; + } + return sourceDoc?.id || id; + }) + .filter(Boolean); + formData.append('sources', JSON.stringify(sourcesArray)); + formData.append('source', ''); + } else if (selectedSourceIds.size === 1) { + const singleSourceId = Array.from(selectedSourceIds)[0]; + const sourceDoc = sourceDocs?.find( + (source) => + source.id === singleSourceId || + source.retriever === singleSourceId || + source.name === singleSourceId, + ); + let finalSourceId; + if (sourceDoc?.name === 'Default' && !sourceDoc?.id) + finalSourceId = 'default'; + else finalSourceId = sourceDoc?.id || singleSourceId; + formData.append('source', String(finalSourceId)); + formData.append('sources', JSON.stringify([])); + } else { + formData.append('source', ''); + formData.append('sources', JSON.stringify([])); + } + formData.append('chunks', agent.chunks); formData.append('retriever', agent.retriever); formData.append('prompt_id', agent.prompt_id); @@ -293,9 +362,33 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { throw new Error('Failed to fetch agent'); } const data = await response.json(); - if (data.source) setSelectedSourceIds(new Set([data.source])); - else if (data.retriever) + + if (data.sources && data.sources.length > 0) { + const mappedSources = data.sources.map((sourceId: string) => { + if (sourceId === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + return defaultSource?.retriever || 'classic'; + } + return sourceId; + }); + setSelectedSourceIds(new Set(mappedSources)); + } else if (data.source) { + if (data.source === 'default') { + const defaultSource = sourceDocs?.find( + (source) => source.name === 'Default', + ); + setSelectedSourceIds( + new Set([defaultSource?.retriever || 'classic']), + ); + } else { + setSelectedSourceIds(new Set([data.source])); + } + } else if (data.retriever) { setSelectedSourceIds(new Set([data.retriever])); + } + if (data.tools) setSelectedToolIds(new Set(data.tools)); if (data.status === 'draft') setEffectiveMode('draft'); if (data.json_schema) { @@ -311,25 +404,57 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { }, [agentId, mode, token]); useEffect(() => { - const selectedSource = Array.from(selectedSourceIds).map((id) => - sourceDocs?.find( - (source) => - source.id === id || source.retriever === id || source.name === id, - ), - ); - if (selectedSource[0]?.model === embeddingsName) { - if (selectedSource[0] && 'id' in selectedSource[0]) { + const selectedSources = Array.from(selectedSourceIds) + .map((id) => + sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ), + ) + .filter(Boolean); + + if (selectedSources.length > 0) { + // Handle multiple sources + if (selectedSources.length > 1) { + // Multiple sources selected - store in sources array + const sourceIds = selectedSources + .map((source) => source?.id) + .filter((id): id is string => Boolean(id)); setAgent((prev) => ({ ...prev, - source: selectedSource[0]?.id || 'default', + sources: sourceIds, + source: '', // Clear single source for multiple sources retriever: '', })); - } else - setAgent((prev) => ({ - ...prev, - source: '', - retriever: selectedSource[0]?.retriever || 'classic', - })); + } else { + // Single source selected - maintain backward compatibility + const selectedSource = selectedSources[0]; + if (selectedSource?.model === embeddingsName) { + if (selectedSource && 'id' in selectedSource) { + setAgent((prev) => ({ + ...prev, + source: selectedSource?.id || 'default', + sources: [], // Clear sources array for single source + retriever: '', + })); + } else { + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], // Clear sources array + retriever: selectedSource?.retriever || 'classic', + })); + } + } + } + } else { + // No sources selected + setAgent((prev) => ({ + ...prev, + source: '', + sources: [], + retriever: '', + })); } }, [selectedSourceIds]); @@ -510,7 +635,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { ) .filter(Boolean) .join(', ') - : 'Select source'} + : 'Select sources'} ) => { setSelectedSourceIds(newSelectedIds); - setIsSourcePopupOpen(false); }} - title="Select Source" + title="Select Sources" searchPlaceholder="Search sources..." - noOptionsMessage="No source available" - singleSelect={true} + noOptionsMessage="No sources available" />
diff --git a/frontend/src/agents/types/index.ts b/frontend/src/agents/types/index.ts index e841cb0a..442097a1 100644 --- a/frontend/src/agents/types/index.ts +++ b/frontend/src/agents/types/index.ts @@ -10,6 +10,7 @@ export type Agent = { description: string; image: string; source: string; + sources?: string[]; chunks: string; retriever: string; prompt_id: string; From 816f660be3349e8235c089bc7dfefb007873ac27 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 10 Sep 2025 13:08:14 +0530 Subject: [PATCH 08/13] fix: enhance MCPTool request handling and tool discovery logic --- application/agents/tools/mcp_tool.py | 58 +++++++++++++++++++++------- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index c9133b96..72a26482 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -156,15 +156,20 @@ class MCPTool(Tool): ) -> Dict: """Execute MCP request with optional retry on session failure.""" try: - headers = {"Content-Type": "application/json", "Accept": "application/json"} - headers.update(self._session.headers) + final_headers = self._session.headers.copy() + final_headers.update( + { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + ) if self._mcp_session_id: - headers["Mcp-Session-Id"] = self._mcp_session_id + final_headers["Mcp-Session-Id"] = self._mcp_session_id response = self._session.post( self.server_url.rstrip("/"), json=mcp_message, - headers=headers, + headers=final_headers, timeout=self.timeout, ) @@ -175,10 +180,26 @@ class MCPTool(Tool): if method.startswith("notifications/"): return {} - try: - result = response.json() - except json.JSONDecodeError: - raise Exception(f"Invalid JSON response: {response.text}") + response_text = response.text.strip() + if response_text.startswith("event:") and "data:" in response_text: + lines = response_text.split("\n") + data_line = None + for line in lines: + if line.startswith("data:"): + data_line = line[5:].strip() + break + if data_line: + try: + result = json.loads(data_line) + except json.JSONDecodeError: + raise Exception(f"Invalid JSON in SSE data: {data_line}") + else: + raise Exception(f"No data found in SSE response: {response_text}") + else: + try: + result = response.json() + except json.JSONDecodeError: + raise Exception(f"Invalid JSON response: {response.text}") if "error" in result: error_msg = result["error"] if isinstance(error_msg, dict): @@ -228,15 +249,24 @@ class MCPTool(Tool): response = self._make_mcp_request("tools/list") - if isinstance(response, dict) and "tools" in response: - self.available_tools = response["tools"] - return self.available_tools + # Handle both formats: response with 'tools' key or response that IS the tools list + + if isinstance(response, dict): + if "tools" in response: + self.available_tools = response["tools"] + elif ( + "result" in response + and isinstance(response["result"], dict) + and "tools" in response["result"] + ): + self.available_tools = response["result"]["tools"] + else: + self.available_tools = [response] if response else [] elif isinstance(response, list): self.available_tools = response - return self.available_tools else: self.available_tools = [] - return self.available_tools + return self.available_tools except Exception as e: raise Exception(f"Failed to discover tools from MCP server: {str(e)}") @@ -312,6 +342,8 @@ class MCPTool(Tool): Dictionary with connection test results including tool count """ try: + self._mcp_session_id = None + init_result = self._initialize_mcp_connection() tools = self.discover_tools() From b052e3280532ca3e10ec69e05417d8be1d0c71a8 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 10 Sep 2025 14:15:51 +0530 Subject: [PATCH 09/13] feat: enhance MCP tool configuration handling and authentication logic --- application/api/user/routes.py | 56 +++++++++++++++++- frontend/src/settings/ToolConfig.tsx | 88 ++++++++++++++++++++++++---- 2 files changed, 133 insertions(+), 11 deletions(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 2af52521..b0554461 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -3513,7 +3513,60 @@ class UpdateTool(Resource): ), 400, ) - update_data["config"] = data["config"] + tool_doc = user_tools_collection.find_one( + {"_id": ObjectId(data["id"]), "user": user} + ) + if tool_doc and tool_doc.get("name") == "mcp_tool": + config = data["config"] + existing_config = tool_doc.get("config", {}) + storage_config = existing_config.copy() + + storage_config.update(config) + existing_credentials = {} + if "encrypted_credentials" in existing_config: + existing_credentials = decrypt_credentials( + existing_config["encrypted_credentials"], user + ) + auth_credentials = existing_credentials.copy() + auth_type = storage_config.get("auth_type", "none") + if auth_type == "api_key": + if "api_key" in config and config["api_key"]: + auth_credentials["api_key"] = config["api_key"] + if "api_key_header" in config: + auth_credentials["api_key_header"] = config[ + "api_key_header" + ] + elif auth_type == "bearer": + if "bearer_token" in config and config["bearer_token"]: + auth_credentials["bearer_token"] = config["bearer_token"] + elif "encrypted_token" in config and config["encrypted_token"]: + auth_credentials["bearer_token"] = config["encrypted_token"] + elif auth_type == "basic": + if "username" in config and config["username"]: + auth_credentials["username"] = config["username"] + if "password" in config and config["password"]: + auth_credentials["password"] = config["password"] + if auth_type != "none" and auth_credentials: + encrypted_credentials_string = encrypt_credentials( + auth_credentials, user + ) + storage_config["encrypted_credentials"] = ( + encrypted_credentials_string + ) + elif auth_type == "none": + storage_config.pop("encrypted_credentials", None) + for field in [ + "api_key", + "bearer_token", + "encrypted_token", + "username", + "password", + "api_key_header", + ]: + storage_config.pop(field, None) + update_data["config"] = storage_config + else: + update_data["config"] = data["config"] if "status" in data: update_data["status"] = data["status"] user_tools_collection.update_one( @@ -4238,6 +4291,7 @@ class MCPServerSave(Resource): tool_data = { "name": "mcp_tool", "displayName": data["displayName"], + "customName": data["displayName"], "description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}", "config": storage_config, "actions": transformed_actions, diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx index 61a1d850..bca5c6ce 100644 --- a/frontend/src/settings/ToolConfig.tsx +++ b/frontend/src/settings/ToolConfig.tsx @@ -30,9 +30,22 @@ export default function ToolConfig({ handleGoBack: () => void; }) { const token = useSelector(selectToken); - const [authKey, setAuthKey] = React.useState( - 'token' in tool.config ? tool.config.token : '', - ); + const [authKey, setAuthKey] = React.useState(() => { + if (tool.name === 'mcp_tool') { + const config = tool.config as any; + if (config.auth_type === 'api_key') { + return config.api_key || ''; + } else if (config.auth_type === 'bearer') { + return config.encrypted_token || ''; + } else if (config.auth_type === 'basic') { + return config.password || ''; + } + return ''; + } else if ('token' in tool.config) { + return tool.config.token; + } + return ''; + }); const [customName, setCustomName] = React.useState( tool.customName || '', ); @@ -97,6 +110,26 @@ export default function ToolConfig({ }; const handleSaveChanges = () => { + let configToSave; + if (tool.name === 'api_tool') { + configToSave = tool.config; + } else if (tool.name === 'mcp_tool') { + configToSave = { ...tool.config } as any; + const mcpConfig = tool.config as any; + + if (authKey.trim()) { + if (mcpConfig.auth_type === 'api_key') { + configToSave.api_key = authKey; + } else if (mcpConfig.auth_type === 'bearer') { + configToSave.encrypted_token = authKey; + } else if (mcpConfig.auth_type === 'basic') { + configToSave.password = authKey; + } + } + } else { + configToSave = { token: authKey }; + } + userService .updateTool( { @@ -105,7 +138,7 @@ export default function ToolConfig({ displayName: tool.displayName, customName: customName, description: tool.description, - config: tool.name === 'api_tool' ? tool.config : { token: authKey }, + config: configToSave, actions: 'actions' in tool ? tool.actions : [], status: tool.status, }, @@ -196,7 +229,15 @@ export default function ToolConfig({
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (

- {t('settings.tools.authentication')} + {tool.name === 'mcp_tool' + ? (tool.config as any)?.auth_type === 'bearer' + ? 'Bearer Token' + : (tool.config as any)?.auth_type === 'api_key' + ? 'API Key' + : (tool.config as any)?.auth_type === 'basic' + ? 'Password' + : t('settings.tools.authentication') + : t('settings.tools.authentication')}

)}
@@ -208,7 +249,17 @@ export default function ToolConfig({ value={authKey} onChange={(e) => setAuthKey(e.target.value)} borderVariant="thin" - placeholder={t('modals.configTool.apiKeyPlaceholder')} + placeholder={ + tool.name === 'mcp_tool' + ? (tool.config as any)?.auth_type === 'bearer' + ? 'Bearer Token' + : (tool.config as any)?.auth_type === 'api_key' + ? 'API Key' + : (tool.config as any)?.auth_type === 'basic' + ? 'Password' + : t('modals.configTool.apiKeyPlaceholder') + : t('modals.configTool.apiKeyPlaceholder') + } />
)} @@ -450,6 +501,26 @@ export default function ToolConfig({ setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')} submitLabel={t('settings.tools.saveAndLeave')} handleSubmit={() => { + let configToSave; + if (tool.name === 'api_tool') { + configToSave = tool.config; + } else if (tool.name === 'mcp_tool') { + configToSave = { ...tool.config } as any; + const mcpConfig = tool.config as any; + + if (authKey.trim()) { + if (mcpConfig.auth_type === 'api_key') { + configToSave.api_key = authKey; + } else if (mcpConfig.auth_type === 'bearer') { + configToSave.encrypted_token = authKey; + } else if (mcpConfig.auth_type === 'basic') { + configToSave.password = authKey; + } + } + } else { + configToSave = { token: authKey }; + } + userService .updateTool( { @@ -458,10 +529,7 @@ export default function ToolConfig({ displayName: tool.displayName, customName: customName, description: tool.description, - config: - tool.name === 'api_tool' - ? tool.config - : { token: authKey }, + config: configToSave, actions: 'actions' in tool ? tool.actions : [], status: tool.status, }, From adcdce8d764ca4de31009af52487fe40581156fd Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 10 Sep 2025 22:10:11 +0530 Subject: [PATCH 10/13] fix: handle invalid chunks value in StreamProcessor and ClassicRAG --- .../api/answer/services/stream_processor.py | 16 ++++++++++++++-- application/retriever/classic_rag.py | 11 ++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index a04020cb..f6e639ef 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -266,7 +266,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 elif self.agent_key: data_key = self._get_data_from_api_key(self.agent_key) self.agent_config.update( @@ -287,7 +293,13 @@ class StreamProcessor: if data_key.get("retriever"): self.retriever_config["retriever_name"] = data_key["retriever"] if data_key.get("chunks") is not None: - self.retriever_config["chunks"] = data_key["chunks"] + try: + self.retriever_config["chunks"] = int(data_key["chunks"]) + except (ValueError, TypeError): + logger.warning( + f"Invalid chunks value: {data_key['chunks']}, using default value 2" + ) + self.retriever_config["chunks"] = 2 else: self.agent_config.update( { diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index ce1b937b..2ce863c2 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -25,7 +25,16 @@ class ClassicRAG(BaseRetriever): self.original_question = source.get("question", "") self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt - self.chunks = chunks + if isinstance(chunks, str): + try: + self.chunks = int(chunks) + except ValueError: + logging.warning( + f"Invalid chunks value '{chunks}', using default value 2" + ) + self.chunks = 2 + else: + self.chunks = chunks self.gpt_model = gpt_model self.token_limit = ( token_limit From 188d118fc0c689870b0c89ff3d32d5721f24c757 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 10 Sep 2025 22:14:31 +0530 Subject: [PATCH 11/13] refactor: remove unused logging import from routes.py --- application/api/connector/routes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/application/api/connector/routes.py b/application/api/connector/routes.py index f203a703..1647aa78 100644 --- a/application/api/connector/routes.py +++ b/application/api/connector/routes.py @@ -1,6 +1,5 @@ import datetime import json -import logging from bson.objectid import ObjectId From 09b9576eef0a344a310d5052fe374199554418bb Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 11 Sep 2025 17:54:46 +0530 Subject: [PATCH 12/13] feat: enhance message and schema cleaning for Google AI integration --- application/llm/google_ai.py | 81 ++++++++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 8 deletions(-) diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 91065b74..54567f6f 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM): raise def _clean_messages_google(self, messages): + """Convert OpenAI format messages to Google AI format.""" cleaned_messages = [] for message in messages: role = message.get("role") @@ -150,6 +151,15 @@ class GoogleLLM(BaseLLM): if role == "assistant": role = "model" + elif role == "system": + continue + elif role == "tool": + continue + elif role not in ["user", "model"]: + logging.warning( + f"GoogleLLM: Converting unsupported role '{role}' to 'user'" + ) + role = "user" parts = [] if role and content is not None: @@ -188,11 +198,63 @@ class GoogleLLM(BaseLLM): else: raise ValueError(f"Unexpected content type: {type(content)}") - cleaned_messages.append(types.Content(role=role, parts=parts)) + if parts: + cleaned_messages.append(types.Content(role=role, parts=parts)) return cleaned_messages + def _clean_schema(self, schema_obj): + """ + Recursively remove unsupported fields from schema objects + and validate required properties. + """ + if not isinstance(schema_obj, dict): + return schema_obj + allowed_fields = { + "type", + "description", + "items", + "properties", + "required", + "enum", + "pattern", + "minimum", + "maximum", + "nullable", + "default", + } + + cleaned = {} + for key, value in schema_obj.items(): + if key not in allowed_fields: + continue + elif key == "type" and isinstance(value, str): + cleaned[key] = value.upper() + elif isinstance(value, dict): + cleaned[key] = self._clean_schema(value) + elif isinstance(value, list): + cleaned[key] = [self._clean_schema(item) for item in value] + else: + cleaned[key] = value + + # Validate that required properties actually exist in properties + if "required" in cleaned and "properties" in cleaned: + valid_required = [] + properties_keys = set(cleaned["properties"].keys()) + for required_prop in cleaned["required"]: + if required_prop in properties_keys: + valid_required.append(required_prop) + if valid_required: + cleaned["required"] = valid_required + else: + cleaned.pop("required", None) + elif "required" in cleaned and "properties" not in cleaned: + cleaned.pop("required", None) + + return cleaned + def _clean_tools_format(self, tools_list): + """Convert OpenAI format tools to Google AI format.""" genai_tools = [] for tool_data in tools_list: if tool_data["type"] == "function": @@ -201,18 +263,16 @@ class GoogleLLM(BaseLLM): properties = parameters.get("properties", {}) if properties: + cleaned_properties = {} + for k, v in properties.items(): + cleaned_properties[k] = self._clean_schema(v) + genai_function = dict( name=function["name"], description=function["description"], parameters={ "type": "OBJECT", - "properties": { - k: { - **v, - "type": v["type"].upper() if v["type"] else None, - } - for k, v in properties.items() - }, + "properties": cleaned_properties, "required": ( parameters["required"] if "required" in parameters @@ -242,6 +302,7 @@ class GoogleLLM(BaseLLM): response_schema=None, **kwargs, ): + """Generate content using Google AI API without streaming.""" client = genai.Client(api_key=self.api_key) if formatting == "openai": messages = self._clean_messages_google(messages) @@ -281,6 +342,7 @@ class GoogleLLM(BaseLLM): response_schema=None, **kwargs, ): + """Generate content using Google AI API with streaming.""" client = genai.Client(api_key=self.api_key) if formatting == "openai": messages = self._clean_messages_google(messages) @@ -331,12 +393,15 @@ class GoogleLLM(BaseLLM): yield chunk.text def _supports_tools(self): + """Return whether this LLM supports function calling.""" return True def _supports_structured_output(self): + """Return whether this LLM supports structured JSON output.""" return True def prepare_structured_output_format(self, json_schema): + """Convert JSON schema to Google AI structured output format.""" if not json_schema: return None From 641cf5a4c1d167956d4bcafbfcac98e1b9d4ac5f Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 11 Sep 2025 19:04:10 +0530 Subject: [PATCH 13/13] feat: skip empty fields in mcp tool call + improve error handling and response --- application/agents/base.py | 12 +++---- application/agents/tools/mcp_tool.py | 9 ++++- application/core/settings.py | 2 +- application/llm/google_ai.py | 9 +---- application/llm/handlers/base.py | 36 +++++++++++++------ application/llm/handlers/google.py | 14 ++++---- .../src/conversation/ConversationBubble.tsx | 21 +++++++++-- frontend/src/conversation/types/index.ts | 3 +- 8 files changed, 69 insertions(+), 37 deletions(-) diff --git a/application/agents/base.py b/application/agents/base.py index dff191a3..068b2a3c 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -142,28 +142,28 @@ class BaseAgent(ABC): tool_id, action_name, call_args = parser.parse_args(call) call_id = getattr(call, "id", None) or str(uuid.uuid4()) - + # Check if parsing failed if tool_id is None or action_name is None: error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}" logger.error(error_message) - + tool_call_data = { "tool_name": "unknown", "call_id": call_id, - "action_name": getattr(call, 'name', 'unknown'), + "action_name": getattr(call, "name", "unknown"), "arguments": call_args or {}, "result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}", } yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} self.tool_calls.append(tool_call_data) return f"Failed to parse tool call.", call_id - + # Check if tool_id exists in available tools if tool_id not in tools_dict: error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}" logger.error(error_message) - + # Return error result tool_call_data = { "tool_name": "unknown", @@ -175,7 +175,7 @@ class BaseAgent(ABC): yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} self.tool_calls.append(tool_call_data) return f"Tool with ID {tool_id} not found.", call_id - + tool_call_data = { "tool_name": tools_dict[tool_id]["name"], "call_id": call_id, diff --git a/application/agents/tools/mcp_tool.py b/application/agents/tools/mcp_tool.py index 72a26482..dc689367 100644 --- a/application/agents/tools/mcp_tool.py +++ b/application/agents/tools/mcp_tool.py @@ -283,7 +283,14 @@ class MCPTool(Tool): """ self._ensure_valid_session() - call_params = {"name": action_name, "arguments": kwargs} + # Skipping empty/None values - letting the server use defaults + + cleaned_kwargs = {} + for key, value in kwargs.items(): + if value == "" or value is None: + continue + cleaned_kwargs[key] = value + call_params = {"name": action_name, "arguments": cleaned_kwargs} try: result = self._make_mcp_request("tools/call", call_params) return result diff --git a/application/core/settings.py b/application/core/settings.py index a8c6bfa3..f1563569 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -26,7 +26,7 @@ class Settings(BaseSettings): "gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5, - "gemini-2.0-flash-exp": 1e6, + "gemini-2.5-flash": 1e6, } UPLOAD_FOLDER: str = "inputs" PARSE_PDF_AS_IMAGE: bool = False diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 54567f6f..b88e1d9f 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -151,15 +151,8 @@ class GoogleLLM(BaseLLM): if role == "assistant": role = "model" - elif role == "system": - continue elif role == "tool": - continue - elif role not in ["user", "model"]: - logging.warning( - f"GoogleLLM: Converting unsupported role '{role}' to 'user'" - ) - role = "user" + role = "model" parts = [] if role and content is not None: diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index 43205472..96ed4c00 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -205,7 +205,6 @@ class LLMHandler(ABC): except StopIteration as e: tool_response, call_id = e.value break - updated_messages.append( { "role": "assistant", @@ -222,17 +221,36 @@ class LLMHandler(ABC): ) updated_messages.append(self.create_tool_message(call, tool_response)) - except Exception as e: logger.error(f"Error executing tool: {str(e)}", exc_info=True) - updated_messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call.id, - } + error_call = ToolCall( + id=call.id, name=call.name, arguments=call.arguments ) + error_response = f"Error executing tool: {str(e)}" + error_message = self.create_tool_message(error_call, error_response) + updated_messages.append(error_message) + call_parts = call.name.split("_") + if len(call_parts) >= 2: + tool_id = call_parts[-1] # Last part is tool ID (e.g., "1") + action_name = "_".join(call_parts[:-1]) + tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool") + full_action_name = f"{action_name}_{tool_id}" + else: + tool_name = "unknown_tool" + action_name = call.name + full_action_name = call.name + yield { + "type": "tool_call", + "data": { + "tool_name": tool_name, + "call_id": call.id, + "action_name": full_action_name, + "arguments": call.arguments, + "error": error_response, + "status": "error", + }, + } return updated_messages def handle_non_streaming( @@ -263,13 +281,11 @@ class LLMHandler(ABC): except StopIteration as e: messages = e.value break - response = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) parsed = self.parse_response(response) self.llm_calls.append(build_stack_data(agent.llm)) - return parsed.content def handle_streaming( diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py index b43f2a16..7fa44cb6 100644 --- a/application/llm/handlers/google.py +++ b/application/llm/handlers/google.py @@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler): finish_reason="stop", raw_response=response, ) - if hasattr(response, "candidates"): parts = response.candidates[0].content.parts if response.candidates else [] tool_calls = [ @@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler): finish_reason="tool_calls" if tool_calls else "stop", raw_response=response, ) - else: tool_calls = [] if hasattr(response, "function_call"): @@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler): def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: """Create Google-style tool message.""" - from google.genai import types return { - "role": "tool", + "role": "model", "content": [ - types.Part.from_function_response( - name=tool_call.name, response={"result": result} - ).to_json_dict() + { + "function_response": { + "name": tool_call.name, + "response": {"result": result}, + } + } ], } diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 3be40df7..bbdf5e00 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -1,6 +1,6 @@ import 'katex/dist/katex.min.css'; -import { forwardRef, Fragment, useRef, useState, useEffect } from 'react'; +import { forwardRef, Fragment, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import ReactMarkdown from 'react-markdown'; import { useSelector } from 'react-redux'; @@ -12,12 +12,13 @@ import { import rehypeKatex from 'rehype-katex'; import remarkGfm from 'remark-gfm'; import remarkMath from 'remark-math'; -import DocumentationDark from '../assets/documentation-dark.svg'; + import ChevronDown from '../assets/chevron-down.svg'; import Cloud from '../assets/cloud.svg'; import DocsGPT3 from '../assets/cute_docsgpt3.svg'; import Dislike from '../assets/dislike.svg?react'; import Document from '../assets/document.svg'; +import DocumentationDark from '../assets/documentation-dark.svg'; import Edit from '../assets/edit.svg'; import Like from '../assets/like.svg?react'; import Link from '../assets/link.svg'; @@ -761,7 +762,11 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { Response {' '}

{toolCall.status === 'pending' && ( @@ -779,6 +784,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {

)} + {toolCall.status === 'error' && ( +

+ + {toolCall.error} + +

+ )}
diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index 4ccb04a1..d962e4bc 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -4,5 +4,6 @@ export type ToolCallsType = { call_id: string; arguments: Record; result?: Record; - status?: 'pending' | 'completed'; + error?: string; + status?: 'pending' | 'completed' | 'error'; };