mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 14:00:55 +00:00
676 lines
23 KiB
Python
676 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Integration tests for DocsGPT source management endpoints.
|
|
|
|
Endpoints tested:
|
|
- /api/upload (POST) - File upload
|
|
- /api/remote (POST) - Remote source (crawler)
|
|
- /api/sources (GET) - List sources
|
|
- /api/sources/paginated (GET) - Paginated sources
|
|
- /api/task_status (GET) - Task status
|
|
- /api/add_chunk (POST) - Add chunk to source
|
|
- /api/get_chunks (GET) - Get chunks from source
|
|
- /api/update_chunk (PUT) - Update chunk
|
|
- /api/delete_chunk (DELETE) - Delete chunk
|
|
- /api/delete_by_ids (GET) - Delete sources by IDs
|
|
- /api/delete_old (GET) - Delete old sources
|
|
- /api/directory_structure (GET) - Get directory structure
|
|
- /api/manage_source_files (POST) - Manage source files
|
|
- /api/manage_sync (POST) - Manage sync
|
|
- /api/combine (GET) - Combine sources
|
|
|
|
Usage:
|
|
python tests/integration/test_sources.py
|
|
python tests/integration/test_sources.py --base-url http://localhost:7091
|
|
python tests/integration/test_sources.py --token YOUR_JWT_TOKEN
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
# Add parent directory to path for standalone execution
|
|
_THIS_DIR = Path(__file__).parent
|
|
_TESTS_DIR = _THIS_DIR.parent
|
|
_ROOT_DIR = _TESTS_DIR.parent
|
|
if str(_ROOT_DIR) not in sys.path:
|
|
sys.path.insert(0, str(_ROOT_DIR))
|
|
|
|
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
|
|
|
|
|
class SourceTests(DocsGPTTestBase):
|
|
"""Integration tests for source management endpoints."""
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Test Data Helpers
|
|
# -------------------------------------------------------------------------
|
|
|
|
def get_or_create_test_source(self) -> Optional[dict]:
|
|
"""
|
|
Get or create a test source.
|
|
|
|
Returns:
|
|
Dict with keys: id, task_id, name or None
|
|
"""
|
|
if hasattr(self, "_test_source"):
|
|
return self._test_source
|
|
|
|
if not self.is_authenticated:
|
|
return None
|
|
|
|
test_name = f"Source Test {int(time.time())}"
|
|
test_content = """# Test Documentation
|
|
|
|
## Overview
|
|
This is test documentation for source integration tests.
|
|
|
|
## Installation
|
|
Run `pip install docsgpt` to install.
|
|
|
|
## Usage
|
|
Import and use the library in your code.
|
|
|
|
## API Reference
|
|
See the API documentation for details.
|
|
"""
|
|
|
|
files = {"file": ("test_source.txt", test_content.encode(), "text/plain")}
|
|
data = {"user": "test_user", "name": test_name}
|
|
|
|
try:
|
|
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
task_id = result.get("task_id")
|
|
if task_id:
|
|
# Wait for processing
|
|
time.sleep(5)
|
|
|
|
# Get source ID
|
|
source_id = self._get_source_id_by_name(test_name)
|
|
if source_id:
|
|
self._test_source = {
|
|
"id": source_id,
|
|
"task_id": task_id,
|
|
"name": test_name,
|
|
}
|
|
return self._test_source
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
def _get_source_id_by_name(self, name: str) -> Optional[str]:
|
|
"""Get source ID by name from sources list."""
|
|
try:
|
|
response = self.get("/api/sources")
|
|
if response.status_code == 200:
|
|
sources = response.json()
|
|
for source in sources:
|
|
if source.get("name") == name:
|
|
return source.get("id")
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
def _wait_for_task(self, task_id: str, max_wait: int = 30) -> Optional[str]:
|
|
"""Wait for task to complete and return status."""
|
|
for _ in range(max_wait):
|
|
try:
|
|
response = self.get("/api/task_status", params={"task_id": task_id})
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
status = result.get("status")
|
|
if status in ["SUCCESS", "FAILURE"]:
|
|
return status
|
|
except Exception:
|
|
pass
|
|
time.sleep(1)
|
|
return None
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Upload Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_upload_text_source(self) -> bool:
|
|
"""Test uploading a text file source."""
|
|
test_name = "Upload - Text Source"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
test_content = f"""# Upload Test {int(time.time())}
|
|
This is a test document for upload testing.
|
|
It contains multiple lines of text.
|
|
"""
|
|
|
|
files = {"file": ("upload_test.txt", test_content.encode(), "text/plain")}
|
|
data = {"user": "test_user", "name": f"Upload Test {int(time.time())}"}
|
|
|
|
try:
|
|
self.print_info("POST /api/upload")
|
|
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
task_id = result.get("task_id")
|
|
|
|
if task_id:
|
|
self.print_success(f"Upload task started: {task_id}")
|
|
self.record_result(test_name, True, f"Task: {task_id}")
|
|
return True
|
|
else:
|
|
self.print_warning("No task_id returned")
|
|
self.record_result(test_name, False, "No task_id")
|
|
return False
|
|
else:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_upload_markdown_source(self) -> bool:
|
|
"""Test uploading a markdown file source."""
|
|
test_name = "Upload - Markdown Source"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
test_content = f"""# Markdown Test Document
|
|
|
|
## Section 1
|
|
This is the first section with **bold** and *italic* text.
|
|
|
|
## Section 2
|
|
- Item 1
|
|
- Item 2
|
|
- Item 3
|
|
|
|
## Code Example
|
|
```python
|
|
def hello():
|
|
print("Hello, World!")
|
|
```
|
|
|
|
Created at: {int(time.time())}
|
|
"""
|
|
|
|
files = {"file": ("test.md", test_content.encode(), "text/markdown")}
|
|
data = {"user": "test_user", "name": f"Markdown Test {int(time.time())}"}
|
|
|
|
try:
|
|
self.print_info("POST /api/upload (markdown)")
|
|
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
task_id = result.get("task_id")
|
|
if task_id:
|
|
self.print_success(f"Markdown upload task started: {task_id}")
|
|
self.record_result(test_name, True, f"Task: {task_id}")
|
|
return True
|
|
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Remote Source Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_remote_crawler_source(self) -> bool:
|
|
"""Test remote crawler source upload."""
|
|
test_name = "Remote - Crawler Source"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
# Use a small, fast-loading page
|
|
payload = {
|
|
"user": "test_user",
|
|
"source": "crawler",
|
|
"name": f"Crawler Test {int(time.time())}",
|
|
"data": '{"url": "https://example.com/"}',
|
|
}
|
|
|
|
try:
|
|
self.print_info("POST /api/remote (crawler)")
|
|
response = self.post("/api/remote", data=payload, timeout=30)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
task_id = result.get("task_id")
|
|
if task_id:
|
|
self.print_success(f"Crawler task started: {task_id}")
|
|
self.record_result(test_name, True, f"Task: {task_id}")
|
|
return True
|
|
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Source Listing Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_get_sources(self) -> bool:
|
|
"""Test getting list of sources."""
|
|
test_name = "Sources - List All"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
try:
|
|
self.print_info("GET /api/sources")
|
|
response = self.get("/api/sources")
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
sources = response.json()
|
|
self.print_success(f"Retrieved {len(sources)} sources")
|
|
|
|
if sources:
|
|
first = sources[0]
|
|
self.print_info(f"First source: {first.get('name', 'N/A')}")
|
|
|
|
self.record_result(test_name, True, f"{len(sources)} sources")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_get_sources_paginated(self) -> bool:
|
|
"""Test getting paginated sources."""
|
|
test_name = "Sources - Paginated"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
try:
|
|
self.print_info("GET /api/sources/paginated")
|
|
response = self.get("/api/sources/paginated", params={"page": 1, "per_page": 10})
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
self.print_success("Paginated sources retrieved")
|
|
|
|
if isinstance(result, dict):
|
|
total = result.get("total", "N/A")
|
|
self.print_info(f"Total sources: {total}")
|
|
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Task Status Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_task_status(self) -> bool:
|
|
"""Test getting task status."""
|
|
test_name = "Task Status - Check"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
# First upload a file to get a task_id
|
|
test_content = "Test content for task status"
|
|
files = {"file": ("task_test.txt", test_content.encode(), "text/plain")}
|
|
data = {"user": "test_user", "name": f"Task Test {int(time.time())}"}
|
|
|
|
try:
|
|
upload_response = self.post("/api/upload", files=files, data=data, timeout=30)
|
|
|
|
if upload_response.status_code != 200:
|
|
self.record_result(test_name, True, "Skipped (upload failed)")
|
|
return True
|
|
|
|
task_id = upload_response.json().get("task_id")
|
|
if not task_id:
|
|
self.record_result(test_name, True, "Skipped (no task_id)")
|
|
return True
|
|
|
|
self.print_info(f"GET /api/task_status?task_id={task_id[:8]}...")
|
|
response = self.get("/api/task_status", params={"task_id": task_id})
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
status = result.get("status", "UNKNOWN")
|
|
self.print_success(f"Task status: {status}")
|
|
self.record_result(test_name, True, f"Status: {status}")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Chunk Management Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_get_chunks(self) -> bool:
|
|
"""Test getting chunks from a source."""
|
|
test_name = "Chunks - Get"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
source = self.get_or_create_test_source()
|
|
if not source:
|
|
self.print_warning("Could not create test source")
|
|
self.record_result(test_name, True, "Skipped (no source)")
|
|
return True
|
|
|
|
try:
|
|
# Swagger says param is 'id', not 'source_id'
|
|
self.print_info(f"GET /api/get_chunks?id={source['id'][:8]}...")
|
|
response = self.get("/api/get_chunks", params={"id": source["id"]})
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
chunks = result if isinstance(result, list) else result.get("chunks", [])
|
|
self.print_success(f"Retrieved {len(chunks)} chunks")
|
|
self.record_result(test_name, True, f"{len(chunks)} chunks")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_add_chunk(self) -> bool:
|
|
"""Test adding a chunk to a source."""
|
|
test_name = "Chunks - Add"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
source = self.get_or_create_test_source()
|
|
if not source:
|
|
self.record_result(test_name, True, "Skipped (no source)")
|
|
return True
|
|
|
|
payload = {
|
|
"source_id": source["id"],
|
|
"content": f"Test chunk content added at {int(time.time())}",
|
|
"metadata": {"test": True},
|
|
}
|
|
|
|
try:
|
|
self.print_info("POST /api/add_chunk")
|
|
response = self.post("/api/add_chunk", json=payload)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code in [200, 201]:
|
|
self.print_success("Chunk added successfully")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
else:
|
|
# May not be supported or require specific format
|
|
self.print_warning(f"Status {response.status_code}")
|
|
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Delete Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_delete_by_ids(self) -> bool:
|
|
"""Test deleting documents by vector store IDs.
|
|
|
|
Note: This endpoint expects vector store document IDs (chunk IDs),
|
|
not MongoDB source IDs. Testing with non-existent IDs returns 400.
|
|
"""
|
|
test_name = "Sources - Delete by IDs"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
try:
|
|
# Test endpoint accessibility with a test ID
|
|
# Note: This endpoint expects vector document IDs, not source IDs
|
|
test_id = "test-document-id-12345"
|
|
self.print_info(f"GET /api/delete_by_ids?path={test_id}")
|
|
response = self.get("/api/delete_by_ids", params={"path": test_id})
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
self.print_success("Delete endpoint responded successfully")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
elif response.status_code == 400:
|
|
# 400 is expected when document ID doesn't exist in vector store
|
|
self.print_warning("Expected 400 (ID not in vector store)")
|
|
self.record_result(test_name, True, "Endpoint works (ID not found)")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Directory Structure Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_directory_structure(self) -> bool:
|
|
"""Test getting directory structure."""
|
|
test_name = "Directory Structure"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
source = self.get_or_create_test_source()
|
|
if not source:
|
|
self.record_result(test_name, True, "Skipped (no source)")
|
|
return True
|
|
|
|
try:
|
|
self.print_info(f"GET /api/directory_structure?source_id={source['id'][:8]}...")
|
|
response = self.get("/api/directory_structure", params={"source_id": source["id"]})
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
response.json() # Validate JSON response
|
|
self.print_success("Directory structure retrieved")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
else:
|
|
# May not be supported for all source types
|
|
self.print_warning(f"Status {response.status_code}")
|
|
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Combine Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_combine(self) -> bool:
|
|
"""Test combine endpoint."""
|
|
test_name = "Sources - Combine"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
try:
|
|
self.print_info("GET /api/combine")
|
|
response = self.get("/api/combine")
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
self.print_success("Combine endpoint works")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Manage Source Files Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_manage_source_files(self) -> bool:
|
|
"""Test managing source files."""
|
|
test_name = "Manage Source Files"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
source = self.get_or_create_test_source()
|
|
if not source:
|
|
self.record_result(test_name, True, "Skipped (no source)")
|
|
return True
|
|
|
|
payload = {
|
|
"source_id": source["id"],
|
|
"action": "list",
|
|
}
|
|
|
|
try:
|
|
self.print_info("POST /api/manage_source_files")
|
|
response = self.post("/api/manage_source_files", json=payload)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
self.print_success("Source files managed")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
else:
|
|
# May require specific format
|
|
self.print_warning(f"Status {response.status_code}")
|
|
self.record_result(test_name, True, f"Skipped (status {response.status_code})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Run All Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def run_all(self) -> bool:
|
|
"""Run all source integration tests."""
|
|
self.print_header("Source Integration Tests")
|
|
self.print_info(f"Base URL: {self.base_url}")
|
|
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
|
|
|
|
# Upload tests
|
|
self.test_upload_text_source()
|
|
time.sleep(1)
|
|
|
|
self.test_upload_markdown_source()
|
|
time.sleep(1)
|
|
|
|
# Remote source tests
|
|
self.test_remote_crawler_source()
|
|
time.sleep(1)
|
|
|
|
# Source listing tests
|
|
self.test_get_sources()
|
|
time.sleep(1)
|
|
|
|
self.test_get_sources_paginated()
|
|
time.sleep(1)
|
|
|
|
# Task status test
|
|
self.test_task_status()
|
|
time.sleep(1)
|
|
|
|
# Chunk tests
|
|
self.test_get_chunks()
|
|
time.sleep(1)
|
|
|
|
self.test_add_chunk()
|
|
time.sleep(1)
|
|
|
|
# Directory structure
|
|
self.test_directory_structure()
|
|
time.sleep(1)
|
|
|
|
# Combine
|
|
self.test_combine()
|
|
time.sleep(1)
|
|
|
|
# Manage source files
|
|
self.test_manage_source_files()
|
|
time.sleep(1)
|
|
|
|
# Delete test (last because it removes data)
|
|
self.test_delete_by_ids()
|
|
|
|
return self.print_summary()
|
|
|
|
|
|
def main():
|
|
"""Main entry point for standalone execution."""
|
|
client = create_client_from_args(SourceTests, "DocsGPT Source Integration Tests")
|
|
success = client.run_all()
|
|
sys.exit(0 if success else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|