mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-28 07:11:21 +00:00
End 2 end tests (#2266)
* All endpoints covered test_integration.py kept for backwards compatability. tests/integration/run_all.py proposed as alternative to cover all endpoints. * Linter fixes
This commit is contained in:
485
tests/integration/base.py
Normal file
485
tests/integration/base.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Base classes and utilities for DocsGPT integration tests.
|
||||
|
||||
This module provides:
|
||||
- Colors: ANSI color codes for terminal output
|
||||
- DocsGPTTestBase: Base class with HTTP helpers and output utilities
|
||||
- generate_jwt_token: JWT token generation for authentication
|
||||
- create_client_from_args: Factory function to create client from CLI args
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json as json_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, Optional, Type, TypeVar
|
||||
|
||||
import requests
|
||||
|
||||
T = TypeVar("T", bound="DocsGPTTestBase")
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output."""
|
||||
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
|
||||
|
||||
def generate_jwt_token() -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Generate a JWT token using local secret or environment variable.
|
||||
|
||||
Returns:
|
||||
Tuple of (token, error_message). Token is None on failure.
|
||||
"""
|
||||
secret = os.getenv("JWT_SECRET_KEY")
|
||||
key_file = Path(".jwt_secret_key")
|
||||
|
||||
if not secret:
|
||||
try:
|
||||
secret = key_file.read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return None, f"Set JWT_SECRET_KEY or create {key_file} by running the backend once."
|
||||
except OSError as exc:
|
||||
return None, f"Could not read {key_file}: {exc}"
|
||||
|
||||
if not secret:
|
||||
return None, "JWT secret key is empty."
|
||||
|
||||
try:
|
||||
from jose import jwt
|
||||
except ImportError:
|
||||
return None, "python-jose is not installed (pip install 'python-jose' to auto-generate tokens)."
|
||||
|
||||
try:
|
||||
payload = {"sub": "test_integration_user"}
|
||||
return jwt.encode(payload, secret, algorithm="HS256"), None
|
||||
except Exception as exc:
|
||||
return None, f"Failed to generate JWT token: {exc}"
|
||||
|
||||
|
||||
class DocsGPTTestBase:
|
||||
"""
|
||||
Base class for DocsGPT integration tests.
|
||||
|
||||
Provides HTTP helpers, SSE streaming, output formatting, and result tracking.
|
||||
|
||||
Usage:
|
||||
client = DocsGPTTestBase("http://localhost:7091", token="...")
|
||||
response = client.post("/api/answer", json={"question": "test"})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
token: Optional[str] = None,
|
||||
token_source: str = "provided",
|
||||
):
|
||||
"""
|
||||
Initialize test client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of DocsGPT instance (e.g., "http://localhost:7091")
|
||||
token: Optional JWT authentication token
|
||||
token_source: Description of token source for logging
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.token_source = token_source
|
||||
self.headers: dict[str, str] = {}
|
||||
if token:
|
||||
self.headers["Authorization"] = f"Bearer {token}"
|
||||
self.test_results: list[tuple[str, bool, str]] = []
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# HTTP Helper Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a GET request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/sources")
|
||||
params: Optional query parameters
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.get
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.get(
|
||||
url,
|
||||
params=params,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def post(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
files: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a POST request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/answer")
|
||||
json: Optional JSON body
|
||||
data: Optional form data
|
||||
files: Optional files for multipart upload
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.post
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.post(
|
||||
url,
|
||||
json=json,
|
||||
data=data,
|
||||
files=files,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a PUT request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/update_agent/123")
|
||||
json: Optional JSON body
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.put
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.put(
|
||||
url,
|
||||
json=json,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a DELETE request.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/api/delete_agent")
|
||||
json: Optional JSON body
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.delete
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
return requests.delete(
|
||||
url,
|
||||
json=json,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def post_stream(
|
||||
self,
|
||||
path: str,
|
||||
json: Optional[dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
Make a streaming POST request and yield SSE events.
|
||||
|
||||
Args:
|
||||
path: API path (e.g., "/stream")
|
||||
json: Optional JSON body
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.post
|
||||
|
||||
Yields:
|
||||
Parsed JSON data from each SSE event
|
||||
|
||||
Example:
|
||||
for event in client.post_stream("/stream", json={"question": "test"}):
|
||||
if event.get("type") == "answer":
|
||||
print(event.get("message"))
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
response = requests.post(
|
||||
url,
|
||||
json=json,
|
||||
headers={**self.headers, **kwargs.pop("headers", {})},
|
||||
stream=True,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Store response for status code checking
|
||||
self._last_stream_response = response
|
||||
|
||||
if response.status_code != 200:
|
||||
# Yield error event for non-200 responses
|
||||
yield {"type": "error", "status_code": response.status_code, "text": response.text[:500]}
|
||||
return
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:] # Remove 'data: ' prefix
|
||||
try:
|
||||
data = json_module.loads(data_str)
|
||||
yield data
|
||||
if data.get("type") == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Output Helper Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def print_header(self, message: str) -> None:
|
||||
"""Print a colored header."""
|
||||
print(f"\n{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{message}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}\n")
|
||||
|
||||
def print_success(self, message: str) -> None:
|
||||
"""Print a success message."""
|
||||
print(f"{Colors.OKGREEN}[PASS] {message}{Colors.ENDC}")
|
||||
|
||||
def print_error(self, message: str) -> None:
|
||||
"""Print an error message."""
|
||||
print(f"{Colors.FAIL}[FAIL] {message}{Colors.ENDC}")
|
||||
|
||||
def print_info(self, message: str) -> None:
|
||||
"""Print an info message."""
|
||||
print(f"{Colors.OKCYAN}[INFO] {message}{Colors.ENDC}")
|
||||
|
||||
def print_warning(self, message: str) -> None:
|
||||
"""Print a warning message."""
|
||||
print(f"{Colors.WARNING}[WARN] {message}{Colors.ENDC}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Result Tracking Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def record_result(self, test_name: str, success: bool, message: str) -> None:
|
||||
"""
|
||||
Record a test result.
|
||||
|
||||
Args:
|
||||
test_name: Name of the test
|
||||
success: Whether the test passed
|
||||
message: Result message or error details
|
||||
"""
|
||||
self.test_results.append((test_name, success, message))
|
||||
|
||||
def print_summary(self) -> bool:
|
||||
"""
|
||||
Print test results summary.
|
||||
|
||||
Returns:
|
||||
True if all tests passed, False otherwise
|
||||
"""
|
||||
self.print_header("Test Results Summary")
|
||||
|
||||
passed = sum(1 for _, success, _ in self.test_results if success)
|
||||
failed = len(self.test_results) - passed
|
||||
|
||||
print(f"\n{Colors.BOLD}Total Tests: {len(self.test_results)}{Colors.ENDC}")
|
||||
print(f"{Colors.OKGREEN}Passed: {passed}{Colors.ENDC}")
|
||||
print(f"{Colors.FAIL}Failed: {failed}{Colors.ENDC}\n")
|
||||
|
||||
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
|
||||
for test_name, success, message in self.test_results:
|
||||
status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if success else f"{Colors.FAIL}FAIL{Colors.ENDC}"
|
||||
print(f" {status} - {test_name}: {message}")
|
||||
|
||||
print()
|
||||
return failed == 0
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Assertion Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def assert_status(
|
||||
self,
|
||||
response: requests.Response,
|
||||
expected: int,
|
||||
test_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Assert response status code and record result.
|
||||
|
||||
Args:
|
||||
response: Response object to check
|
||||
expected: Expected status code
|
||||
test_name: Name of the test for recording
|
||||
|
||||
Returns:
|
||||
True if status matches, False otherwise
|
||||
"""
|
||||
if response.status_code == expected:
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected {expected}, got {response.status_code}")
|
||||
self.print_error(f"Response: {response.text[:500]}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
def assert_json_key(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
key: str,
|
||||
test_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Assert JSON response contains a key.
|
||||
|
||||
Args:
|
||||
data: JSON response data
|
||||
key: Expected key
|
||||
test_name: Name of the test for recording
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise
|
||||
"""
|
||||
if key in data:
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Missing key '{key}' in response")
|
||||
self.record_result(test_name, False, f"Missing key: {key}")
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Convenience Properties
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if client has authentication token."""
|
||||
return self.token is not None
|
||||
|
||||
def require_auth(self, test_name: str) -> bool:
|
||||
"""
|
||||
Check authentication and record skip if not authenticated.
|
||||
|
||||
Args:
|
||||
test_name: Name of the test
|
||||
|
||||
Returns:
|
||||
True if authenticated, False otherwise (test skipped)
|
||||
"""
|
||||
if not self.is_authenticated:
|
||||
self.print_warning("No authentication token provided")
|
||||
self.print_info("Skipping test (auth required)")
|
||||
self.record_result(test_name, True, "Skipped (auth required)")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Factory Function
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_client_from_args(
|
||||
client_class: Type[T],
|
||||
description: str = "DocsGPT Integration Tests",
|
||||
) -> T:
|
||||
"""
|
||||
Create a test client from command-line arguments.
|
||||
|
||||
Parses --base-url and --token arguments, and handles JWT token generation.
|
||||
|
||||
Args:
|
||||
client_class: The test class to instantiate (must inherit from DocsGPTTestBase)
|
||||
description: Description for the argument parser
|
||||
|
||||
Returns:
|
||||
An instance of the provided client_class
|
||||
|
||||
Example:
|
||||
class ChatTests(DocsGPTTestBase):
|
||||
def run_all(self):
|
||||
...
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = create_client_from_args(ChatTests)
|
||||
sys.exit(0 if client.run_all() else 1)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.getenv("DOCSGPT_BASE_URL", "http://localhost:7091"),
|
||||
help="Base URL of DocsGPT instance (default: http://localhost:7091)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
default=os.getenv("JWT_TOKEN"),
|
||||
help="JWT authentication token (auto-generated from local secret when available)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine token and source
|
||||
token = args.token
|
||||
token_source = "provided via --token" if token else "none"
|
||||
|
||||
if not token:
|
||||
token, token_error = generate_jwt_token()
|
||||
if token:
|
||||
token_source = "auto-generated from local secret"
|
||||
print(f"{Colors.OKCYAN}[INFO] Using auto-generated JWT token{Colors.ENDC}")
|
||||
elif token_error:
|
||||
print(f"{Colors.WARNING}[WARN] Could not auto-generate token: {token_error}{Colors.ENDC}")
|
||||
print(f"{Colors.WARNING}[WARN] Tests requiring auth will be skipped{Colors.ENDC}")
|
||||
|
||||
return client_class(args.base_url, token=token, token_source=token_source)
|
||||
Reference in New Issue
Block a user