mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
feat: tests and approval gate
This commit is contained in:
@@ -90,7 +90,7 @@ class ToolExecutor:
|
||||
|
||||
Args:
|
||||
tools_dict: The mutable server tools dict (will be modified in place).
|
||||
client_tools: List of tool definitions in OpenAI function-calling format.
|
||||
client_tools: List of tool definitions in function-calling format.
|
||||
|
||||
Returns:
|
||||
The updated *tools_dict* (same reference, for convenience).
|
||||
@@ -138,7 +138,7 @@ class ToolExecutor:
|
||||
if not action.get("active", True):
|
||||
continue
|
||||
|
||||
# Client-side tools already have parameters in OpenAI format
|
||||
# Client-side tools already have parameters in standard format
|
||||
if is_client:
|
||||
params = action.get("parameters", {})
|
||||
else:
|
||||
@@ -185,7 +185,7 @@ class ToolExecutor:
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
|
||||
# Phase 2: client-side tools
|
||||
# Client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
@@ -198,7 +198,7 @@ class ToolExecutor:
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
# Phase 3: approval required
|
||||
# Approval required
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||
action_name, {}
|
||||
|
||||
@@ -294,36 +294,80 @@ class BaseAnswerResource:
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation and conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.save_state(
|
||||
conversation_id=str(conversation_id),
|
||||
user=decoded_token.get("sub", "local"),
|
||||
messages=continuation["messages"],
|
||||
pending_tool_calls=continuation["pending_tool_calls"],
|
||||
tools_dict=continuation["tools_dict"],
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent.__class__.__name__,
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
conversation_id = (
|
||||
self.conversation_service.save_conversation(
|
||||
None,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create conversation for continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.save_state(
|
||||
conversation_id=str(conversation_id),
|
||||
user=decoded_token.get("sub", "local"),
|
||||
messages=continuation["messages"],
|
||||
pending_tool_calls=continuation["pending_tool_calls"],
|
||||
tools_dict=continuation["tools_dict"],
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent.__class__.__name__,
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
|
||||
@@ -75,7 +75,7 @@ class ContinuationService:
|
||||
tools_dict: Serializable tools configuration dict.
|
||||
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||
agent_config: Config needed to recreate the agent on resume.
|
||||
client_tools: Client-provided tool schemas (Phase 2).
|
||||
client_tools: Client-provided tool schemas for client-side execution.
|
||||
|
||||
Returns:
|
||||
The string ID of the saved state document.
|
||||
|
||||
@@ -956,7 +956,7 @@ class StreamProcessor:
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
tool_executor.conversation_id = self.conversation_id
|
||||
# Pass client-side tools (Phase 2) so they get merged in get_tools()
|
||||
# Pass client-side tools so they get merged in get_tools()
|
||||
client_tools = self.data.get("client_tools")
|
||||
if client_tools:
|
||||
tool_executor.client_tools = client_tools
|
||||
|
||||
@@ -145,7 +145,7 @@ def translate_request(
|
||||
"save_conversation": True,
|
||||
}
|
||||
|
||||
# Client tools (Phase 2)
|
||||
# Client tools
|
||||
if data.get("tools"):
|
||||
result["client_tools"] = data["tools"]
|
||||
|
||||
|
||||
@@ -890,6 +890,97 @@ function AllSources(sources: AllSourcesProps) {
|
||||
}
|
||||
export default ConversationBubble;
|
||||
|
||||
function ToolCallApprovalBar({
|
||||
toolCall,
|
||||
onToolAction,
|
||||
}: {
|
||||
toolCall: ToolCallsType;
|
||||
onToolAction?: (
|
||||
callId: string,
|
||||
decision: 'approved' | 'denied',
|
||||
comment?: string,
|
||||
) => void;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [comment, setComment] = useState('');
|
||||
const actionLabel = toolCall.action_name.substring(
|
||||
0,
|
||||
toolCall.action_name.lastIndexOf('_'),
|
||||
);
|
||||
const argPreview = JSON.stringify(toolCall.arguments);
|
||||
const truncated =
|
||||
argPreview.length > 60 ? argPreview.slice(0, 57) + '...' : argPreview;
|
||||
|
||||
return (
|
||||
<div className="border-border bg-muted dark:bg-card mb-2 w-full overflow-hidden rounded-2xl border">
|
||||
<div className="flex items-center gap-3 px-4 py-2.5">
|
||||
<div className="flex min-w-0 flex-1 items-center gap-2">
|
||||
<span className="text-sm font-semibold whitespace-nowrap">
|
||||
{toolCall.tool_name}
|
||||
</span>
|
||||
<span className="text-muted-foreground text-xs">{actionLabel}</span>
|
||||
<span
|
||||
className="text-muted-foreground hidden min-w-0 truncate font-mono text-xs md:block"
|
||||
title={argPreview}
|
||||
>
|
||||
{truncated}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
className="bg-primary hover:bg-primary/90 rounded-full px-4 py-1 text-xs font-medium text-white"
|
||||
onClick={() => onToolAction?.(toolCall.call_id, 'approved')}
|
||||
>
|
||||
Approve
|
||||
</button>
|
||||
<button
|
||||
className="hover:bg-accent text-muted-foreground rounded-full border px-4 py-1 text-xs font-medium"
|
||||
onClick={() => {
|
||||
if (expanded && comment) {
|
||||
onToolAction?.(toolCall.call_id, 'denied', comment);
|
||||
} else if (expanded) {
|
||||
onToolAction?.(toolCall.call_id, 'denied');
|
||||
} else {
|
||||
setExpanded(true);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Deny
|
||||
</button>
|
||||
<button
|
||||
className="text-muted-foreground hover:text-foreground flex h-6 w-6 items-center justify-center rounded-full transition-colors"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
title="Details"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt="expand"
|
||||
className={`h-3.5 w-3.5 transition-transform duration-200 dark:invert ${expanded ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{expanded && (
|
||||
<div className="border-border border-t px-4 py-3">
|
||||
<p className="text-muted-foreground mb-1 text-xs font-medium">
|
||||
Arguments
|
||||
</p>
|
||||
<pre className="bg-background dark:bg-background/50 mb-2 max-h-40 overflow-auto rounded-lg p-2 font-mono text-xs">
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</pre>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Optional reason for denying..."
|
||||
className="border-border bg-background w-full rounded-lg border px-3 py-1.5 text-sm"
|
||||
value={comment}
|
||||
onChange={(e) => setComment(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolCalls({
|
||||
toolCalls,
|
||||
onToolAction,
|
||||
@@ -902,170 +993,145 @@ function ToolCalls({
|
||||
) => void;
|
||||
}) {
|
||||
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
|
||||
const [denyComments, setDenyComments] = useState<Record<string, string>>({});
|
||||
|
||||
const hasAwaitingApproval = toolCalls.some(
|
||||
const awaitingCalls = toolCalls.filter(
|
||||
(tc) => tc.status === 'awaiting_approval',
|
||||
);
|
||||
const resolvedCalls = toolCalls.filter(
|
||||
(tc) => tc.status !== 'awaiting_approval',
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
alt={'ToolCalls'}
|
||||
className="h-full w-full object-fill"
|
||||
{/* Approval bars — always visible, compact inline */}
|
||||
{awaitingCalls.length > 0 && (
|
||||
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
{awaitingCalls.map((tc) => (
|
||||
<ToolCallApprovalBar
|
||||
key={`approval-${tc.call_id}`}
|
||||
toolCall={tc}
|
||||
onToolAction={onToolAction}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<button
|
||||
className="flex flex-row items-center gap-2"
|
||||
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
|
||||
>
|
||||
<p className="text-base font-semibold">
|
||||
Tool Calls
|
||||
{hasAwaitingApproval && (
|
||||
<span className="ml-2 text-xs font-normal text-yellow-600 dark:text-yellow-400">
|
||||
(approval needed)
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt="ChevronDown"
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen || hasAwaitingApproval ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
{(isToolCallsOpen || hasAwaitingApproval) && (
|
||||
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
<div className="grid grid-cols-1 gap-2">
|
||||
{toolCalls.map((toolCall, index) => (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
open={toolCall.status === 'awaiting_approval'}
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={JSON.stringify(toolCall.arguments, null, 2)}
|
||||
/>
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={
|
||||
toolCall.status === 'error'
|
||||
? toolCall.error || 'Unknown error'
|
||||
: JSON.stringify(toolCall.result, null, 2)
|
||||
}
|
||||
/>
|
||||
</p>
|
||||
{toolCall.status === 'pending' && (
|
||||
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
|
||||
<Spinner size="small" />
|
||||
</span>
|
||||
)}
|
||||
{toolCall.status === 'completed' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="leading-[23px] text-red-500 dark:text-red-400"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'awaiting_approval' && (
|
||||
<div className="dark:bg-card flex flex-col gap-2 rounded-b-2xl p-3">
|
||||
<p className="text-sm text-yellow-600 dark:text-yellow-400">
|
||||
This tool requires your approval before executing.
|
||||
</p>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Optional comment (for deny)..."
|
||||
className="border-border bg-background w-full rounded-lg border px-3 py-1.5 text-sm"
|
||||
value={denyComments[toolCall.call_id] || ''}
|
||||
onChange={(e) =>
|
||||
setDenyComments((prev) => ({
|
||||
...prev,
|
||||
[toolCall.call_id]: e.target.value,
|
||||
}))
|
||||
}
|
||||
/>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
className="rounded-lg bg-green-600 px-4 py-1.5 text-sm font-medium text-white hover:bg-green-700"
|
||||
onClick={() =>
|
||||
onToolAction?.(toolCall.call_id, 'approved')
|
||||
}
|
||||
>
|
||||
Approve
|
||||
</button>
|
||||
<button
|
||||
className="rounded-lg bg-red-600 px-4 py-1.5 text-sm font-medium text-white hover:bg-red-700"
|
||||
onClick={() =>
|
||||
onToolAction?.(
|
||||
toolCall.call_id,
|
||||
'denied',
|
||||
denyComments[toolCall.call_id],
|
||||
)
|
||||
}
|
||||
>
|
||||
Deny
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{toolCall.status === 'denied' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="leading-[23px] text-orange-500 dark:text-orange-400"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
Denied by user
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Accordion>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Regular tool calls accordion */}
|
||||
{resolvedCalls.length > 0 && (
|
||||
<>
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
alt={'ToolCalls'}
|
||||
className="h-full w-full object-fill"
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<button
|
||||
className="flex flex-row items-center gap-2"
|
||||
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
|
||||
>
|
||||
<p className="text-base font-semibold">Tool Calls</p>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt="ChevronDown"
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
{isToolCallsOpen && (
|
||||
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
<div className="grid grid-cols-1 gap-2">
|
||||
{resolvedCalls.map((toolCall, index) => (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={JSON.stringify(
|
||||
toolCall.arguments,
|
||||
null,
|
||||
2,
|
||||
)}
|
||||
/>
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={
|
||||
toolCall.status === 'error'
|
||||
? toolCall.error || 'Unknown error'
|
||||
: JSON.stringify(toolCall.result, null, 2)
|
||||
}
|
||||
/>
|
||||
</p>
|
||||
{toolCall.status === 'pending' && (
|
||||
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
|
||||
<Spinner size="small" />
|
||||
</span>
|
||||
)}
|
||||
{toolCall.status === 'completed' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-destructive leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'denied' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-muted-foreground leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
Denied by user
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Accordion>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ from tests.integration.test_analytics import AnalyticsTests
|
||||
from tests.integration.test_connectors import ConnectorTests
|
||||
from tests.integration.test_mcp import MCPTests
|
||||
from tests.integration.test_misc import MiscTests
|
||||
from tests.integration.test_v1_api import V1ApiTests
|
||||
|
||||
|
||||
# Module registry
|
||||
@@ -64,6 +65,7 @@ MODULES = {
|
||||
"connectors": ConnectorTests,
|
||||
"mcp": MCPTests,
|
||||
"misc": MiscTests,
|
||||
"v1_api": V1ApiTests,
|
||||
}
|
||||
|
||||
|
||||
|
||||
681
tests/integration/test_v1_api.py
Normal file
681
tests/integration/test_v1_api.py
Normal file
@@ -0,0 +1,681 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for the /v1/ chat completions API (Phase 4).
|
||||
|
||||
Endpoints tested:
|
||||
- /v1/chat/completions (POST) - Standard chat completions (streaming & non-streaming)
|
||||
- /v1/models (GET) - List available agent models
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_v1_api.py
|
||||
python tests/integration/test_v1_api.py --base-url http://localhost:7091
|
||||
python tests/integration/test_v1_api.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import json as json_module
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class V1ApiTests(DocsGPTTestBase):
|
||||
"""Integration tests for /v1/ chat completions API."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Data Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_agent_key(self) -> Optional[str]:
|
||||
"""Get or create a test agent and return its API key."""
|
||||
if hasattr(self, "_agent_key") and self._agent_key:
|
||||
return self._agent_key
|
||||
|
||||
# Try both authenticated and unauthenticated creation.
|
||||
# Published agents need a source to get an API key.
|
||||
payload = {
|
||||
"name": f"V1 Test Agent {int(time.time())}",
|
||||
"description": "Integration test agent for v1 API tests",
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"retriever": "classic",
|
||||
"agent_type": "classic",
|
||||
"status": "published",
|
||||
"source": "default",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_agent", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
api_key = result.get("key")
|
||||
self._agent_id = result.get("id")
|
||||
if api_key:
|
||||
self._agent_key = api_key
|
||||
self.print_info(f"Created test agent with key: {api_key[:8]}...")
|
||||
return api_key
|
||||
else:
|
||||
self.print_warning("Agent created but no API key returned")
|
||||
else:
|
||||
self.print_warning(f"Agent creation returned {response.status_code}: {response.text[:200]}")
|
||||
except Exception as e:
|
||||
self.print_error(f"Failed to create agent: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _v1_headers(self, api_key: str) -> dict:
|
||||
"""Build headers for v1 API requests."""
|
||||
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Auth Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_no_auth_returns_401(self) -> bool:
|
||||
"""Test that /v1/chat/completions without auth returns 401."""
|
||||
test_name = "v1 chat completions - no auth"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "Hi"}]},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
self.print_success("Correctly returned 401 for missing auth")
|
||||
self.record_result(test_name, True, "401 as expected")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected 401, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Request failed: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_invalid_key_returns_error(self) -> bool:
|
||||
"""Test that invalid API key returns error."""
|
||||
test_name = "v1 chat completions - invalid key"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "Hi"}]},
|
||||
headers=self._v1_headers("invalid-key-12345"),
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Should return 400 or 500 (agent not found)
|
||||
if response.status_code in [400, 401, 500]:
|
||||
self.print_success(f"Correctly returned {response.status_code} for invalid key")
|
||||
self.record_result(test_name, True, f"Error as expected ({response.status_code})")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Request failed: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_missing_messages_returns_400(self) -> bool:
|
||||
"""Test that missing messages field returns 400."""
|
||||
test_name = "v1 chat completions - missing messages"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={"stream": False},
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code == 400:
|
||||
self.print_success("Correctly returned 400 for missing messages")
|
||||
self.record_result(test_name, True, "400 as expected")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected 400, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Request failed: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Non-streaming
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_non_streaming_basic(self) -> bool:
|
||||
"""Test basic non-streaming chat completion."""
|
||||
test_name = "v1 chat completions - non-streaming"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello in one word."}],
|
||||
"stream": False,
|
||||
},
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
self.print_info(f"Status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.print_error(f"Response: {response.text[:300]}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Verify standard format
|
||||
checks = [
|
||||
("id" in data, "has id"),
|
||||
(data.get("object") == "chat.completion", "object is chat.completion"),
|
||||
("choices" in data, "has choices"),
|
||||
(len(data["choices"]) > 0, "choices not empty"),
|
||||
(data["choices"][0].get("message", {}).get("role") == "assistant", "role is assistant"),
|
||||
(data["choices"][0].get("message", {}).get("content") is not None, "has content"),
|
||||
(data["choices"][0].get("finish_reason") == "stop", "finish_reason is stop"),
|
||||
("usage" in data, "has usage"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
self.print_info(f"Response: {content[:100]}")
|
||||
|
||||
# Check docsgpt extension
|
||||
if "docsgpt" in data:
|
||||
self.print_success(" has docsgpt extension")
|
||||
if "conversation_id" in data["docsgpt"]:
|
||||
self.print_success(f" conversation_id: {data['docsgpt']['conversation_id'][:8]}...")
|
||||
|
||||
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Streaming
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_streaming_basic(self) -> bool:
|
||||
"""Test basic streaming chat completion."""
|
||||
test_name = "v1 chat completions - streaming"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hi briefly."}],
|
||||
"stream": True,
|
||||
},
|
||||
headers=self._v1_headers(api_key),
|
||||
stream=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
self.print_info(f"Status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
chunks = []
|
||||
content_pieces = []
|
||||
got_done = False
|
||||
got_stop = False
|
||||
got_id = False
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8")
|
||||
if not line_str.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line_str[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
got_done = True
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json_module.loads(data_str)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Standard chunks
|
||||
if "choices" in chunk:
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
if "content" in delta:
|
||||
content_pieces.append(delta["content"])
|
||||
if chunk["choices"][0].get("finish_reason") == "stop":
|
||||
got_stop = True
|
||||
|
||||
# Extension chunks
|
||||
if "docsgpt" in chunk:
|
||||
ext = chunk["docsgpt"]
|
||||
if ext.get("type") == "id":
|
||||
got_id = True
|
||||
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
full_content = "".join(content_pieces)
|
||||
|
||||
checks = [
|
||||
(len(chunks) > 0, f"received {len(chunks)} chunks"),
|
||||
(len(content_pieces) > 0, f"got content: {full_content[:50]}..."),
|
||||
(got_stop, "got finish_reason=stop"),
|
||||
(got_done, "got [DONE] sentinel"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
if got_id:
|
||||
self.print_success(" got conversation_id via docsgpt extension")
|
||||
|
||||
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Multi-turn conversation
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_multi_turn_conversation(self) -> bool:
|
||||
"""Test multi-turn conversation with history in messages."""
|
||||
test_name = "v1 chat completions - multi-turn"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "user", "content": "My name is TestBot."},
|
||||
{"role": "assistant", "content": "Hello TestBot!"},
|
||||
{"role": "user", "content": "What is my name?"},
|
||||
],
|
||||
"stream": False,
|
||||
},
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
self.print_info(f"Response: {content[:150]}")
|
||||
|
||||
# The response should reference "TestBot" from the history
|
||||
has_content = bool(content)
|
||||
self.print_success(f" Got response with {len(content)} chars")
|
||||
self.record_result(test_name, has_content, "Multi-turn works")
|
||||
return has_content
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/models
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_list_models(self) -> bool:
|
||||
"""Test GET /v1/models endpoint."""
|
||||
test_name = "v1 models - list"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
self.print_info(f"Status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
|
||||
checks = [
|
||||
(data.get("object") == "list", "object is list"),
|
||||
("data" in data, "has data array"),
|
||||
(len(data.get("data", [])) > 0, f"has {len(data.get('data', []))} model(s)"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
if data.get("data"):
|
||||
model = data["data"][0]
|
||||
model_checks = [
|
||||
("id" in model, "model has id"),
|
||||
(model.get("object") == "model", "model object is 'model'"),
|
||||
(model.get("owned_by") == "docsgpt", "owned_by is docsgpt"),
|
||||
]
|
||||
for check, label in model_checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_models_no_auth(self) -> bool:
|
||||
"""Test that /v1/models without auth returns 401."""
|
||||
test_name = "v1 models - no auth"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
self.print_success("Correctly returned 401")
|
||||
self.record_result(test_name, True, "401 as expected")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected 401, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Backward Compatibility — old endpoints still work
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_old_stream_endpoint_still_works(self) -> bool:
|
||||
"""Verify the old /stream endpoint still works after v1 changes."""
|
||||
test_name = "Backward compat - /stream"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "Say hello briefly.",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
events = []
|
||||
got_end = False
|
||||
got_answer = False
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
events.append(data)
|
||||
if data.get("type") == "answer":
|
||||
got_answer = True
|
||||
if data.get("type") == "end":
|
||||
got_end = True
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
checks = [
|
||||
(len(events) > 0, f"received {len(events)} events"),
|
||||
(got_answer, "got answer event"),
|
||||
(got_end, "got end event"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_old_answer_endpoint_still_works(self) -> bool:
|
||||
"""Verify the old /api/answer endpoint still works."""
|
||||
test_name = "Backward compat - /api/answer"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "Say hi.",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/api/answer",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
checks = [
|
||||
("answer" in data, "has answer"),
|
||||
("conversation_id" in data, "has conversation_id"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
self.print_info(f"Answer: {data.get('answer', '')[:100]}")
|
||||
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up test resources."""
|
||||
if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated:
|
||||
try:
|
||||
self.post(f"/api/delete_agent?id={self._agent_id}", json={})
|
||||
self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Run All
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all v1 API integration tests."""
|
||||
self.print_header("V1 Chat Completions API Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
|
||||
|
||||
try:
|
||||
# Auth tests (no agent needed)
|
||||
self.test_no_auth_returns_401()
|
||||
time.sleep(0.5)
|
||||
|
||||
self.test_models_no_auth()
|
||||
time.sleep(0.5)
|
||||
|
||||
self.test_invalid_key_returns_error()
|
||||
time.sleep(0.5)
|
||||
|
||||
self.test_missing_messages_returns_400()
|
||||
time.sleep(0.5)
|
||||
|
||||
# Non-streaming
|
||||
self.test_non_streaming_basic()
|
||||
time.sleep(1)
|
||||
|
||||
# Streaming
|
||||
self.test_streaming_basic()
|
||||
time.sleep(1)
|
||||
|
||||
# Multi-turn
|
||||
self.test_multi_turn_conversation()
|
||||
time.sleep(1)
|
||||
|
||||
# Models
|
||||
self.test_list_models()
|
||||
time.sleep(0.5)
|
||||
|
||||
# Backward compatibility
|
||||
self.test_old_stream_endpoint_still_works()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_old_answer_endpoint_still_works()
|
||||
time.sleep(1)
|
||||
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(V1ApiTests, "DocsGPT V1 API Integration Tests")
|
||||
success = client.run_all()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user