Files
DocsGPT/tests/integration/base.py
2026-01-16 13:10:27 +03:00

486 lines
15 KiB
Python

"""
Base classes and utilities for DocsGPT integration tests.
This module provides:
- Colors: ANSI color codes for terminal output
- DocsGPTTestBase: Base class with HTTP helpers and output utilities
- generate_jwt_token: JWT token generation for authentication
- create_client_from_args: Factory function to create client from CLI args
"""
import argparse
import json as json_module
import os
from pathlib import Path
from typing import Any, Iterator, Optional, Type, TypeVar
import requests
T = TypeVar("T", bound="DocsGPTTestBase")
class Colors:
"""ANSI color codes for terminal output."""
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
def generate_jwt_token() -> tuple[Optional[str], Optional[str]]:
"""
Generate a JWT token using local secret or environment variable.
Returns:
Tuple of (token, error_message). Token is None on failure.
"""
secret = os.getenv("JWT_SECRET_KEY")
key_file = Path(".jwt_secret_key")
if not secret:
try:
secret = key_file.read_text().strip()
except FileNotFoundError:
return None, f"Set JWT_SECRET_KEY or create {key_file} by running the backend once."
except OSError as exc:
return None, f"Could not read {key_file}: {exc}"
if not secret:
return None, "JWT secret key is empty."
try:
from jose import jwt
except ImportError:
return None, "python-jose is not installed (pip install 'python-jose' to auto-generate tokens)."
try:
payload = {"sub": "test_integration_user"}
return jwt.encode(payload, secret, algorithm="HS256"), None
except Exception as exc:
return None, f"Failed to generate JWT token: {exc}"
class DocsGPTTestBase:
"""
Base class for DocsGPT integration tests.
Provides HTTP helpers, SSE streaming, output formatting, and result tracking.
Usage:
client = DocsGPTTestBase("http://localhost:7091", token="...")
response = client.post("/api/answer", json={"question": "test"})
"""
def __init__(
self,
base_url: str,
token: Optional[str] = None,
token_source: str = "provided",
):
"""
Initialize test client.
Args:
base_url: Base URL of DocsGPT instance (e.g., "http://localhost:7091")
token: Optional JWT authentication token
token_source: Description of token source for logging
"""
self.base_url = base_url.rstrip("/")
self.token = token
self.token_source = token_source
self.headers: dict[str, str] = {}
if token:
self.headers["Authorization"] = f"Bearer {token}"
self.test_results: list[tuple[str, bool, str]] = []
# -------------------------------------------------------------------------
# HTTP Helper Methods
# -------------------------------------------------------------------------
def get(
self,
path: str,
params: Optional[dict[str, Any]] = None,
timeout: int = 30,
**kwargs: Any,
) -> requests.Response:
"""
Make a GET request.
Args:
path: API path (e.g., "/api/sources")
params: Optional query parameters
timeout: Request timeout in seconds
**kwargs: Additional arguments passed to requests.get
Returns:
Response object
"""
url = f"{self.base_url}{path}"
return requests.get(
url,
params=params,
headers={**self.headers, **kwargs.pop("headers", {})},
timeout=timeout,
**kwargs,
)
def post(
self,
path: str,
json: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
files: Optional[dict[str, Any]] = None,
timeout: int = 30,
**kwargs: Any,
) -> requests.Response:
"""
Make a POST request.
Args:
path: API path (e.g., "/api/answer")
json: Optional JSON body
data: Optional form data
files: Optional files for multipart upload
timeout: Request timeout in seconds
**kwargs: Additional arguments passed to requests.post
Returns:
Response object
"""
url = f"{self.base_url}{path}"
return requests.post(
url,
json=json,
data=data,
files=files,
headers={**self.headers, **kwargs.pop("headers", {})},
timeout=timeout,
**kwargs,
)
def put(
self,
path: str,
json: Optional[dict[str, Any]] = None,
timeout: int = 30,
**kwargs: Any,
) -> requests.Response:
"""
Make a PUT request.
Args:
path: API path (e.g., "/api/update_agent/123")
json: Optional JSON body
timeout: Request timeout in seconds
**kwargs: Additional arguments passed to requests.put
Returns:
Response object
"""
url = f"{self.base_url}{path}"
return requests.put(
url,
json=json,
headers={**self.headers, **kwargs.pop("headers", {})},
timeout=timeout,
**kwargs,
)
def delete(
self,
path: str,
json: Optional[dict[str, Any]] = None,
timeout: int = 30,
**kwargs: Any,
) -> requests.Response:
"""
Make a DELETE request.
Args:
path: API path (e.g., "/api/delete_agent")
json: Optional JSON body
timeout: Request timeout in seconds
**kwargs: Additional arguments passed to requests.delete
Returns:
Response object
"""
url = f"{self.base_url}{path}"
return requests.delete(
url,
json=json,
headers={**self.headers, **kwargs.pop("headers", {})},
timeout=timeout,
**kwargs,
)
def post_stream(
self,
path: str,
json: Optional[dict[str, Any]] = None,
timeout: int = 60,
**kwargs: Any,
) -> Iterator[dict[str, Any]]:
"""
Make a streaming POST request and yield SSE events.
Args:
path: API path (e.g., "/stream")
json: Optional JSON body
timeout: Request timeout in seconds
**kwargs: Additional arguments passed to requests.post
Yields:
Parsed JSON data from each SSE event
Example:
for event in client.post_stream("/stream", json={"question": "test"}):
if event.get("type") == "answer":
print(event.get("message"))
"""
url = f"{self.base_url}{path}"
response = requests.post(
url,
json=json,
headers={**self.headers, **kwargs.pop("headers", {})},
stream=True,
timeout=timeout,
**kwargs,
)
# Store response for status code checking
self._last_stream_response = response
if response.status_code != 200:
# Yield error event for non-200 responses
yield {"type": "error", "status_code": response.status_code, "text": response.text[:500]}
return
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
data_str = line_str[6:] # Remove 'data: ' prefix
try:
data = json_module.loads(data_str)
yield data
if data.get("type") == "end":
break
except json_module.JSONDecodeError:
pass
# -------------------------------------------------------------------------
# Output Helper Methods
# -------------------------------------------------------------------------
def print_header(self, message: str) -> None:
"""Print a colored header."""
print(f"\n{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}")
print(f"{Colors.HEADER}{Colors.BOLD}{message}{Colors.ENDC}")
print(f"{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}\n")
def print_success(self, message: str) -> None:
"""Print a success message."""
print(f"{Colors.OKGREEN}[PASS] {message}{Colors.ENDC}")
def print_error(self, message: str) -> None:
"""Print an error message."""
print(f"{Colors.FAIL}[FAIL] {message}{Colors.ENDC}")
def print_info(self, message: str) -> None:
"""Print an info message."""
print(f"{Colors.OKCYAN}[INFO] {message}{Colors.ENDC}")
def print_warning(self, message: str) -> None:
"""Print a warning message."""
print(f"{Colors.WARNING}[WARN] {message}{Colors.ENDC}")
# -------------------------------------------------------------------------
# Result Tracking Methods
# -------------------------------------------------------------------------
def record_result(self, test_name: str, success: bool, message: str) -> None:
"""
Record a test result.
Args:
test_name: Name of the test
success: Whether the test passed
message: Result message or error details
"""
self.test_results.append((test_name, success, message))
def print_summary(self) -> bool:
"""
Print test results summary.
Returns:
True if all tests passed, False otherwise
"""
self.print_header("Test Results Summary")
passed = sum(1 for _, success, _ in self.test_results if success)
failed = len(self.test_results) - passed
print(f"\n{Colors.BOLD}Total Tests: {len(self.test_results)}{Colors.ENDC}")
print(f"{Colors.OKGREEN}Passed: {passed}{Colors.ENDC}")
print(f"{Colors.FAIL}Failed: {failed}{Colors.ENDC}\n")
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
for test_name, success, message in self.test_results:
status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if success else f"{Colors.FAIL}FAIL{Colors.ENDC}"
print(f" {status} - {test_name}: {message}")
print()
return failed == 0
# -------------------------------------------------------------------------
# Assertion Helpers
# -------------------------------------------------------------------------
def assert_status(
self,
response: requests.Response,
expected: int,
test_name: str,
) -> bool:
"""
Assert response status code and record result.
Args:
response: Response object to check
expected: Expected status code
test_name: Name of the test for recording
Returns:
True if status matches, False otherwise
"""
if response.status_code == expected:
return True
else:
self.print_error(f"Expected {expected}, got {response.status_code}")
self.print_error(f"Response: {response.text[:500]}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
def assert_json_key(
self,
data: dict[str, Any],
key: str,
test_name: str,
) -> bool:
"""
Assert JSON response contains a key.
Args:
data: JSON response data
key: Expected key
test_name: Name of the test for recording
Returns:
True if key exists, False otherwise
"""
if key in data:
return True
else:
self.print_error(f"Missing key '{key}' in response")
self.record_result(test_name, False, f"Missing key: {key}")
return False
# -------------------------------------------------------------------------
# Convenience Properties
# -------------------------------------------------------------------------
@property
def is_authenticated(self) -> bool:
"""Check if client has authentication token."""
return self.token is not None
def require_auth(self, test_name: str) -> bool:
"""
Check authentication and record skip if not authenticated.
Args:
test_name: Name of the test
Returns:
True if authenticated, False otherwise (test skipped)
"""
if not self.is_authenticated:
self.print_warning("No authentication token provided")
self.print_info("Skipping test (auth required)")
self.record_result(test_name, True, "Skipped (auth required)")
return False
return True
# -----------------------------------------------------------------------------
# Factory Function
# -----------------------------------------------------------------------------
def create_client_from_args(
client_class: Type[T],
description: str = "DocsGPT Integration Tests",
) -> T:
"""
Create a test client from command-line arguments.
Parses --base-url and --token arguments, and handles JWT token generation.
Args:
client_class: The test class to instantiate (must inherit from DocsGPTTestBase)
description: Description for the argument parser
Returns:
An instance of the provided client_class
Example:
class ChatTests(DocsGPTTestBase):
def run_all(self):
...
if __name__ == "__main__":
client = create_client_from_args(ChatTests)
sys.exit(0 if client.run_all() else 1)
"""
parser = argparse.ArgumentParser(
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--base-url",
default=os.getenv("DOCSGPT_BASE_URL", "http://localhost:7091"),
help="Base URL of DocsGPT instance (default: http://localhost:7091)",
)
parser.add_argument(
"--token",
default=os.getenv("JWT_TOKEN"),
help="JWT authentication token (auto-generated from local secret when available)",
)
args = parser.parse_args()
# Determine token and source
token = args.token
token_source = "provided via --token" if token else "none"
if not token:
token, token_error = generate_jwt_token()
if token:
token_source = "auto-generated from local secret"
print(f"{Colors.OKCYAN}[INFO] Using auto-generated JWT token{Colors.ENDC}")
elif token_error:
print(f"{Colors.WARNING}[WARN] Could not auto-generate token: {token_error}{Colors.ENDC}")
print(f"{Colors.WARNING}[WARN] Tests requiring auth will be skipped{Colors.ENDC}")
return client_class(args.base_url, token=token, token_source=token_source)