mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: finalize remote mcp
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
|
||||
_mcp_session_cache = {}
|
||||
@@ -33,18 +34,12 @@ class MCPTool(Tool):
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
# Fallback to unencrypted credentials (for backward compatibility)
|
||||
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.available_tools = []
|
||||
self._session = requests.Session()
|
||||
@@ -52,10 +47,25 @@ class MCPTool(Tool):
|
||||
self._setup_authentication()
|
||||
self._cache_key = self._generate_cache_key()
|
||||
|
||||
def _setup_authentication(self):
|
||||
"""Setup authentication for the MCP server connection."""
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
self._session.headers.update({header_name: api_key})
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
if token:
|
||||
self._session.headers.update({"Authorization": f"Bearer {token}"})
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
self._session.auth = (username, password)
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
# Use server URL + auth info to create unique key
|
||||
|
||||
auth_key = ""
|
||||
if self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
@@ -76,13 +86,9 @@ class MCPTool(Tool):
|
||||
|
||||
if self._cache_key in _mcp_session_cache:
|
||||
session_data = _mcp_session_cache[self._cache_key]
|
||||
# Check if session is less than 30 minutes old
|
||||
|
||||
if time.time() - session_data["created_at"] < 1800: # 30 minutes
|
||||
if time.time() - session_data["created_at"] < 1800:
|
||||
return session_data["session_id"]
|
||||
else:
|
||||
# Remove expired session
|
||||
|
||||
del _mcp_session_cache[self._cache_key]
|
||||
return None
|
||||
|
||||
@@ -94,23 +100,6 @@ class MCPTool(Tool):
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _setup_authentication(self):
|
||||
"""Setup authentication for the MCP server connection."""
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
self._session.headers.update({header_name: api_key})
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
if token:
|
||||
self._session.headers.update({"Authorization": f"Bearer {token}"})
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
self._session.auth = (username, password)
|
||||
|
||||
def _initialize_mcp_connection(self) -> Dict:
|
||||
"""
|
||||
Initialize MCP connection with the server, using cached session if available.
|
||||
@@ -264,10 +253,7 @@ class MCPTool(Tool):
|
||||
"""
|
||||
self._ensure_valid_session()
|
||||
|
||||
# Prepare call parameters for MCP protocol
|
||||
|
||||
call_params = {"name": action_name, "arguments": kwargs}
|
||||
|
||||
try:
|
||||
result = self._make_mcp_request("tools/call", call_params)
|
||||
return result
|
||||
@@ -283,9 +269,6 @@ class MCPTool(Tool):
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
# Parse MCP tool schema according to MCP specification
|
||||
# Check multiple possible schema locations for compatibility
|
||||
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
@@ -293,20 +276,14 @@ class MCPTool(Tool):
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
# Default empty schema if no inputSchema provided
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
# Parse the inputSchema if it exists
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
# Handle standard JSON Schema format
|
||||
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
@@ -314,14 +291,10 @@ class MCPTool(Tool):
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
# Add additional schema properties if they exist
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
# Might be properties directly at root level
|
||||
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
@@ -331,64 +304,6 @@ class MCPTool(Tool):
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
"""
|
||||
Get configuration requirements for the MCP tool.
|
||||
|
||||
Returns:
|
||||
Dictionary describing required configuration
|
||||
"""
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
|
||||
"required": True,
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"description": "Authentication type",
|
||||
"enum": ["none", "api_key", "bearer", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
},
|
||||
"auth_credentials": {
|
||||
"type": "object",
|
||||
"description": "Authentication credentials (varies by auth_type)",
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "API key for api_key auth",
|
||||
},
|
||||
"header_name": {
|
||||
"type": "string",
|
||||
"description": "Header name for API key (default: X-API-Key)",
|
||||
"default": "X-API-Key",
|
||||
},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Bearer token for bearer auth",
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"description": "Username for basic auth",
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"description": "Password for basic auth",
|
||||
},
|
||||
},
|
||||
"required": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
"""
|
||||
Test the connection to the MCP server and validate functionality.
|
||||
@@ -411,9 +326,7 @@ class MCPTool(Tool):
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"session_id": self._mcp_session_id,
|
||||
"tools": [
|
||||
tool.get("name", "unknown") for tool in tools[:5]
|
||||
], # First 5 tool names
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
@@ -422,3 +335,32 @@ class MCPTool(Tool):
|
||||
"tools_count": 0,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
|
||||
"required": True,
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"description": "Authentication type",
|
||||
"enum": ["none", "api_key", "bearer", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
},
|
||||
"auth_credentials": {
|
||||
"type": "object",
|
||||
"description": "Authentication credentials (varies by auth_type)",
|
||||
"required": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
# For MCP tools, pass the user_id for credential decryption
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
@@ -36,18 +35,11 @@ class ToolManager:
|
||||
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
# For MCP tools, they might not be pre-loaded, so load dynamically
|
||||
if tool_name == "mcp_tool":
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded and no config provided for dynamic loading")
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
|
||||
# For MCP tools, if user_id is provided, create a new instance with user context
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
# Load tool dynamically with user context for proper credential access
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
@@ -3,11 +3,12 @@ import json
|
||||
import math
|
||||
import os
|
||||
import secrets
|
||||
import tempfile
|
||||
import uuid
|
||||
import zipfile
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
@@ -24,7 +25,10 @@ from flask_restx import fields, inputs, Namespace, Resource
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest,
|
||||
@@ -34,17 +38,17 @@ from application.api.user.tasks import (
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.api import api
|
||||
from application.security.encryption import encrypt_credentials, decrypt_credentials
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
generate_image_url,
|
||||
num_tokens_from_string,
|
||||
safe_filename,
|
||||
validate_function_name,
|
||||
validate_required_fields,
|
||||
)
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
@@ -3435,31 +3439,6 @@ class CreateTool(Resource):
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
try:
|
||||
# Process config to encrypt credentials for MCP tools
|
||||
config = data["config"]
|
||||
if data["name"] == "mcp_tool":
|
||||
from application.security.encryption import encrypt_credentials
|
||||
|
||||
# Extract credentials from config
|
||||
credentials = {}
|
||||
if config.get("auth_type") == "bearer":
|
||||
credentials["bearer_token"] = config.get("bearer_token", "")
|
||||
elif config.get("auth_type") == "api_key":
|
||||
credentials["api_key"] = config.get("api_key", "")
|
||||
credentials["api_key_header"] = config.get("api_key_header", "")
|
||||
elif config.get("auth_type") == "basic":
|
||||
credentials["username"] = config.get("username", "")
|
||||
credentials["password"] = config.get("password", "")
|
||||
|
||||
# Encrypt credentials if any exist
|
||||
if credentials:
|
||||
config["encrypted_credentials"] = encrypt_credentials(
|
||||
credentials, user
|
||||
)
|
||||
# Remove plaintext credentials from config
|
||||
for key in credentials.keys():
|
||||
config.pop(key, None)
|
||||
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
@@ -3467,7 +3446,7 @@ class CreateTool(Resource):
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": config,
|
||||
"config": data["config"],
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
@@ -3534,41 +3513,7 @@ class UpdateTool(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Handle MCP tool credential encryption
|
||||
config = data["config"]
|
||||
tool_name = data.get("name")
|
||||
if not tool_name:
|
||||
# Get the tool name from the database
|
||||
existing_tool = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
tool_name = existing_tool.get("name") if existing_tool else None
|
||||
|
||||
if tool_name == "mcp_tool":
|
||||
from application.security.encryption import encrypt_credentials
|
||||
|
||||
# Extract credentials from config
|
||||
credentials = {}
|
||||
if config.get("auth_type") == "bearer":
|
||||
credentials["bearer_token"] = config.get("bearer_token", "")
|
||||
elif config.get("auth_type") == "api_key":
|
||||
credentials["api_key"] = config.get("api_key", "")
|
||||
credentials["api_key_header"] = config.get("api_key_header", "")
|
||||
elif config.get("auth_type") == "basic":
|
||||
credentials["username"] = config.get("username", "")
|
||||
credentials["password"] = config.get("password", "")
|
||||
|
||||
# Encrypt credentials if any exist
|
||||
if credentials:
|
||||
config["encrypted_credentials"] = encrypt_credentials(
|
||||
credentials, user
|
||||
)
|
||||
# Remove plaintext credentials from config
|
||||
for key in credentials.keys():
|
||||
config.pop(key, None)
|
||||
|
||||
update_data["config"] = config
|
||||
update_data["config"] = data["config"]
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
@@ -4142,74 +4087,55 @@ class DirectoryStructure(Resource):
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_servers")
|
||||
class MCPServers(Resource):
|
||||
@api.doc(description="Get all MCP servers configured by the user")
|
||||
def get(self):
|
||||
@user_ns.route("/api/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerTestModel",
|
||||
{
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration to test"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Test MCP server connection with provided configuration")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
# Find all MCP tools for this user
|
||||
mcp_tools = user_tools_collection.find({"user": user, "name": "mcp_tool"})
|
||||
config = data["config"]
|
||||
|
||||
servers = []
|
||||
for tool in mcp_tools:
|
||||
config = tool.get("config", {})
|
||||
servers.append(
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("displayName", "MCP Server"),
|
||||
"server_url": config.get("server_url", ""),
|
||||
"auth_type": config.get("auth_type", "none"),
|
||||
"status": tool.get("status", False),
|
||||
"created_at": (
|
||||
tool.get("_id").generation_time.isoformat()
|
||||
if tool.get("_id")
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
|
||||
return make_response(jsonify({"success": True, "servers": servers}), 200)
|
||||
if auth_type == "api_key" and "api_key" in config:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer" and "bearer_token" in config:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config:
|
||||
auth_credentials["password"] = config["password"]
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving MCP servers: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/<string:server_id>/test")
|
||||
class TestMCPServer(Resource):
|
||||
@api.doc(description="Test connection to an MCP server")
|
||||
def post(self, server_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
# Find the MCP tool
|
||||
mcp_tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
|
||||
if not mcp_tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "MCP server not found"}), 404
|
||||
)
|
||||
|
||||
# Load the tool and test connection
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
|
||||
mcp_tool = MCPTool(test_config, user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
@@ -4220,38 +4146,86 @@ class TestMCPServer(Resource):
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/<string:server_id>/tools")
|
||||
class MCPServerTools(Resource):
|
||||
@api.doc(description="Discover and get tools from an MCP server")
|
||||
def get(self, server_id):
|
||||
@user_ns.route("/api/mcp_server/save")
|
||||
class MCPServerSave(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerSaveModel",
|
||||
{
|
||||
"id": fields.String(
|
||||
required=False, description="Tool ID for updates (optional)"
|
||||
),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the MCP server"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=False, default=True, description="Tool status"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create or update MCP server with automatic tool discovery")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
# Find the MCP tool
|
||||
mcp_tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
data = request.get_json()
|
||||
|
||||
if not mcp_tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "MCP server not found"}), 404
|
||||
required_fields = ["displayName", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "api_key":
|
||||
if "api_key" in config and config["api_key"]:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer":
|
||||
if "bearer_token" in config and config["bearer_token"]:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config and config["username"]:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config and config["password"]:
|
||||
auth_credentials["password"] = config["password"]
|
||||
mcp_config = config.copy()
|
||||
mcp_config["auth_credentials"] = auth_credentials
|
||||
|
||||
if auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(mcp_config, user)
|
||||
mcp_tool.discover_tools()
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
else:
|
||||
raise Exception(
|
||||
"No valid credentials provided for the selected authentication type"
|
||||
)
|
||||
|
||||
# Load the tool and discover tools
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
storage_config = config.copy()
|
||||
if auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
||||
|
||||
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
|
||||
tools = mcp_tool.discover_tools()
|
||||
|
||||
# Get actions metadata and transform to match other tools format
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
transformed_actions = []
|
||||
|
||||
for action in actions_metadata:
|
||||
# Add active flag and transform parameters
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
@@ -4261,77 +4235,53 @@ class MCPServerTools(Resource):
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"displayName": data["displayName"],
|
||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": data.get("status", True),
|
||||
"user": user,
|
||||
}
|
||||
|
||||
# Update the stored actions in the database
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(server_id)}, {"$set": {"actions": transformed_actions}}
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "tools": tools, "actions": transformed_actions}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
tool_id = data.get("id")
|
||||
if tool_id:
|
||||
result = user_tools_collection.update_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
result = user_tools_collection.insert_one(tool_data)
|
||||
tool_id = str(result.inserted_id)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error discovering MCP tools: {e}", exc_info=True)
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Tool discovery failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/<string:server_id>/tools/<string:action_name>")
|
||||
class MCPServerToolAction(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPToolActionModel",
|
||||
{
|
||||
"parameters": fields.Raw(
|
||||
required=False, description="Parameters for the tool action"
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Execute a specific tool action on an MCP server")
|
||||
def post(self, server_id, action_name):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
parameters = data.get("parameters", {})
|
||||
|
||||
try:
|
||||
# Find the MCP tool
|
||||
mcp_tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
|
||||
if not mcp_tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "MCP server not found"}), 404
|
||||
)
|
||||
|
||||
# Load the tool and execute action
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
|
||||
result = mcp_tool.execute_action(action_name, **parameters)
|
||||
|
||||
return make_response(jsonify({"success": True, "result": result}), 200)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error executing MCP tool action: {e}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Action execution failed: {str(e)}"}
|
||||
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
@@ -109,6 +109,9 @@ class Settings(BaseSettings):
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -2,6 +2,7 @@ anthropic==0.49.0
|
||||
boto3==1.38.18
|
||||
beautifulsoup4==4.13.4
|
||||
celery==5.4.0
|
||||
cryptography==42.0.8
|
||||
dataclasses-json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==7.5.2
|
||||
|
||||
@@ -1,97 +1,85 @@
|
||||
"""
|
||||
Simple encryption utility for securely storing sensitive credentials.
|
||||
Uses XOR encryption with a key derived from app secret and user ID.
|
||||
Note: This is basic obfuscation. For production, consider using cryptography library.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def _get_encryption_key(user_id: str) -> bytes:
|
||||
"""
|
||||
Generate a consistent encryption key for a specific user.
|
||||
Uses app secret + user ID to create a unique key per user.
|
||||
"""
|
||||
# Get app secret from environment or use a default (in production, always use env)
|
||||
app_secret = os.environ.get(
|
||||
"APP_SECRET_KEY", "default-docsgpt-secret-key-change-in-production"
|
||||
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
||||
app_secret = settings.ENCRYPTION_SECRET_KEY
|
||||
|
||||
password = f"{app_secret}#{user_id}".encode()
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
# Combine app secret with user ID for user-specific encryption
|
||||
combined = f"{app_secret}#{user_id}"
|
||||
|
||||
# Create a 32-byte key
|
||||
key_material = hashlib.sha256(combined.encode()).digest()
|
||||
|
||||
return key_material
|
||||
|
||||
|
||||
def _xor_encrypt_decrypt(data: bytes, key: bytes) -> bytes:
|
||||
"""Simple XOR encryption/decryption."""
|
||||
result = bytearray()
|
||||
for i, byte in enumerate(data):
|
||||
result.append(byte ^ key[i % len(key)])
|
||||
return bytes(result)
|
||||
return kdf.derive(password)
|
||||
|
||||
|
||||
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||
"""
|
||||
Encrypt credentials dictionary for secure storage.
|
||||
|
||||
Args:
|
||||
credentials: Dictionary containing sensitive data
|
||||
user_id: User ID for creating user-specific encryption key
|
||||
|
||||
Returns:
|
||||
Base64 encoded encrypted string
|
||||
"""
|
||||
if not credentials:
|
||||
return ""
|
||||
|
||||
try:
|
||||
key = _get_encryption_key(user_id)
|
||||
salt = os.urandom(16)
|
||||
iv = os.urandom(16)
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
# Convert dict to JSON string and encrypt
|
||||
json_str = json.dumps(credentials)
|
||||
encrypted_data = _xor_encrypt_decrypt(json_str.encode(), key)
|
||||
|
||||
# Return base64 encoded for storage
|
||||
return base64.b64encode(encrypted_data).decode()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
padded_data = _pad_data(json_str.encode())
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
result = salt + iv + encrypted_data
|
||||
return base64.b64encode(result).decode()
|
||||
except Exception as e:
|
||||
# If encryption fails, store empty string (will require re-auth)
|
||||
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||
"""
|
||||
Decrypt credentials from storage.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64 encoded encrypted string
|
||||
user_id: User ID for creating user-specific encryption key
|
||||
|
||||
Returns:
|
||||
Dictionary containing decrypted credentials
|
||||
"""
|
||||
if not encrypted_data:
|
||||
return {}
|
||||
|
||||
try:
|
||||
key = _get_encryption_key(user_id)
|
||||
data = base64.b64decode(encrypted_data.encode())
|
||||
|
||||
# Decode and decrypt
|
||||
encrypted_bytes = base64.b64decode(encrypted_data.encode())
|
||||
decrypted_data = _xor_encrypt_decrypt(encrypted_bytes, key)
|
||||
salt = data[:16]
|
||||
iv = data[16:32]
|
||||
encrypted_content = data[32:]
|
||||
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
||||
decrypted_data = _unpad_data(decrypted_padded)
|
||||
|
||||
# Parse JSON back to dict
|
||||
return json.loads(decrypted_data.decode())
|
||||
|
||||
except Exception as e:
|
||||
# If decryption fails, return empty dict (will require re-auth)
|
||||
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _pad_data(data: bytes) -> bytes:
|
||||
block_size = 16
|
||||
padding_len = block_size - (len(data) % block_size)
|
||||
padding = bytes([padding_len]) * padding_len
|
||||
return data + padding
|
||||
|
||||
|
||||
def _unpad_data(data: bytes) -> bytes:
|
||||
padding_len = data[-1]
|
||||
return data[:-padding_len]
|
||||
|
||||
Reference in New Issue
Block a user