From ba49eea23d6f25701781e5138c4c1ac79580f880 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 1 Oct 2025 13:56:31 +0530 Subject: [PATCH] Refactor agent creation and update logic to improve error handling and default values; enhance logging for better traceability --- application/api/user/routes.py | 337 +++++++++++++++++---------- application/retriever/classic_rag.py | 36 ++- application/vectorstore/base.py | 49 +++- frontend/src/agents/NewAgent.tsx | 58 ++++- frontend/src/components/Table.tsx | 74 +++--- 5 files changed, 373 insertions(+), 181 deletions(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 281664d3..4311e7f6 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1726,7 +1726,7 @@ class CreateAgent(Resource): "key": key, } if new_agent["chunks"] == "": - new_agent["chunks"] = "0" + new_agent["chunks"] = "2" if ( new_agent["source"] == "" and new_agent["retriever"] == "" @@ -1782,43 +1782,56 @@ class UpdateAgent(Resource): @api.doc(description="Update an existing agent") def put(self, agent_id): if not (decoded_token := request.decoded_token): - return {"success": False}, 401 + return make_response( + jsonify({"success": False, "message": "Unauthorized"}), 401 + ) user = decoded_token.get("sub") - if request.content_type == "application/json": - data = request.get_json() - else: - data = request.form.to_dict() - if "tools" in data: - try: - data["tools"] = json.loads(data["tools"]) - except json.JSONDecodeError: - data["tools"] = [] - if "sources" in data: - try: - data["sources"] = json.loads(data["sources"]) - except json.JSONDecodeError: - data["sources"] = [] - if "json_schema" in data: - try: - data["json_schema"] = json.loads(data["json_schema"]) - except json.JSONDecodeError: - data["json_schema"] = None + if not ObjectId.is_valid(agent_id): return make_response( jsonify({"success": False, "message": "Invalid agent ID format"}), 400 ) oid = ObjectId(agent_id) + try: + if request.content_type and "application/json" in request.content_type: + data = request.get_json() + else: + data = request.form.to_dict() + json_fields = ["tools", "sources", "json_schema"] + for field in json_fields: + if field in data and data[field]: + try: + data[field] = json.loads(data[field]) + except json.JSONDecodeError: + return make_response( + jsonify( + { + "success": False, + "message": f"Invalid JSON format for field: {field}", + } + ), + 400, + ) + except Exception as err: + current_app.logger.error( + f"Error parsing request data: {err}", exc_info=True + ) + return make_response( + jsonify({"success": False, "message": "Invalid request data"}), 400 + ) + try: existing_agent = agents_collection.find_one({"_id": oid, "user": user}) except Exception as err: + current_app.logger.error( + f"Error finding agent {agent_id}: {err}", exc_info=True + ) return make_response( - current_app.logger.error( - f"Error finding agent {agent_id}: {err}", exc_info=True - ), jsonify({"success": False, "message": "Database error finding agent"}), 500, ) + if not existing_agent: return make_response( jsonify( @@ -1826,13 +1839,19 @@ class UpdateAgent(Resource): ), 404, ) + image_url, error = handle_image_upload( request, existing_agent.get("image", ""), user, storage ) if error: - return make_response( - jsonify({"success": False, "message": "Image upload failed"}), 400 + current_app.logger.error( + f"Image upload error for agent {agent_id}: {error}" ) + return make_response( + jsonify({"success": False, "message": f"Image upload failed: {error}"}), + 400, + ) + update_fields = {} allowed_fields = [ "name", @@ -1850,116 +1869,189 @@ class UpdateAgent(Resource): ] for field in allowed_fields: - if field in data: - if field == "status": - new_status = data.get("status") - if new_status not in ["draft", "published"]: - return make_response( - jsonify( - {"success": False, "message": "Invalid status value"} - ), - 400, - ) - update_fields[field] = new_status - elif field == "source": - source_id = data.get("source") - if source_id == "default": - # Handle special "default" source + if field not in data: + continue - update_fields[field] = "default" - elif source_id and ObjectId.is_valid(source_id): - update_fields[field] = DBRef("sources", ObjectId(source_id)) - elif source_id: - return make_response( - jsonify( - { - "success": False, - "message": "Invalid source ID format provided", - } - ), - 400, - ) - else: - update_fields[field] = "" - elif field == "sources": - sources_list = data.get("sources", []) - if sources_list and isinstance(sources_list, list): - valid_sources = [] - for source_id in sources_list: - if source_id == "default": - valid_sources.append("default") - elif ObjectId.is_valid(source_id): - valid_sources.append( - DBRef("sources", ObjectId(source_id)) - ) - else: - return make_response( - jsonify( - { - "success": False, - "message": f"Invalid source ID format: {source_id}", - } - ), - 400, - ) - update_fields[field] = valid_sources - else: - update_fields[field] = [] - elif field == "chunks": - chunks_value = data.get("chunks") - if chunks_value == "": - update_fields[field] = "0" - else: - try: - if int(chunks_value) < 0: - return make_response( - jsonify( - { - "success": False, - "message": "Chunks value must be a positive integer", - } - ), - 400, - ) - update_fields[field] = chunks_value - except ValueError: + if field == "status": + new_status = data.get("status") + if new_status not in ["draft", "published"]: + return make_response( + jsonify( + { + "success": False, + "message": "Invalid status value. Must be 'draft' or 'published'", + } + ), + 400, + ) + update_fields[field] = new_status + + elif field == "source": + source_id = data.get("source") + if source_id == "default": + update_fields[field] = "default" + elif source_id and ObjectId.is_valid(source_id): + update_fields[field] = DBRef("sources", ObjectId(source_id)) + elif source_id: + return make_response( + jsonify( + { + "success": False, + "message": f"Invalid source ID format: {source_id}", + } + ), + 400, + ) + else: + update_fields[field] = "" + + elif field == "sources": + sources_list = data.get("sources", []) + if sources_list and isinstance(sources_list, list): + valid_sources = [] + for source_id in sources_list: + if source_id == "default": + valid_sources.append("default") + elif ObjectId.is_valid(source_id): + valid_sources.append(DBRef("sources", ObjectId(source_id))) + else: return make_response( jsonify( { "success": False, - "message": "Invalid chunks value provided", + "message": f"Invalid source ID in list: {source_id}", } ), 400, ) + update_fields[field] = valid_sources else: - update_fields[field] = data[field] + update_fields[field] = [] + + elif field == "chunks": + chunks_value = data.get("chunks") + if chunks_value == "" or chunks_value is None: + update_fields[field] = "2" + else: + try: + chunks_int = int(chunks_value) + if chunks_int < 0: + return make_response( + jsonify( + { + "success": False, + "message": "Chunks value must be a non-negative integer", + } + ), + 400, + ) + update_fields[field] = str(chunks_int) + except (ValueError, TypeError): + return make_response( + jsonify( + { + "success": False, + "message": f"Invalid chunks value: {chunks_value}", + } + ), + 400, + ) + + elif field == "tools": + tools_list = data.get("tools", []) + if isinstance(tools_list, list): + update_fields[field] = tools_list + else: + return make_response( + jsonify( + { + "success": False, + "message": "Tools must be a list", + } + ), + 400, + ) + + elif field == "json_schema": + json_schema = data.get("json_schema") + if json_schema is not None: + if not isinstance(json_schema, dict): + return make_response( + jsonify( + { + "success": False, + "message": "JSON schema must be a valid object", + } + ), + 400, + ) + update_fields[field] = json_schema + else: + update_fields[field] = None + + else: + value = data[field] + if field in ["name", "description", "prompt_id", "agent_type"]: + if not value or not str(value).strip(): + return make_response( + jsonify( + { + "success": False, + "message": f"Field '{field}' cannot be empty", + } + ), + 400, + ) + update_fields[field] = value + if image_url: update_fields["image"] = image_url + if not update_fields: return make_response( - jsonify({"success": False, "message": "No update data provided"}), 400 + jsonify( + { + "success": False, + "message": "No valid update data provided", + } + ), + 400, ) + newly_generated_key = None final_status = update_fields.get("status", existing_agent.get("status")) + if final_status == "published": - required_published_fields = [ - "name", - "description", - "source", - "chunks", - "retriever", - "prompt_id", - "agent_type", - ] + required_published_fields = { + "name": "Agent name", + "description": "Agent description", + "chunks": "Chunks count", + "prompt_id": "Prompt", + "agent_type": "Agent type", + } + missing_published_fields = [] - for req_field in required_published_fields: + for req_field, field_label in required_published_fields.items(): final_value = update_fields.get( req_field, existing_agent.get(req_field) ) - if req_field == "source" and final_value: - if not isinstance(final_value, DBRef): - missing_published_fields.append(req_field) + if not final_value: + missing_published_fields.append(field_label) + + source_val = update_fields.get("source", existing_agent.get("source")) + sources_val = update_fields.get( + "sources", existing_agent.get("sources", []) + ) + + has_valid_source = ( + isinstance(source_val, DBRef) + or source_val == "default" + or (isinstance(sources_val, list) and len(sources_val) > 0) + ) + + if not has_valid_source: + missing_published_fields.append("Source") + if missing_published_fields: return make_response( jsonify( @@ -1970,9 +2062,11 @@ class UpdateAgent(Resource): ), 400, ) + if not existing_agent.get("key"): newly_generated_key = str(uuid.uuid4()) update_fields["key"] = newly_generated_key + update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc) try: @@ -1985,20 +2079,22 @@ class UpdateAgent(Resource): jsonify( { "success": False, - "message": "Agent not found or update failed unexpectedly", + "message": "Agent not found or update failed", } ), 404, ) + if result.modified_count == 0 and result.matched_count == 1: return make_response( jsonify( { "success": True, - "message": "Agent found, but no changes were applied", + "message": "No changes detected", + "id": agent_id, } ), - 304, + 200, ) except Exception as err: current_app.logger.error( @@ -2008,6 +2104,7 @@ class UpdateAgent(Resource): jsonify({"success": False, "message": "Database error during update"}), 500, ) + response_data = { "success": True, "id": agent_id, @@ -2015,10 +2112,8 @@ class UpdateAgent(Resource): } if newly_generated_key: response_data["key"] = newly_generated_key - return make_response( - jsonify(response_data), - 200, - ) + + return make_response(jsonify(response_data), 200) @user_ns.route("/api/delete_agent") diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index f90a751c..d34b47d5 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -36,6 +36,11 @@ class ClassicRAG(BaseRetriever): self.chunks = 2 else: self.chunks = chunks + user_identifier = user_api_key if user_api_key else "default" + logging.info( + f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, " + f"sources={'active_docs' in source and source['active_docs'] is not None}" + ) self.gpt_model = gpt_model self.token_limit = ( token_limit @@ -92,17 +97,12 @@ class ClassicRAG(BaseRetriever): or not self.vectorstores ): return self.original_question - prompt = f"""Given the following conversation history: - - {self.chat_history} - - - - Rephrase the following user question to be a standalone search query - - that captures all relevant context from the conversation: - - """ + prompt = ( + "Given the following conversation history:\n" + f"{self.chat_history}\n\n" + "Rephrase the following user question to be a standalone search query " + "that captures all relevant context from the conversation:\n" + ) messages = [ {"role": "system", "content": prompt}, @@ -120,10 +120,20 @@ class ClassicRAG(BaseRetriever): def _get_data(self): """Retrieve relevant documents from configured vectorstores""" if self.chunks == 0 or not self.vectorstores: + logging.info( + f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, " + f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}" + ) return [] all_docs = [] chunks_per_source = max(1, self.chunks // len(self.vectorstores)) + logging.info( + f"ClassicRAG._get_data: Starting retrieval with chunks={self.chunks}, " + f"vectorstores={self.vectorstores}, chunks_per_source={chunks_per_source}, " + f"query='{self.question[:50]}...'" + ) + for vectorstore_id in self.vectorstores: if vectorstore_id: try: @@ -172,6 +182,10 @@ class ClassicRAG(BaseRetriever): exc_info=True, ) continue + logging.info( + f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents " + f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source})" + ) return all_docs def search(self, query: str = ""): diff --git a/application/vectorstore/base.py b/application/vectorstore/base.py index ea4885cd..84839059 100644 --- a/application/vectorstore/base.py +++ b/application/vectorstore/base.py @@ -1,3 +1,4 @@ +import logging import os from abc import ABC, abstractmethod @@ -9,13 +10,27 @@ from application.core.settings import settings class EmbeddingsWrapper: def __init__(self, model_name, *args, **kwargs): - self.model = SentenceTransformer( - model_name, - config_kwargs={"allow_dangerous_deserialization": True}, - *args, - **kwargs - ) - self.dimension = self.model.get_sentence_embedding_dimension() + logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}") + try: + kwargs.setdefault("trust_remote_code", True) + self.model = SentenceTransformer( + model_name, + config_kwargs={"allow_dangerous_deserialization": True}, + *args, + **kwargs, + ) + if self.model is None or self.model._first_module() is None: + raise ValueError( + f"SentenceTransformer model failed to load properly for: {model_name}" + ) + self.dimension = self.model.get_sentence_embedding_dimension() + logging.info(f"Successfully loaded model with dimension: {self.dimension}") + except Exception as e: + logging.error( + f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}", + exc_info=True, + ) + raise def embed_query(self, query: str): return self.model.encode(query).tolist() @@ -117,15 +132,29 @@ class BaseVectorStore(ABC): embeddings_name, openai_api_key=embeddings_key ) elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": - if os.path.exists("./models/all-mpnet-base-v2"): + possible_paths = [ + "/app/models/all-mpnet-base-v2", # Docker absolute path + "./models/all-mpnet-base-v2", # Relative path + ] + local_model_path = None + for path in possible_paths: + if os.path.exists(path): + local_model_path = path + logging.info(f"Found local model at path: {path}") + break + else: + logging.info(f"Path does not exist: {path}") + if local_model_path: embedding_instance = EmbeddingsSingleton.get_instance( - embeddings_name="./models/all-mpnet-base-v2", + local_model_path, ) else: + logging.warning( + f"Local model not found in any of the paths: {possible_paths}. Falling back to HuggingFace download." + ) embedding_instance = EmbeddingsSingleton.get_instance( embeddings_name, ) else: embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) - return embedding_instance diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index 92e8b961..7313cdfc 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -46,11 +46,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { image: '', source: '', sources: [], - chunks: '', - retriever: '', + chunks: '2', + retriever: 'classic', prompt_id: 'default', tools: [], - agent_type: '', + agent_type: 'classic', status: '', json_schema: undefined, }); @@ -122,7 +122,8 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { agent.name && agent.description && agent.prompt_id && agent.agent_type; const isJsonSchemaValidOrEmpty = jsonSchemaText.trim() === '' || jsonSchemaValid; - return hasRequiredFields && isJsonSchemaValidOrEmpty; + const hasSource = selectedSourceIds.size > 0; + return hasRequiredFields && isJsonSchemaValidOrEmpty && hasSource; }; const isJsonSchemaInvalid = () => { @@ -353,6 +354,26 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { getPrompts(); }, [token]); + // Auto-select default source if none selected + useEffect(() => { + if (sourceDocs && sourceDocs.length > 0 && selectedSourceIds.size === 0) { + const defaultSource = sourceDocs.find((s) => s.name === 'Default'); + if (defaultSource) { + setSelectedSourceIds( + new Set([ + defaultSource.id || defaultSource.retriever || defaultSource.name, + ]), + ); + } else { + setSelectedSourceIds( + new Set([ + sourceDocs[0].id || sourceDocs[0].retriever || sourceDocs[0].name, + ]), + ); + } + } + }, [sourceDocs, selectedSourceIds.size]); + useEffect(() => { if ((mode === 'edit' || mode === 'draft') && agentId) { const getAgent = async () => { @@ -650,7 +671,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { } selectedIds={selectedSourceIds} onSelectionChange={(newSelectedIds: Set) => { - setSelectedSourceIds(newSelectedIds); + if ( + newSelectedIds.size === 0 && + sourceDocs && + sourceDocs.length > 0 + ) { + const defaultSource = sourceDocs.find( + (s) => s.name === 'Default', + ); + if (defaultSource) { + setSelectedSourceIds( + new Set([ + defaultSource.id || + defaultSource.retriever || + defaultSource.name, + ]), + ); + } else { + setSelectedSourceIds( + new Set([ + sourceDocs[0].id || + sourceDocs[0].retriever || + sourceDocs[0].name, + ]), + ); + } + } else { + setSelectedSourceIds(newSelectedIds); + } }} title="Select Sources" searchPlaceholder="Search sources..." diff --git a/frontend/src/components/Table.tsx b/frontend/src/components/Table.tsx index 9e60f622..5c8878ff 100644 --- a/frontend/src/components/Table.tsx +++ b/frontend/src/components/Table.tsx @@ -6,7 +6,6 @@ interface TableProps { minWidth?: string; } - interface TableContainerProps { children: React.ReactNode; className?: string; @@ -34,45 +33,44 @@ interface TableCellProps { align?: 'left' | 'right' | 'center'; } -const TableContainer = React.forwardRef(({ - children, - className = '', - height = 'auto', - bordered = true -}, ref) => { - return ( -
-
- {children} +const TableContainer = React.forwardRef( + ({ children, className = '', height = 'auto', bordered = true }, ref) => { + return ( +
+
+ {children} +
-
- ); -});; + ); + }, +); +TableContainer.displayName = 'TableContainer'; + const Table: React.FC = ({ children, className = '', - minWidth = 'min-w-[600px]' + minWidth = 'min-w-[600px]', }) => { return ( - +
{children}
); }; const TableHead: React.FC = ({ children, className = '' }) => { return ( - + {children} ); @@ -86,12 +84,20 @@ const TableBody: React.FC = ({ children, className = '' }) => { ); }; -const TableRow: React.FC = ({ children, className = '', onClick }) => { - const baseClasses = "border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"; - const cursorClass = onClick ? "cursor-pointer" : ""; +const TableRow: React.FC = ({ + children, + className = '', + onClick, +}) => { + const baseClasses = + 'border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]'; + const cursorClass = onClick ? 'cursor-pointer' : ''; return ( - + {children} ); @@ -102,7 +108,7 @@ const TableHeader: React.FC = ({ className = '', minWidth, width, - align = 'left' + align = 'left', }) => { const getAlignmentClass = () => { switch (align) { @@ -133,7 +139,7 @@ const TableCell: React.FC = ({ className = '', minWidth, width, - align = 'left' + align = 'left', }) => { const getAlignmentClass = () => { switch (align) {