mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generator, List, Optional
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.core.json_schema_utils import (
|
||||
@@ -9,6 +10,7 @@ from application.core.json_schema_utils import (
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.base import ToolCall
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
@@ -113,6 +115,153 @@ class BaseAgent(ABC):
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
def gen_continuation(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools_dict: Dict,
|
||||
pending_tool_calls: List[Dict],
|
||||
tool_actions: List[Dict],
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Resume generation after tool actions are resolved.
|
||||
|
||||
Processes the client-provided *tool_actions* (approvals, denials,
|
||||
or client-side results), appends the resulting messages, then
|
||||
hands back to the LLM to continue the conversation.
|
||||
|
||||
Args:
|
||||
messages: The saved messages array from the pause point.
|
||||
tools_dict: The saved tools dictionary.
|
||||
pending_tool_calls: The pending tool call descriptors from the pause.
|
||||
tool_actions: Client-provided actions resolving the pending calls.
|
||||
"""
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||
|
||||
# Build a single assistant message containing all tool calls so
|
||||
# the message history matches the format LLM providers expect
|
||||
# (one assistant message with N tool_calls, followed by N tool results).
|
||||
tc_objects: List[Dict[str, Any]] = []
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
args_str = (
|
||||
json.dumps(args) if isinstance(args, dict) else (args or "{}")
|
||||
)
|
||||
tc_obj: Dict[str, Any] = {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": pending["name"],
|
||||
"arguments": args_str,
|
||||
},
|
||||
}
|
||||
if pending.get("thought_signature"):
|
||||
tc_obj["thought_signature"] = pending["thought_signature"]
|
||||
tc_objects.append(tc_obj)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tc_objects,
|
||||
})
|
||||
|
||||
# Now process each pending call and append tool result messages
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
action = actions_by_id.get(call_id)
|
||||
if not action:
|
||||
action = {
|
||||
"call_id": call_id,
|
||||
"decision": "denied",
|
||||
"comment": "No response provided",
|
||||
}
|
||||
|
||||
if action.get("decision") == "approved":
|
||||
# Execute the tool server-side
|
||||
tc = ToolCall(
|
||||
id=call_id,
|
||||
name=pending["name"],
|
||||
arguments=(
|
||||
json.dumps(args) if isinstance(args, dict) else args
|
||||
),
|
||||
)
|
||||
tool_gen = self._execute_tool_action(tools_dict, tc)
|
||||
tool_response = None
|
||||
while True:
|
||||
try:
|
||||
event = next(tool_gen)
|
||||
yield event
|
||||
except StopIteration as e:
|
||||
tool_response, _ = e.value
|
||||
break
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, tool_response)
|
||||
)
|
||||
|
||||
elif action.get("decision") == "denied":
|
||||
comment = action.get("comment", "")
|
||||
denial = (
|
||||
f"Tool execution denied by user. Reason: {comment}"
|
||||
if comment
|
||||
else "Tool execution denied by user."
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, denial)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"status": "denied",
|
||||
},
|
||||
}
|
||||
|
||||
elif "result" in action:
|
||||
result = action["result"]
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, result_str)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"result": (
|
||||
result_str[:50] + "..."
|
||||
if len(result_str) > 50
|
||||
else result_str
|
||||
),
|
||||
"status": "completed",
|
||||
},
|
||||
}
|
||||
|
||||
# Resume the LLM loop with the updated messages
|
||||
llm_response = self._llm_gen(messages)
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, None
|
||||
)
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||
|
||||
@property
|
||||
@@ -267,28 +416,35 @@ class BaseAgent(ABC):
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
|
||||
|
||||
@@ -593,16 +593,22 @@ class ResearchAgent(BaseAgent):
|
||||
)
|
||||
result = result_str
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_content]}
|
||||
import json as _json
|
||||
|
||||
args_str = (
|
||||
_json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else call.arguments
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": call.name, "arguments": args_str},
|
||||
}],
|
||||
})
|
||||
tool_message = self.llm_handler.create_tool_message(call, result)
|
||||
messages.append(tool_message)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
@@ -31,12 +32,23 @@ class ToolExecutor:
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
|
||||
def get_tools(self) -> Dict[str, Dict]:
|
||||
"""Load tool configs from DB based on user context."""
|
||||
"""Load tool configs from DB based on user context.
|
||||
|
||||
If *client_tools* have been set on this executor, they are
|
||||
automatically merged into the returned dict.
|
||||
"""
|
||||
if self.user_api_key:
|
||||
return self._get_tools_by_api_key(self.user_api_key)
|
||||
return self._get_user_tools(self.user or "local")
|
||||
tools = self._get_tools_by_api_key(self.user_api_key)
|
||||
else:
|
||||
tools = self._get_user_tools(self.user or "local")
|
||||
if self.client_tools:
|
||||
self.merge_client_tools(tools, self.client_tools)
|
||||
return tools
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
@@ -65,29 +77,123 @@ class ToolExecutor:
|
||||
user_tools = list(user_tools)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||
"""Convert tool configs to LLM function schemas."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": self._build_tool_parameters(action),
|
||||
},
|
||||
def merge_client_tools(
|
||||
self, tools_dict: Dict, client_tools: List[Dict]
|
||||
) -> Dict:
|
||||
"""Merge client-provided tool definitions into tools_dict.
|
||||
|
||||
Client tools use the standard function-calling format::
|
||||
|
||||
[{"type": "function", "function": {"name": "get_weather",
|
||||
"description": "...", "parameters": {...}}}]
|
||||
|
||||
They are stored in *tools_dict* with ``client_side: True`` so that
|
||||
:meth:`check_pause` returns a pause signal instead of trying to
|
||||
execute them server-side.
|
||||
|
||||
Args:
|
||||
tools_dict: The mutable server tools dict (will be modified in place).
|
||||
client_tools: List of tool definitions in function-calling format.
|
||||
|
||||
Returns:
|
||||
The updated *tools_dict* (same reference, for convenience).
|
||||
"""
|
||||
for i, ct in enumerate(client_tools):
|
||||
func = ct.get("function", ct) # tolerate bare {"name":..} too
|
||||
name = func.get("name", f"clienttool{i}")
|
||||
tool_id = f"ct{i}"
|
||||
|
||||
tools_dict[tool_id] = {
|
||||
"name": name,
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": name,
|
||||
"description": func.get("description", ""),
|
||||
"active": True,
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
],
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
if (
|
||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
||||
)
|
||||
for action in (
|
||||
return tools_dict
|
||||
|
||||
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||
"""Convert tool configs to LLM function schemas.
|
||||
|
||||
Action names are kept clean for the LLM:
|
||||
- Unique action names appear as-is (e.g. ``get_weather``).
|
||||
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
|
||||
``search_2``).
|
||||
|
||||
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
|
||||
can be routed back to the correct ``(tool_id, action_name)`` without
|
||||
brittle string splitting.
|
||||
"""
|
||||
# Pass 1: collect entries and count action name occurrences
|
||||
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
|
||||
name_counts: Counter = Counter()
|
||||
|
||||
for tool_id, tool in tools_dict.items():
|
||||
is_api = tool["name"] == "api_tool"
|
||||
is_client = tool.get("client_side", False)
|
||||
|
||||
if is_api and "actions" not in tool.get("config", {}):
|
||||
continue
|
||||
if not is_api and "actions" not in tool:
|
||||
continue
|
||||
|
||||
actions = (
|
||||
tool["config"]["actions"].values()
|
||||
if tool["name"] == "api_tool"
|
||||
if is_api
|
||||
else tool["actions"]
|
||||
)
|
||||
if action.get("active", True)
|
||||
]
|
||||
|
||||
for action in actions:
|
||||
if not action.get("active", True):
|
||||
continue
|
||||
entries.append((tool_id, action["name"], action, is_client))
|
||||
name_counts[action["name"]] += 1
|
||||
|
||||
# Pass 2: assign LLM-visible names and build mappings
|
||||
self._name_to_tool = {}
|
||||
self._tool_to_name = {}
|
||||
collision_counters: Dict[str, int] = {}
|
||||
all_llm_names: set = set()
|
||||
|
||||
result = []
|
||||
for tool_id, action_name, action, is_client in entries:
|
||||
if name_counts[action_name] == 1:
|
||||
llm_name = action_name
|
||||
else:
|
||||
counter = collision_counters.get(action_name, 1)
|
||||
candidate = f"{action_name}_{counter}"
|
||||
# Skip if candidate collides with a unique action name
|
||||
while candidate in all_llm_names or (
|
||||
candidate in name_counts and name_counts[candidate] == 1
|
||||
):
|
||||
counter += 1
|
||||
candidate = f"{action_name}_{counter}"
|
||||
collision_counters[action_name] = counter + 1
|
||||
llm_name = candidate
|
||||
|
||||
all_llm_names.add(llm_name)
|
||||
self._name_to_tool[llm_name] = (tool_id, action_name)
|
||||
self._tool_to_name[(tool_id, action_name)] = llm_name
|
||||
|
||||
if is_client:
|
||||
params = action.get("parameters", {})
|
||||
else:
|
||||
params = self._build_tool_parameters(action)
|
||||
|
||||
result.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": llm_name,
|
||||
"description": action.get("description", ""),
|
||||
"parameters": params,
|
||||
},
|
||||
})
|
||||
return result
|
||||
|
||||
def _build_tool_parameters(self, action: Dict) -> Dict:
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
@@ -104,23 +210,81 @@ class ToolExecutor:
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def check_pause(
|
||||
self, tools_dict: Dict, call, llm_class_name: str
|
||||
) -> Optional[Dict]:
|
||||
"""Check if a tool call requires pausing for approval or client execution.
|
||||
|
||||
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||
"""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
llm_name = getattr(call, "name", "")
|
||||
|
||||
if tool_id is None or action_name is None or tool_id not in tools_dict:
|
||||
return None # Will be handled as error by execute()
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
|
||||
# Client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
# Approval required
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||
action_name, {}
|
||||
)
|
||||
else:
|
||||
action_data = next(
|
||||
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
|
||||
{},
|
||||
)
|
||||
|
||||
if action_data.get("require_approval"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||
parser = ToolActionParser(llm_class_name)
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
llm_name = getattr(call, "name", "unknown")
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": getattr(call, "name", "unknown"),
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
@@ -133,7 +297,7 @@ class ToolExecutor:
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||
}
|
||||
@@ -144,7 +308,7 @@ class ToolExecutor:
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
internal: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -20,6 +20,8 @@ class InternalSearchTool(Tool):
|
||||
- list_files action: browse the file/folder structure
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.retrieved_docs: List[Dict] = []
|
||||
|
||||
@@ -36,6 +36,8 @@ class ThinkTool(Tool):
|
||||
The reasoning content is captured in tool_call data for transparency.
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config=None):
|
||||
pass
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type):
|
||||
def __init__(self, llm_type, name_mapping=None):
|
||||
self.llm_type = llm_type
|
||||
self.name_mapping = name_mapping
|
||||
self.parsers = {
|
||||
"OpenAILLM": self._parse_openai_llm,
|
||||
"GoogleLLM": self._parse_google_llm,
|
||||
@@ -16,22 +17,33 @@ class ToolActionParser:
|
||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||
return parser(call)
|
||||
|
||||
def _resolve_via_mapping(self, call_name):
|
||||
"""Look up (tool_id, action_name) from the name mapping if available."""
|
||||
if self.name_mapping and call_name in self.name_mapping:
|
||||
return self.name_mapping[call_name]
|
||||
return None
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
@@ -45,19 +57,24 @@ class ToolActionParser:
|
||||
def _parse_google_llm(self, call):
|
||||
try:
|
||||
call_args = call.arguments
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
|
||||
@@ -19,7 +19,7 @@ class ToolManager:
|
||||
continue
|
||||
module = importlib.import_module(f"application.agents.tools.{name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
|
||||
@@ -74,57 +74,72 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
stream = self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if len(stream_result) == 7:
|
||||
(
|
||||
conversation_id,
|
||||
response,
|
||||
sources,
|
||||
tool_calls,
|
||||
thought,
|
||||
error,
|
||||
structured_info,
|
||||
) = stream_result
|
||||
else:
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
stream_result
|
||||
)
|
||||
structured_info = None
|
||||
if stream_result["error"]:
|
||||
return make_response({"error": stream_result["error"]}, 400)
|
||||
|
||||
if error:
|
||||
return make_response({"error": error}, 400)
|
||||
result = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
"conversation_id": stream_result["conversation_id"],
|
||||
"answer": stream_result["answer"],
|
||||
"sources": stream_result["sources"],
|
||||
"tool_calls": stream_result["tool_calls"],
|
||||
"thought": stream_result["thought"],
|
||||
}
|
||||
|
||||
if structured_info:
|
||||
result.update(structured_info)
|
||||
extra_info = stream_result.get("extra")
|
||||
if extra_info:
|
||||
result.update(extra_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
@@ -39,7 +40,16 @@ class BaseAnswerResource:
|
||||
def validate_request(
|
||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||
) -> Optional[Response]:
|
||||
"""Common request validation"""
|
||||
"""Common request validation.
|
||||
|
||||
Continuation requests (``tool_actions`` present) require
|
||||
``conversation_id`` but not ``question``.
|
||||
"""
|
||||
if data.get("tool_actions"):
|
||||
# Continuation mode — question is not required
|
||||
if missing := check_required_fields(data, ["conversation_id"]):
|
||||
return missing
|
||||
return None
|
||||
required_fields = ["question"]
|
||||
if require_conversation_id:
|
||||
required_fields.append("conversation_id")
|
||||
@@ -177,6 +187,7 @@ class BaseAnswerResource:
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
_continuation: Optional[Dict] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -207,8 +218,19 @@ class BaseAnswerResource:
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
|
||||
for line in agent.gen(query=question):
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
messages=_continuation["messages"],
|
||||
tools_dict=_continuation["tools_dict"],
|
||||
pending_tool_calls=_continuation["pending_tool_calls"],
|
||||
tool_actions=_continuation["tool_actions"],
|
||||
)
|
||||
else:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
@@ -244,15 +266,21 @@ class BaseAnswerResource:
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
if line.get("type") == "error":
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield f"data: {data}\n\n"
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -262,6 +290,93 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
conversation_id = (
|
||||
self.conversation_service.save_conversation(
|
||||
None,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create conversation for continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.save_state(
|
||||
conversation_id=str(conversation_id),
|
||||
user=decoded_token.get("sub", "local"),
|
||||
messages=continuation["messages"],
|
||||
pending_tool_calls=continuation["pending_tool_calls"],
|
||||
tools_dict=continuation["tools_dict"],
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent.__class__.__name__,
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
@@ -425,8 +540,13 @@ class BaseAnswerResource:
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream):
|
||||
"""Process the stream response for non-streaming endpoint"""
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
"""Process the stream response for non-streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Dict with keys: conversation_id, answer, sources, tool_calls,
|
||||
thought, error, and optional extra.
|
||||
"""
|
||||
conversation_id = ""
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
@@ -435,6 +555,7 @@ class BaseAnswerResource:
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
pending_tool_calls = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
@@ -453,11 +574,22 @@ class BaseAnswerResource:
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "tool_calls_pending":
|
||||
pending_tool_calls = event.get("data", {}).get(
|
||||
"pending_tool_calls", []
|
||||
)
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"], None
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": event["error"],
|
||||
}
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -465,18 +597,30 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
thought,
|
||||
None,
|
||||
)
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": "Stream ended unexpectedly",
|
||||
}
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if pending_tool_calls is not None:
|
||||
result["extra"] = {"pending_tool_calls": pending_tool_calls}
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
result["extra"] = {"structured": True, "schema": schema_info}
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -79,7 +79,39 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data["question"])
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
@@ -49,28 +50,35 @@ class MessageBuilder:
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
@@ -180,28 +188,35 @@ class MessageBuilder:
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
rebuilt_messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
rebuilt_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
rebuilt_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
|
||||
141
application/api/answer/services/continuation_service.py
Normal file
141
application/api/answer/services/continuation_service.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Service for saving and restoring tool-call continuation state.
|
||||
|
||||
When a stream pauses (tool needs approval or client-side execution),
|
||||
the full execution state is persisted to MongoDB so the client can
|
||||
resume later by sending tool_actions.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
|
||||
if isinstance(obj, ObjectId):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
"""Manages pending tool-call state in MongoDB."""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.collection = db["pending_tool_state"]
|
||||
self._ensure_indexes()
|
||||
|
||||
def _ensure_indexes(self):
|
||||
try:
|
||||
self.collection.create_index(
|
||||
"expires_at", expireAfterSeconds=0
|
||||
)
|
||||
self.collection.create_index(
|
||||
[("conversation_id", 1), ("user", 1)], unique=True
|
||||
)
|
||||
except Exception:
|
||||
# Indexes may already exist or mongomock doesn't support TTL
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user: str,
|
||||
messages: List[Dict],
|
||||
pending_tool_calls: List[Dict],
|
||||
tools_dict: Dict,
|
||||
tool_schemas: List[Dict],
|
||||
agent_config: Dict,
|
||||
client_tools: Optional[List[Dict]] = None,
|
||||
) -> str:
|
||||
"""Save execution state for later continuation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation this state belongs to.
|
||||
user: Owner user ID.
|
||||
messages: Full messages array at the pause point.
|
||||
pending_tool_calls: Tool calls awaiting client action.
|
||||
tools_dict: Serializable tools configuration dict.
|
||||
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||
agent_config: Config needed to recreate the agent on resume.
|
||||
client_tools: Client-provided tool schemas for client-side execution.
|
||||
|
||||
Returns:
|
||||
The string ID of the saved state document.
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
|
||||
|
||||
doc = {
|
||||
"conversation_id": conversation_id,
|
||||
"user": user,
|
||||
"messages": _make_serializable(messages),
|
||||
"pending_tool_calls": _make_serializable(pending_tool_calls),
|
||||
"tools_dict": _make_serializable(tools_dict),
|
||||
"tool_schemas": _make_serializable(tool_schemas),
|
||||
"agent_config": _make_serializable(agent_config),
|
||||
"client_tools": _make_serializable(client_tools) if client_tools else None,
|
||||
"created_at": now,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
# Upsert — only one pending state per conversation per user
|
||||
result = self.collection.replace_one(
|
||||
{"conversation_id": conversation_id, "user": user},
|
||||
doc,
|
||||
upsert=True,
|
||||
)
|
||||
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
|
||||
logger.info(
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
return state_id
|
||||
|
||||
def load_state(
|
||||
self, conversation_id: str, user: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pending continuation state.
|
||||
|
||||
Returns:
|
||||
The state dict, or None if no pending state exists.
|
||||
"""
|
||||
doc = self.collection.find_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
doc["_id"] = str(doc["_id"])
|
||||
return doc
|
||||
|
||||
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||
"""Delete pending state after successful resumption.
|
||||
|
||||
Returns:
|
||||
True if a document was deleted.
|
||||
"""
|
||||
result = self.collection.delete_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
if result.deleted_count:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
return result.deleted_count > 0
|
||||
@@ -771,6 +771,121 @@ class StreamProcessor:
|
||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||
return None
|
||||
|
||||
def resume_from_tool_actions(
|
||||
self,
|
||||
tool_actions: list,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""Resume a paused agent from saved continuation state.
|
||||
|
||||
Loads the pending state from MongoDB, recreates the agent with
|
||||
the saved configuration, and returns an agent ready to call
|
||||
``gen_continuation()``.
|
||||
|
||||
Args:
|
||||
tool_actions: Client-provided actions (approvals / results).
|
||||
conversation_id: The conversation being resumed.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
|
||||
"""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
cont_service = ContinuationService()
|
||||
state = cont_service.load_state(conversation_id, self.initial_user_id)
|
||||
if not state:
|
||||
raise ValueError("No pending tool state found for this conversation")
|
||||
|
||||
messages = state["messages"]
|
||||
pending_tool_calls = state["pending_tool_calls"]
|
||||
tools_dict = state["tools_dict"]
|
||||
tool_schemas = state.get("tool_schemas", [])
|
||||
agent_config = state["agent_config"]
|
||||
|
||||
model_id = agent_config.get("model_id")
|
||||
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||||
api_key = agent_config.get("api_key")
|
||||
user_api_key = agent_config.get("user_api_key")
|
||||
agent_id = agent_config.get("agent_id")
|
||||
prompt = agent_config.get("prompt", "")
|
||||
json_schema = agent_config.get("json_schema")
|
||||
retriever_config = agent_config.get("retriever_config")
|
||||
|
||||
# Recreate dependencies
|
||||
system_api_key = api_key or get_api_key_for_provider(llm_name)
|
||||
llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||||
tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=self.initial_user_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
tool_executor.conversation_id = conversation_id
|
||||
# Restore client tools so they stay available for subsequent LLM calls
|
||||
saved_client_tools = state.get("client_tools")
|
||||
if saved_client_tools:
|
||||
tool_executor.client_tools = saved_client_tools
|
||||
# Re-merge into tools_dict (they may have been stripped during serialization)
|
||||
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
|
||||
|
||||
agent_type = agent_config.get("agent_type", "ClassicAgent")
|
||||
# Map class names back to agent creator keys
|
||||
type_map = {
|
||||
"ClassicAgent": "classic",
|
||||
"AgenticAgent": "agentic",
|
||||
"ResearchAgent": "research",
|
||||
"WorkflowAgent": "workflow",
|
||||
}
|
||||
agent_key = type_map.get(agent_type, "classic")
|
||||
|
||||
agent_kwargs = {
|
||||
"endpoint": "stream",
|
||||
"llm_name": llm_name,
|
||||
"model_id": model_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": agent_id,
|
||||
"user_api_key": user_api_key,
|
||||
"prompt": prompt,
|
||||
"chat_history": [],
|
||||
"decoded_token": self.decoded_token,
|
||||
"json_schema": json_schema,
|
||||
"llm": llm,
|
||||
"llm_handler": llm_handler,
|
||||
"tool_executor": tool_executor,
|
||||
}
|
||||
|
||||
if agent_key in ("agentic", "research") and retriever_config:
|
||||
agent_kwargs["retriever_config"] = retriever_config
|
||||
|
||||
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
|
||||
agent.conversation_id = conversation_id
|
||||
agent.initial_user_id = self.initial_user_id
|
||||
agent.tools = tool_schemas
|
||||
|
||||
# Store config for the route layer
|
||||
self.model_id = model_id
|
||||
self.agent_id = agent_id
|
||||
self.agent_config["user_api_key"] = user_api_key
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
# Delete state so it can't be replayed
|
||||
cont_service.delete_state(conversation_id, self.initial_user_id)
|
||||
|
||||
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
docs_together: Optional[str] = None,
|
||||
@@ -841,6 +956,10 @@ class StreamProcessor:
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
tool_executor.conversation_id = self.conversation_id
|
||||
# Pass client-side tools so they get merged in get_tools()
|
||||
client_tools = self.data.get("client_tools")
|
||||
if client_tools:
|
||||
tool_executor.client_tools = client_tools
|
||||
|
||||
# Base agent kwargs
|
||||
agent_kwargs = {
|
||||
|
||||
3
application/api/v1/__init__.py
Normal file
3
application/api/v1/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from application.api.v1.routes import v1_bp
|
||||
|
||||
__all__ = ["v1_bp"]
|
||||
314
application/api/v1/routes.py
Normal file
314
application/api/v1/routes.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""Standard chat completions API routes.
|
||||
|
||||
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
|
||||
follow the widely-adopted chat completions protocol so external tools
|
||||
(opencode, continue, etc.) can connect to DocsGPT agents.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, make_response, request, Response
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.api.v1.translator import (
|
||||
translate_request,
|
||||
translate_response,
|
||||
translate_stream_event,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
|
||||
|
||||
|
||||
def _extract_bearer_token() -> Optional[str]:
|
||||
"""Extract API key from Authorization: Bearer header."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth.startswith("Bearer "):
|
||||
return auth[7:].strip()
|
||||
return None
|
||||
|
||||
|
||||
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
||||
"""Look up the agent document for this API key."""
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
return db["agents"].find_one({"key": api_key})
|
||||
except Exception:
|
||||
logger.warning("Failed to look up agent for API key", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
|
||||
"""Return agent name for display as model name."""
|
||||
if agent:
|
||||
return agent.get("name", api_key)
|
||||
return api_key
|
||||
|
||||
|
||||
class _V1AnswerHelper(BaseAnswerResource):
|
||||
"""Thin wrapper to access complete_stream / process_response_stream."""
|
||||
pass
|
||||
|
||||
|
||||
@v1_bp.route("/chat/completions", methods=["POST"])
|
||||
def chat_completions():
|
||||
"""Handle POST /v1/chat/completions."""
|
||||
api_key = _extract_bearer_token()
|
||||
if not api_key:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
data = request.get_json()
|
||||
if not data or not data.get("messages"):
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
|
||||
is_stream = data.get("stream", False)
|
||||
agent_doc = _lookup_agent(api_key)
|
||||
model_name = _get_model_name(agent_doc, api_key)
|
||||
|
||||
try:
|
||||
internal_data = translate_request(data, api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
|
||||
# Link decoded_token to the agent's owner so continuation state,
|
||||
# logs, and tool execution use the correct user identity.
|
||||
agent_user = agent_doc.get("user") if agent_doc else None
|
||||
decoded_token = {"sub": agent_user or "api_key_user"}
|
||||
|
||||
try:
|
||||
processor = StreamProcessor(internal_data, decoded_token)
|
||||
|
||||
if internal_data.get("tool_actions"):
|
||||
# Continuation mode
|
||||
conversation_id = internal_data.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
internal_data["tool_actions"], conversation_id
|
||||
)
|
||||
continuation = {
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
}
|
||||
question = ""
|
||||
else:
|
||||
# Normal mode
|
||||
question = internal_data.get("question", "")
|
||||
agent = processor.build_agent(question)
|
||||
continuation = None
|
||||
|
||||
if not processor.decoded_token:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
helper = _V1AnswerHelper()
|
||||
usage_error = helper.check_usage(processor.agent_config)
|
||||
if usage_error:
|
||||
return usage_error
|
||||
|
||||
if is_stream:
|
||||
return Response(
|
||||
_stream_response(
|
||||
helper, question, agent, processor, model_name, continuation
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return _non_stream_response(
|
||||
helper, question, agent, processor, model_name, continuation
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
def _stream_response(
|
||||
helper: _V1AnswerHelper,
|
||||
question: str,
|
||||
agent: Any,
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate translated SSE chunks for streaming response."""
|
||||
completion_id = f"chatcmpl-{int(time.time())}"
|
||||
|
||||
internal_stream = helper.complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
for line in internal_stream:
|
||||
if not line.strip():
|
||||
continue
|
||||
# Parse the internal SSE event
|
||||
event_str = line.replace("data: ", "").strip()
|
||||
try:
|
||||
event_data = json.loads(event_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# Update completion_id when we get the conversation id
|
||||
if event_data.get("type") == "id":
|
||||
conv_id = event_data.get("id", "")
|
||||
if conv_id:
|
||||
completion_id = f"chatcmpl-{conv_id}"
|
||||
|
||||
# Translate to standard format
|
||||
translated = translate_stream_event(event_data, completion_id, model_name)
|
||||
for chunk in translated:
|
||||
yield chunk
|
||||
|
||||
|
||||
def _non_stream_response(
|
||||
helper: _V1AnswerHelper,
|
||||
question: str,
|
||||
agent: Any,
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
) -> Response:
|
||||
"""Collect full response and return as single JSON."""
|
||||
stream = helper.complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
result = helper.process_response_stream(stream)
|
||||
|
||||
if result["error"]:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
|
||||
500,
|
||||
)
|
||||
|
||||
extra = result.get("extra")
|
||||
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
|
||||
|
||||
response = translate_response(
|
||||
conversation_id=result["conversation_id"],
|
||||
answer=result["answer"] or "",
|
||||
sources=result["sources"],
|
||||
tool_calls=result["tool_calls"],
|
||||
thought=result["thought"] or "",
|
||||
model_name=model_name,
|
||||
pending_tool_calls=pending,
|
||||
)
|
||||
return make_response(jsonify(response), 200)
|
||||
|
||||
|
||||
@v1_bp.route("/models", methods=["GET"])
|
||||
def list_models():
|
||||
"""Handle GET /v1/models — return agents as models."""
|
||||
api_key = _extract_bearer_token()
|
||||
if not api_key:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
# Find the agent for this api_key
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
user = agent.get("user")
|
||||
|
||||
# Return all agents belonging to this user
|
||||
user_agents = list(agents_collection.find({"user": user}))
|
||||
|
||||
models = []
|
||||
for ag in user_agents:
|
||||
created = ag.get("createdAt")
|
||||
created_ts = int(created.timestamp()) if created else int(time.time())
|
||||
models.append({
|
||||
"id": str(ag.get("key", "")),
|
||||
"object": "model",
|
||||
"created": created_ts,
|
||||
"owned_by": "docsgpt",
|
||||
"name": ag.get("name", ""),
|
||||
"description": ag.get("description", ""),
|
||||
})
|
||||
|
||||
return make_response(
|
||||
jsonify({"object": "list", "data": models}),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"/v1/models error: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||
500,
|
||||
)
|
||||
415
application/api/v1/translator.py
Normal file
415
application/api/v1/translator.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""Translate between standard chat completions format and DocsGPT internals.
|
||||
|
||||
This module handles:
|
||||
- Request translation (chat completions -> DocsGPT internal format)
|
||||
- Response translation (DocsGPT response -> chat completions format)
|
||||
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
def _get_client_tool_name(tc: Dict) -> str:
|
||||
"""Return the original tool name for client-facing responses.
|
||||
|
||||
For client-side tools the ``tool_name`` field carries the name the
|
||||
client originally registered. Fall back to ``action_name`` (which
|
||||
is now the clean LLM-visible name) or ``name``.
|
||||
"""
|
||||
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request translation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_continuation(messages: List[Dict]) -> bool:
|
||||
"""Check if messages represent a tool-call continuation.
|
||||
|
||||
A continuation is detected when the last message(s) have ``role: "tool"``
|
||||
immediately after an assistant message with ``tool_calls``.
|
||||
"""
|
||||
if not messages:
|
||||
return False
|
||||
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
|
||||
# and there's an assistant message with tool_calls, it's a continuation.
|
||||
i = len(messages) - 1
|
||||
while i >= 0 and messages[i].get("role") == "tool":
|
||||
i -= 1
|
||||
if i < 0:
|
||||
return False
|
||||
return (
|
||||
messages[i].get("role") == "assistant"
|
||||
and bool(messages[i].get("tool_calls"))
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
|
||||
"""Extract tool results from trailing tool messages for continuation.
|
||||
|
||||
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
|
||||
"""
|
||||
results = []
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") != "tool":
|
||||
break
|
||||
call_id = msg.get("tool_call_id", "")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
results.append({"call_id": call_id, "result": content})
|
||||
results.reverse()
|
||||
return results
|
||||
|
||||
|
||||
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
|
||||
"""Try to extract conversation_id from the assistant message before tool results.
|
||||
|
||||
The conversation_id may be stored in a custom field on the assistant message
|
||||
from a previous response cycle.
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "assistant":
|
||||
# Check docsgpt extension
|
||||
return msg.get("docsgpt", {}).get("conversation_id")
|
||||
return None
|
||||
|
||||
|
||||
def convert_history(messages: List[Dict]) -> List[Dict]:
|
||||
"""Convert chat completions messages array to DocsGPT history format.
|
||||
|
||||
DocsGPT history is a list of ``{prompt, response}`` dicts.
|
||||
Excludes the last user message (that becomes the ``question``).
|
||||
"""
|
||||
history = []
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
if msg.get("role") == "system":
|
||||
i += 1
|
||||
continue
|
||||
if msg.get("role") == "user":
|
||||
# Look ahead for assistant response
|
||||
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
|
||||
content = messages[i + 1].get("content") or ""
|
||||
history.append({
|
||||
"prompt": msg.get("content", ""),
|
||||
"response": content,
|
||||
})
|
||||
i += 2
|
||||
continue
|
||||
# Last user message without response — skip (it's the question)
|
||||
i += 1
|
||||
continue
|
||||
i += 1
|
||||
return history
|
||||
|
||||
|
||||
def translate_request(
|
||||
data: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Translate a chat completions request to DocsGPT internal format.
|
||||
|
||||
Args:
|
||||
data: The incoming request body.
|
||||
api_key: Agent API key from the Authorization header.
|
||||
|
||||
Returns:
|
||||
Dict suitable for passing to ``StreamProcessor``.
|
||||
"""
|
||||
messages = data.get("messages", [])
|
||||
|
||||
# Check for continuation (tool results after assistant tool_calls)
|
||||
if is_continuation(messages):
|
||||
tool_actions = extract_tool_results(messages)
|
||||
conversation_id = extract_conversation_id(messages)
|
||||
if not conversation_id:
|
||||
conversation_id = data.get("conversation_id")
|
||||
result = {
|
||||
"conversation_id": conversation_id,
|
||||
"tool_actions": tool_actions,
|
||||
"api_key": api_key,
|
||||
}
|
||||
# Carry tools forward for next iteration
|
||||
if data.get("tools"):
|
||||
result["client_tools"] = data["tools"]
|
||||
return result
|
||||
|
||||
# Normal request — extract question from last user message
|
||||
question = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
question = msg.get("content", "")
|
||||
break
|
||||
|
||||
history = convert_history(messages)
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
"api_key": api_key,
|
||||
"history": json.dumps(history),
|
||||
"save_conversation": True,
|
||||
}
|
||||
|
||||
# Client tools
|
||||
if data.get("tools"):
|
||||
result["client_tools"] = data["tools"]
|
||||
|
||||
# DocsGPT extensions
|
||||
docsgpt = data.get("docsgpt", {})
|
||||
if docsgpt.get("attachments"):
|
||||
result["attachments"] = docsgpt["attachments"]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response translation (non-streaming)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def translate_response(
|
||||
conversation_id: str,
|
||||
answer: str,
|
||||
sources: Optional[List[Dict]],
|
||||
tool_calls: Optional[List[Dict]],
|
||||
thought: str,
|
||||
model_name: str,
|
||||
pending_tool_calls: Optional[List[Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Translate DocsGPT response to chat completions format.
|
||||
|
||||
Args:
|
||||
conversation_id: The DocsGPT conversation ID.
|
||||
answer: The assistant's text response.
|
||||
sources: RAG retrieval sources.
|
||||
tool_calls: Completed tool call results.
|
||||
thought: Reasoning/thinking tokens.
|
||||
model_name: Model/agent identifier.
|
||||
pending_tool_calls: Pending client-side tool calls (if paused).
|
||||
|
||||
Returns:
|
||||
Dict in the standard chat completions response format.
|
||||
"""
|
||||
created = int(time.time())
|
||||
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
|
||||
|
||||
# Build message
|
||||
message: Dict[str, Any] = {"role": "assistant"}
|
||||
|
||||
if pending_tool_calls:
|
||||
# Tool calls pending — return them for client execution
|
||||
message["content"] = None
|
||||
message["tool_calls"] = [
|
||||
{
|
||||
"id": tc.get("call_id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": _get_client_tool_name(tc),
|
||||
"arguments": (
|
||||
json.dumps(tc["arguments"])
|
||||
if isinstance(tc.get("arguments"), dict)
|
||||
else tc.get("arguments", "{}")
|
||||
),
|
||||
},
|
||||
}
|
||||
for tc in pending_tool_calls
|
||||
]
|
||||
finish_reason = "tool_calls"
|
||||
else:
|
||||
message["content"] = answer
|
||||
if thought:
|
||||
message["reasoning_content"] = thought
|
||||
finish_reason = "stop"
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# DocsGPT extensions
|
||||
docsgpt: Dict[str, Any] = {}
|
||||
if conversation_id:
|
||||
docsgpt["conversation_id"] = conversation_id
|
||||
if sources:
|
||||
docsgpt["sources"] = sources
|
||||
if tool_calls:
|
||||
docsgpt["tool_calls"] = tool_calls
|
||||
if docsgpt:
|
||||
result["docsgpt"] = docsgpt
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming event translation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
completion_id: str,
|
||||
model_name: str,
|
||||
delta: Dict[str, Any],
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build a single SSE chunk in the standard streaming format."""
|
||||
chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
return f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
|
||||
"""Build a DocsGPT extension SSE chunk."""
|
||||
return f"data: {json.dumps({'docsgpt': data})}\n\n"
|
||||
|
||||
|
||||
def translate_stream_event(
|
||||
event_data: Dict[str, Any],
|
||||
completion_id: str,
|
||||
model_name: str,
|
||||
) -> List[str]:
|
||||
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
|
||||
|
||||
May return 0, 1, or 2 chunks per input event. For example, a completed
|
||||
tool call produces both a docsgpt extension chunk and nothing on the
|
||||
standard side (since server-side tool calls aren't surfaced in standard
|
||||
format).
|
||||
|
||||
Args:
|
||||
event_data: Parsed DocsGPT event dict.
|
||||
completion_id: The completion ID for this response.
|
||||
model_name: Model/agent identifier.
|
||||
|
||||
Returns:
|
||||
List of SSE-formatted strings to send to the client.
|
||||
"""
|
||||
event_type = event_data.get("type")
|
||||
chunks: List[str] = []
|
||||
|
||||
if event_type == "answer":
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
|
||||
)
|
||||
|
||||
elif event_type == "thought":
|
||||
chunks.append(
|
||||
_make_chunk(
|
||||
completion_id, model_name,
|
||||
{"reasoning_content": event_data.get("thought", "")},
|
||||
)
|
||||
)
|
||||
|
||||
elif event_type == "source":
|
||||
chunks.append(
|
||||
_make_docsgpt_chunk({
|
||||
"type": "source",
|
||||
"sources": event_data.get("source", []),
|
||||
})
|
||||
)
|
||||
|
||||
elif event_type == "tool_call":
|
||||
tc_data = event_data.get("data", {})
|
||||
status = tc_data.get("status")
|
||||
|
||||
if status == "requires_client_execution":
|
||||
# Standard: stream as tool_calls delta
|
||||
args = tc_data.get("arguments", {})
|
||||
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"id": tc_data.get("call_id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": _get_client_tool_name(tc_data),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
)
|
||||
elif status == "awaiting_approval":
|
||||
# Extension: approval needed
|
||||
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||
elif status in ("completed", "pending", "error", "denied", "skipped"):
|
||||
# Extension: tool call progress
|
||||
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||
|
||||
elif event_type == "tool_calls_pending":
|
||||
# Standard: finish_reason = tool_calls
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
|
||||
)
|
||||
# Also emit as docsgpt extension
|
||||
chunks.append(
|
||||
_make_docsgpt_chunk({
|
||||
"type": "tool_calls_pending",
|
||||
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
|
||||
})
|
||||
)
|
||||
|
||||
elif event_type == "end":
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
|
||||
)
|
||||
chunks.append("data: [DONE]\n\n")
|
||||
|
||||
elif event_type == "id":
|
||||
chunks.append(
|
||||
_make_docsgpt_chunk({
|
||||
"type": "id",
|
||||
"conversation_id": event_data.get("id", ""),
|
||||
})
|
||||
)
|
||||
|
||||
elif event_type == "error":
|
||||
# Emit as standard error (non-standard but widely supported)
|
||||
error_data = {
|
||||
"error": {
|
||||
"message": event_data.get("error", "An error occurred"),
|
||||
"type": "server_error",
|
||||
}
|
||||
}
|
||||
chunks.append(f"data: {json.dumps(error_data)}\n\n")
|
||||
|
||||
elif event_type == "structured_answer":
|
||||
chunks.append(
|
||||
_make_chunk(
|
||||
completion_id, model_name,
|
||||
{"content": event_data.get("answer", "")},
|
||||
)
|
||||
)
|
||||
|
||||
# Skip: tool_calls (redundant), research_plan, research_progress
|
||||
|
||||
return chunks
|
||||
@@ -17,6 +17,7 @@ from application.api.answer import answer # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
from application.api.v1 import v1_bp # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.stt.upload_limits import ( # noqa: E402
|
||||
@@ -36,6 +37,7 @@ app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.register_blueprint(v1_bp)
|
||||
app.config.update(
|
||||
UPLOAD_FOLDER="inputs",
|
||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||
|
||||
@@ -167,6 +167,8 @@ class GoogleLLM(BaseLLM):
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
import json as _json
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
@@ -180,9 +182,66 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
|
||||
# Standard format: assistant message with tool_calls array
|
||||
msg_tool_calls = message.get("tool_calls")
|
||||
if msg_tool_calls and role == "model":
|
||||
for tc in msg_tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = _json.loads(args)
|
||||
except (_json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
cleaned_args = self._remove_null_values(args)
|
||||
thought_sig = tc.get("thought_signature")
|
||||
if thought_sig:
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
name=func.get("name", ""),
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=thought_sig,
|
||||
)
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=func.get("name", ""),
|
||||
args=cleaned_args,
|
||||
)
|
||||
)
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
continue
|
||||
|
||||
# Standard format: tool message with tool_call_id
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if role == "tool" and tool_call_id is not None:
|
||||
result_content = content
|
||||
if isinstance(result_content, str):
|
||||
try:
|
||||
result_content = _json.loads(result_content)
|
||||
except (_json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# Google expects function_response name — extract from tool_call_id context
|
||||
# We use a placeholder name since Google API doesn't require exact match
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
name="tool_result",
|
||||
response={"result": result_content},
|
||||
)
|
||||
)
|
||||
cleaned_messages.append(types.Content(role="model", parts=parts))
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
role = "model"
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
@@ -191,15 +250,11 @@ class GoogleLLM(BaseLLM):
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(text=item["text"]))
|
||||
elif "function_call" in item:
|
||||
# Remove null values from args to avoid API errors
|
||||
|
||||
# Legacy format support
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
# Create function call part with thought_signature if present
|
||||
# For Gemini 3 models, we need to include thought_signature
|
||||
if "thought_signature" in item:
|
||||
# Use Part constructor with functionCall and thoughtSignature
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
@@ -210,7 +265,6 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use helper method when no thought_signature
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -315,10 +316,34 @@ class LLMHandler(ABC):
|
||||
current_prompt = self._extract_text_from_content(content)
|
||||
|
||||
elif role in {"assistant", "model"}:
|
||||
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
|
||||
# Standard format: tool_calls array on assistant message
|
||||
msg_tool_calls = message.get("tool_calls")
|
||||
if msg_tool_calls:
|
||||
for tc in msg_tool_calls:
|
||||
call_id = tc.get("id") or str(uuid.uuid4())
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": func.get("name"),
|
||||
"arguments": args,
|
||||
"result": None,
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
continue
|
||||
|
||||
# Legacy format: function_call/function_response in content list
|
||||
if isinstance(content, list):
|
||||
has_fc = False
|
||||
for item in content:
|
||||
if "function_call" in item:
|
||||
has_fc = True
|
||||
fc = item["function_call"]
|
||||
call_id = fc.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
@@ -329,37 +354,30 @@ class LLMHandler(ABC):
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
elif "function_response" in item:
|
||||
fr = item["function_response"]
|
||||
call_id = fr.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fr.get("name"),
|
||||
"arguments": None,
|
||||
"result": fr.get("response", {}).get("result"),
|
||||
"status": "completed",
|
||||
"call_id": call_id,
|
||||
}
|
||||
# No direct assistant text here; continue to next message
|
||||
continue
|
||||
if has_fc:
|
||||
continue
|
||||
|
||||
response_text = self._extract_text_from_content(content)
|
||||
_commit_query(response_text)
|
||||
|
||||
elif role == "tool":
|
||||
# Attach tool outputs to the latest pending tool call if possible
|
||||
# Standard format: tool_call_id on tool message
|
||||
call_id = message.get("tool_call_id")
|
||||
tool_text = self._extract_text_from_content(content)
|
||||
# Attempt to parse function_response style
|
||||
call_id = None
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item and item["function_response"].get("call_id"):
|
||||
call_id = item["function_response"]["call_id"]
|
||||
break
|
||||
|
||||
if call_id and call_id in current_tool_calls:
|
||||
current_tool_calls[call_id]["result"] = tool_text
|
||||
current_tool_calls[call_id]["status"] = "completed"
|
||||
elif queries:
|
||||
# Legacy: function_response in content list
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item:
|
||||
legacy_id = item["function_response"].get("call_id")
|
||||
if legacy_id and legacy_id in current_tool_calls:
|
||||
current_tool_calls[legacy_id]["result"] = tool_text
|
||||
current_tool_calls[legacy_id]["status"] = "completed"
|
||||
break
|
||||
elif call_id is None and queries:
|
||||
queries[-1].setdefault("tool_calls", []).append(
|
||||
{
|
||||
"tool_name": "unknown_tool",
|
||||
@@ -648,6 +666,13 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
Execute tool calls and update conversation history.
|
||||
|
||||
When a tool requires approval or client-side execution, it is
|
||||
collected as a pending action instead of being executed. The
|
||||
generator returns ``(updated_messages, pending_actions)`` where
|
||||
*pending_actions* is ``None`` when every tool was executed
|
||||
normally, or a list of dicts describing actions the client must
|
||||
resolve before the LLM loop can continue.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
tool_calls: List of tool calls to execute
|
||||
@@ -655,9 +680,11 @@ class LLMHandler(ABC):
|
||||
messages: Current conversation history
|
||||
|
||||
Returns:
|
||||
Updated messages list
|
||||
Tuple of (updated_messages, pending_actions).
|
||||
pending_actions is None if all tools executed, otherwise a list.
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
pending_actions: List[Dict] = []
|
||||
|
||||
for i, call in enumerate(tool_calls):
|
||||
# Check context limit before executing tool call
|
||||
@@ -763,6 +790,29 @@ class LLMHandler(ABC):
|
||||
# Set flag on agent
|
||||
agent.context_limit_reached = True
|
||||
break
|
||||
|
||||
# ---- Pause check: approval / client-side execution ----
|
||||
llm_class = agent.llm.__class__.__name__
|
||||
pause_info = agent.tool_executor.check_pause(
|
||||
tools_dict, call, llm_class
|
||||
)
|
||||
if pause_info:
|
||||
# Yield pause event so the client knows this tool is waiting
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pause_info["tool_name"],
|
||||
"call_id": pause_info["call_id"],
|
||||
"action_name": pause_info.get("llm_name", pause_info["name"]),
|
||||
"arguments": pause_info["arguments"],
|
||||
"status": pause_info["pause_type"],
|
||||
},
|
||||
}
|
||||
pending_actions.append(pause_info)
|
||||
# Do NOT add messages for pending tools here.
|
||||
# They will be added on resume to keep call/result pairs together.
|
||||
continue
|
||||
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
@@ -772,25 +822,30 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
# Include thought_signature for Google Gemini 3 models
|
||||
# It should be at the same level as function_call, not inside it
|
||||
if call.thought_signature:
|
||||
function_call_content["thought_signature"] = call.thought_signature
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [function_call_content],
|
||||
}
|
||||
)
|
||||
|
||||
# Standard internal format: assistant message with tool_calls array
|
||||
args_str = (
|
||||
json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else call.arguments
|
||||
)
|
||||
tool_call_obj = {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.name,
|
||||
"arguments": args_str,
|
||||
},
|
||||
}
|
||||
# Preserve thought_signature for Google Gemini 3 models
|
||||
if call.thought_signature:
|
||||
tool_call_obj["thought_signature"] = call.thought_signature
|
||||
|
||||
updated_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call_obj],
|
||||
})
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
except Exception as e:
|
||||
@@ -802,16 +857,15 @@ class LLMHandler(ABC):
|
||||
error_message = self.create_tool_message(error_call, error_response)
|
||||
updated_messages.append(error_message)
|
||||
|
||||
call_parts = call.name.split("_")
|
||||
if len(call_parts) >= 2:
|
||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||
action_name = "_".join(call_parts[:-1])
|
||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||
full_action_name = f"{action_name}_{tool_id}"
|
||||
mapping = agent.tool_executor._name_to_tool
|
||||
if call.name in mapping:
|
||||
resolved_tool_id, _ = mapping[call.name]
|
||||
tool_name = tools_dict.get(resolved_tool_id, {}).get(
|
||||
"name", "unknown_tool"
|
||||
)
|
||||
else:
|
||||
tool_name = "unknown_tool"
|
||||
action_name = call.name
|
||||
full_action_name = call.name
|
||||
full_action_name = call.name
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
@@ -823,7 +877,7 @@ class LLMHandler(ABC):
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
return updated_messages
|
||||
return updated_messages, pending_actions if pending_actions else None
|
||||
|
||||
def handle_non_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
@@ -851,8 +905,22 @@ class LLMHandler(ABC):
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, pending_actions = e.value
|
||||
break
|
||||
|
||||
# If tools need approval or client execution, pause the loop
|
||||
if pending_actions:
|
||||
agent._pending_continuation = {
|
||||
"messages": messages,
|
||||
"pending_tool_calls": pending_actions,
|
||||
"tools_dict": tools_dict,
|
||||
}
|
||||
yield {
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": pending_actions},
|
||||
}
|
||||
return ""
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
)
|
||||
@@ -913,10 +981,23 @@ class LLMHandler(ABC):
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, pending_actions = e.value
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
# If tools need approval or client execution, pause the loop
|
||||
if pending_actions:
|
||||
agent._pending_continuation = {
|
||||
"messages": messages,
|
||||
"pending_tool_calls": pending_actions,
|
||||
"tools_dict": tools_dict,
|
||||
}
|
||||
yield {
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": pending_actions},
|
||||
}
|
||||
return
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
|
||||
@@ -67,18 +67,18 @@ class GoogleLLMHandler(LLMHandler):
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create Google-style tool message."""
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
return {
|
||||
"role": "model",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
}
|
||||
}
|
||||
],
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
|
||||
@@ -37,18 +37,18 @@ class OpenAILLMHandler(LLMHandler):
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create OpenAI-style tool message."""
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
"call_id": tool_call.id,
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
|
||||
@@ -91,16 +91,52 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
# Standard format: assistant message with tool_calls (passthrough)
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls and role == "assistant":
|
||||
cleaned_tcs = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(self._remove_null_values(args))
|
||||
elif isinstance(args, str):
|
||||
try:
|
||||
parsed = json.loads(args)
|
||||
args = json.dumps(self._remove_null_values(parsed))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
cleaned_tcs.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {"name": func.get("name", ""), "arguments": args},
|
||||
})
|
||||
cleaned_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": cleaned_tcs,
|
||||
})
|
||||
continue
|
||||
|
||||
# Standard format: tool message with tool_call_id (passthrough)
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if role == "tool" and tool_call_id is not None:
|
||||
cleaned_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content if isinstance(content, str) else json.dumps(content),
|
||||
})
|
||||
continue
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
# Collect all content parts into a single message
|
||||
content_parts = []
|
||||
|
||||
for item in content:
|
||||
# Legacy format support: function_call / function_response
|
||||
if "function_call" in item:
|
||||
# Function calls need their own message
|
||||
args = item["function_call"]["args"]
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
@@ -116,28 +152,20 @@ class OpenAILLM(BaseLLM):
|
||||
"arguments": json.dumps(cleaned_args),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
)
|
||||
cleaned_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call],
|
||||
})
|
||||
elif "function_response" in item:
|
||||
# Function responses need their own message
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["function_response"][
|
||||
"call_id"
|
||||
],
|
||||
"content": json.dumps(
|
||||
item["function_response"]["response"]["result"]
|
||||
),
|
||||
}
|
||||
)
|
||||
cleaned_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": item["function_response"]["call_id"],
|
||||
"content": json.dumps(
|
||||
item["function_response"]["response"]["result"]
|
||||
),
|
||||
})
|
||||
elif isinstance(item, dict):
|
||||
# Collect content parts (text, images, files) into a single message
|
||||
if "type" in item and item["type"] == "text" and "text" in item:
|
||||
content_parts.append(item)
|
||||
elif "type" in item and item["type"] == "file" and "file" in item:
|
||||
@@ -145,10 +173,7 @@ class OpenAILLM(BaseLLM):
|
||||
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
|
||||
content_parts.append(item)
|
||||
elif "text" in item and "type" not in item:
|
||||
# Legacy format: {"text": "..."} without type
|
||||
content_parts.append({"type": "text", "text": item["text"]})
|
||||
|
||||
# Add the collected content parts as a single message
|
||||
if content_parts:
|
||||
cleaned_messages.append({"role": role, "content": content_parts})
|
||||
else:
|
||||
|
||||
@@ -77,6 +77,10 @@ const endpoints = {
|
||||
WORKFLOWS: '/api/workflows',
|
||||
WORKFLOW: (id: string) => `/api/workflows/${id}`,
|
||||
},
|
||||
V1: {
|
||||
CHAT_COMPLETIONS: '/v1/chat/completions',
|
||||
MODELS: '/v1/models',
|
||||
},
|
||||
CONVERSATION: {
|
||||
ANSWER: '/api/answer',
|
||||
ANSWER_STREAMING: '/stream',
|
||||
|
||||
@@ -54,6 +54,18 @@ const conversationService = {
|
||||
apiClient.get(endpoints.CONVERSATION.DELETE_ALL, token, {}),
|
||||
update: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.CONVERSATION.UPDATE, data, token, {}),
|
||||
chatCompletions: (
|
||||
data: any,
|
||||
agentApiKey: string,
|
||||
signal: AbortSignal,
|
||||
): Promise<any> =>
|
||||
apiClient.post(
|
||||
endpoints.V1.CHAT_COMPLETIONS,
|
||||
data,
|
||||
null,
|
||||
{ Authorization: `Bearer ${agentApiKey}` },
|
||||
signal,
|
||||
),
|
||||
};
|
||||
|
||||
export default conversationService;
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
resendQuery,
|
||||
selectQueries,
|
||||
selectStatus,
|
||||
submitToolActions,
|
||||
updateQuery,
|
||||
} from './conversationSlice';
|
||||
import { selectCompletedAttachments } from '../upload/uploadSlice';
|
||||
@@ -41,6 +42,17 @@ export default function Conversation() {
|
||||
const [lastQueryReturnedErr, setLastQueryReturnedErr] =
|
||||
useState<boolean>(false);
|
||||
|
||||
const handleToolAction = useCallback(
|
||||
(callId: string, decision: 'approved' | 'denied', comment?: string) => {
|
||||
dispatch(
|
||||
submitToolActions({
|
||||
toolActions: [{ call_id: callId, decision, comment }],
|
||||
}),
|
||||
);
|
||||
},
|
||||
[dispatch],
|
||||
);
|
||||
|
||||
const lastAutoOpenedArtifactId = useRef<string | null>(null);
|
||||
const didInitArtifactAutoOpen = useRef(false);
|
||||
const prevConversationId = useRef<string | null>(conversationId);
|
||||
@@ -233,6 +245,7 @@ export default function Conversation() {
|
||||
status={status}
|
||||
showHeroOnEmpty={selectedAgent ? false : true}
|
||||
onOpenArtifact={handleOpenArtifact}
|
||||
onToolAction={handleToolAction}
|
||||
isSplitView={isSplitArtifactOpen}
|
||||
headerContent={
|
||||
selectedAgent ? (
|
||||
|
||||
@@ -65,6 +65,11 @@ const ConversationBubble = forwardRef<
|
||||
) => void;
|
||||
filesAttached?: { id: string; fileName: string }[];
|
||||
onOpenArtifact?: (artifact: { id: string; toolName: string }) => void;
|
||||
onToolAction?: (
|
||||
callId: string,
|
||||
decision: 'approved' | 'denied',
|
||||
comment?: string,
|
||||
) => void;
|
||||
}
|
||||
>(function ConversationBubble(
|
||||
{
|
||||
@@ -83,6 +88,7 @@ const ConversationBubble = forwardRef<
|
||||
handleUpdatedQuestionSubmission,
|
||||
filesAttached,
|
||||
onOpenArtifact,
|
||||
onToolAction,
|
||||
},
|
||||
ref,
|
||||
) {
|
||||
@@ -411,7 +417,7 @@ const ConversationBubble = forwardRef<
|
||||
)}
|
||||
{research && <ResearchProgress research={research} />}
|
||||
{toolCalls && toolCalls.length > 0 && (
|
||||
<ToolCalls toolCalls={toolCalls} />
|
||||
<ToolCalls toolCalls={toolCalls} onToolAction={onToolAction} />
|
||||
)}
|
||||
{!message && primaryArtifactCall?.artifact_id && onOpenArtifact && (
|
||||
<div className="my-2 ml-2 flex justify-start">
|
||||
@@ -884,108 +890,263 @@ function AllSources(sources: AllSourcesProps) {
|
||||
}
|
||||
export default ConversationBubble;
|
||||
|
||||
function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
||||
function ToolCallApprovalBar({
|
||||
toolCall,
|
||||
onToolAction,
|
||||
}: {
|
||||
toolCall: ToolCallsType;
|
||||
onToolAction?: (
|
||||
callId: string,
|
||||
decision: 'approved' | 'denied',
|
||||
comment?: string,
|
||||
) => void;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [comment, setComment] = useState('');
|
||||
const actionLabel = toolCall.action_name.substring(
|
||||
0,
|
||||
toolCall.action_name.lastIndexOf('_'),
|
||||
);
|
||||
const argPreview = JSON.stringify(toolCall.arguments);
|
||||
const truncated =
|
||||
argPreview.length > 60 ? argPreview.slice(0, 57) + '...' : argPreview;
|
||||
|
||||
return (
|
||||
<div className="border-border bg-muted dark:bg-card mb-2 w-full overflow-hidden rounded-2xl border">
|
||||
<div className="flex items-center gap-3 px-4 py-2.5">
|
||||
<div className="flex min-w-0 flex-1 items-center gap-2">
|
||||
<span className="text-sm font-semibold whitespace-nowrap">
|
||||
{toolCall.tool_name}
|
||||
</span>
|
||||
<span className="text-muted-foreground text-xs">{actionLabel}</span>
|
||||
<span
|
||||
className="text-muted-foreground hidden min-w-0 truncate font-mono text-xs md:block"
|
||||
title={argPreview}
|
||||
>
|
||||
{truncated}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
className={`rounded-full px-4 py-1 text-xs font-medium transition-colors ${
|
||||
comment
|
||||
? 'bg-muted text-muted-foreground cursor-default opacity-50'
|
||||
: 'bg-primary hover:bg-primary/90 text-white'
|
||||
}`}
|
||||
onClick={() => {
|
||||
if (!comment) onToolAction?.(toolCall.call_id, 'approved');
|
||||
}}
|
||||
>
|
||||
Approve
|
||||
</button>
|
||||
<button
|
||||
className={`rounded-full border px-4 py-1 text-xs font-medium transition-colors ${
|
||||
comment
|
||||
? 'border-destructive bg-destructive/10 text-destructive font-semibold'
|
||||
: 'hover:bg-accent text-muted-foreground'
|
||||
}`}
|
||||
onClick={() => {
|
||||
if (expanded && comment) {
|
||||
onToolAction?.(toolCall.call_id, 'denied', comment);
|
||||
} else if (expanded) {
|
||||
onToolAction?.(toolCall.call_id, 'denied');
|
||||
} else {
|
||||
setExpanded(true);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Deny
|
||||
</button>
|
||||
<button
|
||||
className="text-muted-foreground hover:text-foreground flex h-6 w-6 items-center justify-center rounded-full transition-colors"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
title="Details"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt="expand"
|
||||
className={`h-3.5 w-3.5 transition-transform duration-200 dark:invert ${expanded ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{expanded && (
|
||||
<div className="border-border border-t px-4 py-3">
|
||||
<p className="text-muted-foreground mb-1 text-xs font-medium">
|
||||
Arguments
|
||||
</p>
|
||||
<pre className="bg-background dark:bg-background/50 mb-2 max-h-40 overflow-auto rounded-lg p-2 font-mono text-xs">
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</pre>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Optional reason for denying..."
|
||||
className="border-border bg-background w-full rounded-lg border px-3 py-1.5 text-sm"
|
||||
value={comment}
|
||||
onChange={(e) => setComment(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' && comment) {
|
||||
onToolAction?.(toolCall.call_id, 'denied', comment);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolCalls({
|
||||
toolCalls,
|
||||
onToolAction,
|
||||
}: {
|
||||
toolCalls: ToolCallsType[];
|
||||
onToolAction?: (
|
||||
callId: string,
|
||||
decision: 'approved' | 'denied',
|
||||
comment?: string,
|
||||
) => void;
|
||||
}) {
|
||||
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
|
||||
|
||||
const awaitingCalls = toolCalls.filter(
|
||||
(tc) => tc.status === 'awaiting_approval',
|
||||
);
|
||||
const resolvedCalls = toolCalls.filter(
|
||||
(tc) => tc.status !== 'awaiting_approval',
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
alt={'ToolCalls'}
|
||||
className="h-full w-full object-fill"
|
||||
{/* Approval bars — always visible, compact inline */}
|
||||
{awaitingCalls.length > 0 && (
|
||||
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
{awaitingCalls.map((tc) => (
|
||||
<ToolCallApprovalBar
|
||||
key={`approval-${tc.call_id}`}
|
||||
toolCall={tc}
|
||||
onToolAction={onToolAction}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<button
|
||||
className="flex flex-row items-center gap-2"
|
||||
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
|
||||
>
|
||||
<p className="text-base font-semibold">Tool Calls</p>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt="ChevronDown"
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
{isToolCallsOpen && (
|
||||
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
<div className="grid grid-cols-1 gap-2">
|
||||
{toolCalls.map((toolCall, index) => (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={JSON.stringify(toolCall.arguments, null, 2)}
|
||||
/>
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={
|
||||
toolCall.status === 'error'
|
||||
? toolCall.error || 'Unknown error'
|
||||
: JSON.stringify(toolCall.result, null, 2)
|
||||
}
|
||||
/>
|
||||
</p>
|
||||
{toolCall.status === 'pending' && (
|
||||
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
|
||||
<Spinner size="small" />
|
||||
</span>
|
||||
)}
|
||||
{toolCall.status === 'completed' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="leading-[23px] text-red-500 dark:text-red-400"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Accordion>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Regular tool calls accordion */}
|
||||
{resolvedCalls.length > 0 && (
|
||||
<>
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
alt={'ToolCalls'}
|
||||
className="h-full w-full object-fill"
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<button
|
||||
className="flex flex-row items-center gap-2"
|
||||
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
|
||||
>
|
||||
<p className="text-base font-semibold">Tool Calls</p>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt="ChevronDown"
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
{isToolCallsOpen && (
|
||||
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
<div className="grid grid-cols-1 gap-2">
|
||||
{resolvedCalls.map((toolCall, index) => (
|
||||
<Accordion
|
||||
key={`tool-call-${index}`}
|
||||
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
|
||||
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
|
||||
titleClassName="px-6 py-2 text-sm font-semibold"
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Arguments
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={JSON.stringify(
|
||||
toolCall.arguments,
|
||||
null,
|
||||
2,
|
||||
)}
|
||||
/>
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
<div className="border-border flex flex-col rounded-2xl border">
|
||||
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
|
||||
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={
|
||||
toolCall.status === 'error'
|
||||
? toolCall.error || 'Unknown error'
|
||||
: JSON.stringify(toolCall.result, null, 2)
|
||||
}
|
||||
/>
|
||||
</p>
|
||||
{toolCall.status === 'pending' && (
|
||||
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
|
||||
<Spinner size="small" />
|
||||
</span>
|
||||
)}
|
||||
{toolCall.status === 'completed' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-destructive leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'denied' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-muted-foreground leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
Denied by user
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Accordion>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -38,6 +38,11 @@ type ConversationMessagesProps = {
|
||||
showHeroOnEmpty?: boolean;
|
||||
headerContent?: ReactNode;
|
||||
onOpenArtifact?: (artifact: { id: string; toolName: string }) => void;
|
||||
onToolAction?: (
|
||||
callId: string,
|
||||
decision: 'approved' | 'denied',
|
||||
comment?: string,
|
||||
) => void;
|
||||
isSplitView?: boolean;
|
||||
};
|
||||
|
||||
@@ -50,6 +55,7 @@ export default function ConversationMessages({
|
||||
showHeroOnEmpty = true,
|
||||
headerContent,
|
||||
onOpenArtifact,
|
||||
onToolAction,
|
||||
isSplitView = false,
|
||||
}: ConversationMessagesProps) {
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
@@ -154,6 +160,7 @@ export default function ConversationMessages({
|
||||
toolCalls={query.tool_calls}
|
||||
research={query.research}
|
||||
onOpenArtifact={onOpenArtifact}
|
||||
onToolAction={onToolAction}
|
||||
feedback={query.feedback}
|
||||
isStreaming={isCurrentlyStreaming}
|
||||
handleFeedback={
|
||||
|
||||
@@ -188,6 +188,264 @@ export function handleFetchAnswerSteaming(
|
||||
});
|
||||
}
|
||||
|
||||
export function handleSubmitToolActions(
|
||||
conversationId: string,
|
||||
toolActions: {
|
||||
call_id: string;
|
||||
decision?: 'approved' | 'denied';
|
||||
comment?: string;
|
||||
result?: Record<string, any>;
|
||||
}[],
|
||||
token: string | null,
|
||||
signal: AbortSignal,
|
||||
onEvent: (event: MessageEvent) => void,
|
||||
): Promise<Answer> {
|
||||
const payload = {
|
||||
conversation_id: conversationId,
|
||||
tool_actions: toolActions,
|
||||
};
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
conversationService
|
||||
.answerStream(payload, token, signal)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (event.trim().startsWith('data:')) {
|
||||
const dataLine: string = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const messageEvent = new MessageEvent('message', {
|
||||
data: dataLine.trim(),
|
||||
});
|
||||
|
||||
onEvent(messageEvent);
|
||||
}
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Tool actions submission failed:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat completion via the /v1/chat/completions endpoint.
|
||||
*
|
||||
* Translates the standard streaming format (choices[0].delta) back into
|
||||
* the internal DocsGPT event shape so the existing Redux reducers can
|
||||
* consume the events without any changes.
|
||||
*/
|
||||
export function handleV1ChatCompletionStreaming(
|
||||
question: string,
|
||||
signal: AbortSignal,
|
||||
agentApiKey: string,
|
||||
history: { prompt: string; response: string }[],
|
||||
onEvent: (event: MessageEvent) => void,
|
||||
tools?: any[],
|
||||
attachments?: string[],
|
||||
): Promise<Answer> {
|
||||
// Build messages array from history + current question
|
||||
const messages: any[] = [];
|
||||
for (const h of history) {
|
||||
messages.push({ role: 'user', content: h.prompt });
|
||||
messages.push({ role: 'assistant', content: h.response });
|
||||
}
|
||||
messages.push({ role: 'user', content: question });
|
||||
|
||||
const payload: any = {
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
if (tools && tools.length > 0) {
|
||||
payload.tools = tools;
|
||||
}
|
||||
if (attachments && attachments.length > 0) {
|
||||
payload.docsgpt = { attachments };
|
||||
}
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
conversationService
|
||||
.chatCompletions(payload, agentApiKey, signal)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (!event.trim().startsWith('data:')) continue;
|
||||
|
||||
const dataLine = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const trimmed = dataLine.trim();
|
||||
|
||||
// Handle [DONE] sentinel
|
||||
if (trimmed === '[DONE]') {
|
||||
onEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({ type: 'end' }),
|
||||
}),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
// Translate standard format to DocsGPT internal events
|
||||
const translated = translateV1ChunkToInternalEvents(parsed);
|
||||
for (const evt of translated) {
|
||||
onEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify(evt),
|
||||
}),
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
// Skip unparseable chunks
|
||||
}
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('V1 chat completion stream failed:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Translate a single v1 streaming chunk to internal DocsGPT event(s).
|
||||
*
|
||||
* Standard format: {"choices": [{"delta": {"content": "chunk"}, ...}]}
|
||||
* Extension format: {"docsgpt": {"type": "source", ...}}
|
||||
*/
|
||||
function translateV1ChunkToInternalEvents(
|
||||
chunk: any,
|
||||
): { type: string; [key: string]: any }[] {
|
||||
const events: { type: string; [key: string]: any }[] = [];
|
||||
|
||||
// DocsGPT extension chunks
|
||||
if (chunk.docsgpt) {
|
||||
const ext = chunk.docsgpt;
|
||||
if (ext.type === 'source') {
|
||||
events.push({ type: 'source', source: ext.sources });
|
||||
} else if (ext.type === 'tool_call') {
|
||||
events.push({ type: 'tool_call', data: ext.data });
|
||||
} else if (ext.type === 'tool_calls_pending') {
|
||||
events.push({
|
||||
type: 'tool_calls_pending',
|
||||
data: { pending_tool_calls: ext.pending_tool_calls },
|
||||
});
|
||||
} else if (ext.type === 'id') {
|
||||
events.push({ type: 'id', id: ext.conversation_id });
|
||||
}
|
||||
return events;
|
||||
}
|
||||
|
||||
// Error chunks
|
||||
if (chunk.error) {
|
||||
events.push({ type: 'error', error: chunk.error.message || 'Error' });
|
||||
return events;
|
||||
}
|
||||
|
||||
// Standard choices chunks
|
||||
const choice = chunk.choices?.[0];
|
||||
if (!choice) return events;
|
||||
|
||||
const delta = choice.delta || {};
|
||||
const finishReason = choice.finish_reason;
|
||||
|
||||
if (delta.content) {
|
||||
events.push({ type: 'answer', answer: delta.content });
|
||||
}
|
||||
|
||||
if (delta.reasoning_content) {
|
||||
events.push({ type: 'thought', thought: delta.reasoning_content });
|
||||
}
|
||||
|
||||
if (delta.tool_calls) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
let parsedArgs: Record<string, any> = {};
|
||||
if (tc.function?.arguments) {
|
||||
try {
|
||||
parsedArgs = JSON.parse(tc.function.arguments);
|
||||
} catch {
|
||||
// Arguments may arrive as fragments during streaming;
|
||||
// keep the raw string so downstream can accumulate it.
|
||||
parsedArgs = { _raw: tc.function.arguments };
|
||||
}
|
||||
}
|
||||
events.push({
|
||||
type: 'tool_call',
|
||||
data: {
|
||||
call_id: tc.id,
|
||||
action_name: tc.function?.name || '',
|
||||
tool_name: tc.function?.name || '',
|
||||
arguments: parsedArgs,
|
||||
status: 'requires_client_execution',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (finishReason === 'stop') {
|
||||
events.push({ type: 'end' });
|
||||
} else if (finishReason === 'tool_calls') {
|
||||
events.push({
|
||||
type: 'tool_calls_pending',
|
||||
data: { pending_tool_calls: [] },
|
||||
});
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
export function handleSearch(
|
||||
question: string,
|
||||
token: string | null,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
|
||||
export type Status = 'idle' | 'loading' | 'failed';
|
||||
export type Status = 'idle' | 'loading' | 'failed' | 'awaiting_tool_actions';
|
||||
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
|
||||
|
||||
export interface Message {
|
||||
|
||||
@@ -10,6 +10,8 @@ import {
|
||||
import {
|
||||
handleFetchAnswer,
|
||||
handleFetchAnswerSteaming,
|
||||
handleSubmitToolActions,
|
||||
handleV1ChatCompletionStreaming,
|
||||
} from './conversationHandlers';
|
||||
import {
|
||||
Answer,
|
||||
@@ -27,6 +29,7 @@ const initialState: ConversationState = {
|
||||
};
|
||||
|
||||
const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true';
|
||||
const USE_V1_API = import.meta.env.VITE_USE_V1_API === 'true';
|
||||
|
||||
let abortController: AbortController | null = null;
|
||||
export function handleAbort() {
|
||||
@@ -60,7 +63,102 @@ export const fetchAnswer = createAsyncThunk<
|
||||
state.preference.selectedModel?.id;
|
||||
|
||||
if (state.preference) {
|
||||
if (API_STREAMING) {
|
||||
const agentKey = state.preference.selectedAgent?.key;
|
||||
if (USE_V1_API && agentKey) {
|
||||
// Build history from prior queries for v1 format
|
||||
const v1History = state.conversation.queries
|
||||
.filter((q) => q.response)
|
||||
.map((q) => ({ prompt: q.prompt, response: q.response || '' }));
|
||||
|
||||
await handleV1ChatCompletionStreaming(
|
||||
question,
|
||||
signal,
|
||||
agentKey,
|
||||
v1History,
|
||||
(event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
const targetIndex = indx ?? state.conversation.queries.length - 1;
|
||||
|
||||
if (currentConversationId === state.conversation.conversationId) {
|
||||
if (data.type === 'end') {
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
getConversations(state.preference.token)
|
||||
.then((fetchedConversations) => {
|
||||
dispatch(setConversations(fetchedConversations));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to fetch conversations: ', error);
|
||||
});
|
||||
if (!isSourceUpdated) {
|
||||
dispatch(
|
||||
updateStreamingSource({
|
||||
conversationId: currentConversationId,
|
||||
index: targetIndex,
|
||||
query: { sources: [] },
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else if (data.type === 'id') {
|
||||
const currentState = getState() as RootState;
|
||||
if (currentState.conversation.conversationId === null) {
|
||||
dispatch(
|
||||
updateConversationId({
|
||||
query: { conversationId: data.id },
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else if (data.type === 'thought') {
|
||||
dispatch(
|
||||
updateThought({
|
||||
conversationId: currentConversationId,
|
||||
index: targetIndex,
|
||||
query: { thought: data.thought },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'source') {
|
||||
isSourceUpdated = true;
|
||||
dispatch(
|
||||
updateStreamingSource({
|
||||
conversationId: currentConversationId,
|
||||
index: targetIndex,
|
||||
query: { sources: data.source ?? [] },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_call') {
|
||||
dispatch(
|
||||
updateToolCall({
|
||||
index: targetIndex,
|
||||
tool_call: data.data as ToolCallsType,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_calls_pending') {
|
||||
dispatch(
|
||||
conversationSlice.actions.setStatus('awaiting_tool_actions'),
|
||||
);
|
||||
} else if (data.type === 'error') {
|
||||
dispatch(conversationSlice.actions.setStatus('failed'));
|
||||
dispatch(
|
||||
conversationSlice.actions.raiseError({
|
||||
conversationId: currentConversationId,
|
||||
index: targetIndex,
|
||||
message: data.error,
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
conversationId: currentConversationId,
|
||||
index: targetIndex,
|
||||
query: { response: data.answer },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
undefined,
|
||||
attachmentIds.length > 0 ? attachmentIds : undefined,
|
||||
);
|
||||
} else if (API_STREAMING) {
|
||||
await handleFetchAnswerSteaming(
|
||||
question,
|
||||
signal,
|
||||
@@ -138,6 +236,10 @@ export const fetchAnswer = createAsyncThunk<
|
||||
tool_call: data.data as ToolCallsType,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_calls_pending') {
|
||||
dispatch(
|
||||
conversationSlice.actions.setStatus('awaiting_tool_actions'),
|
||||
);
|
||||
} else if (data.type === 'research_plan') {
|
||||
dispatch(
|
||||
updateResearchPlan({
|
||||
@@ -260,6 +362,94 @@ export const fetchAnswer = createAsyncThunk<
|
||||
};
|
||||
});
|
||||
|
||||
export const submitToolActions = createAsyncThunk<
|
||||
void,
|
||||
{
|
||||
toolActions: {
|
||||
call_id: string;
|
||||
decision?: 'approved' | 'denied';
|
||||
comment?: string;
|
||||
result?: Record<string, any>;
|
||||
}[];
|
||||
}
|
||||
>('submitToolActions', async ({ toolActions }, { dispatch, getState }) => {
|
||||
if (abortController) abortController.abort();
|
||||
abortController = new AbortController();
|
||||
const { signal } = abortController;
|
||||
|
||||
const state = getState() as RootState;
|
||||
const conversationId = state.conversation.conversationId;
|
||||
if (!conversationId) return;
|
||||
|
||||
dispatch(conversationSlice.actions.setStatus('loading'));
|
||||
|
||||
await handleSubmitToolActions(
|
||||
conversationId,
|
||||
toolActions,
|
||||
state.preference.token,
|
||||
signal,
|
||||
(event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
const targetIndex = state.conversation.queries.length - 1;
|
||||
|
||||
if (data.type === 'end') {
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
getConversations(state.preference.token)
|
||||
.then((fetchedConversations) => {
|
||||
dispatch(setConversations(fetchedConversations));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to fetch conversations: ', error);
|
||||
});
|
||||
} else if (data.type === 'id') {
|
||||
// conversation ID already set
|
||||
} else if (data.type === 'thought') {
|
||||
dispatch(
|
||||
updateThought({
|
||||
conversationId,
|
||||
index: targetIndex,
|
||||
query: { thought: data.thought },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'source') {
|
||||
dispatch(
|
||||
updateStreamingSource({
|
||||
conversationId,
|
||||
index: targetIndex,
|
||||
query: { sources: data.source ?? [] },
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_call') {
|
||||
dispatch(
|
||||
updateToolCall({
|
||||
index: targetIndex,
|
||||
tool_call: data.data as ToolCallsType,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'tool_calls_pending') {
|
||||
dispatch(conversationSlice.actions.setStatus('awaiting_tool_actions'));
|
||||
} else if (data.type === 'error') {
|
||||
dispatch(conversationSlice.actions.setStatus('failed'));
|
||||
dispatch(
|
||||
conversationSlice.actions.raiseError({
|
||||
conversationId,
|
||||
index: targetIndex,
|
||||
message: data.error,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'answer') {
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
conversationId,
|
||||
index: targetIndex,
|
||||
query: { response: data.answer },
|
||||
}),
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
export const conversationSlice = createSlice({
|
||||
name: 'conversation',
|
||||
initialState,
|
||||
|
||||
@@ -5,6 +5,12 @@ export type ToolCallsType = {
|
||||
arguments: Record<string, any>;
|
||||
result?: Record<string, any>;
|
||||
error?: string;
|
||||
status?: 'pending' | 'completed' | 'error';
|
||||
status?:
|
||||
| 'pending'
|
||||
| 'completed'
|
||||
| 'error'
|
||||
| 'awaiting_approval'
|
||||
| 'denied'
|
||||
| 'requires_client_execution';
|
||||
artifact_id?: string;
|
||||
};
|
||||
|
||||
@@ -487,9 +487,33 @@ export default function ToolConfig({
|
||||
)}
|
||||
</div>
|
||||
<div
|
||||
className="flex items-center gap-2"
|
||||
className="flex items-center gap-3"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div className="flex items-center gap-1">
|
||||
<span className="text-xs text-gray-500 dark:text-gray-400">
|
||||
{t('settings.tools.requireApproval', 'Approval')}
|
||||
</span>
|
||||
<ToggleSwitch
|
||||
checked={action.require_approval ?? false}
|
||||
onChange={(checked) => {
|
||||
setTool({
|
||||
...tool,
|
||||
actions: tool.actions.map((act, index) => {
|
||||
if (index === originalIndex) {
|
||||
return {
|
||||
...act,
|
||||
require_approval: checked,
|
||||
};
|
||||
}
|
||||
return act;
|
||||
}),
|
||||
});
|
||||
}}
|
||||
size="small"
|
||||
id={`approvalToggle-${originalIndex}`}
|
||||
/>
|
||||
</div>
|
||||
<ToggleSwitch
|
||||
checked={action.active}
|
||||
onChange={(checked) => {
|
||||
@@ -926,6 +950,35 @@ function APIToolConfig({
|
||||
className="h-4 w-4 opacity-40 transition-opacity hover:opacity-100"
|
||||
/>
|
||||
</button>
|
||||
<div className="flex items-center gap-1">
|
||||
<span className="text-xs text-gray-500 dark:text-gray-400">
|
||||
{t('settings.tools.requireApproval', 'Approval')}
|
||||
</span>
|
||||
<ToggleSwitch
|
||||
checked={action.require_approval ?? false}
|
||||
onChange={() => {
|
||||
setApiTool((prevApiTool) => {
|
||||
const updatedActions = {
|
||||
...prevApiTool.config.actions,
|
||||
};
|
||||
updatedActions[actionName] = {
|
||||
...updatedActions[actionName],
|
||||
require_approval:
|
||||
!updatedActions[actionName].require_approval,
|
||||
};
|
||||
return {
|
||||
...prevApiTool,
|
||||
config: {
|
||||
...prevApiTool.config,
|
||||
actions: updatedActions,
|
||||
},
|
||||
};
|
||||
});
|
||||
}}
|
||||
size="small"
|
||||
id={`approvalToggle-${actionIndex}`}
|
||||
/>
|
||||
</div>
|
||||
<ToggleSwitch
|
||||
checked={action.active}
|
||||
onChange={() => handleActionToggle(actionName)}
|
||||
|
||||
@@ -69,6 +69,7 @@ export type UserToolType = {
|
||||
type: string;
|
||||
};
|
||||
active: boolean;
|
||||
require_approval?: boolean;
|
||||
}[];
|
||||
};
|
||||
|
||||
@@ -81,6 +82,7 @@ export type APIActionType = {
|
||||
headers: ParameterGroupType;
|
||||
body: ParameterGroupType;
|
||||
active: boolean;
|
||||
require_approval?: boolean;
|
||||
body_content_type?:
|
||||
| 'application/json'
|
||||
| 'application/x-www-form-urlencoded'
|
||||
|
||||
@@ -341,7 +341,7 @@ class TestBaseAgentTools:
|
||||
|
||||
assert len(agent.tools) == 1
|
||||
assert agent.tools[0]["type"] == "function"
|
||||
assert agent.tools[0]["function"]["name"] == "get_data_1"
|
||||
assert agent.tools[0]["function"]["name"] == "get_data"
|
||||
|
||||
def test_prepare_tools_with_regular_tool(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
@@ -365,7 +365,7 @@ class TestBaseAgentTools:
|
||||
agent._prepare_tools(tools_dict)
|
||||
|
||||
assert len(agent.tools) == 1
|
||||
assert agent.tools[0]["function"]["name"] == "action1_1"
|
||||
assert agent.tools[0]["function"]["name"] == "action1"
|
||||
|
||||
def test_prepare_tools_filters_inactive_actions(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
@@ -395,7 +395,7 @@ class TestBaseAgentTools:
|
||||
agent._prepare_tools(tools_dict)
|
||||
|
||||
assert len(agent.tools) == 1
|
||||
assert agent.tools[0]["function"]["name"] == "active_action_1"
|
||||
assert agent.tools[0]["function"]["name"] == "active_action"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@@ -202,3 +202,69 @@ class TestToolActionParser:
|
||||
assert action_name == "create_record"
|
||||
assert call_args["data"]["name"] == "John"
|
||||
assert call_args["data"]["age"] == 30
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolActionParserWithMapping:
|
||||
"""Tests for the mapping-based lookup path."""
|
||||
|
||||
def test_openai_mapping_resolves_clean_name(self):
|
||||
mapping = {"get_weather": ("ct0", "get_weather")}
|
||||
parser = ToolActionParser("OpenAILLM", name_mapping=mapping)
|
||||
|
||||
call = Mock()
|
||||
call.name = "get_weather"
|
||||
call.arguments = '{"city": "SF"}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
assert tool_id == "ct0"
|
||||
assert action_name == "get_weather"
|
||||
assert call_args == {"city": "SF"}
|
||||
|
||||
def test_openai_mapping_resolves_numbered_suffix(self):
|
||||
mapping = {"search_1": ("t1", "search"), "search_2": ("t2", "search")}
|
||||
parser = ToolActionParser("OpenAILLM", name_mapping=mapping)
|
||||
|
||||
call = Mock()
|
||||
call.name = "search_1"
|
||||
call.arguments = '{"q": "test"}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
assert tool_id == "t1"
|
||||
assert action_name == "search"
|
||||
|
||||
def test_google_mapping_resolves(self):
|
||||
mapping = {"get_weather": ("ct0", "get_weather")}
|
||||
parser = ToolActionParser("GoogleLLM", name_mapping=mapping)
|
||||
|
||||
call = Mock()
|
||||
call.name = "get_weather"
|
||||
call.arguments = {"city": "SF"}
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
assert tool_id == "ct0"
|
||||
assert action_name == "get_weather"
|
||||
|
||||
def test_fallback_to_split_when_not_in_mapping(self):
|
||||
mapping = {"get_weather": ("ct0", "get_weather")}
|
||||
parser = ToolActionParser("OpenAILLM", name_mapping=mapping)
|
||||
|
||||
call = Mock()
|
||||
call.name = "unknown_action_99"
|
||||
call.arguments = "{}"
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
# Falls back to legacy split
|
||||
assert tool_id == "99"
|
||||
assert action_name == "unknown_action"
|
||||
|
||||
def test_no_mapping_uses_legacy_split(self):
|
||||
parser = ToolActionParser("OpenAILLM", name_mapping=None)
|
||||
|
||||
call = Mock()
|
||||
call.name = "action_123"
|
||||
call.arguments = '{"k": "v"}'
|
||||
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
assert tool_id == "123"
|
||||
assert action_name == "action"
|
||||
|
||||
@@ -80,7 +80,7 @@ class TestToolExecutorPrepare:
|
||||
result = executor.prepare_tools_for_llm(tools_dict)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function"
|
||||
assert result[0]["function"]["name"] == "do_thing_t1"
|
||||
assert result[0]["function"]["name"] == "do_thing"
|
||||
assert "query" in result[0]["function"]["parameters"]["properties"]
|
||||
|
||||
def test_prepare_tools_skips_inactive_actions(self):
|
||||
@@ -97,7 +97,96 @@ class TestToolExecutorPrepare:
|
||||
|
||||
result = executor.prepare_tools_for_llm(tools_dict)
|
||||
assert len(result) == 1
|
||||
assert result[0]["function"]["name"] == "active_one_t1"
|
||||
assert result[0]["function"]["name"] == "active_one"
|
||||
|
||||
def test_prepare_tools_builds_name_mapping(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"actions": [
|
||||
{"name": "do_thing", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
executor.prepare_tools_for_llm(tools_dict)
|
||||
assert executor._name_to_tool["do_thing"] == ("t1", "do_thing")
|
||||
assert executor._tool_to_name[("t1", "do_thing")] == "do_thing"
|
||||
|
||||
def test_prepare_tools_duplicate_names_get_numbered_suffixes(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "tool_a",
|
||||
"actions": [
|
||||
{"name": "search", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
},
|
||||
"t2": {
|
||||
"name": "tool_b",
|
||||
"actions": [
|
||||
{"name": "search", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = executor.prepare_tools_for_llm(tools_dict)
|
||||
names = [r["function"]["name"] for r in result]
|
||||
assert "search_1" in names
|
||||
assert "search_2" in names
|
||||
assert executor._name_to_tool["search_1"][1] == "search"
|
||||
assert executor._name_to_tool["search_2"][1] == "search"
|
||||
|
||||
def test_prepare_tools_unique_name_no_suffix(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "tool_a",
|
||||
"actions": [
|
||||
{"name": "get_weather", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
},
|
||||
"t2": {
|
||||
"name": "tool_b",
|
||||
"actions": [
|
||||
{"name": "send_email", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = executor.prepare_tools_for_llm(tools_dict)
|
||||
names = [r["function"]["name"] for r in result]
|
||||
assert "get_weather" in names
|
||||
assert "send_email" in names
|
||||
|
||||
def test_prepare_tools_suffix_skips_collision_with_unique_name(self):
|
||||
"""If action 'foo_1' exists as unique and 'foo' is duplicated, skip '_1'."""
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "tool_a",
|
||||
"actions": [
|
||||
{"name": "foo", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
},
|
||||
"t2": {
|
||||
"name": "tool_b",
|
||||
"actions": [
|
||||
{"name": "foo", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
},
|
||||
"t3": {
|
||||
"name": "tool_c",
|
||||
"actions": [
|
||||
{"name": "foo_1", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
],
|
||||
},
|
||||
}
|
||||
result = executor.prepare_tools_for_llm(tools_dict)
|
||||
names = [r["function"]["name"] for r in result]
|
||||
# foo_1 is taken by the unique action, so duplicates skip to _2 and _3
|
||||
assert "foo_1" in names # The unique action
|
||||
assert "foo_2" in names
|
||||
assert "foo_3" in names
|
||||
assert executor._name_to_tool["foo_1"] == ("t3", "foo_1")
|
||||
|
||||
def test_build_tool_parameters_filters_non_llm_fields(self):
|
||||
executor = ToolExecutor()
|
||||
@@ -128,6 +217,68 @@ class TestToolExecutorPrepare:
|
||||
assert "value" not in result["properties"]["query"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckPause:
|
||||
|
||||
def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"):
|
||||
call = Mock()
|
||||
call.name = name
|
||||
call.id = call_id
|
||||
call.arguments = arguments
|
||||
call.thought_signature = None
|
||||
return call
|
||||
|
||||
def test_client_side_tool_returns_llm_name(self):
|
||||
"""check_pause returns the clean LLM-facing name and llm_name field."""
|
||||
executor = ToolExecutor()
|
||||
|
||||
tools_dict = {
|
||||
"ct0": {
|
||||
"name": "write_file",
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{"name": "write_file", "description": "Write a file", "active": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Prepare tools so the mapping is built
|
||||
executor.prepare_tools_for_llm(tools_dict)
|
||||
|
||||
call = self._make_call(name="write_file")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "write_file"
|
||||
assert result["llm_name"] == "write_file"
|
||||
assert result["action_name"] == "write_file"
|
||||
assert result["tool_id"] == "ct0"
|
||||
|
||||
def test_approval_required_returns_llm_name(self):
|
||||
"""check_pause for approval-required tools returns clean LLM name."""
|
||||
executor = ToolExecutor()
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "dangerous_tool",
|
||||
"actions": [
|
||||
{"name": "delete_all", "description": "Deletes everything", "active": True,
|
||||
"require_approval": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
executor.prepare_tools_for_llm(tools_dict)
|
||||
|
||||
call = self._make_call(name="delete_all")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
|
||||
assert result is not None
|
||||
assert result["name"] == "delete_all"
|
||||
assert result["llm_name"] == "delete_all"
|
||||
assert result["action_name"] == "delete_all"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolExecutorExecute:
|
||||
|
||||
@@ -143,7 +294,7 @@ class TestToolExecutorExecute:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=(None, None, {}))),
|
||||
lambda _cls, **kw: Mock(parse_args=Mock(return_value=(None, None, {}))),
|
||||
)
|
||||
|
||||
call = self._make_call(name="bad")
|
||||
@@ -167,7 +318,7 @@ class TestToolExecutorExecute:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))),
|
||||
lambda _cls, **kw: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))),
|
||||
)
|
||||
|
||||
call = self._make_call()
|
||||
@@ -190,7 +341,7 @@ class TestToolExecutorExecute:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))),
|
||||
lambda _cls, **kw: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))),
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
@@ -244,7 +395,7 @@ class TestToolExecutorExecute:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))),
|
||||
lambda _cls, **kw: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))),
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
@@ -284,7 +435,7 @@ class TestToolExecutorExecute:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(return_value=("t1", "get_users", {"body_param": "val"}))
|
||||
),
|
||||
)
|
||||
@@ -331,7 +482,7 @@ class TestToolExecutorExecute:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(return_value=("t1", "act", {}))
|
||||
),
|
||||
)
|
||||
@@ -376,7 +527,7 @@ class TestToolExecutorExecute:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(return_value=("t1", "act", {"q": "v"}))
|
||||
),
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestAnswerResourcePost:
|
||||
),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(conv_id, "Hello", [], [], "", None),
|
||||
return_value={"conversation_id": conv_id, "answer": "Hello", "sources": [], "tool_calls": [], "thought": "", "error": None},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
@@ -129,7 +129,7 @@ class TestAnswerResourcePost:
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(None, None, None, None, None, "Stream error"),
|
||||
return_value={"conversation_id": None, "answer": None, "sources": None, "tool_calls": None, "thought": None, "error": "Stream error"},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
@@ -173,15 +173,7 @@ class TestAnswerResourcePost:
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
'{"key": "val"}',
|
||||
[],
|
||||
[],
|
||||
"",
|
||||
None,
|
||||
{"structured": True, "schema": {"type": "object"}},
|
||||
),
|
||||
return_value={"conversation_id": conv_id, "answer": '{"key": "val"}', "sources": [], "tool_calls": [], "thought": "", "error": None, "extra": {"structured": True, "schema": {"type": "object"}}},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
@@ -208,14 +200,7 @@ class TestAnswerResourcePost:
|
||||
return_value=iter([]),
|
||||
), patch(
|
||||
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
|
||||
return_value=(
|
||||
conv_id,
|
||||
"answer text",
|
||||
[{"title": "src"}],
|
||||
[{"tool": "t"}],
|
||||
"thinking...",
|
||||
None,
|
||||
),
|
||||
return_value={"conversation_id": conv_id, "answer": "answer text", "sources": [{"title": "src"}], "tool_calls": [{"tool": "t"}], "thought": "thinking...", "error": None},
|
||||
):
|
||||
resp = answer_client.post(
|
||||
"/api/answer",
|
||||
|
||||
@@ -481,10 +481,10 @@ class TestProcessResponseStream:
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert result[0] == conv_id
|
||||
assert result[1] == "Hello world"
|
||||
assert result[2] == [{"title": "doc1"}]
|
||||
assert result[5] is None
|
||||
assert result["conversation_id"] == conv_id
|
||||
assert result["answer"] == "Hello world"
|
||||
assert result["sources"] == [{"title": "doc1"}]
|
||||
assert result["error"] is None
|
||||
|
||||
def test_handles_stream_error(self, mock_mongo_db, flask_app):
|
||||
import json
|
||||
@@ -500,10 +500,8 @@ class TestProcessResponseStream:
|
||||
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
|
||||
assert len(result) == 6
|
||||
assert result[0] is None
|
||||
assert result[4] == "Test error"
|
||||
assert result[5] is None
|
||||
assert result["conversation_id"] is None
|
||||
assert result["error"] == "Test error"
|
||||
|
||||
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
@@ -79,9 +79,10 @@ class TestBuildFromCompressedContext:
|
||||
# system + user + assistant + tool_call_assistant + tool_response = 5
|
||||
assert len(messages) == 5
|
||||
assert messages[3]["role"] == "assistant"
|
||||
assert "function_call" in messages[3]["content"][0]
|
||||
assert messages[3].get("tool_calls") is not None
|
||||
assert messages[3]["tool_calls"][0]["function"]["name"] == "search"
|
||||
assert messages[4]["role"] == "tool"
|
||||
assert "function_response" in messages[4]["content"][0]
|
||||
assert messages[4].get("tool_call_id") == "call-1"
|
||||
|
||||
def test_tool_calls_not_included_by_default(self):
|
||||
queries = [
|
||||
@@ -127,8 +128,8 @@ class TestBuildFromCompressedContext:
|
||||
recent_queries=queries,
|
||||
include_tool_calls=True,
|
||||
)
|
||||
tool_msg = messages[3]["content"][0]
|
||||
call_id = tool_msg["function_call"]["call_id"]
|
||||
assistant_msg = messages[3]
|
||||
call_id = assistant_msg["tool_calls"][0]["id"]
|
||||
assert call_id is not None
|
||||
assert len(call_id) > 0
|
||||
|
||||
|
||||
@@ -295,10 +295,8 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[1] == "{}"
|
||||
# Structured output adds extra tuple element
|
||||
assert len(result) == 7
|
||||
assert result[6]["structured"] is True
|
||||
assert result["answer"] == "{}"
|
||||
assert result.get("extra", {}).get("structured") is True
|
||||
|
||||
def test_handles_tool_calls_event(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
@@ -312,7 +310,7 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[3] == [{"name": "t1"}]
|
||||
assert result["tool_calls"] == [{"name": "t1"}]
|
||||
|
||||
def test_incomplete_stream(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
@@ -323,7 +321,7 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[4] == "Stream ended unexpectedly"
|
||||
assert result["error"] == "Stream ended unexpectedly"
|
||||
|
||||
def test_handles_thought_event(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
@@ -335,7 +333,7 @@ class TestProcessResponseStreamExtended:
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result[4] == "thinking..."
|
||||
assert result["thought"] == "thinking..."
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@@ -50,6 +50,7 @@ from tests.integration.test_analytics import AnalyticsTests
|
||||
from tests.integration.test_connectors import ConnectorTests
|
||||
from tests.integration.test_mcp import MCPTests
|
||||
from tests.integration.test_misc import MiscTests
|
||||
from tests.integration.test_v1_api import V1ApiTests
|
||||
|
||||
|
||||
# Module registry
|
||||
@@ -64,6 +65,7 @@ MODULES = {
|
||||
"connectors": ConnectorTests,
|
||||
"mcp": MCPTests,
|
||||
"misc": MiscTests,
|
||||
"v1_api": V1ApiTests,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1036,208 +1036,7 @@ This is test documentation for integration tests.
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Compression Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_compression_heavy_tool_usage(self) -> bool:
|
||||
"""Test compression with heavy conversation usage."""
|
||||
test_name = "Compression - Heavy Tool Usage"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
self.print_info("Making 10 consecutive requests to build conversation history...")
|
||||
|
||||
current_conv_id = None
|
||||
|
||||
for i in range(10):
|
||||
question = f"Tell me about Python topic {i+1}: data structures, decorators, async, testing. Provide a comprehensive explanation."
|
||||
|
||||
payload = {
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
if current_conv_id:
|
||||
payload["conversation_id"] = current_conv_id
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=payload, timeout=90)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
current_conv_id = result.get("conversation_id", current_conv_id)
|
||||
answer_preview = (result.get("answer") or "")[:80]
|
||||
self.print_success(f"Request {i+1}/10 completed")
|
||||
self.print_info(f" Answer: {answer_preview}...")
|
||||
else:
|
||||
self.print_error(f"Request {i+1}/10 failed: status {response.status_code}")
|
||||
self.record_result(test_name, False, f"Request {i+1} failed")
|
||||
return False
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Request {i+1}/10 failed: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
if current_conv_id:
|
||||
self.print_success("Heavy usage test completed")
|
||||
self.record_result(test_name, True, f"10 requests, conv_id: {current_conv_id}")
|
||||
return True
|
||||
else:
|
||||
self.print_warning("No conversation_id received")
|
||||
self.record_result(test_name, False, "No conversation_id")
|
||||
return False
|
||||
|
||||
def test_compression_needle_in_haystack(self) -> bool:
|
||||
"""Test that compression preserves critical information.
|
||||
|
||||
Note: This is a long-running test that may timeout due to LLM response times.
|
||||
Timeouts are handled gracefully as they indicate performance issues, not bugs.
|
||||
"""
|
||||
test_name = "Compression - Needle in Haystack"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
conversation_id = None
|
||||
|
||||
# Step 1: Send general questions
|
||||
self.print_info("Step 1: Sending general questions...")
|
||||
for i, question in enumerate([
|
||||
"Tell me about Python best practices in detail",
|
||||
"Explain Python data structures comprehensively",
|
||||
]):
|
||||
payload = {
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
if conversation_id:
|
||||
payload["conversation_id"] = conversation_id
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
conversation_id = result.get("conversation_id", conversation_id)
|
||||
self.print_success(f"General question {i+1}/2 completed")
|
||||
else:
|
||||
self.print_error(f"Request failed: status {response.status_code}")
|
||||
self.record_result(test_name, False, "General questions failed")
|
||||
return False
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
# Timeout errors are expected for long LLM responses
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.print_error(f"Request failed: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# Step 2: Send critical information
|
||||
self.print_info("Step 2: Sending CRITICAL information...")
|
||||
critical_payload = {
|
||||
"question": "Please remember: The production database password is stored in DB_PASSWORD_PROD environment variable. The backup runs at 3:00 AM UTC daily.",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=critical_payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
conversation_id = result.get("conversation_id", conversation_id)
|
||||
self.print_success("Critical information sent")
|
||||
else:
|
||||
self.record_result(test_name, False, "Critical info failed")
|
||||
return False
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# Step 3: Bury with more questions
|
||||
self.print_info("Step 3: Sending more questions to bury critical info...")
|
||||
for i, question in enumerate([
|
||||
"Explain Python decorators in great detail",
|
||||
"Tell me about Python async programming comprehensively",
|
||||
]):
|
||||
payload = {
|
||||
"question": question,
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
conversation_id = result.get("conversation_id", conversation_id)
|
||||
self.print_success(f"Burying question {i+1}/2 completed")
|
||||
else:
|
||||
self.record_result(test_name, False, "Burying questions failed")
|
||||
return False
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# Step 4: Test recall
|
||||
self.print_info("Step 4: Testing if critical info was preserved...")
|
||||
recall_payload = {
|
||||
"question": "What was the database password environment variable I mentioned earlier?",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/answer", json=recall_payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
answer = (result.get("answer") or "").lower()
|
||||
|
||||
if "db_password_prod" in answer or "database password" in answer:
|
||||
self.print_success("Critical information preserved!")
|
||||
self.print_info(f"Answer: {answer[:150]}...")
|
||||
self.record_result(test_name, True, "Info preserved")
|
||||
return True
|
||||
else:
|
||||
self.print_warning("Critical information may have been lost")
|
||||
self.print_info(f"Answer: {answer[:150]}...")
|
||||
self.record_result(test_name, False, "Info not preserved")
|
||||
return False
|
||||
else:
|
||||
self.record_result(test_name, False, "Recall failed")
|
||||
return False
|
||||
except Exception as e:
|
||||
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
|
||||
self.print_warning(f"Request timed out: {str(e)[:50]}")
|
||||
self.record_result(test_name, True, "Skipped (timeout)")
|
||||
return True
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Feedback Tests (NEW)
|
||||
# Feedback Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_feedback_positive(self) -> bool:
|
||||
@@ -1435,15 +1234,6 @@ This is test documentation for integration tests.
|
||||
self.test_tts_basic()
|
||||
time.sleep(1)
|
||||
|
||||
# Compression tests (longer running)
|
||||
if self.is_authenticated:
|
||||
self.test_compression_heavy_tool_usage()
|
||||
time.sleep(2)
|
||||
|
||||
self.test_compression_needle_in_haystack()
|
||||
else:
|
||||
self.print_info("Skipping compression tests (no authentication)")
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
|
||||
681
tests/integration/test_v1_api.py
Normal file
681
tests/integration/test_v1_api.py
Normal file
@@ -0,0 +1,681 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for the /v1/ chat completions API (Phase 4).
|
||||
|
||||
Endpoints tested:
|
||||
- /v1/chat/completions (POST) - Standard chat completions (streaming & non-streaming)
|
||||
- /v1/models (GET) - List available agent models
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_v1_api.py
|
||||
python tests/integration/test_v1_api.py --base-url http://localhost:7091
|
||||
python tests/integration/test_v1_api.py --token YOUR_JWT_TOKEN
|
||||
"""
|
||||
|
||||
import json as json_module
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
# Add parent directory to path for standalone execution
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
|
||||
class V1ApiTests(DocsGPTTestBase):
|
||||
"""Integration tests for /v1/ chat completions API."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Test Data Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_agent_key(self) -> Optional[str]:
|
||||
"""Get or create a test agent and return its API key."""
|
||||
if hasattr(self, "_agent_key") and self._agent_key:
|
||||
return self._agent_key
|
||||
|
||||
# Try both authenticated and unauthenticated creation.
|
||||
# Published agents need a source to get an API key.
|
||||
payload = {
|
||||
"name": f"V1 Test Agent {int(time.time())}",
|
||||
"description": "Integration test agent for v1 API tests",
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"retriever": "classic",
|
||||
"agent_type": "classic",
|
||||
"status": "published",
|
||||
"source": "default",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_agent", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
api_key = result.get("key")
|
||||
self._agent_id = result.get("id")
|
||||
if api_key:
|
||||
self._agent_key = api_key
|
||||
self.print_info(f"Created test agent with key: {api_key[:8]}...")
|
||||
return api_key
|
||||
else:
|
||||
self.print_warning("Agent created but no API key returned")
|
||||
else:
|
||||
self.print_warning(f"Agent creation returned {response.status_code}: {response.text[:200]}")
|
||||
except Exception as e:
|
||||
self.print_error(f"Failed to create agent: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _v1_headers(self, api_key: str) -> dict:
|
||||
"""Build headers for v1 API requests."""
|
||||
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Auth Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_no_auth_returns_401(self) -> bool:
|
||||
"""Test that /v1/chat/completions without auth returns 401."""
|
||||
test_name = "v1 chat completions - no auth"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "Hi"}]},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
self.print_success("Correctly returned 401 for missing auth")
|
||||
self.record_result(test_name, True, "401 as expected")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected 401, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Request failed: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_invalid_key_returns_error(self) -> bool:
|
||||
"""Test that invalid API key returns error."""
|
||||
test_name = "v1 chat completions - invalid key"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "Hi"}]},
|
||||
headers=self._v1_headers("invalid-key-12345"),
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Should return 400 or 500 (agent not found)
|
||||
if response.status_code in [400, 401, 500]:
|
||||
self.print_success(f"Correctly returned {response.status_code} for invalid key")
|
||||
self.record_result(test_name, True, f"Error as expected ({response.status_code})")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Unexpected status: {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Request failed: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_missing_messages_returns_400(self) -> bool:
|
||||
"""Test that missing messages field returns 400."""
|
||||
test_name = "v1 chat completions - missing messages"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={"stream": False},
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code == 400:
|
||||
self.print_success("Correctly returned 400 for missing messages")
|
||||
self.record_result(test_name, True, "400 as expected")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected 400, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Request failed: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Non-streaming
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_non_streaming_basic(self) -> bool:
|
||||
"""Test basic non-streaming chat completion."""
|
||||
test_name = "v1 chat completions - non-streaming"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello in one word."}],
|
||||
"stream": False,
|
||||
},
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
self.print_info(f"Status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.print_error(f"Response: {response.text[:300]}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Verify standard format
|
||||
checks = [
|
||||
("id" in data, "has id"),
|
||||
(data.get("object") == "chat.completion", "object is chat.completion"),
|
||||
("choices" in data, "has choices"),
|
||||
(len(data["choices"]) > 0, "choices not empty"),
|
||||
(data["choices"][0].get("message", {}).get("role") == "assistant", "role is assistant"),
|
||||
(data["choices"][0].get("message", {}).get("content") is not None, "has content"),
|
||||
(data["choices"][0].get("finish_reason") == "stop", "finish_reason is stop"),
|
||||
("usage" in data, "has usage"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
self.print_info(f"Response: {content[:100]}")
|
||||
|
||||
# Check docsgpt extension
|
||||
if "docsgpt" in data:
|
||||
self.print_success(" has docsgpt extension")
|
||||
if "conversation_id" in data["docsgpt"]:
|
||||
self.print_success(f" conversation_id: {data['docsgpt']['conversation_id'][:8]}...")
|
||||
|
||||
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Streaming
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_streaming_basic(self) -> bool:
|
||||
"""Test basic streaming chat completion."""
|
||||
test_name = "v1 chat completions - streaming"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hi briefly."}],
|
||||
"stream": True,
|
||||
},
|
||||
headers=self._v1_headers(api_key),
|
||||
stream=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
self.print_info(f"Status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
chunks = []
|
||||
content_pieces = []
|
||||
got_done = False
|
||||
got_stop = False
|
||||
got_id = False
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8")
|
||||
if not line_str.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line_str[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
got_done = True
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json_module.loads(data_str)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Standard chunks
|
||||
if "choices" in chunk:
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
if "content" in delta:
|
||||
content_pieces.append(delta["content"])
|
||||
if chunk["choices"][0].get("finish_reason") == "stop":
|
||||
got_stop = True
|
||||
|
||||
# Extension chunks
|
||||
if "docsgpt" in chunk:
|
||||
ext = chunk["docsgpt"]
|
||||
if ext.get("type") == "id":
|
||||
got_id = True
|
||||
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
full_content = "".join(content_pieces)
|
||||
|
||||
checks = [
|
||||
(len(chunks) > 0, f"received {len(chunks)} chunks"),
|
||||
(len(content_pieces) > 0, f"got content: {full_content[:50]}..."),
|
||||
(got_stop, "got finish_reason=stop"),
|
||||
(got_done, "got [DONE] sentinel"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
if got_id:
|
||||
self.print_success(" got conversation_id via docsgpt extension")
|
||||
|
||||
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/chat/completions — Multi-turn conversation
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_multi_turn_conversation(self) -> bool:
|
||||
"""Test multi-turn conversation with history in messages."""
|
||||
test_name = "v1 chat completions - multi-turn"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "user", "content": "My name is TestBot."},
|
||||
{"role": "assistant", "content": "Hello TestBot!"},
|
||||
{"role": "user", "content": "What is my name?"},
|
||||
],
|
||||
"stream": False,
|
||||
},
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
self.print_info(f"Response: {content[:150]}")
|
||||
|
||||
# The response should reference "TestBot" from the history
|
||||
has_content = bool(content)
|
||||
self.print_success(f" Got response with {len(content)} chars")
|
||||
self.record_result(test_name, has_content, "Multi-turn works")
|
||||
return has_content
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# /v1/models
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_list_models(self) -> bool:
|
||||
"""Test GET /v1/models endpoint."""
|
||||
test_name = "v1 models - list"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
self.print_info(f"Status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
|
||||
checks = [
|
||||
(data.get("object") == "list", "object is list"),
|
||||
("data" in data, "has data array"),
|
||||
(len(data.get("data", [])) > 0, f"has {len(data.get('data', []))} model(s)"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
if data.get("data"):
|
||||
model = data["data"][0]
|
||||
model_checks = [
|
||||
("id" in model, "model has id"),
|
||||
(model.get("object") == "model", "model object is 'model'"),
|
||||
(model.get("owned_by") == "docsgpt", "owned_by is docsgpt"),
|
||||
]
|
||||
for check, label in model_checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_models_no_auth(self) -> bool:
|
||||
"""Test that /v1/models without auth returns 401."""
|
||||
test_name = "v1 models - no auth"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
self.print_success("Correctly returned 401")
|
||||
self.record_result(test_name, True, "401 as expected")
|
||||
return True
|
||||
else:
|
||||
self.print_error(f"Expected 401, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Backward Compatibility — old endpoints still work
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_old_stream_endpoint_still_works(self) -> bool:
|
||||
"""Verify the old /stream endpoint still works after v1 changes."""
|
||||
test_name = "Backward compat - /stream"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "Say hello briefly.",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
events = []
|
||||
got_end = False
|
||||
got_answer = False
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
events.append(data)
|
||||
if data.get("type") == "answer":
|
||||
got_answer = True
|
||||
if data.get("type") == "end":
|
||||
got_end = True
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
checks = [
|
||||
(len(events) > 0, f"received {len(events)} events"),
|
||||
(got_answer, "got answer event"),
|
||||
(got_end, "got end event"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_old_answer_endpoint_still_works(self) -> bool:
|
||||
"""Verify the old /api/answer endpoint still works."""
|
||||
test_name = "Backward compat - /api/answer"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "Say hi.",
|
||||
"history": "[]",
|
||||
"isNoneDoc": True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/api/answer",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
data = response.json()
|
||||
checks = [
|
||||
("answer" in data, "has answer"),
|
||||
("conversation_id" in data, "has conversation_id"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
self.print_info(f"Answer: {data.get('answer', '')[:100]}")
|
||||
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up test resources."""
|
||||
if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated:
|
||||
try:
|
||||
self.post(f"/api/delete_agent?id={self._agent_id}", json={})
|
||||
self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Run All
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def run_all(self) -> bool:
|
||||
"""Run all v1 API integration tests."""
|
||||
self.print_header("V1 Chat Completions API Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
|
||||
|
||||
try:
|
||||
# Auth tests (no agent needed)
|
||||
self.test_no_auth_returns_401()
|
||||
time.sleep(0.5)
|
||||
|
||||
self.test_models_no_auth()
|
||||
time.sleep(0.5)
|
||||
|
||||
self.test_invalid_key_returns_error()
|
||||
time.sleep(0.5)
|
||||
|
||||
self.test_missing_messages_returns_400()
|
||||
time.sleep(0.5)
|
||||
|
||||
# Non-streaming
|
||||
self.test_non_streaming_basic()
|
||||
time.sleep(1)
|
||||
|
||||
# Streaming
|
||||
self.test_streaming_basic()
|
||||
time.sleep(1)
|
||||
|
||||
# Multi-turn
|
||||
self.test_multi_turn_conversation()
|
||||
time.sleep(1)
|
||||
|
||||
# Models
|
||||
self.test_list_models()
|
||||
time.sleep(0.5)
|
||||
|
||||
# Backward compatibility
|
||||
self.test_old_stream_endpoint_still_works()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_old_answer_endpoint_still_works()
|
||||
time.sleep(1)
|
||||
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
client = create_client_from_args(V1ApiTests, "DocsGPT V1 API Integration Tests")
|
||||
success = client.run_all()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
539
tests/integration/test_v1_tool_calls.py
Normal file
539
tests/integration/test_v1_tool_calls.py
Normal file
@@ -0,0 +1,539 @@
|
||||
#!/usr/bin/env python3
|
||||
r"""
|
||||
Integration tests for the /v1/ chat completions API — client tool-call flow.
|
||||
|
||||
Tests the full lifecycle:
|
||||
1. Send request with client tools → LLM triggers a tool call
|
||||
2. Verify response returns clean tool names (no internal _ct\d+ suffix)
|
||||
3. Send continuation with tool results + top-level conversation_id
|
||||
4. Verify the continuation completes successfully
|
||||
|
||||
Usage:
|
||||
python tests/integration/test_v1_tool_calls.py
|
||||
python tests/integration/test_v1_tool_calls.py --base-url http://localhost:7091
|
||||
"""
|
||||
|
||||
import json as json_module
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
_TESTS_DIR = _THIS_DIR.parent
|
||||
_ROOT_DIR = _TESTS_DIR.parent
|
||||
if str(_ROOT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT_DIR))
|
||||
|
||||
from tests.integration.base import DocsGPTTestBase, create_client_from_args
|
||||
|
||||
# Internal suffix pattern that should NOT appear in client responses
|
||||
_CT_SUFFIX_RE = re.compile(r"_ct\d+$")
|
||||
|
||||
|
||||
class V1ToolCallTests(DocsGPTTestBase):
|
||||
"""Integration tests for /v1/ client tool-call flows."""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_or_create_agent_key(self) -> Optional[str]:
|
||||
"""Get or create a test agent and return its API key."""
|
||||
if hasattr(self, "_agent_key") and self._agent_key:
|
||||
return self._agent_key
|
||||
|
||||
payload = {
|
||||
"name": f"V1 ToolCall Test {int(time.time())}",
|
||||
"description": "Integration test agent for tool-call flow",
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"retriever": "classic",
|
||||
"agent_type": "classic",
|
||||
"status": "published",
|
||||
"source": "default",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_agent", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
api_key = result.get("key")
|
||||
self._agent_id = result.get("id")
|
||||
if api_key:
|
||||
self._agent_key = api_key
|
||||
self.print_info(f"Created test agent with key: {api_key[:8]}...")
|
||||
return api_key
|
||||
except Exception as e:
|
||||
self.print_error(f"Failed to create agent: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _v1_headers(self, api_key: str) -> dict:
|
||||
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# A simple client tool definition in OpenAI format
|
||||
_CLIENT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create",
|
||||
"description": "Create a new todo item",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The title of the new todo item",
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def _send_streaming_request(
|
||||
self,
|
||||
api_key: str,
|
||||
messages: List[Dict],
|
||||
tools: Optional[List[Dict]] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Tuple[List[Dict], str, Optional[Dict]]:
|
||||
"""Send a streaming request and collect all events.
|
||||
|
||||
Returns:
|
||||
(all_chunks, full_content, tool_call_info)
|
||||
tool_call_info is a dict with 'name', 'arguments', 'call_id'
|
||||
if the response paused for a client tool call, else None.
|
||||
"""
|
||||
body: Dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
if conversation_id:
|
||||
body["conversation_id"] = conversation_id
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=body,
|
||||
headers=self._v1_headers(api_key),
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Expected 200, got {response.status_code}: {response.text[:300]}"
|
||||
)
|
||||
|
||||
chunks: List[Dict] = []
|
||||
content_pieces: List[str] = []
|
||||
tool_call_info: Optional[Dict] = None
|
||||
conversation_id_from_response: Optional[str] = None
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8")
|
||||
if not line_str.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line_str[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json_module.loads(data_str)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Standard chunks
|
||||
if "choices" in chunk:
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
if "content" in delta:
|
||||
content_pieces.append(delta["content"])
|
||||
|
||||
# Tool call delta
|
||||
if "tool_calls" in delta:
|
||||
tc = delta["tool_calls"][0]
|
||||
tool_call_info = {
|
||||
"call_id": tc.get("id", ""),
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"].get("arguments", "{}"),
|
||||
}
|
||||
|
||||
# Extension chunks
|
||||
if "docsgpt" in chunk:
|
||||
ext = chunk["docsgpt"]
|
||||
if ext.get("type") == "id":
|
||||
conversation_id_from_response = ext.get("conversation_id")
|
||||
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
full_content = "".join(content_pieces)
|
||||
|
||||
# Attach conversation_id to tool_call_info for convenience
|
||||
if tool_call_info and conversation_id_from_response:
|
||||
tool_call_info["conversation_id"] = conversation_id_from_response
|
||||
|
||||
return chunks, full_content, tool_call_info
|
||||
|
||||
def _send_non_streaming_request(
|
||||
self,
|
||||
api_key: str,
|
||||
messages: List[Dict],
|
||||
tools: Optional[List[Dict]] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""Send a non-streaming request and return parsed JSON."""
|
||||
body: Dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
if conversation_id:
|
||||
body["conversation_id"] = conversation_id
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=body,
|
||||
headers=self._v1_headers(api_key),
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Expected 200, got {response.status_code}: {response.text[:300]}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_streaming_tool_call_clean_name(self) -> bool:
|
||||
"""Streaming: tool names returned to client must not have _ct suffixes."""
|
||||
test_name = "v1 streaming tool call - clean name"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "user", "content": "Use the create tool to add a todo item titled 'Test integration'. Call the tool now."},
|
||||
]
|
||||
chunks, content, tool_call_info = self._send_streaming_request(
|
||||
api_key, messages, tools=self._CLIENT_TOOLS
|
||||
)
|
||||
|
||||
if not tool_call_info:
|
||||
# LLM didn't trigger the tool — could happen, not a failure of our code
|
||||
self.print_warning("LLM did not trigger a tool call (may need prompt tuning)")
|
||||
self.print_info(f"Got text response instead: {content[:100]}")
|
||||
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
|
||||
return True
|
||||
|
||||
tool_name = tool_call_info["name"]
|
||||
self.print_info(f"Tool call name: {tool_name}")
|
||||
|
||||
has_suffix = bool(_CT_SUFFIX_RE.search(tool_name))
|
||||
if has_suffix:
|
||||
self.print_error(f"Tool name has internal suffix: {tool_name}")
|
||||
self.record_result(test_name, False, f"Suffix leak: {tool_name}")
|
||||
return False
|
||||
|
||||
self.print_success(f"Tool name is clean: {tool_name}")
|
||||
self.record_result(test_name, True, f"Clean name: {tool_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_non_streaming_tool_call_clean_name(self) -> bool:
|
||||
"""Non-streaming: tool names returned to client must not have _ct suffixes."""
|
||||
test_name = "v1 non-streaming tool call - clean name"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "user", "content": "Use the create tool to add a todo item titled 'Test non-stream'. Call the tool now."},
|
||||
]
|
||||
data = self._send_non_streaming_request(
|
||||
api_key, messages, tools=self._CLIENT_TOOLS
|
||||
)
|
||||
|
||||
message = data["choices"][0]["message"]
|
||||
tool_calls = message.get("tool_calls")
|
||||
|
||||
if not tool_calls:
|
||||
content = message.get("content", "")
|
||||
self.print_warning("LLM did not trigger a tool call")
|
||||
self.print_info(f"Got text response: {content[:100]}")
|
||||
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
|
||||
return True
|
||||
|
||||
tool_name = tool_calls[0]["function"]["name"]
|
||||
self.print_info(f"Tool call name: {tool_name}")
|
||||
|
||||
has_suffix = bool(_CT_SUFFIX_RE.search(tool_name))
|
||||
if has_suffix:
|
||||
self.print_error(f"Tool name has internal suffix: {tool_name}")
|
||||
self.record_result(test_name, False, f"Suffix leak: {tool_name}")
|
||||
return False
|
||||
|
||||
self.print_success(f"Tool name is clean: {tool_name}")
|
||||
self.record_result(test_name, True, f"Clean name: {tool_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool:
|
||||
"""Full tool-call round-trip: trigger → get conversation_id → continue with top-level id."""
|
||||
test_name = "v1 streaming tool continuation - top-level conversation_id"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Step 1: trigger a tool call
|
||||
messages = [
|
||||
{"role": "user", "content": "Use the create tool to add a todo item titled 'Round trip test'. Call the tool now."},
|
||||
]
|
||||
chunks, content, tool_call_info = self._send_streaming_request(
|
||||
api_key, messages, tools=self._CLIENT_TOOLS
|
||||
)
|
||||
|
||||
if not tool_call_info:
|
||||
self.print_warning("LLM did not trigger a tool call")
|
||||
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
|
||||
return True
|
||||
|
||||
conversation_id = tool_call_info.get("conversation_id")
|
||||
if not conversation_id:
|
||||
self.print_error("No conversation_id returned in stream")
|
||||
self.record_result(test_name, False, "Missing conversation_id")
|
||||
return False
|
||||
|
||||
self.print_info(f"Got conversation_id: {conversation_id[:12]}...")
|
||||
self.print_info(f"Tool call: {tool_call_info['name']}({tool_call_info['arguments']})")
|
||||
|
||||
# Step 2: send continuation with tool result + top-level conversation_id
|
||||
# (standard OpenAI format — no docsgpt field in assistant message)
|
||||
continuation_messages = [
|
||||
*messages,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call_info["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call_info["name"],
|
||||
"arguments": tool_call_info["arguments"],
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_info["call_id"],
|
||||
"content": json_module.dumps({"id": 99, "title": "Round trip test", "status": "created"}),
|
||||
},
|
||||
]
|
||||
|
||||
chunks2, content2, tool_call_info2 = self._send_streaming_request(
|
||||
api_key,
|
||||
continuation_messages,
|
||||
tools=self._CLIENT_TOOLS,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
checks = [
|
||||
(len(chunks2) > 0, f"continuation returned {len(chunks2)} chunks"),
|
||||
(bool(content2) or tool_call_info2 is not None, "got content or another tool call"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for check, label in checks:
|
||||
if check:
|
||||
self.print_success(f" {label}")
|
||||
else:
|
||||
self.print_error(f" {label}")
|
||||
all_passed = False
|
||||
|
||||
if content2:
|
||||
self.print_info(f"Continuation response: {content2[:150]}")
|
||||
|
||||
self.record_result(
|
||||
test_name,
|
||||
all_passed,
|
||||
"Full round-trip works" if all_passed else "Continuation failed",
|
||||
)
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_non_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool:
|
||||
"""Non-streaming full round-trip with top-level conversation_id."""
|
||||
test_name = "v1 non-streaming tool continuation - top-level conversation_id"
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
api_key = self.get_or_create_agent_key()
|
||||
if not api_key:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Step 1: trigger a tool call
|
||||
messages = [
|
||||
{"role": "user", "content": "Use the create tool to add a todo item titled 'Non-stream round trip'. Call the tool now."},
|
||||
]
|
||||
data = self._send_non_streaming_request(
|
||||
api_key, messages, tools=self._CLIENT_TOOLS
|
||||
)
|
||||
|
||||
message = data["choices"][0]["message"]
|
||||
tool_calls = message.get("tool_calls")
|
||||
|
||||
if not tool_calls:
|
||||
self.print_warning("LLM did not trigger a tool call")
|
||||
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
|
||||
return True
|
||||
|
||||
conversation_id = data.get("docsgpt", {}).get("conversation_id")
|
||||
if not conversation_id:
|
||||
self.print_error("No conversation_id in response")
|
||||
self.record_result(test_name, False, "Missing conversation_id")
|
||||
return False
|
||||
|
||||
tc = tool_calls[0]
|
||||
self.print_info(f"Got tool call: {tc['function']['name']}")
|
||||
self.print_info(f"conversation_id: {conversation_id[:12]}...")
|
||||
|
||||
# Step 2: send continuation (standard format, top-level conversation_id)
|
||||
continuation_messages = [
|
||||
*messages,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tc],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
"content": json_module.dumps({"id": 100, "title": "Non-stream round trip", "status": "created"}),
|
||||
},
|
||||
]
|
||||
|
||||
data2 = self._send_non_streaming_request(
|
||||
api_key,
|
||||
continuation_messages,
|
||||
tools=self._CLIENT_TOOLS,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
message2 = data2["choices"][0]["message"]
|
||||
has_response = bool(message2.get("content")) or bool(message2.get("tool_calls"))
|
||||
|
||||
if has_response:
|
||||
self.print_success("Continuation returned a response")
|
||||
content2 = message2.get("content", "")
|
||||
if content2:
|
||||
self.print_info(f"Response: {content2[:150]}")
|
||||
else:
|
||||
self.print_error("Continuation returned empty response")
|
||||
|
||||
self.record_result(
|
||||
test_name,
|
||||
has_response,
|
||||
"Round-trip works" if has_response else "Empty continuation response",
|
||||
)
|
||||
return has_response
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {e}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Cleanup & Run All
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated:
|
||||
try:
|
||||
self.post(f"/api/delete_agent?id={self._agent_id}", json={})
|
||||
self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def run_all(self) -> bool:
|
||||
self.print_header("V1 Tool-Call Flow Integration Tests")
|
||||
self.print_info(f"Base URL: {self.base_url}")
|
||||
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
|
||||
|
||||
try:
|
||||
# Streaming tests
|
||||
self.test_streaming_tool_call_clean_name()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_non_streaming_tool_call_clean_name()
|
||||
time.sleep(1)
|
||||
|
||||
# Full round-trip tests
|
||||
self.test_streaming_tool_continuation_with_top_level_conversation_id()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_non_streaming_tool_continuation_with_top_level_conversation_id()
|
||||
time.sleep(1)
|
||||
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
return self.print_summary()
|
||||
|
||||
|
||||
def main():
|
||||
client = create_client_from_args(V1ToolCallTests, "DocsGPT V1 Tool-Call Integration Tests")
|
||||
success = client.run_all()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -196,9 +196,9 @@ class TestGoogleLLMHandler:
|
||||
assert result.finish_reason == "tool_calls"
|
||||
|
||||
def test_create_tool_message(self):
|
||||
"""Test creating tool message."""
|
||||
"""Test creating tool message in standard format."""
|
||||
handler = GoogleLLMHandler()
|
||||
|
||||
|
||||
tool_call = ToolCall(
|
||||
id="call_123",
|
||||
name="get_weather",
|
||||
@@ -206,35 +206,26 @@ class TestGoogleLLMHandler:
|
||||
index=0
|
||||
)
|
||||
result = {"temperature": "25C", "condition": "cloudy"}
|
||||
|
||||
|
||||
message = handler.create_tool_message(tool_call, result)
|
||||
|
||||
expected = {
|
||||
"role": "model",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "get_weather",
|
||||
"response": {"result": result},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert message == expected
|
||||
|
||||
assert message["role"] == "tool"
|
||||
assert message["tool_call_id"] == "call_123"
|
||||
import json
|
||||
assert json.loads(message["content"]) == result
|
||||
|
||||
def test_create_tool_message_string_result(self):
|
||||
"""Test creating tool message with string result."""
|
||||
handler = GoogleLLMHandler()
|
||||
|
||||
|
||||
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
|
||||
result = "2023-12-01 15:30:00 JST"
|
||||
|
||||
|
||||
message = handler.create_tool_message(tool_call, result)
|
||||
|
||||
assert message["role"] == "model"
|
||||
assert message["content"][0]["function_response"]["response"]["result"] == result
|
||||
assert message["content"][0]["function_response"]["name"] == "get_time"
|
||||
|
||||
assert message["role"] == "tool"
|
||||
assert message["tool_call_id"] == "call_456"
|
||||
assert message["content"] == result
|
||||
|
||||
def test_iterate_stream(self):
|
||||
"""Test stream iteration."""
|
||||
|
||||
@@ -621,6 +621,7 @@ class TestHandleToolCalls:
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
@@ -641,7 +642,7 @@ class TestHandleToolCalls:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, _pending = e.value
|
||||
|
||||
assert any(e.get("type") == "tool_call" for e in events)
|
||||
assert len(messages) >= 2 # function_call + tool_message
|
||||
@@ -675,6 +676,9 @@ class TestHandleToolCalls:
|
||||
agent = Mock()
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
agent.tool_executor._name_to_tool = {}
|
||||
agent._execute_tool_action = Mock(side_effect=RuntimeError("exec error"))
|
||||
|
||||
call = ToolCall(id="c1", name="action_1", arguments="{}")
|
||||
@@ -704,18 +708,17 @@ class TestHandleToolCalls:
|
||||
while True:
|
||||
next(gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, _pending = e.value
|
||||
|
||||
# Standard format: thought_signature is on tool_calls items
|
||||
assistant_msgs = [
|
||||
m for m in messages
|
||||
if m.get("role") == "assistant"
|
||||
and isinstance(m.get("content"), list)
|
||||
if m.get("role") == "assistant" and m.get("tool_calls")
|
||||
]
|
||||
assert any(
|
||||
"thought_signature" in item
|
||||
tc.get("thought_signature") == "sig"
|
||||
for m in assistant_msgs
|
||||
for item in m["content"]
|
||||
if isinstance(item, dict)
|
||||
for tc in m["tool_calls"]
|
||||
)
|
||||
|
||||
|
||||
@@ -751,6 +754,7 @@ class TestHandleNonStreaming:
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
# First response requires tool call, second is final
|
||||
call_count = {"n": 0}
|
||||
@@ -856,6 +860,7 @@ class TestHandleStreaming:
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
# First chunk has partial tool call, second completes it
|
||||
chunk1 = LLMResponse(
|
||||
@@ -907,6 +912,7 @@ class TestHandleStreaming:
|
||||
agent.context_limit_reached = True
|
||||
agent._check_context_limit = Mock(return_value=True)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
# Chunk finishes with tool_calls
|
||||
chunk = LLMResponse(
|
||||
@@ -929,7 +935,7 @@ class TestHandleStreaming:
|
||||
def fake_handle_tool_calls(agent, calls, tools_dict, messages):
|
||||
agent.context_limit_reached = True
|
||||
yield {"type": "tool_call", "data": {"status": "skipped"}}
|
||||
return messages
|
||||
return messages, None
|
||||
|
||||
handler.handle_tool_calls = fake_handle_tool_calls
|
||||
|
||||
@@ -1501,6 +1507,7 @@ class TestHandleToolCallsCompressionSuccess:
|
||||
agent._check_context_limit = Mock(side_effect=check_limit)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
@@ -1538,6 +1545,7 @@ class TestHandleToolCallsCompressionSuccess:
|
||||
agent = Mock()
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
exec_count = {"n": 0}
|
||||
|
||||
|
||||
@@ -128,9 +128,9 @@ class TestOpenAILLMHandler:
|
||||
assert result.finish_reason == ""
|
||||
|
||||
def test_create_tool_message(self):
|
||||
"""Test creating tool message."""
|
||||
"""Test creating tool message in standard format."""
|
||||
handler = OpenAILLMHandler()
|
||||
|
||||
|
||||
tool_call = ToolCall(
|
||||
id="call_123",
|
||||
name="get_weather",
|
||||
@@ -138,36 +138,26 @@ class TestOpenAILLMHandler:
|
||||
index=0
|
||||
)
|
||||
result = {"temperature": "72F", "condition": "sunny"}
|
||||
|
||||
|
||||
message = handler.create_tool_message(tool_call, result)
|
||||
|
||||
expected = {
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "get_weather",
|
||||
"response": {"result": result},
|
||||
"call_id": "call_123",
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert message == expected
|
||||
|
||||
assert message["role"] == "tool"
|
||||
assert message["tool_call_id"] == "call_123"
|
||||
import json
|
||||
assert json.loads(message["content"]) == result
|
||||
|
||||
def test_create_tool_message_string_result(self):
|
||||
"""Test creating tool message with string result."""
|
||||
handler = OpenAILLMHandler()
|
||||
|
||||
|
||||
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
|
||||
result = "2023-12-01 10:30:00"
|
||||
|
||||
|
||||
message = handler.create_tool_message(tool_call, result)
|
||||
|
||||
|
||||
assert message["role"] == "tool"
|
||||
assert message["content"][0]["function_response"]["response"]["result"] == result
|
||||
assert message["content"][0]["function_response"]["call_id"] == "call_456"
|
||||
assert message["tool_call_id"] == "call_456"
|
||||
assert message["content"] == result
|
||||
|
||||
def test_iterate_stream(self):
|
||||
"""Test stream iteration."""
|
||||
|
||||
@@ -478,11 +478,14 @@ class TestHandleToolCallsErrors:
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock()
|
||||
agent._check_context_limit = MagicMock(return_value=False)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent.tool_executor._name_to_tool = {"search": ("1", "search")}
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=RuntimeError("tool failed")
|
||||
)
|
||||
|
||||
tool_call = ToolCall(id="tc1", name="search_1", arguments={"q": "test"})
|
||||
tool_call = ToolCall(id="tc1", name="search", arguments={"q": "test"})
|
||||
tools_dict = {"1": {"name": "search_tool"}}
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
|
||||
@@ -506,6 +509,9 @@ class TestHandleToolCallsErrors:
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock()
|
||||
agent._check_context_limit = MagicMock(return_value=False)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent.tool_executor._name_to_tool = {}
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=RuntimeError("tool failed")
|
||||
)
|
||||
@@ -1169,12 +1175,15 @@ class TestHandleToolCallsErrorsAdditional:
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock()
|
||||
agent._check_context_limit = MagicMock(return_value=False)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent.tool_executor._name_to_tool = {"do_thing": ("42", "do_thing")}
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=RuntimeError("broken tool")
|
||||
)
|
||||
|
||||
tool_call = ToolCall(
|
||||
id="tc1", name="do_thing_42", arguments={"x": 1}
|
||||
id="tc1", name="do_thing", arguments={"x": 1}
|
||||
)
|
||||
tools_dict = {"42": {"name": "my_tool"}}
|
||||
messages = [{"role": "user", "content": "go"}]
|
||||
@@ -1188,7 +1197,7 @@ class TestHandleToolCallsErrorsAdditional:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
final_messages = e.value
|
||||
final_messages, _pending = e.value
|
||||
|
||||
# Verify the error message was appended
|
||||
error_msgs = [
|
||||
@@ -1205,12 +1214,17 @@ class TestHandleToolCallsErrorsAdditional:
|
||||
]
|
||||
assert len(error_events) == 1
|
||||
assert error_events[0]["data"]["tool_name"] == "my_tool"
|
||||
assert error_events[0]["data"]["action_name"] == "do_thing_42"
|
||||
assert error_events[0]["data"]["action_name"] == "do_thing"
|
||||
|
||||
def test_tool_error_with_no_context_check(self):
|
||||
"""Cover line 660: messages.copy() at start of handle_tool_calls."""
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock(spec=[]) # No _check_context_limit attribute
|
||||
agent.llm = MagicMock()
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor = MagicMock()
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent.tool_executor._name_to_tool = {}
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=ValueError("bad args")
|
||||
)
|
||||
|
||||
@@ -176,6 +176,9 @@ class TestLLMHandlerTokenTracking:
|
||||
# Create mock agent that hits limit on second tool
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
mock_agent.llm.__class__.__name__ = "MockLLM"
|
||||
mock_agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
mock_agent.tool_executor._name_to_tool = {}
|
||||
|
||||
call_count = [0]
|
||||
|
||||
@@ -235,6 +238,9 @@ class TestLLMHandlerTokenTracking:
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
mock_agent._check_context_limit = Mock(return_value=False)
|
||||
mock_agent.llm.__class__.__name__ = "MockLLM"
|
||||
mock_agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
mock_agent.tool_executor._name_to_tool = {}
|
||||
mock_agent._execute_tool_action = Mock(
|
||||
return_value=iter([{"type": "tool_call", "data": {}}])
|
||||
)
|
||||
@@ -300,7 +306,7 @@ class TestLLMHandlerTokenTracking:
|
||||
|
||||
def tool_handler_gen(*args):
|
||||
yield {"type": "tool", "data": {}}
|
||||
return []
|
||||
return [], None
|
||||
|
||||
# Mock handle_tool_calls to return messages and set flag
|
||||
with patch.object(
|
||||
|
||||
430
tests/test_client_tools.py
Normal file
430
tests/test_client_tools.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""Tests for client-side tools (Phase 2).
|
||||
|
||||
Covers merge_client_tools, prepare_tools_for_llm with client tools,
|
||||
check_pause for client-side tools, and the full flow through the handler.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolExecutor.merge_client_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMergeClientTools:
|
||||
|
||||
def test_merge_single_tool(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {}
|
||||
client_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = executor.merge_client_tools(tools_dict, client_tools)
|
||||
|
||||
assert "ct0" in result
|
||||
tool = result["ct0"]
|
||||
assert tool["name"] == "get_weather"
|
||||
assert tool["client_side"] is True
|
||||
assert len(tool["actions"]) == 1
|
||||
assert tool["actions"][0]["name"] == "get_weather"
|
||||
assert tool["actions"][0]["active"] is True
|
||||
assert "city" in tool["actions"][0]["parameters"]["properties"]
|
||||
|
||||
def test_merge_multiple_tools(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {"0": {"name": "existing_tool", "actions": []}}
|
||||
client_tools = [
|
||||
{"type": "function", "function": {"name": "tool_a", "description": "A"}},
|
||||
{"type": "function", "function": {"name": "tool_b", "description": "B"}},
|
||||
]
|
||||
|
||||
result = executor.merge_client_tools(tools_dict, client_tools)
|
||||
|
||||
# Original tool still present
|
||||
assert "0" in result
|
||||
# Client tools added
|
||||
assert "ct0" in result
|
||||
assert "ct1" in result
|
||||
assert result["ct0"]["name"] == "tool_a"
|
||||
assert result["ct1"]["name"] == "tool_b"
|
||||
|
||||
def test_merge_bare_format(self):
|
||||
"""Accept simplified format without the outer 'function' wrapper."""
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {}
|
||||
client_tools = [
|
||||
{"name": "simple_tool", "description": "Simple", "parameters": {}},
|
||||
]
|
||||
|
||||
result = executor.merge_client_tools(tools_dict, client_tools)
|
||||
|
||||
assert "ct0" in result
|
||||
assert result["ct0"]["name"] == "simple_tool"
|
||||
|
||||
def test_merge_preserves_existing_tools(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"abc123": {
|
||||
"name": "brave",
|
||||
"actions": [{"name": "search", "active": True}],
|
||||
}
|
||||
}
|
||||
client_tools = [
|
||||
{"type": "function", "function": {"name": "my_tool", "description": "D"}},
|
||||
]
|
||||
|
||||
executor.merge_client_tools(tools_dict, client_tools)
|
||||
|
||||
assert "abc123" in tools_dict
|
||||
assert tools_dict["abc123"]["name"] == "brave"
|
||||
assert "ct0" in tools_dict
|
||||
|
||||
def test_merge_empty_list(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {"0": {"name": "existing"}}
|
||||
|
||||
executor.merge_client_tools(tools_dict, [])
|
||||
|
||||
assert len(tools_dict) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_tools_for_llm with client tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPrepareClientToolsForLlm:
|
||||
|
||||
def test_client_tools_included_in_llm_schema(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {}
|
||||
client_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
executor.merge_client_tools(tools_dict, client_tools)
|
||||
|
||||
schemas = executor.prepare_tools_for_llm(tools_dict)
|
||||
|
||||
assert len(schemas) == 1
|
||||
assert schemas[0]["type"] == "function"
|
||||
assert schemas[0]["function"]["name"] == "get_weather"
|
||||
assert schemas[0]["function"]["description"] == "Get weather"
|
||||
# Parameters passed through directly (not filtered by _build_tool_parameters)
|
||||
assert "city" in schemas[0]["function"]["parameters"]["properties"]
|
||||
assert schemas[0]["function"]["parameters"]["required"] == ["city"]
|
||||
|
||||
def test_mixed_server_and_client_tools(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"actions": [
|
||||
{
|
||||
"name": "do_thing",
|
||||
"description": "Does a thing",
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
client_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "local_fn",
|
||||
"description": "Local function",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
executor.merge_client_tools(tools_dict, client_tools)
|
||||
|
||||
schemas = executor.prepare_tools_for_llm(tools_dict)
|
||||
|
||||
assert len(schemas) == 2
|
||||
names = {s["function"]["name"] for s in schemas}
|
||||
assert "do_thing" in names
|
||||
assert "local_fn" in names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_tools auto-merges client_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetToolsAutoMerge:
|
||||
|
||||
def test_get_tools_merges_client_tools(self, mock_mongo_db):
|
||||
executor = ToolExecutor(user="alice")
|
||||
executor.client_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "my_fn", "description": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
tools = executor.get_tools()
|
||||
|
||||
assert any(
|
||||
t.get("client_side") is True for t in tools.values()
|
||||
), "Client tools should be merged into tools_dict"
|
||||
|
||||
def test_get_tools_no_client_tools(self, mock_mongo_db):
|
||||
executor = ToolExecutor(user="alice")
|
||||
|
||||
tools = executor.get_tools()
|
||||
|
||||
assert not any(
|
||||
t.get("client_side") for t in tools.values()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_pause for client-side tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckPauseClientTools:
|
||||
|
||||
def _make_call(self, name="action_0", call_id="c1"):
|
||||
call = Mock()
|
||||
call.name = name
|
||||
call.id = call_id
|
||||
call.arguments = "{}"
|
||||
call.thought_signature = None
|
||||
return call
|
||||
|
||||
def test_client_tool_triggers_pause(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"ct0": {
|
||||
"name": "get_weather",
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{"name": "get_weather", "active": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
executor.prepare_tools_for_llm(tools_dict)
|
||||
call = self._make_call(name="get_weather")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "requires_client_execution"
|
||||
assert result["tool_name"] == "get_weather"
|
||||
assert result["tool_id"] == "ct0"
|
||||
|
||||
def test_server_tool_no_pause(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "brave",
|
||||
"actions": [
|
||||
{"name": "search", "active": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
executor.prepare_tools_for_llm(tools_dict)
|
||||
call = self._make_call(name="search")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler flow: client tool causes pause
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteHandler(LLMHandler):
|
||||
"""Minimal concrete handler for testing."""
|
||||
|
||||
def parse_response(self, response):
|
||||
return LLMResponse(
|
||||
content=str(response), tool_calls=[], finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {"role": "tool", "content": str(result)}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandlerClientToolPause:
|
||||
|
||||
def test_client_tool_pauses_stream(self):
|
||||
"""When LLM calls a client-side tool, handler yields tool_calls_pending."""
|
||||
handler = ConcreteHandler()
|
||||
|
||||
agent = Mock()
|
||||
agent.llm = Mock()
|
||||
agent.model_id = "test"
|
||||
agent.tools = []
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
# check_pause returns pause info for client tool
|
||||
agent.tool_executor.check_pause = Mock(return_value={
|
||||
"call_id": "c1",
|
||||
"name": "get_weather",
|
||||
"tool_name": "get_weather",
|
||||
"tool_id": "ct0",
|
||||
"action_name": "get_weather",
|
||||
"llm_name": "get_weather",
|
||||
"arguments": {"city": "SF"},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": None,
|
||||
})
|
||||
agent.tool_executor._name_to_tool = {"get_weather": ("ct0", "get_weather")}
|
||||
|
||||
# Simulate streaming: one chunk with tool_calls finish_reason
|
||||
chunk = LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="c1", name="get_weather", arguments='{"city": "SF"}', index=0)],
|
||||
finish_reason="tool_calls",
|
||||
raw_response={},
|
||||
)
|
||||
handler.parse_response = lambda c: c
|
||||
handler._iterate_stream = lambda r: iter(r)
|
||||
|
||||
gen = handler.handle_streaming(
|
||||
agent, [chunk], {"ct0": {"name": "get_weather", "client_side": True}}, []
|
||||
)
|
||||
events = list(gen)
|
||||
|
||||
# Should have a requires_client_execution event
|
||||
client_events = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "requires_client_execution"
|
||||
]
|
||||
assert len(client_events) == 1
|
||||
|
||||
# Should have a tool_calls_pending event
|
||||
pending_events = [
|
||||
e for e in events
|
||||
if isinstance(e, dict) and e.get("type") == "tool_calls_pending"
|
||||
]
|
||||
assert len(pending_events) == 1
|
||||
|
||||
def test_mixed_server_and_client_tools_in_batch(self):
|
||||
"""Server tool executes, client tool pauses."""
|
||||
handler = ConcreteHandler()
|
||||
|
||||
agent = Mock()
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def check_pause_fn(tools_dict, call, llm_class):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 2: # Second tool is client-side
|
||||
return {
|
||||
"call_id": "c2",
|
||||
"name": "get_weather",
|
||||
"tool_name": "get_weather",
|
||||
"tool_id": "ct0",
|
||||
"action_name": "get_weather",
|
||||
"llm_name": "get_weather",
|
||||
"arguments": {},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": None,
|
||||
}
|
||||
return None
|
||||
|
||||
agent.tool_executor.check_pause = Mock(side_effect=check_pause_fn)
|
||||
agent.tool_executor._name_to_tool = {
|
||||
"search": ("0", "search"),
|
||||
"get_weather": ("ct0", "get_weather"),
|
||||
}
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
return ("server result", call.id)
|
||||
|
||||
agent._execute_tool_action = Mock(side_effect=fake_execute)
|
||||
|
||||
calls = [
|
||||
ToolCall(id="c1", name="search", arguments="{}"),
|
||||
ToolCall(id="c2", name="get_weather", arguments="{}"),
|
||||
]
|
||||
|
||||
gen = handler.handle_tool_calls(
|
||||
agent,
|
||||
calls,
|
||||
{
|
||||
"0": {"name": "search"},
|
||||
"ct0": {"name": "get_weather", "client_side": True},
|
||||
},
|
||||
[],
|
||||
)
|
||||
|
||||
events = []
|
||||
messages = None
|
||||
pending = None
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
# Server tool executed
|
||||
assert agent._execute_tool_action.call_count == 1
|
||||
# Client tool pending
|
||||
assert pending is not None
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["pause_type"] == "requires_client_execution"
|
||||
667
tests/test_continuation.py
Normal file
667
tests/test_continuation.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""Tests for the continuation infrastructure (Phase 1).
|
||||
|
||||
Covers ContinuationService, ToolExecutor.check_pause, handler pause
|
||||
signaling, BaseAgent.gen_continuation, and request validation.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContinuationService
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContinuationService:
|
||||
|
||||
def test_save_and_load(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
svc.save_state(
|
||||
conversation_id="conv-1",
|
||||
user="alice",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
pending_tool_calls=[{"call_id": "c1", "pause_type": "awaiting_approval"}],
|
||||
tools_dict={"0": {"name": "test_tool"}},
|
||||
tool_schemas=[{"type": "function", "function": {"name": "act_0"}}],
|
||||
agent_config={"model_id": "gpt-4"},
|
||||
)
|
||||
|
||||
state = svc.load_state("conv-1", "alice")
|
||||
assert state is not None
|
||||
assert state["conversation_id"] == "conv-1"
|
||||
assert state["user"] == "alice"
|
||||
assert len(state["messages"]) == 1
|
||||
assert len(state["pending_tool_calls"]) == 1
|
||||
assert state["agent_config"]["model_id"] == "gpt-4"
|
||||
|
||||
def test_load_returns_none_when_missing(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
assert svc.load_state("nonexistent", "alice") is None
|
||||
|
||||
def test_delete_state(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
svc.save_state(
|
||||
conversation_id="conv-2",
|
||||
user="bob",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
assert svc.delete_state("conv-2", "bob") is True
|
||||
assert svc.load_state("conv-2", "bob") is None
|
||||
|
||||
def test_delete_nonexistent(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
assert svc.delete_state("nope", "nope") is False
|
||||
|
||||
def test_upsert_replaces_existing(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
svc.save_state(
|
||||
conversation_id="conv-3",
|
||||
user="carol",
|
||||
messages=[{"role": "user", "content": "v1"}],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
svc.save_state(
|
||||
conversation_id="conv-3",
|
||||
user="carol",
|
||||
messages=[{"role": "user", "content": "v2"}],
|
||||
pending_tool_calls=[{"call_id": "c2"}],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
state = svc.load_state("conv-3", "carol")
|
||||
assert state["messages"][0]["content"] == "v2"
|
||||
assert len(state["pending_tool_calls"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolExecutor.check_pause
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckPause:
|
||||
|
||||
def _make_call(self, name="action_0", call_id="c1", arguments="{}"):
|
||||
call = Mock()
|
||||
call.name = name
|
||||
call.id = call_id
|
||||
call.arguments = arguments
|
||||
call.thought_signature = None
|
||||
return call
|
||||
|
||||
def test_returns_none_for_normal_tool(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "brave",
|
||||
"actions": [
|
||||
{"name": "search", "active": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="search_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_returns_pause_for_client_side_tool(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "get_weather",
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{"name": "get_weather", "active": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="get_weather_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "requires_client_execution"
|
||||
assert result["call_id"] == "c1"
|
||||
assert result["tool_id"] == "0"
|
||||
|
||||
def test_returns_pause_for_approval_required(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "telegram",
|
||||
"actions": [
|
||||
{
|
||||
"name": "send_msg",
|
||||
"active": True,
|
||||
"require_approval": True,
|
||||
"parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="send_msg_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "awaiting_approval"
|
||||
|
||||
def test_returns_none_when_parse_fails(self):
|
||||
executor = ToolExecutor()
|
||||
call = self._make_call(name="bad_name_no_id", arguments="not json")
|
||||
# Bad arguments will cause parse error -> None
|
||||
result = executor.check_pause({}, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_tool_not_in_dict(self):
|
||||
executor = ToolExecutor()
|
||||
call = self._make_call(name="action_99")
|
||||
result = executor.check_pause({"0": {"name": "t"}}, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_api_tool_approval(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "api_tool",
|
||||
"config": {
|
||||
"actions": {
|
||||
"delete_user": {
|
||||
"name": "delete_user",
|
||||
"require_approval": True,
|
||||
"url": "http://example.com",
|
||||
"method": "DELETE",
|
||||
"active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="delete_user_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "awaiting_approval"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler pause signaling (handle_tool_calls returns pending_actions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteHandler(LLMHandler):
|
||||
"""Minimal concrete handler for testing."""
|
||||
|
||||
def parse_response(self, response):
|
||||
return LLMResponse(
|
||||
content=str(response), tool_calls=[], finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
"call_id": tool_call.id,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandlerPauseSignaling:
|
||||
|
||||
def _make_agent(self):
|
||||
agent = Mock()
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
return ("tool result", call.id)
|
||||
|
||||
agent._execute_tool_action = Mock(side_effect=fake_execute)
|
||||
return agent
|
||||
|
||||
def test_no_pause_returns_none_pending(self):
|
||||
handler = ConcreteHandler()
|
||||
agent = self._make_agent()
|
||||
call = ToolCall(id="c1", name="action_0", arguments="{}")
|
||||
|
||||
gen = handler.handle_tool_calls(agent, [call], {"0": {"name": "t"}}, [])
|
||||
events = []
|
||||
messages = None
|
||||
pending = "NOT_SET"
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
assert pending is None
|
||||
assert messages is not None
|
||||
|
||||
def test_pause_returns_pending_actions(self):
|
||||
handler = ConcreteHandler()
|
||||
agent = self._make_agent()
|
||||
agent.tool_executor.check_pause = Mock(return_value={
|
||||
"call_id": "c1",
|
||||
"name": "send_msg_0",
|
||||
"tool_name": "telegram",
|
||||
"tool_id": "0",
|
||||
"action_name": "send_msg",
|
||||
"arguments": {"text": "hello"},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
})
|
||||
|
||||
call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}')
|
||||
gen = handler.handle_tool_calls(
|
||||
agent, [call], {"0": {"name": "telegram"}}, []
|
||||
)
|
||||
|
||||
events = []
|
||||
pending = None
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
assert pending is not None
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["pause_type"] == "awaiting_approval"
|
||||
|
||||
# Should have yielded a tool_call event with awaiting_approval status
|
||||
pause_events = [
|
||||
e for e in events
|
||||
if e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "awaiting_approval"
|
||||
]
|
||||
assert len(pause_events) == 1
|
||||
|
||||
def test_mixed_execute_and_pause(self):
|
||||
"""One tool executes, another needs approval."""
|
||||
handler = ConcreteHandler()
|
||||
agent = self._make_agent()
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def selective_pause(tools_dict, call, llm_class):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 2:
|
||||
return {
|
||||
"call_id": "c2",
|
||||
"name": "danger_0",
|
||||
"tool_name": "danger",
|
||||
"tool_id": "0",
|
||||
"action_name": "danger",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
return None
|
||||
|
||||
agent.tool_executor.check_pause = Mock(side_effect=selective_pause)
|
||||
|
||||
calls = [
|
||||
ToolCall(id="c1", name="safe_0", arguments="{}"),
|
||||
ToolCall(id="c2", name="danger_0", arguments="{}"),
|
||||
]
|
||||
gen = handler.handle_tool_calls(
|
||||
agent, calls, {"0": {"name": "multi"}}, []
|
||||
)
|
||||
|
||||
events = []
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
# First tool was executed normally
|
||||
assert agent._execute_tool_action.call_count == 1
|
||||
# Second tool is pending
|
||||
assert pending is not None
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["call_id"] == "c2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_streaming yields tool_calls_pending
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamingPause:
|
||||
|
||||
def test_streaming_yields_tool_calls_pending(self):
|
||||
handler = ConcreteHandler()
|
||||
agent = Mock()
|
||||
agent.llm = Mock()
|
||||
agent.model_id = "test"
|
||||
agent.tools = []
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
pause_info = {
|
||||
"call_id": "c1",
|
||||
"name": "fn_0",
|
||||
"tool_name": "test",
|
||||
"tool_id": "0",
|
||||
"action_name": "fn",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
agent.tool_executor.check_pause = Mock(return_value=pause_info)
|
||||
|
||||
chunk = LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="c1", name="fn_0", arguments="{}", index=0)],
|
||||
finish_reason="tool_calls",
|
||||
raw_response={},
|
||||
)
|
||||
handler.parse_response = lambda c: c
|
||||
|
||||
def fake_iterate(response):
|
||||
yield from response
|
||||
|
||||
handler._iterate_stream = fake_iterate
|
||||
|
||||
gen = handler.handle_streaming(agent, [chunk], {"0": {"name": "t"}}, [])
|
||||
events = list(gen)
|
||||
|
||||
# Should contain a tool_calls_pending event
|
||||
pending_events = [
|
||||
e for e in events
|
||||
if isinstance(e, dict) and e.get("type") == "tool_calls_pending"
|
||||
]
|
||||
assert len(pending_events) == 1
|
||||
assert len(pending_events[0]["data"]["pending_tool_calls"]) == 1
|
||||
|
||||
# Agent should have _pending_continuation set
|
||||
assert hasattr(agent, "_pending_continuation")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseAgent.gen_continuation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenContinuation:
|
||||
|
||||
def test_approved_tool_executes(self):
|
||||
"""When a tool action is approved, the tool is executed."""
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm._supports_tools = True
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Final answer"]))
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
mock_llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
mock_handler = Mock()
|
||||
mock_handler.process_message_flow = Mock(return_value=iter([]))
|
||||
mock_handler.create_tool_message = Mock(
|
||||
return_value={"role": "tool", "content": [{"function_response": {
|
||||
"name": "act_0", "response": {"result": "done"}, "call_id": "c1"
|
||||
}}]}
|
||||
)
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.tool_calls = []
|
||||
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
|
||||
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
|
||||
|
||||
def fake_execute(tools_dict, call, llm_class):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
return ("result_data", "c1")
|
||||
|
||||
mock_executor.execute = Mock(side_effect=fake_execute)
|
||||
|
||||
agent = ClassicAgent(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test",
|
||||
llm=mock_llm,
|
||||
llm_handler=mock_handler,
|
||||
tool_executor=mock_executor,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": "You are helpful."}]
|
||||
tools_dict = {"0": {"name": "test_tool"}}
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "act_0",
|
||||
"tool_name": "test_tool",
|
||||
"tool_id": "0",
|
||||
"action_name": "act",
|
||||
"arguments": {"q": "test"},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [{"call_id": "c1", "decision": "approved"}]
|
||||
|
||||
list(agent.gen_continuation(messages, tools_dict, pending, tool_actions))
|
||||
|
||||
# Tool should have been executed
|
||||
assert mock_executor.execute.called
|
||||
|
||||
def test_denied_tool_sends_denial(self):
|
||||
"""When a tool action is denied, a denial message is added."""
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm._supports_tools = True
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
mock_llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
mock_handler = Mock()
|
||||
mock_handler.process_message_flow = Mock(return_value=iter([]))
|
||||
mock_handler.create_tool_message = Mock(
|
||||
return_value={"role": "tool", "content": "denied"}
|
||||
)
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.tool_calls = []
|
||||
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
|
||||
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
|
||||
|
||||
agent = ClassicAgent(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test",
|
||||
llm=mock_llm,
|
||||
llm_handler=mock_handler,
|
||||
tool_executor=mock_executor,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "danger_0",
|
||||
"tool_name": "danger",
|
||||
"tool_id": "0",
|
||||
"action_name": "danger",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [
|
||||
{"call_id": "c1", "decision": "denied", "comment": "too risky"}
|
||||
]
|
||||
|
||||
events = list(
|
||||
agent.gen_continuation(messages, {"0": {"name": "danger"}}, pending, tool_actions)
|
||||
)
|
||||
|
||||
# Should have a denied tool_call event
|
||||
denied = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "denied"
|
||||
]
|
||||
assert len(denied) == 1
|
||||
|
||||
# create_tool_message should have been called with denial text
|
||||
denial_arg = mock_handler.create_tool_message.call_args[0][1]
|
||||
assert "denied" in denial_arg.lower()
|
||||
assert "too risky" in denial_arg
|
||||
|
||||
def test_client_result_appended(self):
|
||||
"""Client-provided tool result is added to messages."""
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm._supports_tools = True
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Done"]))
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
mock_llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
mock_handler = Mock()
|
||||
mock_handler.process_message_flow = Mock(return_value=iter([]))
|
||||
mock_handler.create_tool_message = Mock(
|
||||
return_value={"role": "tool", "content": "client result"}
|
||||
)
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.tool_calls = []
|
||||
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
|
||||
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
|
||||
|
||||
agent = ClassicAgent(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test",
|
||||
llm=mock_llm,
|
||||
llm_handler=mock_handler,
|
||||
tool_executor=mock_executor,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "weather_0",
|
||||
"tool_name": "weather",
|
||||
"tool_id": "0",
|
||||
"action_name": "weather",
|
||||
"arguments": {"city": "SF"},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [{"call_id": "c1", "result": {"temp": "72F"}}]
|
||||
|
||||
events = list(
|
||||
agent.gen_continuation(messages, {"0": {"name": "weather"}}, pending, tool_actions)
|
||||
)
|
||||
|
||||
# create_tool_message was called with the client result
|
||||
result_arg = mock_handler.create_tool_message.call_args[0][1]
|
||||
assert "72F" in result_arg
|
||||
|
||||
# Should have a completed tool_call event
|
||||
completed = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "completed"
|
||||
]
|
||||
assert len(completed) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateRequest:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _app_context(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
yield
|
||||
|
||||
def test_continuation_request_without_question(self, mock_mongo_db):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
base = BaseAnswerResource()
|
||||
data = {
|
||||
"conversation_id": "conv-1",
|
||||
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
|
||||
}
|
||||
result = base.validate_request(data)
|
||||
assert result is None # Valid
|
||||
|
||||
def test_continuation_request_missing_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
base = BaseAnswerResource()
|
||||
data = {
|
||||
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
|
||||
}
|
||||
result = base.validate_request(data)
|
||||
assert result is not None # Error — missing conversation_id
|
||||
|
||||
def test_normal_request_still_requires_question(self, mock_mongo_db):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
base = BaseAnswerResource()
|
||||
data = {"conversation_id": "conv-1"}
|
||||
result = base.validate_request(data)
|
||||
assert result is not None # Error — missing question
|
||||
@@ -1238,7 +1238,8 @@ class TestEmbeddingPipelineAddDocWithRetry:
|
||||
# NUL characters should be removed
|
||||
assert "\x00" not in doc.page_content
|
||||
|
||||
def test_add_text_to_store_with_retry_failure(self):
|
||||
@patch("time.sleep", return_value=None)
|
||||
def test_add_text_to_store_with_retry_failure(self, _mock_sleep):
|
||||
from application.parser.embedding_pipeline import add_text_to_store_with_retry
|
||||
|
||||
mock_store = MagicMock()
|
||||
|
||||
481
tests/test_tool_approval.py
Normal file
481
tests/test_tool_approval.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""Tests for tool approval (Phase 3).
|
||||
|
||||
Covers require_approval flag, check_pause for approval, the handler
|
||||
pause/resume flow, and gen_continuation with approved/denied actions.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_pause with require_approval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckPauseApproval:
|
||||
|
||||
def _make_call(self, name="action_0", call_id="c1"):
|
||||
call = Mock()
|
||||
call.name = name
|
||||
call.id = call_id
|
||||
call.arguments = "{}"
|
||||
call.thought_signature = None
|
||||
return call
|
||||
|
||||
def test_approval_required_triggers_pause(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "telegram",
|
||||
"actions": [
|
||||
{
|
||||
"name": "send_msg",
|
||||
"active": True,
|
||||
"require_approval": True,
|
||||
"parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="send_msg_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "awaiting_approval"
|
||||
assert result["tool_name"] == "telegram"
|
||||
assert result["action_name"] == "send_msg"
|
||||
assert result["tool_id"] == "0"
|
||||
|
||||
def test_approval_not_required_no_pause(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "brave",
|
||||
"actions": [
|
||||
{
|
||||
"name": "search",
|
||||
"active": True,
|
||||
"require_approval": False,
|
||||
"parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="search_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_approval_absent_defaults_to_false(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "brave",
|
||||
"actions": [
|
||||
{
|
||||
"name": "search",
|
||||
"active": True,
|
||||
"parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="search_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_api_tool_approval(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "api_tool",
|
||||
"config": {
|
||||
"actions": {
|
||||
"delete_user": {
|
||||
"name": "delete_user",
|
||||
"require_approval": True,
|
||||
"url": "http://example.com",
|
||||
"method": "DELETE",
|
||||
"active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="delete_user_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "awaiting_approval"
|
||||
|
||||
def test_api_tool_no_approval(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "api_tool",
|
||||
"config": {
|
||||
"actions": {
|
||||
"list_users": {
|
||||
"name": "list_users",
|
||||
"url": "http://example.com",
|
||||
"method": "GET",
|
||||
"active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="list_users_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler: approval tool causes pause signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteHandler(LLMHandler):
|
||||
def parse_response(self, response):
|
||||
return LLMResponse(
|
||||
content=str(response), tool_calls=[], finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
import json as _json
|
||||
content = _json.dumps(result) if not isinstance(result, str) else result
|
||||
return {"role": "tool", "tool_call_id": tool_call.id, "content": content}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandlerApprovalPause:
|
||||
|
||||
def _make_agent(self, pause_return):
|
||||
agent = Mock()
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=pause_return)
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
return ("tool result", call.id)
|
||||
|
||||
agent._execute_tool_action = Mock(side_effect=fake_execute)
|
||||
return agent
|
||||
|
||||
def test_approval_tool_pauses(self):
|
||||
handler = ConcreteHandler()
|
||||
pause_info = {
|
||||
"call_id": "c1",
|
||||
"name": "send_msg_0",
|
||||
"tool_name": "telegram",
|
||||
"tool_id": "0",
|
||||
"action_name": "send_msg",
|
||||
"arguments": {"text": "hello"},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
agent = self._make_agent(pause_info)
|
||||
|
||||
call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}')
|
||||
gen = handler.handle_tool_calls(
|
||||
agent, [call], {"0": {"name": "telegram"}}, []
|
||||
)
|
||||
|
||||
events = []
|
||||
pending = None
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
assert pending is not None
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["pause_type"] == "awaiting_approval"
|
||||
|
||||
# Should NOT have executed the tool
|
||||
assert agent._execute_tool_action.call_count == 0
|
||||
|
||||
# Should have yielded awaiting_approval status
|
||||
approval_events = [
|
||||
e for e in events
|
||||
if e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "awaiting_approval"
|
||||
]
|
||||
assert len(approval_events) == 1
|
||||
|
||||
def test_mixed_normal_and_approval(self):
|
||||
"""First tool runs normally, second needs approval."""
|
||||
handler = ConcreteHandler()
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def selective_pause(tools_dict, call, llm_class):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 2:
|
||||
return {
|
||||
"call_id": "c2",
|
||||
"name": "send_msg_0",
|
||||
"tool_name": "telegram",
|
||||
"tool_id": "0",
|
||||
"action_name": "send_msg",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
return None
|
||||
|
||||
agent = self._make_agent(None)
|
||||
agent.tool_executor.check_pause = Mock(side_effect=selective_pause)
|
||||
|
||||
calls = [
|
||||
ToolCall(id="c1", name="search_0", arguments="{}"),
|
||||
ToolCall(id="c2", name="send_msg_0", arguments="{}"),
|
||||
]
|
||||
|
||||
gen = handler.handle_tool_calls(
|
||||
agent, calls, {"0": {"name": "multi"}}, []
|
||||
)
|
||||
|
||||
events = []
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
# First tool executed
|
||||
assert agent._execute_tool_action.call_count == 1
|
||||
# Second tool is pending
|
||||
assert pending is not None
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["call_id"] == "c2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# gen_continuation: approval and denial flows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenContinuationApproval:
|
||||
|
||||
def _make_agent(self):
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm._supports_tools = True
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
mock_llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
mock_handler = Mock()
|
||||
mock_handler.process_message_flow = Mock(return_value=iter([]))
|
||||
mock_handler.create_tool_message = Mock(
|
||||
return_value={"role": "tool", "tool_call_id": "c1", "content": "result"}
|
||||
)
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.tool_calls = []
|
||||
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
|
||||
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
|
||||
|
||||
def fake_execute(tools_dict, call, llm_class):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
return ("executed_result", "c1")
|
||||
|
||||
mock_executor.execute = Mock(side_effect=fake_execute)
|
||||
|
||||
agent = ClassicAgent(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test",
|
||||
llm=mock_llm,
|
||||
llm_handler=mock_handler,
|
||||
tool_executor=mock_executor,
|
||||
)
|
||||
return agent, mock_executor, mock_handler
|
||||
|
||||
def test_approved_tool_executes(self):
|
||||
agent, mock_executor, mock_handler = self._make_agent()
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "send_msg_0",
|
||||
"tool_name": "telegram",
|
||||
"tool_id": "0",
|
||||
"action_name": "send_msg",
|
||||
"arguments": {"text": "hello"},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [{"call_id": "c1", "decision": "approved"}]
|
||||
|
||||
list(agent.gen_continuation(
|
||||
messages, {"0": {"name": "telegram"}}, pending, tool_actions
|
||||
))
|
||||
|
||||
# Tool should have been executed
|
||||
assert mock_executor.execute.called
|
||||
|
||||
def test_denied_tool_sends_denial_to_llm(self):
|
||||
agent, mock_executor, mock_handler = self._make_agent()
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "send_msg_0",
|
||||
"tool_name": "telegram",
|
||||
"tool_id": "0",
|
||||
"action_name": "send_msg",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [
|
||||
{"call_id": "c1", "decision": "denied", "comment": "not safe"},
|
||||
]
|
||||
|
||||
events = list(agent.gen_continuation(
|
||||
messages, {"0": {"name": "telegram"}}, pending, tool_actions
|
||||
))
|
||||
|
||||
# Tool should NOT have been executed
|
||||
assert not mock_executor.execute.called
|
||||
|
||||
# Should have a denied event
|
||||
denied = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "denied"
|
||||
]
|
||||
assert len(denied) == 1
|
||||
|
||||
# create_tool_message should have been called with denial text
|
||||
denial_text = mock_handler.create_tool_message.call_args[0][1]
|
||||
assert "denied" in denial_text.lower()
|
||||
assert "not safe" in denial_text
|
||||
|
||||
def test_denied_without_comment(self):
|
||||
agent, mock_executor, mock_handler = self._make_agent()
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "act_0",
|
||||
"tool_name": "tool",
|
||||
"tool_id": "0",
|
||||
"action_name": "act",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [{"call_id": "c1", "decision": "denied"}]
|
||||
|
||||
list(agent.gen_continuation(
|
||||
messages, {"0": {"name": "tool"}}, pending, tool_actions
|
||||
))
|
||||
|
||||
denial_text = mock_handler.create_tool_message.call_args[0][1]
|
||||
assert "denied" in denial_text.lower()
|
||||
|
||||
def test_mixed_approve_deny_batch(self):
|
||||
"""Two tools: one approved, one denied."""
|
||||
agent, mock_executor, mock_handler = self._make_agent()
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "safe_0",
|
||||
"tool_name": "safe",
|
||||
"tool_id": "0",
|
||||
"action_name": "safe",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
},
|
||||
{
|
||||
"call_id": "c2",
|
||||
"name": "danger_0",
|
||||
"tool_name": "danger",
|
||||
"tool_id": "0",
|
||||
"action_name": "danger",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
},
|
||||
]
|
||||
tool_actions = [
|
||||
{"call_id": "c1", "decision": "approved"},
|
||||
{"call_id": "c2", "decision": "denied", "comment": "too risky"},
|
||||
]
|
||||
|
||||
events = list(agent.gen_continuation(
|
||||
messages, {"0": {"name": "multi"}}, pending, tool_actions
|
||||
))
|
||||
|
||||
# First tool executed, second denied
|
||||
assert mock_executor.execute.call_count == 1
|
||||
|
||||
denied = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "denied"
|
||||
]
|
||||
assert len(denied) == 1
|
||||
|
||||
def test_missing_action_defaults_to_denial(self):
|
||||
"""If client doesn't respond for a pending tool, treat as denied."""
|
||||
agent, mock_executor, mock_handler = self._make_agent()
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "act_0",
|
||||
"tool_name": "tool",
|
||||
"tool_id": "0",
|
||||
"action_name": "act",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
# Empty tool_actions — no response for c1
|
||||
tool_actions = []
|
||||
|
||||
events = list(agent.gen_continuation(
|
||||
messages, {"0": {"name": "tool"}}, pending, tool_actions
|
||||
))
|
||||
|
||||
# Should have been treated as denied
|
||||
assert not mock_executor.execute.called
|
||||
denied = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "denied"
|
||||
]
|
||||
assert len(denied) == 1
|
||||
577
tests/test_v1_translator.py
Normal file
577
tests/test_v1_translator.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""Tests for the v1 API translator (Phase 4).
|
||||
|
||||
Covers request translation, response translation, streaming event
|
||||
translation, continuation detection, and history conversion.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.v1.translator import (
|
||||
_get_client_tool_name,
|
||||
convert_history,
|
||||
extract_tool_results,
|
||||
is_continuation,
|
||||
translate_request,
|
||||
translate_response,
|
||||
translate_stream_event,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_continuation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestIsContinuation:
|
||||
|
||||
def test_normal_messages_not_continuation(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert is_continuation(messages) is False
|
||||
|
||||
def test_tool_after_assistant_tool_calls_is_continuation(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": '{"temp": "72F"}'},
|
||||
]
|
||||
assert is_continuation(messages) is True
|
||||
|
||||
def test_assistant_without_tool_calls_not_continuation(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "result"},
|
||||
]
|
||||
# assistant has no tool_calls — not a valid continuation
|
||||
assert is_continuation(messages) is False
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert is_continuation([]) is False
|
||||
|
||||
def test_multiple_tool_results(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Do stuff"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "c1", "type": "function", "function": {"name": "a", "arguments": "{}"}},
|
||||
{"id": "c2", "type": "function", "function": {"name": "b", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "r1"},
|
||||
{"role": "tool", "tool_call_id": "c2", "content": "r2"},
|
||||
]
|
||||
assert is_continuation(messages) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_tool_results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractToolResults:
|
||||
|
||||
def test_extracts_results(self):
|
||||
messages = [
|
||||
{"role": "assistant", "tool_calls": [{"id": "c1"}]},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": '{"temp": "72F"}'},
|
||||
]
|
||||
results = extract_tool_results(messages)
|
||||
assert len(results) == 1
|
||||
assert results[0]["call_id"] == "c1"
|
||||
assert results[0]["result"] == {"temp": "72F"}
|
||||
|
||||
def test_string_content(self):
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "plain text"},
|
||||
]
|
||||
results = extract_tool_results(messages)
|
||||
assert results[0]["result"] == "plain text"
|
||||
|
||||
def test_multiple_results(self):
|
||||
messages = [
|
||||
{"role": "assistant", "tool_calls": []},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "r1"},
|
||||
{"role": "tool", "tool_call_id": "c2", "content": "r2"},
|
||||
]
|
||||
results = extract_tool_results(messages)
|
||||
assert len(results) == 2
|
||||
assert results[0]["call_id"] == "c1"
|
||||
assert results[1]["call_id"] == "c2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# convert_history
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestConvertHistory:
|
||||
|
||||
def test_user_assistant_pairs(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
{"role": "assistant", "content": "I'm good"},
|
||||
{"role": "user", "content": "What's 2+2?"}, # Last user = question
|
||||
]
|
||||
history = convert_history(messages)
|
||||
assert len(history) == 2
|
||||
assert history[0]["prompt"] == "Hello"
|
||||
assert history[0]["response"] == "Hi there"
|
||||
assert history[1]["prompt"] == "How are you?"
|
||||
assert history[1]["response"] == "I'm good"
|
||||
|
||||
def test_single_user_message(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
history = convert_history(messages)
|
||||
assert history == []
|
||||
|
||||
def test_system_messages_skipped(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "Question"},
|
||||
]
|
||||
history = convert_history(messages)
|
||||
assert history == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# translate_request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTranslateRequest:
|
||||
|
||||
def test_normal_request(self):
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
{"role": "user", "content": "What's 2+2?"},
|
||||
],
|
||||
}
|
||||
result = translate_request(data, "test-key")
|
||||
assert result["question"] == "What's 2+2?"
|
||||
assert result["api_key"] == "test-key"
|
||||
assert result["save_conversation"] is True
|
||||
history = json.loads(result["history"])
|
||||
assert len(history) == 1
|
||||
assert history[0]["prompt"] == "Hello"
|
||||
|
||||
def test_continuation_request(self):
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Search for X"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "search", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": '{"found": true}'},
|
||||
],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert "tool_actions" in result
|
||||
assert len(result["tool_actions"]) == 1
|
||||
assert result["tool_actions"][0]["call_id"] == "c1"
|
||||
|
||||
def test_continuation_with_top_level_conversation_id(self):
|
||||
"""Standard clients send conversation_id at request level, not in messages."""
|
||||
data = {
|
||||
"conversation_id": "conv-top-level",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do stuff"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "done"},
|
||||
],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["conversation_id"] == "conv-top-level"
|
||||
|
||||
def test_continuation_in_message_conversation_id_takes_precedence(self):
|
||||
"""When both in-message and top-level conversation_id exist, in-message wins."""
|
||||
data = {
|
||||
"conversation_id": "conv-top-level",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Do stuff"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}],
|
||||
"docsgpt": {"conversation_id": "conv-in-message"},
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "done"},
|
||||
],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["conversation_id"] == "conv-in-message"
|
||||
|
||||
def test_client_tools_passed_through(self):
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"tools": [{"type": "function", "function": {"name": "my_tool"}}],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["client_tools"] == data["tools"]
|
||||
|
||||
def test_docsgpt_attachments(self):
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"docsgpt": {"attachments": ["att1", "att2"]},
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["attachments"] == ["att1", "att2"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# translate_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTranslateResponse:
|
||||
|
||||
def test_basic_response(self):
|
||||
resp = translate_response(
|
||||
conversation_id="conv-1",
|
||||
answer="Hello!",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
thought="",
|
||||
model_name="my-agent",
|
||||
)
|
||||
assert resp["id"] == "chatcmpl-conv-1"
|
||||
assert resp["object"] == "chat.completion"
|
||||
assert resp["model"] == "my-agent"
|
||||
assert resp["choices"][0]["message"]["content"] == "Hello!"
|
||||
assert resp["choices"][0]["finish_reason"] == "stop"
|
||||
assert "reasoning_content" not in resp["choices"][0]["message"]
|
||||
|
||||
def test_response_with_thought(self):
|
||||
resp = translate_response(
|
||||
conversation_id="c1",
|
||||
answer="Result",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
thought="Thinking about it...",
|
||||
model_name="agent",
|
||||
)
|
||||
assert resp["choices"][0]["message"]["reasoning_content"] == "Thinking about it..."
|
||||
|
||||
def test_response_with_sources(self):
|
||||
sources = [{"title": "doc.txt", "text": "content", "source": "/doc.txt"}]
|
||||
resp = translate_response(
|
||||
conversation_id="c1",
|
||||
answer="Found it",
|
||||
sources=sources,
|
||||
tool_calls=[],
|
||||
thought="",
|
||||
model_name="agent",
|
||||
)
|
||||
assert resp["docsgpt"]["sources"] == sources
|
||||
|
||||
def test_response_with_tool_calls(self):
|
||||
tool_calls = [{"tool_name": "notes", "call_id": "c1", "artifact_id": "a1"}]
|
||||
resp = translate_response(
|
||||
conversation_id="c1",
|
||||
answer="Done",
|
||||
sources=[],
|
||||
tool_calls=tool_calls,
|
||||
thought="",
|
||||
model_name="agent",
|
||||
)
|
||||
assert resp["docsgpt"]["tool_calls"] == tool_calls
|
||||
|
||||
def test_pending_tool_calls_uses_tool_name(self):
|
||||
"""Client tool responses use the original tool_name, not the LLM-visible action_name."""
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"tool_name": "get_weather",
|
||||
"action_name": "get_weather",
|
||||
"arguments": {"city": "SF"},
|
||||
}
|
||||
]
|
||||
resp = translate_response(
|
||||
conversation_id="c1",
|
||||
answer="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
thought="",
|
||||
model_name="agent",
|
||||
pending_tool_calls=pending,
|
||||
)
|
||||
tc = resp["choices"][0]["message"]["tool_calls"][0]
|
||||
assert tc["function"]["name"] == "get_weather"
|
||||
|
||||
def test_pending_tool_calls_tool_name_takes_precedence(self):
|
||||
"""When tool_name differs from action_name, tool_name is used."""
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"tool_name": "search",
|
||||
"action_name": "search_1",
|
||||
"arguments": {"q": "test"},
|
||||
}
|
||||
]
|
||||
resp = translate_response(
|
||||
conversation_id="c1",
|
||||
answer="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
thought="",
|
||||
model_name="agent",
|
||||
pending_tool_calls=pending,
|
||||
)
|
||||
tc = resp["choices"][0]["message"]["tool_calls"][0]
|
||||
assert tc["function"]["name"] == "search"
|
||||
|
||||
def test_pending_tool_calls(self):
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "get_weather",
|
||||
"arguments": {"city": "SF"},
|
||||
}
|
||||
]
|
||||
resp = translate_response(
|
||||
conversation_id="c1",
|
||||
answer="",
|
||||
sources=[],
|
||||
tool_calls=[],
|
||||
thought="",
|
||||
model_name="agent",
|
||||
pending_tool_calls=pending,
|
||||
)
|
||||
assert resp["choices"][0]["finish_reason"] == "tool_calls"
|
||||
assert resp["choices"][0]["message"]["content"] is None
|
||||
assert len(resp["choices"][0]["message"]["tool_calls"]) == 1
|
||||
tc = resp["choices"][0]["message"]["tool_calls"][0]
|
||||
assert tc["id"] == "c1"
|
||||
assert tc["function"]["name"] == "get_weather"
|
||||
|
||||
def test_no_docsgpt_when_empty(self):
|
||||
resp = translate_response(
|
||||
conversation_id="",
|
||||
answer="Hi",
|
||||
sources=None,
|
||||
tool_calls=None,
|
||||
thought="",
|
||||
model_name="agent",
|
||||
)
|
||||
assert "docsgpt" not in resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# translate_stream_event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTranslateStreamEvent:
|
||||
|
||||
def test_answer_event(self):
|
||||
chunks = translate_stream_event(
|
||||
{"type": "answer", "answer": "Hello"},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["choices"][0]["delta"]["content"] == "Hello"
|
||||
|
||||
def test_thought_event(self):
|
||||
chunks = translate_stream_event(
|
||||
{"type": "thought", "thought": "reasoning"},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["choices"][0]["delta"]["reasoning_content"] == "reasoning"
|
||||
|
||||
def test_source_event(self):
|
||||
chunks = translate_stream_event(
|
||||
{"type": "source", "source": [{"title": "t", "text": "x"}]},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["docsgpt"]["type"] == "source"
|
||||
assert len(parsed["docsgpt"]["sources"]) == 1
|
||||
|
||||
def test_end_event(self):
|
||||
chunks = translate_stream_event(
|
||||
{"type": "end"},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 2
|
||||
# First chunk: finish_reason stop
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["choices"][0]["finish_reason"] == "stop"
|
||||
# Second chunk: [DONE]
|
||||
assert chunks[1].strip() == "data: [DONE]"
|
||||
|
||||
def test_tool_call_client_execution(self):
|
||||
chunks = translate_stream_event(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"call_id": "c1",
|
||||
"action_name": "get_weather",
|
||||
"arguments": {"city": "SF"},
|
||||
"status": "requires_client_execution",
|
||||
},
|
||||
},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
tc = parsed["choices"][0]["delta"]["tool_calls"][0]
|
||||
assert tc["id"] == "c1"
|
||||
assert tc["function"]["name"] == "get_weather"
|
||||
|
||||
def test_tool_call_client_execution_uses_tool_name(self):
|
||||
"""Streaming tool calls use tool_name (original name) for client responses."""
|
||||
chunks = translate_stream_event(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"call_id": "c1",
|
||||
"tool_name": "create",
|
||||
"action_name": "create",
|
||||
"arguments": {"title": "test"},
|
||||
"status": "requires_client_execution",
|
||||
},
|
||||
},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
tc = parsed["choices"][0]["delta"]["tool_calls"][0]
|
||||
assert tc["function"]["name"] == "create"
|
||||
|
||||
def test_tool_call_completed(self):
|
||||
chunks = translate_stream_event(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"call_id": "c1",
|
||||
"status": "completed",
|
||||
"result": "done",
|
||||
"artifact_id": "a1",
|
||||
},
|
||||
},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["docsgpt"]["type"] == "tool_call"
|
||||
assert parsed["docsgpt"]["data"]["artifact_id"] == "a1"
|
||||
|
||||
def test_tool_calls_pending(self):
|
||||
chunks = translate_stream_event(
|
||||
{
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": [{"call_id": "c1"}]},
|
||||
},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 2
|
||||
# Standard chunk with finish_reason tool_calls
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["choices"][0]["finish_reason"] == "tool_calls"
|
||||
# Extension chunk
|
||||
ext = json.loads(chunks[1].replace("data: ", "").strip())
|
||||
assert ext["docsgpt"]["type"] == "tool_calls_pending"
|
||||
|
||||
def test_id_event(self):
|
||||
chunks = translate_stream_event(
|
||||
{"type": "id", "id": "conv-123"},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["docsgpt"]["conversation_id"] == "conv-123"
|
||||
|
||||
def test_error_event(self):
|
||||
chunks = translate_stream_event(
|
||||
{"type": "error", "error": "Something went wrong"},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["error"]["message"] == "Something went wrong"
|
||||
|
||||
def test_tool_calls_event_skipped(self):
|
||||
"""The aggregate tool_calls event is redundant and should be skipped."""
|
||||
chunks = translate_stream_event(
|
||||
{"type": "tool_calls", "tool_calls": [{"call_id": "c1"}]},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_research_events_skipped(self):
|
||||
assert translate_stream_event(
|
||||
{"type": "research_plan", "data": {}}, "id", "m"
|
||||
) == []
|
||||
assert translate_stream_event(
|
||||
{"type": "research_progress", "data": {}}, "id", "m"
|
||||
) == []
|
||||
|
||||
def test_awaiting_approval_as_extension(self):
|
||||
chunks = translate_stream_event(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"data": {"call_id": "c1", "status": "awaiting_approval"},
|
||||
},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
assert parsed["docsgpt"]["type"] == "tool_call"
|
||||
|
||||
def test_standard_clients_can_ignore_docsgpt(self):
|
||||
"""Standard clients parse only 'choices' — docsgpt namespace is ignored."""
|
||||
chunks = translate_stream_event(
|
||||
{"type": "source", "source": [{"title": "t"}]},
|
||||
"chatcmpl-1", "agent",
|
||||
)
|
||||
parsed = json.loads(chunks[0].replace("data: ", "").strip())
|
||||
# No "choices" key — standard parsers skip this chunk entirely
|
||||
assert "choices" not in parsed
|
||||
# docsgpt key is present
|
||||
assert "docsgpt" in parsed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_client_tool_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetClientToolName:
|
||||
|
||||
def test_uses_tool_name_when_present(self):
|
||||
assert _get_client_tool_name({"tool_name": "create", "action_name": "create_1"}) == "create"
|
||||
|
||||
def test_falls_back_to_action_name(self):
|
||||
assert _get_client_tool_name({"action_name": "get_weather"}) == "get_weather"
|
||||
|
||||
def test_falls_back_to_name(self):
|
||||
assert _get_client_tool_name({"name": "search"}) == "search"
|
||||
|
||||
def test_returns_empty_when_no_fields(self):
|
||||
assert _get_client_tool_name({}) == ""
|
||||
Reference in New Issue
Block a user