feat: shared and pinning agents + fix for streaming tools

This commit is contained in:
Siddhant Rai
2025-05-12 06:06:11 +05:30
parent 07fa656e7c
commit 6520be5b85
16 changed files with 1015 additions and 169 deletions

View File

@@ -88,19 +88,28 @@ def run_async_chain(chain, question, chat_history):
def get_agent_key(agent_id, user_id):
if not agent_id:
return None
return None, False, None
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found", 404)
if agent.get("user") == user_id:
is_owner = agent.get("user") == user_id
if is_owner:
agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
)
return str(agent["key"])
return str(agent["key"]), False, None
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if is_shared_with_user:
return str(agent["key"]), True, agent.get("shared_token")
raise Exception("Unauthorized access to the agent", 403)
@@ -153,6 +162,8 @@ def save_conversation(
index=None,
api_key=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
@@ -228,6 +239,9 @@ def save_conversation(
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
api_key_doc = agents_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
@@ -261,6 +275,8 @@ def complete_stream(
should_save_conversation=True,
attachments=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
):
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
@@ -325,6 +341,8 @@ def complete_stream(
index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
else:
conversation_id = None
@@ -433,7 +451,9 @@ class Stream(Resource):
retriever_name = data.get("retriever", "classic")
agent_id = data.get("agent_id", None)
agent_type = settings.AGENT_NAME
agent_key = get_agent_key(agent_id, request.decoded_token.get("sub"))
agent_key, is_shared_usage, shared_token = get_agent_key(
agent_id, request.decoded_token.get("sub")
)
if agent_key:
data.update({"api_key": agent_key})
@@ -448,7 +468,10 @@ class Stream(Resource):
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
decoded_token = {"sub": data_key.get("user")}
if is_shared_usage:
decoded_token = request.decoded_token
else:
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
@@ -514,6 +537,8 @@ class Stream(Resource):
index=index,
should_save_conversation=save_conv,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
),
mimetype="text/event-stream",
)
@@ -881,6 +906,8 @@ def get_attachments_content(attachment_ids, user):
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(f"Error retrieving attachment {attachment_id}: {e}", exc_info=True)
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments