diff --git a/application/agents/base.py b/application/agents/base.py index 14ecad49..64fac17b 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -10,6 +10,7 @@ from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever +from bson.objectid import ObjectId class BaseAgent(ABC): @@ -23,7 +24,7 @@ class BaseAgent(ABC): prompt: str = "", chat_history: Optional[List[Dict]] = None, decoded_token: Optional[Dict] = None, - attachments: Optional[List[Dict]]=None, + attachments: Optional[List[Dict]] = None, ): self.endpoint = endpoint self.llm_name = llm_name @@ -58,6 +59,27 @@ class BaseAgent(ABC): ) -> Generator[Dict, None, None]: pass + def _get_tools(self, api_key: str = None) -> Dict[str, Dict]: + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + agents_collection = db["agents"] + tools_collection = db["user_tools"] + + agent_data = agents_collection.find_one({"key": api_key or self.user_api_key}) + tool_ids = agent_data.get("tools", []) if agent_data else [] + + tools = ( + tools_collection.find( + {"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}} + ) + if tool_ids + else [] + ) + tools = list(tools) + tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {} + + return tools_by_id + def _get_user_tools(self, user="local"): mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -243,9 +265,11 @@ class BaseAgent(ABC): tools_dict: Dict, messages: List[Dict], log_context: Optional[LogContext] = None, - attachments: Optional[List[Dict]] = None + attachments: Optional[List[Dict]] = None, ): - resp = self.llm_handler.handle_response(self, resp, tools_dict, messages, attachments) + resp = self.llm_handler.handle_response( + self, resp, tools_dict, messages, attachments + ) if log_context: data = build_stack_data(self.llm_handler) log_context.stacks.append({"component": "llm_handler", "data": data}) diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 8446347c..bf472cd0 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -5,21 +5,25 @@ from application.logging import LogContext from application.retriever.base import BaseRetriever import logging + logger = logging.getLogger(__name__) + class ClassicAgent(BaseAgent): def _gen_inner( self, query: str, retriever: BaseRetriever, log_context: LogContext ) -> Generator[Dict, None, None]: retrieved_data = self._retriever_search(retriever, query, log_context) - - tools_dict = self._get_user_tools(self.user) + if self.user_api_key: + tools_dict = self._get_tools(self.user_api_key) + else: + tools_dict = self._get_user_tools(self.user) self._prepare_tools(tools_dict) messages = self._build_messages(self.prompt, query, retrieved_data) resp = self._llm_gen(messages, log_context) - + attachments = self.attachments if isinstance(resp, str): @@ -33,7 +37,7 @@ class ClassicAgent(BaseAgent): yield {"answer": resp.message.content} return - resp = self._llm_handler(resp, tools_dict, messages, log_context,attachments) + resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments) if isinstance(resp, str): yield {"answer": resp} diff --git a/application/agents/react_agent.py b/application/agents/react_agent.py index f721b487..3fae1fda 100644 --- a/application/agents/react_agent.py +++ b/application/agents/react_agent.py @@ -30,7 +30,10 @@ class ReActAgent(BaseAgent): ) -> Generator[Dict, None, None]: retrieved_data = self._retriever_search(retriever, query, log_context) - tools_dict = self._get_user_tools(self.user) + if self.user_api_key: + tools_dict = self._get_tools(self.user_api_key) + else: + tools_dict = self._get_user_tools(self.user) self._prepare_tools(tools_dict) docs_together = "\n".join([doc["text"] for doc in retrieved_data]) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index ef9b7381..8f44385b 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -27,7 +27,7 @@ db = mongo["docsgpt"] conversations_collection = db["conversations"] sources_collection = db["sources"] prompts_collection = db["prompts"] -api_key_collection = db["api_keys"] +agents_collection = db["agents"] user_logs_collection = db["user_logs"] attachments_collection = db["attachments"] @@ -86,19 +86,42 @@ def run_async_chain(chain, question, chat_history): return result -def get_data_from_api_key(api_key): - data = api_key_collection.find_one({"key": api_key}) - # # Raise custom exception if the API key is not found - if data is None: - raise Exception("Invalid API Key, please generate new key", 401) +def get_agent_key(agent_id, user_id): + if not agent_id: + return None - if "source" in data and isinstance(data["source"], DBRef): - source_doc = db.dereference(data["source"]) + try: + agent = agents_collection.find_one({"_id": ObjectId(agent_id)}) + if agent is None: + raise Exception("Agent not found", 404) + + if agent.get("user") == user_id: + agents_collection.update_one( + {"_id": ObjectId(agent_id)}, + {"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}}, + ) + return str(agent["key"]) + + raise Exception("Unauthorized access to the agent", 403) + + except Exception as e: + logger.error(f"Error in get_agent_key: {str(e)}") + raise + + +def get_data_from_api_key(api_key): + data = agents_collection.find_one({"key": api_key}) + if not data: + raise Exception("Invalid API Key, please generate a new key", 401) + + source = data.get("source") + if isinstance(source, DBRef): + source_doc = db.dereference(source) data["source"] = str(source_doc["_id"]) - if "retriever" in source_doc: - data["retriever"] = source_doc["retriever"] + data["retriever"] = source_doc.get("retriever", data.get("retriever")) else: data["source"] = {} + return data @@ -128,7 +151,8 @@ def save_conversation( llm, decoded_token, index=None, - api_key=None + api_key=None, + agent_id=None, ): current_time = datetime.datetime.now(datetime.timezone.utc) if conversation_id is not None and index is not None: @@ -202,7 +226,9 @@ def save_conversation( ], } if api_key: - api_key_doc = api_key_collection.find_one({"key": api_key}) + if agent_id: + conversation_data["agent_id"] = agent_id + api_key_doc = agents_collection.find_one({"key": api_key}) if api_key_doc: conversation_data["api_key"] = api_key_doc["key"] conversation_id = conversations_collection.insert_one( @@ -234,6 +260,7 @@ def complete_stream( index=None, should_save_conversation=True, attachments=None, + agent_id=None, ): try: response_full, thought, source_log_docs, tool_calls = "", "", [], [] @@ -241,7 +268,9 @@ def complete_stream( if attachments: attachment_ids = [attachment["id"] for attachment in attachments] - logger.info(f"Processing request with {len(attachments)} attachments: {attachment_ids}") + logger.info( + f"Processing request with {len(attachments)} attachments: {attachment_ids}" + ) answer = agent.gen(query=question, retriever=retriever) @@ -294,7 +323,8 @@ def complete_stream( llm, decoded_token, index, - api_key=user_api_key + api_key=user_api_key, + agent_id=agent_id, ) else: conversation_id = None @@ -366,7 +396,9 @@ class Stream(Resource): required=False, description="Index of the query to update" ), "save_conversation": fields.Boolean( - required=False, default=True, description="Whether to save the conversation" + required=False, + default=True, + description="Whether to save the conversation", ), "attachments": fields.List( fields.String, required=False, description="List of attachment IDs" @@ -400,6 +432,14 @@ class Stream(Resource): chunks = int(data.get("chunks", 2)) token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) retriever_name = data.get("retriever", "classic") + agent_id = data.get("agent_id", None) + agent_type = settings.AGENT_NAME + agent_key = get_agent_key(agent_id, request.decoded_token.get("sub")) + + if agent_key: + data.update({"api_key": agent_key}) + else: + agent_id = None if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) @@ -408,6 +448,7 @@ class Stream(Resource): source = {"active_docs": data_key.get("source")} retriever_name = data_key.get("retriever", retriever_name) user_api_key = data["api_key"] + agent_type = data_key.get("agent_type", agent_type) decoded_token = {"sub": data_key.get("user")} elif "active_docs" in data: @@ -423,8 +464,10 @@ class Stream(Resource): if not decoded_token: return make_response({"error": "Unauthorized"}, 401) - - attachments = get_attachments_content(attachment_ids, decoded_token.get("sub")) + + attachments = get_attachments_content( + attachment_ids, decoded_token.get("sub") + ) logger.info( f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}", @@ -436,7 +479,7 @@ class Stream(Resource): chunks = 0 agent = AgentCreator.create_agent( - settings.AGENT_NAME, + agent_type, endpoint="stream", llm_name=settings.LLM_NAME, gpt_model=gpt_model, @@ -471,6 +514,7 @@ class Stream(Resource): isNoneDoc=data.get("isNoneDoc"), index=index, should_save_conversation=save_conv, + agent_id=agent_id, ), mimetype="text/event-stream", ) @@ -552,6 +596,7 @@ class Answer(Resource): chunks = int(data.get("chunks", 2)) token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) retriever_name = data.get("retriever", "classic") + agent_type = settings.AGENT_NAME if "api_key" in data: data_key = get_data_from_api_key(data["api_key"]) @@ -560,6 +605,7 @@ class Answer(Resource): source = {"active_docs": data_key.get("source")} retriever_name = data_key.get("retriever", retriever_name) user_api_key = data["api_key"] + agent_type = data_key.get("agent_type", agent_type) decoded_token = {"sub": data_key.get("user")} elif "active_docs" in data: @@ -584,7 +630,7 @@ class Answer(Resource): ) agent = AgentCreator.create_agent( - settings.AGENT_NAME, + agent_type, endpoint="api/answer", llm_name=settings.LLM_NAME, gpt_model=gpt_model, @@ -815,28 +861,27 @@ class Search(Resource): def get_attachments_content(attachment_ids, user): """ Retrieve content from attachment documents based on their IDs. - + Args: attachment_ids (list): List of attachment document IDs user (str): User identifier to verify ownership - + Returns: list: List of dictionaries containing attachment content and metadata """ if not attachment_ids: return [] - + attachments = [] for attachment_id in attachment_ids: try: - attachment_doc = attachments_collection.find_one({ - "_id": ObjectId(attachment_id), - "user": user - }) - + attachment_doc = attachments_collection.find_one( + {"_id": ObjectId(attachment_id), "user": user} + ) + if attachment_doc: attachments.append(attachment_doc) except Exception as e: logger.error(f"Error retrieving attachment {attachment_id}: {e}") - + return attachments diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 91b028d5..8876be6b 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -28,7 +28,7 @@ conversations_collection = db["conversations"] sources_collection = db["sources"] prompts_collection = db["prompts"] feedback_collection = db["feedback"] -api_key_collection = db["api_keys"] +agents_collection = db["agents"] token_usage_collection = db["token_usage"] shared_conversations_collections = db["shared_conversations"] user_logs_collection = db["user_logs"] @@ -138,14 +138,24 @@ class GetConversations(Resource): try: conversations = ( conversations_collection.find( - {"api_key": {"$exists": False}, "user": decoded_token.get("sub")} + { + "$or": [ + {"api_key": {"$exists": False}}, + {"agent_id": {"$exists": True}}, + ], + "user": decoded_token.get("sub"), + } ) .sort("date", -1) .limit(30) ) list_conversations = [ - {"id": str(conversation["_id"]), "name": conversation["name"]} + { + "id": str(conversation["_id"]), + "name": conversation["name"], + "agent_id": conversation.get("agent_id", None), + } for conversation in conversations ] except Exception as err: @@ -179,7 +189,12 @@ class GetSingleConversation(Resource): except Exception as err: current_app.logger.error(f"Error retrieving conversation: {err}") return make_response(jsonify({"success": False}), 400) - return make_response(jsonify(conversation["queries"]), 200) + + data = { + "queries": conversation["queries"], + "agent_id": conversation.get("agent_id"), + } + return make_response(jsonify(data), 200) @user_ns.route("/api/update_conversation_name") @@ -920,124 +935,398 @@ class UpdatePrompt(Resource): return make_response(jsonify({"success": True}), 200) -@user_ns.route("/api/get_api_keys") -class GetApiKeys(Resource): - @api.doc(description="Retrieve API keys for the user") +@user_ns.route("/api/get_agent") +class GetAgent(Resource): + @api.doc(params={"id": "ID of the agent"}, description="Get a single agent by ID") + def get(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") + agent_id = request.args.get("id") + if not agent_id: + return make_response( + jsonify({"success": False, "message": "ID is required"}), 400 + ) + + try: + agent = agents_collection.find_one( + {"_id": ObjectId(agent_id), "user": user} + ) + if not agent: + return make_response(jsonify({"status": "Not found"}), 404) + data = { + "id": str(agent["_id"]), + "name": agent["name"], + "description": agent["description"], + "source": ( + str(db.dereference(agent["source"])["_id"]) + if "source" in agent and isinstance(agent["source"], DBRef) + else "" + ), + "chunks": agent["chunks"], + "retriever": agent.get("retriever", ""), + "prompt_id": agent["prompt_id"], + "tools": agent.get("tools", []), + "agent_type": agent["agent_type"], + "status": agent["status"], + "createdAt": agent["createdAt"], + "updatedAt": agent["updatedAt"], + "lastUsedAt": agent["lastUsedAt"], + "key": f"{agent['key'][:4]}...{agent['key'][-4:]}", + } + except Exception as err: + current_app.logger.error(f"Error retrieving agent: {err}") + return make_response(jsonify({"success": False}), 400) + + return make_response(jsonify(data), 200) + + +@user_ns.route("/api/get_agents") +class GetAgents(Resource): + @api.doc(description="Retrieve agents for the user") def get(self): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) user = decoded_token.get("sub") try: - keys = api_key_collection.find({"user": user}) - list_keys = [] - for key in keys: - if "source" in key and isinstance(key["source"], DBRef): - source = db.dereference(key["source"]) - if source is None: - continue - source_name = source["name"] - elif "retriever" in key: - source_name = key["retriever"] - else: - continue - - list_keys.append( - { - "id": str(key["_id"]), - "name": key["name"], - "key": key["key"][:4] + "..." + key["key"][-4:], - "source": source_name, - "prompt_id": key["prompt_id"], - "chunks": key["chunks"], - } - ) + agents = agents_collection.find({"user": user}) + list_agents = [ + { + "id": str(agent["_id"]), + "name": agent["name"], + "description": agent["description"], + "source": ( + str(db.dereference(agent["source"])["_id"]) + if "source" in agent and isinstance(agent["source"], DBRef) + else "" + ), + "chunks": agent["chunks"], + "retriever": agent.get("retriever", ""), + "prompt_id": agent["prompt_id"], + "tools": agent.get("tools", []), + "agent_type": agent["agent_type"], + "status": agent["status"], + "created_at": agent["createdAt"], + "updated_at": agent["updatedAt"], + "last_used_at": agent["lastUsedAt"], + "key": f"{agent['key'][:4]}...{agent['key'][-4:]}", + } + for agent in agents + if "source" in agent or "retriever" in agent + ] except Exception as err: - current_app.logger.error(f"Error retrieving API keys: {err}") + current_app.logger.error(f"Error retrieving agents: {err}") return make_response(jsonify({"success": False}), 400) - return make_response(jsonify(list_keys), 200) + return make_response(jsonify(list_agents), 200) -@user_ns.route("/api/create_api_key") -class CreateApiKey(Resource): - create_api_key_model = api.model( - "CreateApiKeyModel", +@user_ns.route("/api/create_agent") +class CreateAgent(Resource): + create_agent_model = api.model( + "CreateAgentModel", { - "name": fields.String(required=True, description="Name of the API key"), - "prompt_id": fields.String(required=True, description="Prompt ID"), + "name": fields.String(required=True, description="Name of the agent"), + "description": fields.String( + required=True, description="Description of the agent" + ), + "image": fields.String( + required=False, description="Image URL or identifier" + ), + "source": fields.String(required=True, description="Source ID"), "chunks": fields.Integer(required=True, description="Chunks count"), - "source": fields.String(description="Source ID (optional)"), - "retriever": fields.String(description="Retriever (optional)"), + "retriever": fields.String(required=True, description="Retriever ID"), + "prompt_id": fields.String(required=True, description="Prompt ID"), + "tools": fields.List( + fields.String, required=False, description="List of tool identifiers" + ), + "agent_type": fields.String(required=True, description="Type of the agent"), + "status": fields.String( + required=True, description="Status of the agent (draft or published)" + ), }, ) - @api.expect(create_api_key_model) - @api.doc(description="Create a new API key") + @api.expect(create_agent_model) + @api.doc(description="Create a new agent") def post(self): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) user = decoded_token.get("sub") data = request.get_json() - required_fields = ["name", "prompt_id", "chunks"] + + if data.get("status") not in ["draft", "published"]: + return make_response( + jsonify({"success": False, "message": "Invalid status"}), 400 + ) + + required_fields = [] + if data.get("status") == "published": + required_fields = [ + "name", + "description", + "source", + "chunks", + "retriever", + "prompt_id", + "agent_type", + ] + else: + required_fields = ["name"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: key = str(uuid.uuid4()) - new_api_key = { - "name": data["name"], - "key": key, + new_agent = { "user": user, - "prompt_id": data["prompt_id"], - "chunks": data["chunks"], + "name": data.get("name"), + "description": data.get("description", ""), + "image": data.get("image", ""), + "source": ( + DBRef("sources", ObjectId(data.get("source"))) + if ObjectId.is_valid(data.get("source")) + else "" + ), + "chunks": data.get("chunks", ""), + "retriever": data.get("retriever", ""), + "prompt_id": data.get("prompt_id", ""), + "tools": data.get("tools", []), + "agent_type": data.get("agent_type", ""), + "status": data.get("status"), + "createdAt": datetime.datetime.now(datetime.timezone.utc), + "updatedAt": datetime.datetime.now(datetime.timezone.utc), + "lastUsedAt": None, + "key": key, } - if "source" in data and ObjectId.is_valid(data["source"]): - new_api_key["source"] = DBRef("sources", ObjectId(data["source"])) - if "retriever" in data: - new_api_key["retriever"] = data["retriever"] - resp = api_key_collection.insert_one(new_api_key) + resp = agents_collection.insert_one(new_agent) new_id = str(resp.inserted_id) except Exception as err: - current_app.logger.error(f"Error creating API key: {err}") + current_app.logger.error(f"Error creating agent: {err}") return make_response(jsonify({"success": False}), 400) return make_response(jsonify({"id": new_id, "key": key}), 201) -@user_ns.route("/api/delete_api_key") -class DeleteApiKey(Resource): - delete_api_key_model = api.model( - "DeleteApiKeyModel", - {"id": fields.String(required=True, description="API Key ID to delete")}, +@user_ns.route("/api/update_agent/") +class UpdateAgent(Resource): + update_agent_model = api.model( + "UpdateAgentModel", + { + "name": fields.String(required=True, description="New name of the agent"), + "description": fields.String( + required=True, description="New description of the agent" + ), + "image": fields.String( + required=False, description="New image URL or identifier" + ), + "source": fields.String(required=True, description="Source ID"), + "chunks": fields.Integer(required=True, description="Chunks count"), + "retriever": fields.String(required=True, description="Retriever ID"), + "prompt_id": fields.String(required=True, description="Prompt ID"), + "tools": fields.List( + fields.String, required=False, description="List of tool identifiers" + ), + "agent_type": fields.String(required=True, description="Type of the agent"), + "status": fields.String( + required=True, description="Status of the agent (draft or published)" + ), + }, ) - @api.expect(delete_api_key_model) - @api.doc(description="Delete an API key by ID") - def post(self): + @api.expect(update_agent_model) + @api.doc(description="Update an existing agent") + def put(self, agent_id): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) user = decoded_token.get("sub") data = request.get_json() - required_fields = ["id"] - missing_fields = check_required_fields(data, required_fields) - if missing_fields: - return missing_fields + + if not ObjectId.is_valid(agent_id): + return make_response( + jsonify({"success": False, "message": "Invalid agent ID format"}), 400 + ) + oid = ObjectId(agent_id) try: - result = api_key_collection.delete_one( - {"_id": ObjectId(data["id"]), "user": user} - ) - if result.deleted_count == 0: - return {"success": False, "message": "API Key not found"}, 404 + existing_agent = agents_collection.find_one({"_id": oid, "user": user}) except Exception as err: - current_app.logger.error(f"Error deleting API key: {err}") - return {"success": False}, 400 + return make_response( + current_app.logger.error(f"Error finding agent {agent_id}: {err}"), + jsonify({"success": False, "message": "Database error finding agent"}), + 500, + ) - return {"success": True}, 200 + if not existing_agent: + return make_response( + jsonify( + {"success": False, "message": "Agent not found or not authorized"} + ), + 404, + ) + + update_fields = {} + allowed_fields = [ + "name", + "description", + "image", + "source", + "chunks", + "retriever", + "prompt_id", + "tools", + "agent_type", + "status", + ] + + for field in allowed_fields: + if field in data: + if field == "status": + new_status = data.get("status") + if new_status not in ["draft", "published"]: + return make_response( + jsonify( + {"success": False, "message": "Invalid status value"} + ), + 400, + ) + update_fields[field] = new_status + elif field == "source": + source_id = data.get("source") + if source_id and ObjectId.is_valid(source_id): + update_fields[field] = DBRef("sources", ObjectId(source_id)) + elif source_id: + return make_response( + jsonify( + { + "success": False, + "message": "Invalid source ID format provided", + } + ), + 400, + ) + else: + update_fields[field] = "" + else: + update_fields[field] = data[field] + + if not update_fields: + return make_response( + jsonify({"success": False, "message": "No update data provided"}), 400 + ) + + final_status = update_fields.get("status", existing_agent.get("status")) + if final_status == "published": + required_published_fields = [ + "name", + "description", + "source", + "chunks", + "retriever", + "prompt_id", + "agent_type", + ] + missing_published_fields = [] + for req_field in required_published_fields: + final_value = update_fields.get( + req_field, existing_agent.get(req_field) + ) + if req_field == "source" and final_value: + if not isinstance(final_value, DBRef): + missing_published_fields.append(req_field) + + if missing_published_fields: + return make_response( + jsonify( + { + "success": False, + "message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}", + } + ), + 400, + ) + + update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc) + + try: + result = agents_collection.update_one( + {"_id": oid, "user": user}, {"$set": update_fields} + ) + + if result.matched_count == 0: + return make_response( + jsonify( + { + "success": False, + "message": "Agent not found or update failed unexpectedly", + } + ), + 404, + ) + if result.modified_count == 0 and result.matched_count == 1: + return make_response( + jsonify( + { + "success": True, + "message": "Agent found, but no changes were applied.", + } + ), + 304, + ) + + except Exception as err: + current_app.logger.error(f"Error updating agent {agent_id}: {err}") + return make_response( + jsonify({"success": False, "message": "Database error during update"}), + 500, + ) + + return make_response( + jsonify( + { + "success": True, + "id": agent_id, + "message": "Agent updated successfully", + } + ), + 200, + ) + + +@user_ns.route("/api/delete_agent") +class DeleteAgent(Resource): + @api.doc(params={"id": "ID of the agent"}, description="Delete an agent by ID") + def delete(self): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user = decoded_token.get("sub") + agent_id = request.args.get("id") + if not agent_id: + return make_response( + jsonify({"success": False, "message": "ID is required"}), 400 + ) + + try: + deleted_agent = agents_collection.find_one_and_delete( + {"_id": ObjectId(agent_id), "user": user} + ) + if not deleted_agent: + return make_response( + jsonify({"success": False, "message": "Agent not found"}), 404 + ) + deleted_id = str(deleted_agent["_id"]) + + except Exception as err: + current_app.logger.error(f"Error deleting agent: {err}") + return make_response(jsonify({"success": False}), 400) + + return make_response(jsonify({"id": deleted_id}), 200) @user_ns.route("/api/share") @@ -1112,9 +1401,7 @@ class ShareConversation(Resource): if "retriever" in data: new_api_key_data["retriever"] = data["retriever"] - pre_existing_api_document = api_key_collection.find_one( - new_api_key_data - ) + pre_existing_api_document = agents_collection.find_one(new_api_key_data) if pre_existing_api_document: api_uuid = pre_existing_api_document["key"] pre_existing = shared_conversations_collections.find_one( @@ -1173,7 +1460,7 @@ class ShareConversation(Resource): if "retriever" in data: new_api_key_data["retriever"] = data["retriever"] - api_key_collection.insert_one(new_api_key_data) + agents_collection.insert_one(new_api_key_data) shared_conversations_collections.insert_one( { "uuid": explicit_binary, @@ -1331,9 +1618,9 @@ class GetMessageAnalytics(Resource): try: api_key = ( - api_key_collection.find_one( - {"_id": ObjectId(api_key_id), "user": user} - )["key"] + agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[ + "key" + ] if api_key_id else None ) @@ -1375,7 +1662,7 @@ class GetMessageAnalytics(Resource): } if api_key: match_stage["$match"]["api_key"] = api_key - + pipeline = [ match_stage, {"$unwind": "$queries"}, @@ -1455,9 +1742,9 @@ class GetTokenAnalytics(Resource): try: api_key = ( - api_key_collection.find_one( - {"_id": ObjectId(api_key_id), "user": user} - )["key"] + agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[ + "key" + ] if api_key_id else None ) @@ -1614,9 +1901,9 @@ class GetFeedbackAnalytics(Resource): try: api_key = ( - api_key_collection.find_one( - {"_id": ObjectId(api_key_id), "user": user} - )["key"] + agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[ + "key" + ] if api_key_id else None ) @@ -1779,7 +2066,7 @@ class GetUserLogs(Resource): try: api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"] if api_key_id else None ) @@ -2493,10 +2780,10 @@ class StoreAttachment(Resource): decoded_token = request.decoded_token if not decoded_token: return make_response(jsonify({"success": False}), 401) - + # Get single file instead of list file = request.files.get("file") - + if not file or file.filename == "": return make_response( jsonify({"status": "error", "message": "Missing file"}), @@ -2504,46 +2791,43 @@ class StoreAttachment(Resource): ) user = secure_filename(decoded_token.get("sub")) - + try: attachment_id = ObjectId() original_filename = secure_filename(file.filename) - + save_dir = os.path.join( - current_dir, + current_dir, settings.UPLOAD_FOLDER, user, - "attachments", - str(attachment_id) + "attachments", + str(attachment_id), ) os.makedirs(save_dir, exist_ok=True) - + file_path = os.path.join(save_dir, original_filename) - - + file.save(file_path) file_info = { "filename": original_filename, - "attachment_id": str(attachment_id) + "attachment_id": str(attachment_id), } current_app.logger.info(f"Saved file: {file_path}") - + # Start async task to process single file - task = store_attachment.delay( - save_dir, - file_info, - user - ) - + task = store_attachment.delay(save_dir, file_info, user) + return make_response( - jsonify({ - "success": True, - "task_id": task.id, - "message": "File uploaded successfully. Processing started." - }), - 200 + jsonify( + { + "success": True, + "task_id": task.id, + "message": "File uploaded successfully. Processing started.", + } + ), + 200, ) - + except Exception as err: current_app.logger.error(f"Error storing attachment: {err}") return make_response(jsonify({"success": False, "error": str(err)}), 400) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index d70a202f..043bbf58 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -49,8 +49,8 @@ "husky": "^8.0.0", "lint-staged": "^15.3.0", "postcss": "^8.4.49", - "prettier": "^3.4.2", - "prettier-plugin-tailwindcss": "^0.6.9", + "prettier": "^3.5.3", + "prettier-plugin-tailwindcss": "^0.6.11", "tailwindcss": "^3.4.17", "typescript": "^5.7.2", "vite": "^5.4.14", @@ -1635,7 +1635,7 @@ "version": "18.3.0", "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.0.tgz", "integrity": "sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==", - "dev": true, + "devOptional": true, "dependencies": { "@types/react": "*" } @@ -7648,10 +7648,11 @@ } }, "node_modules/prettier": { - "version": "3.4.2", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.4.2.tgz", - "integrity": "sha512-e9MewbtFo+Fevyuxn/4rrcDAaq0IYxPGLvObpQjiZBMAzB9IGmzlnG9RZy3FFas+eBMu2vA0CszMeduow5dIuQ==", + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.5.3.tgz", + "integrity": "sha512-QQtaxnoDJeAkDvDKWCLiwIXkTgRhwYDEQCghU9Z6q03iyek/rxRh/2lC3HB7P8sWT2xC/y5JDctPLBIGzHKbhw==", "dev": true, + "license": "MIT", "bin": { "prettier": "bin/prettier.cjs" }, @@ -7675,10 +7676,11 @@ } }, "node_modules/prettier-plugin-tailwindcss": { - "version": "0.6.9", - "resolved": "https://registry.npmjs.org/prettier-plugin-tailwindcss/-/prettier-plugin-tailwindcss-0.6.9.tgz", - "integrity": "sha512-r0i3uhaZAXYP0At5xGfJH876W3HHGHDp+LCRUJrs57PBeQ6mYHMwr25KH8NPX44F2yGTvdnH7OqCshlQx183Eg==", + "version": "0.6.11", + "resolved": "https://registry.npmjs.org/prettier-plugin-tailwindcss/-/prettier-plugin-tailwindcss-0.6.11.tgz", + "integrity": "sha512-YxaYSIvZPAqhrrEpRtonnrXdghZg1irNg4qrjboCXrpybLWVs55cW2N3juhspVJiO0JBvYJT8SYsJpc8OQSnsA==", "dev": true, + "license": "MIT", "engines": { "node": ">=14.21.3" }, @@ -7687,7 +7689,7 @@ "@prettier/plugin-pug": "*", "@shopify/prettier-plugin-liquid": "*", "@trivago/prettier-plugin-sort-imports": "*", - "@zackad/prettier-plugin-twig-melody": "*", + "@zackad/prettier-plugin-twig": "*", "prettier": "^3.0", "prettier-plugin-astro": "*", "prettier-plugin-css-order": "*", @@ -7714,7 +7716,7 @@ "@trivago/prettier-plugin-sort-imports": { "optional": true }, - "@zackad/prettier-plugin-twig-melody": { + "@zackad/prettier-plugin-twig": { "optional": true }, "prettier-plugin-astro": { @@ -9376,7 +9378,7 @@ "version": "5.7.2", "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.2.tgz", "integrity": "sha512-i5t66RHxDvVN40HfDd1PsEThGNnlMCMT3jMUuoh9/0TaqWevNontacunWyN02LA9/fIbEWlcHZcgTKb9QoaLfg==", - "dev": true, + "devOptional": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" diff --git a/frontend/package.json b/frontend/package.json index 89a55b04..45058e98 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -28,8 +28,8 @@ "react-chartjs-2": "^5.3.0", "react-copy-to-clipboard": "^5.1.0", "react-dom": "^18.3.1", - "react-helmet": "^6.1.0", "react-dropzone": "^14.3.5", + "react-helmet": "^6.1.0", "react-i18next": "^15.4.0", "react-markdown": "^9.0.1", "react-redux": "^8.0.5", @@ -60,8 +60,8 @@ "husky": "^8.0.0", "lint-staged": "^15.3.0", "postcss": "^8.4.49", - "prettier": "^3.4.2", - "prettier-plugin-tailwindcss": "^0.6.9", + "prettier": "^3.5.3", + "prettier-plugin-tailwindcss": "^0.6.11", "tailwindcss": "^3.4.17", "typescript": "^5.7.2", "vite": "^5.4.14", diff --git a/frontend/prettier.config.cjs b/frontend/prettier.config.cjs index c92ea504..8b38ecfa 100644 --- a/frontend/prettier.config.cjs +++ b/frontend/prettier.config.cjs @@ -4,4 +4,5 @@ module.exports = { semi: true, singleQuote: true, printWidth: 80, -} + plugins: ['prettier-plugin-tailwindcss'], +}; diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 33c66bd1..41b05ac0 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -12,13 +12,14 @@ import useTokenAuth from './hooks/useTokenAuth'; import Navigation from './Navigation'; import PageNotFound from './PageNotFound'; import Setting from './settings'; +import Agents from './agents'; function AuthWrapper({ children }: { children: React.ReactNode }) { const { isAuthLoading } = useTokenAuth(); if (isAuthLoading) { return ( -
+
); @@ -31,7 +32,7 @@ function MainLayout() { const [navOpen, setNavOpen] = useState(!isMobile); return ( -
+
; } return ( -
+
} /> } /> } /> + } /> } /> } /> diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index 0068c3d3..0e357a6d 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -3,6 +3,7 @@ import { useTranslation } from 'react-i18next'; import { useDispatch, useSelector } from 'react-redux'; import { NavLink, useNavigate } from 'react-router-dom'; +import { Agent } from './agents/types'; import conversationService from './api/services/conversationService'; import userService from './api/services/userService'; import Add from './assets/add.svg'; @@ -12,11 +13,12 @@ import Expand from './assets/expand.svg'; import Github from './assets/github.svg'; import Hamburger from './assets/hamburger.svg'; import openNewChat from './assets/openNewChat.svg'; +import Robot from './assets/robot.svg'; import SettingGear from './assets/settingGear.svg'; +import Spark from './assets/spark.svg'; import SpinnerDark from './assets/spinner-dark.svg'; import Spinner from './assets/spinner.svg'; import Twitter from './assets/TwitterX.svg'; -import UploadIcon from './assets/upload.svg'; import Help from './components/Help'; import { handleAbort, @@ -33,13 +35,15 @@ import JWTModal from './modals/JWTModal'; import { ActiveState } from './models/misc'; import { getConversations } from './preferences/preferenceApi'; import { - selectApiKeyStatus, selectConversationId, selectConversations, selectModalStateDeleteConv, + selectSelectedAgent, selectToken, setConversations, setModalStateDeleteConv, + setSelectedAgent, + setAgents, } from './preferences/preferenceSlice'; import Upload from './upload/Upload'; @@ -50,36 +54,28 @@ interface NavigationProps { export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const dispatch = useDispatch(); + const navigate = useNavigate(); + + const { t } = useTranslation(); + const token = useSelector(selectToken); const queries = useSelector(selectQueries); const conversations = useSelector(selectConversations); - const modalStateDeleteConv = useSelector(selectModalStateDeleteConv); const conversationId = useSelector(selectConversationId); - const [isDeletingConversation, setIsDeletingConversation] = useState(false); + const modalStateDeleteConv = useSelector(selectModalStateDeleteConv); + const selectedAgent = useSelector(selectSelectedAgent); const { isMobile } = useMediaQuery(); const [isDarkTheme] = useDarkTheme(); - const { t } = useTranslation(); - const isApiKeySet = useSelector(selectApiKeyStatus); - const { showTokenModal, handleTokenSubmit } = useTokenAuth(); + const [isDeletingConversation, setIsDeletingConversation] = useState(false); const [uploadModalState, setUploadModalState] = useState('INACTIVE'); + const [recentAgents, setRecentAgents] = useState([]); const navRef = useRef(null); - const navigate = useNavigate(); - - useEffect(() => { - if (!conversations?.data) { - fetchConversations(); - } - if (queries.length === 0) { - resetConversation(); - } - }, [conversations?.data, dispatch]); - async function fetchConversations() { dispatch(setConversations({ ...conversations, loading: true })); return await getConversations(token) @@ -92,6 +88,29 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { }); } + async function getAgents() { + const response = await userService.getAgents(token); + if (!response.ok) throw new Error('Failed to fetch agents'); + const data: Agent[] = await response.json(); + dispatch(setAgents(data)); + setRecentAgents( + data + .filter((agent: Agent) => agent.status === 'published') + .sort( + (a: Agent, b: Agent) => + new Date(b.last_used_at ?? 0).getTime() - + new Date(a.last_used_at ?? 0).getTime(), + ) + .slice(0, 3), + ); + } + + useEffect(() => { + if (recentAgents.length === 0) getAgents(); + if (!conversations?.data) fetchConversations(); + if (queries.length === 0) resetConversation(); + }, [conversations?.data, dispatch]); + const handleDeleteAllConversations = () => { setIsDeletingConversation(true); conversationService @@ -113,18 +132,34 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { .catch((error) => console.error(error)); }; + const handleAgentClick = (agent: Agent) => { + resetConversation(); + dispatch(setSelectedAgent(agent)); + if (isMobile) setNavOpen(!navOpen); + navigate('/'); + }; + const handleConversationClick = (index: string) => { conversationService .getConversation(index, token) .then((response) => response.json()) .then((data) => { navigate('/'); - dispatch(setConversation(data)); + dispatch(setConversation(data.queries)); dispatch( updateConversationId({ query: { conversationId: index }, }), ); + if (data.agent_id) { + userService.getAgent(data.agent_id, token).then((response) => { + if (response.ok) { + response.json().then((agent: Agent) => { + dispatch(setSelectedAgent(agent)); + }); + } + }); + } else dispatch(setSelectedAgent(null)); }); }; @@ -136,6 +171,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { query: { conversationId: null }, }), ); + dispatch(setSelectedAgent(null)); }; const newChat = () => { @@ -170,8 +206,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { return ( <> {!navOpen && ( -
-
+
+
)} -
+
DocsGPT
@@ -208,13 +244,13 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { ref={navRef} className={`${ !navOpen && '-ml-96 md:-ml-[18rem]' - } duration-20 fixed top-0 z-20 flex h-full w-72 flex-col border-r-[1px] border-b-0 bg-lotion dark:bg-chinese-black transition-all dark:border-r-purple-taupe dark:text-white`} + } duration-20 fixed top-0 z-20 flex h-full w-72 flex-col border-b-0 border-r-[1px] bg-lotion transition-all dark:border-r-purple-taupe dark:bg-chinese-black dark:text-white`} >
{ if (isMobile) { setNavOpen(!navOpen); @@ -252,7 +288,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { className={({ isActive }) => `${ isActive ? 'bg-transparent' : '' - } group sticky mx-4 mt-4 flex cursor-pointer gap-2.5 rounded-3xl border border-silver p-3 hover:border-rainy-gray dark:border-purple-taupe dark:text-white hover:bg-transparent` + } group sticky mx-4 mt-4 flex cursor-pointer gap-2.5 rounded-3xl border border-silver p-3 hover:border-rainy-gray hover:bg-transparent dark:border-purple-taupe dark:text-white` } > {conversations?.loading && !isDeletingConversation && ( -
+
)} - {conversations?.data && conversations.data.length > 0 ? ( + {recentAgents?.length > 0 ? (
-
-

{t('chats')}

+
+

Agents

+
+
+
+ {recentAgents.map((agent, idx) => ( +
handleAgentClick(agent)} + > +
+ agent-logo +
+

+ {agent.name} +

+
+ ))} +
+
{ + dispatch(setSelectedAgent(null)); + navigate('/agents'); + }} + > +
+ manage-agents +
+

+ Manage Agents +

+
+
+
+ ) : ( +
navigate('/agents')} + > +
+ manage-agents +
+

+ Manage Agents +

+
+ )} + {conversations?.data && conversations.data.length > 0 ? ( +
+
+

{t('chats')}

{conversations.data?.map((conversation) => ( @@ -316,7 +419,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { }} to="/settings" className={({ isActive }) => - `my-auto mx-4 flex h-9 cursor-pointer gap-4 rounded-3xl hover:bg-gray-100 dark:hover:bg-[#28292E] ${ + `mx-4 my-auto flex h-9 cursor-pointer gap-4 rounded-3xl hover:bg-gray-100 dark:hover:bg-[#28292E] ${ isActive ? 'bg-gray-3000 dark:bg-transparent' : '' }` } @@ -324,15 +427,15 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { Settings -

+

{t('settings.label')}

-
+
@@ -381,9 +484,9 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
-
+
-
DocsGPT
+
DocsGPT
+
+ +

+ Back to all agents +

+
+
+

+ Agent Logs +

+
+ + +
+ ); +} diff --git a/frontend/src/agents/AgentPreview.tsx b/frontend/src/agents/AgentPreview.tsx new file mode 100644 index 00000000..5eaf10a9 --- /dev/null +++ b/frontend/src/agents/AgentPreview.tsx @@ -0,0 +1,153 @@ +import { useCallback, useEffect, useRef, useState } from 'react'; +import { useDispatch, useSelector } from 'react-redux'; + +import MessageInput from '../components/MessageInput'; +import ConversationMessages from '../conversation/ConversationMessages'; +import { Query } from '../conversation/conversationModels'; +import { + addQuery, + fetchAnswer, + handleAbort, + resendQuery, + resetConversation, + selectQueries, + selectStatus, +} from '../conversation/conversationSlice'; +import { selectSelectedAgent } from '../preferences/preferenceSlice'; +import { AppDispatch } from '../store'; + +export default function AgentPreview() { + const dispatch = useDispatch(); + + const queries = useSelector(selectQueries); + const status = useSelector(selectStatus); + const selectedAgent = useSelector(selectSelectedAgent); + + const [input, setInput] = useState(''); + const [lastQueryReturnedErr, setLastQueryReturnedErr] = useState(false); + + const fetchStream = useRef(null); + + const handleFetchAnswer = useCallback( + ({ question, index }: { question: string; index?: number }) => { + fetchStream.current = dispatch( + fetchAnswer({ question, indx: index, isPreview: true }), + ); + }, + [dispatch], + ); + + const handleQuestion = useCallback( + ({ + question, + isRetry = false, + index = undefined, + }: { + question: string; + isRetry?: boolean; + index?: number; + }) => { + const trimmedQuestion = question.trim(); + if (trimmedQuestion === '') return; + + if (index !== undefined) { + if (!isRetry) dispatch(resendQuery({ index, prompt: trimmedQuestion })); + handleFetchAnswer({ question: trimmedQuestion, index }); + } else { + if (!isRetry) { + const newQuery: Query = { prompt: trimmedQuestion }; + dispatch(addQuery(newQuery)); + } + handleFetchAnswer({ question: trimmedQuestion, index: undefined }); + } + }, + [dispatch, handleFetchAnswer], + ); + + const handleQuestionSubmission = ( + updatedQuestion?: string, + updated?: boolean, + indx?: number, + ) => { + if ( + updated === true && + updatedQuestion !== undefined && + indx !== undefined + ) { + handleQuestion({ + question: updatedQuestion, + index: indx, + isRetry: false, + }); + } else if (input.trim() && status !== 'loading') { + const currentInput = input.trim(); + if (lastQueryReturnedErr && queries.length > 0) { + const lastQueryIndex = queries.length - 1; + handleQuestion({ + question: currentInput, + isRetry: true, + index: lastQueryIndex, + }); + } else { + handleQuestion({ + question: currentInput, + isRetry: false, + index: undefined, + }); + } + setInput(''); + } + }; + + const handleKeyDown = (event: React.KeyboardEvent) => { + if (event.key === 'Enter' && !event.shiftKey) { + event.preventDefault(); + handleQuestionSubmission(); + } + }; + + useEffect(() => { + dispatch(resetConversation()); + return () => { + if (fetchStream.current) fetchStream.current.abort(); + handleAbort(); + dispatch(resetConversation()); + }; + }, [dispatch]); + + useEffect(() => { + if (queries.length > 0) { + const lastQuery = queries[queries.length - 1]; + setLastQueryReturnedErr(!!lastQuery.error); + } else setLastQueryReturnedErr(false); + }, [queries]); + return ( +
+
+
+ +
+
+ setInput(e.target.value)} + onSubmit={() => handleQuestionSubmission()} + loading={status === 'loading'} + showSourceButton={selectedAgent ? false : true} + showToolButton={selectedAgent ? false : true} + /> +

+ This is a preview of the agent. You can publish it to start using it + in conversations. +

+
+
+
+ ); +} diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx new file mode 100644 index 00000000..37466a86 --- /dev/null +++ b/frontend/src/agents/NewAgent.tsx @@ -0,0 +1,614 @@ +import React, { useEffect, useRef, useState } from 'react'; +import { useDispatch, useSelector } from 'react-redux'; +import { useNavigate, useParams } from 'react-router-dom'; + +import userService from '../api/services/userService'; +import ArrowLeft from '../assets/arrow-left.svg'; +import SourceIcon from '../assets/source.svg'; +import Dropdown from '../components/Dropdown'; +import MultiSelectPopup, { OptionType } from '../components/MultiSelectPopup'; +import AgentDetailsModal from '../modals/AgentDetailsModal'; +import ConfirmationModal from '../modals/ConfirmationModal'; +import { ActiveState, Doc, Prompt } from '../models/misc'; +import { + selectSelectedAgent, + selectSourceDocs, + selectToken, + setSelectedAgent, +} from '../preferences/preferenceSlice'; +import PromptsModal from '../preferences/PromptsModal'; +import { UserToolType } from '../settings/types'; +import AgentPreview from './AgentPreview'; +import { Agent } from './types'; + +const embeddingsName = + import.meta.env.VITE_EMBEDDINGS_NAME || + 'huggingface_sentence-transformers/all-mpnet-base-v2'; + +export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { + const navigate = useNavigate(); + const dispatch = useDispatch(); + const { agentId } = useParams(); + + const token = useSelector(selectToken); + const sourceDocs = useSelector(selectSourceDocs); + const selectedAgent = useSelector(selectSelectedAgent); + + const [effectiveMode, setEffectiveMode] = useState(mode); + const [agent, setAgent] = useState({ + id: agentId || '', + name: '', + description: '', + image: '', + source: '', + chunks: '', + retriever: '', + prompt_id: '', + tools: [], + agent_type: '', + status: '', + }); + const [prompts, setPrompts] = useState< + { name: string; id: string; type: string }[] + >([]); + const [userTools, setUserTools] = useState([]); + const [isSourcePopupOpen, setIsSourcePopupOpen] = useState(false); + const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false); + const [selectedSourceIds, setSelectedSourceIds] = useState< + Set + >(new Set()); + const [selectedToolIds, setSelectedToolIds] = useState>( + new Set(), + ); + const [deleteConfirmation, setDeleteConfirmation] = + useState('INACTIVE'); + const [agentDetails, setAgentDetails] = useState('INACTIVE'); + const [addPromptModal, setAddPromptModal] = useState('INACTIVE'); + + const sourceAnchorButtonRef = useRef(null); + const toolAnchorButtonRef = useRef(null); + + const modeConfig = { + new: { + heading: 'New Agent', + buttonText: 'Create Agent', + showDelete: false, + showSaveDraft: true, + showLogs: false, + showAccessDetails: false, + }, + edit: { + heading: 'Edit Agent', + buttonText: 'Save Changes', + showDelete: true, + showSaveDraft: false, + showLogs: true, + showAccessDetails: true, + }, + draft: { + heading: 'New Agent (Draft)', + buttonText: 'Publish Draft', + showDelete: true, + showSaveDraft: true, + showLogs: false, + showAccessDetails: false, + }, + }; + const chunks = ['0', '2', '4', '6', '8', '10']; + const agentTypes = [ + { label: 'Classic', value: 'classic' }, + { label: 'ReAct', value: 'react' }, + ]; + + const isPublishable = () => { + return ( + agent.name && + agent.description && + (agent.source || agent.retriever) && + agent.chunks && + agent.prompt_id && + agent.agent_type + ); + }; + + const handleCancel = () => { + if (selectedAgent) dispatch(setSelectedAgent(null)); + navigate('/agents'); + }; + + const handleDelete = async (agentId: string) => { + const response = await userService.deleteAgent(agentId, token); + if (!response.ok) throw new Error('Failed to delete agent'); + navigate('/agents'); + }; + + const handleSaveDraft = async () => { + const response = + effectiveMode === 'new' + ? await userService.createAgent({ ...agent, status: 'draft' }, token) + : await userService.updateAgent( + agent.id || '', + { ...agent, status: 'draft' }, + token, + ); + if (!response.ok) throw new Error('Failed to create agent draft'); + const data = await response.json(); + if (effectiveMode === 'new') { + setEffectiveMode('draft'); + setAgent((prev) => ({ ...prev, id: data.id })); + } + }; + + const handlePublish = async () => { + const response = + effectiveMode === 'new' + ? await userService.createAgent( + { ...agent, status: 'published' }, + token, + ) + : await userService.updateAgent( + agent.id || '', + { ...agent, status: 'published' }, + token, + ); + if (!response.ok) throw new Error('Failed to publish agent'); + const data = await response.json(); + if (data.id) setAgent((prev) => ({ ...prev, id: data.id })); + if (data.key) setAgent((prev) => ({ ...prev, key: data.key })); + if (effectiveMode === 'new') { + setAgentDetails('ACTIVE'); + setEffectiveMode('edit'); + } + }; + + useEffect(() => { + const getTools = async () => { + const response = await userService.getUserTools(token); + if (!response.ok) throw new Error('Failed to fetch tools'); + const data = await response.json(); + const tools: OptionType[] = data.tools.map((tool: UserToolType) => ({ + id: tool.id, + label: tool.displayName, + icon: `/toolIcons/tool_${tool.name}.svg`, + })); + setUserTools(tools); + }; + const getPrompts = async () => { + const response = await userService.getPrompts(token); + if (!response.ok) { + throw new Error('Failed to fetch prompts'); + } + const data = await response.json(); + setPrompts(data); + }; + getTools(); + getPrompts(); + }, [token]); + + useEffect(() => { + if ((mode === 'edit' || mode === 'draft') && agentId) { + const getAgent = async () => { + const response = await userService.getAgent(agentId, token); + if (!response.ok) { + navigate('/agents'); + throw new Error('Failed to fetch agent'); + } + const data = await response.json(); + if (data.source) setSelectedSourceIds(new Set([data.source])); + else if (data.retriever) + setSelectedSourceIds(new Set([data.retriever])); + if (data.tools) setSelectedToolIds(new Set(data.tools)); + if (data.status === 'draft') setEffectiveMode('draft'); + setAgent(data); + }; + getAgent(); + } + }, [agentId, mode, token]); + + useEffect(() => { + const selectedSource = Array.from(selectedSourceIds).map((id) => + sourceDocs?.find( + (source) => + source.id === id || source.retriever === id || source.name === id, + ), + ); + if (selectedSource[0]?.model === embeddingsName) { + if (selectedSource[0] && 'id' in selectedSource[0]) { + setAgent((prev) => ({ + ...prev, + source: selectedSource[0]?.id || 'default', + retriever: '', + })); + } else + setAgent((prev) => ({ + ...prev, + source: '', + retriever: selectedSource[0]?.retriever || 'classic', + })); + } + }, [selectedSourceIds]); + + useEffect(() => { + const selectedTool = Array.from(selectedToolIds).map((id) => + userTools.find((tool) => tool.id === id), + ); + setAgent((prev) => ({ + ...prev, + tools: selectedTool + .map((tool) => tool?.id) + .filter((id): id is string => typeof id === 'string'), + })); + }, [selectedToolIds]); + + useEffect(() => { + if (isPublishable()) dispatch(setSelectedAgent(agent)); + }, [agent, dispatch]); + return ( +
+
+ +

+ Back to all agents +

+
+
+

+ {modeConfig[effectiveMode].heading} +

+
+ + {modeConfig[effectiveMode].showDelete && agent.id && ( + + )} + {modeConfig[effectiveMode].showSaveDraft && ( + + )} + {modeConfig[effectiveMode].showAccessDetails && ( + + )} + {modeConfig[effectiveMode].showAccessDetails && ( + + )} + +
+
+
+
+
+

Meta

+ setAgent({ ...agent, name: e.target.value })} + /> +