feat: tests and approval gate

This commit is contained in:
Alex
2026-04-01 12:49:32 +01:00
parent e5586b6f20
commit e04baa7ed8
8 changed files with 985 additions and 192 deletions

View File

@@ -90,7 +90,7 @@ class ToolExecutor:
Args:
tools_dict: The mutable server tools dict (will be modified in place).
client_tools: List of tool definitions in OpenAI function-calling format.
client_tools: List of tool definitions in function-calling format.
Returns:
The updated *tools_dict* (same reference, for convenience).
@@ -138,7 +138,7 @@ class ToolExecutor:
if not action.get("active", True):
continue
# Client-side tools already have parameters in OpenAI format
# Client-side tools already have parameters in standard format
if is_client:
params = action.get("parameters", {})
else:
@@ -185,7 +185,7 @@ class ToolExecutor:
tool_data = tools_dict[tool_id]
# Phase 2: client-side tools
# Client-side tools
if tool_data.get("client_side"):
return {
"call_id": call_id,
@@ -198,7 +198,7 @@ class ToolExecutor:
"thought_signature": getattr(call, "thought_signature", None),
}
# Phase 3: approval required
# Approval required
if tool_data["name"] == "api_tool":
action_data = tool_data.get("config", {}).get("actions", {}).get(
action_name, {}

View File

@@ -294,36 +294,80 @@ class BaseAnswerResource:
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation and conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
if continuation:
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
conversation_id = (
self.conversation_service.save_conversation(
None,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
)
except Exception as e:
logger.error(
f"Failed to create conversation for continuation: {e}",
exc_info=True,
)
if conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)

View File

@@ -75,7 +75,7 @@ class ContinuationService:
tools_dict: Serializable tools configuration dict.
tool_schemas: LLM-formatted tool schemas (agent.tools).
agent_config: Config needed to recreate the agent on resume.
client_tools: Client-provided tool schemas (Phase 2).
client_tools: Client-provided tool schemas for client-side execution.
Returns:
The string ID of the saved state document.

View File

@@ -956,7 +956,7 @@ class StreamProcessor:
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = self.conversation_id
# Pass client-side tools (Phase 2) so they get merged in get_tools()
# Pass client-side tools so they get merged in get_tools()
client_tools = self.data.get("client_tools")
if client_tools:
tool_executor.client_tools = client_tools

View File

@@ -145,7 +145,7 @@ def translate_request(
"save_conversation": True,
}
# Client tools (Phase 2)
# Client tools
if data.get("tools"):
result["client_tools"] = data["tools"]

View File

@@ -890,6 +890,97 @@ function AllSources(sources: AllSourcesProps) {
}
export default ConversationBubble;
function ToolCallApprovalBar({
toolCall,
onToolAction,
}: {
toolCall: ToolCallsType;
onToolAction?: (
callId: string,
decision: 'approved' | 'denied',
comment?: string,
) => void;
}) {
const [expanded, setExpanded] = useState(false);
const [comment, setComment] = useState('');
const actionLabel = toolCall.action_name.substring(
0,
toolCall.action_name.lastIndexOf('_'),
);
const argPreview = JSON.stringify(toolCall.arguments);
const truncated =
argPreview.length > 60 ? argPreview.slice(0, 57) + '...' : argPreview;
return (
<div className="border-border bg-muted dark:bg-card mb-2 w-full overflow-hidden rounded-2xl border">
<div className="flex items-center gap-3 px-4 py-2.5">
<div className="flex min-w-0 flex-1 items-center gap-2">
<span className="text-sm font-semibold whitespace-nowrap">
{toolCall.tool_name}
</span>
<span className="text-muted-foreground text-xs">{actionLabel}</span>
<span
className="text-muted-foreground hidden min-w-0 truncate font-mono text-xs md:block"
title={argPreview}
>
{truncated}
</span>
</div>
<div className="flex items-center gap-2">
<button
className="bg-primary hover:bg-primary/90 rounded-full px-4 py-1 text-xs font-medium text-white"
onClick={() => onToolAction?.(toolCall.call_id, 'approved')}
>
Approve
</button>
<button
className="hover:bg-accent text-muted-foreground rounded-full border px-4 py-1 text-xs font-medium"
onClick={() => {
if (expanded && comment) {
onToolAction?.(toolCall.call_id, 'denied', comment);
} else if (expanded) {
onToolAction?.(toolCall.call_id, 'denied');
} else {
setExpanded(true);
}
}}
>
Deny
</button>
<button
className="text-muted-foreground hover:text-foreground flex h-6 w-6 items-center justify-center rounded-full transition-colors"
onClick={() => setExpanded(!expanded)}
title="Details"
>
<img
src={ChevronDown}
alt="expand"
className={`h-3.5 w-3.5 transition-transform duration-200 dark:invert ${expanded ? 'rotate-180' : ''}`}
/>
</button>
</div>
</div>
{expanded && (
<div className="border-border border-t px-4 py-3">
<p className="text-muted-foreground mb-1 text-xs font-medium">
Arguments
</p>
<pre className="bg-background dark:bg-background/50 mb-2 max-h-40 overflow-auto rounded-lg p-2 font-mono text-xs">
{JSON.stringify(toolCall.arguments, null, 2)}
</pre>
<input
type="text"
placeholder="Optional reason for denying..."
className="border-border bg-background w-full rounded-lg border px-3 py-1.5 text-sm"
value={comment}
onChange={(e) => setComment(e.target.value)}
/>
</div>
)}
</div>
);
}
function ToolCalls({
toolCalls,
onToolAction,
@@ -902,170 +993,145 @@ function ToolCalls({
) => void;
}) {
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
const [denyComments, setDenyComments] = useState<Record<string, string>>({});
const hasAwaitingApproval = toolCalls.some(
const awaitingCalls = toolCalls.filter(
(tc) => tc.status === 'awaiting_approval',
);
const resolvedCalls = toolCalls.filter(
(tc) => tc.status !== 'awaiting_approval',
);
return (
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
alt={'ToolCalls'}
className="h-full w-full object-fill"
{/* Approval bars — always visible, compact inline */}
{awaitingCalls.length > 0 && (
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
{awaitingCalls.map((tc) => (
<ToolCallApprovalBar
key={`approval-${tc.call_id}`}
toolCall={tc}
onToolAction={onToolAction}
/>
}
/>
<button
className="flex flex-row items-center gap-2"
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
>
<p className="text-base font-semibold">
Tool Calls
{hasAwaitingApproval && (
<span className="ml-2 text-xs font-normal text-yellow-600 dark:text-yellow-400">
(approval needed)
</span>
)}
</p>
<img
src={ChevronDown}
alt="ChevronDown"
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen || hasAwaitingApproval ? 'rotate-180' : ''}`}
/>
</button>
</div>
{(isToolCallsOpen || hasAwaitingApproval) && (
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
<div className="grid grid-cols-1 gap-2">
{toolCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
titleClassName="px-6 py-2 text-sm font-semibold"
open={toolCall.status === 'awaiting_approval'}
>
<div className="flex flex-col gap-1">
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Arguments
</span>{' '}
<CopyButton
textToCopy={JSON.stringify(toolCall.arguments, null, 2)}
/>
</p>
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.arguments, null, 2)}
</span>
</p>
</div>
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Response
</span>{' '}
<CopyButton
textToCopy={
toolCall.status === 'error'
? toolCall.error || 'Unknown error'
: JSON.stringify(toolCall.result, null, 2)
}
/>
</p>
{toolCall.status === 'pending' && (
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
<Spinner size="small" />
</span>
)}
{toolCall.status === 'completed' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.result, null, 2)}
</span>
</p>
)}
{toolCall.status === 'error' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="leading-[23px] text-red-500 dark:text-red-400"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
</span>
</p>
)}
{toolCall.status === 'awaiting_approval' && (
<div className="dark:bg-card flex flex-col gap-2 rounded-b-2xl p-3">
<p className="text-sm text-yellow-600 dark:text-yellow-400">
This tool requires your approval before executing.
</p>
<input
type="text"
placeholder="Optional comment (for deny)..."
className="border-border bg-background w-full rounded-lg border px-3 py-1.5 text-sm"
value={denyComments[toolCall.call_id] || ''}
onChange={(e) =>
setDenyComments((prev) => ({
...prev,
[toolCall.call_id]: e.target.value,
}))
}
/>
<div className="flex gap-2">
<button
className="rounded-lg bg-green-600 px-4 py-1.5 text-sm font-medium text-white hover:bg-green-700"
onClick={() =>
onToolAction?.(toolCall.call_id, 'approved')
}
>
Approve
</button>
<button
className="rounded-lg bg-red-600 px-4 py-1.5 text-sm font-medium text-white hover:bg-red-700"
onClick={() =>
onToolAction?.(
toolCall.call_id,
'denied',
denyComments[toolCall.call_id],
)
}
>
Deny
</button>
</div>
</div>
)}
{toolCall.status === 'denied' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="leading-[23px] text-orange-500 dark:text-orange-400"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
Denied by user
</span>
</p>
)}
</div>
</div>
</Accordion>
))}
</div>
))}
</div>
)}
{/* Regular tool calls accordion */}
{resolvedCalls.length > 0 && (
<>
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
alt={'ToolCalls'}
className="h-full w-full object-fill"
/>
}
/>
<button
className="flex flex-row items-center gap-2"
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
>
<p className="text-base font-semibold">Tool Calls</p>
<img
src={ChevronDown}
alt="ChevronDown"
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
/>
</button>
</div>
{isToolCallsOpen && (
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
<div className="grid grid-cols-1 gap-2">
{resolvedCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
titleClassName="px-6 py-2 text-sm font-semibold"
>
<div className="flex flex-col gap-1">
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Arguments
</span>{' '}
<CopyButton
textToCopy={JSON.stringify(
toolCall.arguments,
null,
2,
)}
/>
</p>
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.arguments, null, 2)}
</span>
</p>
</div>
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Response
</span>{' '}
<CopyButton
textToCopy={
toolCall.status === 'error'
? toolCall.error || 'Unknown error'
: JSON.stringify(toolCall.result, null, 2)
}
/>
</p>
{toolCall.status === 'pending' && (
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
<Spinner size="small" />
</span>
)}
{toolCall.status === 'completed' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.result, null, 2)}
</span>
</p>
)}
{toolCall.status === 'error' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-destructive leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
</span>
</p>
)}
{toolCall.status === 'denied' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-muted-foreground leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
Denied by user
</span>
</p>
)}
</div>
</div>
</Accordion>
))}
</div>
</div>
)}
</>
)}
</div>
);
}

View File

@@ -50,6 +50,7 @@ from tests.integration.test_analytics import AnalyticsTests
from tests.integration.test_connectors import ConnectorTests
from tests.integration.test_mcp import MCPTests
from tests.integration.test_misc import MiscTests
from tests.integration.test_v1_api import V1ApiTests
# Module registry
@@ -64,6 +65,7 @@ MODULES = {
"connectors": ConnectorTests,
"mcp": MCPTests,
"misc": MiscTests,
"v1_api": V1ApiTests,
}

View File

@@ -0,0 +1,681 @@
#!/usr/bin/env python3
"""
Integration tests for the /v1/ chat completions API (Phase 4).
Endpoints tested:
- /v1/chat/completions (POST) - Standard chat completions (streaming & non-streaming)
- /v1/models (GET) - List available agent models
Usage:
python tests/integration/test_v1_api.py
python tests/integration/test_v1_api.py --base-url http://localhost:7091
python tests/integration/test_v1_api.py --token YOUR_JWT_TOKEN
"""
import json as json_module
import sys
import time
from pathlib import Path
from typing import Optional
import requests
# Add parent directory to path for standalone execution
_THIS_DIR = Path(__file__).parent
_TESTS_DIR = _THIS_DIR.parent
_ROOT_DIR = _TESTS_DIR.parent
if str(_ROOT_DIR) not in sys.path:
sys.path.insert(0, str(_ROOT_DIR))
from tests.integration.base import DocsGPTTestBase, create_client_from_args
class V1ApiTests(DocsGPTTestBase):
"""Integration tests for /v1/ chat completions API."""
# -------------------------------------------------------------------------
# Test Data Helpers
# -------------------------------------------------------------------------
def get_or_create_agent_key(self) -> Optional[str]:
"""Get or create a test agent and return its API key."""
if hasattr(self, "_agent_key") and self._agent_key:
return self._agent_key
# Try both authenticated and unauthenticated creation.
# Published agents need a source to get an API key.
payload = {
"name": f"V1 Test Agent {int(time.time())}",
"description": "Integration test agent for v1 API tests",
"prompt_id": "default",
"chunks": 2,
"retriever": "classic",
"agent_type": "classic",
"status": "published",
"source": "default",
}
try:
response = self.post("/api/create_agent", json=payload, timeout=10)
if response.status_code in [200, 201]:
result = response.json()
api_key = result.get("key")
self._agent_id = result.get("id")
if api_key:
self._agent_key = api_key
self.print_info(f"Created test agent with key: {api_key[:8]}...")
return api_key
else:
self.print_warning("Agent created but no API key returned")
else:
self.print_warning(f"Agent creation returned {response.status_code}: {response.text[:200]}")
except Exception as e:
self.print_error(f"Failed to create agent: {e}")
return None
def _v1_headers(self, api_key: str) -> dict:
"""Build headers for v1 API requests."""
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# -------------------------------------------------------------------------
# /v1/chat/completions — Auth Tests
# -------------------------------------------------------------------------
def test_no_auth_returns_401(self) -> bool:
"""Test that /v1/chat/completions without auth returns 401."""
test_name = "v1 chat completions - no auth"
self.print_header(f"Testing {test_name}")
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={"messages": [{"role": "user", "content": "Hi"}]},
headers={"Content-Type": "application/json"},
timeout=10,
)
if response.status_code == 401:
self.print_success("Correctly returned 401 for missing auth")
self.record_result(test_name, True, "401 as expected")
return True
else:
self.print_error(f"Expected 401, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Request failed: {e}")
self.record_result(test_name, False, str(e))
return False
def test_invalid_key_returns_error(self) -> bool:
"""Test that invalid API key returns error."""
test_name = "v1 chat completions - invalid key"
self.print_header(f"Testing {test_name}")
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={"messages": [{"role": "user", "content": "Hi"}]},
headers=self._v1_headers("invalid-key-12345"),
timeout=30,
)
# Should return 400 or 500 (agent not found)
if response.status_code in [400, 401, 500]:
self.print_success(f"Correctly returned {response.status_code} for invalid key")
self.record_result(test_name, True, f"Error as expected ({response.status_code})")
return True
else:
self.print_error(f"Unexpected status: {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Request failed: {e}")
self.record_result(test_name, False, str(e))
return False
def test_missing_messages_returns_400(self) -> bool:
"""Test that missing messages field returns 400."""
test_name = "v1 chat completions - missing messages"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={"stream": False},
headers=self._v1_headers(api_key),
timeout=10,
)
if response.status_code == 400:
self.print_success("Correctly returned 400 for missing messages")
self.record_result(test_name, True, "400 as expected")
return True
else:
self.print_error(f"Expected 400, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Request failed: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/chat/completions — Non-streaming
# -------------------------------------------------------------------------
def test_non_streaming_basic(self) -> bool:
"""Test basic non-streaming chat completion."""
test_name = "v1 chat completions - non-streaming"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Say hello in one word."}],
"stream": False,
},
headers=self._v1_headers(api_key),
timeout=60,
)
self.print_info(f"Status: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.print_error(f"Response: {response.text[:300]}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
# Verify standard format
checks = [
("id" in data, "has id"),
(data.get("object") == "chat.completion", "object is chat.completion"),
("choices" in data, "has choices"),
(len(data["choices"]) > 0, "choices not empty"),
(data["choices"][0].get("message", {}).get("role") == "assistant", "role is assistant"),
(data["choices"][0].get("message", {}).get("content") is not None, "has content"),
(data["choices"][0].get("finish_reason") == "stop", "finish_reason is stop"),
("usage" in data, "has usage"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
content = data["choices"][0]["message"]["content"]
self.print_info(f"Response: {content[:100]}")
# Check docsgpt extension
if "docsgpt" in data:
self.print_success(" has docsgpt extension")
if "conversation_id" in data["docsgpt"]:
self.print_success(f" conversation_id: {data['docsgpt']['conversation_id'][:8]}...")
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/chat/completions — Streaming
# -------------------------------------------------------------------------
def test_streaming_basic(self) -> bool:
"""Test basic streaming chat completion."""
test_name = "v1 chat completions - streaming"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Say hi briefly."}],
"stream": True,
},
headers=self._v1_headers(api_key),
stream=True,
timeout=60,
)
self.print_info(f"Status: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
chunks = []
content_pieces = []
got_done = False
got_stop = False
got_id = False
for line in response.iter_lines():
if not line:
continue
line_str = line.decode("utf-8")
if not line_str.startswith("data: "):
continue
data_str = line_str[6:]
if data_str.strip() == "[DONE]":
got_done = True
break
try:
chunk = json_module.loads(data_str)
chunks.append(chunk)
# Standard chunks
if "choices" in chunk:
delta = chunk["choices"][0].get("delta", {})
if "content" in delta:
content_pieces.append(delta["content"])
if chunk["choices"][0].get("finish_reason") == "stop":
got_stop = True
# Extension chunks
if "docsgpt" in chunk:
ext = chunk["docsgpt"]
if ext.get("type") == "id":
got_id = True
except json_module.JSONDecodeError:
pass
full_content = "".join(content_pieces)
checks = [
(len(chunks) > 0, f"received {len(chunks)} chunks"),
(len(content_pieces) > 0, f"got content: {full_content[:50]}..."),
(got_stop, "got finish_reason=stop"),
(got_done, "got [DONE] sentinel"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
if got_id:
self.print_success(" got conversation_id via docsgpt extension")
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/chat/completions — Multi-turn conversation
# -------------------------------------------------------------------------
def test_multi_turn_conversation(self) -> bool:
"""Test multi-turn conversation with history in messages."""
test_name = "v1 chat completions - multi-turn"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={
"messages": [
{"role": "user", "content": "My name is TestBot."},
{"role": "assistant", "content": "Hello TestBot!"},
{"role": "user", "content": "What is my name?"},
],
"stream": False,
},
headers=self._v1_headers(api_key),
timeout=60,
)
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
content = data["choices"][0]["message"]["content"]
self.print_info(f"Response: {content[:150]}")
# The response should reference "TestBot" from the history
has_content = bool(content)
self.print_success(f" Got response with {len(content)} chars")
self.record_result(test_name, has_content, "Multi-turn works")
return has_content
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/models
# -------------------------------------------------------------------------
def test_list_models(self) -> bool:
"""Test GET /v1/models endpoint."""
test_name = "v1 models - list"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.get(
f"{self.base_url}/v1/models",
headers=self._v1_headers(api_key),
timeout=10,
)
self.print_info(f"Status: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
checks = [
(data.get("object") == "list", "object is list"),
("data" in data, "has data array"),
(len(data.get("data", [])) > 0, f"has {len(data.get('data', []))} model(s)"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
if data.get("data"):
model = data["data"][0]
model_checks = [
("id" in model, "model has id"),
(model.get("object") == "model", "model object is 'model'"),
(model.get("owned_by") == "docsgpt", "owned_by is docsgpt"),
]
for check, label in model_checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_models_no_auth(self) -> bool:
"""Test that /v1/models without auth returns 401."""
test_name = "v1 models - no auth"
self.print_header(f"Testing {test_name}")
try:
response = requests.get(
f"{self.base_url}/v1/models",
timeout=10,
)
if response.status_code == 401:
self.print_success("Correctly returned 401")
self.record_result(test_name, True, "401 as expected")
return True
else:
self.print_error(f"Expected 401, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Backward Compatibility — old endpoints still work
# -------------------------------------------------------------------------
def test_old_stream_endpoint_still_works(self) -> bool:
"""Verify the old /stream endpoint still works after v1 changes."""
test_name = "Backward compat - /stream"
self.print_header(f"Testing {test_name}")
payload = {
"question": "Say hello briefly.",
"history": "[]",
"isNoneDoc": True,
}
try:
response = requests.post(
f"{self.base_url}/stream",
json=payload,
headers=self.headers,
stream=True,
timeout=60,
)
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
events = []
got_end = False
got_answer = False
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
try:
data = json_module.loads(line_str[6:])
events.append(data)
if data.get("type") == "answer":
got_answer = True
if data.get("type") == "end":
got_end = True
break
except json_module.JSONDecodeError:
pass
checks = [
(len(events) > 0, f"received {len(events)} events"),
(got_answer, "got answer event"),
(got_end, "got end event"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_old_answer_endpoint_still_works(self) -> bool:
"""Verify the old /api/answer endpoint still works."""
test_name = "Backward compat - /api/answer"
self.print_header(f"Testing {test_name}")
payload = {
"question": "Say hi.",
"history": "[]",
"isNoneDoc": True,
}
try:
response = requests.post(
f"{self.base_url}/api/answer",
json=payload,
headers=self.headers,
timeout=60,
)
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
checks = [
("answer" in data, "has answer"),
("conversation_id" in data, "has conversation_id"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
self.print_info(f"Answer: {data.get('answer', '')[:100]}")
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Cleanup
# -------------------------------------------------------------------------
def cleanup(self):
"""Clean up test resources."""
if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated:
try:
self.post(f"/api/delete_agent?id={self._agent_id}", json={})
self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...")
except Exception:
pass
# -------------------------------------------------------------------------
# Run All
# -------------------------------------------------------------------------
def run_all(self) -> bool:
"""Run all v1 API integration tests."""
self.print_header("V1 Chat Completions API Integration Tests")
self.print_info(f"Base URL: {self.base_url}")
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
try:
# Auth tests (no agent needed)
self.test_no_auth_returns_401()
time.sleep(0.5)
self.test_models_no_auth()
time.sleep(0.5)
self.test_invalid_key_returns_error()
time.sleep(0.5)
self.test_missing_messages_returns_400()
time.sleep(0.5)
# Non-streaming
self.test_non_streaming_basic()
time.sleep(1)
# Streaming
self.test_streaming_basic()
time.sleep(1)
# Multi-turn
self.test_multi_turn_conversation()
time.sleep(1)
# Models
self.test_list_models()
time.sleep(0.5)
# Backward compatibility
self.test_old_stream_endpoint_still_works()
time.sleep(1)
self.test_old_answer_endpoint_still_works()
time.sleep(1)
finally:
self.cleanup()
return self.print_summary()
def main():
"""Main entry point."""
client = create_client_from_args(V1ApiTests, "DocsGPT V1 API Integration Tests")
success = client.run_all()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()