Files
DocsGPT/tests/integration/test_mcp.py
Pavel 58f27ed141 All endpoints covered
test_integration.py kept for backwards compatability.
tests/integration/run_all.py proposed as alternative to cover all endpoints.
2026-01-14 23:14:05 +03:00

338 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Integration tests for DocsGPT MCP (Model Context Protocol) server endpoints.
Endpoints tested:
- /api/mcp_server/callback (GET) - OAuth callback
- /api/mcp_server/oauth_status/{task_id} (GET) - OAuth status
- /api/mcp_server/save (POST) - Save MCP server config
- /api/mcp_server/test (POST) - Test MCP server connection
Usage:
python tests/integration/test_mcp.py
python tests/integration/test_mcp.py --base-url http://localhost:7091
python tests/integration/test_mcp.py --token YOUR_JWT_TOKEN
"""
import sys
import time
from pathlib import Path
# Add parent directory to path for standalone execution
_THIS_DIR = Path(__file__).parent
_TESTS_DIR = _THIS_DIR.parent
_ROOT_DIR = _TESTS_DIR.parent
if str(_ROOT_DIR) not in sys.path:
sys.path.insert(0, str(_ROOT_DIR))
from tests.integration.base import DocsGPTTestBase, create_client_from_args
class MCPTests(DocsGPTTestBase):
"""Integration tests for MCP server endpoints."""
# -------------------------------------------------------------------------
# Callback Tests
# -------------------------------------------------------------------------
def test_mcp_callback(self) -> bool:
"""Test MCP OAuth callback endpoint."""
test_name = "MCP OAuth callback"
self.print_header(test_name)
try:
response = self.get(
"/api/mcp_server/callback",
params={"code": "test_code", "state": "test_state"},
timeout=10,
)
# Expect various responses depending on configuration
if response.status_code in [200, 302, 400, 404]:
self.print_success(f"Callback response: {response.status_code}")
self.record_result(test_name, True, f"Status: {response.status_code}")
return True
self.print_error(f"Unexpected status: {response.status_code}")
self.record_result(test_name, False, f"Status: {response.status_code}")
return False
except Exception as e:
self.print_error(f"Exception: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# OAuth Status Tests
# -------------------------------------------------------------------------
def test_mcp_oauth_status(self) -> bool:
"""Test getting MCP OAuth status."""
test_name = "MCP OAuth status"
self.print_header(test_name)
if not self.require_auth(test_name):
return True
try:
response = self.get(
"/api/mcp_server/oauth_status/test-task-id-123",
timeout=10,
)
if response.status_code in [200, 404]:
self.print_success(f"OAuth status check: {response.status_code}")
if response.status_code == 200:
result = response.json()
self.print_info(f"Status: {result.get('status', 'N/A')}")
self.record_result(test_name, True, f"Status: {response.status_code}")
return True
self.print_error(f"Unexpected status: {response.status_code}")
self.record_result(test_name, False, f"Status: {response.status_code}")
return False
except Exception as e:
self.print_error(f"Exception: {e}")
self.record_result(test_name, False, str(e))
return False
def test_mcp_oauth_status_invalid_task(self) -> bool:
"""Test OAuth status for invalid task ID."""
test_name = "MCP OAuth status invalid"
self.print_header(test_name)
if not self.require_auth(test_name):
return True
try:
response = self.get(
"/api/mcp_server/oauth_status/nonexistent-task-xyz",
timeout=10,
)
if response.status_code in [404, 400]:
self.print_success(f"Correctly returned: {response.status_code}")
self.record_result(test_name, True, "Invalid task handled")
return True
elif response.status_code == 200:
result = response.json()
if result.get("status") in ["not_found", "unknown", None]:
self.print_success("Invalid task handled (status: not_found)")
self.record_result(test_name, True, "Invalid task handled")
return True
self.print_warning(f"Status: {response.status_code}")
self.record_result(test_name, True, f"Status: {response.status_code}")
return True
except Exception as e:
self.print_error(f"Exception: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Save Tests
# -------------------------------------------------------------------------
def test_mcp_save(self) -> bool:
"""Test saving MCP server configuration."""
test_name = "Save MCP server config"
self.print_header(test_name)
if not self.require_auth(test_name):
return True
payload = {
"name": f"Test MCP Server {int(time.time())}",
"url": "https://example.com/mcp",
"config": {},
}
try:
response = self.post(
"/api/mcp_server/save",
json=payload,
timeout=15,
)
if response.status_code in [200, 201]:
result = response.json()
self.print_success(f"Saved MCP server: {result.get('id', 'N/A')}")
self.record_result(test_name, True, "Config saved")
return True
elif response.status_code in [400, 422]:
self.print_warning(f"Validation error: {response.status_code}")
self.record_result(test_name, True, "Validation handled")
return True
self.print_error(f"Save failed: {response.status_code}")
self.record_result(test_name, False, f"Status: {response.status_code}")
return False
except Exception as e:
self.print_error(f"Exception: {e}")
self.record_result(test_name, False, str(e))
return False
def test_mcp_save_invalid(self) -> bool:
"""Test saving invalid MCP config."""
test_name = "Save invalid MCP config"
self.print_header(test_name)
if not self.require_auth(test_name):
return True
payload = {
"name": "", # Invalid empty name
"url": "not-a-url", # Invalid URL
}
try:
response = self.post(
"/api/mcp_server/save",
json=payload,
timeout=15,
)
if response.status_code in [400, 422]:
self.print_success(f"Validation rejected: {response.status_code}")
self.record_result(test_name, True, "Invalid config rejected")
return True
elif response.status_code in [200, 201]:
self.print_warning("Server accepted invalid data (lenient validation)")
self.record_result(test_name, True, "Lenient validation")
return True
self.print_error(f"Unexpected status: {response.status_code}")
self.record_result(test_name, False, f"Status: {response.status_code}")
return False
except Exception as e:
self.print_error(f"Exception: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Test Connection Tests
# -------------------------------------------------------------------------
def test_mcp_test_connection(self) -> bool:
"""Test MCP server connection test."""
test_name = "Test MCP connection"
self.print_header(test_name)
if not self.require_auth(test_name):
return True
payload = {
"url": "https://example.com/mcp",
"config": {},
}
try:
response = self.post(
"/api/mcp_server/test",
json=payload,
timeout=30, # Connection test may take time
)
if response.status_code == 200:
result = response.json()
success = result.get("success", result.get("connected", False))
self.print_success(f"Connection test: success={success}")
self.record_result(test_name, True, f"Connected: {success}")
return True
elif response.status_code in [400, 500, 502, 504]:
# Connection failed (expected for non-existent server)
self.print_warning(f"Connection failed: {response.status_code}")
self.record_result(test_name, True, "Connection failed (expected)")
return True
self.print_error(f"Unexpected status: {response.status_code}")
self.record_result(test_name, False, f"Status: {response.status_code}")
return False
except Exception as e:
self.print_error(f"Exception: {e}")
self.record_result(test_name, False, str(e))
return False
def test_mcp_test_connection_invalid(self) -> bool:
"""Test MCP connection with invalid URL."""
test_name = "Test MCP invalid URL"
self.print_header(test_name)
if not self.require_auth(test_name):
return True
payload = {
"url": "invalid-url",
"config": {},
}
try:
response = self.post(
"/api/mcp_server/test",
json=payload,
timeout=15,
)
if response.status_code in [400, 422, 500]:
self.print_success(f"Invalid URL rejected: {response.status_code}")
self.record_result(test_name, True, "Invalid URL handled")
return True
elif response.status_code == 200:
result = response.json()
if not result.get("success", result.get("connected", True)):
self.print_success("Connection correctly failed")
self.record_result(test_name, True, "Connection failed")
return True
self.print_warning(f"Status: {response.status_code}")
self.record_result(test_name, True, f"Status: {response.status_code}")
return True
except Exception as e:
self.print_error(f"Exception: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Test Runner
# -------------------------------------------------------------------------
def run_all(self) -> bool:
"""Run all MCP tests."""
self.print_header("DocsGPT MCP Server Integration Tests")
self.print_info(f"Base URL: {self.base_url}")
self.print_info(f"Auth: {self.token_source}")
# Callback tests
self.test_mcp_callback()
# OAuth status tests
self.test_mcp_oauth_status()
self.test_mcp_oauth_status_invalid_task()
# Save tests
self.test_mcp_save()
self.test_mcp_save_invalid()
# Test connection tests
self.test_mcp_test_connection()
self.test_mcp_test_connection_invalid()
return self.print_summary()
def main():
"""Main entry point."""
client = create_client_from_args(MCPTests, "DocsGPT MCP Server Integration Tests")
exit_code = 0 if client.run_all() else 1
sys.exit(exit_code)
if __name__ == "__main__":
main()