Merge pull request #2349 from arc53/messages-format

Messages format
This commit is contained in:
Alex
2026-04-03 16:26:57 +01:00
committed by GitHub
55 changed files with 6486 additions and 695 deletions

View File

@@ -1,7 +1,8 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional
from application.agents.tool_executor import ToolExecutor
from application.core.json_schema_utils import (
@@ -9,6 +10,7 @@ from application.core.json_schema_utils import (
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.llm.handlers.base import ToolCall
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
@@ -113,6 +115,153 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
def gen_continuation(
self,
messages: List[Dict],
tools_dict: Dict,
pending_tool_calls: List[Dict],
tool_actions: List[Dict],
) -> Generator[Dict, None, None]:
"""Resume generation after tool actions are resolved.
Processes the client-provided *tool_actions* (approvals, denials,
or client-side results), appends the resulting messages, then
hands back to the LLM to continue the conversation.
Args:
messages: The saved messages array from the pause point.
tools_dict: The saved tools dictionary.
pending_tool_calls: The pending tool call descriptors from the pause.
tool_actions: Client-provided actions resolving the pending calls.
"""
self._prepare_tools(tools_dict)
actions_by_id = {a["call_id"]: a for a in tool_actions}
# Build a single assistant message containing all tool calls so
# the message history matches the format LLM providers expect
# (one assistant message with N tool_calls, followed by N tool results).
tc_objects: List[Dict[str, Any]] = []
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
args_str = (
json.dumps(args) if isinstance(args, dict) else (args or "{}")
)
tc_obj: Dict[str, Any] = {
"id": call_id,
"type": "function",
"function": {
"name": pending["name"],
"arguments": args_str,
},
}
if pending.get("thought_signature"):
tc_obj["thought_signature"] = pending["thought_signature"]
tc_objects.append(tc_obj)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tc_objects,
})
# Now process each pending call and append tool result messages
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
if action.get("decision") == "approved":
# Execute the tool server-side
tc = ToolCall(
id=call_id,
name=pending["name"],
arguments=(
json.dumps(args) if isinstance(args, dict) else args
),
)
tool_gen = self._execute_tool_action(tools_dict, tc)
tool_response = None
while True:
try:
event = next(tool_gen)
yield event
except StopIteration as e:
tool_response, _ = e.value
break
messages.append(
self.llm_handler.create_tool_message(tc, tool_response)
)
elif action.get("decision") == "denied":
comment = action.get("comment", "")
denial = (
f"Tool execution denied by user. Reason: {comment}"
if comment
else "Tool execution denied by user."
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, denial)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"status": "denied",
},
}
elif "result" in action:
result = action["result"]
result_str = (
json.dumps(result)
if not isinstance(result, str)
else result
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, result_str)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"result": (
result_str[:50] + "..."
if len(result_str) > 50
else result_str
),
"status": "completed",
},
}
# Resume the LLM loop with the updated messages
llm_response = self._llm_gen(messages)
yield from self._handle_response(
llm_response, tools_dict, messages, None
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property
@@ -267,28 +416,35 @@ class BaseAgent(ABC):
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
messages.append({"role": "user", "content": query})
return messages

View File

@@ -593,16 +593,22 @@ class ResearchAgent(BaseAgent):
)
result = result_str
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_content]}
import json as _json
args_str = (
_json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {"name": call.name, "arguments": args_str},
}],
})
tool_message = self.llm_handler.create_tool_message(call, result)
messages.append(tool_message)

View File

@@ -1,6 +1,7 @@
import logging
import uuid
from typing import Dict, List, Optional
from collections import Counter
from typing import Dict, List, Optional, Tuple
from bson.objectid import ObjectId
@@ -31,12 +32,23 @@ class ToolExecutor:
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
self.client_tools: Optional[List[Dict]] = None
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
self._tool_to_name: Dict[Tuple[str, str], str] = {}
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context."""
"""Load tool configs from DB based on user context.
If *client_tools* have been set on this executor, they are
automatically merged into the returned dict.
"""
if self.user_api_key:
return self._get_tools_by_api_key(self.user_api_key)
return self._get_user_tools(self.user or "local")
tools = self._get_tools_by_api_key(self.user_api_key)
else:
tools = self._get_user_tools(self.user or "local")
if self.client_tools:
self.merge_client_tools(tools, self.client_tools)
return tools
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
@@ -65,29 +77,123 @@ class ToolExecutor:
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas."""
return [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
def merge_client_tools(
self, tools_dict: Dict, client_tools: List[Dict]
) -> Dict:
"""Merge client-provided tool definitions into tools_dict.
Client tools use the standard function-calling format::
[{"type": "function", "function": {"name": "get_weather",
"description": "...", "parameters": {...}}}]
They are stored in *tools_dict* with ``client_side: True`` so that
:meth:`check_pause` returns a pause signal instead of trying to
execute them server-side.
Args:
tools_dict: The mutable server tools dict (will be modified in place).
client_tools: List of tool definitions in function-calling format.
Returns:
The updated *tools_dict* (same reference, for convenience).
"""
for i, ct in enumerate(client_tools):
func = ct.get("function", ct) # tolerate bare {"name":..} too
name = func.get("name", f"clienttool{i}")
tool_id = f"ct{i}"
tools_dict[tool_id] = {
"name": name,
"client_side": True,
"actions": [
{
"name": name,
"description": func.get("description", ""),
"active": True,
"parameters": func.get("parameters", {}),
}
],
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
return tools_dict
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas.
Action names are kept clean for the LLM:
- Unique action names appear as-is (e.g. ``get_weather``).
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
``search_2``).
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
can be routed back to the correct ``(tool_id, action_name)`` without
brittle string splitting.
"""
# Pass 1: collect entries and count action name occurrences
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
name_counts: Counter = Counter()
for tool_id, tool in tools_dict.items():
is_api = tool["name"] == "api_tool"
is_client = tool.get("client_side", False)
if is_api and "actions" not in tool.get("config", {}):
continue
if not is_api and "actions" not in tool:
continue
actions = (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
if is_api
else tool["actions"]
)
if action.get("active", True)
]
for action in actions:
if not action.get("active", True):
continue
entries.append((tool_id, action["name"], action, is_client))
name_counts[action["name"]] += 1
# Pass 2: assign LLM-visible names and build mappings
self._name_to_tool = {}
self._tool_to_name = {}
collision_counters: Dict[str, int] = {}
all_llm_names: set = set()
result = []
for tool_id, action_name, action, is_client in entries:
if name_counts[action_name] == 1:
llm_name = action_name
else:
counter = collision_counters.get(action_name, 1)
candidate = f"{action_name}_{counter}"
# Skip if candidate collides with a unique action name
while candidate in all_llm_names or (
candidate in name_counts and name_counts[candidate] == 1
):
counter += 1
candidate = f"{action_name}_{counter}"
collision_counters[action_name] = counter + 1
llm_name = candidate
all_llm_names.add(llm_name)
self._name_to_tool[llm_name] = (tool_id, action_name)
self._tool_to_name[(tool_id, action_name)] = llm_name
if is_client:
params = action.get("parameters", {})
else:
params = self._build_tool_parameters(action)
result.append({
"type": "function",
"function": {
"name": llm_name,
"description": action.get("description", ""),
"parameters": params,
},
})
return result
def _build_tool_parameters(self, action: Dict) -> Dict:
params = {"type": "object", "properties": {}, "required": []}
@@ -104,23 +210,81 @@ class ToolExecutor:
params["required"].append(k)
return params
def check_pause(
self, tools_dict: Dict, call, llm_class_name: str
) -> Optional[Dict]:
"""Check if a tool call requires pausing for approval or client execution.
Returns a dict describing the pending action if pause is needed, None otherwise.
"""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
llm_name = getattr(call, "name", "")
if tool_id is None or action_name is None or tool_id not in tools_dict:
return None # Will be handled as error by execute()
tool_data = tools_dict[tool_id]
# Client-side tools
if tool_data.get("client_side"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "requires_client_execution",
"thought_signature": getattr(call, "thought_signature", None),
}
# Approval required
if tool_data["name"] == "api_tool":
action_data = tool_data.get("config", {}).get("actions", {}).get(
action_name, {}
)
else:
action_data = next(
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
{},
)
if action_data.get("require_approval"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "awaiting_approval",
"thought_signature": getattr(call, "thought_signature", None),
}
return None
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name)
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
llm_name = getattr(call, "name", "unknown")
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"action_name": llm_name,
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
@@ -133,7 +297,7 @@ class ToolExecutor:
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"action_name": llm_name,
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
@@ -144,7 +308,7 @@ class ToolExecutor:
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"action_name": llm_name,
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}

View File

@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
class Tool(ABC):
internal: bool = False
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass

View File

@@ -20,6 +20,8 @@ class InternalSearchTool(Tool):
- list_files action: browse the file/folder structure
"""
internal = True
def __init__(self, config: Dict):
self.config = config
self.retrieved_docs: List[Dict] = []

View File

@@ -36,6 +36,8 @@ class ThinkTool(Tool):
The reasoning content is captured in tool_call data for transparency.
"""
internal = True
def __init__(self, config=None):
pass

View File

@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
class ToolActionParser:
def __init__(self, llm_type):
def __init__(self, llm_type, name_mapping=None):
self.llm_type = llm_type
self.name_mapping = name_mapping
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
@@ -16,22 +17,33 @@ class ToolActionParser:
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _resolve_via_mapping(self, call_name):
"""Look up (tool_id, action_name) from the name mapping if available."""
if self.name_mapping and call_name in self.name_mapping:
return self.name_mapping[call_name]
return None
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
@@ -45,19 +57,24 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."

View File

@@ -19,7 +19,7 @@ class ToolManager:
continue
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)

View File

@@ -74,57 +74,72 @@ class AnswerResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
stream = self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
)
else:
# ---- Normal mode ----
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if stream_result["error"]:
return make_response({"error": stream_result["error"]}, 400)
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
"conversation_id": stream_result["conversation_id"],
"answer": stream_result["answer"],
"sources": stream_result["sources"],
"tool_calls": stream_result["tool_calls"],
"thought": stream_result["thought"],
}
if structured_info:
result.update(structured_info)
extra_info = stream_result.get("extra")
if extra_info:
result.update(extra_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
@@ -39,7 +40,16 @@ class BaseAnswerResource:
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
"""Common request validation.
Continuation requests (``tool_actions`` present) require
``conversation_id`` but not ``question``.
"""
if data.get("tool_actions"):
# Continuation mode — question is not required
if missing := check_required_fields(data, ["conversation_id"]):
return missing
return None
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
@@ -177,6 +187,7 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -207,8 +218,19 @@ class BaseAnswerResource:
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
for line in agent.gen(query=question):
if _continuation:
gen_iter = agent.gen_continuation(
messages=_continuation["messages"],
tools_dict=_continuation["tools_dict"],
pending_tool_calls=_continuation["pending_tool_calls"],
tool_actions=_continuation["tool_actions"],
)
else:
gen_iter = agent.gen(query=question)
for line in gen_iter:
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
@@ -244,15 +266,21 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "error":
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -262,6 +290,93 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation:
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
conversation_id = (
self.conversation_service.save_conversation(
None,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
)
except Exception as e:
logger.error(
f"Failed to create conversation for continuation: {e}",
exc_info=True,
)
if conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
@@ -425,8 +540,13 @@ class BaseAnswerResource:
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
def process_response_stream(self, stream) -> Dict[str, Any]:
"""Process the stream response for non-streaming endpoint.
Returns:
Dict with keys: conversation_id, answer, sources, tool_calls,
thought, error, and optional extra.
"""
conversation_id = ""
response_full = ""
source_log_docs = []
@@ -435,6 +555,7 @@ class BaseAnswerResource:
stream_ended = False
is_structured = False
schema_info = None
pending_tool_calls = None
for line in stream:
try:
@@ -453,11 +574,22 @@ class BaseAnswerResource:
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "tool_calls_pending":
pending_tool_calls = event.get("data", {}).get(
"pending_tool_calls", []
)
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"], None
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": event["error"],
}
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -465,18 +597,30 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": "Stream ended unexpectedly",
}
result: Dict[str, Any] = {
"conversation_id": conversation_id,
"answer": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls,
"thought": thought,
"error": None,
}
if pending_tool_calls is not None:
result["extra"] = {"pending_tool_calls": pending_tool_calls}
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
result["extra"] = {"structured": True, "schema": schema_info}
return result
def error_stream_generate(self, err_response):

View File

@@ -79,7 +79,39 @@ class StreamResource(Resource, BaseAnswerResource):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
return Response(
self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
),
mimetype="text/event-stream",
)
# ---- Normal mode ----
agent = processor.build_agent(data["question"])
if not processor.decoded_token:
return Response(

View File

@@ -1,5 +1,6 @@
"""Message reconstruction utilities for compression."""
import json
import logging
import uuid
from typing import Dict, List, Optional
@@ -49,28 +50,35 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
@@ -180,28 +188,35 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
rebuilt_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
rebuilt_messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:

View File

@@ -0,0 +1,141 @@
"""Service for saving and restoring tool-call continuation state.
When a stream pauses (tool needs approval or client-side execution),
the full execution state is persisted to MongoDB so the client can
resume later by sending tool_actions.
"""
import datetime
import logging
from typing import Any, Dict, List, Optional
from bson import ObjectId
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
def _make_serializable(obj: Any) -> Any:
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
if isinstance(obj, ObjectId):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
"""Manages pending tool-call state in MongoDB."""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.collection = db["pending_tool_state"]
self._ensure_indexes()
def _ensure_indexes(self):
try:
self.collection.create_index(
"expires_at", expireAfterSeconds=0
)
self.collection.create_index(
[("conversation_id", 1), ("user", 1)], unique=True
)
except Exception:
# Indexes may already exist or mongomock doesn't support TTL
pass
def save_state(
self,
conversation_id: str,
user: str,
messages: List[Dict],
pending_tool_calls: List[Dict],
tools_dict: Dict,
tool_schemas: List[Dict],
agent_config: Dict,
client_tools: Optional[List[Dict]] = None,
) -> str:
"""Save execution state for later continuation.
Args:
conversation_id: The conversation this state belongs to.
user: Owner user ID.
messages: Full messages array at the pause point.
pending_tool_calls: Tool calls awaiting client action.
tools_dict: Serializable tools configuration dict.
tool_schemas: LLM-formatted tool schemas (agent.tools).
agent_config: Config needed to recreate the agent on resume.
client_tools: Client-provided tool schemas for client-side execution.
Returns:
The string ID of the saved state document.
"""
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
doc = {
"conversation_id": conversation_id,
"user": user,
"messages": _make_serializable(messages),
"pending_tool_calls": _make_serializable(pending_tool_calls),
"tools_dict": _make_serializable(tools_dict),
"tool_schemas": _make_serializable(tool_schemas),
"agent_config": _make_serializable(agent_config),
"client_tools": _make_serializable(client_tools) if client_tools else None,
"created_at": now,
"expires_at": expires_at,
}
# Upsert — only one pending state per conversation per user
result = self.collection.replace_one(
{"conversation_id": conversation_id, "user": user},
doc,
upsert=True,
)
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
return state_id
def load_state(
self, conversation_id: str, user: str
) -> Optional[Dict[str, Any]]:
"""Load pending continuation state.
Returns:
The state dict, or None if no pending state exists.
"""
doc = self.collection.find_one(
{"conversation_id": conversation_id, "user": user}
)
if not doc:
return None
doc["_id"] = str(doc["_id"])
return doc
def delete_state(self, conversation_id: str, user: str) -> bool:
"""Delete pending state after successful resumption.
Returns:
True if a document was deleted.
"""
result = self.collection.delete_one(
{"conversation_id": conversation_id, "user": user}
)
if result.deleted_count:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
return result.deleted_count > 0

View File

@@ -771,6 +771,121 @@ class StreamProcessor:
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
return None
def resume_from_tool_actions(
self,
tool_actions: list,
conversation_id: str,
):
"""Resume a paused agent from saved continuation state.
Loads the pending state from MongoDB, recreates the agent with
the saved configuration, and returns an agent ready to call
``gen_continuation()``.
Args:
tool_actions: Client-provided actions (approvals / results).
conversation_id: The conversation being resumed.
Returns:
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
"""
from application.api.answer.services.continuation_service import (
ContinuationService,
)
from application.agents.agent_creator import AgentCreator
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
cont_service = ContinuationService()
state = cont_service.load_state(conversation_id, self.initial_user_id)
if not state:
raise ValueError("No pending tool state found for this conversation")
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
tool_schemas = state.get("tool_schemas", [])
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
agent_id = agent_config.get("agent_id")
prompt = agent_config.get("prompt", "")
json_schema = agent_config.get("json_schema")
retriever_config = agent_config.get("retriever_config")
# Recreate dependencies
system_api_key = api_key or get_api_key_for_provider(llm_name)
llm = LLMCreator.create_llm(
llm_name,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.initial_user_id,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = conversation_id
# Restore client tools so they stay available for subsequent LLM calls
saved_client_tools = state.get("client_tools")
if saved_client_tools:
tool_executor.client_tools = saved_client_tools
# Re-merge into tools_dict (they may have been stripped during serialization)
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
agent_type = agent_config.get("agent_type", "ClassicAgent")
# Map class names back to agent creator keys
type_map = {
"ClassicAgent": "classic",
"AgenticAgent": "agentic",
"ResearchAgent": "research",
"WorkflowAgent": "workflow",
}
agent_key = type_map.get(agent_type, "classic")
agent_kwargs = {
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
"prompt": prompt,
"chat_history": [],
"decoded_token": self.decoded_token,
"json_schema": json_schema,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
if agent_key in ("agentic", "research") and retriever_config:
agent_kwargs["retriever_config"] = retriever_config
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
agent.conversation_id = conversation_id
agent.initial_user_id = self.initial_user_id
agent.tools = tool_schemas
# Store config for the route layer
self.model_id = model_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Delete state so it can't be replayed
cont_service.delete_state(conversation_id, self.initial_user_id)
return agent, messages, tools_dict, pending_tool_calls, tool_actions
def create_agent(
self,
docs_together: Optional[str] = None,
@@ -841,6 +956,10 @@ class StreamProcessor:
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = self.conversation_id
# Pass client-side tools so they get merged in get_tools()
client_tools = self.data.get("client_tools")
if client_tools:
tool_executor.client_tools = client_tools
# Base agent kwargs
agent_kwargs = {

View File

@@ -0,0 +1,3 @@
from application.api.v1.routes import v1_bp
__all__ = ["v1_bp"]

View File

@@ -0,0 +1,314 @@
"""Standard chat completions API routes.
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
follow the widely-adopted chat completions protocol so external tools
(opencode, continue, etc.) can connect to DocsGPT agents.
"""
import json
import logging
import time
import traceback
from typing import Any, Dict, Generator, Optional
from flask import Blueprint, jsonify, make_response, request, Response
from application.api.answer.routes.base import BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
from application.api.v1.translator import (
translate_request,
translate_response,
translate_stream_event,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
def _extract_bearer_token() -> Optional[str]:
"""Extract API key from Authorization: Bearer header."""
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth[7:].strip()
return None
def _lookup_agent(api_key: str) -> Optional[Dict]:
"""Look up the agent document for this API key."""
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
return db["agents"].find_one({"key": api_key})
except Exception:
logger.warning("Failed to look up agent for API key", exc_info=True)
return None
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
"""Return agent name for display as model name."""
if agent:
return agent.get("name", api_key)
return api_key
class _V1AnswerHelper(BaseAnswerResource):
"""Thin wrapper to access complete_stream / process_response_stream."""
pass
@v1_bp.route("/chat/completions", methods=["POST"])
def chat_completions():
"""Handle POST /v1/chat/completions."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
data = request.get_json()
if not data or not data.get("messages"):
return make_response(
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
400,
)
is_stream = data.get("stream", False)
agent_doc = _lookup_agent(api_key)
model_name = _get_model_name(agent_doc, api_key)
try:
internal_data = translate_request(data, api_key)
except Exception as e:
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
# Link decoded_token to the agent's owner so continuation state,
# logs, and tool execution use the correct user identity.
agent_user = agent_doc.get("user") if agent_doc else None
decoded_token = {"sub": agent_user or "api_key_user"}
try:
processor = StreamProcessor(internal_data, decoded_token)
if internal_data.get("tool_actions"):
# Continuation mode
conversation_id = internal_data.get("conversation_id")
if not conversation_id:
return make_response(
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
400,
)
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
internal_data["tool_actions"], conversation_id
)
continuation = {
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
}
question = ""
else:
# Normal mode
question = internal_data.get("question", "")
agent = processor.build_agent(question)
continuation = None
if not processor.decoded_token:
return make_response(
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
401,
)
helper = _V1AnswerHelper()
usage_error = helper.check_usage(processor.agent_config)
if usage_error:
return usage_error
if is_stream:
return Response(
_stream_response(
helper, question, agent, processor, model_name, continuation
),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
else:
return _non_stream_response(
helper, question, agent, processor, model_name, continuation
)
except ValueError as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
except Exception as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)
def _stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
) -> Generator[str, None, None]:
"""Generate translated SSE chunks for streaming response."""
completion_id = f"chatcmpl-{int(time.time())}"
internal_stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation=continuation,
)
for line in internal_stream:
if not line.strip():
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")
if conv_id:
completion_id = f"chatcmpl-{conv_id}"
# Translate to standard format
translated = translate_stream_event(event_data, completion_id, model_name)
for chunk in translated:
yield chunk
def _non_stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
) -> Response:
"""Collect full response and return as single JSON."""
stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation=continuation,
)
result = helper.process_response_stream(stream)
if result["error"]:
return make_response(
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
500,
)
extra = result.get("extra")
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
response = translate_response(
conversation_id=result["conversation_id"],
answer=result["answer"] or "",
sources=result["sources"],
tool_calls=result["tool_calls"],
thought=result["thought"] or "",
model_name=model_name,
pending_tool_calls=pending,
)
return make_response(jsonify(response), 200)
@v1_bp.route("/models", methods=["GET"])
def list_models():
"""Handle GET /v1/models — return agents as models."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
# Find the agent for this api_key
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
401,
)
user = agent.get("user")
# Return all agents belonging to this user
user_agents = list(agents_collection.find({"user": user}))
models = []
for ag in user_agents:
created = ag.get("createdAt")
created_ts = int(created.timestamp()) if created else int(time.time())
models.append({
"id": str(ag.get("key", "")),
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",
"name": ag.get("name", ""),
"description": ag.get("description", ""),
})
return make_response(
jsonify({"object": "list", "data": models}),
200,
)
except Exception as e:
logger.error(f"/v1/models error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)

View File

@@ -0,0 +1,415 @@
"""Translate between standard chat completions format and DocsGPT internals.
This module handles:
- Request translation (chat completions -> DocsGPT internal format)
- Response translation (DocsGPT response -> chat completions format)
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
"""
import json
import time
from typing import Any, Dict, List, Optional
def _get_client_tool_name(tc: Dict) -> str:
"""Return the original tool name for client-facing responses.
For client-side tools the ``tool_name`` field carries the name the
client originally registered. Fall back to ``action_name`` (which
is now the clean LLM-visible name) or ``name``.
"""
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
# ---------------------------------------------------------------------------
# Request translation
# ---------------------------------------------------------------------------
def is_continuation(messages: List[Dict]) -> bool:
"""Check if messages represent a tool-call continuation.
A continuation is detected when the last message(s) have ``role: "tool"``
immediately after an assistant message with ``tool_calls``.
"""
if not messages:
return False
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
# and there's an assistant message with tool_calls, it's a continuation.
i = len(messages) - 1
while i >= 0 and messages[i].get("role") == "tool":
i -= 1
if i < 0:
return False
return (
messages[i].get("role") == "assistant"
and bool(messages[i].get("tool_calls"))
)
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
"""Extract tool results from trailing tool messages for continuation.
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
"""
results = []
for msg in reversed(messages):
if msg.get("role") != "tool":
break
call_id = msg.get("tool_call_id", "")
content = msg.get("content", "")
if isinstance(content, str):
try:
content = json.loads(content)
except (json.JSONDecodeError, TypeError):
pass
results.append({"call_id": call_id, "result": content})
results.reverse()
return results
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
"""Try to extract conversation_id from the assistant message before tool results.
The conversation_id may be stored in a custom field on the assistant message
from a previous response cycle.
"""
for msg in reversed(messages):
if msg.get("role") == "assistant":
# Check docsgpt extension
return msg.get("docsgpt", {}).get("conversation_id")
return None
def convert_history(messages: List[Dict]) -> List[Dict]:
"""Convert chat completions messages array to DocsGPT history format.
DocsGPT history is a list of ``{prompt, response}`` dicts.
Excludes the last user message (that becomes the ``question``).
"""
history = []
i = 0
while i < len(messages):
msg = messages[i]
if msg.get("role") == "system":
i += 1
continue
if msg.get("role") == "user":
# Look ahead for assistant response
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
content = messages[i + 1].get("content") or ""
history.append({
"prompt": msg.get("content", ""),
"response": content,
})
i += 2
continue
# Last user message without response — skip (it's the question)
i += 1
continue
i += 1
return history
def translate_request(
data: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""Translate a chat completions request to DocsGPT internal format.
Args:
data: The incoming request body.
api_key: Agent API key from the Authorization header.
Returns:
Dict suitable for passing to ``StreamProcessor``.
"""
messages = data.get("messages", [])
# Check for continuation (tool results after assistant tool_calls)
if is_continuation(messages):
tool_actions = extract_tool_results(messages)
conversation_id = extract_conversation_id(messages)
if not conversation_id:
conversation_id = data.get("conversation_id")
result = {
"conversation_id": conversation_id,
"tool_actions": tool_actions,
"api_key": api_key,
}
# Carry tools forward for next iteration
if data.get("tools"):
result["client_tools"] = data["tools"]
return result
# Normal request — extract question from last user message
question = ""
for msg in reversed(messages):
if msg.get("role") == "user":
question = msg.get("content", "")
break
history = convert_history(messages)
result = {
"question": question,
"api_key": api_key,
"history": json.dumps(history),
"save_conversation": True,
}
# Client tools
if data.get("tools"):
result["client_tools"] = data["tools"]
# DocsGPT extensions
docsgpt = data.get("docsgpt", {})
if docsgpt.get("attachments"):
result["attachments"] = docsgpt["attachments"]
return result
# ---------------------------------------------------------------------------
# Response translation (non-streaming)
# ---------------------------------------------------------------------------
def translate_response(
conversation_id: str,
answer: str,
sources: Optional[List[Dict]],
tool_calls: Optional[List[Dict]],
thought: str,
model_name: str,
pending_tool_calls: Optional[List[Dict]] = None,
) -> Dict[str, Any]:
"""Translate DocsGPT response to chat completions format.
Args:
conversation_id: The DocsGPT conversation ID.
answer: The assistant's text response.
sources: RAG retrieval sources.
tool_calls: Completed tool call results.
thought: Reasoning/thinking tokens.
model_name: Model/agent identifier.
pending_tool_calls: Pending client-side tool calls (if paused).
Returns:
Dict in the standard chat completions response format.
"""
created = int(time.time())
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
# Build message
message: Dict[str, Any] = {"role": "assistant"}
if pending_tool_calls:
# Tool calls pending — return them for client execution
message["content"] = None
message["tool_calls"] = [
{
"id": tc.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc),
"arguments": (
json.dumps(tc["arguments"])
if isinstance(tc.get("arguments"), dict)
else tc.get("arguments", "{}")
),
},
}
for tc in pending_tool_calls
]
finish_reason = "tool_calls"
else:
message["content"] = answer
if thought:
message["reasoning_content"] = thought
finish_reason = "stop"
result: Dict[str, Any] = {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": model_name,
"choices": [
{
"index": 0,
"message": message,
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
# DocsGPT extensions
docsgpt: Dict[str, Any] = {}
if conversation_id:
docsgpt["conversation_id"] = conversation_id
if sources:
docsgpt["sources"] = sources
if tool_calls:
docsgpt["tool_calls"] = tool_calls
if docsgpt:
result["docsgpt"] = docsgpt
return result
# ---------------------------------------------------------------------------
# Streaming event translation
# ---------------------------------------------------------------------------
def _make_chunk(
completion_id: str,
model_name: str,
delta: Dict[str, Any],
finish_reason: Optional[str] = None,
) -> str:
"""Build a single SSE chunk in the standard streaming format."""
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(chunk)}\n\n"
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
"""Build a DocsGPT extension SSE chunk."""
return f"data: {json.dumps({'docsgpt': data})}\n\n"
def translate_stream_event(
event_data: Dict[str, Any],
completion_id: str,
model_name: str,
) -> List[str]:
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
May return 0, 1, or 2 chunks per input event. For example, a completed
tool call produces both a docsgpt extension chunk and nothing on the
standard side (since server-side tool calls aren't surfaced in standard
format).
Args:
event_data: Parsed DocsGPT event dict.
completion_id: The completion ID for this response.
model_name: Model/agent identifier.
Returns:
List of SSE-formatted strings to send to the client.
"""
event_type = event_data.get("type")
chunks: List[str] = []
if event_type == "answer":
chunks.append(
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
)
elif event_type == "thought":
chunks.append(
_make_chunk(
completion_id, model_name,
{"reasoning_content": event_data.get("thought", "")},
)
)
elif event_type == "source":
chunks.append(
_make_docsgpt_chunk({
"type": "source",
"sources": event_data.get("source", []),
})
)
elif event_type == "tool_call":
tc_data = event_data.get("data", {})
status = tc_data.get("status")
if status == "requires_client_execution":
# Standard: stream as tool_calls delta
args = tc_data.get("arguments", {})
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
chunks.append(
_make_chunk(completion_id, model_name, {
"tool_calls": [{
"index": 0,
"id": tc_data.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc_data),
"arguments": args_str,
},
}],
})
)
elif status == "awaiting_approval":
# Extension: approval needed
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif status in ("completed", "pending", "error", "denied", "skipped"):
# Extension: tool call progress
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif event_type == "tool_calls_pending":
# Standard: finish_reason = tool_calls
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
)
# Also emit as docsgpt extension
chunks.append(
_make_docsgpt_chunk({
"type": "tool_calls_pending",
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
})
)
elif event_type == "end":
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
)
chunks.append("data: [DONE]\n\n")
elif event_type == "id":
chunks.append(
_make_docsgpt_chunk({
"type": "id",
"conversation_id": event_data.get("id", ""),
})
)
elif event_type == "error":
# Emit as standard error (non-standard but widely supported)
error_data = {
"error": {
"message": event_data.get("error", "An error occurred"),
"type": "server_error",
}
}
chunks.append(f"data: {json.dumps(error_data)}\n\n")
elif event_type == "structured_answer":
chunks.append(
_make_chunk(
completion_id, model_name,
{"content": event_data.get("answer", "")},
)
)
# Skip: tool_calls (redundant), research_plan, research_progress
return chunks

View File

@@ -17,6 +17,7 @@ from application.api.answer import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.api.v1 import v1_bp # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.stt.upload_limits import ( # noqa: E402
@@ -36,6 +37,7 @@ app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,

View File

@@ -167,6 +167,8 @@ class GoogleLLM(BaseLLM):
return "\n".join(parts)
return ""
import json as _json
for message in messages:
role = message.get("role")
content = message.get("content")
@@ -180,9 +182,66 @@ class GoogleLLM(BaseLLM):
if role == "assistant":
role = "model"
elif role == "tool":
role = "model"
parts = []
# Standard format: assistant message with tool_calls array
msg_tool_calls = message.get("tool_calls")
if msg_tool_calls and role == "model":
for tc in msg_tool_calls:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = _json.loads(args)
except (_json.JSONDecodeError, TypeError):
args = {}
cleaned_args = self._remove_null_values(args)
thought_sig = tc.get("thought_signature")
if thought_sig:
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=func.get("name", ""),
args=cleaned_args,
),
thoughtSignature=thought_sig,
)
)
else:
parts.append(
types.Part.from_function_call(
name=func.get("name", ""),
args=cleaned_args,
)
)
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
continue
# Standard format: tool message with tool_call_id
tool_call_id = message.get("tool_call_id")
if role == "tool" and tool_call_id is not None:
result_content = content
if isinstance(result_content, str):
try:
result_content = _json.loads(result_content)
except (_json.JSONDecodeError, TypeError):
pass
# Google expects function_response name — extract from tool_call_id context
# We use a placeholder name since Google API doesn't require exact match
parts.append(
types.Part.from_function_response(
name="tool_result",
response={"result": result_content},
)
)
cleaned_messages.append(types.Content(role="model", parts=parts))
continue
if role == "tool":
role = "model"
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
@@ -191,15 +250,11 @@ class GoogleLLM(BaseLLM):
if "text" in item:
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
# Remove null values from args to avoid API errors
# Legacy format support
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
@@ -210,7 +265,6 @@ class GoogleLLM(BaseLLM):
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],

View File

@@ -1,3 +1,4 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
@@ -315,10 +316,34 @@ class LLMHandler(ABC):
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
# Standard format: tool_calls array on assistant message
msg_tool_calls = message.get("tool_calls")
if msg_tool_calls:
for tc in msg_tool_calls:
call_id = tc.get("id") or str(uuid.uuid4())
func = tc.get("function", {})
args = func.get("arguments")
if isinstance(args, str):
try:
args = json.loads(args)
except (json.JSONDecodeError, TypeError):
pass
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": func.get("name"),
"arguments": args,
"result": None,
"status": "called",
"call_id": call_id,
}
continue
# Legacy format: function_call/function_response in content list
if isinstance(content, list):
has_fc = False
for item in content:
if "function_call" in item:
has_fc = True
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
@@ -329,37 +354,30 @@ class LLMHandler(ABC):
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
if has_fc:
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
# Standard format: tool_call_id on tool message
call_id = message.get("tool_call_id")
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
# Legacy: function_response in content list
elif isinstance(content, list):
for item in content:
if "function_response" in item:
legacy_id = item["function_response"].get("call_id")
if legacy_id and legacy_id in current_tool_calls:
current_tool_calls[legacy_id]["result"] = tool_text
current_tool_calls[legacy_id]["status"] = "completed"
break
elif call_id is None and queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
@@ -648,6 +666,13 @@ class LLMHandler(ABC):
"""
Execute tool calls and update conversation history.
When a tool requires approval or client-side execution, it is
collected as a pending action instead of being executed. The
generator returns ``(updated_messages, pending_actions)`` where
*pending_actions* is ``None`` when every tool was executed
normally, or a list of dicts describing actions the client must
resolve before the LLM loop can continue.
Args:
agent: The agent instance
tool_calls: List of tool calls to execute
@@ -655,9 +680,11 @@ class LLMHandler(ABC):
messages: Current conversation history
Returns:
Updated messages list
Tuple of (updated_messages, pending_actions).
pending_actions is None if all tools executed, otherwise a list.
"""
updated_messages = messages.copy()
pending_actions: List[Dict] = []
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
@@ -763,6 +790,29 @@ class LLMHandler(ABC):
# Set flag on agent
agent.context_limit_reached = True
break
# ---- Pause check: approval / client-side execution ----
llm_class = agent.llm.__class__.__name__
pause_info = agent.tool_executor.check_pause(
tools_dict, call, llm_class
)
if pause_info:
# Yield pause event so the client knows this tool is waiting
yield {
"type": "tool_call",
"data": {
"tool_name": pause_info["tool_name"],
"call_id": pause_info["call_id"],
"action_name": pause_info.get("llm_name", pause_info["name"]),
"arguments": pause_info["arguments"],
"status": pause_info["pause_type"],
},
}
pending_actions.append(pause_info)
# Do NOT add messages for pending tools here.
# They will be added on resume to keep call/result pairs together.
continue
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -772,25 +822,30 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [function_call_content],
}
)
# Standard internal format: assistant message with tool_calls array
args_str = (
json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
tool_call_obj = {
"id": call_id,
"type": "function",
"function": {
"name": call.name,
"arguments": args_str,
},
}
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
tool_call_obj["thought_signature"] = call.thought_signature
updated_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tool_call_obj],
})
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
@@ -802,16 +857,15 @@ class LLMHandler(ABC):
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
mapping = agent.tool_executor._name_to_tool
if call.name in mapping:
resolved_tool_id, _ = mapping[call.name]
tool_name = tools_dict.get(resolved_tool_id, {}).get(
"name", "unknown_tool"
)
else:
tool_name = "unknown_tool"
action_name = call.name
full_action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
@@ -823,7 +877,7 @@ class LLMHandler(ABC):
"status": "error",
},
}
return updated_messages
return updated_messages, pending_actions if pending_actions else None
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
@@ -851,8 +905,22 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
messages, pending_actions = e.value
break
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return ""
response = agent.llm.gen(
model=agent.model_id, messages=messages, tools=agent.tools
)
@@ -913,10 +981,23 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
messages, pending_actions = e.value
break
tool_calls = {}
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit

View File

@@ -67,18 +67,18 @@ class GoogleLLMHandler(LLMHandler):
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
"""Create a tool result message in the standard internal format."""
import json as _json
content = (
_json.dumps(result)
if not isinstance(result, str)
else result
)
return {
"role": "model",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
}
}
],
"role": "tool",
"tool_call_id": tool_call.id,
"content": content,
}
def _iterate_stream(self, response: Any) -> Generator:

View File

@@ -37,18 +37,18 @@ class OpenAILLMHandler(LLMHandler):
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create OpenAI-style tool message."""
"""Create a tool result message in the standard internal format."""
import json as _json
content = (
_json.dumps(result)
if not isinstance(result, str)
else result
)
return {
"role": "tool",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
"call_id": tool_call.id,
}
}
],
"tool_call_id": tool_call.id,
"content": content,
}
def _iterate_stream(self, response: Any) -> Generator:

View File

@@ -91,16 +91,52 @@ class OpenAILLM(BaseLLM):
if role == "model":
role = "assistant"
# Standard format: assistant message with tool_calls (passthrough)
tool_calls = message.get("tool_calls")
if tool_calls and role == "assistant":
cleaned_tcs = []
for tc in tool_calls:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, dict):
args = json.dumps(self._remove_null_values(args))
elif isinstance(args, str):
try:
parsed = json.loads(args)
args = json.dumps(self._remove_null_values(parsed))
except (json.JSONDecodeError, TypeError):
pass
cleaned_tcs.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": func.get("name", ""), "arguments": args},
})
cleaned_messages.append({
"role": "assistant",
"content": None,
"tool_calls": cleaned_tcs,
})
continue
# Standard format: tool message with tool_call_id (passthrough)
tool_call_id = message.get("tool_call_id")
if role == "tool" and tool_call_id is not None:
cleaned_messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"content": content if isinstance(content, str) else json.dumps(content),
})
continue
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Collect all content parts into a single message
content_parts = []
for item in content:
# Legacy format support: function_call / function_response
if "function_call" in item:
# Function calls need their own message
args = item["function_call"]["args"]
if isinstance(args, str):
try:
@@ -116,28 +152,20 @@ class OpenAILLM(BaseLLM):
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
cleaned_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
})
elif "function_response" in item:
# Function responses need their own message
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
cleaned_messages.append({
"role": "tool",
"tool_call_id": item["function_response"]["call_id"],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
})
elif isinstance(item, dict):
# Collect content parts (text, images, files) into a single message
if "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(item)
elif "type" in item and item["type"] == "file" and "file" in item:
@@ -145,10 +173,7 @@ class OpenAILLM(BaseLLM):
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
content_parts.append(item)
elif "text" in item and "type" not in item:
# Legacy format: {"text": "..."} without type
content_parts.append({"type": "text", "text": item["text"]})
# Add the collected content parts as a single message
if content_parts:
cleaned_messages.append({"role": role, "content": content_parts})
else:

View File

@@ -77,6 +77,10 @@ const endpoints = {
WORKFLOWS: '/api/workflows',
WORKFLOW: (id: string) => `/api/workflows/${id}`,
},
V1: {
CHAT_COMPLETIONS: '/v1/chat/completions',
MODELS: '/v1/models',
},
CONVERSATION: {
ANSWER: '/api/answer',
ANSWER_STREAMING: '/stream',

View File

@@ -54,6 +54,18 @@ const conversationService = {
apiClient.get(endpoints.CONVERSATION.DELETE_ALL, token, {}),
update: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.CONVERSATION.UPDATE, data, token, {}),
chatCompletions: (
data: any,
agentApiKey: string,
signal: AbortSignal,
): Promise<any> =>
apiClient.post(
endpoints.V1.CHAT_COMPLETIONS,
data,
null,
{ Authorization: `Bearer ${agentApiKey}` },
signal,
),
};
export default conversationService;

View File

@@ -22,6 +22,7 @@ import {
resendQuery,
selectQueries,
selectStatus,
submitToolActions,
updateQuery,
} from './conversationSlice';
import { selectCompletedAttachments } from '../upload/uploadSlice';
@@ -41,6 +42,17 @@ export default function Conversation() {
const [lastQueryReturnedErr, setLastQueryReturnedErr] =
useState<boolean>(false);
const handleToolAction = useCallback(
(callId: string, decision: 'approved' | 'denied', comment?: string) => {
dispatch(
submitToolActions({
toolActions: [{ call_id: callId, decision, comment }],
}),
);
},
[dispatch],
);
const lastAutoOpenedArtifactId = useRef<string | null>(null);
const didInitArtifactAutoOpen = useRef(false);
const prevConversationId = useRef<string | null>(conversationId);
@@ -233,6 +245,7 @@ export default function Conversation() {
status={status}
showHeroOnEmpty={selectedAgent ? false : true}
onOpenArtifact={handleOpenArtifact}
onToolAction={handleToolAction}
isSplitView={isSplitArtifactOpen}
headerContent={
selectedAgent ? (

View File

@@ -65,6 +65,11 @@ const ConversationBubble = forwardRef<
) => void;
filesAttached?: { id: string; fileName: string }[];
onOpenArtifact?: (artifact: { id: string; toolName: string }) => void;
onToolAction?: (
callId: string,
decision: 'approved' | 'denied',
comment?: string,
) => void;
}
>(function ConversationBubble(
{
@@ -83,6 +88,7 @@ const ConversationBubble = forwardRef<
handleUpdatedQuestionSubmission,
filesAttached,
onOpenArtifact,
onToolAction,
},
ref,
) {
@@ -411,7 +417,7 @@ const ConversationBubble = forwardRef<
)}
{research && <ResearchProgress research={research} />}
{toolCalls && toolCalls.length > 0 && (
<ToolCalls toolCalls={toolCalls} />
<ToolCalls toolCalls={toolCalls} onToolAction={onToolAction} />
)}
{!message && primaryArtifactCall?.artifact_id && onOpenArtifact && (
<div className="my-2 ml-2 flex justify-start">
@@ -884,108 +890,263 @@ function AllSources(sources: AllSourcesProps) {
}
export default ConversationBubble;
function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
function ToolCallApprovalBar({
toolCall,
onToolAction,
}: {
toolCall: ToolCallsType;
onToolAction?: (
callId: string,
decision: 'approved' | 'denied',
comment?: string,
) => void;
}) {
const [expanded, setExpanded] = useState(false);
const [comment, setComment] = useState('');
const actionLabel = toolCall.action_name.substring(
0,
toolCall.action_name.lastIndexOf('_'),
);
const argPreview = JSON.stringify(toolCall.arguments);
const truncated =
argPreview.length > 60 ? argPreview.slice(0, 57) + '...' : argPreview;
return (
<div className="border-border bg-muted dark:bg-card mb-2 w-full overflow-hidden rounded-2xl border">
<div className="flex items-center gap-3 px-4 py-2.5">
<div className="flex min-w-0 flex-1 items-center gap-2">
<span className="text-sm font-semibold whitespace-nowrap">
{toolCall.tool_name}
</span>
<span className="text-muted-foreground text-xs">{actionLabel}</span>
<span
className="text-muted-foreground hidden min-w-0 truncate font-mono text-xs md:block"
title={argPreview}
>
{truncated}
</span>
</div>
<div className="flex items-center gap-2">
<button
className={`rounded-full px-4 py-1 text-xs font-medium transition-colors ${
comment
? 'bg-muted text-muted-foreground cursor-default opacity-50'
: 'bg-primary hover:bg-primary/90 text-white'
}`}
onClick={() => {
if (!comment) onToolAction?.(toolCall.call_id, 'approved');
}}
>
Approve
</button>
<button
className={`rounded-full border px-4 py-1 text-xs font-medium transition-colors ${
comment
? 'border-destructive bg-destructive/10 text-destructive font-semibold'
: 'hover:bg-accent text-muted-foreground'
}`}
onClick={() => {
if (expanded && comment) {
onToolAction?.(toolCall.call_id, 'denied', comment);
} else if (expanded) {
onToolAction?.(toolCall.call_id, 'denied');
} else {
setExpanded(true);
}
}}
>
Deny
</button>
<button
className="text-muted-foreground hover:text-foreground flex h-6 w-6 items-center justify-center rounded-full transition-colors"
onClick={() => setExpanded(!expanded)}
title="Details"
>
<img
src={ChevronDown}
alt="expand"
className={`h-3.5 w-3.5 transition-transform duration-200 dark:invert ${expanded ? 'rotate-180' : ''}`}
/>
</button>
</div>
</div>
{expanded && (
<div className="border-border border-t px-4 py-3">
<p className="text-muted-foreground mb-1 text-xs font-medium">
Arguments
</p>
<pre className="bg-background dark:bg-background/50 mb-2 max-h-40 overflow-auto rounded-lg p-2 font-mono text-xs">
{JSON.stringify(toolCall.arguments, null, 2)}
</pre>
<input
type="text"
placeholder="Optional reason for denying..."
className="border-border bg-background w-full rounded-lg border px-3 py-1.5 text-sm"
value={comment}
onChange={(e) => setComment(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter' && comment) {
onToolAction?.(toolCall.call_id, 'denied', comment);
}
}}
/>
</div>
)}
</div>
);
}
function ToolCalls({
toolCalls,
onToolAction,
}: {
toolCalls: ToolCallsType[];
onToolAction?: (
callId: string,
decision: 'approved' | 'denied',
comment?: string,
) => void;
}) {
const [isToolCallsOpen, setIsToolCallsOpen] = useState(false);
const awaitingCalls = toolCalls.filter(
(tc) => tc.status === 'awaiting_approval',
);
const resolvedCalls = toolCalls.filter(
(tc) => tc.status !== 'awaiting_approval',
);
return (
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
alt={'ToolCalls'}
className="h-full w-full object-fill"
{/* Approval bars — always visible, compact inline */}
{awaitingCalls.length > 0 && (
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
{awaitingCalls.map((tc) => (
<ToolCallApprovalBar
key={`approval-${tc.call_id}`}
toolCall={tc}
onToolAction={onToolAction}
/>
}
/>
<button
className="flex flex-row items-center gap-2"
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
>
<p className="text-base font-semibold">Tool Calls</p>
<img
src={ChevronDown}
alt="ChevronDown"
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
/>
</button>
</div>
{isToolCallsOpen && (
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
<div className="grid grid-cols-1 gap-2">
{toolCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
titleClassName="px-6 py-2 text-sm font-semibold"
>
<div className="flex flex-col gap-1">
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Arguments
</span>{' '}
<CopyButton
textToCopy={JSON.stringify(toolCall.arguments, null, 2)}
/>
</p>
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.arguments, null, 2)}
</span>
</p>
</div>
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Response
</span>{' '}
<CopyButton
textToCopy={
toolCall.status === 'error'
? toolCall.error || 'Unknown error'
: JSON.stringify(toolCall.result, null, 2)
}
/>
</p>
{toolCall.status === 'pending' && (
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
<Spinner size="small" />
</span>
)}
{toolCall.status === 'completed' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.result, null, 2)}
</span>
</p>
)}
{toolCall.status === 'error' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="leading-[23px] text-red-500 dark:text-red-400"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
</span>
</p>
)}
</div>
</div>
</Accordion>
))}
</div>
))}
</div>
)}
{/* Regular tool calls accordion */}
{resolvedCalls.length > 0 && (
<>
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
alt={'ToolCalls'}
className="h-full w-full object-fill"
/>
}
/>
<button
className="flex flex-row items-center gap-2"
onClick={() => setIsToolCallsOpen(!isToolCallsOpen)}
>
<p className="text-base font-semibold">Tool Calls</p>
<img
src={ChevronDown}
alt="ChevronDown"
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${isToolCallsOpen ? 'rotate-180' : ''}`}
/>
</button>
</div>
{isToolCallsOpen && (
<div className="fade-in mr-5 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
<div className="grid grid-cols-1 gap-2">
{resolvedCalls.map((toolCall, index) => (
<Accordion
key={`tool-call-${index}`}
title={`${toolCall.tool_name} - ${toolCall.action_name.substring(0, toolCall.action_name.lastIndexOf('_'))}`}
className="bg-muted dark:bg-answer-bubble w-full rounded-4xl"
titleClassName="px-6 py-2 text-sm font-semibold"
>
<div className="flex flex-col gap-1">
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Arguments
</span>{' '}
<CopyButton
textToCopy={JSON.stringify(
toolCall.arguments,
null,
2,
)}
/>
</p>
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.arguments, null, 2)}
</span>
</p>
</div>
<div className="border-border flex flex-col rounded-2xl border">
<p className="dark:bg-background flex flex-row items-center justify-between rounded-t-2xl bg-black/10 px-2 py-1 text-sm font-semibold wrap-break-word">
<span style={{ fontFamily: 'IBMPlexMono-Medium' }}>
Response
</span>{' '}
<CopyButton
textToCopy={
toolCall.status === 'error'
? toolCall.error || 'Unknown error'
: JSON.stringify(toolCall.result, null, 2)
}
/>
</p>
{toolCall.status === 'pending' && (
<span className="dark:bg-card flex w-full items-center justify-center rounded-b-2xl p-2">
<Spinner size="small" />
</span>
)}
{toolCall.status === 'completed' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.result, null, 2)}
</span>
</p>
)}
{toolCall.status === 'error' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-destructive leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
</span>
</p>
)}
{toolCall.status === 'denied' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-muted-foreground leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
Denied by user
</span>
</p>
)}
</div>
</div>
</Accordion>
))}
</div>
</div>
)}
</>
)}
</div>
);
}

View File

@@ -38,6 +38,11 @@ type ConversationMessagesProps = {
showHeroOnEmpty?: boolean;
headerContent?: ReactNode;
onOpenArtifact?: (artifact: { id: string; toolName: string }) => void;
onToolAction?: (
callId: string,
decision: 'approved' | 'denied',
comment?: string,
) => void;
isSplitView?: boolean;
};
@@ -50,6 +55,7 @@ export default function ConversationMessages({
showHeroOnEmpty = true,
headerContent,
onOpenArtifact,
onToolAction,
isSplitView = false,
}: ConversationMessagesProps) {
const [isDarkTheme] = useDarkTheme();
@@ -154,6 +160,7 @@ export default function ConversationMessages({
toolCalls={query.tool_calls}
research={query.research}
onOpenArtifact={onOpenArtifact}
onToolAction={onToolAction}
feedback={query.feedback}
isStreaming={isCurrentlyStreaming}
handleFeedback={

View File

@@ -188,6 +188,264 @@ export function handleFetchAnswerSteaming(
});
}
export function handleSubmitToolActions(
conversationId: string,
toolActions: {
call_id: string;
decision?: 'approved' | 'denied';
comment?: string;
result?: Record<string, any>;
}[],
token: string | null,
signal: AbortSignal,
onEvent: (event: MessageEvent) => void,
): Promise<Answer> {
const payload = {
conversation_id: conversationId,
tool_actions: toolActions,
};
return new Promise<Answer>((resolve, reject) => {
conversationService
.answerStream(payload, token, signal)
.then((response) => {
if (!response.body) throw Error('No response body');
let buffer = '';
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
const processStream = ({
done,
value,
}: ReadableStreamReadResult<Uint8Array>) => {
if (done) return;
const chunk = decoder.decode(value);
buffer += chunk;
const events = buffer.split('\n\n');
buffer = events.pop() ?? '';
for (const event of events) {
if (event.trim().startsWith('data:')) {
const dataLine: string = event
.split('\n')
.map((line: string) => line.replace(/^data:\s?/, ''))
.join('');
const messageEvent = new MessageEvent('message', {
data: dataLine.trim(),
});
onEvent(messageEvent);
}
}
reader.read().then(processStream).catch(reject);
};
reader.read().then(processStream).catch(reject);
})
.catch((error) => {
console.error('Tool actions submission failed:', error);
reject(error);
});
});
}
/**
* Stream a chat completion via the /v1/chat/completions endpoint.
*
* Translates the standard streaming format (choices[0].delta) back into
* the internal DocsGPT event shape so the existing Redux reducers can
* consume the events without any changes.
*/
export function handleV1ChatCompletionStreaming(
question: string,
signal: AbortSignal,
agentApiKey: string,
history: { prompt: string; response: string }[],
onEvent: (event: MessageEvent) => void,
tools?: any[],
attachments?: string[],
): Promise<Answer> {
// Build messages array from history + current question
const messages: any[] = [];
for (const h of history) {
messages.push({ role: 'user', content: h.prompt });
messages.push({ role: 'assistant', content: h.response });
}
messages.push({ role: 'user', content: question });
const payload: any = {
messages,
stream: true,
};
if (tools && tools.length > 0) {
payload.tools = tools;
}
if (attachments && attachments.length > 0) {
payload.docsgpt = { attachments };
}
return new Promise<Answer>((resolve, reject) => {
conversationService
.chatCompletions(payload, agentApiKey, signal)
.then((response) => {
if (!response.body) throw Error('No response body');
let buffer = '';
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
const processStream = ({
done,
value,
}: ReadableStreamReadResult<Uint8Array>) => {
if (done) return;
const chunk = decoder.decode(value);
buffer += chunk;
const events = buffer.split('\n\n');
buffer = events.pop() ?? '';
for (const event of events) {
if (!event.trim().startsWith('data:')) continue;
const dataLine = event
.split('\n')
.map((line: string) => line.replace(/^data:\s?/, ''))
.join('');
const trimmed = dataLine.trim();
// Handle [DONE] sentinel
if (trimmed === '[DONE]') {
onEvent(
new MessageEvent('message', {
data: JSON.stringify({ type: 'end' }),
}),
);
continue;
}
try {
const parsed = JSON.parse(trimmed);
// Translate standard format to DocsGPT internal events
const translated = translateV1ChunkToInternalEvents(parsed);
for (const evt of translated) {
onEvent(
new MessageEvent('message', {
data: JSON.stringify(evt),
}),
);
}
} catch {
// Skip unparseable chunks
}
}
reader.read().then(processStream).catch(reject);
};
reader.read().then(processStream).catch(reject);
})
.catch((error) => {
console.error('V1 chat completion stream failed:', error);
reject(error);
});
});
}
/**
* Translate a single v1 streaming chunk to internal DocsGPT event(s).
*
* Standard format: {"choices": [{"delta": {"content": "chunk"}, ...}]}
* Extension format: {"docsgpt": {"type": "source", ...}}
*/
function translateV1ChunkToInternalEvents(
chunk: any,
): { type: string; [key: string]: any }[] {
const events: { type: string; [key: string]: any }[] = [];
// DocsGPT extension chunks
if (chunk.docsgpt) {
const ext = chunk.docsgpt;
if (ext.type === 'source') {
events.push({ type: 'source', source: ext.sources });
} else if (ext.type === 'tool_call') {
events.push({ type: 'tool_call', data: ext.data });
} else if (ext.type === 'tool_calls_pending') {
events.push({
type: 'tool_calls_pending',
data: { pending_tool_calls: ext.pending_tool_calls },
});
} else if (ext.type === 'id') {
events.push({ type: 'id', id: ext.conversation_id });
}
return events;
}
// Error chunks
if (chunk.error) {
events.push({ type: 'error', error: chunk.error.message || 'Error' });
return events;
}
// Standard choices chunks
const choice = chunk.choices?.[0];
if (!choice) return events;
const delta = choice.delta || {};
const finishReason = choice.finish_reason;
if (delta.content) {
events.push({ type: 'answer', answer: delta.content });
}
if (delta.reasoning_content) {
events.push({ type: 'thought', thought: delta.reasoning_content });
}
if (delta.tool_calls) {
for (const tc of delta.tool_calls) {
let parsedArgs: Record<string, any> = {};
if (tc.function?.arguments) {
try {
parsedArgs = JSON.parse(tc.function.arguments);
} catch {
// Arguments may arrive as fragments during streaming;
// keep the raw string so downstream can accumulate it.
parsedArgs = { _raw: tc.function.arguments };
}
}
events.push({
type: 'tool_call',
data: {
call_id: tc.id,
action_name: tc.function?.name || '',
tool_name: tc.function?.name || '',
arguments: parsedArgs,
status: 'requires_client_execution',
},
});
}
}
if (finishReason === 'stop') {
events.push({ type: 'end' });
} else if (finishReason === 'tool_calls') {
events.push({
type: 'tool_calls_pending',
data: { pending_tool_calls: [] },
});
}
return events;
}
export function handleSearch(
question: string,
token: string | null,

View File

@@ -1,7 +1,7 @@
import { ToolCallsType } from './types';
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
export type Status = 'idle' | 'loading' | 'failed';
export type Status = 'idle' | 'loading' | 'failed' | 'awaiting_tool_actions';
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
export interface Message {

View File

@@ -10,6 +10,8 @@ import {
import {
handleFetchAnswer,
handleFetchAnswerSteaming,
handleSubmitToolActions,
handleV1ChatCompletionStreaming,
} from './conversationHandlers';
import {
Answer,
@@ -27,6 +29,7 @@ const initialState: ConversationState = {
};
const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true';
const USE_V1_API = import.meta.env.VITE_USE_V1_API === 'true';
let abortController: AbortController | null = null;
export function handleAbort() {
@@ -60,7 +63,102 @@ export const fetchAnswer = createAsyncThunk<
state.preference.selectedModel?.id;
if (state.preference) {
if (API_STREAMING) {
const agentKey = state.preference.selectedAgent?.key;
if (USE_V1_API && agentKey) {
// Build history from prior queries for v1 format
const v1History = state.conversation.queries
.filter((q) => q.response)
.map((q) => ({ prompt: q.prompt, response: q.response || '' }));
await handleV1ChatCompletionStreaming(
question,
signal,
agentKey,
v1History,
(event) => {
const data = JSON.parse(event.data);
const targetIndex = indx ?? state.conversation.queries.length - 1;
if (currentConversationId === state.conversation.conversationId) {
if (data.type === 'end') {
dispatch(conversationSlice.actions.setStatus('idle'));
getConversations(state.preference.token)
.then((fetchedConversations) => {
dispatch(setConversations(fetchedConversations));
})
.catch((error) => {
console.error('Failed to fetch conversations: ', error);
});
if (!isSourceUpdated) {
dispatch(
updateStreamingSource({
conversationId: currentConversationId,
index: targetIndex,
query: { sources: [] },
}),
);
}
} else if (data.type === 'id') {
const currentState = getState() as RootState;
if (currentState.conversation.conversationId === null) {
dispatch(
updateConversationId({
query: { conversationId: data.id },
}),
);
}
} else if (data.type === 'thought') {
dispatch(
updateThought({
conversationId: currentConversationId,
index: targetIndex,
query: { thought: data.thought },
}),
);
} else if (data.type === 'source') {
isSourceUpdated = true;
dispatch(
updateStreamingSource({
conversationId: currentConversationId,
index: targetIndex,
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_call') {
dispatch(
updateToolCall({
index: targetIndex,
tool_call: data.data as ToolCallsType,
}),
);
} else if (data.type === 'tool_calls_pending') {
dispatch(
conversationSlice.actions.setStatus('awaiting_tool_actions'),
);
} else if (data.type === 'error') {
dispatch(conversationSlice.actions.setStatus('failed'));
dispatch(
conversationSlice.actions.raiseError({
conversationId: currentConversationId,
index: targetIndex,
message: data.error,
}),
);
} else {
dispatch(
updateStreamingQuery({
conversationId: currentConversationId,
index: targetIndex,
query: { response: data.answer },
}),
);
}
}
},
undefined,
attachmentIds.length > 0 ? attachmentIds : undefined,
);
} else if (API_STREAMING) {
await handleFetchAnswerSteaming(
question,
signal,
@@ -138,6 +236,10 @@ export const fetchAnswer = createAsyncThunk<
tool_call: data.data as ToolCallsType,
}),
);
} else if (data.type === 'tool_calls_pending') {
dispatch(
conversationSlice.actions.setStatus('awaiting_tool_actions'),
);
} else if (data.type === 'research_plan') {
dispatch(
updateResearchPlan({
@@ -260,6 +362,94 @@ export const fetchAnswer = createAsyncThunk<
};
});
export const submitToolActions = createAsyncThunk<
void,
{
toolActions: {
call_id: string;
decision?: 'approved' | 'denied';
comment?: string;
result?: Record<string, any>;
}[];
}
>('submitToolActions', async ({ toolActions }, { dispatch, getState }) => {
if (abortController) abortController.abort();
abortController = new AbortController();
const { signal } = abortController;
const state = getState() as RootState;
const conversationId = state.conversation.conversationId;
if (!conversationId) return;
dispatch(conversationSlice.actions.setStatus('loading'));
await handleSubmitToolActions(
conversationId,
toolActions,
state.preference.token,
signal,
(event) => {
const data = JSON.parse(event.data);
const targetIndex = state.conversation.queries.length - 1;
if (data.type === 'end') {
dispatch(conversationSlice.actions.setStatus('idle'));
getConversations(state.preference.token)
.then((fetchedConversations) => {
dispatch(setConversations(fetchedConversations));
})
.catch((error) => {
console.error('Failed to fetch conversations: ', error);
});
} else if (data.type === 'id') {
// conversation ID already set
} else if (data.type === 'thought') {
dispatch(
updateThought({
conversationId,
index: targetIndex,
query: { thought: data.thought },
}),
);
} else if (data.type === 'source') {
dispatch(
updateStreamingSource({
conversationId,
index: targetIndex,
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_call') {
dispatch(
updateToolCall({
index: targetIndex,
tool_call: data.data as ToolCallsType,
}),
);
} else if (data.type === 'tool_calls_pending') {
dispatch(conversationSlice.actions.setStatus('awaiting_tool_actions'));
} else if (data.type === 'error') {
dispatch(conversationSlice.actions.setStatus('failed'));
dispatch(
conversationSlice.actions.raiseError({
conversationId,
index: targetIndex,
message: data.error,
}),
);
} else if (data.type === 'answer') {
dispatch(
updateStreamingQuery({
conversationId,
index: targetIndex,
query: { response: data.answer },
}),
);
}
},
);
});
export const conversationSlice = createSlice({
name: 'conversation',
initialState,

View File

@@ -5,6 +5,12 @@ export type ToolCallsType = {
arguments: Record<string, any>;
result?: Record<string, any>;
error?: string;
status?: 'pending' | 'completed' | 'error';
status?:
| 'pending'
| 'completed'
| 'error'
| 'awaiting_approval'
| 'denied'
| 'requires_client_execution';
artifact_id?: string;
};

View File

@@ -487,9 +487,33 @@ export default function ToolConfig({
)}
</div>
<div
className="flex items-center gap-2"
className="flex items-center gap-3"
onClick={(e) => e.stopPropagation()}
>
<div className="flex items-center gap-1">
<span className="text-xs text-gray-500 dark:text-gray-400">
{t('settings.tools.requireApproval', 'Approval')}
</span>
<ToggleSwitch
checked={action.require_approval ?? false}
onChange={(checked) => {
setTool({
...tool,
actions: tool.actions.map((act, index) => {
if (index === originalIndex) {
return {
...act,
require_approval: checked,
};
}
return act;
}),
});
}}
size="small"
id={`approvalToggle-${originalIndex}`}
/>
</div>
<ToggleSwitch
checked={action.active}
onChange={(checked) => {
@@ -926,6 +950,35 @@ function APIToolConfig({
className="h-4 w-4 opacity-40 transition-opacity hover:opacity-100"
/>
</button>
<div className="flex items-center gap-1">
<span className="text-xs text-gray-500 dark:text-gray-400">
{t('settings.tools.requireApproval', 'Approval')}
</span>
<ToggleSwitch
checked={action.require_approval ?? false}
onChange={() => {
setApiTool((prevApiTool) => {
const updatedActions = {
...prevApiTool.config.actions,
};
updatedActions[actionName] = {
...updatedActions[actionName],
require_approval:
!updatedActions[actionName].require_approval,
};
return {
...prevApiTool,
config: {
...prevApiTool.config,
actions: updatedActions,
},
};
});
}}
size="small"
id={`approvalToggle-${actionIndex}`}
/>
</div>
<ToggleSwitch
checked={action.active}
onChange={() => handleActionToggle(actionName)}

View File

@@ -69,6 +69,7 @@ export type UserToolType = {
type: string;
};
active: boolean;
require_approval?: boolean;
}[];
};
@@ -81,6 +82,7 @@ export type APIActionType = {
headers: ParameterGroupType;
body: ParameterGroupType;
active: boolean;
require_approval?: boolean;
body_content_type?:
| 'application/json'
| 'application/x-www-form-urlencoded'

View File

@@ -341,7 +341,7 @@ class TestBaseAgentTools:
assert len(agent.tools) == 1
assert agent.tools[0]["type"] == "function"
assert agent.tools[0]["function"]["name"] == "get_data_1"
assert agent.tools[0]["function"]["name"] == "get_data"
def test_prepare_tools_with_regular_tool(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
@@ -365,7 +365,7 @@ class TestBaseAgentTools:
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["function"]["name"] == "action1_1"
assert agent.tools[0]["function"]["name"] == "action1"
def test_prepare_tools_filters_inactive_actions(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
@@ -395,7 +395,7 @@ class TestBaseAgentTools:
agent._prepare_tools(tools_dict)
assert len(agent.tools) == 1
assert agent.tools[0]["function"]["name"] == "active_action_1"
assert agent.tools[0]["function"]["name"] == "active_action"
@pytest.mark.unit

View File

@@ -202,3 +202,69 @@ class TestToolActionParser:
assert action_name == "create_record"
assert call_args["data"]["name"] == "John"
assert call_args["data"]["age"] == 30
@pytest.mark.unit
class TestToolActionParserWithMapping:
"""Tests for the mapping-based lookup path."""
def test_openai_mapping_resolves_clean_name(self):
mapping = {"get_weather": ("ct0", "get_weather")}
parser = ToolActionParser("OpenAILLM", name_mapping=mapping)
call = Mock()
call.name = "get_weather"
call.arguments = '{"city": "SF"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "ct0"
assert action_name == "get_weather"
assert call_args == {"city": "SF"}
def test_openai_mapping_resolves_numbered_suffix(self):
mapping = {"search_1": ("t1", "search"), "search_2": ("t2", "search")}
parser = ToolActionParser("OpenAILLM", name_mapping=mapping)
call = Mock()
call.name = "search_1"
call.arguments = '{"q": "test"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "t1"
assert action_name == "search"
def test_google_mapping_resolves(self):
mapping = {"get_weather": ("ct0", "get_weather")}
parser = ToolActionParser("GoogleLLM", name_mapping=mapping)
call = Mock()
call.name = "get_weather"
call.arguments = {"city": "SF"}
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "ct0"
assert action_name == "get_weather"
def test_fallback_to_split_when_not_in_mapping(self):
mapping = {"get_weather": ("ct0", "get_weather")}
parser = ToolActionParser("OpenAILLM", name_mapping=mapping)
call = Mock()
call.name = "unknown_action_99"
call.arguments = "{}"
tool_id, action_name, call_args = parser.parse_args(call)
# Falls back to legacy split
assert tool_id == "99"
assert action_name == "unknown_action"
def test_no_mapping_uses_legacy_split(self):
parser = ToolActionParser("OpenAILLM", name_mapping=None)
call = Mock()
call.name = "action_123"
call.arguments = '{"k": "v"}'
tool_id, action_name, call_args = parser.parse_args(call)
assert tool_id == "123"
assert action_name == "action"

View File

@@ -80,7 +80,7 @@ class TestToolExecutorPrepare:
result = executor.prepare_tools_for_llm(tools_dict)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["function"]["name"] == "do_thing_t1"
assert result[0]["function"]["name"] == "do_thing"
assert "query" in result[0]["function"]["parameters"]["properties"]
def test_prepare_tools_skips_inactive_actions(self):
@@ -97,7 +97,96 @@ class TestToolExecutorPrepare:
result = executor.prepare_tools_for_llm(tools_dict)
assert len(result) == 1
assert result[0]["function"]["name"] == "active_one_t1"
assert result[0]["function"]["name"] == "active_one"
def test_prepare_tools_builds_name_mapping(self):
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "test_tool",
"actions": [
{"name": "do_thing", "description": "D", "active": True, "parameters": {"properties": {}}},
],
}
}
executor.prepare_tools_for_llm(tools_dict)
assert executor._name_to_tool["do_thing"] == ("t1", "do_thing")
assert executor._tool_to_name[("t1", "do_thing")] == "do_thing"
def test_prepare_tools_duplicate_names_get_numbered_suffixes(self):
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "tool_a",
"actions": [
{"name": "search", "description": "D", "active": True, "parameters": {"properties": {}}},
],
},
"t2": {
"name": "tool_b",
"actions": [
{"name": "search", "description": "D", "active": True, "parameters": {"properties": {}}},
],
},
}
result = executor.prepare_tools_for_llm(tools_dict)
names = [r["function"]["name"] for r in result]
assert "search_1" in names
assert "search_2" in names
assert executor._name_to_tool["search_1"][1] == "search"
assert executor._name_to_tool["search_2"][1] == "search"
def test_prepare_tools_unique_name_no_suffix(self):
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "tool_a",
"actions": [
{"name": "get_weather", "description": "D", "active": True, "parameters": {"properties": {}}},
],
},
"t2": {
"name": "tool_b",
"actions": [
{"name": "send_email", "description": "D", "active": True, "parameters": {"properties": {}}},
],
},
}
result = executor.prepare_tools_for_llm(tools_dict)
names = [r["function"]["name"] for r in result]
assert "get_weather" in names
assert "send_email" in names
def test_prepare_tools_suffix_skips_collision_with_unique_name(self):
"""If action 'foo_1' exists as unique and 'foo' is duplicated, skip '_1'."""
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "tool_a",
"actions": [
{"name": "foo", "description": "D", "active": True, "parameters": {"properties": {}}},
],
},
"t2": {
"name": "tool_b",
"actions": [
{"name": "foo", "description": "D", "active": True, "parameters": {"properties": {}}},
],
},
"t3": {
"name": "tool_c",
"actions": [
{"name": "foo_1", "description": "D", "active": True, "parameters": {"properties": {}}},
],
},
}
result = executor.prepare_tools_for_llm(tools_dict)
names = [r["function"]["name"] for r in result]
# foo_1 is taken by the unique action, so duplicates skip to _2 and _3
assert "foo_1" in names # The unique action
assert "foo_2" in names
assert "foo_3" in names
assert executor._name_to_tool["foo_1"] == ("t3", "foo_1")
def test_build_tool_parameters_filters_non_llm_fields(self):
executor = ToolExecutor()
@@ -128,6 +217,68 @@ class TestToolExecutorPrepare:
assert "value" not in result["properties"]["query"]
@pytest.mark.unit
class TestCheckPause:
def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"):
call = Mock()
call.name = name
call.id = call_id
call.arguments = arguments
call.thought_signature = None
return call
def test_client_side_tool_returns_llm_name(self):
"""check_pause returns the clean LLM-facing name and llm_name field."""
executor = ToolExecutor()
tools_dict = {
"ct0": {
"name": "write_file",
"client_side": True,
"actions": [
{"name": "write_file", "description": "Write a file", "active": True, "parameters": {}},
],
}
}
# Prepare tools so the mapping is built
executor.prepare_tools_for_llm(tools_dict)
call = self._make_call(name="write_file")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["name"] == "write_file"
assert result["llm_name"] == "write_file"
assert result["action_name"] == "write_file"
assert result["tool_id"] == "ct0"
def test_approval_required_returns_llm_name(self):
"""check_pause for approval-required tools returns clean LLM name."""
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "dangerous_tool",
"actions": [
{"name": "delete_all", "description": "Deletes everything", "active": True,
"require_approval": True, "parameters": {}},
],
}
}
executor.prepare_tools_for_llm(tools_dict)
call = self._make_call(name="delete_all")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["name"] == "delete_all"
assert result["llm_name"] == "delete_all"
assert result["action_name"] == "delete_all"
@pytest.mark.unit
class TestToolExecutorExecute:
@@ -143,7 +294,7 @@ class TestToolExecutorExecute:
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=(None, None, {}))),
lambda _cls, **kw: Mock(parse_args=Mock(return_value=(None, None, {}))),
)
call = self._make_call(name="bad")
@@ -167,7 +318,7 @@ class TestToolExecutorExecute:
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))),
lambda _cls, **kw: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))),
)
call = self._make_call()
@@ -190,7 +341,7 @@ class TestToolExecutorExecute:
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))),
lambda _cls, **kw: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))),
)
tools_dict = {
@@ -244,7 +395,7 @@ class TestToolExecutorExecute:
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))),
lambda _cls, **kw: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))),
)
tools_dict = {
@@ -284,7 +435,7 @@ class TestToolExecutorExecute:
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
lambda _cls, **kw: Mock(
parse_args=Mock(return_value=("t1", "get_users", {"body_param": "val"}))
),
)
@@ -331,7 +482,7 @@ class TestToolExecutorExecute:
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
lambda _cls, **kw: Mock(
parse_args=Mock(return_value=("t1", "act", {}))
),
)
@@ -376,7 +527,7 @@ class TestToolExecutorExecute:
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(
lambda _cls, **kw: Mock(
parse_args=Mock(return_value=("t1", "act", {"q": "v"}))
),
)

View File

@@ -73,7 +73,7 @@ class TestAnswerResourcePost:
),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(conv_id, "Hello", [], [], "", None),
return_value={"conversation_id": conv_id, "answer": "Hello", "sources": [], "tool_calls": [], "thought": "", "error": None},
):
resp = answer_client.post(
"/api/answer",
@@ -129,7 +129,7 @@ class TestAnswerResourcePost:
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(None, None, None, None, None, "Stream error"),
return_value={"conversation_id": None, "answer": None, "sources": None, "tool_calls": None, "thought": None, "error": "Stream error"},
):
resp = answer_client.post(
"/api/answer",
@@ -173,15 +173,7 @@ class TestAnswerResourcePost:
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
'{"key": "val"}',
[],
[],
"",
None,
{"structured": True, "schema": {"type": "object"}},
),
return_value={"conversation_id": conv_id, "answer": '{"key": "val"}', "sources": [], "tool_calls": [], "thought": "", "error": None, "extra": {"structured": True, "schema": {"type": "object"}}},
):
resp = answer_client.post(
"/api/answer",
@@ -208,14 +200,7 @@ class TestAnswerResourcePost:
return_value=iter([]),
), patch(
"application.api.answer.routes.answer.AnswerResource.process_response_stream",
return_value=(
conv_id,
"answer text",
[{"title": "src"}],
[{"tool": "t"}],
"thinking...",
None,
),
return_value={"conversation_id": conv_id, "answer": "answer text", "sources": [{"title": "src"}], "tool_calls": [{"tool": "t"}], "thought": "thinking...", "error": None},
):
resp = answer_client.post(
"/api/answer",

View File

@@ -481,10 +481,10 @@ class TestProcessResponseStream:
result = resource.process_response_stream(iter(stream))
assert result[0] == conv_id
assert result[1] == "Hello world"
assert result[2] == [{"title": "doc1"}]
assert result[5] is None
assert result["conversation_id"] == conv_id
assert result["answer"] == "Hello world"
assert result["sources"] == [{"title": "doc1"}]
assert result["error"] is None
def test_handles_stream_error(self, mock_mongo_db, flask_app):
import json
@@ -500,10 +500,8 @@ class TestProcessResponseStream:
result = resource.process_response_stream(iter(stream))
assert len(result) == 6
assert result[0] is None
assert result[4] == "Test error"
assert result[5] is None
assert result["conversation_id"] is None
assert result["error"] == "Test error"
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource

View File

@@ -79,9 +79,10 @@ class TestBuildFromCompressedContext:
# system + user + assistant + tool_call_assistant + tool_response = 5
assert len(messages) == 5
assert messages[3]["role"] == "assistant"
assert "function_call" in messages[3]["content"][0]
assert messages[3].get("tool_calls") is not None
assert messages[3]["tool_calls"][0]["function"]["name"] == "search"
assert messages[4]["role"] == "tool"
assert "function_response" in messages[4]["content"][0]
assert messages[4].get("tool_call_id") == "call-1"
def test_tool_calls_not_included_by_default(self):
queries = [
@@ -127,8 +128,8 @@ class TestBuildFromCompressedContext:
recent_queries=queries,
include_tool_calls=True,
)
tool_msg = messages[3]["content"][0]
call_id = tool_msg["function_call"]["call_id"]
assistant_msg = messages[3]
call_id = assistant_msg["tool_calls"][0]["id"]
assert call_id is not None
assert len(call_id) > 0

View File

@@ -295,10 +295,8 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[1] == "{}"
# Structured output adds extra tuple element
assert len(result) == 7
assert result[6]["structured"] is True
assert result["answer"] == "{}"
assert result.get("extra", {}).get("structured") is True
def test_handles_tool_calls_event(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
@@ -312,7 +310,7 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[3] == [{"name": "t1"}]
assert result["tool_calls"] == [{"name": "t1"}]
def test_incomplete_stream(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
@@ -323,7 +321,7 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[4] == "Stream ended unexpectedly"
assert result["error"] == "Stream ended unexpectedly"
def test_handles_thought_event(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
@@ -335,7 +333,7 @@ class TestProcessResponseStreamExtended:
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result[4] == "thinking..."
assert result["thought"] == "thinking..."
@pytest.mark.unit

View File

@@ -50,6 +50,7 @@ from tests.integration.test_analytics import AnalyticsTests
from tests.integration.test_connectors import ConnectorTests
from tests.integration.test_mcp import MCPTests
from tests.integration.test_misc import MiscTests
from tests.integration.test_v1_api import V1ApiTests
# Module registry
@@ -64,6 +65,7 @@ MODULES = {
"connectors": ConnectorTests,
"mcp": MCPTests,
"misc": MiscTests,
"v1_api": V1ApiTests,
}

View File

@@ -1036,208 +1036,7 @@ This is test documentation for integration tests.
return False
# -------------------------------------------------------------------------
# Compression Tests
# -------------------------------------------------------------------------
def test_compression_heavy_tool_usage(self) -> bool:
"""Test compression with heavy conversation usage."""
test_name = "Compression - Heavy Tool Usage"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
self.print_info("Making 10 consecutive requests to build conversation history...")
current_conv_id = None
for i in range(10):
question = f"Tell me about Python topic {i+1}: data structures, decorators, async, testing. Provide a comprehensive explanation."
payload = {
"question": question,
"history": "[]",
"isNoneDoc": True,
}
if current_conv_id:
payload["conversation_id"] = current_conv_id
try:
response = self.post("/api/answer", json=payload, timeout=90)
if response.status_code == 200:
result = response.json()
current_conv_id = result.get("conversation_id", current_conv_id)
answer_preview = (result.get("answer") or "")[:80]
self.print_success(f"Request {i+1}/10 completed")
self.print_info(f" Answer: {answer_preview}...")
else:
self.print_error(f"Request {i+1}/10 failed: status {response.status_code}")
self.record_result(test_name, False, f"Request {i+1} failed")
return False
time.sleep(2)
except Exception as e:
self.print_error(f"Request {i+1}/10 failed: {str(e)}")
self.record_result(test_name, False, str(e))
return False
if current_conv_id:
self.print_success("Heavy usage test completed")
self.record_result(test_name, True, f"10 requests, conv_id: {current_conv_id}")
return True
else:
self.print_warning("No conversation_id received")
self.record_result(test_name, False, "No conversation_id")
return False
def test_compression_needle_in_haystack(self) -> bool:
"""Test that compression preserves critical information.
Note: This is a long-running test that may timeout due to LLM response times.
Timeouts are handled gracefully as they indicate performance issues, not bugs.
"""
test_name = "Compression - Needle in Haystack"
self.print_header(f"Testing {test_name}")
if not self.require_auth(test_name):
return True
conversation_id = None
# Step 1: Send general questions
self.print_info("Step 1: Sending general questions...")
for i, question in enumerate([
"Tell me about Python best practices in detail",
"Explain Python data structures comprehensively",
]):
payload = {
"question": question,
"history": "[]",
"isNoneDoc": True,
}
if conversation_id:
payload["conversation_id"] = conversation_id
try:
response = self.post("/api/answer", json=payload, timeout=90)
if response.status_code == 200:
result = response.json()
conversation_id = result.get("conversation_id", conversation_id)
self.print_success(f"General question {i+1}/2 completed")
else:
self.print_error(f"Request failed: status {response.status_code}")
self.record_result(test_name, False, "General questions failed")
return False
time.sleep(2)
except Exception as e:
# Timeout errors are expected for long LLM responses
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
self.print_warning(f"Request timed out: {str(e)[:50]}")
self.record_result(test_name, True, "Skipped (timeout)")
return True
self.print_error(f"Request failed: {str(e)}")
self.record_result(test_name, False, str(e))
return False
# Step 2: Send critical information
self.print_info("Step 2: Sending CRITICAL information...")
critical_payload = {
"question": "Please remember: The production database password is stored in DB_PASSWORD_PROD environment variable. The backup runs at 3:00 AM UTC daily.",
"history": "[]",
"isNoneDoc": True,
"conversation_id": conversation_id,
}
try:
response = self.post("/api/answer", json=critical_payload, timeout=90)
if response.status_code == 200:
result = response.json()
conversation_id = result.get("conversation_id", conversation_id)
self.print_success("Critical information sent")
else:
self.record_result(test_name, False, "Critical info failed")
return False
time.sleep(2)
except Exception as e:
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
self.print_warning(f"Request timed out: {str(e)[:50]}")
self.record_result(test_name, True, "Skipped (timeout)")
return True
self.record_result(test_name, False, str(e))
return False
# Step 3: Bury with more questions
self.print_info("Step 3: Sending more questions to bury critical info...")
for i, question in enumerate([
"Explain Python decorators in great detail",
"Tell me about Python async programming comprehensively",
]):
payload = {
"question": question,
"history": "[]",
"isNoneDoc": True,
"conversation_id": conversation_id,
}
try:
response = self.post("/api/answer", json=payload, timeout=90)
if response.status_code == 200:
result = response.json()
conversation_id = result.get("conversation_id", conversation_id)
self.print_success(f"Burying question {i+1}/2 completed")
else:
self.record_result(test_name, False, "Burying questions failed")
return False
time.sleep(2)
except Exception as e:
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
self.print_warning(f"Request timed out: {str(e)[:50]}")
self.record_result(test_name, True, "Skipped (timeout)")
return True
self.record_result(test_name, False, str(e))
return False
# Step 4: Test recall
self.print_info("Step 4: Testing if critical info was preserved...")
recall_payload = {
"question": "What was the database password environment variable I mentioned earlier?",
"history": "[]",
"isNoneDoc": True,
"conversation_id": conversation_id,
}
try:
response = self.post("/api/answer", json=recall_payload, timeout=90)
if response.status_code == 200:
result = response.json()
answer = (result.get("answer") or "").lower()
if "db_password_prod" in answer or "database password" in answer:
self.print_success("Critical information preserved!")
self.print_info(f"Answer: {answer[:150]}...")
self.record_result(test_name, True, "Info preserved")
return True
else:
self.print_warning("Critical information may have been lost")
self.print_info(f"Answer: {answer[:150]}...")
self.record_result(test_name, False, "Info not preserved")
return False
else:
self.record_result(test_name, False, "Recall failed")
return False
except Exception as e:
if "timed out" in str(e).lower() or "timeout" in str(e).lower():
self.print_warning(f"Request timed out: {str(e)[:50]}")
self.record_result(test_name, True, "Skipped (timeout)")
return True
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Feedback Tests (NEW)
# Feedback Tests
# -------------------------------------------------------------------------
def test_feedback_positive(self) -> bool:
@@ -1435,15 +1234,6 @@ This is test documentation for integration tests.
self.test_tts_basic()
time.sleep(1)
# Compression tests (longer running)
if self.is_authenticated:
self.test_compression_heavy_tool_usage()
time.sleep(2)
self.test_compression_needle_in_haystack()
else:
self.print_info("Skipping compression tests (no authentication)")
return self.print_summary()

View File

@@ -0,0 +1,681 @@
#!/usr/bin/env python3
"""
Integration tests for the /v1/ chat completions API (Phase 4).
Endpoints tested:
- /v1/chat/completions (POST) - Standard chat completions (streaming & non-streaming)
- /v1/models (GET) - List available agent models
Usage:
python tests/integration/test_v1_api.py
python tests/integration/test_v1_api.py --base-url http://localhost:7091
python tests/integration/test_v1_api.py --token YOUR_JWT_TOKEN
"""
import json as json_module
import sys
import time
from pathlib import Path
from typing import Optional
import requests
# Add parent directory to path for standalone execution
_THIS_DIR = Path(__file__).parent
_TESTS_DIR = _THIS_DIR.parent
_ROOT_DIR = _TESTS_DIR.parent
if str(_ROOT_DIR) not in sys.path:
sys.path.insert(0, str(_ROOT_DIR))
from tests.integration.base import DocsGPTTestBase, create_client_from_args
class V1ApiTests(DocsGPTTestBase):
"""Integration tests for /v1/ chat completions API."""
# -------------------------------------------------------------------------
# Test Data Helpers
# -------------------------------------------------------------------------
def get_or_create_agent_key(self) -> Optional[str]:
"""Get or create a test agent and return its API key."""
if hasattr(self, "_agent_key") and self._agent_key:
return self._agent_key
# Try both authenticated and unauthenticated creation.
# Published agents need a source to get an API key.
payload = {
"name": f"V1 Test Agent {int(time.time())}",
"description": "Integration test agent for v1 API tests",
"prompt_id": "default",
"chunks": 2,
"retriever": "classic",
"agent_type": "classic",
"status": "published",
"source": "default",
}
try:
response = self.post("/api/create_agent", json=payload, timeout=10)
if response.status_code in [200, 201]:
result = response.json()
api_key = result.get("key")
self._agent_id = result.get("id")
if api_key:
self._agent_key = api_key
self.print_info(f"Created test agent with key: {api_key[:8]}...")
return api_key
else:
self.print_warning("Agent created but no API key returned")
else:
self.print_warning(f"Agent creation returned {response.status_code}: {response.text[:200]}")
except Exception as e:
self.print_error(f"Failed to create agent: {e}")
return None
def _v1_headers(self, api_key: str) -> dict:
"""Build headers for v1 API requests."""
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# -------------------------------------------------------------------------
# /v1/chat/completions — Auth Tests
# -------------------------------------------------------------------------
def test_no_auth_returns_401(self) -> bool:
"""Test that /v1/chat/completions without auth returns 401."""
test_name = "v1 chat completions - no auth"
self.print_header(f"Testing {test_name}")
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={"messages": [{"role": "user", "content": "Hi"}]},
headers={"Content-Type": "application/json"},
timeout=10,
)
if response.status_code == 401:
self.print_success("Correctly returned 401 for missing auth")
self.record_result(test_name, True, "401 as expected")
return True
else:
self.print_error(f"Expected 401, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Request failed: {e}")
self.record_result(test_name, False, str(e))
return False
def test_invalid_key_returns_error(self) -> bool:
"""Test that invalid API key returns error."""
test_name = "v1 chat completions - invalid key"
self.print_header(f"Testing {test_name}")
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={"messages": [{"role": "user", "content": "Hi"}]},
headers=self._v1_headers("invalid-key-12345"),
timeout=30,
)
# Should return 400 or 500 (agent not found)
if response.status_code in [400, 401, 500]:
self.print_success(f"Correctly returned {response.status_code} for invalid key")
self.record_result(test_name, True, f"Error as expected ({response.status_code})")
return True
else:
self.print_error(f"Unexpected status: {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Request failed: {e}")
self.record_result(test_name, False, str(e))
return False
def test_missing_messages_returns_400(self) -> bool:
"""Test that missing messages field returns 400."""
test_name = "v1 chat completions - missing messages"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={"stream": False},
headers=self._v1_headers(api_key),
timeout=10,
)
if response.status_code == 400:
self.print_success("Correctly returned 400 for missing messages")
self.record_result(test_name, True, "400 as expected")
return True
else:
self.print_error(f"Expected 400, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Request failed: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/chat/completions — Non-streaming
# -------------------------------------------------------------------------
def test_non_streaming_basic(self) -> bool:
"""Test basic non-streaming chat completion."""
test_name = "v1 chat completions - non-streaming"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Say hello in one word."}],
"stream": False,
},
headers=self._v1_headers(api_key),
timeout=60,
)
self.print_info(f"Status: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.print_error(f"Response: {response.text[:300]}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
# Verify standard format
checks = [
("id" in data, "has id"),
(data.get("object") == "chat.completion", "object is chat.completion"),
("choices" in data, "has choices"),
(len(data["choices"]) > 0, "choices not empty"),
(data["choices"][0].get("message", {}).get("role") == "assistant", "role is assistant"),
(data["choices"][0].get("message", {}).get("content") is not None, "has content"),
(data["choices"][0].get("finish_reason") == "stop", "finish_reason is stop"),
("usage" in data, "has usage"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
content = data["choices"][0]["message"]["content"]
self.print_info(f"Response: {content[:100]}")
# Check docsgpt extension
if "docsgpt" in data:
self.print_success(" has docsgpt extension")
if "conversation_id" in data["docsgpt"]:
self.print_success(f" conversation_id: {data['docsgpt']['conversation_id'][:8]}...")
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/chat/completions — Streaming
# -------------------------------------------------------------------------
def test_streaming_basic(self) -> bool:
"""Test basic streaming chat completion."""
test_name = "v1 chat completions - streaming"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Say hi briefly."}],
"stream": True,
},
headers=self._v1_headers(api_key),
stream=True,
timeout=60,
)
self.print_info(f"Status: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
chunks = []
content_pieces = []
got_done = False
got_stop = False
got_id = False
for line in response.iter_lines():
if not line:
continue
line_str = line.decode("utf-8")
if not line_str.startswith("data: "):
continue
data_str = line_str[6:]
if data_str.strip() == "[DONE]":
got_done = True
break
try:
chunk = json_module.loads(data_str)
chunks.append(chunk)
# Standard chunks
if "choices" in chunk:
delta = chunk["choices"][0].get("delta", {})
if "content" in delta:
content_pieces.append(delta["content"])
if chunk["choices"][0].get("finish_reason") == "stop":
got_stop = True
# Extension chunks
if "docsgpt" in chunk:
ext = chunk["docsgpt"]
if ext.get("type") == "id":
got_id = True
except json_module.JSONDecodeError:
pass
full_content = "".join(content_pieces)
checks = [
(len(chunks) > 0, f"received {len(chunks)} chunks"),
(len(content_pieces) > 0, f"got content: {full_content[:50]}..."),
(got_stop, "got finish_reason=stop"),
(got_done, "got [DONE] sentinel"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
if got_id:
self.print_success(" got conversation_id via docsgpt extension")
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/chat/completions — Multi-turn conversation
# -------------------------------------------------------------------------
def test_multi_turn_conversation(self) -> bool:
"""Test multi-turn conversation with history in messages."""
test_name = "v1 chat completions - multi-turn"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json={
"messages": [
{"role": "user", "content": "My name is TestBot."},
{"role": "assistant", "content": "Hello TestBot!"},
{"role": "user", "content": "What is my name?"},
],
"stream": False,
},
headers=self._v1_headers(api_key),
timeout=60,
)
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
content = data["choices"][0]["message"]["content"]
self.print_info(f"Response: {content[:150]}")
# The response should reference "TestBot" from the history
has_content = bool(content)
self.print_success(f" Got response with {len(content)} chars")
self.record_result(test_name, has_content, "Multi-turn works")
return has_content
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# /v1/models
# -------------------------------------------------------------------------
def test_list_models(self) -> bool:
"""Test GET /v1/models endpoint."""
test_name = "v1 models - list"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
response = requests.get(
f"{self.base_url}/v1/models",
headers=self._v1_headers(api_key),
timeout=10,
)
self.print_info(f"Status: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
checks = [
(data.get("object") == "list", "object is list"),
("data" in data, "has data array"),
(len(data.get("data", [])) > 0, f"has {len(data.get('data', []))} model(s)"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
if data.get("data"):
model = data["data"][0]
model_checks = [
("id" in model, "model has id"),
(model.get("object") == "model", "model object is 'model'"),
(model.get("owned_by") == "docsgpt", "owned_by is docsgpt"),
]
for check, label in model_checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_models_no_auth(self) -> bool:
"""Test that /v1/models without auth returns 401."""
test_name = "v1 models - no auth"
self.print_header(f"Testing {test_name}")
try:
response = requests.get(
f"{self.base_url}/v1/models",
timeout=10,
)
if response.status_code == 401:
self.print_success("Correctly returned 401")
self.record_result(test_name, True, "401 as expected")
return True
else:
self.print_error(f"Expected 401, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Backward Compatibility — old endpoints still work
# -------------------------------------------------------------------------
def test_old_stream_endpoint_still_works(self) -> bool:
"""Verify the old /stream endpoint still works after v1 changes."""
test_name = "Backward compat - /stream"
self.print_header(f"Testing {test_name}")
payload = {
"question": "Say hello briefly.",
"history": "[]",
"isNoneDoc": True,
}
try:
response = requests.post(
f"{self.base_url}/stream",
json=payload,
headers=self.headers,
stream=True,
timeout=60,
)
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
events = []
got_end = False
got_answer = False
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
try:
data = json_module.loads(line_str[6:])
events.append(data)
if data.get("type") == "answer":
got_answer = True
if data.get("type") == "end":
got_end = True
break
except json_module.JSONDecodeError:
pass
checks = [
(len(events) > 0, f"received {len(events)} events"),
(got_answer, "got answer event"),
(got_end, "got end event"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_old_answer_endpoint_still_works(self) -> bool:
"""Verify the old /api/answer endpoint still works."""
test_name = "Backward compat - /api/answer"
self.print_header(f"Testing {test_name}")
payload = {
"question": "Say hi.",
"history": "[]",
"isNoneDoc": True,
}
try:
response = requests.post(
f"{self.base_url}/api/answer",
json=payload,
headers=self.headers,
timeout=60,
)
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
data = response.json()
checks = [
("answer" in data, "has answer"),
("conversation_id" in data, "has conversation_id"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
self.print_info(f"Answer: {data.get('answer', '')[:100]}")
self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression")
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Cleanup
# -------------------------------------------------------------------------
def cleanup(self):
"""Clean up test resources."""
if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated:
try:
self.post(f"/api/delete_agent?id={self._agent_id}", json={})
self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...")
except Exception:
pass
# -------------------------------------------------------------------------
# Run All
# -------------------------------------------------------------------------
def run_all(self) -> bool:
"""Run all v1 API integration tests."""
self.print_header("V1 Chat Completions API Integration Tests")
self.print_info(f"Base URL: {self.base_url}")
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
try:
# Auth tests (no agent needed)
self.test_no_auth_returns_401()
time.sleep(0.5)
self.test_models_no_auth()
time.sleep(0.5)
self.test_invalid_key_returns_error()
time.sleep(0.5)
self.test_missing_messages_returns_400()
time.sleep(0.5)
# Non-streaming
self.test_non_streaming_basic()
time.sleep(1)
# Streaming
self.test_streaming_basic()
time.sleep(1)
# Multi-turn
self.test_multi_turn_conversation()
time.sleep(1)
# Models
self.test_list_models()
time.sleep(0.5)
# Backward compatibility
self.test_old_stream_endpoint_still_works()
time.sleep(1)
self.test_old_answer_endpoint_still_works()
time.sleep(1)
finally:
self.cleanup()
return self.print_summary()
def main():
"""Main entry point."""
client = create_client_from_args(V1ApiTests, "DocsGPT V1 API Integration Tests")
success = client.run_all()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,539 @@
#!/usr/bin/env python3
r"""
Integration tests for the /v1/ chat completions API — client tool-call flow.
Tests the full lifecycle:
1. Send request with client tools → LLM triggers a tool call
2. Verify response returns clean tool names (no internal _ct\d+ suffix)
3. Send continuation with tool results + top-level conversation_id
4. Verify the continuation completes successfully
Usage:
python tests/integration/test_v1_tool_calls.py
python tests/integration/test_v1_tool_calls.py --base-url http://localhost:7091
"""
import json as json_module
import re
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import requests
_THIS_DIR = Path(__file__).parent
_TESTS_DIR = _THIS_DIR.parent
_ROOT_DIR = _TESTS_DIR.parent
if str(_ROOT_DIR) not in sys.path:
sys.path.insert(0, str(_ROOT_DIR))
from tests.integration.base import DocsGPTTestBase, create_client_from_args
# Internal suffix pattern that should NOT appear in client responses
_CT_SUFFIX_RE = re.compile(r"_ct\d+$")
class V1ToolCallTests(DocsGPTTestBase):
"""Integration tests for /v1/ client tool-call flows."""
# -------------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------------
def get_or_create_agent_key(self) -> Optional[str]:
"""Get or create a test agent and return its API key."""
if hasattr(self, "_agent_key") and self._agent_key:
return self._agent_key
payload = {
"name": f"V1 ToolCall Test {int(time.time())}",
"description": "Integration test agent for tool-call flow",
"prompt_id": "default",
"chunks": 2,
"retriever": "classic",
"agent_type": "classic",
"status": "published",
"source": "default",
}
try:
response = self.post("/api/create_agent", json=payload, timeout=10)
if response.status_code in [200, 201]:
result = response.json()
api_key = result.get("key")
self._agent_id = result.get("id")
if api_key:
self._agent_key = api_key
self.print_info(f"Created test agent with key: {api_key[:8]}...")
return api_key
except Exception as e:
self.print_error(f"Failed to create agent: {e}")
return None
def _v1_headers(self, api_key: str) -> dict:
return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# A simple client tool definition in OpenAI format
_CLIENT_TOOLS = [
{
"type": "function",
"function": {
"name": "create",
"description": "Create a new todo item",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "The title of the new todo item",
}
},
"required": ["title"],
},
},
}
]
def _send_streaming_request(
self,
api_key: str,
messages: List[Dict],
tools: Optional[List[Dict]] = None,
conversation_id: Optional[str] = None,
) -> Tuple[List[Dict], str, Optional[Dict]]:
"""Send a streaming request and collect all events.
Returns:
(all_chunks, full_content, tool_call_info)
tool_call_info is a dict with 'name', 'arguments', 'call_id'
if the response paused for a client tool call, else None.
"""
body: Dict[str, Any] = {
"messages": messages,
"stream": True,
}
if tools:
body["tools"] = tools
if conversation_id:
body["conversation_id"] = conversation_id
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json=body,
headers=self._v1_headers(api_key),
stream=True,
timeout=120,
)
if response.status_code != 200:
raise RuntimeError(
f"Expected 200, got {response.status_code}: {response.text[:300]}"
)
chunks: List[Dict] = []
content_pieces: List[str] = []
tool_call_info: Optional[Dict] = None
conversation_id_from_response: Optional[str] = None
for line in response.iter_lines():
if not line:
continue
line_str = line.decode("utf-8")
if not line_str.startswith("data: "):
continue
data_str = line_str[6:]
if data_str.strip() == "[DONE]":
break
try:
chunk = json_module.loads(data_str)
chunks.append(chunk)
# Standard chunks
if "choices" in chunk:
delta = chunk["choices"][0].get("delta", {})
if "content" in delta:
content_pieces.append(delta["content"])
# Tool call delta
if "tool_calls" in delta:
tc = delta["tool_calls"][0]
tool_call_info = {
"call_id": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"].get("arguments", "{}"),
}
# Extension chunks
if "docsgpt" in chunk:
ext = chunk["docsgpt"]
if ext.get("type") == "id":
conversation_id_from_response = ext.get("conversation_id")
except json_module.JSONDecodeError:
pass
full_content = "".join(content_pieces)
# Attach conversation_id to tool_call_info for convenience
if tool_call_info and conversation_id_from_response:
tool_call_info["conversation_id"] = conversation_id_from_response
return chunks, full_content, tool_call_info
def _send_non_streaming_request(
self,
api_key: str,
messages: List[Dict],
tools: Optional[List[Dict]] = None,
conversation_id: Optional[str] = None,
) -> Dict:
"""Send a non-streaming request and return parsed JSON."""
body: Dict[str, Any] = {
"messages": messages,
"stream": False,
}
if tools:
body["tools"] = tools
if conversation_id:
body["conversation_id"] = conversation_id
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json=body,
headers=self._v1_headers(api_key),
timeout=120,
)
if response.status_code != 200:
raise RuntimeError(
f"Expected 200, got {response.status_code}: {response.text[:300]}"
)
return response.json()
# -------------------------------------------------------------------------
# Tests
# -------------------------------------------------------------------------
def test_streaming_tool_call_clean_name(self) -> bool:
"""Streaming: tool names returned to client must not have _ct suffixes."""
test_name = "v1 streaming tool call - clean name"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Test integration'. Call the tool now."},
]
chunks, content, tool_call_info = self._send_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
if not tool_call_info:
# LLM didn't trigger the tool — could happen, not a failure of our code
self.print_warning("LLM did not trigger a tool call (may need prompt tuning)")
self.print_info(f"Got text response instead: {content[:100]}")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
tool_name = tool_call_info["name"]
self.print_info(f"Tool call name: {tool_name}")
has_suffix = bool(_CT_SUFFIX_RE.search(tool_name))
if has_suffix:
self.print_error(f"Tool name has internal suffix: {tool_name}")
self.record_result(test_name, False, f"Suffix leak: {tool_name}")
return False
self.print_success(f"Tool name is clean: {tool_name}")
self.record_result(test_name, True, f"Clean name: {tool_name}")
return True
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_non_streaming_tool_call_clean_name(self) -> bool:
"""Non-streaming: tool names returned to client must not have _ct suffixes."""
test_name = "v1 non-streaming tool call - clean name"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Test non-stream'. Call the tool now."},
]
data = self._send_non_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
message = data["choices"][0]["message"]
tool_calls = message.get("tool_calls")
if not tool_calls:
content = message.get("content", "")
self.print_warning("LLM did not trigger a tool call")
self.print_info(f"Got text response: {content[:100]}")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
tool_name = tool_calls[0]["function"]["name"]
self.print_info(f"Tool call name: {tool_name}")
has_suffix = bool(_CT_SUFFIX_RE.search(tool_name))
if has_suffix:
self.print_error(f"Tool name has internal suffix: {tool_name}")
self.record_result(test_name, False, f"Suffix leak: {tool_name}")
return False
self.print_success(f"Tool name is clean: {tool_name}")
self.record_result(test_name, True, f"Clean name: {tool_name}")
return True
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool:
"""Full tool-call round-trip: trigger → get conversation_id → continue with top-level id."""
test_name = "v1 streaming tool continuation - top-level conversation_id"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
# Step 1: trigger a tool call
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Round trip test'. Call the tool now."},
]
chunks, content, tool_call_info = self._send_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
if not tool_call_info:
self.print_warning("LLM did not trigger a tool call")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
conversation_id = tool_call_info.get("conversation_id")
if not conversation_id:
self.print_error("No conversation_id returned in stream")
self.record_result(test_name, False, "Missing conversation_id")
return False
self.print_info(f"Got conversation_id: {conversation_id[:12]}...")
self.print_info(f"Tool call: {tool_call_info['name']}({tool_call_info['arguments']})")
# Step 2: send continuation with tool result + top-level conversation_id
# (standard OpenAI format — no docsgpt field in assistant message)
continuation_messages = [
*messages,
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": tool_call_info["call_id"],
"type": "function",
"function": {
"name": tool_call_info["name"],
"arguments": tool_call_info["arguments"],
},
}
],
},
{
"role": "tool",
"tool_call_id": tool_call_info["call_id"],
"content": json_module.dumps({"id": 99, "title": "Round trip test", "status": "created"}),
},
]
chunks2, content2, tool_call_info2 = self._send_streaming_request(
api_key,
continuation_messages,
tools=self._CLIENT_TOOLS,
conversation_id=conversation_id,
)
checks = [
(len(chunks2) > 0, f"continuation returned {len(chunks2)} chunks"),
(bool(content2) or tool_call_info2 is not None, "got content or another tool call"),
]
all_passed = True
for check, label in checks:
if check:
self.print_success(f" {label}")
else:
self.print_error(f" {label}")
all_passed = False
if content2:
self.print_info(f"Continuation response: {content2[:150]}")
self.record_result(
test_name,
all_passed,
"Full round-trip works" if all_passed else "Continuation failed",
)
return all_passed
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
def test_non_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool:
"""Non-streaming full round-trip with top-level conversation_id."""
test_name = "v1 non-streaming tool continuation - top-level conversation_id"
self.print_header(f"Testing {test_name}")
api_key = self.get_or_create_agent_key()
if not api_key:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
try:
# Step 1: trigger a tool call
messages = [
{"role": "user", "content": "Use the create tool to add a todo item titled 'Non-stream round trip'. Call the tool now."},
]
data = self._send_non_streaming_request(
api_key, messages, tools=self._CLIENT_TOOLS
)
message = data["choices"][0]["message"]
tool_calls = message.get("tool_calls")
if not tool_calls:
self.print_warning("LLM did not trigger a tool call")
self.record_result(test_name, True, "Skipped (LLM didn't call tool)")
return True
conversation_id = data.get("docsgpt", {}).get("conversation_id")
if not conversation_id:
self.print_error("No conversation_id in response")
self.record_result(test_name, False, "Missing conversation_id")
return False
tc = tool_calls[0]
self.print_info(f"Got tool call: {tc['function']['name']}")
self.print_info(f"conversation_id: {conversation_id[:12]}...")
# Step 2: send continuation (standard format, top-level conversation_id)
continuation_messages = [
*messages,
{
"role": "assistant",
"content": None,
"tool_calls": [tc],
},
{
"role": "tool",
"tool_call_id": tc["id"],
"content": json_module.dumps({"id": 100, "title": "Non-stream round trip", "status": "created"}),
},
]
data2 = self._send_non_streaming_request(
api_key,
continuation_messages,
tools=self._CLIENT_TOOLS,
conversation_id=conversation_id,
)
message2 = data2["choices"][0]["message"]
has_response = bool(message2.get("content")) or bool(message2.get("tool_calls"))
if has_response:
self.print_success("Continuation returned a response")
content2 = message2.get("content", "")
if content2:
self.print_info(f"Response: {content2[:150]}")
else:
self.print_error("Continuation returned empty response")
self.record_result(
test_name,
has_response,
"Round-trip works" if has_response else "Empty continuation response",
)
return has_response
except Exception as e:
self.print_error(f"Error: {e}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Cleanup & Run All
# -------------------------------------------------------------------------
def cleanup(self):
if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated:
try:
self.post(f"/api/delete_agent?id={self._agent_id}", json={})
self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...")
except Exception:
pass
def run_all(self) -> bool:
self.print_header("V1 Tool-Call Flow Integration Tests")
self.print_info(f"Base URL: {self.base_url}")
self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}")
try:
# Streaming tests
self.test_streaming_tool_call_clean_name()
time.sleep(1)
self.test_non_streaming_tool_call_clean_name()
time.sleep(1)
# Full round-trip tests
self.test_streaming_tool_continuation_with_top_level_conversation_id()
time.sleep(1)
self.test_non_streaming_tool_continuation_with_top_level_conversation_id()
time.sleep(1)
finally:
self.cleanup()
return self.print_summary()
def main():
client = create_client_from_args(V1ToolCallTests, "DocsGPT V1 Tool-Call Integration Tests")
success = client.run_all()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -196,9 +196,9 @@ class TestGoogleLLMHandler:
assert result.finish_reason == "tool_calls"
def test_create_tool_message(self):
"""Test creating tool message."""
"""Test creating tool message in standard format."""
handler = GoogleLLMHandler()
tool_call = ToolCall(
id="call_123",
name="get_weather",
@@ -206,35 +206,26 @@ class TestGoogleLLMHandler:
index=0
)
result = {"temperature": "25C", "condition": "cloudy"}
message = handler.create_tool_message(tool_call, result)
expected = {
"role": "model",
"content": [
{
"function_response": {
"name": "get_weather",
"response": {"result": result},
}
}
],
}
assert message == expected
assert message["role"] == "tool"
assert message["tool_call_id"] == "call_123"
import json
assert json.loads(message["content"]) == result
def test_create_tool_message_string_result(self):
"""Test creating tool message with string result."""
handler = GoogleLLMHandler()
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
result = "2023-12-01 15:30:00 JST"
message = handler.create_tool_message(tool_call, result)
assert message["role"] == "model"
assert message["content"][0]["function_response"]["response"]["result"] == result
assert message["content"][0]["function_response"]["name"] == "get_time"
assert message["role"] == "tool"
assert message["tool_call_id"] == "call_456"
assert message["content"] == result
def test_iterate_stream(self):
"""Test stream iteration."""

View File

@@ -621,6 +621,7 @@ class TestHandleToolCalls:
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
@@ -641,7 +642,7 @@ class TestHandleToolCalls:
while True:
events.append(next(gen))
except StopIteration as e:
messages = e.value
messages, _pending = e.value
assert any(e.get("type") == "tool_call" for e in events)
assert len(messages) >= 2 # function_call + tool_message
@@ -675,6 +676,9 @@ class TestHandleToolCalls:
agent = Mock()
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
agent.tool_executor._name_to_tool = {}
agent._execute_tool_action = Mock(side_effect=RuntimeError("exec error"))
call = ToolCall(id="c1", name="action_1", arguments="{}")
@@ -704,18 +708,17 @@ class TestHandleToolCalls:
while True:
next(gen)
except StopIteration as e:
messages = e.value
messages, _pending = e.value
# Standard format: thought_signature is on tool_calls items
assistant_msgs = [
m for m in messages
if m.get("role") == "assistant"
and isinstance(m.get("content"), list)
if m.get("role") == "assistant" and m.get("tool_calls")
]
assert any(
"thought_signature" in item
tc.get("thought_signature") == "sig"
for m in assistant_msgs
for item in m["content"]
if isinstance(item, dict)
for tc in m["tool_calls"]
)
@@ -751,6 +754,7 @@ class TestHandleNonStreaming:
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
# First response requires tool call, second is final
call_count = {"n": 0}
@@ -856,6 +860,7 @@ class TestHandleStreaming:
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
# First chunk has partial tool call, second completes it
chunk1 = LLMResponse(
@@ -907,6 +912,7 @@ class TestHandleStreaming:
agent.context_limit_reached = True
agent._check_context_limit = Mock(return_value=True)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
# Chunk finishes with tool_calls
chunk = LLMResponse(
@@ -929,7 +935,7 @@ class TestHandleStreaming:
def fake_handle_tool_calls(agent, calls, tools_dict, messages):
agent.context_limit_reached = True
yield {"type": "tool_call", "data": {"status": "skipped"}}
return messages
return messages, None
handler.handle_tool_calls = fake_handle_tool_calls
@@ -1501,6 +1507,7 @@ class TestHandleToolCallsCompressionSuccess:
agent._check_context_limit = Mock(side_effect=check_limit)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
@@ -1538,6 +1545,7 @@ class TestHandleToolCallsCompressionSuccess:
agent = Mock()
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
exec_count = {"n": 0}

View File

@@ -128,9 +128,9 @@ class TestOpenAILLMHandler:
assert result.finish_reason == ""
def test_create_tool_message(self):
"""Test creating tool message."""
"""Test creating tool message in standard format."""
handler = OpenAILLMHandler()
tool_call = ToolCall(
id="call_123",
name="get_weather",
@@ -138,36 +138,26 @@ class TestOpenAILLMHandler:
index=0
)
result = {"temperature": "72F", "condition": "sunny"}
message = handler.create_tool_message(tool_call, result)
expected = {
"role": "tool",
"content": [
{
"function_response": {
"name": "get_weather",
"response": {"result": result},
"call_id": "call_123",
}
}
],
}
assert message == expected
assert message["role"] == "tool"
assert message["tool_call_id"] == "call_123"
import json
assert json.loads(message["content"]) == result
def test_create_tool_message_string_result(self):
"""Test creating tool message with string result."""
handler = OpenAILLMHandler()
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
result = "2023-12-01 10:30:00"
message = handler.create_tool_message(tool_call, result)
assert message["role"] == "tool"
assert message["content"][0]["function_response"]["response"]["result"] == result
assert message["content"][0]["function_response"]["call_id"] == "call_456"
assert message["tool_call_id"] == "call_456"
assert message["content"] == result
def test_iterate_stream(self):
"""Test stream iteration."""

View File

@@ -478,11 +478,14 @@ class TestHandleToolCallsErrors:
handler = ConcreteHandler()
agent = MagicMock()
agent._check_context_limit = MagicMock(return_value=False)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent.tool_executor._name_to_tool = {"search": ("1", "search")}
agent._execute_tool_action = MagicMock(
side_effect=RuntimeError("tool failed")
)
tool_call = ToolCall(id="tc1", name="search_1", arguments={"q": "test"})
tool_call = ToolCall(id="tc1", name="search", arguments={"q": "test"})
tools_dict = {"1": {"name": "search_tool"}}
messages = [{"role": "user", "content": "hi"}]
@@ -506,6 +509,9 @@ class TestHandleToolCallsErrors:
handler = ConcreteHandler()
agent = MagicMock()
agent._check_context_limit = MagicMock(return_value=False)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent.tool_executor._name_to_tool = {}
agent._execute_tool_action = MagicMock(
side_effect=RuntimeError("tool failed")
)
@@ -1169,12 +1175,15 @@ class TestHandleToolCallsErrorsAdditional:
handler = ConcreteHandler()
agent = MagicMock()
agent._check_context_limit = MagicMock(return_value=False)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent.tool_executor._name_to_tool = {"do_thing": ("42", "do_thing")}
agent._execute_tool_action = MagicMock(
side_effect=RuntimeError("broken tool")
)
tool_call = ToolCall(
id="tc1", name="do_thing_42", arguments={"x": 1}
id="tc1", name="do_thing", arguments={"x": 1}
)
tools_dict = {"42": {"name": "my_tool"}}
messages = [{"role": "user", "content": "go"}]
@@ -1188,7 +1197,7 @@ class TestHandleToolCallsErrorsAdditional:
while True:
events.append(next(gen))
except StopIteration as e:
final_messages = e.value
final_messages, _pending = e.value
# Verify the error message was appended
error_msgs = [
@@ -1205,12 +1214,17 @@ class TestHandleToolCallsErrorsAdditional:
]
assert len(error_events) == 1
assert error_events[0]["data"]["tool_name"] == "my_tool"
assert error_events[0]["data"]["action_name"] == "do_thing_42"
assert error_events[0]["data"]["action_name"] == "do_thing"
def test_tool_error_with_no_context_check(self):
"""Cover line 660: messages.copy() at start of handle_tool_calls."""
handler = ConcreteHandler()
agent = MagicMock(spec=[]) # No _check_context_limit attribute
agent.llm = MagicMock()
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor = MagicMock()
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent.tool_executor._name_to_tool = {}
agent._execute_tool_action = MagicMock(
side_effect=ValueError("bad args")
)

View File

@@ -176,6 +176,9 @@ class TestLLMHandlerTokenTracking:
# Create mock agent that hits limit on second tool
mock_agent = Mock()
mock_agent.context_limit_reached = False
mock_agent.llm.__class__.__name__ = "MockLLM"
mock_agent.tool_executor.check_pause = Mock(return_value=None)
mock_agent.tool_executor._name_to_tool = {}
call_count = [0]
@@ -235,6 +238,9 @@ class TestLLMHandlerTokenTracking:
mock_agent = Mock()
mock_agent.context_limit_reached = False
mock_agent._check_context_limit = Mock(return_value=False)
mock_agent.llm.__class__.__name__ = "MockLLM"
mock_agent.tool_executor.check_pause = Mock(return_value=None)
mock_agent.tool_executor._name_to_tool = {}
mock_agent._execute_tool_action = Mock(
return_value=iter([{"type": "tool_call", "data": {}}])
)
@@ -300,7 +306,7 @@ class TestLLMHandlerTokenTracking:
def tool_handler_gen(*args):
yield {"type": "tool", "data": {}}
return []
return [], None
# Mock handle_tool_calls to return messages and set flag
with patch.object(

430
tests/test_client_tools.py Normal file
View File

@@ -0,0 +1,430 @@
"""Tests for client-side tools (Phase 2).
Covers merge_client_tools, prepare_tools_for_llm with client tools,
check_pause for client-side tools, and the full flow through the handler.
"""
from unittest.mock import Mock
import pytest
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
# ---------------------------------------------------------------------------
# ToolExecutor.merge_client_tools
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestMergeClientTools:
def test_merge_single_tool(self):
executor = ToolExecutor()
tools_dict = {}
client_tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "City name"}
},
"required": ["city"],
},
},
}
]
result = executor.merge_client_tools(tools_dict, client_tools)
assert "ct0" in result
tool = result["ct0"]
assert tool["name"] == "get_weather"
assert tool["client_side"] is True
assert len(tool["actions"]) == 1
assert tool["actions"][0]["name"] == "get_weather"
assert tool["actions"][0]["active"] is True
assert "city" in tool["actions"][0]["parameters"]["properties"]
def test_merge_multiple_tools(self):
executor = ToolExecutor()
tools_dict = {"0": {"name": "existing_tool", "actions": []}}
client_tools = [
{"type": "function", "function": {"name": "tool_a", "description": "A"}},
{"type": "function", "function": {"name": "tool_b", "description": "B"}},
]
result = executor.merge_client_tools(tools_dict, client_tools)
# Original tool still present
assert "0" in result
# Client tools added
assert "ct0" in result
assert "ct1" in result
assert result["ct0"]["name"] == "tool_a"
assert result["ct1"]["name"] == "tool_b"
def test_merge_bare_format(self):
"""Accept simplified format without the outer 'function' wrapper."""
executor = ToolExecutor()
tools_dict = {}
client_tools = [
{"name": "simple_tool", "description": "Simple", "parameters": {}},
]
result = executor.merge_client_tools(tools_dict, client_tools)
assert "ct0" in result
assert result["ct0"]["name"] == "simple_tool"
def test_merge_preserves_existing_tools(self):
executor = ToolExecutor()
tools_dict = {
"abc123": {
"name": "brave",
"actions": [{"name": "search", "active": True}],
}
}
client_tools = [
{"type": "function", "function": {"name": "my_tool", "description": "D"}},
]
executor.merge_client_tools(tools_dict, client_tools)
assert "abc123" in tools_dict
assert tools_dict["abc123"]["name"] == "brave"
assert "ct0" in tools_dict
def test_merge_empty_list(self):
executor = ToolExecutor()
tools_dict = {"0": {"name": "existing"}}
executor.merge_client_tools(tools_dict, [])
assert len(tools_dict) == 1
# ---------------------------------------------------------------------------
# prepare_tools_for_llm with client tools
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestPrepareClientToolsForLlm:
def test_client_tools_included_in_llm_schema(self):
executor = ToolExecutor()
tools_dict = {}
client_tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"}
},
"required": ["city"],
},
},
}
]
executor.merge_client_tools(tools_dict, client_tools)
schemas = executor.prepare_tools_for_llm(tools_dict)
assert len(schemas) == 1
assert schemas[0]["type"] == "function"
assert schemas[0]["function"]["name"] == "get_weather"
assert schemas[0]["function"]["description"] == "Get weather"
# Parameters passed through directly (not filtered by _build_tool_parameters)
assert "city" in schemas[0]["function"]["parameters"]["properties"]
assert schemas[0]["function"]["parameters"]["required"] == ["city"]
def test_mixed_server_and_client_tools(self):
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "test_tool",
"actions": [
{
"name": "do_thing",
"description": "Does a thing",
"active": True,
"parameters": {
"properties": {
"query": {
"type": "string",
"filled_by_llm": True,
"required": True,
}
}
},
}
],
}
}
client_tools = [
{
"type": "function",
"function": {
"name": "local_fn",
"description": "Local function",
"parameters": {"type": "object", "properties": {}},
},
}
]
executor.merge_client_tools(tools_dict, client_tools)
schemas = executor.prepare_tools_for_llm(tools_dict)
assert len(schemas) == 2
names = {s["function"]["name"] for s in schemas}
assert "do_thing" in names
assert "local_fn" in names
# ---------------------------------------------------------------------------
# get_tools auto-merges client_tools
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGetToolsAutoMerge:
def test_get_tools_merges_client_tools(self, mock_mongo_db):
executor = ToolExecutor(user="alice")
executor.client_tools = [
{
"type": "function",
"function": {"name": "my_fn", "description": "test"},
}
]
tools = executor.get_tools()
assert any(
t.get("client_side") is True for t in tools.values()
), "Client tools should be merged into tools_dict"
def test_get_tools_no_client_tools(self, mock_mongo_db):
executor = ToolExecutor(user="alice")
tools = executor.get_tools()
assert not any(
t.get("client_side") for t in tools.values()
)
# ---------------------------------------------------------------------------
# check_pause for client-side tools
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCheckPauseClientTools:
def _make_call(self, name="action_0", call_id="c1"):
call = Mock()
call.name = name
call.id = call_id
call.arguments = "{}"
call.thought_signature = None
return call
def test_client_tool_triggers_pause(self):
executor = ToolExecutor()
tools_dict = {
"ct0": {
"name": "get_weather",
"client_side": True,
"actions": [
{"name": "get_weather", "active": True, "parameters": {}},
],
}
}
executor.prepare_tools_for_llm(tools_dict)
call = self._make_call(name="get_weather")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "requires_client_execution"
assert result["tool_name"] == "get_weather"
assert result["tool_id"] == "ct0"
def test_server_tool_no_pause(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "brave",
"actions": [
{"name": "search", "active": True, "parameters": {}},
],
}
}
executor.prepare_tools_for_llm(tools_dict)
call = self._make_call(name="search")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is None
# ---------------------------------------------------------------------------
# Handler flow: client tool causes pause
# ---------------------------------------------------------------------------
class ConcreteHandler(LLMHandler):
"""Minimal concrete handler for testing."""
def parse_response(self, response):
return LLMResponse(
content=str(response), tool_calls=[], finish_reason="stop",
raw_response=response,
)
def create_tool_message(self, tool_call, result):
return {"role": "tool", "content": str(result)}
def _iterate_stream(self, response):
for chunk in response:
yield chunk
@pytest.mark.unit
class TestHandlerClientToolPause:
def test_client_tool_pauses_stream(self):
"""When LLM calls a client-side tool, handler yields tool_calls_pending."""
handler = ConcreteHandler()
agent = Mock()
agent.llm = Mock()
agent.model_id = "test"
agent.tools = []
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
# check_pause returns pause info for client tool
agent.tool_executor.check_pause = Mock(return_value={
"call_id": "c1",
"name": "get_weather",
"tool_name": "get_weather",
"tool_id": "ct0",
"action_name": "get_weather",
"llm_name": "get_weather",
"arguments": {"city": "SF"},
"pause_type": "requires_client_execution",
"thought_signature": None,
})
agent.tool_executor._name_to_tool = {"get_weather": ("ct0", "get_weather")}
# Simulate streaming: one chunk with tool_calls finish_reason
chunk = LLMResponse(
content="",
tool_calls=[ToolCall(id="c1", name="get_weather", arguments='{"city": "SF"}', index=0)],
finish_reason="tool_calls",
raw_response={},
)
handler.parse_response = lambda c: c
handler._iterate_stream = lambda r: iter(r)
gen = handler.handle_streaming(
agent, [chunk], {"ct0": {"name": "get_weather", "client_side": True}}, []
)
events = list(gen)
# Should have a requires_client_execution event
client_events = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "requires_client_execution"
]
assert len(client_events) == 1
# Should have a tool_calls_pending event
pending_events = [
e for e in events
if isinstance(e, dict) and e.get("type") == "tool_calls_pending"
]
assert len(pending_events) == 1
def test_mixed_server_and_client_tools_in_batch(self):
"""Server tool executes, client tool pauses."""
handler = ConcreteHandler()
agent = Mock()
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
call_count = {"n": 0}
def check_pause_fn(tools_dict, call, llm_class):
call_count["n"] += 1
if call_count["n"] == 2: # Second tool is client-side
return {
"call_id": "c2",
"name": "get_weather",
"tool_name": "get_weather",
"tool_id": "ct0",
"action_name": "get_weather",
"llm_name": "get_weather",
"arguments": {},
"pause_type": "requires_client_execution",
"thought_signature": None,
}
return None
agent.tool_executor.check_pause = Mock(side_effect=check_pause_fn)
agent.tool_executor._name_to_tool = {
"search": ("0", "search"),
"get_weather": ("ct0", "get_weather"),
}
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
return ("server result", call.id)
agent._execute_tool_action = Mock(side_effect=fake_execute)
calls = [
ToolCall(id="c1", name="search", arguments="{}"),
ToolCall(id="c2", name="get_weather", arguments="{}"),
]
gen = handler.handle_tool_calls(
agent,
calls,
{
"0": {"name": "search"},
"ct0": {"name": "get_weather", "client_side": True},
},
[],
)
events = []
messages = None
pending = None
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
# Server tool executed
assert agent._execute_tool_action.call_count == 1
# Client tool pending
assert pending is not None
assert len(pending) == 1
assert pending[0]["pause_type"] == "requires_client_execution"

667
tests/test_continuation.py Normal file
View File

@@ -0,0 +1,667 @@
"""Tests for the continuation infrastructure (Phase 1).
Covers ContinuationService, ToolExecutor.check_pause, handler pause
signaling, BaseAgent.gen_continuation, and request validation.
"""
from unittest.mock import Mock
import pytest
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
# ---------------------------------------------------------------------------
# ContinuationService
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestContinuationService:
def test_save_and_load(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
svc.save_state(
conversation_id="conv-1",
user="alice",
messages=[{"role": "user", "content": "hi"}],
pending_tool_calls=[{"call_id": "c1", "pause_type": "awaiting_approval"}],
tools_dict={"0": {"name": "test_tool"}},
tool_schemas=[{"type": "function", "function": {"name": "act_0"}}],
agent_config={"model_id": "gpt-4"},
)
state = svc.load_state("conv-1", "alice")
assert state is not None
assert state["conversation_id"] == "conv-1"
assert state["user"] == "alice"
assert len(state["messages"]) == 1
assert len(state["pending_tool_calls"]) == 1
assert state["agent_config"]["model_id"] == "gpt-4"
def test_load_returns_none_when_missing(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
assert svc.load_state("nonexistent", "alice") is None
def test_delete_state(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
svc.save_state(
conversation_id="conv-2",
user="bob",
messages=[],
pending_tool_calls=[],
tools_dict={},
tool_schemas=[],
agent_config={},
)
assert svc.delete_state("conv-2", "bob") is True
assert svc.load_state("conv-2", "bob") is None
def test_delete_nonexistent(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
assert svc.delete_state("nope", "nope") is False
def test_upsert_replaces_existing(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
svc.save_state(
conversation_id="conv-3",
user="carol",
messages=[{"role": "user", "content": "v1"}],
pending_tool_calls=[],
tools_dict={},
tool_schemas=[],
agent_config={},
)
svc.save_state(
conversation_id="conv-3",
user="carol",
messages=[{"role": "user", "content": "v2"}],
pending_tool_calls=[{"call_id": "c2"}],
tools_dict={},
tool_schemas=[],
agent_config={},
)
state = svc.load_state("conv-3", "carol")
assert state["messages"][0]["content"] == "v2"
assert len(state["pending_tool_calls"]) == 1
# ---------------------------------------------------------------------------
# ToolExecutor.check_pause
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCheckPause:
def _make_call(self, name="action_0", call_id="c1", arguments="{}"):
call = Mock()
call.name = name
call.id = call_id
call.arguments = arguments
call.thought_signature = None
return call
def test_returns_none_for_normal_tool(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "brave",
"actions": [
{"name": "search", "active": True, "parameters": {}},
],
}
}
call = self._make_call(name="search_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is None
def test_returns_pause_for_client_side_tool(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "get_weather",
"client_side": True,
"actions": [
{"name": "get_weather", "active": True, "parameters": {}},
],
}
}
call = self._make_call(name="get_weather_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "requires_client_execution"
assert result["call_id"] == "c1"
assert result["tool_id"] == "0"
def test_returns_pause_for_approval_required(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "telegram",
"actions": [
{
"name": "send_msg",
"active": True,
"require_approval": True,
"parameters": {},
},
],
}
}
call = self._make_call(name="send_msg_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "awaiting_approval"
def test_returns_none_when_parse_fails(self):
executor = ToolExecutor()
call = self._make_call(name="bad_name_no_id", arguments="not json")
# Bad arguments will cause parse error -> None
result = executor.check_pause({}, call, "OpenAILLM")
assert result is None
def test_returns_none_when_tool_not_in_dict(self):
executor = ToolExecutor()
call = self._make_call(name="action_99")
result = executor.check_pause({"0": {"name": "t"}}, call, "OpenAILLM")
assert result is None
def test_api_tool_approval(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "api_tool",
"config": {
"actions": {
"delete_user": {
"name": "delete_user",
"require_approval": True,
"url": "http://example.com",
"method": "DELETE",
"active": True,
}
}
},
}
}
call = self._make_call(name="delete_user_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "awaiting_approval"
# ---------------------------------------------------------------------------
# Handler pause signaling (handle_tool_calls returns pending_actions)
# ---------------------------------------------------------------------------
class ConcreteHandler(LLMHandler):
"""Minimal concrete handler for testing."""
def parse_response(self, response):
return LLMResponse(
content=str(response), tool_calls=[], finish_reason="stop",
raw_response=response,
)
def create_tool_message(self, tool_call, result):
return {
"role": "tool",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
"call_id": tool_call.id,
}
}
],
}
def _iterate_stream(self, response):
for chunk in response:
yield chunk
@pytest.mark.unit
class TestHandlerPauseSignaling:
def _make_agent(self):
agent = Mock()
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
return ("tool result", call.id)
agent._execute_tool_action = Mock(side_effect=fake_execute)
return agent
def test_no_pause_returns_none_pending(self):
handler = ConcreteHandler()
agent = self._make_agent()
call = ToolCall(id="c1", name="action_0", arguments="{}")
gen = handler.handle_tool_calls(agent, [call], {"0": {"name": "t"}}, [])
events = []
messages = None
pending = "NOT_SET"
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
assert pending is None
assert messages is not None
def test_pause_returns_pending_actions(self):
handler = ConcreteHandler()
agent = self._make_agent()
agent.tool_executor.check_pause = Mock(return_value={
"call_id": "c1",
"name": "send_msg_0",
"tool_name": "telegram",
"tool_id": "0",
"action_name": "send_msg",
"arguments": {"text": "hello"},
"pause_type": "awaiting_approval",
"thought_signature": None,
})
call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}')
gen = handler.handle_tool_calls(
agent, [call], {"0": {"name": "telegram"}}, []
)
events = []
pending = None
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
assert pending is not None
assert len(pending) == 1
assert pending[0]["pause_type"] == "awaiting_approval"
# Should have yielded a tool_call event with awaiting_approval status
pause_events = [
e for e in events
if e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "awaiting_approval"
]
assert len(pause_events) == 1
def test_mixed_execute_and_pause(self):
"""One tool executes, another needs approval."""
handler = ConcreteHandler()
agent = self._make_agent()
call_count = {"n": 0}
def selective_pause(tools_dict, call, llm_class):
call_count["n"] += 1
if call_count["n"] == 2:
return {
"call_id": "c2",
"name": "danger_0",
"tool_name": "danger",
"tool_id": "0",
"action_name": "danger",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
return None
agent.tool_executor.check_pause = Mock(side_effect=selective_pause)
calls = [
ToolCall(id="c1", name="safe_0", arguments="{}"),
ToolCall(id="c2", name="danger_0", arguments="{}"),
]
gen = handler.handle_tool_calls(
agent, calls, {"0": {"name": "multi"}}, []
)
events = []
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
# First tool was executed normally
assert agent._execute_tool_action.call_count == 1
# Second tool is pending
assert pending is not None
assert len(pending) == 1
assert pending[0]["call_id"] == "c2"
# ---------------------------------------------------------------------------
# handle_streaming yields tool_calls_pending
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestStreamingPause:
def test_streaming_yields_tool_calls_pending(self):
handler = ConcreteHandler()
agent = Mock()
agent.llm = Mock()
agent.model_id = "test"
agent.tools = []
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
pause_info = {
"call_id": "c1",
"name": "fn_0",
"tool_name": "test",
"tool_id": "0",
"action_name": "fn",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
agent.tool_executor.check_pause = Mock(return_value=pause_info)
chunk = LLMResponse(
content="",
tool_calls=[ToolCall(id="c1", name="fn_0", arguments="{}", index=0)],
finish_reason="tool_calls",
raw_response={},
)
handler.parse_response = lambda c: c
def fake_iterate(response):
yield from response
handler._iterate_stream = fake_iterate
gen = handler.handle_streaming(agent, [chunk], {"0": {"name": "t"}}, [])
events = list(gen)
# Should contain a tool_calls_pending event
pending_events = [
e for e in events
if isinstance(e, dict) and e.get("type") == "tool_calls_pending"
]
assert len(pending_events) == 1
assert len(pending_events[0]["data"]["pending_tool_calls"]) == 1
# Agent should have _pending_continuation set
assert hasattr(agent, "_pending_continuation")
# ---------------------------------------------------------------------------
# BaseAgent.gen_continuation
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGenContinuation:
def test_approved_tool_executes(self):
"""When a tool action is approved, the tool is executed."""
from application.agents.classic_agent import ClassicAgent
mock_llm = Mock()
mock_llm._supports_tools = True
mock_llm.gen_stream = Mock(return_value=iter(["Final answer"]))
mock_llm._supports_structured_output = Mock(return_value=False)
mock_llm.__class__.__name__ = "MockLLM"
mock_handler = Mock()
mock_handler.process_message_flow = Mock(return_value=iter([]))
mock_handler.create_tool_message = Mock(
return_value={"role": "tool", "content": [{"function_response": {
"name": "act_0", "response": {"result": "done"}, "call_id": "c1"
}}]}
)
mock_executor = Mock()
mock_executor.tool_calls = []
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
def fake_execute(tools_dict, call, llm_class):
yield {"type": "tool_call", "data": {"status": "pending"}}
return ("result_data", "c1")
mock_executor.execute = Mock(side_effect=fake_execute)
agent = ClassicAgent(
endpoint="stream",
llm_name="openai",
model_id="gpt-4",
api_key="test",
llm=mock_llm,
llm_handler=mock_handler,
tool_executor=mock_executor,
)
messages = [{"role": "system", "content": "You are helpful."}]
tools_dict = {"0": {"name": "test_tool"}}
pending = [
{
"call_id": "c1",
"name": "act_0",
"tool_name": "test_tool",
"tool_id": "0",
"action_name": "act",
"arguments": {"q": "test"},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
tool_actions = [{"call_id": "c1", "decision": "approved"}]
list(agent.gen_continuation(messages, tools_dict, pending, tool_actions))
# Tool should have been executed
assert mock_executor.execute.called
def test_denied_tool_sends_denial(self):
"""When a tool action is denied, a denial message is added."""
from application.agents.classic_agent import ClassicAgent
mock_llm = Mock()
mock_llm._supports_tools = True
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
mock_llm._supports_structured_output = Mock(return_value=False)
mock_llm.__class__.__name__ = "MockLLM"
mock_handler = Mock()
mock_handler.process_message_flow = Mock(return_value=iter([]))
mock_handler.create_tool_message = Mock(
return_value={"role": "tool", "content": "denied"}
)
mock_executor = Mock()
mock_executor.tool_calls = []
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
agent = ClassicAgent(
endpoint="stream",
llm_name="openai",
model_id="gpt-4",
api_key="test",
llm=mock_llm,
llm_handler=mock_handler,
tool_executor=mock_executor,
)
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "danger_0",
"tool_name": "danger",
"tool_id": "0",
"action_name": "danger",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
tool_actions = [
{"call_id": "c1", "decision": "denied", "comment": "too risky"}
]
events = list(
agent.gen_continuation(messages, {"0": {"name": "danger"}}, pending, tool_actions)
)
# Should have a denied tool_call event
denied = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "denied"
]
assert len(denied) == 1
# create_tool_message should have been called with denial text
denial_arg = mock_handler.create_tool_message.call_args[0][1]
assert "denied" in denial_arg.lower()
assert "too risky" in denial_arg
def test_client_result_appended(self):
"""Client-provided tool result is added to messages."""
from application.agents.classic_agent import ClassicAgent
mock_llm = Mock()
mock_llm._supports_tools = True
mock_llm.gen_stream = Mock(return_value=iter(["Done"]))
mock_llm._supports_structured_output = Mock(return_value=False)
mock_llm.__class__.__name__ = "MockLLM"
mock_handler = Mock()
mock_handler.process_message_flow = Mock(return_value=iter([]))
mock_handler.create_tool_message = Mock(
return_value={"role": "tool", "content": "client result"}
)
mock_executor = Mock()
mock_executor.tool_calls = []
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
agent = ClassicAgent(
endpoint="stream",
llm_name="openai",
model_id="gpt-4",
api_key="test",
llm=mock_llm,
llm_handler=mock_handler,
tool_executor=mock_executor,
)
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "weather_0",
"tool_name": "weather",
"tool_id": "0",
"action_name": "weather",
"arguments": {"city": "SF"},
"pause_type": "requires_client_execution",
"thought_signature": None,
}
]
tool_actions = [{"call_id": "c1", "result": {"temp": "72F"}}]
events = list(
agent.gen_continuation(messages, {"0": {"name": "weather"}}, pending, tool_actions)
)
# create_tool_message was called with the client result
result_arg = mock_handler.create_tool_message.call_args[0][1]
assert "72F" in result_arg
# Should have a completed tool_call event
completed = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "completed"
]
assert len(completed) == 1
# ---------------------------------------------------------------------------
# validate_request
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestValidateRequest:
@pytest.fixture(autouse=True)
def _app_context(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
yield
def test_continuation_request_without_question(self, mock_mongo_db):
from application.api.answer.routes.base import BaseAnswerResource
base = BaseAnswerResource()
data = {
"conversation_id": "conv-1",
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
}
result = base.validate_request(data)
assert result is None # Valid
def test_continuation_request_missing_conversation_id(self, mock_mongo_db):
from application.api.answer.routes.base import BaseAnswerResource
base = BaseAnswerResource()
data = {
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
}
result = base.validate_request(data)
assert result is not None # Error — missing conversation_id
def test_normal_request_still_requires_question(self, mock_mongo_db):
from application.api.answer.routes.base import BaseAnswerResource
base = BaseAnswerResource()
data = {"conversation_id": "conv-1"}
result = base.validate_request(data)
assert result is not None # Error — missing question

View File

@@ -1238,7 +1238,8 @@ class TestEmbeddingPipelineAddDocWithRetry:
# NUL characters should be removed
assert "\x00" not in doc.page_content
def test_add_text_to_store_with_retry_failure(self):
@patch("time.sleep", return_value=None)
def test_add_text_to_store_with_retry_failure(self, _mock_sleep):
from application.parser.embedding_pipeline import add_text_to_store_with_retry
mock_store = MagicMock()

481
tests/test_tool_approval.py Normal file
View File

@@ -0,0 +1,481 @@
"""Tests for tool approval (Phase 3).
Covers require_approval flag, check_pause for approval, the handler
pause/resume flow, and gen_continuation with approved/denied actions.
"""
from unittest.mock import Mock
import pytest
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
# ---------------------------------------------------------------------------
# check_pause with require_approval
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCheckPauseApproval:
def _make_call(self, name="action_0", call_id="c1"):
call = Mock()
call.name = name
call.id = call_id
call.arguments = "{}"
call.thought_signature = None
return call
def test_approval_required_triggers_pause(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "telegram",
"actions": [
{
"name": "send_msg",
"active": True,
"require_approval": True,
"parameters": {},
},
],
}
}
call = self._make_call(name="send_msg_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "awaiting_approval"
assert result["tool_name"] == "telegram"
assert result["action_name"] == "send_msg"
assert result["tool_id"] == "0"
def test_approval_not_required_no_pause(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "brave",
"actions": [
{
"name": "search",
"active": True,
"require_approval": False,
"parameters": {},
},
],
}
}
call = self._make_call(name="search_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is None
def test_approval_absent_defaults_to_false(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "brave",
"actions": [
{
"name": "search",
"active": True,
"parameters": {},
},
],
}
}
call = self._make_call(name="search_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is None
def test_api_tool_approval(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "api_tool",
"config": {
"actions": {
"delete_user": {
"name": "delete_user",
"require_approval": True,
"url": "http://example.com",
"method": "DELETE",
"active": True,
}
}
},
}
}
call = self._make_call(name="delete_user_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "awaiting_approval"
def test_api_tool_no_approval(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "api_tool",
"config": {
"actions": {
"list_users": {
"name": "list_users",
"url": "http://example.com",
"method": "GET",
"active": True,
}
}
},
}
}
call = self._make_call(name="list_users_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is None
# ---------------------------------------------------------------------------
# Handler: approval tool causes pause signal
# ---------------------------------------------------------------------------
class ConcreteHandler(LLMHandler):
def parse_response(self, response):
return LLMResponse(
content=str(response), tool_calls=[], finish_reason="stop",
raw_response=response,
)
def create_tool_message(self, tool_call, result):
import json as _json
content = _json.dumps(result) if not isinstance(result, str) else result
return {"role": "tool", "tool_call_id": tool_call.id, "content": content}
def _iterate_stream(self, response):
for chunk in response:
yield chunk
@pytest.mark.unit
class TestHandlerApprovalPause:
def _make_agent(self, pause_return):
agent = Mock()
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=pause_return)
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
return ("tool result", call.id)
agent._execute_tool_action = Mock(side_effect=fake_execute)
return agent
def test_approval_tool_pauses(self):
handler = ConcreteHandler()
pause_info = {
"call_id": "c1",
"name": "send_msg_0",
"tool_name": "telegram",
"tool_id": "0",
"action_name": "send_msg",
"arguments": {"text": "hello"},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
agent = self._make_agent(pause_info)
call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}')
gen = handler.handle_tool_calls(
agent, [call], {"0": {"name": "telegram"}}, []
)
events = []
pending = None
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
assert pending is not None
assert len(pending) == 1
assert pending[0]["pause_type"] == "awaiting_approval"
# Should NOT have executed the tool
assert agent._execute_tool_action.call_count == 0
# Should have yielded awaiting_approval status
approval_events = [
e for e in events
if e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "awaiting_approval"
]
assert len(approval_events) == 1
def test_mixed_normal_and_approval(self):
"""First tool runs normally, second needs approval."""
handler = ConcreteHandler()
call_count = {"n": 0}
def selective_pause(tools_dict, call, llm_class):
call_count["n"] += 1
if call_count["n"] == 2:
return {
"call_id": "c2",
"name": "send_msg_0",
"tool_name": "telegram",
"tool_id": "0",
"action_name": "send_msg",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
return None
agent = self._make_agent(None)
agent.tool_executor.check_pause = Mock(side_effect=selective_pause)
calls = [
ToolCall(id="c1", name="search_0", arguments="{}"),
ToolCall(id="c2", name="send_msg_0", arguments="{}"),
]
gen = handler.handle_tool_calls(
agent, calls, {"0": {"name": "multi"}}, []
)
events = []
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
# First tool executed
assert agent._execute_tool_action.call_count == 1
# Second tool is pending
assert pending is not None
assert len(pending) == 1
assert pending[0]["call_id"] == "c2"
# ---------------------------------------------------------------------------
# gen_continuation: approval and denial flows
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGenContinuationApproval:
def _make_agent(self):
from application.agents.classic_agent import ClassicAgent
mock_llm = Mock()
mock_llm._supports_tools = True
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
mock_llm._supports_structured_output = Mock(return_value=False)
mock_llm.__class__.__name__ = "MockLLM"
mock_handler = Mock()
mock_handler.process_message_flow = Mock(return_value=iter([]))
mock_handler.create_tool_message = Mock(
return_value={"role": "tool", "tool_call_id": "c1", "content": "result"}
)
mock_executor = Mock()
mock_executor.tool_calls = []
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
def fake_execute(tools_dict, call, llm_class):
yield {"type": "tool_call", "data": {"status": "pending"}}
return ("executed_result", "c1")
mock_executor.execute = Mock(side_effect=fake_execute)
agent = ClassicAgent(
endpoint="stream",
llm_name="openai",
model_id="gpt-4",
api_key="test",
llm=mock_llm,
llm_handler=mock_handler,
tool_executor=mock_executor,
)
return agent, mock_executor, mock_handler
def test_approved_tool_executes(self):
agent, mock_executor, mock_handler = self._make_agent()
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "send_msg_0",
"tool_name": "telegram",
"tool_id": "0",
"action_name": "send_msg",
"arguments": {"text": "hello"},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
tool_actions = [{"call_id": "c1", "decision": "approved"}]
list(agent.gen_continuation(
messages, {"0": {"name": "telegram"}}, pending, tool_actions
))
# Tool should have been executed
assert mock_executor.execute.called
def test_denied_tool_sends_denial_to_llm(self):
agent, mock_executor, mock_handler = self._make_agent()
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "send_msg_0",
"tool_name": "telegram",
"tool_id": "0",
"action_name": "send_msg",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
tool_actions = [
{"call_id": "c1", "decision": "denied", "comment": "not safe"},
]
events = list(agent.gen_continuation(
messages, {"0": {"name": "telegram"}}, pending, tool_actions
))
# Tool should NOT have been executed
assert not mock_executor.execute.called
# Should have a denied event
denied = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "denied"
]
assert len(denied) == 1
# create_tool_message should have been called with denial text
denial_text = mock_handler.create_tool_message.call_args[0][1]
assert "denied" in denial_text.lower()
assert "not safe" in denial_text
def test_denied_without_comment(self):
agent, mock_executor, mock_handler = self._make_agent()
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "act_0",
"tool_name": "tool",
"tool_id": "0",
"action_name": "act",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
tool_actions = [{"call_id": "c1", "decision": "denied"}]
list(agent.gen_continuation(
messages, {"0": {"name": "tool"}}, pending, tool_actions
))
denial_text = mock_handler.create_tool_message.call_args[0][1]
assert "denied" in denial_text.lower()
def test_mixed_approve_deny_batch(self):
"""Two tools: one approved, one denied."""
agent, mock_executor, mock_handler = self._make_agent()
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "safe_0",
"tool_name": "safe",
"tool_id": "0",
"action_name": "safe",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
},
{
"call_id": "c2",
"name": "danger_0",
"tool_name": "danger",
"tool_id": "0",
"action_name": "danger",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
},
]
tool_actions = [
{"call_id": "c1", "decision": "approved"},
{"call_id": "c2", "decision": "denied", "comment": "too risky"},
]
events = list(agent.gen_continuation(
messages, {"0": {"name": "multi"}}, pending, tool_actions
))
# First tool executed, second denied
assert mock_executor.execute.call_count == 1
denied = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "denied"
]
assert len(denied) == 1
def test_missing_action_defaults_to_denial(self):
"""If client doesn't respond for a pending tool, treat as denied."""
agent, mock_executor, mock_handler = self._make_agent()
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "act_0",
"tool_name": "tool",
"tool_id": "0",
"action_name": "act",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
# Empty tool_actions — no response for c1
tool_actions = []
events = list(agent.gen_continuation(
messages, {"0": {"name": "tool"}}, pending, tool_actions
))
# Should have been treated as denied
assert not mock_executor.execute.called
denied = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "denied"
]
assert len(denied) == 1

577
tests/test_v1_translator.py Normal file
View File

@@ -0,0 +1,577 @@
"""Tests for the v1 API translator (Phase 4).
Covers request translation, response translation, streaming event
translation, continuation detection, and history conversion.
"""
import json
import pytest
from application.api.v1.translator import (
_get_client_tool_name,
convert_history,
extract_tool_results,
is_continuation,
translate_request,
translate_response,
translate_stream_event,
)
# ---------------------------------------------------------------------------
# is_continuation
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestIsContinuation:
def test_normal_messages_not_continuation(self):
messages = [
{"role": "user", "content": "Hello"},
]
assert is_continuation(messages) is False
def test_tool_after_assistant_tool_calls_is_continuation(self):
messages = [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}],
},
{"role": "tool", "tool_call_id": "c1", "content": '{"temp": "72F"}'},
]
assert is_continuation(messages) is True
def test_assistant_without_tool_calls_not_continuation(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
{"role": "tool", "tool_call_id": "c1", "content": "result"},
]
# assistant has no tool_calls — not a valid continuation
assert is_continuation(messages) is False
def test_empty_messages(self):
assert is_continuation([]) is False
def test_multiple_tool_results(self):
messages = [
{"role": "user", "content": "Do stuff"},
{
"role": "assistant",
"tool_calls": [
{"id": "c1", "type": "function", "function": {"name": "a", "arguments": "{}"}},
{"id": "c2", "type": "function", "function": {"name": "b", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": "c1", "content": "r1"},
{"role": "tool", "tool_call_id": "c2", "content": "r2"},
]
assert is_continuation(messages) is True
# ---------------------------------------------------------------------------
# extract_tool_results
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestExtractToolResults:
def test_extracts_results(self):
messages = [
{"role": "assistant", "tool_calls": [{"id": "c1"}]},
{"role": "tool", "tool_call_id": "c1", "content": '{"temp": "72F"}'},
]
results = extract_tool_results(messages)
assert len(results) == 1
assert results[0]["call_id"] == "c1"
assert results[0]["result"] == {"temp": "72F"}
def test_string_content(self):
messages = [
{"role": "tool", "tool_call_id": "c1", "content": "plain text"},
]
results = extract_tool_results(messages)
assert results[0]["result"] == "plain text"
def test_multiple_results(self):
messages = [
{"role": "assistant", "tool_calls": []},
{"role": "tool", "tool_call_id": "c1", "content": "r1"},
{"role": "tool", "tool_call_id": "c2", "content": "r2"},
]
results = extract_tool_results(messages)
assert len(results) == 2
assert results[0]["call_id"] == "c1"
assert results[1]["call_id"] == "c2"
# ---------------------------------------------------------------------------
# convert_history
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestConvertHistory:
def test_user_assistant_pairs(self):
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "I'm good"},
{"role": "user", "content": "What's 2+2?"}, # Last user = question
]
history = convert_history(messages)
assert len(history) == 2
assert history[0]["prompt"] == "Hello"
assert history[0]["response"] == "Hi there"
assert history[1]["prompt"] == "How are you?"
assert history[1]["response"] == "I'm good"
def test_single_user_message(self):
messages = [{"role": "user", "content": "Hi"}]
history = convert_history(messages)
assert history == []
def test_system_messages_skipped(self):
messages = [
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "Question"},
]
history = convert_history(messages)
assert history == []
# ---------------------------------------------------------------------------
# translate_request
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestTranslateRequest:
def test_normal_request(self):
data = {
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
{"role": "user", "content": "What's 2+2?"},
],
}
result = translate_request(data, "test-key")
assert result["question"] == "What's 2+2?"
assert result["api_key"] == "test-key"
assert result["save_conversation"] is True
history = json.loads(result["history"])
assert len(history) == 1
assert history[0]["prompt"] == "Hello"
def test_continuation_request(self):
data = {
"messages": [
{"role": "user", "content": "Search for X"},
{
"role": "assistant",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "search", "arguments": "{}"}}],
},
{"role": "tool", "tool_call_id": "c1", "content": '{"found": true}'},
],
}
result = translate_request(data, "key")
assert "tool_actions" in result
assert len(result["tool_actions"]) == 1
assert result["tool_actions"][0]["call_id"] == "c1"
def test_continuation_with_top_level_conversation_id(self):
"""Standard clients send conversation_id at request level, not in messages."""
data = {
"conversation_id": "conv-top-level",
"messages": [
{"role": "user", "content": "Do stuff"},
{
"role": "assistant",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}],
},
{"role": "tool", "tool_call_id": "c1", "content": "done"},
],
}
result = translate_request(data, "key")
assert result["conversation_id"] == "conv-top-level"
def test_continuation_in_message_conversation_id_takes_precedence(self):
"""When both in-message and top-level conversation_id exist, in-message wins."""
data = {
"conversation_id": "conv-top-level",
"messages": [
{"role": "user", "content": "Do stuff"},
{
"role": "assistant",
"tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}],
"docsgpt": {"conversation_id": "conv-in-message"},
},
{"role": "tool", "tool_call_id": "c1", "content": "done"},
],
}
result = translate_request(data, "key")
assert result["conversation_id"] == "conv-in-message"
def test_client_tools_passed_through(self):
data = {
"messages": [{"role": "user", "content": "Hi"}],
"tools": [{"type": "function", "function": {"name": "my_tool"}}],
}
result = translate_request(data, "key")
assert result["client_tools"] == data["tools"]
def test_docsgpt_attachments(self):
data = {
"messages": [{"role": "user", "content": "Hi"}],
"docsgpt": {"attachments": ["att1", "att2"]},
}
result = translate_request(data, "key")
assert result["attachments"] == ["att1", "att2"]
# ---------------------------------------------------------------------------
# translate_response
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestTranslateResponse:
def test_basic_response(self):
resp = translate_response(
conversation_id="conv-1",
answer="Hello!",
sources=[],
tool_calls=[],
thought="",
model_name="my-agent",
)
assert resp["id"] == "chatcmpl-conv-1"
assert resp["object"] == "chat.completion"
assert resp["model"] == "my-agent"
assert resp["choices"][0]["message"]["content"] == "Hello!"
assert resp["choices"][0]["finish_reason"] == "stop"
assert "reasoning_content" not in resp["choices"][0]["message"]
def test_response_with_thought(self):
resp = translate_response(
conversation_id="c1",
answer="Result",
sources=[],
tool_calls=[],
thought="Thinking about it...",
model_name="agent",
)
assert resp["choices"][0]["message"]["reasoning_content"] == "Thinking about it..."
def test_response_with_sources(self):
sources = [{"title": "doc.txt", "text": "content", "source": "/doc.txt"}]
resp = translate_response(
conversation_id="c1",
answer="Found it",
sources=sources,
tool_calls=[],
thought="",
model_name="agent",
)
assert resp["docsgpt"]["sources"] == sources
def test_response_with_tool_calls(self):
tool_calls = [{"tool_name": "notes", "call_id": "c1", "artifact_id": "a1"}]
resp = translate_response(
conversation_id="c1",
answer="Done",
sources=[],
tool_calls=tool_calls,
thought="",
model_name="agent",
)
assert resp["docsgpt"]["tool_calls"] == tool_calls
def test_pending_tool_calls_uses_tool_name(self):
"""Client tool responses use the original tool_name, not the LLM-visible action_name."""
pending = [
{
"call_id": "c1",
"tool_name": "get_weather",
"action_name": "get_weather",
"arguments": {"city": "SF"},
}
]
resp = translate_response(
conversation_id="c1",
answer="",
sources=[],
tool_calls=[],
thought="",
model_name="agent",
pending_tool_calls=pending,
)
tc = resp["choices"][0]["message"]["tool_calls"][0]
assert tc["function"]["name"] == "get_weather"
def test_pending_tool_calls_tool_name_takes_precedence(self):
"""When tool_name differs from action_name, tool_name is used."""
pending = [
{
"call_id": "c1",
"tool_name": "search",
"action_name": "search_1",
"arguments": {"q": "test"},
}
]
resp = translate_response(
conversation_id="c1",
answer="",
sources=[],
tool_calls=[],
thought="",
model_name="agent",
pending_tool_calls=pending,
)
tc = resp["choices"][0]["message"]["tool_calls"][0]
assert tc["function"]["name"] == "search"
def test_pending_tool_calls(self):
pending = [
{
"call_id": "c1",
"name": "get_weather",
"arguments": {"city": "SF"},
}
]
resp = translate_response(
conversation_id="c1",
answer="",
sources=[],
tool_calls=[],
thought="",
model_name="agent",
pending_tool_calls=pending,
)
assert resp["choices"][0]["finish_reason"] == "tool_calls"
assert resp["choices"][0]["message"]["content"] is None
assert len(resp["choices"][0]["message"]["tool_calls"]) == 1
tc = resp["choices"][0]["message"]["tool_calls"][0]
assert tc["id"] == "c1"
assert tc["function"]["name"] == "get_weather"
def test_no_docsgpt_when_empty(self):
resp = translate_response(
conversation_id="",
answer="Hi",
sources=None,
tool_calls=None,
thought="",
model_name="agent",
)
assert "docsgpt" not in resp
# ---------------------------------------------------------------------------
# translate_stream_event
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestTranslateStreamEvent:
def test_answer_event(self):
chunks = translate_stream_event(
{"type": "answer", "answer": "Hello"},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["choices"][0]["delta"]["content"] == "Hello"
def test_thought_event(self):
chunks = translate_stream_event(
{"type": "thought", "thought": "reasoning"},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["choices"][0]["delta"]["reasoning_content"] == "reasoning"
def test_source_event(self):
chunks = translate_stream_event(
{"type": "source", "source": [{"title": "t", "text": "x"}]},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["docsgpt"]["type"] == "source"
assert len(parsed["docsgpt"]["sources"]) == 1
def test_end_event(self):
chunks = translate_stream_event(
{"type": "end"},
"chatcmpl-1", "agent",
)
assert len(chunks) == 2
# First chunk: finish_reason stop
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["choices"][0]["finish_reason"] == "stop"
# Second chunk: [DONE]
assert chunks[1].strip() == "data: [DONE]"
def test_tool_call_client_execution(self):
chunks = translate_stream_event(
{
"type": "tool_call",
"data": {
"call_id": "c1",
"action_name": "get_weather",
"arguments": {"city": "SF"},
"status": "requires_client_execution",
},
},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
tc = parsed["choices"][0]["delta"]["tool_calls"][0]
assert tc["id"] == "c1"
assert tc["function"]["name"] == "get_weather"
def test_tool_call_client_execution_uses_tool_name(self):
"""Streaming tool calls use tool_name (original name) for client responses."""
chunks = translate_stream_event(
{
"type": "tool_call",
"data": {
"call_id": "c1",
"tool_name": "create",
"action_name": "create",
"arguments": {"title": "test"},
"status": "requires_client_execution",
},
},
"chatcmpl-1", "agent",
)
parsed = json.loads(chunks[0].replace("data: ", "").strip())
tc = parsed["choices"][0]["delta"]["tool_calls"][0]
assert tc["function"]["name"] == "create"
def test_tool_call_completed(self):
chunks = translate_stream_event(
{
"type": "tool_call",
"data": {
"call_id": "c1",
"status": "completed",
"result": "done",
"artifact_id": "a1",
},
},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["docsgpt"]["type"] == "tool_call"
assert parsed["docsgpt"]["data"]["artifact_id"] == "a1"
def test_tool_calls_pending(self):
chunks = translate_stream_event(
{
"type": "tool_calls_pending",
"data": {"pending_tool_calls": [{"call_id": "c1"}]},
},
"chatcmpl-1", "agent",
)
assert len(chunks) == 2
# Standard chunk with finish_reason tool_calls
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["choices"][0]["finish_reason"] == "tool_calls"
# Extension chunk
ext = json.loads(chunks[1].replace("data: ", "").strip())
assert ext["docsgpt"]["type"] == "tool_calls_pending"
def test_id_event(self):
chunks = translate_stream_event(
{"type": "id", "id": "conv-123"},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["docsgpt"]["conversation_id"] == "conv-123"
def test_error_event(self):
chunks = translate_stream_event(
{"type": "error", "error": "Something went wrong"},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["error"]["message"] == "Something went wrong"
def test_tool_calls_event_skipped(self):
"""The aggregate tool_calls event is redundant and should be skipped."""
chunks = translate_stream_event(
{"type": "tool_calls", "tool_calls": [{"call_id": "c1"}]},
"chatcmpl-1", "agent",
)
assert len(chunks) == 0
def test_research_events_skipped(self):
assert translate_stream_event(
{"type": "research_plan", "data": {}}, "id", "m"
) == []
assert translate_stream_event(
{"type": "research_progress", "data": {}}, "id", "m"
) == []
def test_awaiting_approval_as_extension(self):
chunks = translate_stream_event(
{
"type": "tool_call",
"data": {"call_id": "c1", "status": "awaiting_approval"},
},
"chatcmpl-1", "agent",
)
assert len(chunks) == 1
parsed = json.loads(chunks[0].replace("data: ", "").strip())
assert parsed["docsgpt"]["type"] == "tool_call"
def test_standard_clients_can_ignore_docsgpt(self):
"""Standard clients parse only 'choices' — docsgpt namespace is ignored."""
chunks = translate_stream_event(
{"type": "source", "source": [{"title": "t"}]},
"chatcmpl-1", "agent",
)
parsed = json.loads(chunks[0].replace("data: ", "").strip())
# No "choices" key — standard parsers skip this chunk entirely
assert "choices" not in parsed
# docsgpt key is present
assert "docsgpt" in parsed
# ---------------------------------------------------------------------------
# _get_client_tool_name
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGetClientToolName:
def test_uses_tool_name_when_present(self):
assert _get_client_tool_name({"tool_name": "create", "action_name": "create_1"}) == "create"
def test_falls_back_to_action_name(self):
assert _get_client_tool_name({"action_name": "get_weather"}) == "get_weather"
def test_falls_back_to_name(self):
assert _get_client_tool_name({"name": "search"}) == "search"
def test_returns_empty_when_no_fields(self):
assert _get_client_tool_name({}) == ""