feat: finalize remote mcp

This commit is contained in:
Siddhant Rai
2025-09-04 15:10:12 +05:30
parent 7c23f43c63
commit 1bf6af6eeb
11 changed files with 453 additions and 646 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional
import requests
from application.agents.tools.base import Tool
from application.security.encryption import decrypt_credentials
_mcp_session_cache = {}
@@ -33,18 +34,12 @@ class MCPTool(Tool):
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
# Decrypt credentials if they are encrypted
self.auth_credentials = {}
if config.get("encrypted_credentials") and user_id:
from application.security.encryption import decrypt_credentials
self.auth_credentials = decrypt_credentials(
config["encrypted_credentials"], user_id
)
else:
# Fallback to unencrypted credentials (for backward compatibility)
self.auth_credentials = config.get("auth_credentials", {})
self.available_tools = []
self._session = requests.Session()
@@ -52,10 +47,25 @@ class MCPTool(Tool):
self._setup_authentication()
self._cache_key = self._generate_cache_key()
def _setup_authentication(self):
"""Setup authentication for the MCP server connection."""
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
self._session.headers.update({header_name: api_key})
elif self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
if token:
self._session.headers.update({"Authorization": f"Bearer {token}"})
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
self._session.auth = (username, password)
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
# Use server URL + auth info to create unique key
auth_key = ""
if self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
@@ -76,13 +86,9 @@ class MCPTool(Tool):
if self._cache_key in _mcp_session_cache:
session_data = _mcp_session_cache[self._cache_key]
# Check if session is less than 30 minutes old
if time.time() - session_data["created_at"] < 1800: # 30 minutes
if time.time() - session_data["created_at"] < 1800:
return session_data["session_id"]
else:
# Remove expired session
del _mcp_session_cache[self._cache_key]
return None
@@ -94,23 +100,6 @@ class MCPTool(Tool):
"created_at": time.time(),
}
def _setup_authentication(self):
"""Setup authentication for the MCP server connection."""
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
self._session.headers.update({header_name: api_key})
elif self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
if token:
self._session.headers.update({"Authorization": f"Bearer {token}"})
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
self._session.auth = (username, password)
def _initialize_mcp_connection(self) -> Dict:
"""
Initialize MCP connection with the server, using cached session if available.
@@ -264,10 +253,7 @@ class MCPTool(Tool):
"""
self._ensure_valid_session()
# Prepare call parameters for MCP protocol
call_params = {"name": action_name, "arguments": kwargs}
try:
result = self._make_mcp_request("tools/call", call_params)
return result
@@ -283,9 +269,6 @@ class MCPTool(Tool):
"""
actions = []
for tool in self.available_tools:
# Parse MCP tool schema according to MCP specification
# Check multiple possible schema locations for compatibility
input_schema = (
tool.get("inputSchema")
or tool.get("input_schema")
@@ -293,20 +276,14 @@ class MCPTool(Tool):
or tool.get("parameters")
)
# Default empty schema if no inputSchema provided
parameters_schema = {
"type": "object",
"properties": {},
"required": [],
}
# Parse the inputSchema if it exists
if input_schema:
if isinstance(input_schema, dict):
# Handle standard JSON Schema format
if "properties" in input_schema:
parameters_schema = {
"type": input_schema.get("type", "object"),
@@ -314,14 +291,10 @@ class MCPTool(Tool):
"required": input_schema.get("required", []),
}
# Add additional schema properties if they exist
for key in ["additionalProperties", "description"]:
if key in input_schema:
parameters_schema[key] = input_schema[key]
else:
# Might be properties directly at root level
parameters_schema["properties"] = input_schema
action = {
"name": tool.get("name", ""),
@@ -331,64 +304,6 @@ class MCPTool(Tool):
actions.append(action)
return actions
def get_config_requirements(self) -> Dict:
"""
Get configuration requirements for the MCP tool.
Returns:
Dictionary describing required configuration
"""
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
"required": True,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "api_key", "bearer", "basic"],
"default": "none",
"required": True,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"properties": {
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"header_name": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
"default": "X-API-Key",
},
"token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
}
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
@@ -411,9 +326,7 @@ class MCPTool(Tool):
"message": message,
"tools_count": len(tools),
"session_id": self._mcp_session_id,
"tools": [
tool.get("name", "unknown") for tool in tools[:5]
], # First 5 tool names
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
}
except Exception as e:
return {
@@ -422,3 +335,32 @@ class MCPTool(Tool):
"tools_count": 0,
"error_type": type(e).__name__,
}
def get_config_requirements(self) -> Dict:
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
"required": True,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "api_key", "bearer", "basic"],
"default": "none",
"required": True,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
}

View File

@@ -28,7 +28,6 @@ class ToolManager:
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
# For MCP tools, pass the user_id for credential decryption
if tool_name == "mcp_tool" and user_id:
return obj(tool_config, user_id)
else:
@@ -36,18 +35,11 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
# For MCP tools, they might not be pre-loaded, so load dynamically
if tool_name == "mcp_tool":
raise ValueError(f"Tool '{tool_name}' not loaded and no config provided for dynamic loading")
raise ValueError(f"Tool '{tool_name}' not loaded")
# For MCP tools, if user_id is provided, create a new instance with user context
if tool_name == "mcp_tool" and user_id:
# Load tool dynamically with user context for proper credential access
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):