mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 14:00:55 +00:00
test_integration.py kept for backwards compatability. tests/integration/run_all.py proposed as alternative to cover all endpoints.
338 lines
12 KiB
Python
338 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Integration tests for DocsGPT MCP (Model Context Protocol) server endpoints.
|
|
|
|
Endpoints tested:
|
|
- /api/mcp_server/callback (GET) - OAuth callback
|
|
- /api/mcp_server/oauth_status/{task_id} (GET) - OAuth status
|
|
- /api/mcp_server/save (POST) - Save MCP server config
|
|
- /api/mcp_server/test (POST) - Test MCP server connection
|
|
|
|
Usage:
|
|
python tests/integration/test_mcp.py
|
|
python tests/integration/test_mcp.py --base-url http://localhost:7091
|
|
python tests/integration/test_mcp.py --token YOUR_JWT_TOKEN
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path for standalone execution
|
|
_THIS_DIR = Path(__file__).parent
|
|
_TESTS_DIR = _THIS_DIR.parent
|
|
_ROOT_DIR = _TESTS_DIR.parent
|
|
if str(_ROOT_DIR) not in sys.path:
|
|
sys.path.insert(0, str(_ROOT_DIR))
|
|
|
|
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
|
|
|
|
|
class MCPTests(DocsGPTTestBase):
|
|
"""Integration tests for MCP server endpoints."""
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Callback Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_mcp_callback(self) -> bool:
|
|
"""Test MCP OAuth callback endpoint."""
|
|
test_name = "MCP OAuth callback"
|
|
self.print_header(test_name)
|
|
|
|
try:
|
|
response = self.get(
|
|
"/api/mcp_server/callback",
|
|
params={"code": "test_code", "state": "test_state"},
|
|
timeout=10,
|
|
)
|
|
|
|
# Expect various responses depending on configuration
|
|
if response.status_code in [200, 302, 400, 404]:
|
|
self.print_success(f"Callback response: {response.status_code}")
|
|
self.record_result(test_name, True, f"Status: {response.status_code}")
|
|
return True
|
|
|
|
self.print_error(f"Unexpected status: {response.status_code}")
|
|
self.record_result(test_name, False, f"Status: {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Exception: {e}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# OAuth Status Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_mcp_oauth_status(self) -> bool:
|
|
"""Test getting MCP OAuth status."""
|
|
test_name = "MCP OAuth status"
|
|
self.print_header(test_name)
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
try:
|
|
response = self.get(
|
|
"/api/mcp_server/oauth_status/test-task-id-123",
|
|
timeout=10,
|
|
)
|
|
|
|
if response.status_code in [200, 404]:
|
|
self.print_success(f"OAuth status check: {response.status_code}")
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
self.print_info(f"Status: {result.get('status', 'N/A')}")
|
|
self.record_result(test_name, True, f"Status: {response.status_code}")
|
|
return True
|
|
|
|
self.print_error(f"Unexpected status: {response.status_code}")
|
|
self.record_result(test_name, False, f"Status: {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Exception: {e}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_mcp_oauth_status_invalid_task(self) -> bool:
|
|
"""Test OAuth status for invalid task ID."""
|
|
test_name = "MCP OAuth status invalid"
|
|
self.print_header(test_name)
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
try:
|
|
response = self.get(
|
|
"/api/mcp_server/oauth_status/nonexistent-task-xyz",
|
|
timeout=10,
|
|
)
|
|
|
|
if response.status_code in [404, 400]:
|
|
self.print_success(f"Correctly returned: {response.status_code}")
|
|
self.record_result(test_name, True, "Invalid task handled")
|
|
return True
|
|
elif response.status_code == 200:
|
|
result = response.json()
|
|
if result.get("status") in ["not_found", "unknown", None]:
|
|
self.print_success("Invalid task handled (status: not_found)")
|
|
self.record_result(test_name, True, "Invalid task handled")
|
|
return True
|
|
|
|
self.print_warning(f"Status: {response.status_code}")
|
|
self.record_result(test_name, True, f"Status: {response.status_code}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Exception: {e}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Save Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_mcp_save(self) -> bool:
|
|
"""Test saving MCP server configuration."""
|
|
test_name = "Save MCP server config"
|
|
self.print_header(test_name)
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
payload = {
|
|
"name": f"Test MCP Server {int(time.time())}",
|
|
"url": "https://example.com/mcp",
|
|
"config": {},
|
|
}
|
|
|
|
try:
|
|
response = self.post(
|
|
"/api/mcp_server/save",
|
|
json=payload,
|
|
timeout=15,
|
|
)
|
|
|
|
if response.status_code in [200, 201]:
|
|
result = response.json()
|
|
self.print_success(f"Saved MCP server: {result.get('id', 'N/A')}")
|
|
self.record_result(test_name, True, "Config saved")
|
|
return True
|
|
elif response.status_code in [400, 422]:
|
|
self.print_warning(f"Validation error: {response.status_code}")
|
|
self.record_result(test_name, True, "Validation handled")
|
|
return True
|
|
|
|
self.print_error(f"Save failed: {response.status_code}")
|
|
self.record_result(test_name, False, f"Status: {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Exception: {e}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_mcp_save_invalid(self) -> bool:
|
|
"""Test saving invalid MCP config."""
|
|
test_name = "Save invalid MCP config"
|
|
self.print_header(test_name)
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
payload = {
|
|
"name": "", # Invalid empty name
|
|
"url": "not-a-url", # Invalid URL
|
|
}
|
|
|
|
try:
|
|
response = self.post(
|
|
"/api/mcp_server/save",
|
|
json=payload,
|
|
timeout=15,
|
|
)
|
|
|
|
if response.status_code in [400, 422]:
|
|
self.print_success(f"Validation rejected: {response.status_code}")
|
|
self.record_result(test_name, True, "Invalid config rejected")
|
|
return True
|
|
elif response.status_code in [200, 201]:
|
|
self.print_warning("Server accepted invalid data (lenient validation)")
|
|
self.record_result(test_name, True, "Lenient validation")
|
|
return True
|
|
|
|
self.print_error(f"Unexpected status: {response.status_code}")
|
|
self.record_result(test_name, False, f"Status: {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Exception: {e}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Test Connection Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_mcp_test_connection(self) -> bool:
|
|
"""Test MCP server connection test."""
|
|
test_name = "Test MCP connection"
|
|
self.print_header(test_name)
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
payload = {
|
|
"url": "https://example.com/mcp",
|
|
"config": {},
|
|
}
|
|
|
|
try:
|
|
response = self.post(
|
|
"/api/mcp_server/test",
|
|
json=payload,
|
|
timeout=30, # Connection test may take time
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
success = result.get("success", result.get("connected", False))
|
|
self.print_success(f"Connection test: success={success}")
|
|
self.record_result(test_name, True, f"Connected: {success}")
|
|
return True
|
|
elif response.status_code in [400, 500, 502, 504]:
|
|
# Connection failed (expected for non-existent server)
|
|
self.print_warning(f"Connection failed: {response.status_code}")
|
|
self.record_result(test_name, True, "Connection failed (expected)")
|
|
return True
|
|
|
|
self.print_error(f"Unexpected status: {response.status_code}")
|
|
self.record_result(test_name, False, f"Status: {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Exception: {e}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_mcp_test_connection_invalid(self) -> bool:
|
|
"""Test MCP connection with invalid URL."""
|
|
test_name = "Test MCP invalid URL"
|
|
self.print_header(test_name)
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
payload = {
|
|
"url": "invalid-url",
|
|
"config": {},
|
|
}
|
|
|
|
try:
|
|
response = self.post(
|
|
"/api/mcp_server/test",
|
|
json=payload,
|
|
timeout=15,
|
|
)
|
|
|
|
if response.status_code in [400, 422, 500]:
|
|
self.print_success(f"Invalid URL rejected: {response.status_code}")
|
|
self.record_result(test_name, True, "Invalid URL handled")
|
|
return True
|
|
elif response.status_code == 200:
|
|
result = response.json()
|
|
if not result.get("success", result.get("connected", True)):
|
|
self.print_success("Connection correctly failed")
|
|
self.record_result(test_name, True, "Connection failed")
|
|
return True
|
|
|
|
self.print_warning(f"Status: {response.status_code}")
|
|
self.record_result(test_name, True, f"Status: {response.status_code}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Exception: {e}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Test Runner
|
|
# -------------------------------------------------------------------------
|
|
|
|
def run_all(self) -> bool:
|
|
"""Run all MCP tests."""
|
|
self.print_header("DocsGPT MCP Server Integration Tests")
|
|
self.print_info(f"Base URL: {self.base_url}")
|
|
self.print_info(f"Auth: {self.token_source}")
|
|
|
|
# Callback tests
|
|
self.test_mcp_callback()
|
|
|
|
# OAuth status tests
|
|
self.test_mcp_oauth_status()
|
|
self.test_mcp_oauth_status_invalid_task()
|
|
|
|
# Save tests
|
|
self.test_mcp_save()
|
|
self.test_mcp_save_invalid()
|
|
|
|
# Test connection tests
|
|
self.test_mcp_test_connection()
|
|
self.test_mcp_test_connection_invalid()
|
|
|
|
return self.print_summary()
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
client = create_client_from_args(MCPTests, "DocsGPT MCP Server Integration Tests")
|
|
exit_code = 0 if client.run_all() else 1
|
|
sys.exit(exit_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|