mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-02 06:56:29 +00:00
feat: Enhance agent selection and conversation handling
- Added functionality to select agents in the Navigation component, allowing users to reset conversations and set the selected agent. - Updated the MessageInput component to conditionally show source and tool buttons based on the selected agent. - Modified the Conversation component to handle agent-specific queries and manage file uploads. - Improved conversation fetching logic to include agent IDs and handle attachments. - Introduced new types for conversation summaries and results to streamline API responses. - Refactored Redux slices to manage selected agent state and improve overall state management. - Enhanced error handling and loading states across components for better user experience.
This commit is contained in:
@@ -10,6 +10,7 @@ from application.core.mongo_db import MongoDB
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
@@ -23,7 +24,7 @@ class BaseAgent(ABC):
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]]=None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
@@ -58,6 +59,27 @@ class BaseAgent(ABC):
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
agents_collection = db["agents"]
|
||||
tools_collection = db["user_tools"]
|
||||
|
||||
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
|
||||
tools = (
|
||||
tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
||||
)
|
||||
if tool_ids
|
||||
else []
|
||||
)
|
||||
tools = list(tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
return tools_by_id
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
@@ -243,9 +265,11 @@ class BaseAgent(ABC):
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
log_context: Optional[LogContext] = None,
|
||||
attachments: Optional[List[Dict]] = None
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
):
|
||||
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages, attachments)
|
||||
resp = self.llm_handler.handle_response(
|
||||
self, resp, tools_dict, messages, attachments
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm_handler)
|
||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||
|
||||
@@ -5,21 +5,25 @@ from application.logging import LogContext
|
||||
|
||||
from application.retriever.base import BaseRetriever
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
|
||||
resp = self._llm_gen(messages, log_context)
|
||||
|
||||
|
||||
attachments = self.attachments
|
||||
|
||||
if isinstance(resp, str):
|
||||
@@ -33,7 +37,7 @@ class ClassicAgent(BaseAgent):
|
||||
yield {"answer": resp.message.content}
|
||||
return
|
||||
|
||||
resp = self._llm_handler(resp, tools_dict, messages, log_context,attachments)
|
||||
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
|
||||
@@ -30,7 +30,10 @@ class ReActAgent(BaseAgent):
|
||||
) -> Generator[Dict, None, None]:
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
|
||||
@@ -86,6 +86,20 @@ def run_async_chain(chain, question, chat_history):
|
||||
return result
|
||||
|
||||
|
||||
def get_agent_key(agent_id, user_id):
|
||||
if not agent_id:
|
||||
return None
|
||||
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
if agent is None:
|
||||
raise Exception("Agent not found", 404)
|
||||
|
||||
if agent.get("is_public") or agent.get("user") == user_id:
|
||||
return str(agent["key"])
|
||||
|
||||
raise Exception("Unauthorized access to the agent", 403)
|
||||
|
||||
|
||||
def get_data_from_api_key(api_key):
|
||||
data = agents_collection.find_one({"key": api_key})
|
||||
if not data:
|
||||
@@ -129,6 +143,7 @@ def save_conversation(
|
||||
decoded_token,
|
||||
index=None,
|
||||
api_key=None,
|
||||
agent_id=None,
|
||||
):
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
if conversation_id is not None and index is not None:
|
||||
@@ -202,6 +217,8 @@ def save_conversation(
|
||||
],
|
||||
}
|
||||
if api_key:
|
||||
if agent_id:
|
||||
conversation_data["agent_id"] = agent_id
|
||||
api_key_doc = agents_collection.find_one({"key": api_key})
|
||||
if api_key_doc:
|
||||
conversation_data["api_key"] = api_key_doc["key"]
|
||||
@@ -234,6 +251,7 @@ def complete_stream(
|
||||
index=None,
|
||||
should_save_conversation=True,
|
||||
attachments=None,
|
||||
agent_id=None,
|
||||
):
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
@@ -297,6 +315,7 @@ def complete_stream(
|
||||
decoded_token,
|
||||
index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
@@ -404,7 +423,14 @@ class Stream(Resource):
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
agent_id = data.get("agent_id", None)
|
||||
agent_type = settings.AGENT_NAME
|
||||
agent_key = get_agent_key(agent_id, request.decoded_token.get("sub"))
|
||||
|
||||
if agent_key:
|
||||
data.update({"api_key": agent_key})
|
||||
else:
|
||||
agent_id = None
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
@@ -479,6 +505,7 @@ class Stream(Resource):
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
should_save_conversation=save_conv,
|
||||
agent_id=agent_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -138,14 +138,24 @@ class GetConversations(Resource):
|
||||
try:
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{"api_key": {"$exists": False}, "user": decoded_token.get("sub")}
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{"id": str(conversation["_id"]), "name": conversation["name"]}
|
||||
{
|
||||
"id": str(conversation["_id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
@@ -179,7 +189,12 @@ class GetSingleConversation(Resource):
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving conversation: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(conversation["queries"]), 200)
|
||||
|
||||
data = {
|
||||
"queries": conversation["queries"],
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@user_ns.route("/api/update_conversation_name")
|
||||
|
||||
Reference in New Issue
Block a user