Files
DocsGPT/tests/integration/test_sources.py
2026-01-16 13:10:27 +03:00

676 lines
23 KiB
Python

#!/usr/bin/env python3
"""
Integration tests for DocsGPT source management endpoints.
Endpoints tested:
- /api/upload (POST) - File upload
- /api/remote (POST) - Remote source (crawler)
- /api/sources (GET) - List sources
- /api/sources/paginated (GET) - Paginated sources
- /api/task_status (GET) - Task status
- /api/add_chunk (POST) - Add chunk to source
- /api/get_chunks (GET) - Get chunks from source
- /api/update_chunk (PUT) - Update chunk
- /api/delete_chunk (DELETE) - Delete chunk
- /api/delete_by_ids (GET) - Delete sources by IDs
- /api/delete_old (GET) - Delete old sources
- /api/directory_structure (GET) - Get directory structure
- /api/manage_source_files (POST) - Manage source files
- /api/manage_sync (POST) - Manage sync
- /api/combine (GET) - Combine sources
Usage:
python tests/integration/test_sources.py
python tests/integration/test_sources.py --base-url http://localhost:7091
python tests/integration/test_sources.py --token YOUR_JWT_TOKEN
"""
import sys
import time
from pathlib import Path
from typing import Optional
# Add parent directory to path for standalone execution
_THIS_DIR = Path(__file__).parent
_TESTS_DIR = _THIS_DIR.parent
_ROOT_DIR = _TESTS_DIR.parent
if str(_ROOT_DIR) not in sys.path:
sys.path.insert(0, str(_ROOT_DIR))
from tests.integration.base import DocsGPTTestBase, create_client_from_args
class SourceTests(DocsGPTTestBase):
"""Integration tests for source management endpoints."""
# -------------------------------------------------------------------------
# Test Data Helpers
# -------------------------------------------------------------------------
def get_or_create_test_source(self) -> Optional[dict]:
"""
Get or create a test source.
Returns:
Dict with keys: id, task_id, name or None
"""
if hasattr(self, "_test_source"):
return self._test_source
if not self.is_authenticated:
return None
test_name = f"Source Test {int(time.time())}"
test_content = """# Test Documentation
## Overview
This is test documentation for source integration tests.
## Installation
Run `pip install docsgpt` to install.
## Usage
Import and use the library in your code.
## API Reference
See the API documentation for details.
"""
files = {"file": ("test_source.txt", test_content.encode(), "text/plain")}
data = {"user": "test_user", "name": test_name}
try:
response = self.post("/api/upload", files=files, data=data, timeout=30)
if response.status_code == 200:
result = response.json()
task_id = result.get("task_id")
if task_id:
# Wait for processing
time.sleep(5)
# Get source ID
source_id = self._get_source_id_by_name(test_name)
if source_id:
self._test_source = {
"id": source_id,
"task_id": task_id,
"name": test_name,
}
return self._test_source
except Exception:
pass
return None
def _get_source_id_by_name(self, name: str) -> Optional[str]:
"""Get source ID by name from sources list."""
try:
response = self.get("/api/sources")
if response.status_code == 200:
sources = response.json()
for source in sources:
if source.get("name") == name:
return source.get("id")
except Exception:
pass
return None
def _wait_for_task(self, task_id: str, max_wait: int = 30) -> Optional[str]:
"""Wait for task to complete and return status."""
for _ in range(max_wait):
try:
response = self.get("/api/task_status", params={"task_id": task_id})
if response.status_code == 200:
result = response.json()
status = result.get("status")
if status in ["SUCCESS", "FAILURE"]:
return status
except Exception:
pass
time.sleep(1)
return None
# -------------------------------------------------------------------------
# Upload Tests
# -------------------------------------------------------------------------
def test_upload_text_source(self) -> bool:
"""Test uploading a text file source."""
test_name = "Upload - Text Source"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
test_content = f"""# Upload Test {int(time.time())}
This is a test document for upload testing.
It contains multiple lines of text.
"""
files = {"file": ("upload_test.txt", test_content.encode(), "text/plain")}
data = {"user": "test_user", "name": f"Upload Test {int(time.time())}"}
try:
self.print_info("POST /api/upload")
response = self.post("/api/upload", files=files, data=data, timeout=30)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
result = response.json()
task_id = result.get("task_id")
if task_id:
self.print_success(f"Upload task started: {task_id}")
self.record_result(test_name, True, f"Task: {task_id}")
return True
else:
self.print_warning("No task_id returned")
self.record_result(test_name, False, "No task_id")
return False
else:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Error: {str(e)}")
self.record_result(test_name, False, str(e))
return False
def test_upload_markdown_source(self) -> bool:
"""Test uploading a markdown file source."""
test_name = "Upload - Markdown Source"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
test_content = f"""# Markdown Test Document
## Section 1
This is the first section with **bold** and *italic* text.
## Section 2
- Item 1
- Item 2
- Item 3
## Code Example
```python
def hello():
print("Hello, World!")
```
Created at: {int(time.time())}
"""
files = {"file": ("test.md", test_content.encode(), "text/markdown")}
data = {"user": "test_user", "name": f"Markdown Test {int(time.time())}"}
try:
self.print_info("POST /api/upload (markdown)")
response = self.post("/api/upload", files=files, data=data, timeout=30)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
result = response.json()
task_id = result.get("task_id")
if task_id:
self.print_success(f"Markdown upload task started: {task_id}")
self.record_result(test_name, True, f"Task: {task_id}")
return True
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Remote Source Tests
# -------------------------------------------------------------------------
def test_remote_crawler_source(self) -> bool:
"""Test remote crawler source upload."""
test_name = "Remote - Crawler Source"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
# Use a small, fast-loading page
payload = {
"user": "test_user",
"source": "crawler",
"name": f"Crawler Test {int(time.time())}",
"data": '{"url": "https://example.com/"}',
}
try:
self.print_info("POST /api/remote (crawler)")
response = self.post("/api/remote", data=payload, timeout=30)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
result = response.json()
task_id = result.get("task_id")
if task_id:
self.print_success(f"Crawler task started: {task_id}")
self.record_result(test_name, True, f"Task: {task_id}")
return True
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Source Listing Tests
# -------------------------------------------------------------------------
def test_get_sources(self) -> bool:
"""Test getting list of sources."""
test_name = "Sources - List All"
self.print_header(f"Testing {test_name}")
try:
self.print_info("GET /api/sources")
response = self.get("/api/sources")
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
sources = response.json()
self.print_success(f"Retrieved {len(sources)} sources")
if sources:
first = sources[0]
self.print_info(f"First source: {first.get('name', 'N/A')}")
self.record_result(test_name, True, f"{len(sources)} sources")
return True
else:
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
def test_get_sources_paginated(self) -> bool:
"""Test getting paginated sources."""
test_name = "Sources - Paginated"
self.print_header(f"Testing {test_name}")
try:
self.print_info("GET /api/sources/paginated")
response = self.get("/api/sources/paginated", params={"page": 1, "per_page": 10})
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
result = response.json()
self.print_success("Paginated sources retrieved")
if isinstance(result, dict):
total = result.get("total", "N/A")
self.print_info(f"Total sources: {total}")
self.record_result(test_name, True, "Success")
return True
else:
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Task Status Tests
# -------------------------------------------------------------------------
def test_task_status(self) -> bool:
"""Test getting task status."""
test_name = "Task Status - Check"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
# First upload a file to get a task_id
test_content = "Test content for task status"
files = {"file": ("task_test.txt", test_content.encode(), "text/plain")}
data = {"user": "test_user", "name": f"Task Test {int(time.time())}"}
try:
upload_response = self.post("/api/upload", files=files, data=data, timeout=30)
if upload_response.status_code != 200:
self.record_result(test_name, True, "Skipped (upload failed)")
return True
task_id = upload_response.json().get("task_id")
if not task_id:
self.record_result(test_name, True, "Skipped (no task_id)")
return True
self.print_info(f"GET /api/task_status?task_id={task_id[:8]}...")
response = self.get("/api/task_status", params={"task_id": task_id})
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
result = response.json()
status = result.get("status", "UNKNOWN")
self.print_success(f"Task status: {status}")
self.record_result(test_name, True, f"Status: {status}")
return True
else:
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Chunk Management Tests
# -------------------------------------------------------------------------
def test_get_chunks(self) -> bool:
"""Test getting chunks from a source."""
test_name = "Chunks - Get"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
source = self.get_or_create_test_source()
if not source:
self.print_warning("Could not create test source")
self.record_result(test_name, True, "Skipped (no source)")
return True
try:
# Swagger says param is 'id', not 'source_id'
self.print_info(f"GET /api/get_chunks?id={source['id'][:8]}...")
response = self.get("/api/get_chunks", params={"id": source["id"]})
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
result = response.json()
chunks = result if isinstance(result, list) else result.get("chunks", [])
self.print_success(f"Retrieved {len(chunks)} chunks")
self.record_result(test_name, True, f"{len(chunks)} chunks")
return True
else:
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
def test_add_chunk(self) -> bool:
"""Test adding a chunk to a source."""
test_name = "Chunks - Add"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
source = self.get_or_create_test_source()
if not source:
self.record_result(test_name, True, "Skipped (no source)")
return True
payload = {
"source_id": source["id"],
"content": f"Test chunk content added at {int(time.time())}",
"metadata": {"test": True},
}
try:
self.print_info("POST /api/add_chunk")
response = self.post("/api/add_chunk", json=payload)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code in [200, 201]:
self.print_success("Chunk added successfully")
self.record_result(test_name, True, "Success")
return True
else:
# May not be supported or require specific format
self.print_warning(f"Status {response.status_code}")
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
return True
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Delete Tests
# -------------------------------------------------------------------------
def test_delete_by_ids(self) -> bool:
"""Test deleting documents by vector store IDs.
Note: This endpoint expects vector store document IDs (chunk IDs),
not MongoDB source IDs. Testing with non-existent IDs returns 400.
"""
test_name = "Sources - Delete by IDs"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
try:
# Test endpoint accessibility with a test ID
# Note: This endpoint expects vector document IDs, not source IDs
test_id = "test-document-id-12345"
self.print_info(f"GET /api/delete_by_ids?path={test_id}")
response = self.get("/api/delete_by_ids", params={"path": test_id})
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
self.print_success("Delete endpoint responded successfully")
self.record_result(test_name, True, "Success")
return True
elif response.status_code == 400:
# 400 is expected when document ID doesn't exist in vector store
self.print_warning("Expected 400 (ID not in vector store)")
self.record_result(test_name, True, "Endpoint works (ID not found)")
return True
else:
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Directory Structure Tests
# -------------------------------------------------------------------------
def test_directory_structure(self) -> bool:
"""Test getting directory structure."""
test_name = "Directory Structure"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
source = self.get_or_create_test_source()
if not source:
self.record_result(test_name, True, "Skipped (no source)")
return True
try:
self.print_info(f"GET /api/directory_structure?source_id={source['id'][:8]}...")
response = self.get("/api/directory_structure", params={"source_id": source["id"]})
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
response.json() # Validate JSON response
self.print_success("Directory structure retrieved")
self.record_result(test_name, True, "Success")
return True
else:
# May not be supported for all source types
self.print_warning(f"Status {response.status_code}")
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
return True
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Combine Tests
# -------------------------------------------------------------------------
def test_combine(self) -> bool:
"""Test combine endpoint."""
test_name = "Sources - Combine"
self.print_header(f"Testing {test_name}")
try:
self.print_info("GET /api/combine")
response = self.get("/api/combine")
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
self.print_success("Combine endpoint works")
self.record_result(test_name, True, "Success")
return True
else:
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Manage Source Files Tests
# -------------------------------------------------------------------------
def test_manage_source_files(self) -> bool:
"""Test managing source files."""
test_name = "Manage Source Files"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
source = self.get_or_create_test_source()
if not source:
self.record_result(test_name, True, "Skipped (no source)")
return True
payload = {
"source_id": source["id"],
"action": "list",
}
try:
self.print_info("POST /api/manage_source_files")
response = self.post("/api/manage_source_files", json=payload)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code == 200:
self.print_success("Source files managed")
self.record_result(test_name, True, "Success")
return True
else:
# May require specific format
self.print_warning(f"Status {response.status_code}")
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
return True
except Exception as e:
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Run All Tests
# -------------------------------------------------------------------------
def run_all(self) -> bool:
"""Run all source integration tests."""
self.print_header("Source Integration Tests")
self.print_info(f"Base URL: {self.base_url}")
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
# Upload tests
self.test_upload_text_source()
time.sleep(1)
self.test_upload_markdown_source()
time.sleep(1)
# Remote source tests
self.test_remote_crawler_source()
time.sleep(1)
# Source listing tests
self.test_get_sources()
time.sleep(1)
self.test_get_sources_paginated()
time.sleep(1)
# Task status test
self.test_task_status()
time.sleep(1)
# Chunk tests
self.test_get_chunks()
time.sleep(1)
self.test_add_chunk()
time.sleep(1)
# Directory structure
self.test_directory_structure()
time.sleep(1)
# Combine
self.test_combine()
time.sleep(1)
# Manage source files
self.test_manage_source_files()
time.sleep(1)
# Delete test (last because it removes data)
self.test_delete_by_ids()
return self.print_summary()
def main():
"""Main entry point for standalone execution."""
client = create_client_from_args(SourceTests, "DocsGPT Source Integration Tests")
success = client.run_all()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()