mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 22:10:54 +00:00
test_integration.py kept for backwards compatability. tests/integration/run_all.py proposed as alternative to cover all endpoints.
958 lines
35 KiB
Python
958 lines
35 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Integration tests for DocsGPT chat endpoints.
|
|
|
|
Endpoints tested:
|
|
- /stream (POST) - Streaming chat
|
|
- /api/answer (POST) - Non-streaming chat
|
|
- /api/feedback (POST) - Feedback submission
|
|
- /api/tts (POST) - Text-to-speech
|
|
|
|
Usage:
|
|
python tests/integration/test_chat.py
|
|
python tests/integration/test_chat.py --base-url http://localhost:7091
|
|
python tests/integration/test_chat.py --token YOUR_JWT_TOKEN
|
|
"""
|
|
|
|
import json as json_module
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
# Add parent directory to path for standalone execution
|
|
_THIS_DIR = Path(__file__).parent
|
|
_TESTS_DIR = _THIS_DIR.parent
|
|
_ROOT_DIR = _TESTS_DIR.parent
|
|
if str(_ROOT_DIR) not in sys.path:
|
|
sys.path.insert(0, str(_ROOT_DIR))
|
|
|
|
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
|
|
|
|
|
class ChatTests(DocsGPTTestBase):
|
|
"""Integration tests for chat/streaming endpoints."""
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Test Data Helpers
|
|
# -------------------------------------------------------------------------
|
|
|
|
def get_or_create_test_agent(self) -> Optional[tuple]:
|
|
"""
|
|
Get or create a test agent for chat tests.
|
|
|
|
Returns:
|
|
Tuple of (agent_id, api_key) or None if creation fails
|
|
"""
|
|
if hasattr(self, "_test_agent"):
|
|
return self._test_agent
|
|
|
|
if not self.is_authenticated:
|
|
return None
|
|
|
|
payload = {
|
|
"name": f"Chat Test Agent {int(time.time())}",
|
|
"description": "Integration test agent for chat tests",
|
|
"prompt_id": "default",
|
|
"chunks": 2,
|
|
"retriever": "classic",
|
|
"agent_type": "classic",
|
|
"status": "draft",
|
|
}
|
|
|
|
try:
|
|
response = self.post("/api/create_agent", json=payload, timeout=10)
|
|
if response.status_code in [200, 201]:
|
|
result = response.json()
|
|
agent_id = result.get("id")
|
|
api_key = result.get("key")
|
|
if agent_id:
|
|
self._test_agent = (agent_id, api_key)
|
|
return self._test_agent
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
def get_or_create_published_agent(self) -> Optional[tuple]:
|
|
"""
|
|
Get or create a published agent with API key.
|
|
|
|
Returns:
|
|
Tuple of (agent_id, api_key) or None if creation fails
|
|
"""
|
|
if hasattr(self, "_published_agent"):
|
|
return self._published_agent
|
|
|
|
if not self.is_authenticated:
|
|
return None
|
|
|
|
# First create a source
|
|
source_id = self._create_test_source()
|
|
|
|
payload = {
|
|
"name": f"Chat Test Published Agent {int(time.time())}",
|
|
"description": "Integration test published agent",
|
|
"prompt_id": "default",
|
|
"chunks": 2,
|
|
"retriever": "classic",
|
|
"agent_type": "classic",
|
|
"status": "published",
|
|
}
|
|
|
|
if source_id:
|
|
payload["source"] = source_id
|
|
|
|
try:
|
|
response = self.post("/api/create_agent", json=payload, timeout=10)
|
|
if response.status_code in [200, 201]:
|
|
result = response.json()
|
|
agent_id = result.get("id")
|
|
api_key = result.get("key")
|
|
if agent_id and api_key:
|
|
self._published_agent = (agent_id, api_key)
|
|
return self._published_agent
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
def _create_test_source(self) -> Optional[str]:
|
|
"""Create a simple test source and return its ID."""
|
|
if hasattr(self, "_test_source_id"):
|
|
return self._test_source_id
|
|
|
|
test_content = """# Test Documentation
|
|
## Overview
|
|
This is test documentation for integration tests.
|
|
## Features
|
|
- Feature 1: Testing
|
|
- Feature 2: Integration
|
|
"""
|
|
files = {"file": ("test_docs.txt", test_content.encode(), "text/plain")}
|
|
data = {"user": "test_user", "name": f"Chat Test Source {int(time.time())}"}
|
|
|
|
try:
|
|
response = self.post("/api/upload", files=files, data=data, timeout=30)
|
|
if response.status_code == 200:
|
|
task_id = response.json().get("task_id")
|
|
if task_id:
|
|
time.sleep(5) # Wait for processing
|
|
# Get source ID
|
|
sources_response = self.get("/api/sources")
|
|
if sources_response.status_code == 200:
|
|
sources = sources_response.json()
|
|
for source in sources:
|
|
if "Chat Test Source" in source.get("name", ""):
|
|
self._test_source_id = source.get("id")
|
|
return self._test_source_id
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Stream Endpoint Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_stream_endpoint_no_agent(self) -> bool:
|
|
"""Test /stream endpoint without agent."""
|
|
test_name = "Stream endpoint (no agent)"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {
|
|
"question": "What is DocsGPT?",
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
}
|
|
|
|
try:
|
|
self.print_info("POST /stream")
|
|
self.print_info(f"Payload: {json_module.dumps(payload, indent=2)}")
|
|
|
|
response = requests.post(
|
|
f"{self.base_url}/stream",
|
|
json=payload,
|
|
headers=self.headers,
|
|
stream=True,
|
|
timeout=30,
|
|
)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.print_error(f"Response: {response.text[:500]}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
# Parse SSE stream
|
|
events = []
|
|
full_response = ""
|
|
conversation_id = None
|
|
|
|
for line in response.iter_lines():
|
|
if line:
|
|
line_str = line.decode("utf-8")
|
|
if line_str.startswith("data: "):
|
|
data_str = line_str[6:]
|
|
try:
|
|
data = json_module.loads(data_str)
|
|
events.append(data)
|
|
|
|
if data.get("type") in ["stream", "answer"]:
|
|
full_response += data.get("message", "") or data.get("answer", "")
|
|
elif data.get("type") == "id":
|
|
conversation_id = data.get("id")
|
|
elif data.get("type") == "end":
|
|
break
|
|
except json_module.JSONDecodeError:
|
|
pass
|
|
|
|
self.print_success(f"Received {len(events)} events")
|
|
self.print_info(f"Response preview: {full_response[:100]}...")
|
|
|
|
if conversation_id:
|
|
self.print_success(f"Conversation ID: {conversation_id}")
|
|
|
|
self.record_result(test_name, True, "Success")
|
|
self.print_success(f"{test_name} passed!")
|
|
return True
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
self.print_error(f"Request failed: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
except Exception as e:
|
|
self.print_error(f"Unexpected error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_stream_endpoint_with_agent(self) -> bool:
|
|
"""Test /stream endpoint with agent_id."""
|
|
test_name = "Stream endpoint (with agent)"
|
|
|
|
agent_result = self.get_or_create_test_agent()
|
|
if not agent_result:
|
|
if not self.require_auth(test_name):
|
|
return True # Skipped
|
|
self.print_warning("Could not create test agent")
|
|
self.record_result(test_name, True, "Skipped (no agent)")
|
|
return True
|
|
|
|
agent_id, _ = agent_result
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {
|
|
"question": "What is DocsGPT?",
|
|
"history": "[]",
|
|
"agent_id": agent_id,
|
|
}
|
|
|
|
try:
|
|
self.print_info(f"POST /stream with agent_id={agent_id[:8]}...")
|
|
|
|
response = requests.post(
|
|
f"{self.base_url}/stream",
|
|
json=payload,
|
|
headers=self.headers,
|
|
stream=True,
|
|
timeout=30,
|
|
)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
events = []
|
|
for line in response.iter_lines():
|
|
if line:
|
|
line_str = line.decode("utf-8")
|
|
if line_str.startswith("data: "):
|
|
try:
|
|
data = json_module.loads(line_str[6:])
|
|
events.append(data)
|
|
if data.get("type") == "end":
|
|
break
|
|
except json_module.JSONDecodeError:
|
|
pass
|
|
|
|
self.print_success(f"Received {len(events)} events")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_stream_endpoint_with_api_key(self) -> bool:
|
|
"""Test /stream endpoint with API key."""
|
|
test_name = "Stream endpoint (with API key)"
|
|
|
|
agent_result = self.get_or_create_published_agent()
|
|
if not agent_result or not agent_result[1]:
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
self.print_warning("Could not create published agent with API key")
|
|
self.record_result(test_name, True, "Skipped (no API key)")
|
|
return True
|
|
|
|
_, api_key = agent_result
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {
|
|
"question": "What is DocsGPT?",
|
|
"history": "[]",
|
|
"api_key": api_key,
|
|
}
|
|
|
|
try:
|
|
self.print_info(f"POST /stream with api_key={api_key[:20]}...")
|
|
|
|
response = requests.post(
|
|
f"{self.base_url}/stream",
|
|
json=payload,
|
|
headers=self.headers,
|
|
stream=True,
|
|
timeout=30,
|
|
)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
events = []
|
|
full_response = ""
|
|
for line in response.iter_lines():
|
|
if line:
|
|
line_str = line.decode("utf-8")
|
|
if line_str.startswith("data: "):
|
|
try:
|
|
data = json_module.loads(line_str[6:])
|
|
events.append(data)
|
|
if data.get("type") in ["stream", "answer"]:
|
|
full_response += data.get("message", "") or data.get("answer", "")
|
|
elif data.get("type") == "end":
|
|
break
|
|
except json_module.JSONDecodeError:
|
|
pass
|
|
|
|
self.print_success(f"Received {len(events)} events")
|
|
self.print_info(f"Response preview: {full_response[:100]}...")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Answer Endpoint Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_answer_endpoint_no_agent(self) -> bool:
|
|
"""Test /api/answer endpoint without agent."""
|
|
test_name = "Answer endpoint (no agent)"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {
|
|
"question": "What is DocsGPT?",
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
}
|
|
|
|
try:
|
|
self.print_info("POST /api/answer")
|
|
self.print_info(f"Payload: {json_module.dumps(payload, indent=2)}")
|
|
|
|
response = self.post("/api/answer", json=payload, timeout=30)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.print_error(f"Response: {response.text[:500]}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
result = response.json()
|
|
self.print_info(f"Response keys: {list(result.keys())}")
|
|
|
|
if "answer" in result:
|
|
answer = result["answer"]
|
|
self.print_success(f"Answer received: {answer[:100]}...")
|
|
else:
|
|
self.print_warning("No 'answer' field in response")
|
|
|
|
if "conversation_id" in result:
|
|
self.print_success(f"Conversation ID: {result['conversation_id']}")
|
|
|
|
self.record_result(test_name, True, "Success")
|
|
self.print_success(f"{test_name} passed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_answer_endpoint_with_agent(self) -> bool:
|
|
"""Test /api/answer endpoint with agent_id."""
|
|
test_name = "Answer endpoint (with agent)"
|
|
|
|
agent_result = self.get_or_create_test_agent()
|
|
if not agent_result:
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
self.record_result(test_name, True, "Skipped (no agent)")
|
|
return True
|
|
|
|
agent_id, _ = agent_result
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {
|
|
"question": "What is DocsGPT?",
|
|
"history": "[]",
|
|
"agent_id": agent_id,
|
|
}
|
|
|
|
try:
|
|
self.print_info(f"POST /api/answer with agent_id={agent_id[:8]}...")
|
|
|
|
response = self.post("/api/answer", json=payload, timeout=30)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
result = response.json()
|
|
answer = result.get("answer", "")
|
|
self.print_success(f"Answer received: {answer[:100]}...")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_answer_endpoint_with_api_key(self) -> bool:
|
|
"""Test /api/answer endpoint with API key."""
|
|
test_name = "Answer endpoint (with API key)"
|
|
|
|
agent_result = self.get_or_create_published_agent()
|
|
if not agent_result or not agent_result[1]:
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
self.record_result(test_name, True, "Skipped (no API key)")
|
|
return True
|
|
|
|
_, api_key = agent_result
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {
|
|
"question": "What is DocsGPT?",
|
|
"history": "[]",
|
|
"api_key": api_key,
|
|
}
|
|
|
|
try:
|
|
self.print_info(f"POST /api/answer with api_key={api_key[:20]}...")
|
|
|
|
response = self.post("/api/answer", json=payload, timeout=30)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
result = response.json()
|
|
answer = result.get("answer", "")
|
|
self.print_success(f"Answer received: {answer[:100]}...")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Validation Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_model_validation_invalid_model_id(self) -> bool:
|
|
"""Test that invalid model_id is rejected."""
|
|
test_name = "Model validation (invalid model_id)"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {
|
|
"question": "Test question",
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
"model_id": "invalid-model-xyz-123",
|
|
}
|
|
|
|
try:
|
|
self.print_info("POST /stream with invalid model_id")
|
|
|
|
response = requests.post(
|
|
f"{self.base_url}/stream",
|
|
json=payload,
|
|
headers=self.headers,
|
|
stream=True,
|
|
timeout=10,
|
|
)
|
|
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 400:
|
|
# Read error from SSE stream
|
|
error_message = None
|
|
for line in response.iter_lines():
|
|
if line:
|
|
line_str = line.decode("utf-8")
|
|
if line_str.startswith("data: "):
|
|
try:
|
|
data = json_module.loads(line_str[6:])
|
|
if data.get("type") == "error":
|
|
error_message = data.get("message") or data.get("error", "")
|
|
break
|
|
except json_module.JSONDecodeError:
|
|
pass
|
|
|
|
if error_message:
|
|
self.print_success("Invalid model_id rejected with 400 status")
|
|
self.print_info(f"Error: {error_message[:200]}")
|
|
self.record_result(test_name, True, "Validation works")
|
|
return True
|
|
else:
|
|
self.print_warning("No error message in response")
|
|
self.record_result(test_name, False, "No error message")
|
|
return False
|
|
else:
|
|
self.print_warning(f"Expected 400, got {response.status_code}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Compression Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_compression_heavy_tool_usage(self) -> bool:
|
|
"""Test compression with heavy conversation usage."""
|
|
test_name = "Compression - Heavy Tool Usage"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
self.print_info("Making 10 consecutive requests to build conversation history...")
|
|
|
|
current_conv_id = None
|
|
|
|
for i in range(10):
|
|
question = f"Tell me about Python topic {i+1}: data structures, decorators, async, testing. Provide a comprehensive explanation."
|
|
|
|
payload = {
|
|
"question": question,
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
}
|
|
|
|
if current_conv_id:
|
|
payload["conversation_id"] = current_conv_id
|
|
|
|
try:
|
|
response = self.post("/api/answer", json=payload, timeout=90)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
current_conv_id = result.get("conversation_id", current_conv_id)
|
|
answer_preview = (result.get("answer") or "")[:80]
|
|
self.print_success(f"Request {i+1}/10 completed")
|
|
self.print_info(f" Answer: {answer_preview}...")
|
|
else:
|
|
self.print_error(f"Request {i+1}/10 failed: status {response.status_code}")
|
|
self.record_result(test_name, False, f"Request {i+1} failed")
|
|
return False
|
|
|
|
time.sleep(2)
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Request {i+1}/10 failed: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
if current_conv_id:
|
|
self.print_success("Heavy usage test completed")
|
|
self.record_result(test_name, True, f"10 requests, conv_id: {current_conv_id}")
|
|
return True
|
|
else:
|
|
self.print_warning("No conversation_id received")
|
|
self.record_result(test_name, False, "No conversation_id")
|
|
return False
|
|
|
|
def test_compression_needle_in_haystack(self) -> bool:
|
|
"""Test that compression preserves critical information.
|
|
|
|
Note: This is a long-running test that may timeout due to LLM response times.
|
|
Timeouts are handled gracefully as they indicate performance issues, not bugs.
|
|
"""
|
|
test_name = "Compression - Needle in Haystack"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
if not self.require_auth(test_name):
|
|
return True
|
|
|
|
conversation_id = None
|
|
|
|
# Step 1: Send general questions
|
|
self.print_info("Step 1: Sending general questions...")
|
|
for i, question in enumerate([
|
|
"Tell me about Python best practices in detail",
|
|
"Explain Python data structures comprehensively",
|
|
]):
|
|
payload = {
|
|
"question": question,
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
}
|
|
if conversation_id:
|
|
payload["conversation_id"] = conversation_id
|
|
|
|
try:
|
|
response = self.post("/api/answer", json=payload, timeout=90)
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
conversation_id = result.get("conversation_id", conversation_id)
|
|
self.print_success(f"General question {i+1}/2 completed")
|
|
else:
|
|
self.print_error(f"Request failed: status {response.status_code}")
|
|
self.record_result(test_name, False, "General questions failed")
|
|
return False
|
|
time.sleep(2)
|
|
except Exception as e:
|
|
# Timeout errors are expected for long LLM responses
|
|
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
|
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
|
self.record_result(test_name, True, "Skipped (timeout)")
|
|
return True
|
|
self.print_error(f"Request failed: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# Step 2: Send critical information
|
|
self.print_info("Step 2: Sending CRITICAL information...")
|
|
critical_payload = {
|
|
"question": "Please remember: The production database password is stored in DB_PASSWORD_PROD environment variable. The backup runs at 3:00 AM UTC daily.",
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
"conversation_id": conversation_id,
|
|
}
|
|
|
|
try:
|
|
response = self.post("/api/answer", json=critical_payload, timeout=90)
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
conversation_id = result.get("conversation_id", conversation_id)
|
|
self.print_success("Critical information sent")
|
|
else:
|
|
self.record_result(test_name, False, "Critical info failed")
|
|
return False
|
|
time.sleep(2)
|
|
except Exception as e:
|
|
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
|
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
|
self.record_result(test_name, True, "Skipped (timeout)")
|
|
return True
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# Step 3: Bury with more questions
|
|
self.print_info("Step 3: Sending more questions to bury critical info...")
|
|
for i, question in enumerate([
|
|
"Explain Python decorators in great detail",
|
|
"Tell me about Python async programming comprehensively",
|
|
]):
|
|
payload = {
|
|
"question": question,
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
"conversation_id": conversation_id,
|
|
}
|
|
|
|
try:
|
|
response = self.post("/api/answer", json=payload, timeout=90)
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
conversation_id = result.get("conversation_id", conversation_id)
|
|
self.print_success(f"Burying question {i+1}/2 completed")
|
|
else:
|
|
self.record_result(test_name, False, "Burying questions failed")
|
|
return False
|
|
time.sleep(2)
|
|
except Exception as e:
|
|
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
|
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
|
self.record_result(test_name, True, "Skipped (timeout)")
|
|
return True
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# Step 4: Test recall
|
|
self.print_info("Step 4: Testing if critical info was preserved...")
|
|
recall_payload = {
|
|
"question": "What was the database password environment variable I mentioned earlier?",
|
|
"history": "[]",
|
|
"isNoneDoc": True,
|
|
"conversation_id": conversation_id,
|
|
}
|
|
|
|
try:
|
|
response = self.post("/api/answer", json=recall_payload, timeout=90)
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
answer = (result.get("answer") or "").lower()
|
|
|
|
if "db_password_prod" in answer or "database password" in answer:
|
|
self.print_success("Critical information preserved!")
|
|
self.print_info(f"Answer: {answer[:150]}...")
|
|
self.record_result(test_name, True, "Info preserved")
|
|
return True
|
|
else:
|
|
self.print_warning("Critical information may have been lost")
|
|
self.print_info(f"Answer: {answer[:150]}...")
|
|
self.record_result(test_name, False, "Info not preserved")
|
|
return False
|
|
else:
|
|
self.record_result(test_name, False, "Recall failed")
|
|
return False
|
|
except Exception as e:
|
|
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
|
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
|
self.record_result(test_name, True, "Skipped (timeout)")
|
|
return True
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Feedback Tests (NEW)
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_feedback_positive(self) -> bool:
|
|
"""Test positive feedback submission."""
|
|
test_name = "Feedback - Positive"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
# First create a conversation to get an ID
|
|
answer_response = self.post(
|
|
"/api/answer",
|
|
json={"question": "Hello", "history": "[]", "isNoneDoc": True},
|
|
timeout=30,
|
|
)
|
|
|
|
if answer_response.status_code != 200:
|
|
self.print_warning("Could not create conversation for feedback test")
|
|
self.record_result(test_name, True, "Skipped (no conversation)")
|
|
return True
|
|
|
|
result = answer_response.json()
|
|
conversation_id = result.get("conversation_id")
|
|
|
|
if not conversation_id:
|
|
self.record_result(test_name, True, "Skipped (no conversation_id)")
|
|
return True
|
|
|
|
payload = {
|
|
"question": "Hello",
|
|
"answer": result.get("answer", ""),
|
|
"feedback": "like",
|
|
"conversation_id": conversation_id,
|
|
"question_index": 0, # Required field
|
|
}
|
|
|
|
try:
|
|
response = self.post("/api/feedback", json=payload, timeout=10)
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
self.print_success("Positive feedback submitted")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
else:
|
|
self.print_error(f"Expected 200, got {response.status_code}")
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.print_error(f"Error: {str(e)}")
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
def test_feedback_negative(self) -> bool:
|
|
"""Test negative feedback submission."""
|
|
test_name = "Feedback - Negative"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
answer_response = self.post(
|
|
"/api/answer",
|
|
json={"question": "Hello", "history": "[]", "isNoneDoc": True},
|
|
timeout=30,
|
|
)
|
|
|
|
if answer_response.status_code != 200:
|
|
self.record_result(test_name, True, "Skipped (no conversation)")
|
|
return True
|
|
|
|
result = answer_response.json()
|
|
conversation_id = result.get("conversation_id")
|
|
|
|
if not conversation_id:
|
|
self.record_result(test_name, True, "Skipped (no conversation_id)")
|
|
return True
|
|
|
|
payload = {
|
|
"question": "Hello",
|
|
"answer": result.get("answer", ""),
|
|
"feedback": "dislike",
|
|
"conversation_id": conversation_id,
|
|
"question_index": 0, # Required field
|
|
}
|
|
|
|
try:
|
|
response = self.post("/api/feedback", json=payload, timeout=10)
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
self.print_success("Negative feedback submitted")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# TTS Tests (NEW)
|
|
# -------------------------------------------------------------------------
|
|
|
|
def test_tts_basic(self) -> bool:
|
|
"""Test basic text-to-speech endpoint."""
|
|
test_name = "TTS - Basic"
|
|
self.print_header(f"Testing {test_name}")
|
|
|
|
payload = {"text": "Hello, this is a test of the text to speech system."}
|
|
|
|
try:
|
|
response = self.post("/api/tts", json=payload, timeout=30)
|
|
self.print_info(f"Status Code: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
content_type = response.headers.get("Content-Type", "")
|
|
self.print_success(f"TTS response received, Content-Type: {content_type}")
|
|
self.record_result(test_name, True, "Success")
|
|
return True
|
|
elif response.status_code == 501:
|
|
self.print_warning("TTS not implemented/configured")
|
|
self.record_result(test_name, True, "Skipped (not configured)")
|
|
return True
|
|
else:
|
|
self.record_result(test_name, False, f"Status {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.record_result(test_name, False, str(e))
|
|
return False
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Run All Tests
|
|
# -------------------------------------------------------------------------
|
|
|
|
def run_all(self) -> bool:
|
|
"""Run all chat integration tests."""
|
|
self.print_header("Chat Integration Tests")
|
|
self.print_info(f"Base URL: {self.base_url}")
|
|
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
|
|
|
|
# Basic endpoint tests
|
|
self.test_stream_endpoint_no_agent()
|
|
time.sleep(1)
|
|
|
|
self.test_answer_endpoint_no_agent()
|
|
time.sleep(1)
|
|
|
|
# Validation tests
|
|
self.test_model_validation_invalid_model_id()
|
|
time.sleep(1)
|
|
|
|
# Agent-based tests
|
|
self.test_stream_endpoint_with_agent()
|
|
time.sleep(1)
|
|
|
|
self.test_answer_endpoint_with_agent()
|
|
time.sleep(1)
|
|
|
|
# API key tests
|
|
self.test_stream_endpoint_with_api_key()
|
|
time.sleep(1)
|
|
|
|
self.test_answer_endpoint_with_api_key()
|
|
time.sleep(1)
|
|
|
|
# Feedback tests
|
|
self.test_feedback_positive()
|
|
time.sleep(1)
|
|
|
|
self.test_feedback_negative()
|
|
time.sleep(1)
|
|
|
|
# TTS test
|
|
self.test_tts_basic()
|
|
time.sleep(1)
|
|
|
|
# Compression tests (longer running)
|
|
if self.is_authenticated:
|
|
self.test_compression_heavy_tool_usage()
|
|
time.sleep(2)
|
|
|
|
self.test_compression_needle_in_haystack()
|
|
else:
|
|
self.print_info("Skipping compression tests (no authentication)")
|
|
|
|
return self.print_summary()
|
|
|
|
|
|
def main():
|
|
"""Main entry point for standalone execution."""
|
|
client = create_client_from_args(ChatTests, "DocsGPT Chat Integration Tests")
|
|
success = client.run_all()
|
|
sys.exit(0 if success else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|