feat: shared and pinning agents + fix for streaming tools

This commit is contained in:
Siddhant Rai
2025-05-12 06:06:11 +05:30
parent 07fa656e7c
commit 6520be5b85
16 changed files with 1015 additions and 169 deletions

View File

@@ -88,19 +88,28 @@ def run_async_chain(chain, question, chat_history):
def get_agent_key(agent_id, user_id):
if not agent_id:
return None
return None, False, None
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found", 404)
if agent.get("user") == user_id:
is_owner = agent.get("user") == user_id
if is_owner:
agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
)
return str(agent["key"])
return str(agent["key"]), False, None
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if is_shared_with_user:
return str(agent["key"]), True, agent.get("shared_token")
raise Exception("Unauthorized access to the agent", 403)
@@ -153,6 +162,8 @@ def save_conversation(
index=None,
api_key=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
@@ -228,6 +239,9 @@ def save_conversation(
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
api_key_doc = agents_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
@@ -261,6 +275,8 @@ def complete_stream(
should_save_conversation=True,
attachments=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
):
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
@@ -325,6 +341,8 @@ def complete_stream(
index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
else:
conversation_id = None
@@ -433,7 +451,9 @@ class Stream(Resource):
retriever_name = data.get("retriever", "classic")
agent_id = data.get("agent_id", None)
agent_type = settings.AGENT_NAME
agent_key = get_agent_key(agent_id, request.decoded_token.get("sub"))
agent_key, is_shared_usage, shared_token = get_agent_key(
agent_id, request.decoded_token.get("sub")
)
if agent_key:
data.update({"api_key": agent_key})
@@ -448,7 +468,10 @@ class Stream(Resource):
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
decoded_token = {"sub": data_key.get("user")}
if is_shared_usage:
decoded_token = request.decoded_token
else:
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
@@ -514,6 +537,8 @@ class Stream(Resource):
index=index,
should_save_conversation=save_conv,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
),
mimetype="text/event-stream",
)
@@ -881,6 +906,8 @@ def get_attachments_content(attachment_ids, user):
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(f"Error retrieving attachment {attachment_id}: {e}", exc_info=True)
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments

View File

@@ -41,6 +41,12 @@ shared_conversations_collections = db["shared_conversations"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
agents_collection.create_index(
[("shared", 1)],
name="shared_index",
background=True,
)
user = Blueprint("user", __name__)
user_ns = Namespace("user", description="User related operations", path="/")
api.add_namespace(user_ns)
@@ -166,6 +172,8 @@ class GetConversations(Resource):
"id": str(conversation["_id"]),
"name": conversation["name"],
"agent_id": conversation.get("agent_id", None),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
for conversation in conversations
]
@@ -208,6 +216,8 @@ class GetSingleConversation(Resource):
data = {
"queries": conversation["queries"],
"agent_id": conversation.get("agent_id"),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
return make_response(jsonify(data), 200)
@@ -1034,6 +1044,9 @@ class GetAgent(Resource):
else ""
),
"pinned": agent.get("pinned", False),
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
}
except Exception as err:
current_app.logger.error(f"Error retrieving agent: {err}", exc_info=True)
@@ -1077,6 +1090,9 @@ class GetAgents(Resource):
else ""
),
"pinned": agent.get("pinned", False),
"shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""),
}
for agent in agents
if "source" in agent or "retriever" in agent
@@ -1478,6 +1494,195 @@ class PinAgent(Resource):
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/shared_agent")
class SharedAgent(Resource):
@api.doc(
params={
"token": "Shared token of the agent",
},
description="Get a shared agent by token or ID",
)
def get(self):
shared_token = request.args.get("token")
if not shared_token:
return make_response(
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
try:
query = {}
query["shared_publicly"] = True
query["shared_token"] = shared_token
shared_agent = agents_collection.find_one(query)
if not shared_agent:
return make_response(
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
data = {
"id": str(shared_agent["_id"]),
"user": shared_agent.get("user", ""),
"name": shared_agent.get("name", ""),
"description": shared_agent.get("description", ""),
"tools": shared_agent.get("tools", []),
"agent_type": shared_agent.get("agent_type", ""),
"status": shared_agent.get("status", ""),
"created_at": shared_agent.get("createdAt", ""),
"updated_at": shared_agent.get("updatedAt", ""),
"shared": shared_agent.get("shared_publicly", False),
"shared_token": shared_agent.get("shared_token", ""),
"shared_metadata": shared_agent.get("shared_metadata", {}),
}
if data["tools"]:
enriched_tools = []
for tool in data["tools"]:
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
if tool_data:
enriched_tools.append(tool_data.get("displayName", ""))
data["tools"] = enriched_tools
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(data), 200)
@user_ns.route("/api/shared_agents")
class SharedAgents(Resource):
@api.doc(description="Get shared agents")
def get(self):
try:
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
shared_agents = agents_collection.find(
{"shared_publicly": True, "user": {"$ne": user}}
)
list_shared_agents = [
{
"id": str(shared_agent["_id"]),
"name": shared_agent.get("name", ""),
"description": shared_agent.get("description", ""),
"tools": shared_agent.get("tools", []),
"agent_type": shared_agent.get("agent_type", ""),
"status": shared_agent.get("status", ""),
"created_at": shared_agent.get("createdAt", ""),
"updated_at": shared_agent.get("updatedAt", ""),
"shared": shared_agent.get("shared_publicly", False),
"shared_token": shared_agent.get("shared_token", ""),
"shared_metadata": shared_agent.get("shared_metadata", {}),
}
for shared_agent in shared_agents
]
except Exception as err:
current_app.logger.error(f"Error retrieving shared agents: {err}")
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_shared_agents), 200)
@user_ns.route("/api/share_agent")
class ShareAgent(Resource):
@api.expect(
api.model(
"ShareAgentModel",
{
"id": fields.String(required=True, description="ID of the agent"),
"shared": fields.Boolean(
required=True, description="Share or unshare the agent"
),
"username": fields.String(
required=False, description="Name of the user"
),
},
)
)
@api.doc(description="Share or unshare an agent")
def put(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(
jsonify({"success": False, "message": "Missing JSON body"}), 400
)
agent_id = data.get("id")
shared = data.get("shared")
username = data.get("username", "")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
if shared is None:
return make_response(
jsonify(
{
"success": False,
"message": "Shared parameter is required and must be true or false",
}
),
400,
)
try:
try:
agent_oid = ObjectId(agent_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid agent ID"}), 400
)
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(datetime.timezone.utc),
}
shared_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{
"$set": {
"shared_publicly": shared,
"shared_metadata": shared_metadata,
"shared_token": shared_token,
}
},
)
else:
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{"$set": {"shared_publicly": shared, "shared_token": None}},
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200
)
@user_ns.route("/api/agent_webhook")
class AgentWebhook(Resource):
@api.doc(