mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-02 01:53:14 +00:00
feat: Add MCP Server management functionality
- Implemented encryption utility for securely storing sensitive credentials. - Added MCPServerModal component for managing MCP server configurations. - Updated AddToolModal to include MCP server management. - Enhanced localization with new strings for MCP server features. - Introduced SVG icons for MCP tools in the frontend. - Created a new settings section for MCP server configurations in the application.
This commit is contained in:
@@ -227,6 +227,7 @@ class BaseAgent(ABC):
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
|
||||
424
application/agents/tools/mcp_tool.py
Normal file
424
application/agents/tools/mcp_tool.py
Normal file
@@ -0,0 +1,424 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
|
||||
_mcp_session_cache = {}
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MCP Tool with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing MCP server configuration:
|
||||
- server_url: URL of the remote MCP server
|
||||
- auth_type: Type of authentication (api_key, bearer, basic, none)
|
||||
- encrypted_credentials: Encrypted credentials (if available)
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
# Fallback to unencrypted credentials (for backward compatibility)
|
||||
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.available_tools = []
|
||||
self._session = requests.Session()
|
||||
self._mcp_session_id = None
|
||||
self._setup_authentication()
|
||||
self._cache_key = self._generate_cache_key()
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
# Use server URL + auth info to create unique key
|
||||
|
||||
auth_key = ""
|
||||
if self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||
elif self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
auth_key = f"basic:{username}"
|
||||
else:
|
||||
auth_key = "none"
|
||||
return f"{self.server_url}#{auth_key}"
|
||||
|
||||
def _get_cached_session(self) -> Optional[str]:
|
||||
"""Get cached session ID if available and not expired."""
|
||||
global _mcp_session_cache
|
||||
|
||||
if self._cache_key in _mcp_session_cache:
|
||||
session_data = _mcp_session_cache[self._cache_key]
|
||||
# Check if session is less than 30 minutes old
|
||||
|
||||
if time.time() - session_data["created_at"] < 1800: # 30 minutes
|
||||
return session_data["session_id"]
|
||||
else:
|
||||
# Remove expired session
|
||||
|
||||
del _mcp_session_cache[self._cache_key]
|
||||
return None
|
||||
|
||||
def _cache_session(self, session_id: str):
|
||||
"""Cache the session ID for reuse."""
|
||||
global _mcp_session_cache
|
||||
_mcp_session_cache[self._cache_key] = {
|
||||
"session_id": session_id,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _setup_authentication(self):
|
||||
"""Setup authentication for the MCP server connection."""
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
self._session.headers.update({header_name: api_key})
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
if token:
|
||||
self._session.headers.update({"Authorization": f"Bearer {token}"})
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
self._session.auth = (username, password)
|
||||
|
||||
def _initialize_mcp_connection(self) -> Dict:
|
||||
"""
|
||||
Initialize MCP connection with the server, using cached session if available.
|
||||
|
||||
Returns:
|
||||
Server capabilities and information
|
||||
"""
|
||||
cached_session = self._get_cached_session()
|
||||
if cached_session:
|
||||
self._mcp_session_id = cached_session
|
||||
return {"cached": True}
|
||||
try:
|
||||
init_params = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"roots": {"listChanged": True}, "sampling": {}},
|
||||
"clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
|
||||
}
|
||||
response = self._make_mcp_request("initialize", init_params)
|
||||
self._make_mcp_request("notifications/initialized")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"error": str(e), "fallback": True}
|
||||
|
||||
def _ensure_valid_session(self):
|
||||
"""Ensure we have a valid MCP session, reinitializing if needed."""
|
||||
if not self._mcp_session_id:
|
||||
self._initialize_mcp_connection()
|
||||
|
||||
def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
|
||||
"""
|
||||
Make an MCP protocol request to the server with automatic session recovery.
|
||||
|
||||
Args:
|
||||
method: MCP method name (e.g., "tools/list", "tools/call")
|
||||
params: Parameters for the MCP method
|
||||
|
||||
Returns:
|
||||
Response data as dictionary
|
||||
|
||||
Raises:
|
||||
Exception: If request fails after retry
|
||||
"""
|
||||
mcp_message = {"jsonrpc": "2.0", "method": method}
|
||||
|
||||
if not method.startswith("notifications/"):
|
||||
mcp_message["id"] = 1
|
||||
if params:
|
||||
mcp_message["params"] = params
|
||||
return self._execute_mcp_request(mcp_message, method)
|
||||
|
||||
def _execute_mcp_request(
|
||||
self, mcp_message: Dict, method: str, is_retry: bool = False
|
||||
) -> Dict:
|
||||
"""Execute MCP request with optional retry on session failure."""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
headers.update(self._session.headers)
|
||||
|
||||
if self._mcp_session_id:
|
||||
headers["Mcp-Session-Id"] = self._mcp_session_id
|
||||
response = self._session.post(
|
||||
self.server_url.rstrip("/"),
|
||||
json=mcp_message,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if "mcp-session-id" in response.headers:
|
||||
self._mcp_session_id = response.headers["mcp-session-id"]
|
||||
self._cache_session(self._mcp_session_id)
|
||||
response.raise_for_status()
|
||||
|
||||
if method.startswith("notifications/"):
|
||||
return {}
|
||||
try:
|
||||
result = response.json()
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"Invalid JSON response: {response.text}")
|
||||
if "error" in result:
|
||||
error_msg = result["error"]
|
||||
if isinstance(error_msg, dict):
|
||||
error_msg = error_msg.get("message", str(error_msg))
|
||||
raise Exception(f"MCP server error: {error_msg}")
|
||||
return result.get("result", result)
|
||||
except requests.exceptions.RequestException as e:
|
||||
if not is_retry and self._should_retry_with_new_session(e):
|
||||
self._invalidate_and_refresh_session()
|
||||
return self._execute_mcp_request(mcp_message, method, is_retry=True)
|
||||
raise Exception(f"MCP server request failed: {str(e)}")
|
||||
|
||||
def _should_retry_with_new_session(self, error: Exception) -> bool:
|
||||
"""Check if error indicates session invalidation and retry is warranted."""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
any(
|
||||
indicator in error_str
|
||||
for indicator in [
|
||||
"invalid session",
|
||||
"session expired",
|
||||
"unauthorized",
|
||||
"401",
|
||||
"403",
|
||||
]
|
||||
)
|
||||
and self._mcp_session_id is not None
|
||||
)
|
||||
|
||||
def _invalidate_and_refresh_session(self) -> None:
|
||||
"""Invalidate current session and create a new one."""
|
||||
global _mcp_session_cache
|
||||
if self._cache_key in _mcp_session_cache:
|
||||
del _mcp_session_cache[self._cache_key]
|
||||
self._mcp_session_id = None
|
||||
self._initialize_mcp_connection()
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
Discover available tools from the MCP server using MCP protocol.
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the server
|
||||
"""
|
||||
try:
|
||||
self._ensure_valid_session()
|
||||
|
||||
response = self._make_mcp_request("tools/list")
|
||||
|
||||
if isinstance(response, dict) and "tools" in response:
|
||||
self.available_tools = response["tools"]
|
||||
return self.available_tools
|
||||
elif isinstance(response, list):
|
||||
self.available_tools = response
|
||||
return self.available_tools
|
||||
else:
|
||||
self.available_tools = []
|
||||
return self.available_tools
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||
"""
|
||||
Execute an action on the remote MCP server using MCP protocol.
|
||||
|
||||
Args:
|
||||
action_name: Name of the action to execute
|
||||
**kwargs: Parameters for the action
|
||||
|
||||
Returns:
|
||||
Result from the MCP server
|
||||
"""
|
||||
self._ensure_valid_session()
|
||||
|
||||
# Prepare call parameters for MCP protocol
|
||||
|
||||
call_params = {"name": action_name, "arguments": kwargs}
|
||||
|
||||
try:
|
||||
result = self._make_mcp_request("tools/call", call_params)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
|
||||
Returns:
|
||||
List of action metadata dictionaries
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
# Parse MCP tool schema according to MCP specification
|
||||
# Check multiple possible schema locations for compatibility
|
||||
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
or tool.get("schema")
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
# Default empty schema if no inputSchema provided
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
# Parse the inputSchema if it exists
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
# Handle standard JSON Schema format
|
||||
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": input_schema.get("properties", {}),
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
# Add additional schema properties if they exist
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
# Might be properties directly at root level
|
||||
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
"""
|
||||
Get configuration requirements for the MCP tool.
|
||||
|
||||
Returns:
|
||||
Dictionary describing required configuration
|
||||
"""
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
|
||||
"required": True,
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"description": "Authentication type",
|
||||
"enum": ["none", "api_key", "bearer", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
},
|
||||
"auth_credentials": {
|
||||
"type": "object",
|
||||
"description": "Authentication credentials (varies by auth_type)",
|
||||
"properties": {
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "API key for api_key auth",
|
||||
},
|
||||
"header_name": {
|
||||
"type": "string",
|
||||
"description": "Header name for API key (default: X-API-Key)",
|
||||
"default": "X-API-Key",
|
||||
},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Bearer token for bearer auth",
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"description": "Username for basic auth",
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"description": "Password for basic auth",
|
||||
},
|
||||
},
|
||||
"required": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
"""
|
||||
Test the connection to the MCP server and validate functionality.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection test results including tool count
|
||||
"""
|
||||
try:
|
||||
init_result = self._initialize_mcp_connection()
|
||||
|
||||
tools = self.discover_tools()
|
||||
|
||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||
if init_result.get("cached"):
|
||||
message += " (Using cached session)"
|
||||
elif init_result.get("fallback"):
|
||||
message += " (No formal initialization required)"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"session_id": self._mcp_session_id,
|
||||
"tools": [
|
||||
tool.get("name", "unknown") for tool in tools[:5]
|
||||
], # First 5 tool names
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
@@ -23,16 +23,31 @@ class ToolManager:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
# For MCP tools, pass the user_id for credential decryption
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
# For MCP tools, they might not be pre-loaded, so load dynamically
|
||||
if tool_name == "mcp_tool":
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded and no config provided for dynamic loading")
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
|
||||
# For MCP tools, if user_id is provided, create a new instance with user context
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
# Load tool dynamically with user context for proper credential access
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
@@ -492,9 +492,9 @@ class DeleteOldIndexes(Resource):
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
@@ -508,7 +508,7 @@ class DeleteOldIndexes(Resource):
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
|
||||
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
file_path = doc["file_path"]
|
||||
if storage.is_directory(file_path):
|
||||
@@ -517,7 +517,7 @@ class DeleteOldIndexes(Resource):
|
||||
storage.delete_file(f)
|
||||
else:
|
||||
storage.delete_file(file_path)
|
||||
|
||||
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as err:
|
||||
@@ -525,7 +525,7 @@ class DeleteOldIndexes(Resource):
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@@ -573,55 +573,75 @@ class UploadFile(Resource):
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
|
||||
for file in files:
|
||||
original_filename = file.filename
|
||||
safe_file = safe_filename(original_filename)
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||
file.save(temp_file_path)
|
||||
|
||||
|
||||
if zipfile.is_zipfile(temp_file_path):
|
||||
try:
|
||||
with zipfile.ZipFile(temp_file_path, 'r') as zip_ref:
|
||||
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(path=temp_dir)
|
||||
|
||||
|
||||
# Walk through extracted files and upload them
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for extracted_file in files:
|
||||
if os.path.join(root, extracted_file) == temp_file_path:
|
||||
if (
|
||||
os.path.join(root, extracted_file)
|
||||
== temp_file_path
|
||||
):
|
||||
continue
|
||||
|
||||
rel_path = os.path.relpath(os.path.join(root, extracted_file), temp_dir)
|
||||
|
||||
rel_path = os.path.relpath(
|
||||
os.path.join(root, extracted_file), temp_dir
|
||||
)
|
||||
storage_path = f"{base_path}/{rel_path}"
|
||||
|
||||
with open(os.path.join(root, extracted_file), 'rb') as f:
|
||||
|
||||
with open(
|
||||
os.path.join(root, extracted_file), "rb"
|
||||
) as f:
|
||||
storage.save_file(f, storage_path)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error extracting zip: {e}", exc_info=True)
|
||||
current_app.logger.error(
|
||||
f"Error extracting zip: {e}", exc_info=True
|
||||
)
|
||||
# If zip extraction fails, save the original zip file
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
else:
|
||||
# For non-zip files, save directly
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
|
||||
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
[
|
||||
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||
".jpg", ".jpeg",
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name
|
||||
filename=dir_name,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
@@ -635,12 +655,29 @@ class ManageSourceFiles(Resource):
|
||||
api.model(
|
||||
"ManageSourceFilesModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID to modify"),
|
||||
"operation": fields.String(required=True, description="Operation: 'add', 'remove', or 'remove_directory'"),
|
||||
"file_paths": fields.List(fields.String, required=False, description="File paths to remove (for remove operation)"),
|
||||
"directory_path": fields.String(required=False, description="Directory path to remove (for remove_directory operation)"),
|
||||
"file": fields.Raw(required=False, description="Files to add (for add operation)"),
|
||||
"parent_dir": fields.String(required=False, description="Parent directory path relative to source root"),
|
||||
"source_id": fields.String(
|
||||
required=True, description="Source ID to modify"
|
||||
),
|
||||
"operation": fields.String(
|
||||
required=True,
|
||||
description="Operation: 'add', 'remove', or 'remove_directory'",
|
||||
),
|
||||
"file_paths": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="File paths to remove (for remove operation)",
|
||||
),
|
||||
"directory_path": fields.String(
|
||||
required=False,
|
||||
description="Directory path to remove (for remove_directory operation)",
|
||||
),
|
||||
"file": fields.Raw(
|
||||
required=False, description="Files to add (for add operation)"
|
||||
),
|
||||
"parent_dir": fields.String(
|
||||
required=False,
|
||||
description="Parent directory path relative to source root",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -650,7 +687,9 @@ class ManageSourceFiles(Resource):
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "message": "Unauthorized"}), 401)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.form.get("source_id")
|
||||
@@ -658,12 +697,24 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
if not source_id or not operation:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "source_id and operation are required"}), 400
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "source_id and operation are required",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if operation not in ["add", "remove", "remove_directory"]:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "operation must be 'add', 'remove', or 'remove_directory'"}), 400
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "operation must be 'add', 'remove', or 'remove_directory'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -674,34 +725,53 @@ class ManageSourceFiles(Resource):
|
||||
)
|
||||
|
||||
try:
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found or access denied"}), 404
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Source not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "message": "Database error"}), 500)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
|
||||
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid parent directory path"}), 400
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid parent directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if operation == "add":
|
||||
files = request.files.getlist("file")
|
||||
if not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "No files provided for add operation"}), 400
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "No files provided for add operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
added_files = []
|
||||
|
||||
|
||||
target_dir = source_file_path
|
||||
if parent_dir:
|
||||
target_dir = f"{source_file_path}/{parent_dir}"
|
||||
@@ -720,26 +790,44 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"message": f"Added {len(added_files)} files",
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id
|
||||
}), 200)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Added {len(added_files)} files",
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths")
|
||||
if not file_paths_str:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "file_paths required for remove operation"}), 400
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "file_paths required for remove operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
file_paths = json.loads(file_paths_str) if isinstance(file_paths_str, str) else file_paths_str
|
||||
file_paths = (
|
||||
json.loads(file_paths_str)
|
||||
if isinstance(file_paths_str, str)
|
||||
else file_paths_str
|
||||
)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid file_paths format"}), 400
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid file_paths format"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Remove files from storage and directory structure
|
||||
@@ -757,18 +845,29 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id
|
||||
}), 200)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
elif operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path")
|
||||
if not directory_path:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "directory_path required for remove_directory operation"}), 400
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "directory_path required for remove_directory operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate directory path (prevent path traversal)
|
||||
@@ -778,10 +877,17 @@ class ManageSourceFiles(Resource):
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid directory path"}), 400
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
full_directory_path = f"{source_file_path}/{directory_path}" if directory_path else source_file_path
|
||||
full_directory_path = (
|
||||
f"{source_file_path}/{directory_path}"
|
||||
if directory_path
|
||||
else source_file_path
|
||||
)
|
||||
|
||||
if not storage.is_directory(full_directory_path):
|
||||
current_app.logger.warning(
|
||||
@@ -790,7 +896,13 @@ class ManageSourceFiles(Resource):
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Directory not found or is not a directory"}), 404
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Directory not found or is not a directory",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
@@ -802,7 +914,10 @@ class ManageSourceFiles(Resource):
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Failed to remove directory"}), 500
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
current_app.logger.info(
|
||||
@@ -816,12 +931,17 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id
|
||||
}), 200)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
@@ -835,8 +955,12 @@ class ManageSourceFiles(Resource):
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
error_context += f", parent_dir={parent_dir}"
|
||||
|
||||
current_app.logger.error(f"Error managing source files: {err} ({error_context})", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "message": "Operation failed"}), 500)
|
||||
current_app.logger.error(
|
||||
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Operation failed"}), 500
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/remote")
|
||||
@@ -984,7 +1108,7 @@ class PaginatedSources(Resource):
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"isNested": bool(doc.get("directory_structure"))
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
response = {
|
||||
@@ -1032,7 +1156,7 @@ class CombinedJson(Resource):
|
||||
"tokens": index.get("tokens", ""),
|
||||
"retriever": index.get("retriever", "classic"),
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"is_nested": bool(index.get("directory_structure"))
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -1381,7 +1505,8 @@ class CreateAgent(Resource):
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False, description="JSON schema for enforcing structured output format"
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -1407,7 +1532,7 @@ class CreateAgent(Resource):
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
print(f"Received data: {data}")
|
||||
|
||||
|
||||
# Validate JSON schema if provided
|
||||
if data.get("json_schema"):
|
||||
try:
|
||||
@@ -1415,20 +1540,32 @@ class CreateAgent(Resource):
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}),
|
||||
400
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid JSON object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}),
|
||||
400
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}),
|
||||
400
|
||||
jsonify(
|
||||
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
@@ -1529,7 +1666,8 @@ class UpdateAgent(Resource):
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False, description="JSON schema for enforcing structured output format"
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -3297,6 +3435,31 @@ class CreateTool(Resource):
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
try:
|
||||
# Process config to encrypt credentials for MCP tools
|
||||
config = data["config"]
|
||||
if data["name"] == "mcp_tool":
|
||||
from application.security.encryption import encrypt_credentials
|
||||
|
||||
# Extract credentials from config
|
||||
credentials = {}
|
||||
if config.get("auth_type") == "bearer":
|
||||
credentials["bearer_token"] = config.get("bearer_token", "")
|
||||
elif config.get("auth_type") == "api_key":
|
||||
credentials["api_key"] = config.get("api_key", "")
|
||||
credentials["api_key_header"] = config.get("api_key_header", "")
|
||||
elif config.get("auth_type") == "basic":
|
||||
credentials["username"] = config.get("username", "")
|
||||
credentials["password"] = config.get("password", "")
|
||||
|
||||
# Encrypt credentials if any exist
|
||||
if credentials:
|
||||
config["encrypted_credentials"] = encrypt_credentials(
|
||||
credentials, user
|
||||
)
|
||||
# Remove plaintext credentials from config
|
||||
for key in credentials.keys():
|
||||
config.pop(key, None)
|
||||
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
@@ -3304,7 +3467,7 @@ class CreateTool(Resource):
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": data["config"],
|
||||
"config": config,
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
@@ -3371,7 +3534,41 @@ class UpdateTool(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_data["config"] = data["config"]
|
||||
|
||||
# Handle MCP tool credential encryption
|
||||
config = data["config"]
|
||||
tool_name = data.get("name")
|
||||
if not tool_name:
|
||||
# Get the tool name from the database
|
||||
existing_tool = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
tool_name = existing_tool.get("name") if existing_tool else None
|
||||
|
||||
if tool_name == "mcp_tool":
|
||||
from application.security.encryption import encrypt_credentials
|
||||
|
||||
# Extract credentials from config
|
||||
credentials = {}
|
||||
if config.get("auth_type") == "bearer":
|
||||
credentials["bearer_token"] = config.get("bearer_token", "")
|
||||
elif config.get("auth_type") == "api_key":
|
||||
credentials["api_key"] = config.get("api_key", "")
|
||||
credentials["api_key_header"] = config.get("api_key_header", "")
|
||||
elif config.get("auth_type") == "basic":
|
||||
credentials["username"] = config.get("username", "")
|
||||
credentials["password"] = config.get("password", "")
|
||||
|
||||
# Encrypt credentials if any exist
|
||||
if credentials:
|
||||
config["encrypted_credentials"] = encrypt_credentials(
|
||||
credentials, user
|
||||
)
|
||||
# Remove plaintext credentials from config
|
||||
for key in credentials.keys():
|
||||
config.pop(key, None)
|
||||
|
||||
update_data["config"] = config
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
@@ -3537,7 +3734,7 @@ class GetChunks(Resource):
|
||||
"page": "Page number for pagination",
|
||||
"per_page": "Number of chunks per page",
|
||||
"path": "Optional: Filter chunks by relative file path",
|
||||
"search": "Optional: Search term to filter chunks by title or content"
|
||||
"search": "Optional: Search term to filter chunks by title or content",
|
||||
},
|
||||
)
|
||||
def get(self):
|
||||
@@ -3561,7 +3758,7 @@ class GetChunks(Resource):
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
metadata = chunk.get("metadata", {})
|
||||
@@ -3582,9 +3779,9 @@ class GetChunks(Resource):
|
||||
continue
|
||||
|
||||
filtered_chunks.append(chunk)
|
||||
|
||||
|
||||
chunks = filtered_chunks
|
||||
|
||||
|
||||
total_chunks = len(chunks)
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
@@ -3598,7 +3795,7 @@ class GetChunks(Resource):
|
||||
"total": total_chunks,
|
||||
"chunks": paginated_chunks,
|
||||
"path": path if path else None,
|
||||
"search": search_term if search_term else None
|
||||
"search": search_term if search_term else None,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -3607,6 +3804,7 @@ class GetChunks(Resource):
|
||||
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@user_ns.route("/api/add_chunk")
|
||||
class AddChunk(Resource):
|
||||
@api.expect(
|
||||
@@ -3773,7 +3971,9 @@ class UpdateChunk(Resource):
|
||||
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if not deleted:
|
||||
current_app.logger.warning(f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created")
|
||||
current_app.logger.warning(
|
||||
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -3905,39 +4105,233 @@ class DirectoryStructure(Resource):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
|
||||
|
||||
if not doc_id:
|
||||
return make_response(
|
||||
jsonify({"error": "Document ID is required"}), 400
|
||||
)
|
||||
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
|
||||
|
||||
try:
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
|
||||
|
||||
directory_structure = doc.get("directory_structure", {})
|
||||
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": True,
|
||||
"directory_structure": directory_structure,
|
||||
"base_path": doc.get("file_path", "")
|
||||
}), 200
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"directory_structure": directory_structure,
|
||||
"base_path": doc.get("file_path", ""),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 500
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_servers")
|
||||
class MCPServers(Resource):
|
||||
@api.doc(description="Get all MCP servers configured by the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
# Find all MCP tools for this user
|
||||
mcp_tools = user_tools_collection.find({"user": user, "name": "mcp_tool"})
|
||||
|
||||
servers = []
|
||||
for tool in mcp_tools:
|
||||
config = tool.get("config", {})
|
||||
servers.append(
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("displayName", "MCP Server"),
|
||||
"server_url": config.get("server_url", ""),
|
||||
"auth_type": config.get("auth_type", "none"),
|
||||
"status": tool.get("status", False),
|
||||
"created_at": (
|
||||
tool.get("_id").generation_time.isoformat()
|
||||
if tool.get("_id")
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True, "servers": servers}), 200)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving MCP servers: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/<string:server_id>/test")
|
||||
class TestMCPServer(Resource):
|
||||
@api.doc(description="Test connection to an MCP server")
|
||||
def post(self, server_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
# Find the MCP tool
|
||||
mcp_tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
|
||||
if not mcp_tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "MCP server not found"}), 404
|
||||
)
|
||||
|
||||
# Load the tool and test connection
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/<string:server_id>/tools")
|
||||
class MCPServerTools(Resource):
|
||||
@api.doc(description="Discover and get tools from an MCP server")
|
||||
def get(self, server_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
# Find the MCP tool
|
||||
mcp_tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
|
||||
if not mcp_tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "MCP server not found"}), 404
|
||||
)
|
||||
|
||||
# Load the tool and discover tools
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
|
||||
tools = mcp_tool.discover_tools()
|
||||
|
||||
# Get actions metadata and transform to match other tools format
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
transformed_actions = []
|
||||
|
||||
for action in actions_metadata:
|
||||
# Add active flag and transform parameters
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
|
||||
# Update the stored actions in the database
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(server_id)}, {"$set": {"actions": transformed_actions}}
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "tools": tools, "actions": transformed_actions}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error discovering MCP tools: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Tool discovery failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@user_ns.route("/api/mcp_server/<string:server_id>/tools/<string:action_name>")
|
||||
class MCPServerToolAction(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPToolActionModel",
|
||||
{
|
||||
"parameters": fields.Raw(
|
||||
required=False, description="Parameters for the tool action"
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Execute a specific tool action on an MCP server")
|
||||
def post(self, server_id, action_name):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
parameters = data.get("parameters", {})
|
||||
|
||||
try:
|
||||
# Find the MCP tool
|
||||
mcp_tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(server_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
|
||||
if not mcp_tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "MCP server not found"}), 404
|
||||
)
|
||||
|
||||
# Load the tool and execute action
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
mcp_tool = MCPTool(mcp_tool_doc.get("config", {}), user)
|
||||
result = mcp_tool.execute_action(action_name, **parameters)
|
||||
|
||||
return make_response(jsonify({"success": True, "result": result}), 200)
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error executing MCP tool action: {e}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Action execution failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class Settings(BaseSettings):
|
||||
QDRANT_HOST: Optional[str] = None
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
|
||||
# PGVector vectorstore config
|
||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||
# Milvus vectorstore config
|
||||
|
||||
0
application/security/__init__.py
Normal file
0
application/security/__init__.py
Normal file
97
application/security/encryption.py
Normal file
97
application/security/encryption.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Simple encryption utility for securely storing sensitive credentials.
|
||||
Uses XOR encryption with a key derived from app secret and user ID.
|
||||
Note: This is basic obfuscation. For production, consider using cryptography library.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def _get_encryption_key(user_id: str) -> bytes:
|
||||
"""
|
||||
Generate a consistent encryption key for a specific user.
|
||||
Uses app secret + user ID to create a unique key per user.
|
||||
"""
|
||||
# Get app secret from environment or use a default (in production, always use env)
|
||||
app_secret = os.environ.get(
|
||||
"APP_SECRET_KEY", "default-docsgpt-secret-key-change-in-production"
|
||||
)
|
||||
|
||||
# Combine app secret with user ID for user-specific encryption
|
||||
combined = f"{app_secret}#{user_id}"
|
||||
|
||||
# Create a 32-byte key
|
||||
key_material = hashlib.sha256(combined.encode()).digest()
|
||||
|
||||
return key_material
|
||||
|
||||
|
||||
def _xor_encrypt_decrypt(data: bytes, key: bytes) -> bytes:
|
||||
"""Simple XOR encryption/decryption."""
|
||||
result = bytearray()
|
||||
for i, byte in enumerate(data):
|
||||
result.append(byte ^ key[i % len(key)])
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||
"""
|
||||
Encrypt credentials dictionary for secure storage.
|
||||
|
||||
Args:
|
||||
credentials: Dictionary containing sensitive data
|
||||
user_id: User ID for creating user-specific encryption key
|
||||
|
||||
Returns:
|
||||
Base64 encoded encrypted string
|
||||
"""
|
||||
if not credentials:
|
||||
return ""
|
||||
|
||||
try:
|
||||
key = _get_encryption_key(user_id)
|
||||
|
||||
# Convert dict to JSON string and encrypt
|
||||
json_str = json.dumps(credentials)
|
||||
encrypted_data = _xor_encrypt_decrypt(json_str.encode(), key)
|
||||
|
||||
# Return base64 encoded for storage
|
||||
return base64.b64encode(encrypted_data).decode()
|
||||
|
||||
except Exception as e:
|
||||
# If encryption fails, store empty string (will require re-auth)
|
||||
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||
"""
|
||||
Decrypt credentials from storage.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64 encoded encrypted string
|
||||
user_id: User ID for creating user-specific encryption key
|
||||
|
||||
Returns:
|
||||
Dictionary containing decrypted credentials
|
||||
"""
|
||||
if not encrypted_data:
|
||||
return {}
|
||||
|
||||
try:
|
||||
key = _get_encryption_key(user_id)
|
||||
|
||||
# Decode and decrypt
|
||||
encrypted_bytes = base64.b64decode(encrypted_data.encode())
|
||||
decrypted_data = _xor_encrypt_decrypt(encrypted_bytes, key)
|
||||
|
||||
# Parse JSON back to dict
|
||||
return json.loads(decrypted_data.decode())
|
||||
|
||||
except Exception as e:
|
||||
# If decryption fails, return empty dict (will require re-auth)
|
||||
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||
return {}
|
||||
Reference in New Issue
Block a user