mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 14:00:55 +00:00
486 lines
15 KiB
Python
486 lines
15 KiB
Python
"""
|
|
Base classes and utilities for DocsGPT integration tests.
|
|
|
|
This module provides:
|
|
- Colors: ANSI color codes for terminal output
|
|
- DocsGPTTestBase: Base class with HTTP helpers and output utilities
|
|
- generate_jwt_token: JWT token generation for authentication
|
|
- create_client_from_args: Factory function to create client from CLI args
|
|
"""
|
|
|
|
import argparse
|
|
import json as json_module
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Iterator, Optional, Type, TypeVar
|
|
|
|
import requests
|
|
|
|
T = TypeVar("T", bound="DocsGPTTestBase")
|
|
|
|
|
|
class Colors:
|
|
"""ANSI color codes for terminal output."""
|
|
|
|
HEADER = "\033[95m"
|
|
OKBLUE = "\033[94m"
|
|
OKCYAN = "\033[96m"
|
|
OKGREEN = "\033[92m"
|
|
WARNING = "\033[93m"
|
|
FAIL = "\033[91m"
|
|
ENDC = "\033[0m"
|
|
BOLD = "\033[1m"
|
|
|
|
|
|
def generate_jwt_token() -> tuple[Optional[str], Optional[str]]:
|
|
"""
|
|
Generate a JWT token using local secret or environment variable.
|
|
|
|
Returns:
|
|
Tuple of (token, error_message). Token is None on failure.
|
|
"""
|
|
secret = os.getenv("JWT_SECRET_KEY")
|
|
key_file = Path(".jwt_secret_key")
|
|
|
|
if not secret:
|
|
try:
|
|
secret = key_file.read_text().strip()
|
|
except FileNotFoundError:
|
|
return None, f"Set JWT_SECRET_KEY or create {key_file} by running the backend once."
|
|
except OSError as exc:
|
|
return None, f"Could not read {key_file}: {exc}"
|
|
|
|
if not secret:
|
|
return None, "JWT secret key is empty."
|
|
|
|
try:
|
|
from jose import jwt
|
|
except ImportError:
|
|
return None, "python-jose is not installed (pip install 'python-jose' to auto-generate tokens)."
|
|
|
|
try:
|
|
payload = {"sub": "test_integration_user"}
|
|
return jwt.encode(payload, secret, algorithm="HS256"), None
|
|
except Exception as exc:
|
|
return None, f"Failed to generate JWT token: {exc}"
|
|
|
|
|
|
class DocsGPTTestBase:
|
|
"""
|
|
Base class for DocsGPT integration tests.
|
|
|
|
Provides HTTP helpers, SSE streaming, output formatting, and result tracking.
|
|
|
|
Usage:
|
|
client = DocsGPTTestBase("http://localhost:7091", token="...")
|
|
response = client.post("/api/answer", json={"question": "test"})
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
token: Optional[str] = None,
|
|
token_source: str = "provided",
|
|
):
|
|
"""
|
|
Initialize test client.
|
|
|
|
Args:
|
|
base_url: Base URL of DocsGPT instance (e.g., "http://localhost:7091")
|
|
token: Optional JWT authentication token
|
|
token_source: Description of token source for logging
|
|
"""
|
|
self.base_url = base_url.rstrip("/")
|
|
self.token = token
|
|
self.token_source = token_source
|
|
self.headers: dict[str, str] = {}
|
|
if token:
|
|
self.headers["Authorization"] = f"Bearer {token}"
|
|
self.test_results: list[tuple[str, bool, str]] = []
|
|
|
|
# -------------------------------------------------------------------------
|
|
# HTTP Helper Methods
|
|
# -------------------------------------------------------------------------
|
|
|
|
def get(
|
|
self,
|
|
path: str,
|
|
params: Optional[dict[str, Any]] = None,
|
|
timeout: int = 30,
|
|
**kwargs: Any,
|
|
) -> requests.Response:
|
|
"""
|
|
Make a GET request.
|
|
|
|
Args:
|
|
path: API path (e.g., "/api/sources")
|
|
params: Optional query parameters
|
|
timeout: Request timeout in seconds
|
|
**kwargs: Additional arguments passed to requests.get
|
|
|
|
Returns:
|
|
Response object
|
|
"""
|
|
url = f"{self.base_url}{path}"
|
|
return requests.get(
|
|
url,
|
|
params=params,
|
|
headers={**self.headers, **kwargs.pop("headers", {})},
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
def post(
|
|
self,
|
|
path: str,
|
|
json: Optional[dict[str, Any]] = None,
|
|
data: Optional[dict[str, Any]] = None,
|
|
files: Optional[dict[str, Any]] = None,
|
|
timeout: int = 30,
|
|
**kwargs: Any,
|
|
) -> requests.Response:
|
|
"""
|
|
Make a POST request.
|
|
|
|
Args:
|
|
path: API path (e.g., "/api/answer")
|
|
json: Optional JSON body
|
|
data: Optional form data
|
|
files: Optional files for multipart upload
|
|
timeout: Request timeout in seconds
|
|
**kwargs: Additional arguments passed to requests.post
|
|
|
|
Returns:
|
|
Response object
|
|
"""
|
|
url = f"{self.base_url}{path}"
|
|
return requests.post(
|
|
url,
|
|
json=json,
|
|
data=data,
|
|
files=files,
|
|
headers={**self.headers, **kwargs.pop("headers", {})},
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
def put(
|
|
self,
|
|
path: str,
|
|
json: Optional[dict[str, Any]] = None,
|
|
timeout: int = 30,
|
|
**kwargs: Any,
|
|
) -> requests.Response:
|
|
"""
|
|
Make a PUT request.
|
|
|
|
Args:
|
|
path: API path (e.g., "/api/update_agent/123")
|
|
json: Optional JSON body
|
|
timeout: Request timeout in seconds
|
|
**kwargs: Additional arguments passed to requests.put
|
|
|
|
Returns:
|
|
Response object
|
|
"""
|
|
url = f"{self.base_url}{path}"
|
|
return requests.put(
|
|
url,
|
|
json=json,
|
|
headers={**self.headers, **kwargs.pop("headers", {})},
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
def delete(
|
|
self,
|
|
path: str,
|
|
json: Optional[dict[str, Any]] = None,
|
|
timeout: int = 30,
|
|
**kwargs: Any,
|
|
) -> requests.Response:
|
|
"""
|
|
Make a DELETE request.
|
|
|
|
Args:
|
|
path: API path (e.g., "/api/delete_agent")
|
|
json: Optional JSON body
|
|
timeout: Request timeout in seconds
|
|
**kwargs: Additional arguments passed to requests.delete
|
|
|
|
Returns:
|
|
Response object
|
|
"""
|
|
url = f"{self.base_url}{path}"
|
|
return requests.delete(
|
|
url,
|
|
json=json,
|
|
headers={**self.headers, **kwargs.pop("headers", {})},
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
def post_stream(
|
|
self,
|
|
path: str,
|
|
json: Optional[dict[str, Any]] = None,
|
|
timeout: int = 60,
|
|
**kwargs: Any,
|
|
) -> Iterator[dict[str, Any]]:
|
|
"""
|
|
Make a streaming POST request and yield SSE events.
|
|
|
|
Args:
|
|
path: API path (e.g., "/stream")
|
|
json: Optional JSON body
|
|
timeout: Request timeout in seconds
|
|
**kwargs: Additional arguments passed to requests.post
|
|
|
|
Yields:
|
|
Parsed JSON data from each SSE event
|
|
|
|
Example:
|
|
for event in client.post_stream("/stream", json={"question": "test"}):
|
|
if event.get("type") == "answer":
|
|
print(event.get("message"))
|
|
"""
|
|
url = f"{self.base_url}{path}"
|
|
response = requests.post(
|
|
url,
|
|
json=json,
|
|
headers={**self.headers, **kwargs.pop("headers", {})},
|
|
stream=True,
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
# Store response for status code checking
|
|
self._last_stream_response = response
|
|
|
|
if response.status_code != 200:
|
|
# Yield error event for non-200 responses
|
|
yield {"type": "error", "status_code": response.status_code, "text": response.text[:500]}
|
|
return
|
|
|
|
for line in response.iter_lines():
|
|
if line:
|
|
line_str = line.decode("utf-8")
|
|
if line_str.startswith("data: "):
|
|
data_str = line_str[6:] # Remove 'data: ' prefix
|
|
try:
|
|
data = json_module.loads(data_str)
|
|
yield data
|
|
if data.get("type") == "end":
|
|
break
|
|
except json_module.JSONDecodeError:
|
|
pass
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Output Helper Methods
|
|
# -------------------------------------------------------------------------
|
|
|
|
def print_header(self, message: str) -> None:
|
|
"""Print a colored header."""
|
|
print(f"\n{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}")
|
|
print(f"{Colors.HEADER}{Colors.BOLD}{message}{Colors.ENDC}")
|
|
print(f"{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}\n")
|
|
|
|
def print_success(self, message: str) -> None:
|
|
"""Print a success message."""
|
|
print(f"{Colors.OKGREEN}[PASS] {message}{Colors.ENDC}")
|
|
|
|
def print_error(self, message: str) -> None:
|
|
"""Print an error message."""
|
|
print(f"{Colors.FAIL}[FAIL] {message}{Colors.ENDC}")
|
|
|
|
def print_info(self, message: str) -> None:
|
|
"""Print an info message."""
|
|
print(f"{Colors.OKCYAN}[INFO] {message}{Colors.ENDC}")
|
|
|
|
def print_warning(self, message: str) -> None:
|
|
"""Print a warning message."""
|
|
print(f"{Colors.WARNING}[WARN] {message}{Colors.ENDC}")
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Result Tracking Methods
|
|
# -------------------------------------------------------------------------
|
|
|
|
def record_result(self, test_name: str, success: bool, message: str) -> None:
|
|
"""
|
|
Record a test result.
|
|
|
|
Args:
|
|
test_name: Name of the test
|
|
success: Whether the test passed
|
|
message: Result message or error details
|
|
"""
|
|
self.test_results.append((test_name, success, message))
|
|
|
|
def print_summary(self) -> bool:
|
|
"""
|
|
Print test results summary.
|
|
|
|
Returns:
|
|
True if all tests passed, False otherwise
|
|
"""
|
|
self.print_header("Test Results Summary")
|
|
|
|
passed = sum(1 for _, success, _ in self.test_results if success)
|
|
failed = len(self.test_results) - passed
|
|
|
|
print(f"\n{Colors.BOLD}Total Tests: {len(self.test_results)}{Colors.ENDC}")
|
|
print(f"{Colors.OKGREEN}Passed: {passed}{Colors.ENDC}")
|
|
print(f"{Colors.FAIL}Failed: {failed}{Colors.ENDC}\n")
|
|
|
|
print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}")
|
|
for test_name, success, message in self.test_results:
|
|
status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if success else f"{Colors.FAIL}FAIL{Colors.ENDC}"
|
|
print(f" {status} - {test_name}: {message}")
|
|
|
|
print()
|
|
return failed == 0
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Assertion Helpers
|
|
# -------------------------------------------------------------------------
|
|
|
|
def assert_status(
|
|
self,
|
|
response: requests.Response,
|
|
expected: int,
|
|
test_name: str,
|
|
) -> bool:
|
|
"""
|
|
Assert response status code and record result.
|
|
|
|
Args:
|
|
response: Response object to check
|
|
expected: Expected status code
|
|
test_name: Name of the test for recording
|
|
|
|
Returns:
|
|
True if status matches, False otherwise
|
|
"""
|
|
if response.status_code == expected:
|
|
return True
|
|
else:
|
|
self.print_error(f"Expected {expected}, got {response.status_code}")
|
|
self.print_error(f"Response: {response.text[:500]}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
def assert_json_key(
|
|
self,
|
|
data: dict[str, Any],
|
|
key: str,
|
|
test_name: str,
|
|
) -> bool:
|
|
"""
|
|
Assert JSON response contains a key.
|
|
|
|
Args:
|
|
data: JSON response data
|
|
key: Expected key
|
|
test_name: Name of the test for recording
|
|
|
|
Returns:
|
|
True if key exists, False otherwise
|
|
"""
|
|
if key in data:
|
|
return True
|
|
else:
|
|
self.print_error(f"Missing key '{key}' in response")
|
|
self.record_result(test_name, False, f"Missing key: {key}")
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Convenience Properties
|
|
# -------------------------------------------------------------------------
|
|
|
|
@property
|
|
def is_authenticated(self) -> bool:
|
|
"""Check if client has authentication token."""
|
|
return self.token is not None
|
|
|
|
def require_auth(self, test_name: str) -> bool:
|
|
"""
|
|
Check authentication and record skip if not authenticated.
|
|
|
|
Args:
|
|
test_name: Name of the test
|
|
|
|
Returns:
|
|
True if authenticated, False otherwise (test skipped)
|
|
"""
|
|
if not self.is_authenticated:
|
|
self.print_warning("No authentication token provided")
|
|
self.print_info("Skipping test (auth required)")
|
|
self.record_result(test_name, True, "Skipped (auth required)")
|
|
return False
|
|
return True
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Factory Function
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def create_client_from_args(
|
|
client_class: Type[T],
|
|
description: str = "DocsGPT Integration Tests",
|
|
) -> T:
|
|
"""
|
|
Create a test client from command-line arguments.
|
|
|
|
Parses --base-url and --token arguments, and handles JWT token generation.
|
|
|
|
Args:
|
|
client_class: The test class to instantiate (must inherit from DocsGPTTestBase)
|
|
description: Description for the argument parser
|
|
|
|
Returns:
|
|
An instance of the provided client_class
|
|
|
|
Example:
|
|
class ChatTests(DocsGPTTestBase):
|
|
def run_all(self):
|
|
...
|
|
|
|
if __name__ == "__main__":
|
|
client = create_client_from_args(ChatTests)
|
|
sys.exit(0 if client.run_all() else 1)
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description=description,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--base-url",
|
|
default=os.getenv("DOCSGPT_BASE_URL", "http://localhost:7091"),
|
|
help="Base URL of DocsGPT instance (default: http://localhost:7091)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--token",
|
|
default=os.getenv("JWT_TOKEN"),
|
|
help="JWT authentication token (auto-generated from local secret when available)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Determine token and source
|
|
token = args.token
|
|
token_source = "provided via --token" if token else "none"
|
|
|
|
if not token:
|
|
token, token_error = generate_jwt_token()
|
|
if token:
|
|
token_source = "auto-generated from local secret"
|
|
print(f"{Colors.OKCYAN}[INFO] Using auto-generated JWT token{Colors.ENDC}")
|
|
elif token_error:
|
|
print(f"{Colors.WARNING}[WARN] Could not auto-generate token: {token_error}{Colors.ENDC}")
|
|
print(f"{Colors.WARNING}[WARN] Tests requiring auth will be skipped{Colors.ENDC}")
|
|
|
|
return client_class(args.base_url, token=token, token_source=token_source)
|