feat: Add MCP Server management functionality

- Implemented encryption utility for securely storing sensitive credentials.
- Added MCPServerModal component for managing MCP server configurations.
- Updated AddToolModal to include MCP server management.
- Enhanced localization with new strings for MCP server features.
- Introduced SVG icons for MCP tools in the frontend.
- Created a new settings section for MCP server configurations in the application.
This commit is contained in:
Siddhant Rai
2025-09-03 15:41:59 +05:30
parent 44b8a11c04
commit 7c23f43c63
12 changed files with 1658 additions and 112 deletions

View File

@@ -227,6 +227,7 @@ class BaseAgent(ABC):
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
print(

View File

@@ -0,0 +1,424 @@
import json
import time
from typing import Any, Dict, List, Optional
import requests
from application.agents.tools.base import Tool
_mcp_session_cache = {}
class MCPTool(Tool):
"""
MCP Tool
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
"""
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
"""
Initialize the MCP Tool with configuration.
Args:
config: Dictionary containing MCP server configuration:
- server_url: URL of the remote MCP server
- auth_type: Type of authentication (api_key, bearer, basic, none)
- encrypted_credentials: Encrypted credentials (if available)
- timeout: Request timeout in seconds (default: 30)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.server_url = config.get("server_url", "")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
# Decrypt credentials if they are encrypted
self.auth_credentials = {}
if config.get("encrypted_credentials") and user_id:
from application.security.encryption import decrypt_credentials
self.auth_credentials = decrypt_credentials(
config["encrypted_credentials"], user_id
)
else:
# Fallback to unencrypted credentials (for backward compatibility)
self.auth_credentials = config.get("auth_credentials", {})
self.available_tools = []
self._session = requests.Session()
self._mcp_session_id = None
self._setup_authentication()
self._cache_key = self._generate_cache_key()
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
# Use server URL + auth info to create unique key
auth_key = ""
if self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
elif self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
auth_key = f"basic:{username}"
else:
auth_key = "none"
return f"{self.server_url}#{auth_key}"
def _get_cached_session(self) -> Optional[str]:
"""Get cached session ID if available and not expired."""
global _mcp_session_cache
if self._cache_key in _mcp_session_cache:
session_data = _mcp_session_cache[self._cache_key]
# Check if session is less than 30 minutes old
if time.time() - session_data["created_at"] < 1800: # 30 minutes
return session_data["session_id"]
else:
# Remove expired session
del _mcp_session_cache[self._cache_key]
return None
def _cache_session(self, session_id: str):
"""Cache the session ID for reuse."""
global _mcp_session_cache
_mcp_session_cache[self._cache_key] = {
"session_id": session_id,
"created_at": time.time(),
}
def _setup_authentication(self):
"""Setup authentication for the MCP server connection."""
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
self._session.headers.update({header_name: api_key})
elif self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
if token:
self._session.headers.update({"Authorization": f"Bearer {token}"})
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
self._session.auth = (username, password)
def _initialize_mcp_connection(self) -> Dict:
"""
Initialize MCP connection with the server, using cached session if available.
Returns:
Server capabilities and information
"""
cached_session = self._get_cached_session()
if cached_session:
self._mcp_session_id = cached_session
return {"cached": True}
try:
init_params = {
"protocolVersion": "2024-11-05",
"capabilities": {"roots": {"listChanged": True}, "sampling": {}},
"clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
}
response = self._make_mcp_request("initialize", init_params)
self._make_mcp_request("notifications/initialized")
return response
except Exception as e:
return {"error": str(e), "fallback": True}
def _ensure_valid_session(self):
"""Ensure we have a valid MCP session, reinitializing if needed."""
if not self._mcp_session_id:
self._initialize_mcp_connection()
def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
"""
Make an MCP protocol request to the server with automatic session recovery.
Args:
method: MCP method name (e.g., "tools/list", "tools/call")
params: Parameters for the MCP method
Returns:
Response data as dictionary
Raises:
Exception: If request fails after retry
"""
mcp_message = {"jsonrpc": "2.0", "method": method}
if not method.startswith("notifications/"):
mcp_message["id"] = 1
if params:
mcp_message["params"] = params
return self._execute_mcp_request(mcp_message, method)
def _execute_mcp_request(
self, mcp_message: Dict, method: str, is_retry: bool = False
) -> Dict:
"""Execute MCP request with optional retry on session failure."""
try:
headers = {"Content-Type": "application/json", "Accept": "application/json"}
headers.update(self._session.headers)
if self._mcp_session_id:
headers["Mcp-Session-Id"] = self._mcp_session_id
response = self._session.post(
self.server_url.rstrip("/"),
json=mcp_message,
headers=headers,
timeout=self.timeout,
)
if "mcp-session-id" in response.headers:
self._mcp_session_id = response.headers["mcp-session-id"]
self._cache_session(self._mcp_session_id)
response.raise_for_status()
if method.startswith("notifications/"):
return {}
try:
result = response.json()
except json.JSONDecodeError:
raise Exception(f"Invalid JSON response: {response.text}")
if "error" in result:
error_msg = result["error"]
if isinstance(error_msg, dict):
error_msg = error_msg.get("message", str(error_msg))
raise Exception(f"MCP server error: {error_msg}")
return result.get("result", result)
except requests.exceptions.RequestException as e:
if not is_retry and self._should_retry_with_new_session(e):
self._invalidate_and_refresh_session()
return self._execute_mcp_request(mcp_message, method, is_retry=True)
raise Exception(f"MCP server request failed: {str(e)}")
def _should_retry_with_new_session(self, error: Exception) -> bool:
"""Check if error indicates session invalidation and retry is warranted."""
error_str = str(error).lower()
return (
any(
indicator in error_str
for indicator in [
"invalid session",
"session expired",
"unauthorized",
"401",
"403",
]
)
and self._mcp_session_id is not None
)
def _invalidate_and_refresh_session(self) -> None:
"""Invalidate current session and create a new one."""
global _mcp_session_cache
if self._cache_key in _mcp_session_cache:
del _mcp_session_cache[self._cache_key]
self._mcp_session_id = None
self._initialize_mcp_connection()
def discover_tools(self) -> List[Dict]:
"""
Discover available tools from the MCP server using MCP protocol.
Returns:
List of tool definitions from the server
"""
try:
self._ensure_valid_session()
response = self._make_mcp_request("tools/list")
if isinstance(response, dict) and "tools" in response:
self.available_tools = response["tools"]
return self.available_tools
elif isinstance(response, list):
self.available_tools = response
return self.available_tools
else:
self.available_tools = []
return self.available_tools
except Exception as e:
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using MCP protocol.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
self._ensure_valid_session()
# Prepare call parameters for MCP protocol
call_params = {"name": action_name, "arguments": kwargs}
try:
result = self._make_mcp_request("tools/call", call_params)
return result
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
Returns:
List of action metadata dictionaries
"""
actions = []
for tool in self.available_tools:
# Parse MCP tool schema according to MCP specification
# Check multiple possible schema locations for compatibility
input_schema = (
tool.get("inputSchema")
or tool.get("input_schema")
or tool.get("schema")
or tool.get("parameters")
)
# Default empty schema if no inputSchema provided
parameters_schema = {
"type": "object",
"properties": {},
"required": [],
}
# Parse the inputSchema if it exists
if input_schema:
if isinstance(input_schema, dict):
# Handle standard JSON Schema format
if "properties" in input_schema:
parameters_schema = {
"type": input_schema.get("type", "object"),
"properties": input_schema.get("properties", {}),
"required": input_schema.get("required", []),
}
# Add additional schema properties if they exist
for key in ["additionalProperties", "description"]:
if key in input_schema:
parameters_schema[key] = input_schema[key]
else:
# Might be properties directly at root level
parameters_schema["properties"] = input_schema
action = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": parameters_schema,
}
actions.append(action)
return actions
def get_config_requirements(self) -> Dict:
"""
Get configuration requirements for the MCP tool.
Returns:
Dictionary describing required configuration
"""
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
"required": True,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "api_key", "bearer", "basic"],
"default": "none",
"required": True,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"properties": {
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"header_name": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
"default": "X-API-Key",
},
"token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
}
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
try:
init_result = self._initialize_mcp_connection()
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if init_result.get("cached"):
message += " (Using cached session)"
elif init_result.get("fallback"):
message += " (No formal initialization required)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"session_id": self._mcp_session_id,
"tools": [
tool.get("name", "unknown") for tool in tools[:5]
], # First 5 tool names
}
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"error_type": type(e).__name__,
}

View File

@@ -23,16 +23,31 @@ class ToolManager:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config):
def load_tool(self, tool_name, tool_config, user_id=None):
self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
# For MCP tools, pass the user_id for credential decryption
if tool_name == "mcp_tool" and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
# For MCP tools, they might not be pre-loaded, so load dynamically
if tool_name == "mcp_tool":
raise ValueError(f"Tool '{tool_name}' not loaded and no config provided for dynamic loading")
raise ValueError(f"Tool '{tool_name}' not loaded")
# For MCP tools, if user_id is provided, create a new instance with user context
if tool_name == "mcp_tool" and user_id:
# Load tool dynamically with user context for proper credential access
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):

View File

@@ -492,9 +492,9 @@ class DeleteOldIndexes(Resource):
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
try:
# Delete vector index
if settings.VECTOR_STORE == "faiss":
@@ -508,7 +508,7 @@ class DeleteOldIndexes(Resource):
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"]
if storage.is_directory(file_path):
@@ -517,7 +517,7 @@ class DeleteOldIndexes(Resource):
storage.delete_file(f)
else:
storage.delete_file(file_path)
except FileNotFoundError:
pass
except Exception as err:
@@ -525,7 +525,7 @@ class DeleteOldIndexes(Resource):
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@@ -573,55 +573,75 @@ class UploadFile(Resource):
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = file.filename
safe_file = safe_filename(original_filename)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
if zipfile.is_zipfile(temp_file_path):
try:
with zipfile.ZipFile(temp_file_path, 'r') as zip_ref:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
# Walk through extracted files and upload them
for root, _, files in os.walk(temp_dir):
for extracted_file in files:
if os.path.join(root, extracted_file) == temp_file_path:
if (
os.path.join(root, extracted_file)
== temp_file_path
):
continue
rel_path = os.path.relpath(os.path.join(root, extracted_file), temp_dir)
rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
with open(os.path.join(root, extracted_file), 'rb') as f:
with open(
os.path.join(root, extracted_file), "rb"
) as f:
storage.save_file(f, storage_path)
except Exception as e:
current_app.logger.error(f"Error extracting zip: {e}", exc_info=True)
current_app.logger.error(
f"Error extracting zip: {e}", exc_info=True
)
# If zip extraction fails, save the original zip file
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, 'rb') as f:
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
else:
# For non-zip files, save directly
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, 'rb') as f:
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
".jpg", ".jpeg",
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
user,
file_path=base_path,
filename=dir_name
filename=dir_name,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
@@ -635,12 +655,29 @@ class ManageSourceFiles(Resource):
api.model(
"ManageSourceFilesModel",
{
"source_id": fields.String(required=True, description="Source ID to modify"),
"operation": fields.String(required=True, description="Operation: 'add', 'remove', or 'remove_directory'"),
"file_paths": fields.List(fields.String, required=False, description="File paths to remove (for remove operation)"),
"directory_path": fields.String(required=False, description="Directory path to remove (for remove_directory operation)"),
"file": fields.Raw(required=False, description="Files to add (for add operation)"),
"parent_dir": fields.String(required=False, description="Parent directory path relative to source root"),
"source_id": fields.String(
required=True, description="Source ID to modify"
),
"operation": fields.String(
required=True,
description="Operation: 'add', 'remove', or 'remove_directory'",
),
"file_paths": fields.List(
fields.String,
required=False,
description="File paths to remove (for remove operation)",
),
"directory_path": fields.String(
required=False,
description="Directory path to remove (for remove_directory operation)",
),
"file": fields.Raw(
required=False, description="Files to add (for add operation)"
),
"parent_dir": fields.String(
required=False,
description="Parent directory path relative to source root",
),
},
)
)
@@ -650,7 +687,9 @@ class ManageSourceFiles(Resource):
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "message": "Unauthorized"}), 401)
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
source_id = request.form.get("source_id")
@@ -658,12 +697,24 @@ class ManageSourceFiles(Resource):
if not source_id or not operation:
return make_response(
jsonify({"success": False, "message": "source_id and operation are required"}), 400
jsonify(
{
"success": False,
"message": "source_id and operation are required",
}
),
400,
)
if operation not in ["add", "remove", "remove_directory"]:
return make_response(
jsonify({"success": False, "message": "operation must be 'add', 'remove', or 'remove_directory'"}), 400
jsonify(
{
"success": False,
"message": "operation must be 'add', 'remove', or 'remove_directory'",
}
),
400,
)
try:
@@ -674,34 +725,53 @@ class ManageSourceFiles(Resource):
)
try:
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not source:
return make_response(
jsonify({"success": False, "message": "Source not found or access denied"}), 404
jsonify(
{
"success": False,
"message": "Source not found or access denied",
}
),
404,
)
except Exception as err:
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": "Database error"}), 500)
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
parent_dir = request.form.get("parent_dir", "")
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
jsonify({"success": False, "message": "Invalid parent directory path"}), 400
jsonify(
{"success": False, "message": "Invalid parent directory path"}
),
400,
)
if operation == "add":
files = request.files.getlist("file")
if not files or all(file.filename == "" for file in files):
return make_response(
jsonify({"success": False, "message": "No files provided for add operation"}), 400
jsonify(
{
"success": False,
"message": "No files provided for add operation",
}
),
400,
)
added_files = []
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
@@ -720,26 +790,44 @@ class ManageSourceFiles(Resource):
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(jsonify({
"success": True,
"message": f"Added {len(added_files)} files",
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id
}), 200)
return make_response(
jsonify(
{
"success": True,
"message": f"Added {len(added_files)} files",
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove":
file_paths_str = request.form.get("file_paths")
if not file_paths_str:
return make_response(
jsonify({"success": False, "message": "file_paths required for remove operation"}), 400
jsonify(
{
"success": False,
"message": "file_paths required for remove operation",
}
),
400,
)
try:
file_paths = json.loads(file_paths_str) if isinstance(file_paths_str, str) else file_paths_str
file_paths = (
json.loads(file_paths_str)
if isinstance(file_paths_str, str)
else file_paths_str
)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid file_paths format"}), 400
jsonify(
{"success": False, "message": "Invalid file_paths format"}
),
400,
)
# Remove files from storage and directory structure
@@ -757,18 +845,29 @@ class ManageSourceFiles(Resource):
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(jsonify({
"success": True,
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id
}), 200)
return make_response(
jsonify(
{
"success": True,
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove_directory":
directory_path = request.form.get("directory_path")
if not directory_path:
return make_response(
jsonify({"success": False, "message": "directory_path required for remove_directory operation"}), 400
jsonify(
{
"success": False,
"message": "directory_path required for remove_directory operation",
}
),
400,
)
# Validate directory path (prevent path traversal)
@@ -778,10 +877,17 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
)
return make_response(
jsonify({"success": False, "message": "Invalid directory path"}), 400
jsonify(
{"success": False, "message": "Invalid directory path"}
),
400,
)
full_directory_path = f"{source_file_path}/{directory_path}" if directory_path else source_file_path
full_directory_path = (
f"{source_file_path}/{directory_path}"
if directory_path
else source_file_path
)
if not storage.is_directory(full_directory_path):
current_app.logger.warning(
@@ -790,7 +896,13 @@ class ManageSourceFiles(Resource):
f"Full path: {full_directory_path}"
)
return make_response(
jsonify({"success": False, "message": "Directory not found or is not a directory"}), 404
jsonify(
{
"success": False,
"message": "Directory not found or is not a directory",
}
),
404,
)
success = storage.remove_directory(full_directory_path)
@@ -802,7 +914,10 @@ class ManageSourceFiles(Resource):
f"Full path: {full_directory_path}"
)
return make_response(
jsonify({"success": False, "message": "Failed to remove directory"}), 500
jsonify(
{"success": False, "message": "Failed to remove directory"}
),
500,
)
current_app.logger.info(
@@ -816,12 +931,17 @@ class ManageSourceFiles(Resource):
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(jsonify({
"success": True,
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id
}), 200)
return make_response(
jsonify(
{
"success": True,
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
}
),
200,
)
except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}"
@@ -835,8 +955,12 @@ class ManageSourceFiles(Resource):
parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}"
current_app.logger.error(f"Error managing source files: {err} ({error_context})", exc_info=True)
return make_response(jsonify({"success": False, "message": "Operation failed"}), 500)
current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Operation failed"}), 500
)
@user_ns.route("/api/remote")
@@ -984,7 +1108,7 @@ class PaginatedSources(Resource):
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"isNested": bool(doc.get("directory_structure"))
"isNested": bool(doc.get("directory_structure")),
}
paginated_docs.append(doc_data)
response = {
@@ -1032,7 +1156,7 @@ class CombinedJson(Resource):
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"is_nested": bool(index.get("directory_structure"))
"is_nested": bool(index.get("directory_structure")),
}
)
except Exception as err:
@@ -1381,7 +1505,8 @@ class CreateAgent(Resource):
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False, description="JSON schema for enforcing structured output format"
required=False,
description="JSON schema for enforcing structured output format",
),
},
)
@@ -1407,7 +1532,7 @@ class CreateAgent(Resource):
except json.JSONDecodeError:
data["json_schema"] = None
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
try:
@@ -1415,20 +1540,32 @@ class CreateAgent(Resource):
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}),
400
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}),
400
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
return make_response(
jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}),
400
jsonify(
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -1529,7 +1666,8 @@ class UpdateAgent(Resource):
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False, description="JSON schema for enforcing structured output format"
required=False,
description="JSON schema for enforcing structured output format",
),
},
)
@@ -3297,6 +3435,31 @@ class CreateTool(Resource):
param_details["value"] = ""
transformed_actions.append(action)
try:
# Process config to encrypt credentials for MCP tools
config = data["config"]
if data["name"] == "mcp_tool":
from application.security.encryption import encrypt_credentials
# Extract credentials from config
credentials = {}
if config.get("auth_type") == "bearer":
credentials["bearer_token"] = config.get("bearer_token", "")
elif config.get("auth_type") == "api_key":
credentials["api_key"] = config.get("api_key", "")
credentials["api_key_header"] = config.get("api_key_header", "")
elif config.get("auth_type") == "basic":
credentials["username"] = config.get("username", "")
credentials["password"] = config.get("password", "")
# Encrypt credentials if any exist
if credentials:
config["encrypted_credentials"] = encrypt_credentials(
credentials, user
)
# Remove plaintext credentials from config
for key in credentials.keys():
config.pop(key, None)
new_tool = {
"user": user,
"name": data["name"],
@@ -3304,7 +3467,7 @@ class CreateTool(Resource):
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": data["config"],
"config": config,
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
@@ -3371,7 +3534,41 @@ class UpdateTool(Resource):
),
400,
)
update_data["config"] = data["config"]
# Handle MCP tool credential encryption
config = data["config"]
tool_name = data.get("name")
if not tool_name:
# Get the tool name from the database
existing_tool = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
tool_name = existing_tool.get("name") if existing_tool else None
if tool_name == "mcp_tool":
from application.security.encryption import encrypt_credentials
# Extract credentials from config
credentials = {}
if config.get("auth_type") == "bearer":
credentials["bearer_token"] = config.get("bearer_token", "")
elif config.get("auth_type") == "api_key":
credentials["api_key"] = config.get("api_key", "")
credentials["api_key_header"] = config.get("api_key_header", "")
elif config.get("auth_type") == "basic":
credentials["username"] = config.get("username", "")
credentials["password"] = config.get("password", "")
# Encrypt credentials if any exist
if credentials:
config["encrypted_credentials"] = encrypt_credentials(
credentials, user
)
# Remove plaintext credentials from config
for key in credentials.keys():
config.pop(key, None)
update_data["config"] = config
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
@@ -3537,7 +3734,7 @@ class GetChunks(Resource):
"page": "Page number for pagination",
"per_page": "Number of chunks per page",
"path": "Optional: Filter chunks by relative file path",
"search": "Optional: Search term to filter chunks by title or content"
"search": "Optional: Search term to filter chunks by title or content",
},
)
def get(self):
@@ -3561,7 +3758,7 @@ class GetChunks(Resource):
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
filtered_chunks = []
for chunk in chunks:
metadata = chunk.get("metadata", {})
@@ -3582,9 +3779,9 @@ class GetChunks(Resource):
continue
filtered_chunks.append(chunk)
chunks = filtered_chunks
total_chunks = len(chunks)
start = (page - 1) * per_page
end = start + per_page
@@ -3598,7 +3795,7 @@ class GetChunks(Resource):
"total": total_chunks,
"chunks": paginated_chunks,
"path": path if path else None,
"search": search_term if search_term else None
"search": search_term if search_term else None,
}
),
200,
@@ -3607,6 +3804,7 @@ class GetChunks(Resource):
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@user_ns.route("/api/add_chunk")
class AddChunk(Resource):
@api.expect(
@@ -3773,7 +3971,9 @@ class UpdateChunk(Resource):
deleted = store.delete_chunk(chunk_id)
if not deleted:
current_app.logger.warning(f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created")
current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
)
return make_response(
jsonify(
@@ -3905,39 +4105,233 @@ class DirectoryStructure(Resource):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
if not doc_id:
return make_response(
jsonify({"error": "Document ID is required"}), 400
)
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
directory_structure = doc.get("directory_structure", {})
return make_response(
jsonify({
"success": True,
"directory_structure": directory_structure,
"base_path": doc.get("file_path", "")
}), 200
jsonify(
{
"success": True,
"directory_structure": directory_structure,
"base_path": doc.get("file_path", ""),
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(
jsonify({"success": False, "error": str(e)}), 500
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@user_ns.route("/api/mcp_servers")
class MCPServers(Resource):
@api.doc(description="Get all MCP servers configured by the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
# Find all MCP tools for this user
mcp_tools = user_tools_collection.find({"user": user, "name": "mcp_tool"})
servers = []
for tool in mcp_tools:
config = tool.get("config", {})
servers.append(
{
"id": str(tool["_id"]),
"name": tool.get("displayName", "MCP Server"),
"server_url": config.get("server_url", ""),
"auth_type": config.get("auth_type", "none"),
"status": tool.get("status", False),
"created_at": (
tool.get("_id").generation_time.isoformat()
if tool.get("_id")
else None
),
}
)
return make_response(jsonify({"success": True, "servers": servers}), 200)
except Exception as e:
current_app.logger.error(
f"Error retrieving MCP servers: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@user_ns.route("/api/mcp_server/<string:server_id>/test")
class TestMCPServer(Resource):
@api.doc(description="Test connection to an MCP server")
def post(self, server_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
# Find the MCP tool
mcp_tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
)
if not mcp_tool_doc:
return make_response(
jsonify({"success": False, "error": "MCP server not found"}), 404
)
# Load the tool and test connection
from application.agents.tools.mcp_tool import MCPTool
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
result = mcp_tool.test_connection()
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Connection test failed: {str(e)}"}
),
500,
)
@user_ns.route("/api/mcp_server/<string:server_id>/tools")
class MCPServerTools(Resource):
@api.doc(description="Discover and get tools from an MCP server")
def get(self, server_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
# Find the MCP tool
mcp_tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
)
if not mcp_tool_doc:
return make_response(
jsonify({"success": False, "error": "MCP server not found"}), 404
)
# Load the tool and discover tools
from application.agents.tools.mcp_tool import MCPTool
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
tools = mcp_tool.discover_tools()
# Get actions metadata and transform to match other tools format
actions_metadata = mcp_tool.get_actions_metadata()
transformed_actions = []
for action in actions_metadata:
# Add active flag and transform parameters
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
# Update the stored actions in the database
user_tools_collection.update_one(
{"_id": ObjectId(server_id)}, {"$set": {"actions": transformed_actions}}
)
return make_response(
jsonify(
{"success": True, "tools": tools, "actions": transformed_actions}
),
200,
)
except Exception as e:
current_app.logger.error(f"Error discovering MCP tools: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Tool discovery failed: {str(e)}"}
),
500,
)
@user_ns.route("/api/mcp_server/<string:server_id>/tools/<string:action_name>")
class MCPServerToolAction(Resource):
@api.expect(
api.model(
"MCPToolActionModel",
{
"parameters": fields.Raw(
required=False, description="Parameters for the tool action"
)
},
)
)
@api.doc(description="Execute a specific tool action on an MCP server")
def post(self, server_id, action_name):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
parameters = data.get("parameters", {})
try:
# Find the MCP tool
mcp_tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
)
if not mcp_tool_doc:
return make_response(
jsonify({"success": False, "error": "MCP server not found"}), 404
)
# Load the tool and execute action
from application.agents.tools.mcp_tool import MCPTool
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
result = mcp_tool.execute_action(action_name, **parameters)
return make_response(jsonify({"success": True, "result": result}), 200)
except Exception as e:
current_app.logger.error(
f"Error executing MCP tool action: {e}", exc_info=True
)
return make_response(
jsonify(
{"success": False, "error": f"Action execution failed: {str(e)}"}
),
500,
)

View File

@@ -89,7 +89,7 @@ class Settings(BaseSettings):
QDRANT_HOST: Optional[str] = None
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config

View File

View File

@@ -0,0 +1,97 @@
"""
Simple encryption utility for securely storing sensitive credentials.
Uses XOR encryption with a key derived from app secret and user ID.
Note: This is basic obfuscation. For production, consider using cryptography library.
"""
import base64
import hashlib
import os
import json
def _get_encryption_key(user_id: str) -> bytes:
"""
Generate a consistent encryption key for a specific user.
Uses app secret + user ID to create a unique key per user.
"""
# Get app secret from environment or use a default (in production, always use env)
app_secret = os.environ.get(
"APP_SECRET_KEY", "default-docsgpt-secret-key-change-in-production"
)
# Combine app secret with user ID for user-specific encryption
combined = f"{app_secret}#{user_id}"
# Create a 32-byte key
key_material = hashlib.sha256(combined.encode()).digest()
return key_material
def _xor_encrypt_decrypt(data: bytes, key: bytes) -> bytes:
"""Simple XOR encryption/decryption."""
result = bytearray()
for i, byte in enumerate(data):
result.append(byte ^ key[i % len(key)])
return bytes(result)
def encrypt_credentials(credentials: dict, user_id: str) -> str:
"""
Encrypt credentials dictionary for secure storage.
Args:
credentials: Dictionary containing sensitive data
user_id: User ID for creating user-specific encryption key
Returns:
Base64 encoded encrypted string
"""
if not credentials:
return ""
try:
key = _get_encryption_key(user_id)
# Convert dict to JSON string and encrypt
json_str = json.dumps(credentials)
encrypted_data = _xor_encrypt_decrypt(json_str.encode(), key)
# Return base64 encoded for storage
return base64.b64encode(encrypted_data).decode()
except Exception as e:
# If encryption fails, store empty string (will require re-auth)
print(f"Warning: Failed to encrypt credentials: {e}")
return ""
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
"""
Decrypt credentials from storage.
Args:
encrypted_data: Base64 encoded encrypted string
user_id: User ID for creating user-specific encryption key
Returns:
Dictionary containing decrypted credentials
"""
if not encrypted_data:
return {}
try:
key = _get_encryption_key(user_id)
# Decode and decrypt
encrypted_bytes = base64.b64decode(encrypted_data.encode())
decrypted_data = _xor_encrypt_decrypt(encrypted_bytes, key)
# Parse JSON back to dict
return json.loads(decrypted_data.decode())
except Exception as e:
# If decryption fails, return empty dict (will require re-auth)
print(f"Warning: Failed to decrypt credentials: {e}")
return {}