feat: Enhance agent selection and conversation handling

- Added functionality to select agents in the Navigation component, allowing users to reset conversations and set the selected agent.
- Updated the MessageInput component to conditionally show source and tool buttons based on the selected agent.
- Modified the Conversation component to handle agent-specific queries and manage file uploads.
- Improved conversation fetching logic to include agent IDs and handle attachments.
- Introduced new types for conversation summaries and results to streamline API responses.
- Refactored Redux slices to manage selected agent state and improve overall state management.
- Enhanced error handling and loading states across components for better user experience.
This commit is contained in:
Siddhant Rai
2025-04-15 11:53:53 +05:30
parent fa1f9d7009
commit 7c69e99914
16 changed files with 445 additions and 237 deletions

View File

@@ -10,6 +10,7 @@ from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
from bson.objectid import ObjectId
class BaseAgent(ABC):
@@ -23,7 +24,7 @@ class BaseAgent(ABC):
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
decoded_token: Optional[Dict] = None,
attachments: Optional[List[Dict]]=None,
attachments: Optional[List[Dict]] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -58,6 +59,27 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
return tools_by_id
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -243,9 +265,11 @@ class BaseAgent(ABC):
tools_dict: Dict,
messages: List[Dict],
log_context: Optional[LogContext] = None,
attachments: Optional[List[Dict]] = None
attachments: Optional[List[Dict]] = None,
):
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages, attachments)
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages, attachments
)
if log_context:
data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data})

View File

@@ -5,21 +5,25 @@ from application.logging import LogContext
from application.retriever.base import BaseRetriever
import logging
logger = logging.getLogger(__name__)
class ClassicAgent(BaseAgent):
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)
tools_dict = self._get_user_tools(self.user)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)
messages = self._build_messages(self.prompt, query, retrieved_data)
resp = self._llm_gen(messages, log_context)
attachments = self.attachments
if isinstance(resp, str):
@@ -33,7 +37,7 @@ class ClassicAgent(BaseAgent):
yield {"answer": resp.message.content}
return
resp = self._llm_handler(resp, tools_dict, messages, log_context,attachments)
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
if isinstance(resp, str):
yield {"answer": resp}

View File

@@ -30,7 +30,10 @@ class ReActAgent(BaseAgent):
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)
tools_dict = self._get_user_tools(self.user)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])

View File

@@ -86,6 +86,20 @@ def run_async_chain(chain, question, chat_history):
return result
def get_agent_key(agent_id, user_id):
if not agent_id:
return None
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found", 404)
if agent.get("is_public") or agent.get("user") == user_id:
return str(agent["key"])
raise Exception("Unauthorized access to the agent", 403)
def get_data_from_api_key(api_key):
data = agents_collection.find_one({"key": api_key})
if not data:
@@ -129,6 +143,7 @@ def save_conversation(
decoded_token,
index=None,
api_key=None,
agent_id=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
@@ -202,6 +217,8 @@ def save_conversation(
],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
api_key_doc = agents_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
@@ -234,6 +251,7 @@ def complete_stream(
index=None,
should_save_conversation=True,
attachments=None,
agent_id=None,
):
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
@@ -297,6 +315,7 @@ def complete_stream(
decoded_token,
index,
api_key=user_api_key,
agent_id=agent_id,
)
else:
conversation_id = None
@@ -404,7 +423,14 @@ class Stream(Resource):
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_id = data.get("agent_id", None)
agent_type = settings.AGENT_NAME
agent_key = get_agent_key(agent_id, request.decoded_token.get("sub"))
if agent_key:
data.update({"api_key": agent_key})
else:
agent_id = None
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
@@ -479,6 +505,7 @@ class Stream(Resource):
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
agent_id=agent_id,
),
mimetype="text/event-stream",
)

View File

@@ -138,14 +138,24 @@ class GetConversations(Resource):
try:
conversations = (
conversations_collection.find(
{"api_key": {"$exists": False}, "user": decoded_token.get("sub")}
{
"$or": [
{"api_key": {"$exists": False}},
{"agent_id": {"$exists": True}},
],
"user": decoded_token.get("sub"),
}
)
.sort("date", -1)
.limit(30)
)
list_conversations = [
{"id": str(conversation["_id"]), "name": conversation["name"]}
{
"id": str(conversation["_id"]),
"name": conversation["name"],
"agent_id": conversation.get("agent_id", None),
}
for conversation in conversations
]
except Exception as err:
@@ -179,7 +189,12 @@ class GetSingleConversation(Resource):
except Exception as err:
current_app.logger.error(f"Error retrieving conversation: {err}")
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(conversation["queries"]), 200)
data = {
"queries": conversation["queries"],
"agent_id": conversation.get("agent_id"),
}
return make_response(jsonify(data), 200)
@user_ns.route("/api/update_conversation_name")