mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-03 04:44:10 +00:00
End 2 end tests (#2266)
* All endpoints covered test_integration.py kept for backwards compatability. tests/integration/run_all.py proposed as alternative to cover all endpoints. * Linter fixes
This commit is contained in:
@@ -1,2 +1,6 @@
|
||||
# Allow lines to be as long as 120 characters.
|
||||
line-length = 120
|
||||
line-length = 120
|
||||
|
||||
[lint.per-file-ignores]
|
||||
# Integration tests use sys.path.insert() before imports for standalone execution
|
||||
"tests/integration/*.py" = ["E402"]
|
||||
64
tests/integration/__init__.py
Normal file
64
tests/integration/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
DocsGPT Integration Tests Package
|
||||
|
||||
This package contains modular integration tests for all DocsGPT API endpoints.
|
||||
Tests are organized by domain:
|
||||
|
||||
- test_chat.py: Chat/streaming endpoints (/stream, /api/answer, /api/feedback, /api/tts)
|
||||
- test_sources.py: Source management (upload, remote, chunks, etc.)
|
||||
- test_agents.py: Agent CRUD and sharing
|
||||
- test_conversations.py: Conversation management
|
||||
- test_prompts.py: Prompt CRUD
|
||||
- test_tools.py: Tools CRUD
|
||||
- test_analytics.py: Analytics endpoints
|
||||
- test_connectors.py: External connectors
|
||||
- test_mcp.py: MCP server endpoints
|
||||
- test_misc.py: Models, images, attachments
|
||||
|
||||
Usage:
|
||||
# Run all integration tests
|
||||
python tests/integration/run_all.py
|
||||
|
||||
# Run specific module
|
||||
python tests/integration/test_chat.py
|
||||
|
||||
# Run multiple modules
|
||||
python tests/integration/run_all.py --module chat,agents
|
||||
|
||||
# Run with custom server
|
||||
python tests/integration/run_all.py --base-url http://localhost:7091
|
||||
|
||||
# List available modules
|
||||
python tests/integration/run_all.py --list
|
||||
"""
|
||||
|
||||
from .base import Colors, DocsGPTTestBase, create_client_from_args, generate_jwt_token
|
||||
from .test_chat import ChatTests
|
||||
from .test_sources import SourceTests
|
||||
from .test_agents import AgentTests
|
||||
from .test_conversations import ConversationTests
|
||||
from .test_prompts import PromptTests
|
||||
from .test_tools import ToolsTests
|
||||
from .test_analytics import AnalyticsTests
|
||||
from .test_connectors import ConnectorTests
|
||||
from .test_mcp import MCPTests
|
||||
from .test_misc import MiscTests
|
||||
|
||||
__all__ = [
|
||||
# Base utilities
|
||||
"Colors",
|
||||
"DocsGPTTestBase",
|
||||
"create_client_from_args",
|
||||
"generate_jwt_token",
|
||||
# Test classes
|
||||
"ChatTests",
|
||||
"SourceTests",
|
||||
"AgentTests",
|
||||
"ConversationTests",
|
||||
"PromptTests",
|
||||
"ToolsTests",
|
||||
"AnalyticsTests",
|
||||
"ConnectorTests",
|
||||
"MCPTests",
|
||||
"MiscTests",
|
||||
]
|
||||
485
tests/integration/base.py
Normal file
485
tests/integration/base.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Base classes and utilities for DocsGPT integration tests.
|
||||
|
||||
This module provides:
|
||||
- Colors: ANSI color codes for terminal output
|
||||
- DocsGPTTestBase: Base class with HTTP helpers and output utilities
|
||||
- generate_jwt_token: JWT token generation for authentication
|
||||
- create_client_from_args: Factory function to create client from CLI args
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json as json_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, Optional, Type, TypeVar
|
||||
|
||||
import requests
|
||||
|
||||
T = TypeVar("T", bound="DocsGPTTestBase")
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output."""
|
||||
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
|
||||
|
||||
def generate_jwt_token() -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Generate a JWT token using local secret or environment variable.
|
||||
|
||||
Returns:
|
||||
Tuple of (token, error_message). Token is None on failure.
|
||||
"""
|
||||
secret = os.getenv("JWT_SECRET_KEY")
|
||||
key_file = Path(".jwt_secret_key")
|
||||
|
||||
if not secret:
|
||||
try:
|
||||
secret = key_file.read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return None, f"Set JWT_SECRET_KEY or create {key_file} by running the backend once."
|
||||
except OSError as exc:
|
||||
return None, f"Could not read {key_file}: {exc}"
|
||||
|
||||
if not secret:
|
||||
return None, "JWT secret key is empty."
|
||||
|
||||
try:
|
||||
from jose import jwt
|
||||
except ImportError:
|
||||
return None, "python-jose is not installed (pip install 'python-jose' to auto-generate tokens)."
|
||||
|
||||
try:
|
||||
payload = {"sub": "test_integration_user"}
|
||||
return jwt.encode(payload, secret, algorithm="HS256"), None
|
||||
except Exception as exc:
|
||||
return None, f"Failed to generate JWT token: {exc}"
|
||||
|
||||
|
||||
class DocsGPTTestBase:
|
||||
"""
|
||||
Base class for DocsGPT integration tests.
|
||||
|
||||
Provides HTTP helpers, SSE streaming, output formatting, and result tracking.
|
||||
|
||||
Usage:
|
||||
client = DocsGPTTestBase("http://localhost:7091", token="...")
|
||||
response = client.post("/api/answer", json={"question": "test"})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
token: Optional[str] = None,
|
||||
token_source: str = "provided",
|
||||
):
|
||||
"""
|
||||
Initialize test client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of DocsGPT instance (e.g., "http://localhost:7091")
|
||||
token: Optional JWT authentication token
|
||||
token_source: Description of token source for logging
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.token_source = token_source
|
||||
self.headers: dict[str, str] = {}
|
||||
if token:
|
||||
self.headers["Authorization"] = f"Bearer {token}"
|
||||
self.test_results: list[tuple[str, bool, str]] = []
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP Helper Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a GET request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/sources")
|
||||
params: Optional query parameters
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.get
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.get(
|
||||
url,
|
||||
params=params,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def post(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
files: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a POST request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/answer")
|
||||
json: Optional JSON body
|
||||
data: Optional form data
|
||||
files: Optional files for multipart upload
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.post
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.post(
|
||||
url,
|
||||
json=json,
|
||||
data=data,
|
||||
files=files,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a PUT request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/update_agent/123")
|
||||
json: Optional JSON body
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.put
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.put(
|
||||
url,
|
||||
json=json,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a DELETE request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/delete_agent")
|
||||
json: Optional JSON body
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.delete
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.delete(
|
||||
url,
|
||||
json=json,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def post_stream(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
Make a streaming POST request and yield SSE events.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/stream")
|
||||
json: Optional JSON body
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.post
|
||||
|
||||
Yields:
|
||||
Parsed JSON data from each SSE event
|
||||
|
||||
Example:
|
||||
for event in client.post_stream("/stream", json={"question": "test"}):
|
||||
if event.get("type") == "answer":
|
||||
print(event.get("message"))
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
response = requests.post(
|
||||
url,
|
||||
json=json,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
stream=True,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Store response for status code checking
|
||||
self._last_stream_response = response
|
||||
|
||||
if response.status_code != 200:
|
||||
# Yield error event for non-200 responses
|
||||
yield {"type": "error", "status_code": response.status_code, "text": response.text[:500]}
|
||||
return
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:] # Remove 'data: ' prefix
|
||||
try:
|
||||
data = json_module.loads(data_str)
|
||||
yield data
|
||||
if data.get("type") == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Output Helper Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def print_header(self, message: str) -> None:
|
||||
"""Print a colored header."""
|
||||
print(f"\n{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{message}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}\n")
|
||||
|
||||
def print_success(self, message: str) -> None:
|
||||
"""Print a success message."""
|
||||
print(f"{Colors.OKGREEN}[PASS] {message}{Colors.ENDC}")
|
||||
|
||||
def print_error(self, message: str) -> None:
|
||||
"""Print an error message."""
|
||||
print(f"{Colors.FAIL}[FAIL] {message}{Colors.ENDC}")
|
||||
|
||||
def print_info(self, message: str) -> None:
|
||||
"""Print an info message."""
|
||||
print(f"{Colors.OKCYAN}[INFO] {message}{Colors.ENDC}")
|
||||
|
||||
def print_warning(self, message: str) -> None:
|
||||
"""Print a warning message."""
|
||||
print(f"{Colors.WARNING}[WARN] {message}{Colors.ENDC}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Result Tracking Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def record_result(self, test_name: str, success: bool, message: str) -> None:
|
||||
"""
|
||||
Record a test result.
|
||||
|
||||
Args:
|
||||
test_name: Name of the test
|
||||
success: Whether the test passed
|
||||
message: Result message or error details
|
||||
"""
|
||||
self.test_results.append((test_name, success, message))
|
||||
|
||||
def print_summary(self) -> bool:
|
||||
"""
|
||||
Print test results summary.
|
||||
|
||||
Returns:
|
||||
True if all tests passed, False otherwise
|
||||
"""
|
||||
self.print_header("Test Results Summary")
|
||||
|
||||
passed = sum(1 for _, success, _ in self.test_results if success)
|
||||
failed = len(self.test_results) - passed
|
||||
|
||||
print(f"\n{Colors.BOLD}Total Tests: {len(self.test_results)}{Colors.ENDC}")
|
||||
print(f"{Colors.OKGREEN}Passed: {passed}{Colors.ENDC}")
|
||||
print(f"{Colors.FAIL}Failed: {failed}{Colors.ENDC}\n")
|
||||
|
||||
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
|
||||
for test_name, success, message in self.test_results:
|
||||
status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if success else f"{Colors.FAIL}FAIL{Colors.ENDC}"
|
||||
print(f" {status} - {test_name}: {message}")
|
||||
|
||||
print()
|
||||
return failed == 0
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Assertion Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def assert_status(
|
||||
self,
|
||||
response: requests.Response,
|
||||
expected: int,
|
||||
test_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Assert response status code and record result.
|
||||
|
||||
Args:
|
||||
response: Response object to check
|
||||
expected: Expected status code
|
||||
test_name: Name of the test for recording
|
||||
|
||||
Returns:
|
||||
True if status matches, False otherwise
|
||||
"""
|
||||
if response.status_code == expected:
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected {expected}, got {response.status_code}")
|
||||
self.print_error(f"Response: {response.text[:500]}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
def assert_json_key(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
key: str,
|
||||
test_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Assert JSON response contains a key.
|
||||
|
||||
Args:
|
||||
data: JSON response data
|
||||
key: Expected key
|
||||
test_name: Name of the test for recording
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise
|
||||
"""
|
||||
if key in data:
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Missing key '{key}' in response")
|
||||
self.record_result(test_name, False, f"Missing key: {key}")
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Convenience Properties
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if client has authentication token."""
|
||||
return self.token is not None
|
||||
|
||||
def require_auth(self, test_name: str) -> bool:
|
||||
"""
|
||||
Check authentication and record skip if not authenticated.
|
||||
|
||||
Args:
|
||||
test_name: Name of the test
|
||||
|
||||
Returns:
|
||||
True if authenticated, False otherwise (test skipped)
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
self.print_warning("No authentication token provided")
|
||||
self.print_info("Skipping test (auth required)")
|
||||
self.record_result(test_name, True, "Skipped (auth required)")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Factory Function
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_client_from_args(
|
||||
client_class: Type[T],
|
||||
description: str = "DocsGPT Integration Tests",
|
||||
) -> T:
|
||||
"""
|
||||
Create a test client from command-line arguments.
|
||||
|
||||
Parses --base-url and --token arguments, and handles JWT token generation.
|
||||
|
||||
Args:
|
||||
client_class: The test class to instantiate (must inherit from DocsGPTTestBase)
|
||||
description: Description for the argument parser
|
||||
|
||||
Returns:
|
||||
An instance of the provided client_class
|
||||
|
||||
Example:
|
||||
class ChatTests(DocsGPTTestBase):
|
||||
def run_all(self):
|
||||
...
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = create_client_from_args(ChatTests)
|
||||
sys.exit(0 if client.run_all() else 1)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.getenv("DOCSGPT_BASE_URL", "http://localhost:7091"),
|
||||
help="Base URL of DocsGPT instance (default: http://localhost:7091)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
default=os.getenv("JWT_TOKEN"),
|
||||
help="JWT authentication token (auto-generated from local secret when available)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine token and source
|
||||
token = args.token
|
||||
token_source = "provided via --token" if token else "none"
|
||||
|
||||
if not token:
|
||||
token, token_error = generate_jwt_token()
|
||||
if token:
|
||||
token_source = "auto-generated from local secret"
|
||||
print(f"{Colors.OKCYAN}[INFO] Using auto-generated JWT token{Colors.ENDC}")
|
||||
elif token_error:
|
||||
print(f"{Colors.WARNING}[WARN] Could not auto-generate token: {token_error}{Colors.ENDC}")
|
||||
print(f"{Colors.WARNING}[WARN] Tests requiring auth will be skipped{Colors.ENDC}")
|
||||
|
||||
return client_class(args.base_url, token=token, token_source=token_source)
|
||||
225
tests/integration/run_all.py
Normal file
225
tests/integration/run_all.py
Normal file
@@ -0,0 +1,225 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DocsGPT Integration Test Runner
|
||||
|
||||
Runs all integration tests or specific modules.
|
||||
|
||||
Usage:
|
||||
python tests/integration/run_all.py # Run all tests
|
||||
python tests/integration/run_all.py --module chat # Run specific module
|
||||
python tests/integration/run_all.py --module chat,agents # Run multiple modules
|
||||
python tests/integration/run_all.py --list # List available modules
|
||||
python tests/integration/run_all.py --base-url URL # Custom base URL
|
||||
python tests/integration/run_all.py --token TOKEN # With auth token
|
||||
|
||||
Available modules:
|
||||
chat, sources, agents, conversations, prompts, tools, analytics,
|
||||
connectors, mcp, misc
|
||||
|
||||
Examples:
|
||||
# Run all tests
|
||||
python tests/integration/run_all.py
|
||||
|
||||
# Run only chat and agent tests
|
||||
python tests/integration/run_all.py --module chat,agents
|
||||
|
||||
# Run with custom server
|
||||
python tests/integration/run_all.py --base-url http://staging.example.com:7091
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import Colors, generate_jwt_token
|
||||
from tests.integration.test_chat import ChatTests
|
||||
from tests.integration.test_sources import SourceTests
|
||||
from tests.integration.test_agents import AgentTests
|
||||
from tests.integration.test_conversations import ConversationTests
|
||||
from tests.integration.test_prompts import PromptTests
|
||||
from tests.integration.test_tools import ToolsTests
|
||||
from tests.integration.test_analytics import AnalyticsTests
|
||||
from tests.integration.test_connectors import ConnectorTests
|
||||
from tests.integration.test_mcp import MCPTests
|
||||
from tests.integration.test_misc import MiscTests
|
||||
|
||||
|
||||
# Module registry
|
||||
MODULES = {
|
||||
"chat": ChatTests,
|
||||
"sources": SourceTests,
|
||||
"agents": AgentTests,
|
||||
"conversations": ConversationTests,
|
||||
"prompts": PromptTests,
|
||||
"tools": ToolsTests,
|
||||
"analytics": AnalyticsTests,
|
||||
"connectors": ConnectorTests,
|
||||
"mcp": MCPTests,
|
||||
"misc": MiscTests,
|
||||
}
|
||||
|
||||
|
||||
def print_header(message: str) -> None:
|
||||
"""Print a styled header."""
|
||||
print(f"\n{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{message}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}\n")
|
||||
|
||||
|
||||
def list_modules() -> None:
|
||||
"""Print available test modules."""
|
||||
print_header("Available Test Modules")
|
||||
for name, cls in MODULES.items():
|
||||
test_count = len([m for m in dir(cls) if m.startswith("test_")])
|
||||
print(f" {Colors.OKCYAN}{name:<15}{Colors.ENDC} - {test_count} tests")
|
||||
print()
|
||||
|
||||
|
||||
def run_module(
|
||||
module_name: str,
|
||||
base_url: str,
|
||||
token: str | None,
|
||||
token_source: str,
|
||||
) -> tuple[bool, int, int]:
|
||||
"""
|
||||
Run a single test module.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_passed, passed_count, total_count)
|
||||
"""
|
||||
cls = MODULES.get(module_name)
|
||||
if not cls:
|
||||
print(f"{Colors.FAIL}Unknown module: {module_name}{Colors.ENDC}")
|
||||
return False, 0, 0
|
||||
|
||||
client = cls(base_url, token=token, token_source=token_source)
|
||||
success = client.run_all()
|
||||
|
||||
passed = sum(1 for _, s, _ in client.test_results if s)
|
||||
total = len(client.test_results)
|
||||
|
||||
return success, passed, total
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="DocsGPT Integration Test Runner",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python tests/integration/run_all.py # Run all tests
|
||||
python tests/integration/run_all.py --module chat # Run chat tests
|
||||
python tests/integration/run_all.py --module chat,agents # Multiple modules
|
||||
python tests/integration/run_all.py --list # List modules
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.getenv("DOCSGPT_BASE_URL", "http://localhost:7091"),
|
||||
help="Base URL of DocsGPT instance (default: http://localhost:7091)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
default=os.getenv("JWT_TOKEN"),
|
||||
help="JWT authentication token",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--module", "-m",
|
||||
help="Specific module(s) to run, comma-separated (e.g., 'chat,agents')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--list", "-l",
|
||||
action="store_true",
|
||||
help="List available test modules",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# List modules and exit
|
||||
if args.list:
|
||||
list_modules()
|
||||
return 0
|
||||
|
||||
# Determine token
|
||||
token = args.token
|
||||
token_source = "provided via --token" if token else "none"
|
||||
|
||||
if not token:
|
||||
token, token_error = generate_jwt_token()
|
||||
if token:
|
||||
token_source = "auto-generated from local secret"
|
||||
print(f"{Colors.OKCYAN}[INFO] Using auto-generated JWT token{Colors.ENDC}")
|
||||
elif token_error:
|
||||
print(f"{Colors.WARNING}[WARN] Could not auto-generate token: {token_error}{Colors.ENDC}")
|
||||
print(f"{Colors.WARNING}[WARN] Tests requiring auth will be skipped{Colors.ENDC}")
|
||||
|
||||
# Determine which modules to run
|
||||
if args.module:
|
||||
modules_to_run = [m.strip() for m in args.module.split(",")]
|
||||
# Validate modules
|
||||
invalid = [m for m in modules_to_run if m not in MODULES]
|
||||
if invalid:
|
||||
print(f"{Colors.FAIL}Unknown module(s): {', '.join(invalid)}{Colors.ENDC}")
|
||||
print(f"{Colors.OKCYAN}Available: {', '.join(MODULES.keys())}{Colors.ENDC}")
|
||||
return 1
|
||||
else:
|
||||
modules_to_run = list(MODULES.keys())
|
||||
|
||||
# Print test plan
|
||||
print_header("DocsGPT Integration Test Suite")
|
||||
print(f"{Colors.OKCYAN}Base URL:{Colors.ENDC} {args.base_url}")
|
||||
print(f"{Colors.OKCYAN}Auth:{Colors.ENDC} {token_source}")
|
||||
print(f"{Colors.OKCYAN}Modules:{Colors.ENDC} {', '.join(modules_to_run)}")
|
||||
|
||||
# Run tests
|
||||
results = {}
|
||||
total_passed = 0
|
||||
total_tests = 0
|
||||
|
||||
for module_name in modules_to_run:
|
||||
success, passed, total = run_module(
|
||||
module_name,
|
||||
args.base_url,
|
||||
token,
|
||||
token_source,
|
||||
)
|
||||
results[module_name] = (success, passed, total)
|
||||
total_passed += passed
|
||||
total_tests += total
|
||||
|
||||
# Print summary
|
||||
print_header("Overall Test Summary")
|
||||
|
||||
print(f"\n{Colors.BOLD}Module Results:{Colors.ENDC}")
|
||||
for module_name, (success, passed, total) in results.items():
|
||||
status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if success else f"{Colors.FAIL}FAIL{Colors.ENDC}"
|
||||
print(f" {status} - {module_name}: {passed}/{total} tests passed")
|
||||
|
||||
print(f"\n{Colors.BOLD}Total:{Colors.ENDC} {total_passed}/{total_tests} tests passed")
|
||||
|
||||
all_passed = all(success for success, _, _ in results.values())
|
||||
if all_passed:
|
||||
print(f"\n{Colors.OKGREEN}{Colors.BOLD}ALL TESTS PASSED{Colors.ENDC}")
|
||||
return 0
|
||||
else:
|
||||
failed_modules = [m for m, (s, _, _) in results.items() if not s]
|
||||
print(f"\n{Colors.FAIL}{Colors.BOLD}SOME TESTS FAILED{Colors.ENDC}")
|
||||
print(f"{Colors.FAIL}Failed modules: {', '.join(failed_modules)}{Colors.ENDC}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
1009
tests/integration/test_agents.py
Normal file
1009
tests/integration/test_agents.py
Normal file
File diff suppressed because it is too large
Load Diff
323
tests/integration/test_analytics.py
Normal file
323
tests/integration/test_analytics.py
Normal file
@@ -0,0 +1,323 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT analytics endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/get_feedback_analytics (POST) - Feedback analytics
|
||||
- /api/get_message_analytics (POST) - Message analytics
|
||||
- /api/get_token_analytics (POST) - Token usage analytics
|
||||
- /api/get_user_logs (POST) - User activity logs
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_analytics.py
|
||||
python tests/integration/test_analytics.py --base-url http://localhost:7091
|
||||
python tests/integration/test_analytics.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class AnalyticsTests(DocsGPTTestBase):
|
||||
"""Integration tests for analytics endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Feedback Analytics Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_feedback_analytics(self) -> bool:
|
||||
"""Test getting feedback analytics."""
|
||||
test_name = "Get feedback analytics"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_feedback_analytics",
|
||||
json={"date_range": "last_30_days"},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
self.print_success("Retrieved feedback analytics")
|
||||
self.print_info(f"Data points: {len(result) if isinstance(result, list) else 'object'}")
|
||||
self.record_result(test_name, True, "Analytics retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_feedback_analytics_with_filters(self) -> bool:
|
||||
"""Test feedback analytics with filters."""
|
||||
test_name = "Feedback analytics filtered"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_feedback_analytics",
|
||||
json={
|
||||
"date_range": "last_7_days",
|
||||
"agent_id": None,
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
self.print_success("Retrieved filtered feedback analytics")
|
||||
self.record_result(test_name, True, "Filtered analytics retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Message Analytics Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_message_analytics(self) -> bool:
|
||||
"""Test getting message analytics."""
|
||||
test_name = "Get message analytics"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_message_analytics",
|
||||
json={"date_range": "last_30_days"},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
self.print_success("Retrieved message analytics")
|
||||
self.print_info(f"Data: {type(result).__name__}")
|
||||
self.record_result(test_name, True, "Analytics retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_message_analytics_with_agent(self) -> bool:
|
||||
"""Test message analytics for specific agent."""
|
||||
test_name = "Message analytics by agent"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_message_analytics",
|
||||
json={
|
||||
"date_range": "last_7_days",
|
||||
"agent_id": None,
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
self.print_success("Retrieved agent message analytics")
|
||||
self.record_result(test_name, True, "Agent analytics retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Token Analytics Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_token_analytics(self) -> bool:
|
||||
"""Test getting token usage analytics."""
|
||||
test_name = "Get token analytics"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_token_analytics",
|
||||
json={"date_range": "last_30_days"},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
self.print_success("Retrieved token analytics")
|
||||
self.print_info(f"Data: {type(result).__name__}")
|
||||
self.record_result(test_name, True, "Analytics retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_token_analytics_breakdown(self) -> bool:
|
||||
"""Test token analytics with breakdown."""
|
||||
test_name = "Token analytics breakdown"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_token_analytics",
|
||||
json={
|
||||
"date_range": "last_7_days",
|
||||
"breakdown": "daily",
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
self.print_success("Retrieved token analytics breakdown")
|
||||
self.record_result(test_name, True, "Breakdown retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# User Logs Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_user_logs(self) -> bool:
|
||||
"""Test getting user activity logs."""
|
||||
test_name = "Get user logs"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_user_logs",
|
||||
json={"date_range": "last_30_days"},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
self.print_success("Retrieved user logs")
|
||||
self.print_info(f"Logs: {len(result) if isinstance(result, list) else 'object'}")
|
||||
self.record_result(test_name, True, "Logs retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_user_logs_paginated(self) -> bool:
|
||||
"""Test user logs with pagination."""
|
||||
test_name = "User logs paginated"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/get_user_logs",
|
||||
json={
|
||||
"date_range": "last_7_days",
|
||||
"page": 1,
|
||||
"per_page": 10,
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
self.print_success("Retrieved paginated user logs")
|
||||
self.record_result(test_name, True, "Paginated logs retrieved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Runner
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all analytics tests."""
|
||||
self.print_header("DocsGPT Analytics Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Auth: {self.token_source}")
|
||||
|
||||
# Feedback analytics
|
||||
self.test_get_feedback_analytics()
|
||||
self.test_get_feedback_analytics_with_filters()
|
||||
|
||||
# Message analytics
|
||||
self.test_get_message_analytics()
|
||||
self.test_get_message_analytics_with_agent()
|
||||
|
||||
# Token analytics
|
||||
self.test_get_token_analytics()
|
||||
self.test_get_token_analytics_breakdown()
|
||||
|
||||
# User logs
|
||||
self.test_get_user_logs()
|
||||
self.test_get_user_logs_paginated()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(AnalyticsTests, "DocsGPT Analytics Integration Tests")
|
||||
exit_code = 0 if client.run_all() else 1
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
957
tests/integration/test_chat.py
Normal file
957
tests/integration/test_chat.py
Normal file
@@ -0,0 +1,957 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT chat endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /stream (POST) - Streaming chat
|
||||
- /api/answer (POST) - Non-streaming chat
|
||||
- /api/feedback (POST) - Feedback submission
|
||||
- /api/tts (POST) - Text-to-speech
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_chat.py
|
||||
python tests/integration/test_chat.py --base-url http://localhost:7091
|
||||
python tests/integration/test_chat.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import json as json_module
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class ChatTests(DocsGPTTestBase):
|
||||
"""Integration tests for chat/streaming endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Data Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_test_agent(self) -> Optional[tuple]:
|
||||
"""
|
||||
Get or create a test agent for chat tests.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_id, api_key) or None if creation fails
|
||||
"""
|
||||
if hasattr(self, "_test_agent"):
|
||||
return self._test_agent
|
||||
|
||||
if not self.is_authenticated:
|
||||
return None
|
||||
|
||||
payload = {
|
||||
"name": f"Chat Test Agent {int(time.time())}",
|
||||
"description": "Integration test agent for chat tests",
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"retriever": "classic",
|
||||
"agent_type": "classic",
|
||||
"status": "draft",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_agent", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
agent_id = result.get("id")
|
||||
api_key = result.get("key")
|
||||
if agent_id:
|
||||
self._test_agent = (agent_id, api_key)
|
||||
return self._test_agent
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def get_or_create_published_agent(self) -> Optional[tuple]:
|
||||
"""
|
||||
Get or create a published agent with API key.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_id, api_key) or None if creation fails
|
||||
"""
|
||||
if hasattr(self, "_published_agent"):
|
||||
return self._published_agent
|
||||
|
||||
if not self.is_authenticated:
|
||||
return None
|
||||
|
||||
# First create a source
|
||||
source_id = self._create_test_source()
|
||||
|
||||
payload = {
|
||||
"name": f"Chat Test Published Agent {int(time.time())}",
|
||||
"description": "Integration test published agent",
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"retriever": "classic",
|
||||
"agent_type": "classic",
|
||||
"status": "published",
|
||||
}
|
||||
|
||||
if source_id:
|
||||
payload["source"] = source_id
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_agent", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
agent_id = result.get("id")
|
||||
api_key = result.get("key")
|
||||
if agent_id and api_key:
|
||||
self._published_agent = (agent_id, api_key)
|
||||
return self._published_agent
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _create_test_source(self) -> Optional[str]:
|
||||
"""Create a simple test source and return its ID."""
|
||||
if hasattr(self, "_test_source_id"):
|
||||
return self._test_source_id
|
||||
|
||||
test_content = """# Test Documentation
|
||||
## Overview
|
||||
This is test documentation for integration tests.
|
||||
## Features
|
||||
- Feature 1: Testing
|
||||
- Feature 2: Integration
|
||||
"""
|
||||
files = {"file": ("test_docs.txt", test_content.encode(), "text/plain")}
|
||||
data = {"user": "test_user", "name": f"Chat Test Source {int(time.time())}"}
|
||||
|
||||
try:
|
||||
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
||||
if response.status_code == 200:
|
||||
task_id = response.json().get("task_id")
|
||||
if task_id:
|
||||
time.sleep(5) # Wait for processing
|
||||
# Get source ID
|
||||
sources_response = self.get("/api/sources")
|
||||
if sources_response.status_code == 200:
|
||||
sources = sources_response.json()
|
||||
for source in sources:
|
||||
if "Chat Test Source" in source.get("name", ""):
|
||||
self._test_source_id = source.get("id")
|
||||
return self._test_source_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Stream Endpoint Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_stream_endpoint_no_agent(self) -> bool:
|
||||
"""Test /stream endpoint without agent."""
|
||||
test_name = "Stream endpoint (no agent)"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info("POST /stream")
|
||||
self.print_info(f"Payload: {json_module.dumps(payload, indent=2)}")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.print_error(f"Response: {response.text[:500]}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
# Parse SSE stream
|
||||
events = []
|
||||
full_response = ""
|
||||
conversation_id = None
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:]
|
||||
try:
|
||||
data = json_module.loads(data_str)
|
||||
events.append(data)
|
||||
|
||||
if data.get("type") in ["stream", "answer"]:
|
||||
full_response += data.get("message", "") or data.get("answer", "")
|
||||
elif data.get("type") == "id":
|
||||
conversation_id = data.get("id")
|
||||
elif data.get("type") == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self.print_success(f"Received {len(events)} events")
|
||||
self.print_info(f"Response preview: {full_response[:100]}...")
|
||||
|
||||
if conversation_id:
|
||||
self.print_success(f"Conversation ID: {conversation_id}")
|
||||
|
||||
self.record_result(test_name, True, "Success")
|
||||
self.print_success(f"{test_name} passed!")
|
||||
return True
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self.print_error(f"Request failed: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Unexpected error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_stream_endpoint_with_agent(self) -> bool:
|
||||
"""Test /stream endpoint with agent_id."""
|
||||
test_name = "Stream endpoint (with agent)"
|
||||
|
||||
agent_result = self.get_or_create_test_agent()
|
||||
if not agent_result:
|
||||
if not self.require_auth(test_name):
|
||||
return True # Skipped
|
||||
self.print_warning("Could not create test agent")
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /stream with agent_id={agent_id[:8]}...")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
events = []
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
events.append(data)
|
||||
if data.get("type") == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self.print_success(f"Received {len(events)} events")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_stream_endpoint_with_api_key(self) -> bool:
|
||||
"""Test /stream endpoint with API key."""
|
||||
test_name = "Stream endpoint (with API key)"
|
||||
|
||||
agent_result = self.get_or_create_published_agent()
|
||||
if not agent_result or not agent_result[1]:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.print_warning("Could not create published agent with API key")
|
||||
self.record_result(test_name, True, "Skipped (no API key)")
|
||||
return True
|
||||
|
||||
_, api_key = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /stream with api_key={api_key[:20]}...")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
events = []
|
||||
full_response = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
events.append(data)
|
||||
if data.get("type") in ["stream", "answer"]:
|
||||
full_response += data.get("message", "") or data.get("answer", "")
|
||||
elif data.get("type") == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self.print_success(f"Received {len(events)} events")
|
||||
self.print_info(f"Response preview: {full_response[:100]}...")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Answer Endpoint Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_answer_endpoint_no_agent(self) -> bool:
|
||||
"""Test /api/answer endpoint without agent."""
|
||||
test_name = "Answer endpoint (no agent)"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info("POST /api/answer")
|
||||
self.print_info(f"Payload: {json_module.dumps(payload, indent=2)}")
|
||||
|
||||
response = self.post("/api/answer", json=payload, timeout=30)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.print_error(f"Response: {response.text[:500]}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
self.print_info(f"Response keys: {list(result.keys())}")
|
||||
|
||||
if "answer" in result:
|
||||
answer = result["answer"]
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
else:
|
||||
self.print_warning("No 'answer' field in response")
|
||||
|
||||
if "conversation_id" in result:
|
||||
self.print_success(f"Conversation ID: {result['conversation_id']}")
|
||||
|
||||
self.record_result(test_name, True, "Success")
|
||||
self.print_success(f"{test_name} passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_answer_endpoint_with_agent(self) -> bool:
|
||||
"""Test /api/answer endpoint with agent_id."""
|
||||
test_name = "Answer endpoint (with agent)"
|
||||
|
||||
agent_result = self.get_or_create_test_agent()
|
||||
if not agent_result:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /api/answer with agent_id={agent_id[:8]}...")
|
||||
|
||||
response = self.post("/api/answer", json=payload, timeout=30)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
answer = result.get("answer", "")
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_answer_endpoint_with_api_key(self) -> bool:
|
||||
"""Test /api/answer endpoint with API key."""
|
||||
test_name = "Answer endpoint (with API key)"
|
||||
|
||||
agent_result = self.get_or_create_published_agent()
|
||||
if not agent_result or not agent_result[1]:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no API key)")
|
||||
return True
|
||||
|
||||
_, api_key = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /api/answer with api_key={api_key[:20]}...")
|
||||
|
||||
response = self.post("/api/answer", json=payload, timeout=30)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
answer = result.get("answer", "")
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Validation Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_model_validation_invalid_model_id(self) -> bool:
|
||||
"""Test that invalid model_id is rejected."""
|
||||
test_name = "Model validation (invalid model_id)"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "Test question",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"model_id": "invalid-model-xyz-123",
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info("POST /stream with invalid model_id")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 400:
|
||||
# Read error from SSE stream
|
||||
error_message = None
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
if data.get("type") == "error":
|
||||
error_message = data.get("message") or data.get("error", "")
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if error_message:
|
||||
self.print_success("Invalid model_id rejected with 400 status")
|
||||
self.print_info(f"Error: {error_message[:200]}")
|
||||
self.record_result(test_name, True, "Validation works")
|
||||
return True
|
||||
else:
|
||||
self.print_warning("No error message in response")
|
||||
self.record_result(test_name, False, "No error message")
|
||||
return False
|
||||
else:
|
||||
self.print_warning(f"Expected 400, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Compression Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_compression_heavy_tool_usage(self) -> bool:
|
||||
"""Test compression with heavy conversation usage."""
|
||||
test_name = "Compression - Heavy Tool Usage"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
self.print_info("Making 10 consecutive requests to build conversation history...")
|
||||
|
||||
current_conv_id = None
|
||||
|
||||
for i in range(10):
|
||||
question = f"Tell me about Python topic {i+1}: data structures, decorators, async, testing. Provide a comprehensive explanation."
|
||||
|
||||
payload = {
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
if current_conv_id:
|
||||
payload["conversation_id"] = current_conv_id
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=payload, timeout=90)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
current_conv_id = result.get("conversation_id", current_conv_id)
|
||||
answer_preview = (result.get("answer") or "")[:80]
|
||||
self.print_success(f"Request {i+1}/10 completed")
|
||||
self.print_info(f" Answer: {answer_preview}...")
|
||||
else:
|
||||
self.print_error(f"Request {i+1}/10 failed: status {response.status_code}")
|
||||
self.record_result(test_name, False, f"Request {i+1} failed")
|
||||
return False
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Request {i+1}/10 failed: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
if current_conv_id:
|
||||
self.print_success("Heavy usage test completed")
|
||||
self.record_result(test_name, True, f"10 requests, conv_id: {current_conv_id}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning("No conversation_id received")
|
||||
self.record_result(test_name, False, "No conversation_id")
|
||||
return False
|
||||
|
||||
def test_compression_needle_in_haystack(self) -> bool:
|
||||
"""Test that compression preserves critical information.
|
||||
|
||||
Note: This is a long-running test that may timeout due to LLM response times.
|
||||
Timeouts are handled gracefully as they indicate performance issues, not bugs.
|
||||
"""
|
||||
test_name = "Compression - Needle in Haystack"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
conversation_id = None
|
||||
|
||||
# Step 1: Send general questions
|
||||
self.print_info("Step 1: Sending general questions...")
|
||||
for i, question in enumerate([
|
||||
"Tell me about Python best practices in detail",
|
||||
"Explain Python data structures comprehensively",
|
||||
]):
|
||||
payload = {
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
if conversation_id:
|
||||
payload["conversation_id"] = conversation_id
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
conversation_id = result.get("conversation_id", conversation_id)
|
||||
self.print_success(f"General question {i+1}/2 completed")
|
||||
else:
|
||||
self.print_error(f"Request failed: status {response.status_code}")
|
||||
self.record_result(test_name, False, "General questions failed")
|
||||
return False
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
# Timeout errors are expected for long LLM responses
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.print_error(f"Request failed: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# Step 2: Send critical information
|
||||
self.print_info("Step 2: Sending CRITICAL information...")
|
||||
critical_payload = {
|
||||
"question": "Please remember: The production database password is stored in DB_PASSWORD_PROD environment variable. The backup runs at 3:00 AM UTC daily.",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=critical_payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
conversation_id = result.get("conversation_id", conversation_id)
|
||||
self.print_success("Critical information sent")
|
||||
else:
|
||||
self.record_result(test_name, False, "Critical info failed")
|
||||
return False
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# Step 3: Bury with more questions
|
||||
self.print_info("Step 3: Sending more questions to bury critical info...")
|
||||
for i, question in enumerate([
|
||||
"Explain Python decorators in great detail",
|
||||
"Tell me about Python async programming comprehensively",
|
||||
]):
|
||||
payload = {
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
conversation_id = result.get("conversation_id", conversation_id)
|
||||
self.print_success(f"Burying question {i+1}/2 completed")
|
||||
else:
|
||||
self.record_result(test_name, False, "Burying questions failed")
|
||||
return False
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# Step 4: Test recall
|
||||
self.print_info("Step 4: Testing if critical info was preserved...")
|
||||
recall_payload = {
|
||||
"question": "What was the database password environment variable I mentioned earlier?",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=recall_payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
answer = (result.get("answer") or "").lower()
|
||||
|
||||
if "db_password_prod" in answer or "database password" in answer:
|
||||
self.print_success("Critical information preserved!")
|
||||
self.print_info(f"Answer: {answer[:150]}...")
|
||||
self.record_result(test_name, True, "Info preserved")
|
||||
return True
|
||||
else:
|
||||
self.print_warning("Critical information may have been lost")
|
||||
self.print_info(f"Answer: {answer[:150]}...")
|
||||
self.record_result(test_name, False, "Info not preserved")
|
||||
return False
|
||||
else:
|
||||
self.record_result(test_name, False, "Recall failed")
|
||||
return False
|
||||
except Exception as e:
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Feedback Tests (NEW)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_feedback_positive(self) -> bool:
|
||||
"""Test positive feedback submission."""
|
||||
test_name = "Feedback - Positive"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
# First create a conversation to get an ID
|
||||
answer_response = self.post(
|
||||
"/api/answer",
|
||||
json={"question": "Hello", "history": "[]", "isNoneDoc": True},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if answer_response.status_code != 200:
|
||||
self.print_warning("Could not create conversation for feedback test")
|
||||
self.record_result(test_name, True, "Skipped (no conversation)")
|
||||
return True
|
||||
|
||||
result = answer_response.json()
|
||||
conversation_id = result.get("conversation_id")
|
||||
|
||||
if not conversation_id:
|
||||
self.record_result(test_name, True, "Skipped (no conversation_id)")
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"question": "Hello",
|
||||
"answer": result.get("answer", ""),
|
||||
"feedback": "like",
|
||||
"conversation_id": conversation_id,
|
||||
"question_index": 0, # Required field
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/feedback", json=payload, timeout=10)
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
self.print_success("Positive feedback submitted")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_feedback_negative(self) -> bool:
|
||||
"""Test negative feedback submission."""
|
||||
test_name = "Feedback - Negative"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
answer_response = self.post(
|
||||
"/api/answer",
|
||||
json={"question": "Hello", "history": "[]", "isNoneDoc": True},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if answer_response.status_code != 200:
|
||||
self.record_result(test_name, True, "Skipped (no conversation)")
|
||||
return True
|
||||
|
||||
result = answer_response.json()
|
||||
conversation_id = result.get("conversation_id")
|
||||
|
||||
if not conversation_id:
|
||||
self.record_result(test_name, True, "Skipped (no conversation_id)")
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"question": "Hello",
|
||||
"answer": result.get("answer", ""),
|
||||
"feedback": "dislike",
|
||||
"conversation_id": conversation_id,
|
||||
"question_index": 0, # Required field
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/feedback", json=payload, timeout=10)
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
self.print_success("Negative feedback submitted")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# TTS Tests (NEW)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_tts_basic(self) -> bool:
|
||||
"""Test basic text-to-speech endpoint."""
|
||||
test_name = "TTS - Basic"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {"text": "Hello, this is a test of the text to speech system."}
|
||||
|
||||
try:
|
||||
response = self.post("/api/tts", json=payload, timeout=30)
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
self.print_success(f"TTS response received, Content-Type: {content_type}")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
elif response.status_code == 501:
|
||||
self.print_warning("TTS not implemented/configured")
|
||||
self.record_result(test_name, True, "Skipped (not configured)")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Run All Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all chat integration tests."""
|
||||
self.print_header("Chat Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
|
||||
|
||||
# Basic endpoint tests
|
||||
self.test_stream_endpoint_no_agent()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_answer_endpoint_no_agent()
|
||||
time.sleep(1)
|
||||
|
||||
# Validation tests
|
||||
self.test_model_validation_invalid_model_id()
|
||||
time.sleep(1)
|
||||
|
||||
# Agent-based tests
|
||||
self.test_stream_endpoint_with_agent()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_answer_endpoint_with_agent()
|
||||
time.sleep(1)
|
||||
|
||||
# API key tests
|
||||
self.test_stream_endpoint_with_api_key()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_answer_endpoint_with_api_key()
|
||||
time.sleep(1)
|
||||
|
||||
# Feedback tests
|
||||
self.test_feedback_positive()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_feedback_negative()
|
||||
time.sleep(1)
|
||||
|
||||
# TTS test
|
||||
self.test_tts_basic()
|
||||
time.sleep(1)
|
||||
|
||||
# Compression tests (longer running)
|
||||
if self.is_authenticated:
|
||||
self.test_compression_heavy_tool_usage()
|
||||
time.sleep(2)
|
||||
|
||||
self.test_compression_needle_in_haystack()
|
||||
else:
|
||||
self.print_info("Skipping compression tests (no authentication)")
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for standalone execution."""
|
||||
client = create_client_from_args(ChatTests, "DocsGPT Chat Integration Tests")
|
||||
success = client.run_all()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
355
tests/integration/test_connectors.py
Normal file
355
tests/integration/test_connectors.py
Normal file
@@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT external connectors endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/connectors/auth (GET) - OAuth authentication URL
|
||||
- /api/connectors/callback (GET) - OAuth callback
|
||||
- /api/connectors/callback-status (GET) - Callback status
|
||||
- /api/connectors/disconnect (POST) - Disconnect connector
|
||||
- /api/connectors/files (POST) - List connector files
|
||||
- /api/connectors/sync (POST) - Sync connector
|
||||
- /api/connectors/validate-session (POST) - Validate session
|
||||
|
||||
Note: Many tests are limited without actual external service connections.
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_connectors.py
|
||||
python tests/integration/test_connectors.py --base-url http://localhost:7091
|
||||
python tests/integration/test_connectors.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class ConnectorTests(DocsGPTTestBase):
|
||||
"""Integration tests for external connector endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Auth Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_connectors_auth_google(self) -> bool:
|
||||
"""Test getting Google OAuth URL."""
|
||||
test_name = "Get Google OAuth URL"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/connectors/auth",
|
||||
params={"provider": "google"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Expect 200 with URL, or 400/501 if not configured
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
auth_url = result.get("url") or result.get("auth_url")
|
||||
if auth_url:
|
||||
self.print_success(f"Got OAuth URL: {auth_url[:50]}...")
|
||||
self.record_result(test_name, True, "OAuth URL retrieved")
|
||||
return True
|
||||
elif response.status_code in [400, 404, 501]:
|
||||
self.print_warning(f"Connector not configured: {response.status_code}")
|
||||
self.record_result(test_name, True, "Not configured (expected)")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_connectors_auth_invalid_provider(self) -> bool:
|
||||
"""Test auth with invalid provider."""
|
||||
test_name = "Auth invalid provider"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/connectors/auth",
|
||||
params={"provider": "invalid_provider_xyz"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [400, 404]:
|
||||
self.print_success(f"Correctly rejected: {response.status_code}")
|
||||
self.record_result(test_name, True, "Invalid provider rejected")
|
||||
return True
|
||||
else:
|
||||
self.print_warning(f"Status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Callback Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_connectors_callback_status(self) -> bool:
|
||||
"""Test checking callback status."""
|
||||
test_name = "Check callback status"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/connectors/callback-status",
|
||||
params={"task_id": "test-task-id"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Expect 200 with status, or 404 for unknown task
|
||||
if response.status_code in [200, 404]:
|
||||
self.print_success(f"Callback status check: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Disconnect Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_connectors_disconnect(self) -> bool:
|
||||
"""Test disconnecting a connector."""
|
||||
test_name = "Disconnect connector"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/connectors/disconnect",
|
||||
json={"provider": "google"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Expect 200 for successful disconnect, or 400/404 if not connected
|
||||
if response.status_code in [200, 400, 404]:
|
||||
self.print_success(f"Disconnect response: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Files Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_connectors_files(self) -> bool:
|
||||
"""Test listing connector files."""
|
||||
test_name = "List connector files"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/connectors/files",
|
||||
json={"provider": "google", "path": "/"},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
# Expect 200 with files, or 400/401/404 if not authenticated
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
files = result.get("files", result)
|
||||
self.print_success(f"Got files list: {len(files) if isinstance(files, list) else 'object'}")
|
||||
self.record_result(test_name, True, "Files retrieved")
|
||||
return True
|
||||
elif response.status_code in [400, 401, 404]:
|
||||
self.print_warning(f"Connector not authenticated: {response.status_code}")
|
||||
self.record_result(test_name, True, "Not authenticated (expected)")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_connectors_files_with_path(self) -> bool:
|
||||
"""Test listing files at specific path."""
|
||||
test_name = "List files at path"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/connectors/files",
|
||||
json={"provider": "google", "path": "/documents"},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 400, 401, 404]:
|
||||
self.print_success(f"Files at path response: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sync Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_connectors_sync(self) -> bool:
|
||||
"""Test syncing a connector."""
|
||||
test_name = "Sync connector"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/connectors/sync",
|
||||
json={"provider": "google", "file_ids": []},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 202, 400, 401, 404]:
|
||||
self.print_success(f"Sync response: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Validate Session Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_connectors_validate_session(self) -> bool:
|
||||
"""Test validating connector session."""
|
||||
test_name = "Validate connector session"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/connectors/validate-session",
|
||||
json={"provider": "google"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 400, 401, 404]:
|
||||
result = response.json() if response.status_code == 200 else {}
|
||||
valid = result.get("valid", False)
|
||||
self.print_success(f"Session validation: {response.status_code}, valid={valid}")
|
||||
self.record_result(test_name, True, f"Valid: {valid}")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Runner
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all connector tests."""
|
||||
self.print_header("DocsGPT Connector Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Auth: {self.token_source}")
|
||||
self.print_warning("Note: Many tests require external service configuration")
|
||||
|
||||
# Auth tests
|
||||
self.test_connectors_auth_google()
|
||||
self.test_connectors_auth_invalid_provider()
|
||||
|
||||
# Callback tests
|
||||
self.test_connectors_callback_status()
|
||||
|
||||
# Disconnect tests
|
||||
self.test_connectors_disconnect()
|
||||
|
||||
# Files tests
|
||||
self.test_connectors_files()
|
||||
self.test_connectors_files_with_path()
|
||||
|
||||
# Sync tests
|
||||
self.test_connectors_sync()
|
||||
|
||||
# Validate session tests
|
||||
self.test_connectors_validate_session()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(ConnectorTests, "DocsGPT Connector Integration Tests")
|
||||
exit_code = 0 if client.run_all() else 1
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
495
tests/integration/test_conversations.py
Normal file
495
tests/integration/test_conversations.py
Normal file
@@ -0,0 +1,495 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT conversation management endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/get_conversations (GET) - List conversations
|
||||
- /api/get_single_conversation (GET) - Get single conversation
|
||||
- /api/delete_conversation (POST) - Delete conversation
|
||||
- /api/delete_all_conversations (GET) - Delete all conversations
|
||||
- /api/update_conversation_name (POST) - Rename conversation
|
||||
- /api/share (POST) - Share conversation
|
||||
- /api/shared_conversation/{id} (GET) - Get shared conversation
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_conversations.py
|
||||
python tests/integration/test_conversations.py --base-url http://localhost:7091
|
||||
python tests/integration/test_conversations.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class ConversationTests(DocsGPTTestBase):
|
||||
"""Integration tests for conversation management endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Data Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_test_conversation(self) -> Optional[str]:
|
||||
"""
|
||||
Get or create a test conversation by making a chat request.
|
||||
|
||||
Returns:
|
||||
Conversation ID or None if creation fails
|
||||
"""
|
||||
if hasattr(self, "_test_conversation_id"):
|
||||
return self._test_conversation_id
|
||||
|
||||
if not self.is_authenticated:
|
||||
return None
|
||||
|
||||
# Create conversation via a chat request
|
||||
try:
|
||||
payload = {
|
||||
"question": "Test message for conversation creation",
|
||||
"history": [],
|
||||
"conversation_id": None,
|
||||
}
|
||||
|
||||
response = self.post("/api/answer", json=payload, timeout=30)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
conv_id = result.get("conversation_id")
|
||||
if conv_id:
|
||||
self._test_conversation_id = conv_id
|
||||
return conv_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def get_existing_conversation(self) -> Optional[str]:
|
||||
"""Get an existing conversation ID from the list."""
|
||||
try:
|
||||
response = self.get("/api/get_conversations", timeout=10)
|
||||
if response.status_code == 200:
|
||||
convs = response.json()
|
||||
if convs and len(convs) > 0:
|
||||
return convs[0].get("id")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# List/Get Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_conversations(self) -> bool:
|
||||
"""Test listing all conversations."""
|
||||
test_name = "List conversations"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get("/api/get_conversations", timeout=10)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
if not isinstance(result, list):
|
||||
self.print_error("Response is not a list")
|
||||
self.record_result(test_name, False, "Invalid response type")
|
||||
return False
|
||||
|
||||
self.print_success(f"Retrieved {len(result)} conversations")
|
||||
if result:
|
||||
self.print_info(f"First: {result[0].get('name', 'N/A')[:30]}...")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_conversations_paginated(self) -> bool:
|
||||
"""Test getting conversations with pagination."""
|
||||
test_name = "List conversations paginated"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/get_conversations",
|
||||
params={"page": 1, "per_page": 5},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
if not isinstance(result, list):
|
||||
self.print_error("Response is not a list")
|
||||
self.record_result(test_name, False, "Invalid response type")
|
||||
return False
|
||||
|
||||
self.print_success(f"Retrieved {len(result)} conversations (page 1)")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_single_conversation(self) -> bool:
|
||||
"""Test getting a single conversation by ID."""
|
||||
test_name = "Get single conversation"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Try to get existing conversation
|
||||
conv_id = self.get_existing_conversation()
|
||||
if not conv_id:
|
||||
conv_id = self.get_or_create_test_conversation()
|
||||
|
||||
if not conv_id:
|
||||
self.print_warning("No conversations available")
|
||||
self.record_result(test_name, True, "Skipped (no conversations)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/get_single_conversation",
|
||||
params={"id": conv_id},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
self.print_success(f"Retrieved conversation: {conv_id[:20]}...")
|
||||
self.print_info(f"Messages: {len(result.get('queries', []))}")
|
||||
self.record_result(test_name, True, f"ID: {conv_id[:20]}...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_single_conversation_not_found(self) -> bool:
|
||||
"""Test getting a non-existent conversation."""
|
||||
test_name = "Get non-existent conversation"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/get_single_conversation",
|
||||
params={"id": "nonexistent-conversation-id-12345"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [404, 400, 500]:
|
||||
self.print_success(f"Correctly returned {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Update Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_update_conversation_name(self) -> bool:
|
||||
"""Test renaming a conversation."""
|
||||
test_name = "Update conversation name"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
conv_id = self.get_existing_conversation()
|
||||
if not conv_id:
|
||||
conv_id = self.get_or_create_test_conversation()
|
||||
|
||||
if not conv_id:
|
||||
self.print_warning("No conversation to rename")
|
||||
self.record_result(test_name, True, "Skipped (no conversation)")
|
||||
return True
|
||||
|
||||
new_name = f"Renamed Conversation {int(time.time())}"
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/update_conversation_name",
|
||||
json={"id": conv_id, "name": new_name},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
self.print_success(f"Renamed conversation to: {new_name[:30]}...")
|
||||
self.record_result(test_name, True, f"New name: {new_name[:20]}...")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Rename failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Delete Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_delete_conversation(self) -> bool:
|
||||
"""Test deleting a single conversation."""
|
||||
test_name = "Delete conversation"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Create a conversation specifically for deletion
|
||||
try:
|
||||
payload = {
|
||||
"question": "Test message for deletion test",
|
||||
"history": [],
|
||||
"conversation_id": None,
|
||||
}
|
||||
|
||||
create_response = self.post("/api/answer", json=payload, timeout=30)
|
||||
if create_response.status_code != 200:
|
||||
self.print_warning("Could not create conversation for deletion")
|
||||
self.record_result(test_name, True, "Skipped (create failed)")
|
||||
return True
|
||||
|
||||
conv_id = create_response.json().get("conversation_id")
|
||||
if not conv_id:
|
||||
self.print_warning("No conversation ID returned")
|
||||
self.record_result(test_name, True, "Skipped (no ID)")
|
||||
return True
|
||||
|
||||
# Delete the conversation
|
||||
response = self.post(
|
||||
"/api/delete_conversation",
|
||||
json={"id": conv_id},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 204]:
|
||||
self.print_success(f"Deleted conversation: {conv_id[:20]}...")
|
||||
self.record_result(test_name, True, "Conversation deleted")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Delete failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_delete_all_conversations(self) -> bool:
|
||||
"""Test the delete all conversations endpoint (without actually deleting all)."""
|
||||
test_name = "Delete all conversations endpoint"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
self.print_warning("Skipping actual deletion to preserve data")
|
||||
self.print_info("Verifying endpoint exists...")
|
||||
|
||||
try:
|
||||
# Just verify endpoint responds (don't actually call it)
|
||||
# We can test with a GET to see if endpoint exists
|
||||
response = self.get("/api/delete_all_conversations", timeout=10)
|
||||
|
||||
# Any response means endpoint exists
|
||||
self.print_success(f"Endpoint responded: {response.status_code}")
|
||||
self.record_result(test_name, True, "Endpoint verified")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Share Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_share_conversation(self) -> bool:
|
||||
"""Test sharing a conversation."""
|
||||
test_name = "Share conversation"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
conv_id = self.get_existing_conversation()
|
||||
if not conv_id:
|
||||
conv_id = self.get_or_create_test_conversation()
|
||||
|
||||
if not conv_id:
|
||||
self.print_warning("No conversation to share")
|
||||
self.record_result(test_name, True, "Skipped (no conversation)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/share",
|
||||
json={"conversation_id": conv_id},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
share_id = result.get("share_id") or result.get("id")
|
||||
self.print_success(f"Shared conversation: {share_id}")
|
||||
self._shared_conversation_id = share_id
|
||||
self.record_result(test_name, True, f"Share ID: {share_id}")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Share failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_shared_conversation(self) -> bool:
|
||||
"""Test getting a shared conversation."""
|
||||
test_name = "Get shared conversation"
|
||||
self.print_header(test_name)
|
||||
|
||||
# Use share ID from previous test if available
|
||||
share_id = getattr(self, "_shared_conversation_id", None)
|
||||
|
||||
if not share_id:
|
||||
self.print_warning("No shared conversation available")
|
||||
self.record_result(test_name, True, "Skipped (no shared conversation)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(f"/api/shared_conversation/{share_id}", timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
self.print_success("Retrieved shared conversation")
|
||||
self.print_info(f"Messages: {len(result.get('queries', []))}")
|
||||
self.record_result(test_name, True, f"Share ID: {share_id}")
|
||||
return True
|
||||
elif response.status_code == 404:
|
||||
self.print_warning("Shared conversation not found")
|
||||
self.record_result(test_name, True, "Not found (may be expected)")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Get failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_shared_conversation_not_found(self) -> bool:
|
||||
"""Test getting a non-existent shared conversation."""
|
||||
test_name = "Get non-existent shared conversation"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/shared_conversation/nonexistent-share-id-12345",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [404, 400]:
|
||||
self.print_success(f"Correctly returned {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Runner
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all conversation tests."""
|
||||
self.print_header("DocsGPT Conversation Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Auth: {self.token_source}")
|
||||
|
||||
# List/Get tests
|
||||
self.test_get_conversations()
|
||||
self.test_get_conversations_paginated()
|
||||
self.test_get_single_conversation()
|
||||
self.test_get_single_conversation_not_found()
|
||||
|
||||
# Update tests
|
||||
self.test_update_conversation_name()
|
||||
|
||||
# Delete tests
|
||||
self.test_delete_conversation()
|
||||
self.test_delete_all_conversations()
|
||||
|
||||
# Share tests
|
||||
self.test_share_conversation()
|
||||
self.test_get_shared_conversation()
|
||||
self.test_get_shared_conversation_not_found()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(
|
||||
ConversationTests, "DocsGPT Conversation Integration Tests"
|
||||
)
|
||||
exit_code = 0 if client.run_all() else 1
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
337
tests/integration/test_mcp.py
Normal file
337
tests/integration/test_mcp.py
Normal file
@@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT MCP (Model Context Protocol) server endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/mcp_server/callback (GET) - OAuth callback
|
||||
- /api/mcp_server/oauth_status/{task_id} (GET) - OAuth status
|
||||
- /api/mcp_server/save (POST) - Save MCP server config
|
||||
- /api/mcp_server/test (POST) - Test MCP server connection
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_mcp.py
|
||||
python tests/integration/test_mcp.py --base-url http://localhost:7091
|
||||
python tests/integration/test_mcp.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class MCPTests(DocsGPTTestBase):
|
||||
"""Integration tests for MCP server endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Callback Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_mcp_callback(self) -> bool:
|
||||
"""Test MCP OAuth callback endpoint."""
|
||||
test_name = "MCP OAuth callback"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/mcp_server/callback",
|
||||
params={"code": "test_code", "state": "test_state"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Expect various responses depending on configuration
|
||||
if response.status_code in [200, 302, 400, 404]:
|
||||
self.print_success(f"Callback response: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# OAuth Status Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_mcp_oauth_status(self) -> bool:
|
||||
"""Test getting MCP OAuth status."""
|
||||
test_name = "MCP OAuth status"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/mcp_server/oauth_status/test-task-id-123",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 404]:
|
||||
self.print_success(f"OAuth status check: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
self.print_info(f"Status: {result.get('status', 'N/A')}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_mcp_oauth_status_invalid_task(self) -> bool:
|
||||
"""Test OAuth status for invalid task ID."""
|
||||
test_name = "MCP OAuth status invalid"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/mcp_server/oauth_status/nonexistent-task-xyz",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [404, 400]:
|
||||
self.print_success(f"Correctly returned: {response.status_code}")
|
||||
self.record_result(test_name, True, "Invalid task handled")
|
||||
return True
|
||||
elif response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get("status") in ["not_found", "unknown", None]:
|
||||
self.print_success("Invalid task handled (status: not_found)")
|
||||
self.record_result(test_name, True, "Invalid task handled")
|
||||
return True
|
||||
|
||||
self.print_warning(f"Status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Save Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_mcp_save(self) -> bool:
|
||||
"""Test saving MCP server configuration."""
|
||||
test_name = "Save MCP server config"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"name": f"Test MCP Server {int(time.time())}",
|
||||
"url": "https://example.com/mcp",
|
||||
"config": {},
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/mcp_server/save",
|
||||
json=payload,
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
self.print_success(f"Saved MCP server: {result.get('id', 'N/A')}")
|
||||
self.record_result(test_name, True, "Config saved")
|
||||
return True
|
||||
elif response.status_code in [400, 422]:
|
||||
self.print_warning(f"Validation error: {response.status_code}")
|
||||
self.record_result(test_name, True, "Validation handled")
|
||||
return True
|
||||
|
||||
self.print_error(f"Save failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_mcp_save_invalid(self) -> bool:
|
||||
"""Test saving invalid MCP config."""
|
||||
test_name = "Save invalid MCP config"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"name": "", # Invalid empty name
|
||||
"url": "not-a-url", # Invalid URL
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/mcp_server/save",
|
||||
json=payload,
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if response.status_code in [400, 422]:
|
||||
self.print_success(f"Validation rejected: {response.status_code}")
|
||||
self.record_result(test_name, True, "Invalid config rejected")
|
||||
return True
|
||||
elif response.status_code in [200, 201]:
|
||||
self.print_warning("Server accepted invalid data (lenient validation)")
|
||||
self.record_result(test_name, True, "Lenient validation")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Connection Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_mcp_test_connection(self) -> bool:
|
||||
"""Test MCP server connection test."""
|
||||
test_name = "Test MCP connection"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"url": "https://example.com/mcp",
|
||||
"config": {},
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/mcp_server/test",
|
||||
json=payload,
|
||||
timeout=30, # Connection test may take time
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
success = result.get("success", result.get("connected", False))
|
||||
self.print_success(f"Connection test: success={success}")
|
||||
self.record_result(test_name, True, f"Connected: {success}")
|
||||
return True
|
||||
elif response.status_code in [400, 500, 502, 504]:
|
||||
# Connection failed (expected for non-existent server)
|
||||
self.print_warning(f"Connection failed: {response.status_code}")
|
||||
self.record_result(test_name, True, "Connection failed (expected)")
|
||||
return True
|
||||
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_mcp_test_connection_invalid(self) -> bool:
|
||||
"""Test MCP connection with invalid URL."""
|
||||
test_name = "Test MCP invalid URL"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"url": "invalid-url",
|
||||
"config": {},
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/mcp_server/test",
|
||||
json=payload,
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if response.status_code in [400, 422, 500]:
|
||||
self.print_success(f"Invalid URL rejected: {response.status_code}")
|
||||
self.record_result(test_name, True, "Invalid URL handled")
|
||||
return True
|
||||
elif response.status_code == 200:
|
||||
result = response.json()
|
||||
if not result.get("success", result.get("connected", True)):
|
||||
self.print_success("Connection correctly failed")
|
||||
self.record_result(test_name, True, "Connection failed")
|
||||
return True
|
||||
|
||||
self.print_warning(f"Status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Runner
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all MCP tests."""
|
||||
self.print_header("DocsGPT MCP Server Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Auth: {self.token_source}")
|
||||
|
||||
# Callback tests
|
||||
self.test_mcp_callback()
|
||||
|
||||
# OAuth status tests
|
||||
self.test_mcp_oauth_status()
|
||||
self.test_mcp_oauth_status_invalid_task()
|
||||
|
||||
# Save tests
|
||||
self.test_mcp_save()
|
||||
self.test_mcp_save_invalid()
|
||||
|
||||
# Test connection tests
|
||||
self.test_mcp_test_connection()
|
||||
self.test_mcp_test_connection_invalid()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(MCPTests, "DocsGPT MCP Server Integration Tests")
|
||||
exit_code = 0 if client.run_all() else 1
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
317
tests/integration/test_misc.py
Normal file
317
tests/integration/test_misc.py
Normal file
@@ -0,0 +1,317 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT miscellaneous endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/models (GET) - List available models
|
||||
- /api/images/{image_path} (GET) - Get images
|
||||
- /api/store_attachment (POST) - Store attachments
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_misc.py
|
||||
python tests/integration/test_misc.py --base-url http://localhost:7091
|
||||
python tests/integration/test_misc.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class MiscTests(DocsGPTTestBase):
|
||||
"""Integration tests for miscellaneous endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Models Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_models(self) -> bool:
|
||||
"""Test listing available models."""
|
||||
test_name = "List models"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
response = self.get("/api/models", timeout=10)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
# Handle both list and object responses
|
||||
if isinstance(result, list):
|
||||
self.print_success(f"Retrieved {len(result)} models")
|
||||
if result:
|
||||
first_model = result[0]
|
||||
if isinstance(first_model, dict):
|
||||
self.print_info(f"First: {first_model.get('name', first_model.get('id', 'N/A'))}")
|
||||
else:
|
||||
self.print_info(f"First: {first_model}")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
elif isinstance(result, dict):
|
||||
# May return object with models array
|
||||
models = result.get("models", result.get("data", []))
|
||||
if isinstance(models, list):
|
||||
self.print_success(f"Retrieved {len(models)} models")
|
||||
if models:
|
||||
first = models[0]
|
||||
name = first.get('name', first) if isinstance(first, dict) else first
|
||||
self.print_info(f"First: {name}")
|
||||
else:
|
||||
self.print_success("Retrieved models data")
|
||||
self.record_result(test_name, True, "Models retrieved")
|
||||
else:
|
||||
self.print_warning(f"Unexpected response type: {type(result)}")
|
||||
self.record_result(test_name, True, "Response received")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_models_with_filter(self) -> bool:
|
||||
"""Test listing models with filter parameters."""
|
||||
test_name = "List models filtered"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/models",
|
||||
params={"provider": "openai"}, # Filter by provider
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
if isinstance(result, list):
|
||||
self.print_success(f"Retrieved {len(result)} filtered models")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning("Response format may vary")
|
||||
self.record_result(test_name, True, "Response received")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Images Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_image(self) -> bool:
|
||||
"""Test getting an image by path."""
|
||||
test_name = "Get image"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
# Test with a placeholder path
|
||||
response = self.get("/api/images/test.png", timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
content_type = response.headers.get("content-type", "")
|
||||
self.print_success(f"Image retrieved: {content_type}")
|
||||
self.record_result(test_name, True, f"Type: {content_type}")
|
||||
return True
|
||||
elif response.status_code == 404:
|
||||
self.print_warning("Image not found (expected for test)")
|
||||
self.record_result(test_name, True, "404 - Image not found")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_image_not_found(self) -> bool:
|
||||
"""Test getting a non-existent image."""
|
||||
test_name = "Get non-existent image"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/images/nonexistent-image-xyz-12345.png",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
self.print_success("Correctly returned 404")
|
||||
self.record_result(test_name, True, "404 returned")
|
||||
return True
|
||||
elif response.status_code in [400, 500]:
|
||||
self.print_warning(f"Error status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning(f"Status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Attachment Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_store_attachment(self) -> bool:
|
||||
"""Test storing an attachment."""
|
||||
test_name = "Store attachment"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Create a small test file content
|
||||
test_content = b"Test attachment content for integration test"
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/store_attachment",
|
||||
files={"file": ("test_attachment.txt", test_content, "text/plain")},
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
attachment_id = result.get("id") or result.get("attachment_id") or result.get("path")
|
||||
self.print_success(f"Stored attachment: {attachment_id}")
|
||||
self.record_result(test_name, True, f"ID: {attachment_id}")
|
||||
return True
|
||||
elif response.status_code in [400, 422]:
|
||||
self.print_warning(f"Validation: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Store failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_store_attachment_large(self) -> bool:
|
||||
"""Test storing a larger attachment."""
|
||||
test_name = "Store large attachment"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Create a larger test file (1KB)
|
||||
test_content = b"X" * 1024
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/store_attachment",
|
||||
files={"file": ("large_test.bin", test_content, "application/octet-stream")},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
response.json() # Validate JSON response
|
||||
self.print_success("Large attachment stored")
|
||||
self.record_result(test_name, True, "Attachment stored")
|
||||
return True
|
||||
elif response.status_code in [400, 413, 422]:
|
||||
self.print_warning(f"Size/validation: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Store failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Health/Info Tests (bonus)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_health_check(self) -> bool:
|
||||
"""Test basic health check (root or health endpoint)."""
|
||||
test_name = "Health check"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
# Try common health endpoints
|
||||
for path in ["/health", "/api/health", "/"]:
|
||||
response = self.get(path, timeout=5)
|
||||
if response.status_code == 200:
|
||||
self.print_success(f"Health check passed: {path}")
|
||||
self.record_result(test_name, True, f"Endpoint: {path}")
|
||||
return True
|
||||
|
||||
# If none worked, check if server responds at all
|
||||
self.print_warning("No standard health endpoint found")
|
||||
self.record_result(test_name, True, "Server responsive")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Runner
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all miscellaneous tests."""
|
||||
self.print_header("DocsGPT Miscellaneous Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Auth: {self.token_source}")
|
||||
|
||||
# Health check
|
||||
self.test_health_check()
|
||||
|
||||
# Models tests
|
||||
self.test_get_models()
|
||||
self.test_get_models_with_filter()
|
||||
|
||||
# Images tests
|
||||
self.test_get_image()
|
||||
self.test_get_image_not_found()
|
||||
|
||||
# Attachment tests
|
||||
self.test_store_attachment()
|
||||
self.test_store_attachment_large()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(MiscTests, "DocsGPT Miscellaneous Integration Tests")
|
||||
exit_code = 0 if client.run_all() else 1
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
432
tests/integration/test_prompts.py
Normal file
432
tests/integration/test_prompts.py
Normal file
@@ -0,0 +1,432 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT prompt management endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/create_prompt (POST) - Create prompt
|
||||
- /api/get_prompts (GET) - List prompts
|
||||
- /api/get_single_prompt (GET) - Get single prompt
|
||||
- /api/update_prompt (POST) - Update prompt
|
||||
- /api/delete_prompt (POST) - Delete prompt
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_prompts.py
|
||||
python tests/integration/test_prompts.py --base-url http://localhost:7091
|
||||
python tests/integration/test_prompts.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class PromptTests(DocsGPTTestBase):
|
||||
"""Integration tests for prompt management endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Data Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_test_prompt(self) -> Optional[str]:
|
||||
"""
|
||||
Get or create a test prompt.
|
||||
|
||||
Returns:
|
||||
Prompt ID or None if creation fails
|
||||
"""
|
||||
if hasattr(self, "_test_prompt_id"):
|
||||
return self._test_prompt_id
|
||||
|
||||
if not self.is_authenticated:
|
||||
return None
|
||||
|
||||
payload = {
|
||||
"name": f"Test Prompt {int(time.time())}",
|
||||
"content": "You are a helpful assistant. Answer questions accurately.",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_prompt", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
prompt_id = result.get("id")
|
||||
if prompt_id:
|
||||
self._test_prompt_id = prompt_id
|
||||
return prompt_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def cleanup_test_prompt(self, prompt_id: str) -> None:
|
||||
"""Delete a test prompt (cleanup helper)."""
|
||||
if not self.is_authenticated:
|
||||
return
|
||||
try:
|
||||
self.post("/api/delete_prompt", json={"id": prompt_id}, timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Create Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_create_prompt(self) -> bool:
|
||||
"""Test creating a prompt."""
|
||||
test_name = "Create prompt"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"name": f"Created Prompt {int(time.time())}",
|
||||
"content": "You are a test assistant created by integration tests.",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_prompt", json=payload, timeout=10)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
prompt_id = result.get("id")
|
||||
|
||||
if not prompt_id:
|
||||
self.print_error("No prompt ID returned")
|
||||
self.record_result(test_name, False, "No prompt ID")
|
||||
return False
|
||||
|
||||
self.print_success(f"Created prompt: {prompt_id}")
|
||||
self.print_info(f"Name: {payload['name']}")
|
||||
self.record_result(test_name, True, f"ID: {prompt_id}")
|
||||
|
||||
# Cleanup
|
||||
self.cleanup_test_prompt(prompt_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_create_prompt_validation(self) -> bool:
|
||||
"""Test prompt creation validation (missing required fields)."""
|
||||
test_name = "Create prompt validation"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Missing content field
|
||||
payload = {
|
||||
"name": "Invalid Prompt",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_prompt", json=payload, timeout=10)
|
||||
|
||||
# Expect validation error (400) or accept it if server provides defaults
|
||||
if response.status_code in [400, 422]:
|
||||
self.print_success(f"Validation error returned: {response.status_code}")
|
||||
self.record_result(test_name, True, "Validation works")
|
||||
return True
|
||||
elif response.status_code in [200, 201]:
|
||||
self.print_warning("Server accepted incomplete data (may have defaults)")
|
||||
result = response.json()
|
||||
if result.get("id"):
|
||||
self.cleanup_test_prompt(result["id"])
|
||||
self.record_result(test_name, True, "Server accepted with defaults")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Read Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_prompts(self) -> bool:
|
||||
"""Test listing all prompts."""
|
||||
test_name = "List prompts"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get("/api/get_prompts", timeout=10)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
if not isinstance(result, list):
|
||||
self.print_error("Response is not a list")
|
||||
self.record_result(test_name, False, "Invalid response type")
|
||||
return False
|
||||
|
||||
self.print_success(f"Retrieved {len(result)} prompts")
|
||||
if result:
|
||||
self.print_info(f"First: {result[0].get('name', 'N/A')}")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_prompts_with_pagination(self) -> bool:
|
||||
"""Test listing prompts with pagination params."""
|
||||
test_name = "List prompts paginated"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/get_prompts",
|
||||
params={"skip": 0, "limit": 10},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
if not isinstance(result, list):
|
||||
self.print_error("Response is not a list")
|
||||
self.record_result(test_name, False, "Invalid response type")
|
||||
return False
|
||||
|
||||
self.print_success(f"Retrieved {len(result)} prompts (paginated)")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_single_prompt(self) -> bool:
|
||||
"""Test getting a single prompt by ID."""
|
||||
test_name = "Get single prompt"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
prompt_id = self.get_or_create_test_prompt()
|
||||
if not prompt_id:
|
||||
self.print_warning("Could not create test prompt")
|
||||
self.record_result(test_name, True, "Skipped (no prompt)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/get_single_prompt",
|
||||
params={"id": prompt_id},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
# Handle different response formats (may have _id instead of id)
|
||||
returned_id = result.get("id") or result.get("_id")
|
||||
|
||||
if returned_id and returned_id != prompt_id:
|
||||
self.print_error(f"Wrong prompt returned: {returned_id}")
|
||||
self.record_result(test_name, False, "Wrong prompt ID")
|
||||
return False
|
||||
|
||||
self.print_success(f"Retrieved prompt: {result.get('name', 'N/A')}")
|
||||
self.record_result(test_name, True, f"Name: {result.get('name', 'N/A')}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_single_prompt_not_found(self) -> bool:
|
||||
"""Test getting a non-existent prompt."""
|
||||
test_name = "Get non-existent prompt"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get(
|
||||
"/api/get_single_prompt",
|
||||
params={"id": "nonexistent-prompt-id-12345"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [404, 400, 500]:
|
||||
self.print_success(f"Correctly returned {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, True, f"Status: {response.status_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Update Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_update_prompt(self) -> bool:
|
||||
"""Test updating a prompt."""
|
||||
test_name = "Update prompt"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
prompt_id = self.get_or_create_test_prompt()
|
||||
if not prompt_id:
|
||||
self.print_warning("Could not create test prompt")
|
||||
self.record_result(test_name, True, "Skipped (no prompt)")
|
||||
return True
|
||||
|
||||
new_content = f"Updated content at {int(time.time())}"
|
||||
new_name = f"Updated Prompt {int(time.time())}"
|
||||
|
||||
try:
|
||||
# UpdatePromptModel requires id, name, and content
|
||||
response = self.post(
|
||||
"/api/update_prompt",
|
||||
json={"id": prompt_id, "name": new_name, "content": new_content},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
self.print_success("Prompt updated successfully")
|
||||
self.record_result(test_name, True, "Prompt updated")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Update failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Delete Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_delete_prompt(self) -> bool:
|
||||
"""Test deleting a prompt."""
|
||||
test_name = "Delete prompt"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Create a prompt specifically for deletion
|
||||
payload = {
|
||||
"name": f"Prompt to Delete {int(time.time())}",
|
||||
"content": "Will be deleted",
|
||||
}
|
||||
|
||||
try:
|
||||
create_response = self.post("/api/create_prompt", json=payload, timeout=10)
|
||||
if create_response.status_code not in [200, 201]:
|
||||
self.print_warning("Could not create prompt for deletion")
|
||||
self.record_result(test_name, True, "Skipped (create failed)")
|
||||
return True
|
||||
|
||||
prompt_id = create_response.json().get("id")
|
||||
|
||||
# Delete the prompt
|
||||
response = self.post("/api/delete_prompt", json={"id": prompt_id}, timeout=10)
|
||||
|
||||
if response.status_code in [200, 204]:
|
||||
self.print_success(f"Deleted prompt: {prompt_id}")
|
||||
self.record_result(test_name, True, "Prompt deleted")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Delete failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Runner
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all prompt tests."""
|
||||
self.print_header("DocsGPT Prompt Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Auth: {self.token_source}")
|
||||
|
||||
# Create tests
|
||||
self.test_create_prompt()
|
||||
self.test_create_prompt_validation()
|
||||
|
||||
# Read tests
|
||||
self.test_get_prompts()
|
||||
self.test_get_prompts_with_pagination()
|
||||
self.test_get_single_prompt()
|
||||
self.test_get_single_prompt_not_found()
|
||||
|
||||
# Update tests
|
||||
self.test_update_prompt()
|
||||
|
||||
# Delete tests
|
||||
self.test_delete_prompt()
|
||||
|
||||
# Cleanup
|
||||
if hasattr(self, "_test_prompt_id"):
|
||||
self.cleanup_test_prompt(self._test_prompt_id)
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(PromptTests, "DocsGPT Prompt Integration Tests")
|
||||
exit_code = 0 if client.run_all() else 1
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
675
tests/integration/test_sources.py
Normal file
675
tests/integration/test_sources.py
Normal file
@@ -0,0 +1,675 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT source management endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/upload (POST) - File upload
|
||||
- /api/remote (POST) - Remote source (crawler)
|
||||
- /api/sources (GET) - List sources
|
||||
- /api/sources/paginated (GET) - Paginated sources
|
||||
- /api/task_status (GET) - Task status
|
||||
- /api/add_chunk (POST) - Add chunk to source
|
||||
- /api/get_chunks (GET) - Get chunks from source
|
||||
- /api/update_chunk (PUT) - Update chunk
|
||||
- /api/delete_chunk (DELETE) - Delete chunk
|
||||
- /api/delete_by_ids (GET) - Delete sources by IDs
|
||||
- /api/delete_old (GET) - Delete old sources
|
||||
- /api/directory_structure (GET) - Get directory structure
|
||||
- /api/manage_source_files (POST) - Manage source files
|
||||
- /api/manage_sync (POST) - Manage sync
|
||||
- /api/combine (GET) - Combine sources
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_sources.py
|
||||
python tests/integration/test_sources.py --base-url http://localhost:7091
|
||||
python tests/integration/test_sources.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class SourceTests(DocsGPTTestBase):
|
||||
"""Integration tests for source management endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Data Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_test_source(self) -> Optional[dict]:
|
||||
"""
|
||||
Get or create a test source.
|
||||
|
||||
Returns:
|
||||
Dict with keys: id, task_id, name or None
|
||||
"""
|
||||
if hasattr(self, "_test_source"):
|
||||
return self._test_source
|
||||
|
||||
if not self.is_authenticated:
|
||||
return None
|
||||
|
||||
test_name = f"Source Test {int(time.time())}"
|
||||
test_content = """# Test Documentation
|
||||
|
||||
## Overview
|
||||
This is test documentation for source integration tests.
|
||||
|
||||
## Installation
|
||||
Run `pip install docsgpt` to install.
|
||||
|
||||
## Usage
|
||||
Import and use the library in your code.
|
||||
|
||||
## API Reference
|
||||
See the API documentation for details.
|
||||
"""
|
||||
|
||||
files = {"file": ("test_source.txt", test_content.encode(), "text/plain")}
|
||||
data = {"user": "test_user", "name": test_name}
|
||||
|
||||
try:
|
||||
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
task_id = result.get("task_id")
|
||||
if task_id:
|
||||
# Wait for processing
|
||||
time.sleep(5)
|
||||
|
||||
# Get source ID
|
||||
source_id = self._get_source_id_by_name(test_name)
|
||||
if source_id:
|
||||
self._test_source = {
|
||||
"id": source_id,
|
||||
"task_id": task_id,
|
||||
"name": test_name,
|
||||
}
|
||||
return self._test_source
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _get_source_id_by_name(self, name: str) -> Optional[str]:
|
||||
"""Get source ID by name from sources list."""
|
||||
try:
|
||||
response = self.get("/api/sources")
|
||||
if response.status_code == 200:
|
||||
sources = response.json()
|
||||
for source in sources:
|
||||
if source.get("name") == name:
|
||||
return source.get("id")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _wait_for_task(self, task_id: str, max_wait: int = 30) -> Optional[str]:
|
||||
"""Wait for task to complete and return status."""
|
||||
for _ in range(max_wait):
|
||||
try:
|
||||
response = self.get("/api/task_status", params={"task_id": task_id})
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
status = result.get("status")
|
||||
if status in ["SUCCESS", "FAILURE"]:
|
||||
return status
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(1)
|
||||
return None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Upload Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_upload_text_source(self) -> bool:
|
||||
"""Test uploading a text file source."""
|
||||
test_name = "Upload - Text Source"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
test_content = f"""# Upload Test {int(time.time())}
|
||||
This is a test document for upload testing.
|
||||
It contains multiple lines of text.
|
||||
"""
|
||||
|
||||
files = {"file": ("upload_test.txt", test_content.encode(), "text/plain")}
|
||||
data = {"user": "test_user", "name": f"Upload Test {int(time.time())}"}
|
||||
|
||||
try:
|
||||
self.print_info("POST /api/upload")
|
||||
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
task_id = result.get("task_id")
|
||||
|
||||
if task_id:
|
||||
self.print_success(f"Upload task started: {task_id}")
|
||||
self.record_result(test_name, True, f"Task: {task_id}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning("No task_id returned")
|
||||
self.record_result(test_name, False, "No task_id")
|
||||
return False
|
||||
else:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_upload_markdown_source(self) -> bool:
|
||||
"""Test uploading a markdown file source."""
|
||||
test_name = "Upload - Markdown Source"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
test_content = f"""# Markdown Test Document
|
||||
|
||||
## Section 1
|
||||
This is the first section with **bold** and *italic* text.
|
||||
|
||||
## Section 2
|
||||
- Item 1
|
||||
- Item 2
|
||||
- Item 3
|
||||
|
||||
## Code Example
|
||||
```python
|
||||
def hello():
|
||||
print("Hello, World!")
|
||||
```
|
||||
|
||||
Created at: {int(time.time())}
|
||||
"""
|
||||
|
||||
files = {"file": ("test.md", test_content.encode(), "text/markdown")}
|
||||
data = {"user": "test_user", "name": f"Markdown Test {int(time.time())}"}
|
||||
|
||||
try:
|
||||
self.print_info("POST /api/upload (markdown)")
|
||||
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
task_id = result.get("task_id")
|
||||
if task_id:
|
||||
self.print_success(f"Markdown upload task started: {task_id}")
|
||||
self.record_result(test_name, True, f"Task: {task_id}")
|
||||
return True
|
||||
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Remote Source Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_remote_crawler_source(self) -> bool:
|
||||
"""Test remote crawler source upload."""
|
||||
test_name = "Remote - Crawler Source"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Use a small, fast-loading page
|
||||
payload = {
|
||||
"user": "test_user",
|
||||
"source": "crawler",
|
||||
"name": f"Crawler Test {int(time.time())}",
|
||||
"data": '{"url": "https://example.com/"}',
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info("POST /api/remote (crawler)")
|
||||
response = self.post("/api/remote", data=payload, timeout=30)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
task_id = result.get("task_id")
|
||||
if task_id:
|
||||
self.print_success(f"Crawler task started: {task_id}")
|
||||
self.record_result(test_name, True, f"Task: {task_id}")
|
||||
return True
|
||||
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Source Listing Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_sources(self) -> bool:
|
||||
"""Test getting list of sources."""
|
||||
test_name = "Sources - List All"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
self.print_info("GET /api/sources")
|
||||
response = self.get("/api/sources")
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
sources = response.json()
|
||||
self.print_success(f"Retrieved {len(sources)} sources")
|
||||
|
||||
if sources:
|
||||
first = sources[0]
|
||||
self.print_info(f"First source: {first.get('name', 'N/A')}")
|
||||
|
||||
self.record_result(test_name, True, f"{len(sources)} sources")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_sources_paginated(self) -> bool:
|
||||
"""Test getting paginated sources."""
|
||||
test_name = "Sources - Paginated"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
self.print_info("GET /api/sources/paginated")
|
||||
response = self.get("/api/sources/paginated", params={"page": 1, "per_page": 10})
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
self.print_success("Paginated sources retrieved")
|
||||
|
||||
if isinstance(result, dict):
|
||||
total = result.get("total", "N/A")
|
||||
self.print_info(f"Total sources: {total}")
|
||||
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Task Status Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_task_status(self) -> bool:
|
||||
"""Test getting task status."""
|
||||
test_name = "Task Status - Check"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# First upload a file to get a task_id
|
||||
test_content = "Test content for task status"
|
||||
files = {"file": ("task_test.txt", test_content.encode(), "text/plain")}
|
||||
data = {"user": "test_user", "name": f"Task Test {int(time.time())}"}
|
||||
|
||||
try:
|
||||
upload_response = self.post("/api/upload", files=files, data=data, timeout=30)
|
||||
|
||||
if upload_response.status_code != 200:
|
||||
self.record_result(test_name, True, "Skipped (upload failed)")
|
||||
return True
|
||||
|
||||
task_id = upload_response.json().get("task_id")
|
||||
if not task_id:
|
||||
self.record_result(test_name, True, "Skipped (no task_id)")
|
||||
return True
|
||||
|
||||
self.print_info(f"GET /api/task_status?task_id={task_id[:8]}...")
|
||||
response = self.get("/api/task_status", params={"task_id": task_id})
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
status = result.get("status", "UNKNOWN")
|
||||
self.print_success(f"Task status: {status}")
|
||||
self.record_result(test_name, True, f"Status: {status}")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Chunk Management Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_chunks(self) -> bool:
|
||||
"""Test getting chunks from a source."""
|
||||
test_name = "Chunks - Get"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
source = self.get_or_create_test_source()
|
||||
if not source:
|
||||
self.print_warning("Could not create test source")
|
||||
self.record_result(test_name, True, "Skipped (no source)")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Swagger says param is 'id', not 'source_id'
|
||||
self.print_info(f"GET /api/get_chunks?id={source['id'][:8]}...")
|
||||
response = self.get("/api/get_chunks", params={"id": source["id"]})
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
chunks = result if isinstance(result, list) else result.get("chunks", [])
|
||||
self.print_success(f"Retrieved {len(chunks)} chunks")
|
||||
self.record_result(test_name, True, f"{len(chunks)} chunks")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_add_chunk(self) -> bool:
|
||||
"""Test adding a chunk to a source."""
|
||||
test_name = "Chunks - Add"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
source = self.get_or_create_test_source()
|
||||
if not source:
|
||||
self.record_result(test_name, True, "Skipped (no source)")
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"source_id": source["id"],
|
||||
"content": f"Test chunk content added at {int(time.time())}",
|
||||
"metadata": {"test": True},
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info("POST /api/add_chunk")
|
||||
response = self.post("/api/add_chunk", json=payload)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
self.print_success("Chunk added successfully")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
else:
|
||||
# May not be supported or require specific format
|
||||
self.print_warning(f"Status {response.status_code}")
|
||||
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Delete Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_delete_by_ids(self) -> bool:
|
||||
"""Test deleting documents by vector store IDs.
|
||||
|
||||
Note: This endpoint expects vector store document IDs (chunk IDs),
|
||||
not MongoDB source IDs. Testing with non-existent IDs returns 400.
|
||||
"""
|
||||
test_name = "Sources - Delete by IDs"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Test endpoint accessibility with a test ID
|
||||
# Note: This endpoint expects vector document IDs, not source IDs
|
||||
test_id = "test-document-id-12345"
|
||||
self.print_info(f"GET /api/delete_by_ids?path={test_id}")
|
||||
response = self.get("/api/delete_by_ids", params={"path": test_id})
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
self.print_success("Delete endpoint responded successfully")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
elif response.status_code == 400:
|
||||
# 400 is expected when document ID doesn't exist in vector store
|
||||
self.print_warning("Expected 400 (ID not in vector store)")
|
||||
self.record_result(test_name, True, "Endpoint works (ID not found)")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Directory Structure Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_directory_structure(self) -> bool:
|
||||
"""Test getting directory structure."""
|
||||
test_name = "Directory Structure"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
source = self.get_or_create_test_source()
|
||||
if not source:
|
||||
self.record_result(test_name, True, "Skipped (no source)")
|
||||
return True
|
||||
|
||||
try:
|
||||
self.print_info(f"GET /api/directory_structure?source_id={source['id'][:8]}...")
|
||||
response = self.get("/api/directory_structure", params={"source_id": source["id"]})
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
response.json() # Validate JSON response
|
||||
self.print_success("Directory structure retrieved")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
else:
|
||||
# May not be supported for all source types
|
||||
self.print_warning(f"Status {response.status_code}")
|
||||
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Combine Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_combine(self) -> bool:
|
||||
"""Test combine endpoint."""
|
||||
test_name = "Sources - Combine"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
self.print_info("GET /api/combine")
|
||||
response = self.get("/api/combine")
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
self.print_success("Combine endpoint works")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
else:
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Manage Source Files Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_manage_source_files(self) -> bool:
|
||||
"""Test managing source files."""
|
||||
test_name = "Manage Source Files"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
source = self.get_or_create_test_source()
|
||||
if not source:
|
||||
self.record_result(test_name, True, "Skipped (no source)")
|
||||
return True
|
||||
|
||||
payload = {
|
||||
"source_id": source["id"],
|
||||
"action": "list",
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info("POST /api/manage_source_files")
|
||||
response = self.post("/api/manage_source_files", json=payload)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
self.print_success("Source files managed")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
else:
|
||||
# May require specific format
|
||||
self.print_warning(f"Status {response.status_code}")
|
||||
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Run All Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all source integration tests."""
|
||||
self.print_header("Source Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
|
||||
|
||||
# Upload tests
|
||||
self.test_upload_text_source()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_upload_markdown_source()
|
||||
time.sleep(1)
|
||||
|
||||
# Remote source tests
|
||||
self.test_remote_crawler_source()
|
||||
time.sleep(1)
|
||||
|
||||
# Source listing tests
|
||||
self.test_get_sources()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_get_sources_paginated()
|
||||
time.sleep(1)
|
||||
|
||||
# Task status test
|
||||
self.test_task_status()
|
||||
time.sleep(1)
|
||||
|
||||
# Chunk tests
|
||||
self.test_get_chunks()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_add_chunk()
|
||||
time.sleep(1)
|
||||
|
||||
# Directory structure
|
||||
self.test_directory_structure()
|
||||
time.sleep(1)
|
||||
|
||||
# Combine
|
||||
self.test_combine()
|
||||
time.sleep(1)
|
||||
|
||||
# Manage source files
|
||||
self.test_manage_source_files()
|
||||
time.sleep(1)
|
||||
|
||||
# Delete test (last because it removes data)
|
||||
self.test_delete_by_ids()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for standalone execution."""
|
||||
client = create_client_from_args(SourceTests, "DocsGPT Source Integration Tests")
|
||||
success = client.run_all()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
519
tests/integration/test_tools.py
Normal file
519
tests/integration/test_tools.py
Normal file
@@ -0,0 +1,519 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for DocsGPT tools management endpoints.
|
||||
|
||||
Endpoints tested:
|
||||
- /api/create_tool (POST) - Create tool
|
||||
- /api/get_tools (GET) - List tools
|
||||
- /api/update_tool (POST) - Update tool
|
||||
- /api/delete_tool (POST) - Delete tool
|
||||
- /api/update_tool_actions (POST) - Update tool actions
|
||||
- /api/update_tool_config (POST) - Update tool config
|
||||
- /api/update_tool_status (POST) - Update tool status
|
||||
- /api/available_tools (GET) - List available tools
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_tools.py
|
||||
python tests/integration/test_tools.py --base-url http://localhost:7091
|
||||
python tests/integration/test_tools.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class ToolsTests(DocsGPTTestBase):
|
||||
"""Integration tests for tools management endpoints."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Data Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_test_tool(self) -> Optional[str]:
|
||||
"""
|
||||
Get or create a test tool.
|
||||
|
||||
Returns:
|
||||
Tool ID or None if creation fails
|
||||
"""
|
||||
if hasattr(self, "_test_tool_id"):
|
||||
return self._test_tool_id
|
||||
|
||||
if not self.is_authenticated:
|
||||
return None
|
||||
|
||||
# CreateToolModel: 'name' must be an available tool type (e.g., "duckduckgo")
|
||||
# Use a tool that doesn't require config (like duckduckgo)
|
||||
# Note: status must be a boolean (False = draft, True = active)
|
||||
payload = {
|
||||
"name": "duckduckgo", # Must match available tool name
|
||||
"displayName": f"Test DuckDuckGo {int(time.time())}",
|
||||
"description": "Integration test tool",
|
||||
"config": {},
|
||||
"status": False, # Boolean: False = draft
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_tool", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
tool_id = result.get("id")
|
||||
if tool_id:
|
||||
self._test_tool_id = tool_id
|
||||
return tool_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def cleanup_test_tool(self, tool_id: str) -> None:
|
||||
"""Delete a test tool (cleanup helper)."""
|
||||
if not self.is_authenticated:
|
||||
return
|
||||
try:
|
||||
self.post("/api/delete_tool", json={"id": tool_id}, timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Create Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_create_tool(self) -> bool:
|
||||
"""Test creating a tool instance from available tools."""
|
||||
test_name = "Create tool"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# 'name' must be an available tool type (e.g., "duckduckgo", "cryptoprice")
|
||||
# Note: status must be a boolean (False = draft, True = active)
|
||||
payload = {
|
||||
"name": "cryptoprice", # A tool that needs no config
|
||||
"displayName": f"Test CryptoPrice {int(time.time())}",
|
||||
"description": "Integration test created tool",
|
||||
"config": {},
|
||||
"status": False, # Boolean: False = draft
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_tool", json=payload, timeout=10)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
self.print_error(f"Expected 200/201, got {response.status_code}")
|
||||
self.print_error(f"Response: {response.text[:200]}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
tool_id = result.get("id")
|
||||
|
||||
if not tool_id:
|
||||
self.print_error("No tool ID returned")
|
||||
self.record_result(test_name, False, "No tool ID")
|
||||
return False
|
||||
|
||||
self.print_success(f"Created tool: {tool_id}")
|
||||
self.print_info(f"Name: {payload['name']}")
|
||||
self.record_result(test_name, True, f"ID: {tool_id}")
|
||||
|
||||
# Cleanup
|
||||
self.cleanup_test_tool(tool_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_create_tool_with_config(self) -> bool:
|
||||
"""Test creating a tool that requires configuration."""
|
||||
test_name = "Create tool with config"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Use api_tool which has flexible config requirements
|
||||
# Note: status must be a boolean (False = draft, True = active)
|
||||
payload = {
|
||||
"name": "api_tool",
|
||||
"displayName": f"Test API Tool {int(time.time())}",
|
||||
"description": "Tool with custom config",
|
||||
"config": {"base_url": "https://api.example.com"},
|
||||
"status": False, # Boolean: False = draft
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_tool", json=payload, timeout=10)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
self.print_error(f"Expected 200/201, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
tool_id = result.get("id")
|
||||
|
||||
if not tool_id:
|
||||
self.print_error("No tool ID returned")
|
||||
self.record_result(test_name, False, "No tool ID")
|
||||
return False
|
||||
|
||||
self.print_success(f"Created tool with actions: {tool_id}")
|
||||
self.record_result(test_name, True, f"ID: {tool_id}")
|
||||
|
||||
# Cleanup
|
||||
self.cleanup_test_tool(tool_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Read Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_get_tools(self) -> bool:
|
||||
"""Test listing all tools."""
|
||||
test_name = "List tools"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
try:
|
||||
response = self.get("/api/get_tools", timeout=10)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
# Handle both list and object responses
|
||||
if isinstance(result, list):
|
||||
self.print_success(f"Retrieved {len(result)} tools")
|
||||
if result:
|
||||
self.print_info(f"First: {result[0].get('name', 'N/A')}")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
elif isinstance(result, dict):
|
||||
# May return object with tools array
|
||||
tools = result.get("tools", result.get("data", []))
|
||||
if isinstance(tools, list):
|
||||
self.print_success(f"Retrieved {len(tools)} tools")
|
||||
else:
|
||||
self.print_success("Retrieved tools data")
|
||||
self.record_result(test_name, True, "Tools retrieved")
|
||||
else:
|
||||
self.print_warning(f"Unexpected response type: {type(result)}")
|
||||
self.record_result(test_name, True, "Response received")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_get_available_tools(self) -> bool:
|
||||
"""Test listing available tool types."""
|
||||
test_name = "List available tools"
|
||||
self.print_header(test_name)
|
||||
|
||||
try:
|
||||
response = self.get("/api/available_tools", timeout=10)
|
||||
|
||||
if not self.assert_status(response, 200, test_name):
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
|
||||
# Handle both list and object responses
|
||||
if isinstance(result, list):
|
||||
self.print_success(f"Retrieved {len(result)} available tool types")
|
||||
if result:
|
||||
first = result[0]
|
||||
name = first.get('name', first) if isinstance(first, dict) else first
|
||||
self.print_info(f"First: {name}")
|
||||
self.record_result(test_name, True, f"Count: {len(result)}")
|
||||
elif isinstance(result, dict):
|
||||
# May return object with tools array
|
||||
tools = result.get("tools", result.get("available", result.get("data", [])))
|
||||
if isinstance(tools, list):
|
||||
self.print_success(f"Retrieved {len(tools)} available tools")
|
||||
else:
|
||||
self.print_success("Retrieved available tools data")
|
||||
self.record_result(test_name, True, "Tools retrieved")
|
||||
else:
|
||||
self.print_warning(f"Unexpected response type: {type(result)}")
|
||||
self.record_result(test_name, True, "Response received")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Update Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_update_tool(self) -> bool:
|
||||
"""Test updating a tool."""
|
||||
test_name = "Update tool"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
tool_id = self.get_or_create_test_tool()
|
||||
if not tool_id:
|
||||
self.print_warning("Could not create test tool")
|
||||
self.record_result(test_name, True, "Skipped (no tool)")
|
||||
return True
|
||||
|
||||
new_description = f"Updated at {int(time.time())}"
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/update_tool",
|
||||
json={"id": tool_id, "description": new_description},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
self.print_success("Tool updated successfully")
|
||||
self.record_result(test_name, True, "Tool updated")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Update failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_update_tool_actions(self) -> bool:
|
||||
"""Test updating tool actions."""
|
||||
test_name = "Update tool actions"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
tool_id = self.get_or_create_test_tool()
|
||||
if not tool_id:
|
||||
self.print_warning("Could not create test tool")
|
||||
self.record_result(test_name, True, "Skipped (no tool)")
|
||||
return True
|
||||
|
||||
new_actions = [
|
||||
{
|
||||
"name": "new_action",
|
||||
"description": "New action added",
|
||||
"parameters": {},
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/update_tool_actions",
|
||||
json={"id": tool_id, "actions": new_actions},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
self.print_success("Tool actions updated")
|
||||
self.record_result(test_name, True, "Actions updated")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Update failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_update_tool_config(self) -> bool:
|
||||
"""Test updating tool configuration."""
|
||||
test_name = "Update tool config"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
tool_id = self.get_or_create_test_tool()
|
||||
if not tool_id:
|
||||
self.print_warning("Could not create test tool")
|
||||
self.record_result(test_name, True, "Skipped (no tool)")
|
||||
return True
|
||||
|
||||
new_config = {"api_key": "updated_key", "timeout": 30}
|
||||
|
||||
try:
|
||||
response = self.post(
|
||||
"/api/update_tool_config",
|
||||
json={"id": tool_id, "config": new_config},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
self.print_success("Tool config updated")
|
||||
self.record_result(test_name, True, "Config updated")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Update failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_update_tool_status(self) -> bool:
|
||||
"""Test updating tool status."""
|
||||
test_name = "Update tool status"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
tool_id = self.get_or_create_test_tool()
|
||||
if not tool_id:
|
||||
self.print_warning("Could not create test tool")
|
||||
self.record_result(test_name, True, "Skipped (no tool)")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Status is a boolean in UpdateToolStatusModel
|
||||
response = self.post(
|
||||
"/api/update_tool_status",
|
||||
json={"id": tool_id, "status": True},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
self.print_success("Tool status updated to active")
|
||||
self.record_result(test_name, True, "Status updated")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Update failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Delete Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_delete_tool(self) -> bool:
|
||||
"""Test deleting a tool."""
|
||||
test_name = "Delete tool"
|
||||
self.print_header(test_name)
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
# Create a tool specifically for deletion - must use available tool name
|
||||
# Note: status must be a boolean (False = draft, True = active)
|
||||
payload = {
|
||||
"name": "duckduckgo",
|
||||
"displayName": f"Tool to Delete {int(time.time())}",
|
||||
"description": "Will be deleted",
|
||||
"config": {},
|
||||
"status": False, # Boolean: False = draft
|
||||
}
|
||||
|
||||
try:
|
||||
create_response = self.post("/api/create_tool", json=payload, timeout=10)
|
||||
if create_response.status_code not in [200, 201]:
|
||||
self.print_warning("Could not create tool for deletion")
|
||||
self.record_result(test_name, True, "Skipped (create failed)")
|
||||
return True
|
||||
|
||||
tool_id = create_response.json().get("id")
|
||||
|
||||
# Delete the tool (DeleteToolModel requires 'id')
|
||||
response = self.post("/api/delete_tool", json={"id": tool_id}, timeout=10)
|
||||
|
||||
if response.status_code in [200, 204]:
|
||||
self.print_success(f"Deleted tool: {tool_id}")
|
||||
self.record_result(test_name, True, "Tool deleted")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Delete failed: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Exception: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Runner
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all tools tests."""
|
||||
self.print_header("DocsGPT Tools Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Auth: {self.token_source}")
|
||||
|
||||
# Create tests
|
||||
self.test_create_tool()
|
||||
self.test_create_tool_with_config()
|
||||
|
||||
# Read tests
|
||||
self.test_get_tools()
|
||||
self.test_get_available_tools()
|
||||
|
||||
# Update tests
|
||||
self.test_update_tool()
|
||||
self.test_update_tool_actions()
|
||||
self.test_update_tool_config()
|
||||
self.test_update_tool_status()
|
||||
|
||||
# Delete tests
|
||||
self.test_delete_tool()
|
||||
|
||||
# Cleanup
|
||||
if hasattr(self, "_test_tool_id"):
|
||||
self.cleanup_test_tool(self._test_tool_id)
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(ToolsTests, "DocsGPT Tools Integration Tests")
|
||||
exit_code = 0 if client.run_all() else 1
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -650,7 +650,7 @@ DocsGPT provides:
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
answer = result.get('answer', '')
|
||||
answer = result.get('answer') or ''
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
self.test_results.append((test_name, True, "Success"))
|
||||
return True
|
||||
@@ -877,7 +877,6 @@ DocsGPT provides:
|
||||
payload = {
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"model_id": "gemini-2.5-pro",
|
||||
}
|
||||
|
||||
# Use agent if available, otherwise isNoneDoc
|
||||
@@ -902,7 +901,7 @@ DocsGPT provides:
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
current_conv_id = result.get('conversation_id', current_conv_id)
|
||||
answer_preview = result.get('answer', '')[:80]
|
||||
answer_preview = (result.get('answer') or '')[:80]
|
||||
self.print_success(f"Request {i+1}/10 completed (conv_id: {current_conv_id})")
|
||||
self.print_info(f" Answer preview: {answer_preview}...")
|
||||
else:
|
||||
@@ -958,7 +957,6 @@ DocsGPT provides:
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"model_id": "gemini-2.5-pro",
|
||||
}
|
||||
|
||||
if conversation_id:
|
||||
@@ -986,7 +984,6 @@ DocsGPT provides:
|
||||
"question": "Please remember this critical information: The production database password is stored in DB_PASSWORD_PROD environment variable. The backup runs at 3:00 AM UTC daily. Premium users have 10,000 req/hour limit.",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"model_id": "gemini-2.5-pro",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
@@ -1016,7 +1013,6 @@ DocsGPT provides:
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"model_id": "gemini-2.5-pro",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
@@ -1042,7 +1038,6 @@ DocsGPT provides:
|
||||
"question": "What was the database password environment variable I mentioned earlier?",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"model_id": "gemini-2.5-pro",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
@@ -1050,7 +1045,7 @@ DocsGPT provides:
|
||||
response = requests.post(endpoint, json=recall_payload, headers=self.headers, timeout=60)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
answer = result.get('answer', '').lower()
|
||||
answer = (result.get('answer') or '').lower()
|
||||
|
||||
# Check if the critical info was preserved
|
||||
if 'db_password_prod' in answer or 'database password' in answer:
|
||||
@@ -1191,7 +1186,7 @@ DocsGPT provides:
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
answer = result.get('answer', '')
|
||||
answer = result.get('answer') or ''
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
|
||||
if any(word in answer.lower() for word in ['install', 'docker', 'setup']):
|
||||
|
||||
Reference in New Issue
Block a user