mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
* feat: Implement model registry and capabilities for multi-provider support - Added ModelRegistry to manage available models and their capabilities. - Introduced ModelProvider enum for different LLM providers. - Created ModelCapabilities dataclass to define model features. - Implemented methods to load models based on API keys and settings. - Added utility functions for model management in model_utils.py. - Updated settings.py to include provider-specific API keys. - Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry. - Enhanced utility functions to handle token limits and model validation. - Improved code structure and logging for better maintainability. * feat: Add model selection feature with API integration and UI component * feat: Add model selection and default model functionality in agent management * test: Update assertions and formatting in stream processing tests * refactor(llm): Standardize model identifier to model_id * fix tests --------- Co-authored-by: Alex <a@tushynski.me>
1141 lines
46 KiB
Python
1141 lines
46 KiB
Python
"""Agent management routes."""
|
|
|
|
import datetime
|
|
import json
|
|
import uuid
|
|
|
|
from bson.dbref import DBRef
|
|
from bson.objectid import ObjectId
|
|
from flask import current_app, jsonify, make_response, request
|
|
from flask_restx import fields, Namespace, Resource
|
|
|
|
from application.api import api
|
|
from application.api.user.base import (
|
|
agents_collection,
|
|
db,
|
|
ensure_user_doc,
|
|
handle_image_upload,
|
|
resolve_tool_details,
|
|
storage,
|
|
users_collection,
|
|
)
|
|
from application.core.settings import settings
|
|
from application.utils import (
|
|
check_required_fields,
|
|
generate_image_url,
|
|
validate_required_fields,
|
|
)
|
|
|
|
|
|
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
|
|
|
|
|
|
@agents_ns.route("/get_agent")
|
|
class GetAgent(Resource):
|
|
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
|
|
def get(self):
|
|
if not (decoded_token := request.decoded_token):
|
|
return {"success": False}, 401
|
|
if not (agent_id := request.args.get("id")):
|
|
return {"success": False, "message": "ID required"}, 400
|
|
try:
|
|
agent = agents_collection.find_one(
|
|
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
|
|
)
|
|
if not agent:
|
|
return {"status": "Not found"}, 404
|
|
data = {
|
|
"id": str(agent["_id"]),
|
|
"name": agent["name"],
|
|
"description": agent.get("description", ""),
|
|
"image": (
|
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
|
),
|
|
"source": (
|
|
str(source_doc["_id"])
|
|
if isinstance(agent.get("source"), DBRef)
|
|
and (source_doc := db.dereference(agent.get("source")))
|
|
else ""
|
|
),
|
|
"sources": [
|
|
(
|
|
str(db.dereference(source_ref)["_id"])
|
|
if isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
|
else source_ref
|
|
)
|
|
for source_ref in agent.get("sources", [])
|
|
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
|
|
or source_ref == "default"
|
|
],
|
|
"chunks": agent["chunks"],
|
|
"retriever": agent.get("retriever", ""),
|
|
"prompt_id": agent.get("prompt_id", ""),
|
|
"tools": agent.get("tools", []),
|
|
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
|
"agent_type": agent.get("agent_type", ""),
|
|
"status": agent.get("status", ""),
|
|
"json_schema": agent.get("json_schema"),
|
|
"limited_token_mode": agent.get("limited_token_mode", False),
|
|
"token_limit": agent.get(
|
|
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
|
),
|
|
"limited_request_mode": agent.get("limited_request_mode", False),
|
|
"request_limit": agent.get(
|
|
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
|
),
|
|
"created_at": agent.get("createdAt", ""),
|
|
"updated_at": agent.get("updatedAt", ""),
|
|
"last_used_at": agent.get("lastUsedAt", ""),
|
|
"key": (
|
|
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
|
if "key" in agent
|
|
else ""
|
|
),
|
|
"pinned": agent.get("pinned", False),
|
|
"shared": agent.get("shared_publicly", False),
|
|
"shared_metadata": agent.get("shared_metadata", {}),
|
|
"shared_token": agent.get("shared_token", ""),
|
|
"models": agent.get("models", []),
|
|
"default_model_id": agent.get("default_model_id", ""),
|
|
}
|
|
return make_response(jsonify(data), 200)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
|
|
return {"success": False}, 400
|
|
|
|
|
|
@agents_ns.route("/get_agents")
|
|
class GetAgents(Resource):
|
|
@api.doc(description="Retrieve agents for the user")
|
|
def get(self):
|
|
if not (decoded_token := request.decoded_token):
|
|
return {"success": False}, 401
|
|
user = decoded_token.get("sub")
|
|
try:
|
|
user_doc = ensure_user_doc(user)
|
|
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
|
|
|
agents = agents_collection.find({"user": user})
|
|
list_agents = [
|
|
{
|
|
"id": str(agent["_id"]),
|
|
"name": agent["name"],
|
|
"description": agent.get("description", ""),
|
|
"image": (
|
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
|
),
|
|
"source": (
|
|
str(source_doc["_id"])
|
|
if isinstance(agent.get("source"), DBRef)
|
|
and (source_doc := db.dereference(agent.get("source")))
|
|
else (
|
|
agent.get("source", "")
|
|
if agent.get("source") == "default"
|
|
else ""
|
|
)
|
|
),
|
|
"sources": [
|
|
(
|
|
source_ref
|
|
if source_ref == "default"
|
|
else str(db.dereference(source_ref)["_id"])
|
|
)
|
|
for source_ref in agent.get("sources", [])
|
|
if source_ref == "default"
|
|
or (
|
|
isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
|
)
|
|
],
|
|
"chunks": agent["chunks"],
|
|
"retriever": agent.get("retriever", ""),
|
|
"prompt_id": agent.get("prompt_id", ""),
|
|
"tools": agent.get("tools", []),
|
|
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
|
"agent_type": agent.get("agent_type", ""),
|
|
"status": agent.get("status", ""),
|
|
"json_schema": agent.get("json_schema"),
|
|
"limited_token_mode": agent.get("limited_token_mode", False),
|
|
"token_limit": agent.get(
|
|
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
|
),
|
|
"limited_request_mode": agent.get("limited_request_mode", False),
|
|
"request_limit": agent.get(
|
|
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
|
),
|
|
"created_at": agent.get("createdAt", ""),
|
|
"updated_at": agent.get("updatedAt", ""),
|
|
"last_used_at": agent.get("lastUsedAt", ""),
|
|
"key": (
|
|
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
|
if "key" in agent
|
|
else ""
|
|
),
|
|
"pinned": str(agent["_id"]) in pinned_ids,
|
|
"shared": agent.get("shared_publicly", False),
|
|
"shared_metadata": agent.get("shared_metadata", {}),
|
|
"shared_token": agent.get("shared_token", ""),
|
|
"models": agent.get("models", []),
|
|
"default_model_id": agent.get("default_model_id", ""),
|
|
}
|
|
for agent in agents
|
|
if "source" in agent or "retriever" in agent
|
|
]
|
|
except Exception as err:
|
|
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
|
|
return make_response(jsonify({"success": False}), 400)
|
|
return make_response(jsonify(list_agents), 200)
|
|
|
|
|
|
@agents_ns.route("/create_agent")
|
|
class CreateAgent(Resource):
|
|
create_agent_model = api.model(
|
|
"CreateAgentModel",
|
|
{
|
|
"name": fields.String(required=True, description="Name of the agent"),
|
|
"description": fields.String(
|
|
required=True, description="Description of the agent"
|
|
),
|
|
"image": fields.Raw(
|
|
required=False, description="Image file upload", type="file"
|
|
),
|
|
"source": fields.String(
|
|
required=False, description="Source ID (legacy single source)"
|
|
),
|
|
"sources": fields.List(
|
|
fields.String,
|
|
required=False,
|
|
description="List of source identifiers for multiple sources",
|
|
),
|
|
"chunks": fields.Integer(required=True, description="Chunks count"),
|
|
"retriever": fields.String(required=True, description="Retriever ID"),
|
|
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
|
"tools": fields.List(
|
|
fields.String, required=False, description="List of tool identifiers"
|
|
),
|
|
"agent_type": fields.String(required=True, description="Type of the agent"),
|
|
"status": fields.String(
|
|
required=True, description="Status of the agent (draft or published)"
|
|
),
|
|
"json_schema": fields.Raw(
|
|
required=False,
|
|
description="JSON schema for enforcing structured output format",
|
|
),
|
|
"limited_token_mode": fields.Boolean(
|
|
required=False, description="Whether the agent is in limited token mode"
|
|
),
|
|
"token_limit": fields.Integer(
|
|
required=False, description="Token limit for the agent in limited mode"
|
|
),
|
|
"limited_request_mode": fields.Boolean(
|
|
required=False,
|
|
description="Whether the agent is in limited request mode",
|
|
),
|
|
"request_limit": fields.Integer(
|
|
required=False,
|
|
description="Request limit for the agent in limited mode",
|
|
),
|
|
"models": fields.List(
|
|
fields.String,
|
|
required=False,
|
|
description="List of available model IDs for this agent",
|
|
),
|
|
"default_model_id": fields.String(
|
|
required=False, description="Default model ID for this agent"
|
|
),
|
|
},
|
|
)
|
|
|
|
@api.expect(create_agent_model)
|
|
@api.doc(description="Create a new agent")
|
|
def post(self):
|
|
if not (decoded_token := request.decoded_token):
|
|
return {"success": False}, 401
|
|
user = decoded_token.get("sub")
|
|
if request.content_type == "application/json":
|
|
data = request.get_json()
|
|
else:
|
|
data = request.form.to_dict()
|
|
if "tools" in data:
|
|
try:
|
|
data["tools"] = json.loads(data["tools"])
|
|
except json.JSONDecodeError:
|
|
data["tools"] = []
|
|
if "sources" in data:
|
|
try:
|
|
data["sources"] = json.loads(data["sources"])
|
|
except json.JSONDecodeError:
|
|
data["sources"] = []
|
|
if "json_schema" in data:
|
|
try:
|
|
data["json_schema"] = json.loads(data["json_schema"])
|
|
except json.JSONDecodeError:
|
|
data["json_schema"] = None
|
|
if "models" in data:
|
|
try:
|
|
data["models"] = json.loads(data["models"])
|
|
except json.JSONDecodeError:
|
|
data["models"] = []
|
|
print(f"Received data: {data}")
|
|
|
|
# Validate JSON schema if provided
|
|
|
|
if data.get("json_schema"):
|
|
try:
|
|
# Basic validation - ensure it's a valid JSON structure
|
|
|
|
json_schema = data.get("json_schema")
|
|
if not isinstance(json_schema, dict):
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "JSON schema must be a valid JSON object",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
# Validate that it has either a 'schema' property or is itself a schema
|
|
|
|
if "schema" not in json_schema and "type" not in json_schema:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
except Exception as e:
|
|
return make_response(
|
|
jsonify(
|
|
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
|
|
),
|
|
400,
|
|
)
|
|
if data.get("status") not in ["draft", "published"]:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Status must be either 'draft' or 'published'",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
if data.get("status") == "published":
|
|
required_fields = [
|
|
"name",
|
|
"description",
|
|
"chunks",
|
|
"retriever",
|
|
"prompt_id",
|
|
"agent_type",
|
|
]
|
|
# Require either source or sources (but not both)
|
|
|
|
if not data.get("source") and not data.get("sources"):
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Either 'source' or 'sources' field is required for published agents",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
validate_fields = ["name", "description", "prompt_id", "agent_type"]
|
|
else:
|
|
required_fields = ["name"]
|
|
validate_fields = []
|
|
missing_fields = check_required_fields(data, required_fields)
|
|
invalid_fields = validate_required_fields(data, validate_fields)
|
|
if missing_fields:
|
|
return missing_fields
|
|
if invalid_fields:
|
|
return invalid_fields
|
|
image_url, error = handle_image_upload(request, "", user, storage)
|
|
if error:
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Image upload failed"}), 400
|
|
)
|
|
try:
|
|
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
|
|
|
|
sources_list = []
|
|
if data.get("sources") and len(data.get("sources", [])) > 0:
|
|
for source_id in data.get("sources", []):
|
|
if source_id == "default":
|
|
sources_list.append("default")
|
|
elif ObjectId.is_valid(source_id):
|
|
sources_list.append(DBRef("sources", ObjectId(source_id)))
|
|
source_field = ""
|
|
else:
|
|
source_value = data.get("source", "")
|
|
if source_value == "default":
|
|
source_field = "default"
|
|
elif ObjectId.is_valid(source_value):
|
|
source_field = DBRef("sources", ObjectId(source_value))
|
|
else:
|
|
source_field = ""
|
|
new_agent = {
|
|
"user": user,
|
|
"name": data.get("name"),
|
|
"description": data.get("description", ""),
|
|
"image": image_url,
|
|
"source": source_field,
|
|
"sources": sources_list,
|
|
"chunks": data.get("chunks", ""),
|
|
"retriever": data.get("retriever", ""),
|
|
"prompt_id": data.get("prompt_id", ""),
|
|
"tools": data.get("tools", []),
|
|
"agent_type": data.get("agent_type", ""),
|
|
"status": data.get("status"),
|
|
"json_schema": data.get("json_schema"),
|
|
"limited_token_mode": (
|
|
data.get("limited_token_mode") == "True"
|
|
if isinstance(data.get("limited_token_mode"), str)
|
|
else bool(data.get("limited_token_mode", False))
|
|
),
|
|
"token_limit": int(
|
|
data.get(
|
|
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
|
)
|
|
),
|
|
"limited_request_mode": (
|
|
data.get("limited_request_mode") == "True"
|
|
if isinstance(data.get("limited_request_mode"), str)
|
|
else bool(data.get("limited_request_mode", False))
|
|
),
|
|
"request_limit": int(
|
|
data.get(
|
|
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
|
)
|
|
),
|
|
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
|
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
|
"lastUsedAt": None,
|
|
"key": key,
|
|
"models": data.get("models", []),
|
|
"default_model_id": data.get("default_model_id", ""),
|
|
}
|
|
if new_agent["chunks"] == "":
|
|
new_agent["chunks"] = "2"
|
|
if (
|
|
new_agent["source"] == ""
|
|
and new_agent["retriever"] == ""
|
|
and not new_agent["sources"]
|
|
):
|
|
new_agent["retriever"] = "classic"
|
|
resp = agents_collection.insert_one(new_agent)
|
|
new_id = str(resp.inserted_id)
|
|
except Exception as err:
|
|
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
|
|
return make_response(jsonify({"success": False}), 400)
|
|
return make_response(jsonify({"id": new_id, "key": key}), 201)
|
|
|
|
|
|
@agents_ns.route("/update_agent/<string:agent_id>")
|
|
class UpdateAgent(Resource):
|
|
update_agent_model = api.model(
|
|
"UpdateAgentModel",
|
|
{
|
|
"name": fields.String(required=True, description="New name of the agent"),
|
|
"description": fields.String(
|
|
required=True, description="New description of the agent"
|
|
),
|
|
"image": fields.String(
|
|
required=False, description="New image URL or identifier"
|
|
),
|
|
"source": fields.String(
|
|
required=False, description="Source ID (legacy single source)"
|
|
),
|
|
"sources": fields.List(
|
|
fields.String,
|
|
required=False,
|
|
description="List of source identifiers for multiple sources",
|
|
),
|
|
"chunks": fields.Integer(required=True, description="Chunks count"),
|
|
"retriever": fields.String(required=True, description="Retriever ID"),
|
|
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
|
"tools": fields.List(
|
|
fields.String, required=False, description="List of tool identifiers"
|
|
),
|
|
"agent_type": fields.String(required=True, description="Type of the agent"),
|
|
"status": fields.String(
|
|
required=True, description="Status of the agent (draft or published)"
|
|
),
|
|
"json_schema": fields.Raw(
|
|
required=False,
|
|
description="JSON schema for enforcing structured output format",
|
|
),
|
|
"limited_token_mode": fields.Boolean(
|
|
required=False, description="Whether the agent is in limited token mode"
|
|
),
|
|
"token_limit": fields.Integer(
|
|
required=False, description="Token limit for the agent in limited mode"
|
|
),
|
|
"limited_request_mode": fields.Boolean(
|
|
require=False,
|
|
description="Whether the agent is in limited request mode",
|
|
),
|
|
"request_limit": fields.Integer(
|
|
required=False,
|
|
description="Request limit for the agent in limited mode",
|
|
),
|
|
"models": fields.List(
|
|
fields.String,
|
|
required=False,
|
|
description="List of available model IDs for this agent",
|
|
),
|
|
"default_model_id": fields.String(
|
|
required=False, description="Default model ID for this agent"
|
|
),
|
|
},
|
|
)
|
|
|
|
@api.expect(update_agent_model)
|
|
@api.doc(description="Update an existing agent")
|
|
def put(self, agent_id):
|
|
if not (decoded_token := request.decoded_token):
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Unauthorized"}), 401
|
|
)
|
|
user = decoded_token.get("sub")
|
|
|
|
if not ObjectId.is_valid(agent_id):
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
|
|
)
|
|
oid = ObjectId(agent_id)
|
|
|
|
try:
|
|
if request.content_type and "application/json" in request.content_type:
|
|
data = request.get_json()
|
|
else:
|
|
data = request.form.to_dict()
|
|
json_fields = ["tools", "sources", "json_schema", "models"]
|
|
for field in json_fields:
|
|
if field in data and data[field]:
|
|
try:
|
|
data[field] = json.loads(data[field])
|
|
except json.JSONDecodeError:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": f"Invalid JSON format for field: {field}",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
except Exception as err:
|
|
current_app.logger.error(
|
|
f"Error parsing request data: {err}", exc_info=True
|
|
)
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Invalid request data"}), 400
|
|
)
|
|
try:
|
|
existing_agent = agents_collection.find_one({"_id": oid, "user": user})
|
|
except Exception as err:
|
|
current_app.logger.error(
|
|
f"Error finding agent {agent_id}: {err}", exc_info=True
|
|
)
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Database error finding agent"}),
|
|
500,
|
|
)
|
|
if not existing_agent:
|
|
return make_response(
|
|
jsonify(
|
|
{"success": False, "message": "Agent not found or not authorized"}
|
|
),
|
|
404,
|
|
)
|
|
image_url, error = handle_image_upload(
|
|
request, existing_agent.get("image", ""), user, storage
|
|
)
|
|
if error:
|
|
current_app.logger.error(
|
|
f"Image upload error for agent {agent_id}: {error}"
|
|
)
|
|
return make_response(
|
|
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
|
|
400,
|
|
)
|
|
update_fields = {}
|
|
allowed_fields = [
|
|
"name",
|
|
"description",
|
|
"image",
|
|
"source",
|
|
"sources",
|
|
"chunks",
|
|
"retriever",
|
|
"prompt_id",
|
|
"tools",
|
|
"agent_type",
|
|
"status",
|
|
"json_schema",
|
|
"limited_token_mode",
|
|
"token_limit",
|
|
"limited_request_mode",
|
|
"request_limit",
|
|
"models",
|
|
"default_model_id",
|
|
]
|
|
|
|
for field in allowed_fields:
|
|
if field not in data:
|
|
continue
|
|
if field == "status":
|
|
new_status = data.get("status")
|
|
if new_status not in ["draft", "published"]:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Invalid status value. Must be 'draft' or 'published'",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
update_fields[field] = new_status
|
|
elif field == "source":
|
|
source_id = data.get("source")
|
|
if source_id == "default":
|
|
update_fields[field] = "default"
|
|
elif source_id and ObjectId.is_valid(source_id):
|
|
update_fields[field] = DBRef("sources", ObjectId(source_id))
|
|
elif source_id:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": f"Invalid source ID format: {source_id}",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
else:
|
|
update_fields[field] = ""
|
|
elif field == "sources":
|
|
sources_list = data.get("sources", [])
|
|
if sources_list and isinstance(sources_list, list):
|
|
valid_sources = []
|
|
for source_id in sources_list:
|
|
if source_id == "default":
|
|
valid_sources.append("default")
|
|
elif ObjectId.is_valid(source_id):
|
|
valid_sources.append(DBRef("sources", ObjectId(source_id)))
|
|
else:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": f"Invalid source ID in list: {source_id}",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
update_fields[field] = valid_sources
|
|
else:
|
|
update_fields[field] = []
|
|
elif field == "chunks":
|
|
chunks_value = data.get("chunks")
|
|
if chunks_value == "" or chunks_value is None:
|
|
update_fields[field] = "2"
|
|
else:
|
|
try:
|
|
chunks_int = int(chunks_value)
|
|
if chunks_int < 0:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Chunks value must be a non-negative integer",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
update_fields[field] = str(chunks_int)
|
|
except (ValueError, TypeError):
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": f"Invalid chunks value: {chunks_value}",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
elif field == "tools":
|
|
tools_list = data.get("tools", [])
|
|
if isinstance(tools_list, list):
|
|
update_fields[field] = tools_list
|
|
else:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Tools must be a list",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
elif field == "json_schema":
|
|
json_schema = data.get("json_schema")
|
|
if json_schema is not None:
|
|
if not isinstance(json_schema, dict):
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "JSON schema must be a valid object",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
update_fields[field] = json_schema
|
|
else:
|
|
update_fields[field] = None
|
|
elif field == "limited_token_mode":
|
|
raw_value = data.get("limited_token_mode", False)
|
|
bool_value = (
|
|
raw_value == "True"
|
|
if isinstance(raw_value, str)
|
|
else bool(raw_value)
|
|
)
|
|
update_fields[field] = bool_value
|
|
|
|
if bool_value and data.get("token_limit") is None:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Token limit must be provided when limited token mode is enabled",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
elif field == "limited_request_mode":
|
|
raw_value = data.get("limited_request_mode", False)
|
|
bool_value = (
|
|
raw_value == "True"
|
|
if isinstance(raw_value, str)
|
|
else bool(raw_value)
|
|
)
|
|
update_fields[field] = bool_value
|
|
|
|
if bool_value and data.get("request_limit") is None:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Request limit must be provided when limited request mode is enabled",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
elif field == "token_limit":
|
|
token_limit = data.get("token_limit")
|
|
# Convert to int and store
|
|
update_fields[field] = int(token_limit) if token_limit else 0
|
|
|
|
# Validate consistency with mode
|
|
if update_fields[field] > 0 and not data.get("limited_token_mode"):
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Token limit cannot be set when limited token mode is disabled",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
elif field == "request_limit":
|
|
request_limit = data.get("request_limit")
|
|
update_fields[field] = int(request_limit) if request_limit else 0
|
|
|
|
if update_fields[field] > 0 and not data.get("limited_request_mode"):
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Request limit cannot be set when limited request mode is disabled",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
else:
|
|
value = data[field]
|
|
if field in ["name", "description", "prompt_id", "agent_type"]:
|
|
if not value or not str(value).strip():
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": f"Field '{field}' cannot be empty",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
update_fields[field] = value
|
|
if image_url:
|
|
update_fields["image"] = image_url
|
|
if not update_fields:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "No valid update data provided",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
newly_generated_key = None
|
|
final_status = update_fields.get("status", existing_agent.get("status"))
|
|
|
|
if final_status == "published":
|
|
required_published_fields = {
|
|
"name": "Agent name",
|
|
"description": "Agent description",
|
|
"chunks": "Chunks count",
|
|
"prompt_id": "Prompt",
|
|
"agent_type": "Agent type",
|
|
}
|
|
|
|
missing_published_fields = []
|
|
for req_field, field_label in required_published_fields.items():
|
|
final_value = update_fields.get(
|
|
req_field, existing_agent.get(req_field)
|
|
)
|
|
if not final_value:
|
|
missing_published_fields.append(field_label)
|
|
source_val = update_fields.get("source", existing_agent.get("source"))
|
|
sources_val = update_fields.get(
|
|
"sources", existing_agent.get("sources", [])
|
|
)
|
|
|
|
has_valid_source = (
|
|
isinstance(source_val, DBRef)
|
|
or source_val == "default"
|
|
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
|
)
|
|
|
|
if not has_valid_source:
|
|
missing_published_fields.append("Source")
|
|
if missing_published_fields:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
|
|
}
|
|
),
|
|
400,
|
|
)
|
|
if not existing_agent.get("key"):
|
|
newly_generated_key = str(uuid.uuid4())
|
|
update_fields["key"] = newly_generated_key
|
|
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
|
|
|
try:
|
|
result = agents_collection.update_one(
|
|
{"_id": oid, "user": user}, {"$set": update_fields}
|
|
)
|
|
|
|
if result.matched_count == 0:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": False,
|
|
"message": "Agent not found or update failed",
|
|
}
|
|
),
|
|
404,
|
|
)
|
|
if result.modified_count == 0 and result.matched_count == 1:
|
|
return make_response(
|
|
jsonify(
|
|
{
|
|
"success": True,
|
|
"message": "No changes detected",
|
|
"id": agent_id,
|
|
}
|
|
),
|
|
200,
|
|
)
|
|
except Exception as err:
|
|
current_app.logger.error(
|
|
f"Error updating agent {agent_id}: {err}", exc_info=True
|
|
)
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Database error during update"}),
|
|
500,
|
|
)
|
|
response_data = {
|
|
"success": True,
|
|
"id": agent_id,
|
|
"message": "Agent updated successfully",
|
|
}
|
|
if newly_generated_key:
|
|
response_data["key"] = newly_generated_key
|
|
return make_response(jsonify(response_data), 200)
|
|
|
|
|
|
@agents_ns.route("/delete_agent")
|
|
class DeleteAgent(Resource):
|
|
@api.doc(params={"id": "ID of the agent"}, description="Delete an agent by ID")
|
|
def delete(self):
|
|
decoded_token = request.decoded_token
|
|
if not decoded_token:
|
|
return make_response(jsonify({"success": False}), 401)
|
|
user = decoded_token.get("sub")
|
|
agent_id = request.args.get("id")
|
|
if not agent_id:
|
|
return make_response(
|
|
jsonify({"success": False, "message": "ID is required"}), 400
|
|
)
|
|
try:
|
|
deleted_agent = agents_collection.find_one_and_delete(
|
|
{"_id": ObjectId(agent_id), "user": user}
|
|
)
|
|
if not deleted_agent:
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
|
)
|
|
deleted_id = str(deleted_agent["_id"])
|
|
except Exception as err:
|
|
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
|
|
return make_response(jsonify({"success": False}), 400)
|
|
return make_response(jsonify({"id": deleted_id}), 200)
|
|
|
|
|
|
@agents_ns.route("/pinned_agents")
|
|
class PinnedAgents(Resource):
|
|
@api.doc(description="Get pinned agents for the user")
|
|
def get(self):
|
|
decoded_token = request.decoded_token
|
|
if not decoded_token:
|
|
return make_response(jsonify({"success": False}), 401)
|
|
user_id = decoded_token.get("sub")
|
|
|
|
try:
|
|
user_doc = ensure_user_doc(user_id)
|
|
pinned_ids = user_doc.get("agent_preferences", {}).get("pinned", [])
|
|
|
|
if not pinned_ids:
|
|
return make_response(jsonify([]), 200)
|
|
pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids]
|
|
|
|
pinned_agents_cursor = agents_collection.find(
|
|
{"_id": {"$in": pinned_object_ids}}
|
|
)
|
|
pinned_agents = list(pinned_agents_cursor)
|
|
existing_ids = {str(agent["_id"]) for agent in pinned_agents}
|
|
|
|
# Clean up any stale pinned IDs
|
|
|
|
stale_ids = [
|
|
agent_id for agent_id in pinned_ids if agent_id not in existing_ids
|
|
]
|
|
if stale_ids:
|
|
users_collection.update_one(
|
|
{"user_id": user_id},
|
|
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
|
)
|
|
list_pinned_agents = [
|
|
{
|
|
"id": str(agent["_id"]),
|
|
"name": agent.get("name", ""),
|
|
"description": agent.get("description", ""),
|
|
"image": (
|
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
|
),
|
|
"source": (
|
|
str(db.dereference(agent["source"])["_id"])
|
|
if "source" in agent
|
|
and agent["source"]
|
|
and isinstance(agent["source"], DBRef)
|
|
and db.dereference(agent["source"]) is not None
|
|
else ""
|
|
),
|
|
"chunks": agent.get("chunks", ""),
|
|
"retriever": agent.get("retriever", ""),
|
|
"prompt_id": agent.get("prompt_id", ""),
|
|
"tools": agent.get("tools", []),
|
|
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
|
"agent_type": agent.get("agent_type", ""),
|
|
"status": agent.get("status", ""),
|
|
"created_at": agent.get("createdAt", ""),
|
|
"updated_at": agent.get("updatedAt", ""),
|
|
"last_used_at": agent.get("lastUsedAt", ""),
|
|
"key": (
|
|
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
|
if "key" in agent
|
|
else ""
|
|
),
|
|
"pinned": True,
|
|
}
|
|
for agent in pinned_agents
|
|
if "source" in agent or "retriever" in agent
|
|
]
|
|
except Exception as err:
|
|
current_app.logger.error(f"Error retrieving pinned agents: {err}")
|
|
return make_response(jsonify({"success": False}), 400)
|
|
return make_response(jsonify(list_pinned_agents), 200)
|
|
|
|
|
|
@agents_ns.route("/template_agents")
|
|
class GetTemplateAgents(Resource):
|
|
@api.doc(description="Get template/premade agents")
|
|
def get(self):
|
|
try:
|
|
template_agents = agents_collection.find({"user": "system"})
|
|
template_agents = [
|
|
{
|
|
"id": str(agent["_id"]),
|
|
"name": agent["name"],
|
|
"description": agent["description"],
|
|
"image": agent.get("image", ""),
|
|
}
|
|
for agent in template_agents
|
|
]
|
|
return make_response(jsonify(template_agents), 200)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Template agents fetch error: {e}", exc_info=True)
|
|
return make_response(jsonify({"success": False}), 400)
|
|
|
|
|
|
@agents_ns.route("/adopt_agent")
|
|
class AdoptAgent(Resource):
|
|
@api.doc(params={"id": "Agent ID"}, description="Adopt an agent by ID")
|
|
def post(self):
|
|
if not (decoded_token := request.decoded_token):
|
|
return make_response(jsonify({"success": False}), 401)
|
|
|
|
if not (agent_id := request.args.get("id")):
|
|
return make_response(
|
|
jsonify({"success": False, "message": "ID required"}), 400
|
|
)
|
|
|
|
try:
|
|
agent = agents_collection.find_one(
|
|
{"_id": ObjectId(agent_id), "user": "system"}
|
|
)
|
|
if not agent:
|
|
return make_response(jsonify({"status": "Not found"}), 404)
|
|
|
|
new_agent = agent.copy()
|
|
new_agent.pop("_id", None)
|
|
new_agent["user"] = decoded_token["sub"]
|
|
new_agent["status"] = "published"
|
|
new_agent["lastUsedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
|
new_agent["key"] = str(uuid.uuid4())
|
|
insert_result = agents_collection.insert_one(new_agent)
|
|
|
|
response_agent = new_agent.copy()
|
|
response_agent.pop("_id", None)
|
|
response_agent["id"] = str(insert_result.inserted_id)
|
|
response_agent["tool_details"] = resolve_tool_details(
|
|
response_agent.get("tools", [])
|
|
)
|
|
if isinstance(response_agent.get("source"), DBRef):
|
|
response_agent["source"] = str(response_agent["source"].id)
|
|
return make_response(
|
|
jsonify({"success": True, "agent": response_agent}), 200
|
|
)
|
|
except Exception as e:
|
|
current_app.logger.error(f"Agent adopt error: {e}", exc_info=True)
|
|
return make_response(jsonify({"success": False}), 400)
|
|
|
|
|
|
@agents_ns.route("/pin_agent")
|
|
class PinAgent(Resource):
|
|
@api.doc(params={"id": "ID of the agent"}, description="Pin or unpin an agent")
|
|
def post(self):
|
|
decoded_token = request.decoded_token
|
|
if not decoded_token:
|
|
return make_response(jsonify({"success": False}), 401)
|
|
user_id = decoded_token.get("sub")
|
|
agent_id = request.args.get("id")
|
|
|
|
if not agent_id:
|
|
return make_response(
|
|
jsonify({"success": False, "message": "ID is required"}), 400
|
|
)
|
|
try:
|
|
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
|
if not agent:
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
|
)
|
|
user_doc = ensure_user_doc(user_id)
|
|
pinned_list = user_doc.get("agent_preferences", {}).get("pinned", [])
|
|
|
|
if agent_id in pinned_list:
|
|
users_collection.update_one(
|
|
{"user_id": user_id},
|
|
{"$pull": {"agent_preferences.pinned": agent_id}},
|
|
)
|
|
action = "unpinned"
|
|
else:
|
|
users_collection.update_one(
|
|
{"user_id": user_id},
|
|
{"$addToSet": {"agent_preferences.pinned": agent_id}},
|
|
)
|
|
action = "pinned"
|
|
except Exception as err:
|
|
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Server error"}), 500
|
|
)
|
|
return make_response(jsonify({"success": True, "action": action}), 200)
|
|
|
|
|
|
@agents_ns.route("/remove_shared_agent")
|
|
class RemoveSharedAgent(Resource):
|
|
@api.doc(
|
|
params={"id": "ID of the shared agent"},
|
|
description="Remove a shared agent from the current user's shared list",
|
|
)
|
|
def delete(self):
|
|
decoded_token = request.decoded_token
|
|
if not decoded_token:
|
|
return make_response(jsonify({"success": False}), 401)
|
|
user_id = decoded_token.get("sub")
|
|
agent_id = request.args.get("id")
|
|
|
|
if not agent_id:
|
|
return make_response(
|
|
jsonify({"success": False, "message": "ID is required"}), 400
|
|
)
|
|
try:
|
|
agent = agents_collection.find_one(
|
|
{"_id": ObjectId(agent_id), "shared_publicly": True}
|
|
)
|
|
if not agent:
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Shared agent not found"}),
|
|
404,
|
|
)
|
|
ensure_user_doc(user_id)
|
|
users_collection.update_one(
|
|
{"user_id": user_id},
|
|
{
|
|
"$pull": {
|
|
"agent_preferences.shared_with_me": agent_id,
|
|
"agent_preferences.pinned": agent_id,
|
|
}
|
|
},
|
|
)
|
|
|
|
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
|
except Exception as err:
|
|
current_app.logger.error(f"Error removing shared agent: {err}")
|
|
return make_response(
|
|
jsonify({"success": False, "message": "Server error"}), 500
|
|
)
|