diff --git a/application/agents/base.py b/application/agents/base.py index 15735c8c..596caa5a 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -1,7 +1,8 @@ +import json import logging import uuid from abc import ABC, abstractmethod -from typing import Dict, Generator, List, Optional +from typing import Any, Dict, Generator, List, Optional from application.agents.tool_executor import ToolExecutor from application.core.json_schema_utils import ( @@ -9,6 +10,7 @@ from application.core.json_schema_utils import ( normalize_json_schema_payload, ) from application.core.settings import settings +from application.llm.handlers.base import ToolCall from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext @@ -113,6 +115,153 @@ class BaseAgent(ABC): ) -> Generator[Dict, None, None]: pass + def gen_continuation( + self, + messages: List[Dict], + tools_dict: Dict, + pending_tool_calls: List[Dict], + tool_actions: List[Dict], + ) -> Generator[Dict, None, None]: + """Resume generation after tool actions are resolved. + + Processes the client-provided *tool_actions* (approvals, denials, + or client-side results), appends the resulting messages, then + hands back to the LLM to continue the conversation. + + Args: + messages: The saved messages array from the pause point. + tools_dict: The saved tools dictionary. + pending_tool_calls: The pending tool call descriptors from the pause. + tool_actions: Client-provided actions resolving the pending calls. + """ + self._prepare_tools(tools_dict) + + actions_by_id = {a["call_id"]: a for a in tool_actions} + + # Build a single assistant message containing all tool calls so + # the message history matches the format LLM providers expect + # (one assistant message with N tool_calls, followed by N tool results). + tc_objects: List[Dict[str, Any]] = [] + for pending in pending_tool_calls: + call_id = pending["call_id"] + args = pending["arguments"] + args_str = ( + json.dumps(args) if isinstance(args, dict) else (args or "{}") + ) + tc_obj: Dict[str, Any] = { + "id": call_id, + "type": "function", + "function": { + "name": pending["name"], + "arguments": args_str, + }, + } + if pending.get("thought_signature"): + tc_obj["thought_signature"] = pending["thought_signature"] + tc_objects.append(tc_obj) + + messages.append({ + "role": "assistant", + "content": None, + "tool_calls": tc_objects, + }) + + # Now process each pending call and append tool result messages + for pending in pending_tool_calls: + call_id = pending["call_id"] + args = pending["arguments"] + action = actions_by_id.get(call_id) + if not action: + action = { + "call_id": call_id, + "decision": "denied", + "comment": "No response provided", + } + + if action.get("decision") == "approved": + # Execute the tool server-side + tc = ToolCall( + id=call_id, + name=pending["name"], + arguments=( + json.dumps(args) if isinstance(args, dict) else args + ), + ) + tool_gen = self._execute_tool_action(tools_dict, tc) + tool_response = None + while True: + try: + event = next(tool_gen) + yield event + except StopIteration as e: + tool_response, _ = e.value + break + messages.append( + self.llm_handler.create_tool_message(tc, tool_response) + ) + + elif action.get("decision") == "denied": + comment = action.get("comment", "") + denial = ( + f"Tool execution denied by user. Reason: {comment}" + if comment + else "Tool execution denied by user." + ) + tc = ToolCall( + id=call_id, name=pending["name"], arguments=args + ) + messages.append( + self.llm_handler.create_tool_message(tc, denial) + ) + yield { + "type": "tool_call", + "data": { + "tool_name": pending.get("tool_name", "unknown"), + "call_id": call_id, + "action_name": pending.get("llm_name", pending["name"]), + "arguments": args, + "status": "denied", + }, + } + + elif "result" in action: + result = action["result"] + result_str = ( + json.dumps(result) + if not isinstance(result, str) + else result + ) + tc = ToolCall( + id=call_id, name=pending["name"], arguments=args + ) + messages.append( + self.llm_handler.create_tool_message(tc, result_str) + ) + yield { + "type": "tool_call", + "data": { + "tool_name": pending.get("tool_name", "unknown"), + "call_id": call_id, + "action_name": pending.get("llm_name", pending["name"]), + "arguments": args, + "result": ( + result_str[:50] + "..." + if len(result_str) > 50 + else result_str + ), + "status": "completed", + }, + } + + # Resume the LLM loop with the updated messages + llm_response = self._llm_gen(messages) + yield from self._handle_response( + llm_response, tools_dict, messages, None + ) + + yield {"sources": self.retrieved_docs} + yield {"tool_calls": self._get_truncated_tool_calls()} + # ---- Tool delegation (thin wrappers around ToolExecutor) ---- @property @@ -267,28 +416,35 @@ class BaseAgent(ABC): if "tool_calls" in i: for tool_call in i["tool_calls"]: call_id = tool_call.get("call_id") or str(uuid.uuid4()) - - function_call_dict = { - "function_call": { - "name": tool_call.get("action_name"), - "args": tool_call.get("arguments"), - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": tool_call.get("action_name"), - "response": {"result": tool_call.get("result")}, - "call_id": call_id, - } - } - - messages.append( - {"role": "assistant", "content": [function_call_dict]} + args = tool_call.get("arguments") + args_str = ( + json.dumps(args) + if isinstance(args, dict) + else (args or "{}") ) - messages.append( - {"role": "tool", "content": [function_response_dict]} + messages.append({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": tool_call.get("action_name", ""), + "arguments": args_str, + }, + }], + }) + result = tool_call.get("result") + result_str = ( + json.dumps(result) + if not isinstance(result, str) + else (result or "") ) + messages.append({ + "role": "tool", + "tool_call_id": call_id, + "content": result_str, + }) messages.append({"role": "user", "content": query}) return messages diff --git a/application/agents/research_agent.py b/application/agents/research_agent.py index 280fa2cd..9c7212b1 100644 --- a/application/agents/research_agent.py +++ b/application/agents/research_agent.py @@ -593,16 +593,22 @@ class ResearchAgent(BaseAgent): ) result = result_str - function_call_content = { - "function_call": { - "name": call.name, - "args": call.arguments, - "call_id": call_id, - } - } - messages.append( - {"role": "assistant", "content": [function_call_content]} + import json as _json + + args_str = ( + _json.dumps(call.arguments) + if isinstance(call.arguments, dict) + else call.arguments ) + messages.append({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": {"name": call.name, "arguments": args_str}, + }], + }) tool_message = self.llm_handler.create_tool_message(call, result) messages.append(tool_message) diff --git a/application/agents/tool_executor.py b/application/agents/tool_executor.py index 69739076..11095b10 100644 --- a/application/agents/tool_executor.py +++ b/application/agents/tool_executor.py @@ -1,6 +1,7 @@ import logging import uuid -from typing import Dict, List, Optional +from collections import Counter +from typing import Dict, List, Optional, Tuple from bson.objectid import ObjectId @@ -31,12 +32,23 @@ class ToolExecutor: self.tool_calls: List[Dict] = [] self._loaded_tools: Dict[str, object] = {} self.conversation_id: Optional[str] = None + self.client_tools: Optional[List[Dict]] = None + self._name_to_tool: Dict[str, Tuple[str, str]] = {} + self._tool_to_name: Dict[Tuple[str, str], str] = {} def get_tools(self) -> Dict[str, Dict]: - """Load tool configs from DB based on user context.""" + """Load tool configs from DB based on user context. + + If *client_tools* have been set on this executor, they are + automatically merged into the returned dict. + """ if self.user_api_key: - return self._get_tools_by_api_key(self.user_api_key) - return self._get_user_tools(self.user or "local") + tools = self._get_tools_by_api_key(self.user_api_key) + else: + tools = self._get_user_tools(self.user or "local") + if self.client_tools: + self.merge_client_tools(tools, self.client_tools) + return tools def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]: mongo = MongoDB.get_client() @@ -65,29 +77,123 @@ class ToolExecutor: user_tools = list(user_tools) return {str(i): tool for i, tool in enumerate(user_tools)} - def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]: - """Convert tool configs to LLM function schemas.""" - return [ - { - "type": "function", - "function": { - "name": f"{action['name']}_{tool_id}", - "description": action["description"], - "parameters": self._build_tool_parameters(action), - }, + def merge_client_tools( + self, tools_dict: Dict, client_tools: List[Dict] + ) -> Dict: + """Merge client-provided tool definitions into tools_dict. + + Client tools use the standard function-calling format:: + + [{"type": "function", "function": {"name": "get_weather", + "description": "...", "parameters": {...}}}] + + They are stored in *tools_dict* with ``client_side: True`` so that + :meth:`check_pause` returns a pause signal instead of trying to + execute them server-side. + + Args: + tools_dict: The mutable server tools dict (will be modified in place). + client_tools: List of tool definitions in function-calling format. + + Returns: + The updated *tools_dict* (same reference, for convenience). + """ + for i, ct in enumerate(client_tools): + func = ct.get("function", ct) # tolerate bare {"name":..} too + name = func.get("name", f"clienttool{i}") + tool_id = f"ct{i}" + + tools_dict[tool_id] = { + "name": name, + "client_side": True, + "actions": [ + { + "name": name, + "description": func.get("description", ""), + "active": True, + "parameters": func.get("parameters", {}), + } + ], } - for tool_id, tool in tools_dict.items() - if ( - (tool["name"] == "api_tool" and "actions" in tool.get("config", {})) - or (tool["name"] != "api_tool" and "actions" in tool) - ) - for action in ( + return tools_dict + + def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]: + """Convert tool configs to LLM function schemas. + + Action names are kept clean for the LLM: + - Unique action names appear as-is (e.g. ``get_weather``). + - Duplicate action names get numbered suffixes (e.g. ``search_1``, + ``search_2``). + + A reverse mapping is stored in ``_name_to_tool`` so that tool calls + can be routed back to the correct ``(tool_id, action_name)`` without + brittle string splitting. + """ + # Pass 1: collect entries and count action name occurrences + entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client) + name_counts: Counter = Counter() + + for tool_id, tool in tools_dict.items(): + is_api = tool["name"] == "api_tool" + is_client = tool.get("client_side", False) + + if is_api and "actions" not in tool.get("config", {}): + continue + if not is_api and "actions" not in tool: + continue + + actions = ( tool["config"]["actions"].values() - if tool["name"] == "api_tool" + if is_api else tool["actions"] ) - if action.get("active", True) - ] + + for action in actions: + if not action.get("active", True): + continue + entries.append((tool_id, action["name"], action, is_client)) + name_counts[action["name"]] += 1 + + # Pass 2: assign LLM-visible names and build mappings + self._name_to_tool = {} + self._tool_to_name = {} + collision_counters: Dict[str, int] = {} + all_llm_names: set = set() + + result = [] + for tool_id, action_name, action, is_client in entries: + if name_counts[action_name] == 1: + llm_name = action_name + else: + counter = collision_counters.get(action_name, 1) + candidate = f"{action_name}_{counter}" + # Skip if candidate collides with a unique action name + while candidate in all_llm_names or ( + candidate in name_counts and name_counts[candidate] == 1 + ): + counter += 1 + candidate = f"{action_name}_{counter}" + collision_counters[action_name] = counter + 1 + llm_name = candidate + + all_llm_names.add(llm_name) + self._name_to_tool[llm_name] = (tool_id, action_name) + self._tool_to_name[(tool_id, action_name)] = llm_name + + if is_client: + params = action.get("parameters", {}) + else: + params = self._build_tool_parameters(action) + + result.append({ + "type": "function", + "function": { + "name": llm_name, + "description": action.get("description", ""), + "parameters": params, + }, + }) + return result def _build_tool_parameters(self, action: Dict) -> Dict: params = {"type": "object", "properties": {}, "required": []} @@ -104,23 +210,81 @@ class ToolExecutor: params["required"].append(k) return params + def check_pause( + self, tools_dict: Dict, call, llm_class_name: str + ) -> Optional[Dict]: + """Check if a tool call requires pausing for approval or client execution. + + Returns a dict describing the pending action if pause is needed, None otherwise. + """ + parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool) + tool_id, action_name, call_args = parser.parse_args(call) + call_id = getattr(call, "id", None) or str(uuid.uuid4()) + llm_name = getattr(call, "name", "") + + if tool_id is None or action_name is None or tool_id not in tools_dict: + return None # Will be handled as error by execute() + + tool_data = tools_dict[tool_id] + + # Client-side tools + if tool_data.get("client_side"): + return { + "call_id": call_id, + "name": llm_name, + "tool_name": tool_data.get("name", "unknown"), + "tool_id": tool_id, + "action_name": action_name, + "llm_name": llm_name, + "arguments": call_args if isinstance(call_args, dict) else {}, + "pause_type": "requires_client_execution", + "thought_signature": getattr(call, "thought_signature", None), + } + + # Approval required + if tool_data["name"] == "api_tool": + action_data = tool_data.get("config", {}).get("actions", {}).get( + action_name, {} + ) + else: + action_data = next( + (a for a in tool_data.get("actions", []) if a["name"] == action_name), + {}, + ) + + if action_data.get("require_approval"): + return { + "call_id": call_id, + "name": llm_name, + "tool_name": tool_data.get("name", "unknown"), + "tool_id": tool_id, + "action_name": action_name, + "llm_name": llm_name, + "arguments": call_args if isinstance(call_args, dict) else {}, + "pause_type": "awaiting_approval", + "thought_signature": getattr(call, "thought_signature", None), + } + + return None + def execute(self, tools_dict: Dict, call, llm_class_name: str): """Execute a tool call. Yields status events, returns (result, call_id).""" - parser = ToolActionParser(llm_class_name) + parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool) tool_id, action_name, call_args = parser.parse_args(call) + llm_name = getattr(call, "name", "unknown") call_id = getattr(call, "id", None) or str(uuid.uuid4()) if tool_id is None or action_name is None: - error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}" + error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}" logger.error(error_message) tool_call_data = { "tool_name": "unknown", "call_id": call_id, - "action_name": getattr(call, "name", "unknown"), + "action_name": llm_name, "arguments": call_args or {}, - "result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}", + "result": f"Failed to parse tool call. Invalid tool name format: {llm_name}", } yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}} self.tool_calls.append(tool_call_data) @@ -133,7 +297,7 @@ class ToolExecutor: tool_call_data = { "tool_name": "unknown", "call_id": call_id, - "action_name": f"{action_name}_{tool_id}", + "action_name": llm_name, "arguments": call_args, "result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}", } @@ -144,7 +308,7 @@ class ToolExecutor: tool_call_data = { "tool_name": tools_dict[tool_id]["name"], "call_id": call_id, - "action_name": f"{action_name}_{tool_id}", + "action_name": llm_name, "arguments": call_args, } yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}} diff --git a/application/agents/tools/base.py b/application/agents/tools/base.py index fd7b4a85..dfe8c85d 100644 --- a/application/agents/tools/base.py +++ b/application/agents/tools/base.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod class Tool(ABC): + internal: bool = False + @abstractmethod def execute_action(self, action_name: str, **kwargs): pass diff --git a/application/agents/tools/internal_search.py b/application/agents/tools/internal_search.py index 2cd7915b..78001bf5 100644 --- a/application/agents/tools/internal_search.py +++ b/application/agents/tools/internal_search.py @@ -20,6 +20,8 @@ class InternalSearchTool(Tool): - list_files action: browse the file/folder structure """ + internal = True + def __init__(self, config: Dict): self.config = config self.retrieved_docs: List[Dict] = [] diff --git a/application/agents/tools/think.py b/application/agents/tools/think.py index 7c1fc2b6..24af553e 100644 --- a/application/agents/tools/think.py +++ b/application/agents/tools/think.py @@ -36,6 +36,8 @@ class ThinkTool(Tool): The reasoning content is captured in tool_call data for transparency. """ + internal = True + def __init__(self, config=None): pass diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index aff7d2e2..2c31ad45 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -5,8 +5,9 @@ logger = logging.getLogger(__name__) class ToolActionParser: - def __init__(self, llm_type): + def __init__(self, llm_type, name_mapping=None): self.llm_type = llm_type + self.name_mapping = name_mapping self.parsers = { "OpenAILLM": self._parse_openai_llm, "GoogleLLM": self._parse_google_llm, @@ -16,22 +17,33 @@ class ToolActionParser: parser = self.parsers.get(self.llm_type, self._parse_openai_llm) return parser(call) + def _resolve_via_mapping(self, call_name): + """Look up (tool_id, action_name) from the name mapping if available.""" + if self.name_mapping and call_name in self.name_mapping: + return self.name_mapping[call_name] + return None + def _parse_openai_llm(self, call): try: call_args = json.loads(call.arguments) + + resolved = self._resolve_via_mapping(call.name) + if resolved: + return resolved[0], resolved[1], call_args + + # Fallback: legacy split on "_" for backward compatibility tool_parts = call.name.split("_") - # If the tool name doesn't contain an underscore, it's likely a hallucinated tool if len(tool_parts) < 2: logger.warning( - f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id" + f"Invalid tool name format: {call.name}. " + "Could not resolve via mapping or legacy parsing." ) return None, None, None tool_id = tool_parts[-1] action_name = "_".join(tool_parts[:-1]) - # Validate that tool_id looks like a numerical ID if not tool_id.isdigit(): logger.warning( f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call." @@ -45,19 +57,24 @@ class ToolActionParser: def _parse_google_llm(self, call): try: call_args = call.arguments + + resolved = self._resolve_via_mapping(call.name) + if resolved: + return resolved[0], resolved[1], call_args + + # Fallback: legacy split on "_" for backward compatibility tool_parts = call.name.split("_") - # If the tool name doesn't contain an underscore, it's likely a hallucinated tool if len(tool_parts) < 2: logger.warning( - f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id" + f"Invalid tool name format: {call.name}. " + "Could not resolve via mapping or legacy parsing." ) return None, None, None tool_id = tool_parts[-1] action_name = "_".join(tool_parts[:-1]) - # Validate that tool_id looks like a numerical ID if not tool_id.isdigit(): logger.warning( f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call." diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py index 08ef30a4..41970eac 100644 --- a/application/agents/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -19,7 +19,7 @@ class ToolManager: continue module = importlib.import_module(f"application.agents.tools.{name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): - if issubclass(obj, Tool) and obj is not Tool: + if issubclass(obj, Tool) and obj is not Tool and not obj.internal: tool_config = self.config.get(name, {}) self.tools[name] = obj(tool_config) diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py index f3111605..5fa7199f 100644 --- a/application/api/answer/routes/answer.py +++ b/application/api/answer/routes/answer.py @@ -74,57 +74,72 @@ class AnswerResource(Resource, BaseAnswerResource): decoded_token = getattr(request, "decoded_token", None) processor = StreamProcessor(data, decoded_token) try: - agent = processor.build_agent(data.get("question", "")) - if not processor.decoded_token: - return make_response({"error": "Unauthorized"}, 401) + # ---- Continuation mode ---- + if data.get("tool_actions"): + ( + agent, + messages, + tools_dict, + pending_tool_calls, + tool_actions, + ) = processor.resume_from_tool_actions( + data["tool_actions"], data["conversation_id"] + ) + stream = self.complete_stream( + question="", + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + agent_id=processor.agent_id, + model_id=processor.model_id, + _continuation={ + "messages": messages, + "tools_dict": tools_dict, + "pending_tool_calls": pending_tool_calls, + "tool_actions": tool_actions, + }, + ) + else: + # ---- Normal mode ---- + agent = processor.build_agent(data.get("question", "")) + if not processor.decoded_token: + return make_response({"error": "Unauthorized"}, 401) - if error := self.check_usage(processor.agent_config): - return error + if error := self.check_usage(processor.agent_config): + return error + + stream = self.complete_stream( + question=data["question"], + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + isNoneDoc=data.get("isNoneDoc"), + index=None, + should_save_conversation=data.get("save_conversation", True), + agent_id=processor.agent_id, + is_shared_usage=processor.is_shared_usage, + shared_token=processor.shared_token, + model_id=processor.model_id, + ) - stream = self.complete_stream( - question=data["question"], - agent=agent, - conversation_id=processor.conversation_id, - user_api_key=processor.agent_config.get("user_api_key"), - decoded_token=processor.decoded_token, - isNoneDoc=data.get("isNoneDoc"), - index=None, - should_save_conversation=data.get("save_conversation", True), - agent_id=processor.agent_id, - is_shared_usage=processor.is_shared_usage, - shared_token=processor.shared_token, - model_id=processor.model_id, - ) stream_result = self.process_response_stream(stream) - if len(stream_result) == 7: - ( - conversation_id, - response, - sources, - tool_calls, - thought, - error, - structured_info, - ) = stream_result - else: - conversation_id, response, sources, tool_calls, thought, error = ( - stream_result - ) - structured_info = None + if stream_result["error"]: + return make_response({"error": stream_result["error"]}, 400) - if error: - return make_response({"error": error}, 400) result = { - "conversation_id": conversation_id, - "answer": response, - "sources": sources, - "tool_calls": tool_calls, - "thought": thought, + "conversation_id": stream_result["conversation_id"], + "answer": stream_result["answer"], + "sources": stream_result["sources"], + "tool_calls": stream_result["tool_calls"], + "thought": stream_result["thought"], } - if structured_info: - result.update(structured_info) + extra_info = stream_result.get("extra") + if extra_info: + result.update(extra_info) except Exception as e: logger.error( f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index 29d3d66c..4a152b2a 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional from flask import jsonify, make_response, Response from flask_restx import Namespace +from application.api.answer.services.continuation_service import ContinuationService from application.api.answer.services.conversation_service import ConversationService from application.core.model_utils import ( get_api_key_for_provider, @@ -39,7 +40,16 @@ class BaseAnswerResource: def validate_request( self, data: Dict[str, Any], require_conversation_id: bool = False ) -> Optional[Response]: - """Common request validation""" + """Common request validation. + + Continuation requests (``tool_actions`` present) require + ``conversation_id`` but not ``question``. + """ + if data.get("tool_actions"): + # Continuation mode — question is not required + if missing := check_required_fields(data, ["conversation_id"]): + return missing + return None required_fields = ["question"] if require_conversation_id: required_fields.append("conversation_id") @@ -177,6 +187,7 @@ class BaseAnswerResource: is_shared_usage: bool = False, shared_token: Optional[str] = None, model_id: Optional[str] = None, + _continuation: Optional[Dict] = None, ) -> Generator[str, None, None]: """ Generator function that streams the complete conversation response. @@ -207,8 +218,19 @@ class BaseAnswerResource: schema_info = None structured_chunks = [] query_metadata = {} + paused = False - for line in agent.gen(query=question): + if _continuation: + gen_iter = agent.gen_continuation( + messages=_continuation["messages"], + tools_dict=_continuation["tools_dict"], + pending_tool_calls=_continuation["pending_tool_calls"], + tool_actions=_continuation["tool_actions"], + ) + else: + gen_iter = agent.gen(query=question) + + for line in gen_iter: if "metadata" in line: query_metadata.update(line["metadata"]) elif "answer" in line: @@ -244,15 +266,21 @@ class BaseAnswerResource: data = json.dumps({"type": "thought", "thought": line["thought"]}) yield f"data: {data}\n\n" elif "type" in line: - if line.get("type") == "error": + if line.get("type") == "tool_calls_pending": + # Save continuation state and end the stream + paused = True + data = json.dumps(line) + yield f"data: {data}\n\n" + elif line.get("type") == "error": sanitized_error = { "type": "error", "error": sanitize_api_error(line.get("error", "An error occurred")) } data = json.dumps(sanitized_error) + yield f"data: {data}\n\n" else: data = json.dumps(line) - yield f"data: {data}\n\n" + yield f"data: {data}\n\n" if is_structured and structured_chunks: structured_data = { "type": "structured_answer", @@ -262,6 +290,93 @@ class BaseAnswerResource: } data = json.dumps(structured_data) yield f"data: {data}\n\n" + + # ---- Paused: save continuation state and end stream early ---- + if paused: + continuation = getattr(agent, "_pending_continuation", None) + if continuation: + # Ensure we have a conversation_id — create a partial + # conversation if this is the first turn. + if not conversation_id and should_save_conversation: + try: + provider = ( + get_provider_from_model_id(model_id) + if model_id + else settings.LLM_PROVIDER + ) + sys_api_key = get_api_key_for_provider( + provider or settings.LLM_PROVIDER + ) + llm = LLMCreator.create_llm( + provider or settings.LLM_PROVIDER, + api_key=sys_api_key, + user_api_key=user_api_key, + decoded_token=decoded_token, + model_id=model_id, + agent_id=agent_id, + ) + conversation_id = ( + self.conversation_service.save_conversation( + None, + question, + response_full, + thought, + source_log_docs, + tool_calls, + llm, + model_id or self.default_model_id, + decoded_token, + api_key=user_api_key, + agent_id=agent_id, + is_shared_usage=is_shared_usage, + shared_token=shared_token, + ) + ) + except Exception as e: + logger.error( + f"Failed to create conversation for continuation: {e}", + exc_info=True, + ) + + if conversation_id: + try: + cont_service = ContinuationService() + cont_service.save_state( + conversation_id=str(conversation_id), + user=decoded_token.get("sub", "local"), + messages=continuation["messages"], + pending_tool_calls=continuation["pending_tool_calls"], + tools_dict=continuation["tools_dict"], + tool_schemas=getattr(agent, "tools", []), + agent_config={ + "model_id": model_id or self.default_model_id, + "llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER), + "api_key": getattr(agent, "api_key", None), + "user_api_key": user_api_key, + "agent_id": agent_id, + "agent_type": agent.__class__.__name__, + "prompt": getattr(agent, "prompt", ""), + "json_schema": getattr(agent, "json_schema", None), + "retriever_config": getattr(agent, "retriever_config", None), + }, + client_tools=getattr( + agent.tool_executor, "client_tools", None + ), + ) + except Exception as e: + logger.error( + f"Failed to save continuation state: {str(e)}", + exc_info=True, + ) + + id_data = {"type": "id", "id": str(conversation_id)} + data = json.dumps(id_data) + yield f"data: {data}\n\n" + + data = json.dumps({"type": "end"}) + yield f"data: {data}\n\n" + return + if isNoneDoc: for doc in source_log_docs: doc["source"] = "None" @@ -425,8 +540,13 @@ class BaseAnswerResource: yield f"data: {data}\n\n" return - def process_response_stream(self, stream): - """Process the stream response for non-streaming endpoint""" + def process_response_stream(self, stream) -> Dict[str, Any]: + """Process the stream response for non-streaming endpoint. + + Returns: + Dict with keys: conversation_id, answer, sources, tool_calls, + thought, error, and optional extra. + """ conversation_id = "" response_full = "" source_log_docs = [] @@ -435,6 +555,7 @@ class BaseAnswerResource: stream_ended = False is_structured = False schema_info = None + pending_tool_calls = None for line in stream: try: @@ -453,11 +574,22 @@ class BaseAnswerResource: source_log_docs = event["source"] elif event["type"] == "tool_calls": tool_calls = event["tool_calls"] + elif event["type"] == "tool_calls_pending": + pending_tool_calls = event.get("data", {}).get( + "pending_tool_calls", [] + ) elif event["type"] == "thought": thought = event["thought"] elif event["type"] == "error": logger.error(f"Error from stream: {event['error']}") - return None, None, None, None, event["error"], None + return { + "conversation_id": None, + "answer": None, + "sources": None, + "tool_calls": None, + "thought": None, + "error": event["error"], + } elif event["type"] == "end": stream_ended = True except (json.JSONDecodeError, KeyError) as e: @@ -465,18 +597,30 @@ class BaseAnswerResource: continue if not stream_ended: logger.error("Stream ended unexpectedly without an 'end' event.") - return None, None, None, None, "Stream ended unexpectedly", None - result = ( - conversation_id, - response_full, - source_log_docs, - tool_calls, - thought, - None, - ) + return { + "conversation_id": None, + "answer": None, + "sources": None, + "tool_calls": None, + "thought": None, + "error": "Stream ended unexpectedly", + } + + result: Dict[str, Any] = { + "conversation_id": conversation_id, + "answer": response_full, + "sources": source_log_docs, + "tool_calls": tool_calls, + "thought": thought, + "error": None, + } + + if pending_tool_calls is not None: + result["extra"] = {"pending_tool_calls": pending_tool_calls} if is_structured: - result = result + ({"structured": True, "schema": schema_info},) + result["extra"] = {"structured": True, "schema": schema_info} + return result def error_stream_generate(self, err_response): diff --git a/application/api/answer/routes/stream.py b/application/api/answer/routes/stream.py index c2cc0ec6..2a6b9a11 100644 --- a/application/api/answer/routes/stream.py +++ b/application/api/answer/routes/stream.py @@ -79,7 +79,39 @@ class StreamResource(Resource, BaseAnswerResource): return error decoded_token = getattr(request, "decoded_token", None) processor = StreamProcessor(data, decoded_token) + try: + # ---- Continuation mode ---- + if data.get("tool_actions"): + ( + agent, + messages, + tools_dict, + pending_tool_calls, + tool_actions, + ) = processor.resume_from_tool_actions( + data["tool_actions"], data["conversation_id"] + ) + return Response( + self.complete_stream( + question="", + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + agent_id=processor.agent_id, + model_id=processor.model_id, + _continuation={ + "messages": messages, + "tools_dict": tools_dict, + "pending_tool_calls": pending_tool_calls, + "tool_actions": tool_actions, + }, + ), + mimetype="text/event-stream", + ) + + # ---- Normal mode ---- agent = processor.build_agent(data["question"]) if not processor.decoded_token: return Response( diff --git a/application/api/answer/services/compression/message_builder.py b/application/api/answer/services/compression/message_builder.py index 93772fe5..0c54e4e4 100644 --- a/application/api/answer/services/compression/message_builder.py +++ b/application/api/answer/services/compression/message_builder.py @@ -1,5 +1,6 @@ """Message reconstruction utilities for compression.""" +import json import logging import uuid from typing import Dict, List, Optional @@ -49,28 +50,35 @@ class MessageBuilder: if include_tool_calls and "tool_calls" in query: for tool_call in query["tool_calls"]: call_id = tool_call.get("call_id") or str(uuid.uuid4()) - - function_call_dict = { - "function_call": { - "name": tool_call.get("action_name"), - "args": tool_call.get("arguments"), - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": tool_call.get("action_name"), - "response": {"result": tool_call.get("result")}, - "call_id": call_id, - } - } - - messages.append( - {"role": "assistant", "content": [function_call_dict]} + args = tool_call.get("arguments") + args_str = ( + json.dumps(args) + if isinstance(args, dict) + else (args or "{}") ) - messages.append( - {"role": "tool", "content": [function_response_dict]} + messages.append({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": tool_call.get("action_name", ""), + "arguments": args_str, + }, + }], + }) + result = tool_call.get("result") + result_str = ( + json.dumps(result) + if not isinstance(result, str) + else (result or "") ) + messages.append({ + "role": "tool", + "tool_call_id": call_id, + "content": result_str, + }) # If no recent queries (everything was compressed), add a continuation user message if len(recent_queries) == 0 and compressed_summary: @@ -180,28 +188,35 @@ class MessageBuilder: if include_tool_calls and "tool_calls" in query: for tool_call in query["tool_calls"]: call_id = tool_call.get("call_id") or str(uuid.uuid4()) - - function_call_dict = { - "function_call": { - "name": tool_call.get("action_name"), - "args": tool_call.get("arguments"), - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": tool_call.get("action_name"), - "response": {"result": tool_call.get("result")}, - "call_id": call_id, - } - } - - rebuilt_messages.append( - {"role": "assistant", "content": [function_call_dict]} + args = tool_call.get("arguments") + args_str = ( + json.dumps(args) + if isinstance(args, dict) + else (args or "{}") ) - rebuilt_messages.append( - {"role": "tool", "content": [function_response_dict]} + rebuilt_messages.append({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": tool_call.get("action_name", ""), + "arguments": args_str, + }, + }], + }) + result = tool_call.get("result") + result_str = ( + json.dumps(result) + if not isinstance(result, str) + else (result or "") ) + rebuilt_messages.append({ + "role": "tool", + "tool_call_id": call_id, + "content": result_str, + }) # If no recent queries (everything was compressed), add a continuation user message if len(recent_queries) == 0 and compressed_summary: diff --git a/application/api/answer/services/continuation_service.py b/application/api/answer/services/continuation_service.py new file mode 100644 index 00000000..d63c3966 --- /dev/null +++ b/application/api/answer/services/continuation_service.py @@ -0,0 +1,141 @@ +"""Service for saving and restoring tool-call continuation state. + +When a stream pauses (tool needs approval or client-side execution), +the full execution state is persisted to MongoDB so the client can +resume later by sending tool_actions. +""" + +import datetime +import logging +from typing import Any, Dict, List, Optional + +from bson import ObjectId + +from application.core.mongo_db import MongoDB +from application.core.settings import settings + +logger = logging.getLogger(__name__) + +# TTL for pending states — auto-cleaned after this period +PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes + + +def _make_serializable(obj: Any) -> Any: + """Recursively convert MongoDB ObjectIds and other non-JSON types.""" + if isinstance(obj, ObjectId): + return str(obj) + if isinstance(obj, dict): + return {str(k): _make_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_make_serializable(v) for v in obj] + if isinstance(obj, bytes): + return obj.decode("utf-8", errors="replace") + return obj + + +class ContinuationService: + """Manages pending tool-call state in MongoDB.""" + + def __init__(self): + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + self.collection = db["pending_tool_state"] + self._ensure_indexes() + + def _ensure_indexes(self): + try: + self.collection.create_index( + "expires_at", expireAfterSeconds=0 + ) + self.collection.create_index( + [("conversation_id", 1), ("user", 1)], unique=True + ) + except Exception: + # Indexes may already exist or mongomock doesn't support TTL + pass + + def save_state( + self, + conversation_id: str, + user: str, + messages: List[Dict], + pending_tool_calls: List[Dict], + tools_dict: Dict, + tool_schemas: List[Dict], + agent_config: Dict, + client_tools: Optional[List[Dict]] = None, + ) -> str: + """Save execution state for later continuation. + + Args: + conversation_id: The conversation this state belongs to. + user: Owner user ID. + messages: Full messages array at the pause point. + pending_tool_calls: Tool calls awaiting client action. + tools_dict: Serializable tools configuration dict. + tool_schemas: LLM-formatted tool schemas (agent.tools). + agent_config: Config needed to recreate the agent on resume. + client_tools: Client-provided tool schemas for client-side execution. + + Returns: + The string ID of the saved state document. + """ + now = datetime.datetime.now(datetime.timezone.utc) + expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS) + + doc = { + "conversation_id": conversation_id, + "user": user, + "messages": _make_serializable(messages), + "pending_tool_calls": _make_serializable(pending_tool_calls), + "tools_dict": _make_serializable(tools_dict), + "tool_schemas": _make_serializable(tool_schemas), + "agent_config": _make_serializable(agent_config), + "client_tools": _make_serializable(client_tools) if client_tools else None, + "created_at": now, + "expires_at": expires_at, + } + + # Upsert — only one pending state per conversation per user + result = self.collection.replace_one( + {"conversation_id": conversation_id, "user": user}, + doc, + upsert=True, + ) + state_id = str(result.upserted_id) if result.upserted_id else conversation_id + logger.info( + f"Saved continuation state for conversation {conversation_id} " + f"with {len(pending_tool_calls)} pending tool call(s)" + ) + return state_id + + def load_state( + self, conversation_id: str, user: str + ) -> Optional[Dict[str, Any]]: + """Load pending continuation state. + + Returns: + The state dict, or None if no pending state exists. + """ + doc = self.collection.find_one( + {"conversation_id": conversation_id, "user": user} + ) + if not doc: + return None + doc["_id"] = str(doc["_id"]) + return doc + + def delete_state(self, conversation_id: str, user: str) -> bool: + """Delete pending state after successful resumption. + + Returns: + True if a document was deleted. + """ + result = self.collection.delete_one( + {"conversation_id": conversation_id, "user": user} + ) + if result.deleted_count: + logger.info( + f"Deleted continuation state for conversation {conversation_id}" + ) + return result.deleted_count > 0 diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index 28a2e8d4..fe2d7d8c 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -771,6 +771,121 @@ class StreamProcessor: logger.warning(f"Failed to fetch memory tool data: {str(e)}") return None + def resume_from_tool_actions( + self, + tool_actions: list, + conversation_id: str, + ): + """Resume a paused agent from saved continuation state. + + Loads the pending state from MongoDB, recreates the agent with + the saved configuration, and returns an agent ready to call + ``gen_continuation()``. + + Args: + tool_actions: Client-provided actions (approvals / results). + conversation_id: The conversation being resumed. + + Returns: + Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions). + """ + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + from application.agents.agent_creator import AgentCreator + from application.agents.tool_executor import ToolExecutor + from application.llm.handlers.handler_creator import LLMHandlerCreator + from application.llm.llm_creator import LLMCreator + + cont_service = ContinuationService() + state = cont_service.load_state(conversation_id, self.initial_user_id) + if not state: + raise ValueError("No pending tool state found for this conversation") + + messages = state["messages"] + pending_tool_calls = state["pending_tool_calls"] + tools_dict = state["tools_dict"] + tool_schemas = state.get("tool_schemas", []) + agent_config = state["agent_config"] + + model_id = agent_config.get("model_id") + llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER) + api_key = agent_config.get("api_key") + user_api_key = agent_config.get("user_api_key") + agent_id = agent_config.get("agent_id") + prompt = agent_config.get("prompt", "") + json_schema = agent_config.get("json_schema") + retriever_config = agent_config.get("retriever_config") + + # Recreate dependencies + system_api_key = api_key or get_api_key_for_provider(llm_name) + llm = LLMCreator.create_llm( + llm_name, + api_key=system_api_key, + user_api_key=user_api_key, + decoded_token=self.decoded_token, + model_id=model_id, + agent_id=agent_id, + ) + llm_handler = LLMHandlerCreator.create_handler(llm_name or "default") + tool_executor = ToolExecutor( + user_api_key=user_api_key, + user=self.initial_user_id, + decoded_token=self.decoded_token, + ) + tool_executor.conversation_id = conversation_id + # Restore client tools so they stay available for subsequent LLM calls + saved_client_tools = state.get("client_tools") + if saved_client_tools: + tool_executor.client_tools = saved_client_tools + # Re-merge into tools_dict (they may have been stripped during serialization) + tool_executor.merge_client_tools(tools_dict, saved_client_tools) + + agent_type = agent_config.get("agent_type", "ClassicAgent") + # Map class names back to agent creator keys + type_map = { + "ClassicAgent": "classic", + "AgenticAgent": "agentic", + "ResearchAgent": "research", + "WorkflowAgent": "workflow", + } + agent_key = type_map.get(agent_type, "classic") + + agent_kwargs = { + "endpoint": "stream", + "llm_name": llm_name, + "model_id": model_id, + "api_key": system_api_key, + "agent_id": agent_id, + "user_api_key": user_api_key, + "prompt": prompt, + "chat_history": [], + "decoded_token": self.decoded_token, + "json_schema": json_schema, + "llm": llm, + "llm_handler": llm_handler, + "tool_executor": tool_executor, + } + + if agent_key in ("agentic", "research") and retriever_config: + agent_kwargs["retriever_config"] = retriever_config + + agent = AgentCreator.create_agent(agent_key, **agent_kwargs) + agent.conversation_id = conversation_id + agent.initial_user_id = self.initial_user_id + agent.tools = tool_schemas + + # Store config for the route layer + self.model_id = model_id + self.agent_id = agent_id + self.agent_config["user_api_key"] = user_api_key + self.conversation_id = conversation_id + + # Delete state so it can't be replayed + cont_service.delete_state(conversation_id, self.initial_user_id) + + return agent, messages, tools_dict, pending_tool_calls, tool_actions + def create_agent( self, docs_together: Optional[str] = None, @@ -841,6 +956,10 @@ class StreamProcessor: decoded_token=self.decoded_token, ) tool_executor.conversation_id = self.conversation_id + # Pass client-side tools so they get merged in get_tools() + client_tools = self.data.get("client_tools") + if client_tools: + tool_executor.client_tools = client_tools # Base agent kwargs agent_kwargs = { diff --git a/application/api/v1/__init__.py b/application/api/v1/__init__.py new file mode 100644 index 00000000..69e535af --- /dev/null +++ b/application/api/v1/__init__.py @@ -0,0 +1,3 @@ +from application.api.v1.routes import v1_bp + +__all__ = ["v1_bp"] diff --git a/application/api/v1/routes.py b/application/api/v1/routes.py new file mode 100644 index 00000000..d773d962 --- /dev/null +++ b/application/api/v1/routes.py @@ -0,0 +1,314 @@ +"""Standard chat completions API routes. + +Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that +follow the widely-adopted chat completions protocol so external tools +(opencode, continue, etc.) can connect to DocsGPT agents. +""" + +import json +import logging +import time +import traceback +from typing import Any, Dict, Generator, Optional + +from flask import Blueprint, jsonify, make_response, request, Response + +from application.api.answer.routes.base import BaseAnswerResource +from application.api.answer.services.stream_processor import StreamProcessor +from application.api.v1.translator import ( + translate_request, + translate_response, + translate_stream_event, +) +from application.core.mongo_db import MongoDB +from application.core.settings import settings + +logger = logging.getLogger(__name__) + +v1_bp = Blueprint("v1", __name__, url_prefix="/v1") + + +def _extract_bearer_token() -> Optional[str]: + """Extract API key from Authorization: Bearer header.""" + auth = request.headers.get("Authorization", "") + if auth.startswith("Bearer "): + return auth[7:].strip() + return None + + +def _lookup_agent(api_key: str) -> Optional[Dict]: + """Look up the agent document for this API key.""" + try: + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + return db["agents"].find_one({"key": api_key}) + except Exception: + logger.warning("Failed to look up agent for API key", exc_info=True) + return None + + +def _get_model_name(agent: Optional[Dict], api_key: str) -> str: + """Return agent name for display as model name.""" + if agent: + return agent.get("name", api_key) + return api_key + + +class _V1AnswerHelper(BaseAnswerResource): + """Thin wrapper to access complete_stream / process_response_stream.""" + pass + + +@v1_bp.route("/chat/completions", methods=["POST"]) +def chat_completions(): + """Handle POST /v1/chat/completions.""" + api_key = _extract_bearer_token() + if not api_key: + return make_response( + jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}), + 401, + ) + + data = request.get_json() + if not data or not data.get("messages"): + return make_response( + jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}), + 400, + ) + + is_stream = data.get("stream", False) + agent_doc = _lookup_agent(api_key) + model_name = _get_model_name(agent_doc, api_key) + + try: + internal_data = translate_request(data, api_key) + except Exception as e: + logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True) + return make_response( + jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}), + 400, + ) + + # Link decoded_token to the agent's owner so continuation state, + # logs, and tool execution use the correct user identity. + agent_user = agent_doc.get("user") if agent_doc else None + decoded_token = {"sub": agent_user or "api_key_user"} + + try: + processor = StreamProcessor(internal_data, decoded_token) + + if internal_data.get("tool_actions"): + # Continuation mode + conversation_id = internal_data.get("conversation_id") + if not conversation_id: + return make_response( + jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}), + 400, + ) + ( + agent, + messages, + tools_dict, + pending_tool_calls, + tool_actions, + ) = processor.resume_from_tool_actions( + internal_data["tool_actions"], conversation_id + ) + continuation = { + "messages": messages, + "tools_dict": tools_dict, + "pending_tool_calls": pending_tool_calls, + "tool_actions": tool_actions, + } + question = "" + else: + # Normal mode + question = internal_data.get("question", "") + agent = processor.build_agent(question) + continuation = None + + if not processor.decoded_token: + return make_response( + jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}), + 401, + ) + + helper = _V1AnswerHelper() + usage_error = helper.check_usage(processor.agent_config) + if usage_error: + return usage_error + + if is_stream: + return Response( + _stream_response( + helper, question, agent, processor, model_name, continuation + ), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + else: + return _non_stream_response( + helper, question, agent, processor, model_name, continuation + ) + + except ValueError as e: + logger.error( + f"/v1/chat/completions error: {e} - {traceback.format_exc()}", + extra={"error": str(e)}, + ) + return make_response( + jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}), + 400, + ) + except Exception as e: + logger.error( + f"/v1/chat/completions error: {e} - {traceback.format_exc()}", + extra={"error": str(e)}, + ) + return make_response( + jsonify({"error": {"message": "Internal server error", "type": "server_error"}}), + 500, + ) + + +def _stream_response( + helper: _V1AnswerHelper, + question: str, + agent: Any, + processor: StreamProcessor, + model_name: str, + continuation: Optional[Dict], +) -> Generator[str, None, None]: + """Generate translated SSE chunks for streaming response.""" + completion_id = f"chatcmpl-{int(time.time())}" + + internal_stream = helper.complete_stream( + question=question, + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + agent_id=processor.agent_id, + model_id=processor.model_id, + _continuation=continuation, + ) + + for line in internal_stream: + if not line.strip(): + continue + # Parse the internal SSE event + event_str = line.replace("data: ", "").strip() + try: + event_data = json.loads(event_str) + except (json.JSONDecodeError, TypeError): + continue + + # Update completion_id when we get the conversation id + if event_data.get("type") == "id": + conv_id = event_data.get("id", "") + if conv_id: + completion_id = f"chatcmpl-{conv_id}" + + # Translate to standard format + translated = translate_stream_event(event_data, completion_id, model_name) + for chunk in translated: + yield chunk + + +def _non_stream_response( + helper: _V1AnswerHelper, + question: str, + agent: Any, + processor: StreamProcessor, + model_name: str, + continuation: Optional[Dict], +) -> Response: + """Collect full response and return as single JSON.""" + stream = helper.complete_stream( + question=question, + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + agent_id=processor.agent_id, + model_id=processor.model_id, + _continuation=continuation, + ) + + result = helper.process_response_stream(stream) + + if result["error"]: + return make_response( + jsonify({"error": {"message": result["error"], "type": "server_error"}}), + 500, + ) + + extra = result.get("extra") + pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None + + response = translate_response( + conversation_id=result["conversation_id"], + answer=result["answer"] or "", + sources=result["sources"], + tool_calls=result["tool_calls"], + thought=result["thought"] or "", + model_name=model_name, + pending_tool_calls=pending, + ) + return make_response(jsonify(response), 200) + + +@v1_bp.route("/models", methods=["GET"]) +def list_models(): + """Handle GET /v1/models — return agents as models.""" + api_key = _extract_bearer_token() + if not api_key: + return make_response( + jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}), + 401, + ) + + try: + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + agents_collection = db["agents"] + + # Find the agent for this api_key + agent = agents_collection.find_one({"key": api_key}) + if not agent: + return make_response( + jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}), + 401, + ) + + user = agent.get("user") + + # Return all agents belonging to this user + user_agents = list(agents_collection.find({"user": user})) + + models = [] + for ag in user_agents: + created = ag.get("createdAt") + created_ts = int(created.timestamp()) if created else int(time.time()) + models.append({ + "id": str(ag.get("key", "")), + "object": "model", + "created": created_ts, + "owned_by": "docsgpt", + "name": ag.get("name", ""), + "description": ag.get("description", ""), + }) + + return make_response( + jsonify({"object": "list", "data": models}), + 200, + ) + except Exception as e: + logger.error(f"/v1/models error: {e}", exc_info=True) + return make_response( + jsonify({"error": {"message": "Internal server error", "type": "server_error"}}), + 500, + ) diff --git a/application/api/v1/translator.py b/application/api/v1/translator.py new file mode 100644 index 00000000..3d3c64fe --- /dev/null +++ b/application/api/v1/translator.py @@ -0,0 +1,415 @@ +"""Translate between standard chat completions format and DocsGPT internals. + +This module handles: +- Request translation (chat completions -> DocsGPT internal format) +- Response translation (DocsGPT response -> chat completions format) +- Streaming event translation (DocsGPT SSE -> standard SSE chunks) +""" + +import json +import time +from typing import Any, Dict, List, Optional + +def _get_client_tool_name(tc: Dict) -> str: + """Return the original tool name for client-facing responses. + + For client-side tools the ``tool_name`` field carries the name the + client originally registered. Fall back to ``action_name`` (which + is now the clean LLM-visible name) or ``name``. + """ + return tc.get("tool_name", tc.get("action_name", tc.get("name", ""))) + + +# --------------------------------------------------------------------------- +# Request translation +# --------------------------------------------------------------------------- + + +def is_continuation(messages: List[Dict]) -> bool: + """Check if messages represent a tool-call continuation. + + A continuation is detected when the last message(s) have ``role: "tool"`` + immediately after an assistant message with ``tool_calls``. + """ + if not messages: + return False + # Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message + # and there's an assistant message with tool_calls, it's a continuation. + i = len(messages) - 1 + while i >= 0 and messages[i].get("role") == "tool": + i -= 1 + if i < 0: + return False + return ( + messages[i].get("role") == "assistant" + and bool(messages[i].get("tool_calls")) + ) + + +def extract_tool_results(messages: List[Dict]) -> List[Dict]: + """Extract tool results from trailing tool messages for continuation. + + Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``. + """ + results = [] + for msg in reversed(messages): + if msg.get("role") != "tool": + break + call_id = msg.get("tool_call_id", "") + content = msg.get("content", "") + if isinstance(content, str): + try: + content = json.loads(content) + except (json.JSONDecodeError, TypeError): + pass + results.append({"call_id": call_id, "result": content}) + results.reverse() + return results + + +def extract_conversation_id(messages: List[Dict]) -> Optional[str]: + """Try to extract conversation_id from the assistant message before tool results. + + The conversation_id may be stored in a custom field on the assistant message + from a previous response cycle. + """ + for msg in reversed(messages): + if msg.get("role") == "assistant": + # Check docsgpt extension + return msg.get("docsgpt", {}).get("conversation_id") + return None + + +def convert_history(messages: List[Dict]) -> List[Dict]: + """Convert chat completions messages array to DocsGPT history format. + + DocsGPT history is a list of ``{prompt, response}`` dicts. + Excludes the last user message (that becomes the ``question``). + """ + history = [] + i = 0 + while i < len(messages): + msg = messages[i] + if msg.get("role") == "system": + i += 1 + continue + if msg.get("role") == "user": + # Look ahead for assistant response + if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant": + content = messages[i + 1].get("content") or "" + history.append({ + "prompt": msg.get("content", ""), + "response": content, + }) + i += 2 + continue + # Last user message without response — skip (it's the question) + i += 1 + continue + i += 1 + return history + + +def translate_request( + data: Dict[str, Any], api_key: str +) -> Dict[str, Any]: + """Translate a chat completions request to DocsGPT internal format. + + Args: + data: The incoming request body. + api_key: Agent API key from the Authorization header. + + Returns: + Dict suitable for passing to ``StreamProcessor``. + """ + messages = data.get("messages", []) + + # Check for continuation (tool results after assistant tool_calls) + if is_continuation(messages): + tool_actions = extract_tool_results(messages) + conversation_id = extract_conversation_id(messages) + if not conversation_id: + conversation_id = data.get("conversation_id") + result = { + "conversation_id": conversation_id, + "tool_actions": tool_actions, + "api_key": api_key, + } + # Carry tools forward for next iteration + if data.get("tools"): + result["client_tools"] = data["tools"] + return result + + # Normal request — extract question from last user message + question = "" + for msg in reversed(messages): + if msg.get("role") == "user": + question = msg.get("content", "") + break + + history = convert_history(messages) + + result = { + "question": question, + "api_key": api_key, + "history": json.dumps(history), + "save_conversation": True, + } + + # Client tools + if data.get("tools"): + result["client_tools"] = data["tools"] + + # DocsGPT extensions + docsgpt = data.get("docsgpt", {}) + if docsgpt.get("attachments"): + result["attachments"] = docsgpt["attachments"] + + return result + + +# --------------------------------------------------------------------------- +# Response translation (non-streaming) +# --------------------------------------------------------------------------- + + +def translate_response( + conversation_id: str, + answer: str, + sources: Optional[List[Dict]], + tool_calls: Optional[List[Dict]], + thought: str, + model_name: str, + pending_tool_calls: Optional[List[Dict]] = None, +) -> Dict[str, Any]: + """Translate DocsGPT response to chat completions format. + + Args: + conversation_id: The DocsGPT conversation ID. + answer: The assistant's text response. + sources: RAG retrieval sources. + tool_calls: Completed tool call results. + thought: Reasoning/thinking tokens. + model_name: Model/agent identifier. + pending_tool_calls: Pending client-side tool calls (if paused). + + Returns: + Dict in the standard chat completions response format. + """ + created = int(time.time()) + completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}" + + # Build message + message: Dict[str, Any] = {"role": "assistant"} + + if pending_tool_calls: + # Tool calls pending — return them for client execution + message["content"] = None + message["tool_calls"] = [ + { + "id": tc.get("call_id", ""), + "type": "function", + "function": { + "name": _get_client_tool_name(tc), + "arguments": ( + json.dumps(tc["arguments"]) + if isinstance(tc.get("arguments"), dict) + else tc.get("arguments", "{}") + ), + }, + } + for tc in pending_tool_calls + ] + finish_reason = "tool_calls" + else: + message["content"] = answer + if thought: + message["reasoning_content"] = thought + finish_reason = "stop" + + result: Dict[str, Any] = { + "id": completion_id, + "object": "chat.completion", + "created": created, + "model": model_name, + "choices": [ + { + "index": 0, + "message": message, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + } + + # DocsGPT extensions + docsgpt: Dict[str, Any] = {} + if conversation_id: + docsgpt["conversation_id"] = conversation_id + if sources: + docsgpt["sources"] = sources + if tool_calls: + docsgpt["tool_calls"] = tool_calls + if docsgpt: + result["docsgpt"] = docsgpt + + return result + + +# --------------------------------------------------------------------------- +# Streaming event translation +# --------------------------------------------------------------------------- + + +def _make_chunk( + completion_id: str, + model_name: str, + delta: Dict[str, Any], + finish_reason: Optional[str] = None, +) -> str: + """Build a single SSE chunk in the standard streaming format.""" + chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + } + ], + } + return f"data: {json.dumps(chunk)}\n\n" + + +def _make_docsgpt_chunk(data: Dict[str, Any]) -> str: + """Build a DocsGPT extension SSE chunk.""" + return f"data: {json.dumps({'docsgpt': data})}\n\n" + + +def translate_stream_event( + event_data: Dict[str, Any], + completion_id: str, + model_name: str, +) -> List[str]: + """Translate a DocsGPT SSE event dict to standard streaming chunks. + + May return 0, 1, or 2 chunks per input event. For example, a completed + tool call produces both a docsgpt extension chunk and nothing on the + standard side (since server-side tool calls aren't surfaced in standard + format). + + Args: + event_data: Parsed DocsGPT event dict. + completion_id: The completion ID for this response. + model_name: Model/agent identifier. + + Returns: + List of SSE-formatted strings to send to the client. + """ + event_type = event_data.get("type") + chunks: List[str] = [] + + if event_type == "answer": + chunks.append( + _make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")}) + ) + + elif event_type == "thought": + chunks.append( + _make_chunk( + completion_id, model_name, + {"reasoning_content": event_data.get("thought", "")}, + ) + ) + + elif event_type == "source": + chunks.append( + _make_docsgpt_chunk({ + "type": "source", + "sources": event_data.get("source", []), + }) + ) + + elif event_type == "tool_call": + tc_data = event_data.get("data", {}) + status = tc_data.get("status") + + if status == "requires_client_execution": + # Standard: stream as tool_calls delta + args = tc_data.get("arguments", {}) + args_str = json.dumps(args) if isinstance(args, dict) else str(args) + chunks.append( + _make_chunk(completion_id, model_name, { + "tool_calls": [{ + "index": 0, + "id": tc_data.get("call_id", ""), + "type": "function", + "function": { + "name": _get_client_tool_name(tc_data), + "arguments": args_str, + }, + }], + }) + ) + elif status == "awaiting_approval": + # Extension: approval needed + chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data})) + elif status in ("completed", "pending", "error", "denied", "skipped"): + # Extension: tool call progress + chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data})) + + elif event_type == "tool_calls_pending": + # Standard: finish_reason = tool_calls + chunks.append( + _make_chunk(completion_id, model_name, {}, finish_reason="tool_calls") + ) + # Also emit as docsgpt extension + chunks.append( + _make_docsgpt_chunk({ + "type": "tool_calls_pending", + "pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []), + }) + ) + + elif event_type == "end": + chunks.append( + _make_chunk(completion_id, model_name, {}, finish_reason="stop") + ) + chunks.append("data: [DONE]\n\n") + + elif event_type == "id": + chunks.append( + _make_docsgpt_chunk({ + "type": "id", + "conversation_id": event_data.get("id", ""), + }) + ) + + elif event_type == "error": + # Emit as standard error (non-standard but widely supported) + error_data = { + "error": { + "message": event_data.get("error", "An error occurred"), + "type": "server_error", + } + } + chunks.append(f"data: {json.dumps(error_data)}\n\n") + + elif event_type == "structured_answer": + chunks.append( + _make_chunk( + completion_id, model_name, + {"content": event_data.get("answer", "")}, + ) + ) + + # Skip: tool_calls (redundant), research_plan, research_progress + + return chunks diff --git a/application/app.py b/application/app.py index aed069e0..a2578fea 100644 --- a/application/app.py +++ b/application/app.py @@ -17,6 +17,7 @@ from application.api.answer import answer # noqa: E402 from application.api.internal.routes import internal # noqa: E402 from application.api.user.routes import user # noqa: E402 from application.api.connector.routes import connector # noqa: E402 +from application.api.v1 import v1_bp # noqa: E402 from application.celery_init import celery # noqa: E402 from application.core.settings import settings # noqa: E402 from application.stt.upload_limits import ( # noqa: E402 @@ -36,6 +37,7 @@ app.register_blueprint(user) app.register_blueprint(answer) app.register_blueprint(internal) app.register_blueprint(connector) +app.register_blueprint(v1_bp) app.config.update( UPLOAD_FOLDER="inputs", CELERY_BROKER_URL=settings.CELERY_BROKER_URL, diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index ca984ba7..538abe67 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -167,6 +167,8 @@ class GoogleLLM(BaseLLM): return "\n".join(parts) return "" + import json as _json + for message in messages: role = message.get("role") content = message.get("content") @@ -180,9 +182,66 @@ class GoogleLLM(BaseLLM): if role == "assistant": role = "model" - elif role == "tool": - role = "model" + parts = [] + + # Standard format: assistant message with tool_calls array + msg_tool_calls = message.get("tool_calls") + if msg_tool_calls and role == "model": + for tc in msg_tool_calls: + func = tc.get("function", {}) + args = func.get("arguments", "{}") + if isinstance(args, str): + try: + args = _json.loads(args) + except (_json.JSONDecodeError, TypeError): + args = {} + cleaned_args = self._remove_null_values(args) + thought_sig = tc.get("thought_signature") + if thought_sig: + parts.append( + types.Part( + functionCall=types.FunctionCall( + name=func.get("name", ""), + args=cleaned_args, + ), + thoughtSignature=thought_sig, + ) + ) + else: + parts.append( + types.Part.from_function_call( + name=func.get("name", ""), + args=cleaned_args, + ) + ) + if parts: + cleaned_messages.append(types.Content(role=role, parts=parts)) + continue + + # Standard format: tool message with tool_call_id + tool_call_id = message.get("tool_call_id") + if role == "tool" and tool_call_id is not None: + result_content = content + if isinstance(result_content, str): + try: + result_content = _json.loads(result_content) + except (_json.JSONDecodeError, TypeError): + pass + # Google expects function_response name — extract from tool_call_id context + # We use a placeholder name since Google API doesn't require exact match + parts.append( + types.Part.from_function_response( + name="tool_result", + response={"result": result_content}, + ) + ) + cleaned_messages.append(types.Content(role="model", parts=parts)) + continue + + if role == "tool": + role = "model" + if role and content is not None: if isinstance(content, str): parts = [types.Part.from_text(text=content)] @@ -191,15 +250,11 @@ class GoogleLLM(BaseLLM): if "text" in item: parts.append(types.Part.from_text(text=item["text"])) elif "function_call" in item: - # Remove null values from args to avoid API errors - + # Legacy format support cleaned_args = self._remove_null_values( item["function_call"]["args"] ) - # Create function call part with thought_signature if present - # For Gemini 3 models, we need to include thought_signature if "thought_signature" in item: - # Use Part constructor with functionCall and thoughtSignature parts.append( types.Part( functionCall=types.FunctionCall( @@ -210,7 +265,6 @@ class GoogleLLM(BaseLLM): ) ) else: - # Use helper method when no thought_signature parts.append( types.Part.from_function_call( name=item["function_call"]["name"], diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index 7537d9c5..a2edd233 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -1,3 +1,4 @@ +import json import logging import uuid from abc import ABC, abstractmethod @@ -315,10 +316,34 @@ class LLMHandler(ABC): current_prompt = self._extract_text_from_content(content) elif role in {"assistant", "model"}: - # If this assistant turn contains tool calls, collect them; otherwise commit a response. + # Standard format: tool_calls array on assistant message + msg_tool_calls = message.get("tool_calls") + if msg_tool_calls: + for tc in msg_tool_calls: + call_id = tc.get("id") or str(uuid.uuid4()) + func = tc.get("function", {}) + args = func.get("arguments") + if isinstance(args, str): + try: + args = json.loads(args) + except (json.JSONDecodeError, TypeError): + pass + current_tool_calls[call_id] = { + "tool_name": "unknown_tool", + "action_name": func.get("name"), + "arguments": args, + "result": None, + "status": "called", + "call_id": call_id, + } + continue + + # Legacy format: function_call/function_response in content list if isinstance(content, list): + has_fc = False for item in content: if "function_call" in item: + has_fc = True fc = item["function_call"] call_id = fc.get("call_id") or str(uuid.uuid4()) current_tool_calls[call_id] = { @@ -329,37 +354,30 @@ class LLMHandler(ABC): "status": "called", "call_id": call_id, } - elif "function_response" in item: - fr = item["function_response"] - call_id = fr.get("call_id") or str(uuid.uuid4()) - current_tool_calls[call_id] = { - "tool_name": "unknown_tool", - "action_name": fr.get("name"), - "arguments": None, - "result": fr.get("response", {}).get("result"), - "status": "completed", - "call_id": call_id, - } - # No direct assistant text here; continue to next message - continue + if has_fc: + continue response_text = self._extract_text_from_content(content) _commit_query(response_text) elif role == "tool": - # Attach tool outputs to the latest pending tool call if possible + # Standard format: tool_call_id on tool message + call_id = message.get("tool_call_id") tool_text = self._extract_text_from_content(content) - # Attempt to parse function_response style - call_id = None - if isinstance(content, list): - for item in content: - if "function_response" in item and item["function_response"].get("call_id"): - call_id = item["function_response"]["call_id"] - break + if call_id and call_id in current_tool_calls: current_tool_calls[call_id]["result"] = tool_text current_tool_calls[call_id]["status"] = "completed" - elif queries: + # Legacy: function_response in content list + elif isinstance(content, list): + for item in content: + if "function_response" in item: + legacy_id = item["function_response"].get("call_id") + if legacy_id and legacy_id in current_tool_calls: + current_tool_calls[legacy_id]["result"] = tool_text + current_tool_calls[legacy_id]["status"] = "completed" + break + elif call_id is None and queries: queries[-1].setdefault("tool_calls", []).append( { "tool_name": "unknown_tool", @@ -648,6 +666,13 @@ class LLMHandler(ABC): """ Execute tool calls and update conversation history. + When a tool requires approval or client-side execution, it is + collected as a pending action instead of being executed. The + generator returns ``(updated_messages, pending_actions)`` where + *pending_actions* is ``None`` when every tool was executed + normally, or a list of dicts describing actions the client must + resolve before the LLM loop can continue. + Args: agent: The agent instance tool_calls: List of tool calls to execute @@ -655,9 +680,11 @@ class LLMHandler(ABC): messages: Current conversation history Returns: - Updated messages list + Tuple of (updated_messages, pending_actions). + pending_actions is None if all tools executed, otherwise a list. """ updated_messages = messages.copy() + pending_actions: List[Dict] = [] for i, call in enumerate(tool_calls): # Check context limit before executing tool call @@ -763,6 +790,29 @@ class LLMHandler(ABC): # Set flag on agent agent.context_limit_reached = True break + + # ---- Pause check: approval / client-side execution ---- + llm_class = agent.llm.__class__.__name__ + pause_info = agent.tool_executor.check_pause( + tools_dict, call, llm_class + ) + if pause_info: + # Yield pause event so the client knows this tool is waiting + yield { + "type": "tool_call", + "data": { + "tool_name": pause_info["tool_name"], + "call_id": pause_info["call_id"], + "action_name": pause_info.get("llm_name", pause_info["name"]), + "arguments": pause_info["arguments"], + "status": pause_info["pause_type"], + }, + } + pending_actions.append(pause_info) + # Do NOT add messages for pending tools here. + # They will be added on resume to keep call/result pairs together. + continue + try: self.tool_calls.append(call) tool_executor_gen = agent._execute_tool_action(tools_dict, call) @@ -772,25 +822,30 @@ class LLMHandler(ABC): except StopIteration as e: tool_response, call_id = e.value break - - function_call_content = { - "function_call": { - "name": call.name, - "args": call.arguments, - "call_id": call_id, - } - } - # Include thought_signature for Google Gemini 3 models - # It should be at the same level as function_call, not inside it - if call.thought_signature: - function_call_content["thought_signature"] = call.thought_signature - updated_messages.append( - { - "role": "assistant", - "content": [function_call_content], - } - ) + # Standard internal format: assistant message with tool_calls array + args_str = ( + json.dumps(call.arguments) + if isinstance(call.arguments, dict) + else call.arguments + ) + tool_call_obj = { + "id": call_id, + "type": "function", + "function": { + "name": call.name, + "arguments": args_str, + }, + } + # Preserve thought_signature for Google Gemini 3 models + if call.thought_signature: + tool_call_obj["thought_signature"] = call.thought_signature + + updated_messages.append({ + "role": "assistant", + "content": None, + "tool_calls": [tool_call_obj], + }) updated_messages.append(self.create_tool_message(call, tool_response)) except Exception as e: @@ -802,16 +857,15 @@ class LLMHandler(ABC): error_message = self.create_tool_message(error_call, error_response) updated_messages.append(error_message) - call_parts = call.name.split("_") - if len(call_parts) >= 2: - tool_id = call_parts[-1] # Last part is tool ID (e.g., "1") - action_name = "_".join(call_parts[:-1]) - tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool") - full_action_name = f"{action_name}_{tool_id}" + mapping = agent.tool_executor._name_to_tool + if call.name in mapping: + resolved_tool_id, _ = mapping[call.name] + tool_name = tools_dict.get(resolved_tool_id, {}).get( + "name", "unknown_tool" + ) else: tool_name = "unknown_tool" - action_name = call.name - full_action_name = call.name + full_action_name = call.name yield { "type": "tool_call", "data": { @@ -823,7 +877,7 @@ class LLMHandler(ABC): "status": "error", }, } - return updated_messages + return updated_messages, pending_actions if pending_actions else None def handle_non_streaming( self, agent, response: Any, tools_dict: Dict, messages: List[Dict] @@ -851,8 +905,22 @@ class LLMHandler(ABC): try: yield next(tool_handler_gen) except StopIteration as e: - messages = e.value + messages, pending_actions = e.value break + + # If tools need approval or client execution, pause the loop + if pending_actions: + agent._pending_continuation = { + "messages": messages, + "pending_tool_calls": pending_actions, + "tools_dict": tools_dict, + } + yield { + "type": "tool_calls_pending", + "data": {"pending_tool_calls": pending_actions}, + } + return "" + response = agent.llm.gen( model=agent.model_id, messages=messages, tools=agent.tools ) @@ -913,10 +981,23 @@ class LLMHandler(ABC): try: yield next(tool_handler_gen) except StopIteration as e: - messages = e.value + messages, pending_actions = e.value break tool_calls = {} + # If tools need approval or client execution, pause the loop + if pending_actions: + agent._pending_continuation = { + "messages": messages, + "pending_tool_calls": pending_actions, + "tools_dict": tools_dict, + } + yield { + "type": "tool_calls_pending", + "data": {"pending_tool_calls": pending_actions}, + } + return + # Check if context limit was reached during tool execution if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached: # Add system message warning about context limit diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py index 0142922a..5c202153 100644 --- a/application/llm/handlers/google.py +++ b/application/llm/handlers/google.py @@ -67,18 +67,18 @@ class GoogleLLMHandler(LLMHandler): ) def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: - """Create Google-style tool message.""" + """Create a tool result message in the standard internal format.""" + import json as _json + content = ( + _json.dumps(result) + if not isinstance(result, str) + else result + ) return { - "role": "model", - "content": [ - { - "function_response": { - "name": tool_call.name, - "response": {"result": result}, - } - } - ], + "role": "tool", + "tool_call_id": tool_call.id, + "content": content, } def _iterate_stream(self, response: Any) -> Generator: diff --git a/application/llm/handlers/openai.py b/application/llm/handlers/openai.py index 99ddde4c..6c8431d7 100644 --- a/application/llm/handlers/openai.py +++ b/application/llm/handlers/openai.py @@ -37,18 +37,18 @@ class OpenAILLMHandler(LLMHandler): ) def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: - """Create OpenAI-style tool message.""" + """Create a tool result message in the standard internal format.""" + import json as _json + + content = ( + _json.dumps(result) + if not isinstance(result, str) + else result + ) return { "role": "tool", - "content": [ - { - "function_response": { - "name": tool_call.name, - "response": {"result": result}, - "call_id": tool_call.id, - } - } - ], + "tool_call_id": tool_call.id, + "content": content, } def _iterate_stream(self, response: Any) -> Generator: diff --git a/application/llm/openai.py b/application/llm/openai.py index 8bfc3cef..ee9438de 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -91,16 +91,52 @@ class OpenAILLM(BaseLLM): if role == "model": role = "assistant" + + # Standard format: assistant message with tool_calls (passthrough) + tool_calls = message.get("tool_calls") + if tool_calls and role == "assistant": + cleaned_tcs = [] + for tc in tool_calls: + func = tc.get("function", {}) + args = func.get("arguments", "{}") + if isinstance(args, dict): + args = json.dumps(self._remove_null_values(args)) + elif isinstance(args, str): + try: + parsed = json.loads(args) + args = json.dumps(self._remove_null_values(parsed)) + except (json.JSONDecodeError, TypeError): + pass + cleaned_tcs.append({ + "id": tc.get("id", ""), + "type": "function", + "function": {"name": func.get("name", ""), "arguments": args}, + }) + cleaned_messages.append({ + "role": "assistant", + "content": None, + "tool_calls": cleaned_tcs, + }) + continue + + # Standard format: tool message with tool_call_id (passthrough) + tool_call_id = message.get("tool_call_id") + if role == "tool" and tool_call_id is not None: + cleaned_messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "content": content if isinstance(content, str) else json.dumps(content), + }) + continue + if role and content is not None: if isinstance(content, str): cleaned_messages.append({"role": role, "content": content}) elif isinstance(content, list): - # Collect all content parts into a single message content_parts = [] - for item in content: + # Legacy format support: function_call / function_response if "function_call" in item: - # Function calls need their own message args = item["function_call"]["args"] if isinstance(args, str): try: @@ -116,28 +152,20 @@ class OpenAILLM(BaseLLM): "arguments": json.dumps(cleaned_args), }, } - cleaned_messages.append( - { - "role": "assistant", - "content": None, - "tool_calls": [tool_call], - } - ) + cleaned_messages.append({ + "role": "assistant", + "content": None, + "tool_calls": [tool_call], + }) elif "function_response" in item: - # Function responses need their own message - cleaned_messages.append( - { - "role": "tool", - "tool_call_id": item["function_response"][ - "call_id" - ], - "content": json.dumps( - item["function_response"]["response"]["result"] - ), - } - ) + cleaned_messages.append({ + "role": "tool", + "tool_call_id": item["function_response"]["call_id"], + "content": json.dumps( + item["function_response"]["response"]["result"] + ), + }) elif isinstance(item, dict): - # Collect content parts (text, images, files) into a single message if "type" in item and item["type"] == "text" and "text" in item: content_parts.append(item) elif "type" in item and item["type"] == "file" and "file" in item: @@ -145,10 +173,7 @@ class OpenAILLM(BaseLLM): elif "type" in item and item["type"] == "image_url" and "image_url" in item: content_parts.append(item) elif "text" in item and "type" not in item: - # Legacy format: {"text": "..."} without type content_parts.append({"type": "text", "text": item["text"]}) - - # Add the collected content parts as a single message if content_parts: cleaned_messages.append({"role": role, "content": content_parts}) else: diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 573b58a8..e0dc7742 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -77,6 +77,10 @@ const endpoints = { WORKFLOWS: '/api/workflows', WORKFLOW: (id: string) => `/api/workflows/${id}`, }, + V1: { + CHAT_COMPLETIONS: '/v1/chat/completions', + MODELS: '/v1/models', + }, CONVERSATION: { ANSWER: '/api/answer', ANSWER_STREAMING: '/stream', diff --git a/frontend/src/api/services/conversationService.ts b/frontend/src/api/services/conversationService.ts index 853a6863..a79370aa 100644 --- a/frontend/src/api/services/conversationService.ts +++ b/frontend/src/api/services/conversationService.ts @@ -54,6 +54,18 @@ const conversationService = { apiClient.get(endpoints.CONVERSATION.DELETE_ALL, token, {}), update: (data: any, token: string | null): Promise => apiClient.post(endpoints.CONVERSATION.UPDATE, data, token, {}), + chatCompletions: ( + data: any, + agentApiKey: string, + signal: AbortSignal, + ): Promise => + apiClient.post( + endpoints.V1.CHAT_COMPLETIONS, + data, + null, + { Authorization: `Bearer ${agentApiKey}` }, + signal, + ), }; export default conversationService; diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index 66cb6830..179417be 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -22,6 +22,7 @@ import { resendQuery, selectQueries, selectStatus, + submitToolActions, updateQuery, } from './conversationSlice'; import { selectCompletedAttachments } from '../upload/uploadSlice'; @@ -41,6 +42,17 @@ export default function Conversation() { const [lastQueryReturnedErr, setLastQueryReturnedErr] = useState(false); + const handleToolAction = useCallback( + (callId: string, decision: 'approved' | 'denied', comment?: string) => { + dispatch( + submitToolActions({ + toolActions: [{ call_id: callId, decision, comment }], + }), + ); + }, + [dispatch], + ); + const lastAutoOpenedArtifactId = useRef(null); const didInitArtifactAutoOpen = useRef(false); const prevConversationId = useRef(conversationId); @@ -233,6 +245,7 @@ export default function Conversation() { status={status} showHeroOnEmpty={selectedAgent ? false : true} onOpenArtifact={handleOpenArtifact} + onToolAction={handleToolAction} isSplitView={isSplitArtifactOpen} headerContent={ selectedAgent ? ( diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 4fd9f868..41f81a42 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -65,6 +65,11 @@ const ConversationBubble = forwardRef< ) => void; filesAttached?: { id: string; fileName: string }[]; onOpenArtifact?: (artifact: { id: string; toolName: string }) => void; + onToolAction?: ( + callId: string, + decision: 'approved' | 'denied', + comment?: string, + ) => void; } >(function ConversationBubble( { @@ -83,6 +88,7 @@ const ConversationBubble = forwardRef< handleUpdatedQuestionSubmission, filesAttached, onOpenArtifact, + onToolAction, }, ref, ) { @@ -411,7 +417,7 @@ const ConversationBubble = forwardRef< )} {research && } {toolCalls && toolCalls.length > 0 && ( - + )} {!message && primaryArtifactCall?.artifact_id && onOpenArtifact && (
@@ -884,108 +890,263 @@ function AllSources(sources: AllSourcesProps) { } export default ConversationBubble; -function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { +function ToolCallApprovalBar({ + toolCall, + onToolAction, +}: { + toolCall: ToolCallsType; + onToolAction?: ( + callId: string, + decision: 'approved' | 'denied', + comment?: string, + ) => void; +}) { + const [expanded, setExpanded] = useState(false); + const [comment, setComment] = useState(''); + const actionLabel = toolCall.action_name.substring( + 0, + toolCall.action_name.lastIndexOf('_'), + ); + const argPreview = JSON.stringify(toolCall.arguments); + const truncated = + argPreview.length > 60 ? argPreview.slice(0, 57) + '...' : argPreview; + + return ( +
+
+
+ + {toolCall.tool_name} + + {actionLabel} + + {truncated} + +
+
+ + + +
+
+ {expanded && ( +
+

+ Arguments +

+
+            {JSON.stringify(toolCall.arguments, null, 2)}
+          
+ setComment(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter' && comment) { + onToolAction?.(toolCall.call_id, 'denied', comment); + } + }} + /> +
+ )} +
+ ); +} + +function ToolCalls({ + toolCalls, + onToolAction, +}: { + toolCalls: ToolCallsType[]; + onToolAction?: ( + callId: string, + decision: 'approved' | 'denied', + comment?: string, + ) => void; +}) { const [isToolCallsOpen, setIsToolCallsOpen] = useState(false); + const awaitingCalls = toolCalls.filter( + (tc) => tc.status === 'awaiting_approval', + ); + const resolvedCalls = toolCalls.filter( + (tc) => tc.status !== 'awaiting_approval', + ); + return (
-
- 0 && ( +
+ {awaitingCalls.map((tc) => ( + - } - /> - -
- {isToolCallsOpen && ( -
-
- {toolCalls.map((toolCall, index) => ( - -
-
-

- - Arguments - {' '} - -

-

- - {JSON.stringify(toolCall.arguments, null, 2)} - -

-
-
-

- - Response - {' '} - -

- {toolCall.status === 'pending' && ( - - - - )} - {toolCall.status === 'completed' && ( -

- - {JSON.stringify(toolCall.result, null, 2)} - -

- )} - {toolCall.status === 'error' && ( -

- - {toolCall.error} - -

- )} -
-
-
- ))} -
+ ))}
)} + + {/* Regular tool calls accordion */} + {resolvedCalls.length > 0 && ( + <> +
+ + } + /> + +
+ {isToolCallsOpen && ( +
+
+ {resolvedCalls.map((toolCall, index) => ( + +
+
+

+ + Arguments + {' '} + +

+

+ + {JSON.stringify(toolCall.arguments, null, 2)} + +

+
+
+

+ + Response + {' '} + +

+ {toolCall.status === 'pending' && ( + + + + )} + {toolCall.status === 'completed' && ( +

+ + {JSON.stringify(toolCall.result, null, 2)} + +

+ )} + {toolCall.status === 'error' && ( +

+ + {toolCall.error} + +

+ )} + {toolCall.status === 'denied' && ( +

+ + Denied by user + +

+ )} +
+
+
+ ))} +
+
+ )} + + )}
); } diff --git a/frontend/src/conversation/ConversationMessages.tsx b/frontend/src/conversation/ConversationMessages.tsx index dbfdeba7..7747e7b8 100644 --- a/frontend/src/conversation/ConversationMessages.tsx +++ b/frontend/src/conversation/ConversationMessages.tsx @@ -38,6 +38,11 @@ type ConversationMessagesProps = { showHeroOnEmpty?: boolean; headerContent?: ReactNode; onOpenArtifact?: (artifact: { id: string; toolName: string }) => void; + onToolAction?: ( + callId: string, + decision: 'approved' | 'denied', + comment?: string, + ) => void; isSplitView?: boolean; }; @@ -50,6 +55,7 @@ export default function ConversationMessages({ showHeroOnEmpty = true, headerContent, onOpenArtifact, + onToolAction, isSplitView = false, }: ConversationMessagesProps) { const [isDarkTheme] = useDarkTheme(); @@ -154,6 +160,7 @@ export default function ConversationMessages({ toolCalls={query.tool_calls} research={query.research} onOpenArtifact={onOpenArtifact} + onToolAction={onToolAction} feedback={query.feedback} isStreaming={isCurrentlyStreaming} handleFeedback={ diff --git a/frontend/src/conversation/conversationHandlers.ts b/frontend/src/conversation/conversationHandlers.ts index e55952fe..759635b6 100644 --- a/frontend/src/conversation/conversationHandlers.ts +++ b/frontend/src/conversation/conversationHandlers.ts @@ -188,6 +188,264 @@ export function handleFetchAnswerSteaming( }); } +export function handleSubmitToolActions( + conversationId: string, + toolActions: { + call_id: string; + decision?: 'approved' | 'denied'; + comment?: string; + result?: Record; + }[], + token: string | null, + signal: AbortSignal, + onEvent: (event: MessageEvent) => void, +): Promise { + const payload = { + conversation_id: conversationId, + tool_actions: toolActions, + }; + + return new Promise((resolve, reject) => { + conversationService + .answerStream(payload, token, signal) + .then((response) => { + if (!response.body) throw Error('No response body'); + + let buffer = ''; + const reader = response.body.getReader(); + const decoder = new TextDecoder('utf-8'); + + const processStream = ({ + done, + value, + }: ReadableStreamReadResult) => { + if (done) return; + + const chunk = decoder.decode(value); + buffer += chunk; + + const events = buffer.split('\n\n'); + buffer = events.pop() ?? ''; + + for (const event of events) { + if (event.trim().startsWith('data:')) { + const dataLine: string = event + .split('\n') + .map((line: string) => line.replace(/^data:\s?/, '')) + .join(''); + + const messageEvent = new MessageEvent('message', { + data: dataLine.trim(), + }); + + onEvent(messageEvent); + } + } + + reader.read().then(processStream).catch(reject); + }; + + reader.read().then(processStream).catch(reject); + }) + .catch((error) => { + console.error('Tool actions submission failed:', error); + reject(error); + }); + }); +} + +/** + * Stream a chat completion via the /v1/chat/completions endpoint. + * + * Translates the standard streaming format (choices[0].delta) back into + * the internal DocsGPT event shape so the existing Redux reducers can + * consume the events without any changes. + */ +export function handleV1ChatCompletionStreaming( + question: string, + signal: AbortSignal, + agentApiKey: string, + history: { prompt: string; response: string }[], + onEvent: (event: MessageEvent) => void, + tools?: any[], + attachments?: string[], +): Promise { + // Build messages array from history + current question + const messages: any[] = []; + for (const h of history) { + messages.push({ role: 'user', content: h.prompt }); + messages.push({ role: 'assistant', content: h.response }); + } + messages.push({ role: 'user', content: question }); + + const payload: any = { + messages, + stream: true, + }; + if (tools && tools.length > 0) { + payload.tools = tools; + } + if (attachments && attachments.length > 0) { + payload.docsgpt = { attachments }; + } + + return new Promise((resolve, reject) => { + conversationService + .chatCompletions(payload, agentApiKey, signal) + .then((response) => { + if (!response.body) throw Error('No response body'); + + let buffer = ''; + const reader = response.body.getReader(); + const decoder = new TextDecoder('utf-8'); + + const processStream = ({ + done, + value, + }: ReadableStreamReadResult) => { + if (done) return; + + const chunk = decoder.decode(value); + buffer += chunk; + + const events = buffer.split('\n\n'); + buffer = events.pop() ?? ''; + + for (const event of events) { + if (!event.trim().startsWith('data:')) continue; + + const dataLine = event + .split('\n') + .map((line: string) => line.replace(/^data:\s?/, '')) + .join(''); + + const trimmed = dataLine.trim(); + + // Handle [DONE] sentinel + if (trimmed === '[DONE]') { + onEvent( + new MessageEvent('message', { + data: JSON.stringify({ type: 'end' }), + }), + ); + continue; + } + + try { + const parsed = JSON.parse(trimmed); + // Translate standard format to DocsGPT internal events + const translated = translateV1ChunkToInternalEvents(parsed); + for (const evt of translated) { + onEvent( + new MessageEvent('message', { + data: JSON.stringify(evt), + }), + ); + } + } catch { + // Skip unparseable chunks + } + } + + reader.read().then(processStream).catch(reject); + }; + + reader.read().then(processStream).catch(reject); + }) + .catch((error) => { + console.error('V1 chat completion stream failed:', error); + reject(error); + }); + }); +} + +/** + * Translate a single v1 streaming chunk to internal DocsGPT event(s). + * + * Standard format: {"choices": [{"delta": {"content": "chunk"}, ...}]} + * Extension format: {"docsgpt": {"type": "source", ...}} + */ +function translateV1ChunkToInternalEvents( + chunk: any, +): { type: string; [key: string]: any }[] { + const events: { type: string; [key: string]: any }[] = []; + + // DocsGPT extension chunks + if (chunk.docsgpt) { + const ext = chunk.docsgpt; + if (ext.type === 'source') { + events.push({ type: 'source', source: ext.sources }); + } else if (ext.type === 'tool_call') { + events.push({ type: 'tool_call', data: ext.data }); + } else if (ext.type === 'tool_calls_pending') { + events.push({ + type: 'tool_calls_pending', + data: { pending_tool_calls: ext.pending_tool_calls }, + }); + } else if (ext.type === 'id') { + events.push({ type: 'id', id: ext.conversation_id }); + } + return events; + } + + // Error chunks + if (chunk.error) { + events.push({ type: 'error', error: chunk.error.message || 'Error' }); + return events; + } + + // Standard choices chunks + const choice = chunk.choices?.[0]; + if (!choice) return events; + + const delta = choice.delta || {}; + const finishReason = choice.finish_reason; + + if (delta.content) { + events.push({ type: 'answer', answer: delta.content }); + } + + if (delta.reasoning_content) { + events.push({ type: 'thought', thought: delta.reasoning_content }); + } + + if (delta.tool_calls) { + for (const tc of delta.tool_calls) { + let parsedArgs: Record = {}; + if (tc.function?.arguments) { + try { + parsedArgs = JSON.parse(tc.function.arguments); + } catch { + // Arguments may arrive as fragments during streaming; + // keep the raw string so downstream can accumulate it. + parsedArgs = { _raw: tc.function.arguments }; + } + } + events.push({ + type: 'tool_call', + data: { + call_id: tc.id, + action_name: tc.function?.name || '', + tool_name: tc.function?.name || '', + arguments: parsedArgs, + status: 'requires_client_execution', + }, + }); + } + } + + if (finishReason === 'stop') { + events.push({ type: 'end' }); + } else if (finishReason === 'tool_calls') { + events.push({ + type: 'tool_calls_pending', + data: { pending_tool_calls: [] }, + }); + } + + return events; +} + export function handleSearch( question: string, token: string | null, diff --git a/frontend/src/conversation/conversationModels.ts b/frontend/src/conversation/conversationModels.ts index cfa4ec8e..61bbbb5d 100644 --- a/frontend/src/conversation/conversationModels.ts +++ b/frontend/src/conversation/conversationModels.ts @@ -1,7 +1,7 @@ import { ToolCallsType } from './types'; export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR'; -export type Status = 'idle' | 'loading' | 'failed'; +export type Status = 'idle' | 'loading' | 'failed' | 'awaiting_tool_actions'; export type FEEDBACK = 'LIKE' | 'DISLIKE' | null; export interface Message { diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index bc234de8..6181e1da 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -10,6 +10,8 @@ import { import { handleFetchAnswer, handleFetchAnswerSteaming, + handleSubmitToolActions, + handleV1ChatCompletionStreaming, } from './conversationHandlers'; import { Answer, @@ -27,6 +29,7 @@ const initialState: ConversationState = { }; const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true'; +const USE_V1_API = import.meta.env.VITE_USE_V1_API === 'true'; let abortController: AbortController | null = null; export function handleAbort() { @@ -60,7 +63,102 @@ export const fetchAnswer = createAsyncThunk< state.preference.selectedModel?.id; if (state.preference) { - if (API_STREAMING) { + const agentKey = state.preference.selectedAgent?.key; + if (USE_V1_API && agentKey) { + // Build history from prior queries for v1 format + const v1History = state.conversation.queries + .filter((q) => q.response) + .map((q) => ({ prompt: q.prompt, response: q.response || '' })); + + await handleV1ChatCompletionStreaming( + question, + signal, + agentKey, + v1History, + (event) => { + const data = JSON.parse(event.data); + const targetIndex = indx ?? state.conversation.queries.length - 1; + + if (currentConversationId === state.conversation.conversationId) { + if (data.type === 'end') { + dispatch(conversationSlice.actions.setStatus('idle')); + getConversations(state.preference.token) + .then((fetchedConversations) => { + dispatch(setConversations(fetchedConversations)); + }) + .catch((error) => { + console.error('Failed to fetch conversations: ', error); + }); + if (!isSourceUpdated) { + dispatch( + updateStreamingSource({ + conversationId: currentConversationId, + index: targetIndex, + query: { sources: [] }, + }), + ); + } + } else if (data.type === 'id') { + const currentState = getState() as RootState; + if (currentState.conversation.conversationId === null) { + dispatch( + updateConversationId({ + query: { conversationId: data.id }, + }), + ); + } + } else if (data.type === 'thought') { + dispatch( + updateThought({ + conversationId: currentConversationId, + index: targetIndex, + query: { thought: data.thought }, + }), + ); + } else if (data.type === 'source') { + isSourceUpdated = true; + dispatch( + updateStreamingSource({ + conversationId: currentConversationId, + index: targetIndex, + query: { sources: data.source ?? [] }, + }), + ); + } else if (data.type === 'tool_call') { + dispatch( + updateToolCall({ + index: targetIndex, + tool_call: data.data as ToolCallsType, + }), + ); + } else if (data.type === 'tool_calls_pending') { + dispatch( + conversationSlice.actions.setStatus('awaiting_tool_actions'), + ); + } else if (data.type === 'error') { + dispatch(conversationSlice.actions.setStatus('failed')); + dispatch( + conversationSlice.actions.raiseError({ + conversationId: currentConversationId, + index: targetIndex, + message: data.error, + }), + ); + } else { + dispatch( + updateStreamingQuery({ + conversationId: currentConversationId, + index: targetIndex, + query: { response: data.answer }, + }), + ); + } + } + }, + undefined, + attachmentIds.length > 0 ? attachmentIds : undefined, + ); + } else if (API_STREAMING) { await handleFetchAnswerSteaming( question, signal, @@ -138,6 +236,10 @@ export const fetchAnswer = createAsyncThunk< tool_call: data.data as ToolCallsType, }), ); + } else if (data.type === 'tool_calls_pending') { + dispatch( + conversationSlice.actions.setStatus('awaiting_tool_actions'), + ); } else if (data.type === 'research_plan') { dispatch( updateResearchPlan({ @@ -260,6 +362,94 @@ export const fetchAnswer = createAsyncThunk< }; }); +export const submitToolActions = createAsyncThunk< + void, + { + toolActions: { + call_id: string; + decision?: 'approved' | 'denied'; + comment?: string; + result?: Record; + }[]; + } +>('submitToolActions', async ({ toolActions }, { dispatch, getState }) => { + if (abortController) abortController.abort(); + abortController = new AbortController(); + const { signal } = abortController; + + const state = getState() as RootState; + const conversationId = state.conversation.conversationId; + if (!conversationId) return; + + dispatch(conversationSlice.actions.setStatus('loading')); + + await handleSubmitToolActions( + conversationId, + toolActions, + state.preference.token, + signal, + (event) => { + const data = JSON.parse(event.data); + const targetIndex = state.conversation.queries.length - 1; + + if (data.type === 'end') { + dispatch(conversationSlice.actions.setStatus('idle')); + getConversations(state.preference.token) + .then((fetchedConversations) => { + dispatch(setConversations(fetchedConversations)); + }) + .catch((error) => { + console.error('Failed to fetch conversations: ', error); + }); + } else if (data.type === 'id') { + // conversation ID already set + } else if (data.type === 'thought') { + dispatch( + updateThought({ + conversationId, + index: targetIndex, + query: { thought: data.thought }, + }), + ); + } else if (data.type === 'source') { + dispatch( + updateStreamingSource({ + conversationId, + index: targetIndex, + query: { sources: data.source ?? [] }, + }), + ); + } else if (data.type === 'tool_call') { + dispatch( + updateToolCall({ + index: targetIndex, + tool_call: data.data as ToolCallsType, + }), + ); + } else if (data.type === 'tool_calls_pending') { + dispatch(conversationSlice.actions.setStatus('awaiting_tool_actions')); + } else if (data.type === 'error') { + dispatch(conversationSlice.actions.setStatus('failed')); + dispatch( + conversationSlice.actions.raiseError({ + conversationId, + index: targetIndex, + message: data.error, + }), + ); + } else if (data.type === 'answer') { + dispatch( + updateStreamingQuery({ + conversationId, + index: targetIndex, + query: { response: data.answer }, + }), + ); + } + }, + ); +}); + export const conversationSlice = createSlice({ name: 'conversation', initialState, diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index c416bde6..fb0c948f 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -5,6 +5,12 @@ export type ToolCallsType = { arguments: Record; result?: Record; error?: string; - status?: 'pending' | 'completed' | 'error'; + status?: + | 'pending' + | 'completed' + | 'error' + | 'awaiting_approval' + | 'denied' + | 'requires_client_execution'; artifact_id?: string; }; diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx index 8c8faff6..5ae2a208 100644 --- a/frontend/src/settings/ToolConfig.tsx +++ b/frontend/src/settings/ToolConfig.tsx @@ -487,9 +487,33 @@ export default function ToolConfig({ )}
e.stopPropagation()} > +
+ + {t('settings.tools.requireApproval', 'Approval')} + + { + setTool({ + ...tool, + actions: tool.actions.map((act, index) => { + if (index === originalIndex) { + return { + ...act, + require_approval: checked, + }; + } + return act; + }), + }); + }} + size="small" + id={`approvalToggle-${originalIndex}`} + /> +
{ @@ -926,6 +950,35 @@ function APIToolConfig({ className="h-4 w-4 opacity-40 transition-opacity hover:opacity-100" /> +
+ + {t('settings.tools.requireApproval', 'Approval')} + + { + setApiTool((prevApiTool) => { + const updatedActions = { + ...prevApiTool.config.actions, + }; + updatedActions[actionName] = { + ...updatedActions[actionName], + require_approval: + !updatedActions[actionName].require_approval, + }; + return { + ...prevApiTool, + config: { + ...prevApiTool.config, + actions: updatedActions, + }, + }; + }); + }} + size="small" + id={`approvalToggle-${actionIndex}`} + /> +
handleActionToggle(actionName)} diff --git a/frontend/src/settings/types/index.ts b/frontend/src/settings/types/index.ts index ed1b9122..d692a960 100644 --- a/frontend/src/settings/types/index.ts +++ b/frontend/src/settings/types/index.ts @@ -69,6 +69,7 @@ export type UserToolType = { type: string; }; active: boolean; + require_approval?: boolean; }[]; }; @@ -81,6 +82,7 @@ export type APIActionType = { headers: ParameterGroupType; body: ParameterGroupType; active: boolean; + require_approval?: boolean; body_content_type?: | 'application/json' | 'application/x-www-form-urlencoded' diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py index 05e831f0..f65c7b0d 100644 --- a/tests/agents/test_base_agent.py +++ b/tests/agents/test_base_agent.py @@ -341,7 +341,7 @@ class TestBaseAgentTools: assert len(agent.tools) == 1 assert agent.tools[0]["type"] == "function" - assert agent.tools[0]["function"]["name"] == "get_data_1" + assert agent.tools[0]["function"]["name"] == "get_data" def test_prepare_tools_with_regular_tool( self, agent_base_params, mock_llm_creator, mock_llm_handler_creator @@ -365,7 +365,7 @@ class TestBaseAgentTools: agent._prepare_tools(tools_dict) assert len(agent.tools) == 1 - assert agent.tools[0]["function"]["name"] == "action1_1" + assert agent.tools[0]["function"]["name"] == "action1" def test_prepare_tools_filters_inactive_actions( self, agent_base_params, mock_llm_creator, mock_llm_handler_creator @@ -395,7 +395,7 @@ class TestBaseAgentTools: agent._prepare_tools(tools_dict) assert len(agent.tools) == 1 - assert agent.tools[0]["function"]["name"] == "active_action_1" + assert agent.tools[0]["function"]["name"] == "active_action" @pytest.mark.unit diff --git a/tests/agents/test_tool_action_parser.py b/tests/agents/test_tool_action_parser.py index 840d070c..61e9e962 100644 --- a/tests/agents/test_tool_action_parser.py +++ b/tests/agents/test_tool_action_parser.py @@ -202,3 +202,69 @@ class TestToolActionParser: assert action_name == "create_record" assert call_args["data"]["name"] == "John" assert call_args["data"]["age"] == 30 + + +@pytest.mark.unit +class TestToolActionParserWithMapping: + """Tests for the mapping-based lookup path.""" + + def test_openai_mapping_resolves_clean_name(self): + mapping = {"get_weather": ("ct0", "get_weather")} + parser = ToolActionParser("OpenAILLM", name_mapping=mapping) + + call = Mock() + call.name = "get_weather" + call.arguments = '{"city": "SF"}' + + tool_id, action_name, call_args = parser.parse_args(call) + assert tool_id == "ct0" + assert action_name == "get_weather" + assert call_args == {"city": "SF"} + + def test_openai_mapping_resolves_numbered_suffix(self): + mapping = {"search_1": ("t1", "search"), "search_2": ("t2", "search")} + parser = ToolActionParser("OpenAILLM", name_mapping=mapping) + + call = Mock() + call.name = "search_1" + call.arguments = '{"q": "test"}' + + tool_id, action_name, call_args = parser.parse_args(call) + assert tool_id == "t1" + assert action_name == "search" + + def test_google_mapping_resolves(self): + mapping = {"get_weather": ("ct0", "get_weather")} + parser = ToolActionParser("GoogleLLM", name_mapping=mapping) + + call = Mock() + call.name = "get_weather" + call.arguments = {"city": "SF"} + + tool_id, action_name, call_args = parser.parse_args(call) + assert tool_id == "ct0" + assert action_name == "get_weather" + + def test_fallback_to_split_when_not_in_mapping(self): + mapping = {"get_weather": ("ct0", "get_weather")} + parser = ToolActionParser("OpenAILLM", name_mapping=mapping) + + call = Mock() + call.name = "unknown_action_99" + call.arguments = "{}" + + tool_id, action_name, call_args = parser.parse_args(call) + # Falls back to legacy split + assert tool_id == "99" + assert action_name == "unknown_action" + + def test_no_mapping_uses_legacy_split(self): + parser = ToolActionParser("OpenAILLM", name_mapping=None) + + call = Mock() + call.name = "action_123" + call.arguments = '{"k": "v"}' + + tool_id, action_name, call_args = parser.parse_args(call) + assert tool_id == "123" + assert action_name == "action" diff --git a/tests/agents/test_tool_executor.py b/tests/agents/test_tool_executor.py index 96be815c..02b6ce0f 100644 --- a/tests/agents/test_tool_executor.py +++ b/tests/agents/test_tool_executor.py @@ -80,7 +80,7 @@ class TestToolExecutorPrepare: result = executor.prepare_tools_for_llm(tools_dict) assert len(result) == 1 assert result[0]["type"] == "function" - assert result[0]["function"]["name"] == "do_thing_t1" + assert result[0]["function"]["name"] == "do_thing" assert "query" in result[0]["function"]["parameters"]["properties"] def test_prepare_tools_skips_inactive_actions(self): @@ -97,7 +97,96 @@ class TestToolExecutorPrepare: result = executor.prepare_tools_for_llm(tools_dict) assert len(result) == 1 - assert result[0]["function"]["name"] == "active_one_t1" + assert result[0]["function"]["name"] == "active_one" + + def test_prepare_tools_builds_name_mapping(self): + executor = ToolExecutor() + tools_dict = { + "t1": { + "name": "test_tool", + "actions": [ + {"name": "do_thing", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + } + } + executor.prepare_tools_for_llm(tools_dict) + assert executor._name_to_tool["do_thing"] == ("t1", "do_thing") + assert executor._tool_to_name[("t1", "do_thing")] == "do_thing" + + def test_prepare_tools_duplicate_names_get_numbered_suffixes(self): + executor = ToolExecutor() + tools_dict = { + "t1": { + "name": "tool_a", + "actions": [ + {"name": "search", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + }, + "t2": { + "name": "tool_b", + "actions": [ + {"name": "search", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + }, + } + result = executor.prepare_tools_for_llm(tools_dict) + names = [r["function"]["name"] for r in result] + assert "search_1" in names + assert "search_2" in names + assert executor._name_to_tool["search_1"][1] == "search" + assert executor._name_to_tool["search_2"][1] == "search" + + def test_prepare_tools_unique_name_no_suffix(self): + executor = ToolExecutor() + tools_dict = { + "t1": { + "name": "tool_a", + "actions": [ + {"name": "get_weather", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + }, + "t2": { + "name": "tool_b", + "actions": [ + {"name": "send_email", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + }, + } + result = executor.prepare_tools_for_llm(tools_dict) + names = [r["function"]["name"] for r in result] + assert "get_weather" in names + assert "send_email" in names + + def test_prepare_tools_suffix_skips_collision_with_unique_name(self): + """If action 'foo_1' exists as unique and 'foo' is duplicated, skip '_1'.""" + executor = ToolExecutor() + tools_dict = { + "t1": { + "name": "tool_a", + "actions": [ + {"name": "foo", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + }, + "t2": { + "name": "tool_b", + "actions": [ + {"name": "foo", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + }, + "t3": { + "name": "tool_c", + "actions": [ + {"name": "foo_1", "description": "D", "active": True, "parameters": {"properties": {}}}, + ], + }, + } + result = executor.prepare_tools_for_llm(tools_dict) + names = [r["function"]["name"] for r in result] + # foo_1 is taken by the unique action, so duplicates skip to _2 and _3 + assert "foo_1" in names # The unique action + assert "foo_2" in names + assert "foo_3" in names + assert executor._name_to_tool["foo_1"] == ("t3", "foo_1") def test_build_tool_parameters_filters_non_llm_fields(self): executor = ToolExecutor() @@ -128,6 +217,68 @@ class TestToolExecutorPrepare: assert "value" not in result["properties"]["query"] +@pytest.mark.unit +class TestCheckPause: + + def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"): + call = Mock() + call.name = name + call.id = call_id + call.arguments = arguments + call.thought_signature = None + return call + + def test_client_side_tool_returns_llm_name(self): + """check_pause returns the clean LLM-facing name and llm_name field.""" + executor = ToolExecutor() + + tools_dict = { + "ct0": { + "name": "write_file", + "client_side": True, + "actions": [ + {"name": "write_file", "description": "Write a file", "active": True, "parameters": {}}, + ], + } + } + + # Prepare tools so the mapping is built + executor.prepare_tools_for_llm(tools_dict) + + call = self._make_call(name="write_file") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + + assert result is not None + assert result["name"] == "write_file" + assert result["llm_name"] == "write_file" + assert result["action_name"] == "write_file" + assert result["tool_id"] == "ct0" + + def test_approval_required_returns_llm_name(self): + """check_pause for approval-required tools returns clean LLM name.""" + executor = ToolExecutor() + + tools_dict = { + "t1": { + "name": "dangerous_tool", + "actions": [ + {"name": "delete_all", "description": "Deletes everything", "active": True, + "require_approval": True, "parameters": {}}, + ], + } + } + + executor.prepare_tools_for_llm(tools_dict) + + call = self._make_call(name="delete_all") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + + assert result is not None + assert result["name"] == "delete_all" + assert result["llm_name"] == "delete_all" + assert result["action_name"] == "delete_all" + + @pytest.mark.unit class TestToolExecutorExecute: @@ -143,7 +294,7 @@ class TestToolExecutorExecute: monkeypatch.setattr( "application.agents.tool_executor.ToolActionParser", - lambda _cls: Mock(parse_args=Mock(return_value=(None, None, {}))), + lambda _cls, **kw: Mock(parse_args=Mock(return_value=(None, None, {}))), ) call = self._make_call(name="bad") @@ -167,7 +318,7 @@ class TestToolExecutorExecute: monkeypatch.setattr( "application.agents.tool_executor.ToolActionParser", - lambda _cls: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))), + lambda _cls, **kw: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))), ) call = self._make_call() @@ -190,7 +341,7 @@ class TestToolExecutorExecute: monkeypatch.setattr( "application.agents.tool_executor.ToolActionParser", - lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))), + lambda _cls, **kw: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))), ) tools_dict = { @@ -244,7 +395,7 @@ class TestToolExecutorExecute: monkeypatch.setattr( "application.agents.tool_executor.ToolActionParser", - lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))), + lambda _cls, **kw: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))), ) tools_dict = { @@ -284,7 +435,7 @@ class TestToolExecutorExecute: monkeypatch.setattr( "application.agents.tool_executor.ToolActionParser", - lambda _cls: Mock( + lambda _cls, **kw: Mock( parse_args=Mock(return_value=("t1", "get_users", {"body_param": "val"})) ), ) @@ -331,7 +482,7 @@ class TestToolExecutorExecute: monkeypatch.setattr( "application.agents.tool_executor.ToolActionParser", - lambda _cls: Mock( + lambda _cls, **kw: Mock( parse_args=Mock(return_value=("t1", "act", {})) ), ) @@ -376,7 +527,7 @@ class TestToolExecutorExecute: monkeypatch.setattr( "application.agents.tool_executor.ToolActionParser", - lambda _cls: Mock( + lambda _cls, **kw: Mock( parse_args=Mock(return_value=("t1", "act", {"q": "v"})) ), ) diff --git a/tests/api/answer/routes/test_answer.py b/tests/api/answer/routes/test_answer.py index 53f2525d..4e4b7a8b 100644 --- a/tests/api/answer/routes/test_answer.py +++ b/tests/api/answer/routes/test_answer.py @@ -73,7 +73,7 @@ class TestAnswerResourcePost: ), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=(conv_id, "Hello", [], [], "", None), + return_value={"conversation_id": conv_id, "answer": "Hello", "sources": [], "tool_calls": [], "thought": "", "error": None}, ): resp = answer_client.post( "/api/answer", @@ -129,7 +129,7 @@ class TestAnswerResourcePost: return_value=iter([]), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=(None, None, None, None, None, "Stream error"), + return_value={"conversation_id": None, "answer": None, "sources": None, "tool_calls": None, "thought": None, "error": "Stream error"}, ): resp = answer_client.post( "/api/answer", @@ -173,15 +173,7 @@ class TestAnswerResourcePost: return_value=iter([]), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=( - conv_id, - '{"key": "val"}', - [], - [], - "", - None, - {"structured": True, "schema": {"type": "object"}}, - ), + return_value={"conversation_id": conv_id, "answer": '{"key": "val"}', "sources": [], "tool_calls": [], "thought": "", "error": None, "extra": {"structured": True, "schema": {"type": "object"}}}, ): resp = answer_client.post( "/api/answer", @@ -208,14 +200,7 @@ class TestAnswerResourcePost: return_value=iter([]), ), patch( "application.api.answer.routes.answer.AnswerResource.process_response_stream", - return_value=( - conv_id, - "answer text", - [{"title": "src"}], - [{"tool": "t"}], - "thinking...", - None, - ), + return_value={"conversation_id": conv_id, "answer": "answer text", "sources": [{"title": "src"}], "tool_calls": [{"tool": "t"}], "thought": "thinking...", "error": None}, ): resp = answer_client.post( "/api/answer", diff --git a/tests/api/answer/routes/test_base.py b/tests/api/answer/routes/test_base.py index eee2ab31..95b7e423 100644 --- a/tests/api/answer/routes/test_base.py +++ b/tests/api/answer/routes/test_base.py @@ -481,10 +481,10 @@ class TestProcessResponseStream: result = resource.process_response_stream(iter(stream)) - assert result[0] == conv_id - assert result[1] == "Hello world" - assert result[2] == [{"title": "doc1"}] - assert result[5] is None + assert result["conversation_id"] == conv_id + assert result["answer"] == "Hello world" + assert result["sources"] == [{"title": "doc1"}] + assert result["error"] is None def test_handles_stream_error(self, mock_mongo_db, flask_app): import json @@ -500,10 +500,8 @@ class TestProcessResponseStream: result = resource.process_response_stream(iter(stream)) - assert len(result) == 6 - assert result[0] is None - assert result[4] == "Test error" - assert result[5] is None + assert result["conversation_id"] is None + assert result["error"] == "Test error" def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource diff --git a/tests/api/answer/services/compression/test_message_builder.py b/tests/api/answer/services/compression/test_message_builder.py index 9cb681d4..99a0e1c2 100644 --- a/tests/api/answer/services/compression/test_message_builder.py +++ b/tests/api/answer/services/compression/test_message_builder.py @@ -79,9 +79,10 @@ class TestBuildFromCompressedContext: # system + user + assistant + tool_call_assistant + tool_response = 5 assert len(messages) == 5 assert messages[3]["role"] == "assistant" - assert "function_call" in messages[3]["content"][0] + assert messages[3].get("tool_calls") is not None + assert messages[3]["tool_calls"][0]["function"]["name"] == "search" assert messages[4]["role"] == "tool" - assert "function_response" in messages[4]["content"][0] + assert messages[4].get("tool_call_id") == "call-1" def test_tool_calls_not_included_by_default(self): queries = [ @@ -127,8 +128,8 @@ class TestBuildFromCompressedContext: recent_queries=queries, include_tool_calls=True, ) - tool_msg = messages[3]["content"][0] - call_id = tool_msg["function_call"]["call_id"] + assistant_msg = messages[3] + call_id = assistant_msg["tool_calls"][0]["id"] assert call_id is not None assert len(call_id) > 0 diff --git a/tests/api/answer/test_base_routes.py b/tests/api/answer/test_base_routes.py index dc6a19a5..d47076ad 100644 --- a/tests/api/answer/test_base_routes.py +++ b/tests/api/answer/test_base_routes.py @@ -295,10 +295,8 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "end"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[1] == "{}" - # Structured output adds extra tuple element - assert len(result) == 7 - assert result[6]["structured"] is True + assert result["answer"] == "{}" + assert result.get("extra", {}).get("structured") is True def test_handles_tool_calls_event(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource @@ -312,7 +310,7 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "end"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[3] == [{"name": "t1"}] + assert result["tool_calls"] == [{"name": "t1"}] def test_incomplete_stream(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource @@ -323,7 +321,7 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "answer", "answer": "partial"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[4] == "Stream ended unexpectedly" + assert result["error"] == "Stream ended unexpectedly" def test_handles_thought_event(self, mock_mongo_db, flask_app): from application.api.answer.routes.base import BaseAnswerResource @@ -335,7 +333,7 @@ class TestProcessResponseStreamExtended: f'data: {json.dumps({"type": "end"})}\n\n', ] result = resource.process_response_stream(iter(stream)) - assert result[4] == "thinking..." + assert result["thought"] == "thinking..." @pytest.mark.unit diff --git a/tests/integration/run_all.py b/tests/integration/run_all.py index 12397a84..b8137b27 100644 --- a/tests/integration/run_all.py +++ b/tests/integration/run_all.py @@ -50,6 +50,7 @@ from tests.integration.test_analytics import AnalyticsTests from tests.integration.test_connectors import ConnectorTests from tests.integration.test_mcp import MCPTests from tests.integration.test_misc import MiscTests +from tests.integration.test_v1_api import V1ApiTests # Module registry @@ -64,6 +65,7 @@ MODULES = { "connectors": ConnectorTests, "mcp": MCPTests, "misc": MiscTests, + "v1_api": V1ApiTests, } diff --git a/tests/integration/test_chat.py b/tests/integration/test_chat.py index 2651d169..b213e37b 100644 --- a/tests/integration/test_chat.py +++ b/tests/integration/test_chat.py @@ -1036,208 +1036,7 @@ This is test documentation for integration tests. return False # ------------------------------------------------------------------------- - # Compression Tests - # ------------------------------------------------------------------------- - - def test_compression_heavy_tool_usage(self) -> bool: - """Test compression with heavy conversation usage.""" - test_name = "Compression - Heavy Tool Usage" - self.print_header(f"Testing {test_name}") - - if not self.require_auth(test_name): - return True - - self.print_info("Making 10 consecutive requests to build conversation history...") - - current_conv_id = None - - for i in range(10): - question = f"Tell me about Python topic {i+1}: data structures, decorators, async, testing. Provide a comprehensive explanation." - - payload = { - "question": question, - "history": "[]", - "isNoneDoc": True, - } - - if current_conv_id: - payload["conversation_id"] = current_conv_id - - try: - response = self.post("/api/answer", json=payload, timeout=90) - - if response.status_code == 200: - result = response.json() - current_conv_id = result.get("conversation_id", current_conv_id) - answer_preview = (result.get("answer") or "")[:80] - self.print_success(f"Request {i+1}/10 completed") - self.print_info(f" Answer: {answer_preview}...") - else: - self.print_error(f"Request {i+1}/10 failed: status {response.status_code}") - self.record_result(test_name, False, f"Request {i+1} failed") - return False - - time.sleep(2) - - except Exception as e: - self.print_error(f"Request {i+1}/10 failed: {str(e)}") - self.record_result(test_name, False, str(e)) - return False - - if current_conv_id: - self.print_success("Heavy usage test completed") - self.record_result(test_name, True, f"10 requests, conv_id: {current_conv_id}") - return True - else: - self.print_warning("No conversation_id received") - self.record_result(test_name, False, "No conversation_id") - return False - - def test_compression_needle_in_haystack(self) -> bool: - """Test that compression preserves critical information. - - Note: This is a long-running test that may timeout due to LLM response times. - Timeouts are handled gracefully as they indicate performance issues, not bugs. - """ - test_name = "Compression - Needle in Haystack" - self.print_header(f"Testing {test_name}") - - if not self.require_auth(test_name): - return True - - conversation_id = None - - # Step 1: Send general questions - self.print_info("Step 1: Sending general questions...") - for i, question in enumerate([ - "Tell me about Python best practices in detail", - "Explain Python data structures comprehensively", - ]): - payload = { - "question": question, - "history": "[]", - "isNoneDoc": True, - } - if conversation_id: - payload["conversation_id"] = conversation_id - - try: - response = self.post("/api/answer", json=payload, timeout=90) - if response.status_code == 200: - result = response.json() - conversation_id = result.get("conversation_id", conversation_id) - self.print_success(f"General question {i+1}/2 completed") - else: - self.print_error(f"Request failed: status {response.status_code}") - self.record_result(test_name, False, "General questions failed") - return False - time.sleep(2) - except Exception as e: - # Timeout errors are expected for long LLM responses - if "timed out" in str(e).lower() or "timeout" in str(e).lower(): - self.print_warning(f"Request timed out: {str(e)[:50]}") - self.record_result(test_name, True, "Skipped (timeout)") - return True - self.print_error(f"Request failed: {str(e)}") - self.record_result(test_name, False, str(e)) - return False - - # Step 2: Send critical information - self.print_info("Step 2: Sending CRITICAL information...") - critical_payload = { - "question": "Please remember: The production database password is stored in DB_PASSWORD_PROD environment variable. The backup runs at 3:00 AM UTC daily.", - "history": "[]", - "isNoneDoc": True, - "conversation_id": conversation_id, - } - - try: - response = self.post("/api/answer", json=critical_payload, timeout=90) - if response.status_code == 200: - result = response.json() - conversation_id = result.get("conversation_id", conversation_id) - self.print_success("Critical information sent") - else: - self.record_result(test_name, False, "Critical info failed") - return False - time.sleep(2) - except Exception as e: - if "timed out" in str(e).lower() or "timeout" in str(e).lower(): - self.print_warning(f"Request timed out: {str(e)[:50]}") - self.record_result(test_name, True, "Skipped (timeout)") - return True - self.record_result(test_name, False, str(e)) - return False - - # Step 3: Bury with more questions - self.print_info("Step 3: Sending more questions to bury critical info...") - for i, question in enumerate([ - "Explain Python decorators in great detail", - "Tell me about Python async programming comprehensively", - ]): - payload = { - "question": question, - "history": "[]", - "isNoneDoc": True, - "conversation_id": conversation_id, - } - - try: - response = self.post("/api/answer", json=payload, timeout=90) - if response.status_code == 200: - result = response.json() - conversation_id = result.get("conversation_id", conversation_id) - self.print_success(f"Burying question {i+1}/2 completed") - else: - self.record_result(test_name, False, "Burying questions failed") - return False - time.sleep(2) - except Exception as e: - if "timed out" in str(e).lower() or "timeout" in str(e).lower(): - self.print_warning(f"Request timed out: {str(e)[:50]}") - self.record_result(test_name, True, "Skipped (timeout)") - return True - self.record_result(test_name, False, str(e)) - return False - - # Step 4: Test recall - self.print_info("Step 4: Testing if critical info was preserved...") - recall_payload = { - "question": "What was the database password environment variable I mentioned earlier?", - "history": "[]", - "isNoneDoc": True, - "conversation_id": conversation_id, - } - - try: - response = self.post("/api/answer", json=recall_payload, timeout=90) - if response.status_code == 200: - result = response.json() - answer = (result.get("answer") or "").lower() - - if "db_password_prod" in answer or "database password" in answer: - self.print_success("Critical information preserved!") - self.print_info(f"Answer: {answer[:150]}...") - self.record_result(test_name, True, "Info preserved") - return True - else: - self.print_warning("Critical information may have been lost") - self.print_info(f"Answer: {answer[:150]}...") - self.record_result(test_name, False, "Info not preserved") - return False - else: - self.record_result(test_name, False, "Recall failed") - return False - except Exception as e: - if "timed out" in str(e).lower() or "timeout" in str(e).lower(): - self.print_warning(f"Request timed out: {str(e)[:50]}") - self.record_result(test_name, True, "Skipped (timeout)") - return True - self.record_result(test_name, False, str(e)) - return False - - # ------------------------------------------------------------------------- - # Feedback Tests (NEW) + # Feedback Tests # ------------------------------------------------------------------------- def test_feedback_positive(self) -> bool: @@ -1435,15 +1234,6 @@ This is test documentation for integration tests. self.test_tts_basic() time.sleep(1) - # Compression tests (longer running) - if self.is_authenticated: - self.test_compression_heavy_tool_usage() - time.sleep(2) - - self.test_compression_needle_in_haystack() - else: - self.print_info("Skipping compression tests (no authentication)") - return self.print_summary() diff --git a/tests/integration/test_v1_api.py b/tests/integration/test_v1_api.py new file mode 100644 index 00000000..9774af2b --- /dev/null +++ b/tests/integration/test_v1_api.py @@ -0,0 +1,681 @@ +#!/usr/bin/env python3 +""" +Integration tests for the /v1/ chat completions API (Phase 4). + +Endpoints tested: +- /v1/chat/completions (POST) - Standard chat completions (streaming & non-streaming) +- /v1/models (GET) - List available agent models + +Usage: + python tests/integration/test_v1_api.py + python tests/integration/test_v1_api.py --base-url http://localhost:7091 + python tests/integration/test_v1_api.py --token YOUR_JWT_TOKEN +""" + +import json as json_module +import sys +import time +from pathlib import Path +from typing import Optional + +import requests + +# Add parent directory to path for standalone execution +_THIS_DIR = Path(__file__).parent +_TESTS_DIR = _THIS_DIR.parent +_ROOT_DIR = _TESTS_DIR.parent +if str(_ROOT_DIR) not in sys.path: + sys.path.insert(0, str(_ROOT_DIR)) + +from tests.integration.base import DocsGPTTestBase, create_client_from_args + + +class V1ApiTests(DocsGPTTestBase): + """Integration tests for /v1/ chat completions API.""" + + # ------------------------------------------------------------------------- + # Test Data Helpers + # ------------------------------------------------------------------------- + + def get_or_create_agent_key(self) -> Optional[str]: + """Get or create a test agent and return its API key.""" + if hasattr(self, "_agent_key") and self._agent_key: + return self._agent_key + + # Try both authenticated and unauthenticated creation. + # Published agents need a source to get an API key. + payload = { + "name": f"V1 Test Agent {int(time.time())}", + "description": "Integration test agent for v1 API tests", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": "classic", + "status": "published", + "source": "default", + } + + try: + response = self.post("/api/create_agent", json=payload, timeout=10) + if response.status_code in [200, 201]: + result = response.json() + api_key = result.get("key") + self._agent_id = result.get("id") + if api_key: + self._agent_key = api_key + self.print_info(f"Created test agent with key: {api_key[:8]}...") + return api_key + else: + self.print_warning("Agent created but no API key returned") + else: + self.print_warning(f"Agent creation returned {response.status_code}: {response.text[:200]}") + except Exception as e: + self.print_error(f"Failed to create agent: {e}") + + return None + + def _v1_headers(self, api_key: str) -> dict: + """Build headers for v1 API requests.""" + return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + + # ------------------------------------------------------------------------- + # /v1/chat/completions — Auth Tests + # ------------------------------------------------------------------------- + + def test_no_auth_returns_401(self) -> bool: + """Test that /v1/chat/completions without auth returns 401.""" + test_name = "v1 chat completions - no auth" + self.print_header(f"Testing {test_name}") + + try: + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hi"}]}, + headers={"Content-Type": "application/json"}, + timeout=10, + ) + + if response.status_code == 401: + self.print_success("Correctly returned 401 for missing auth") + self.record_result(test_name, True, "401 as expected") + return True + else: + self.print_error(f"Expected 401, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + except Exception as e: + self.print_error(f"Request failed: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_invalid_key_returns_error(self) -> bool: + """Test that invalid API key returns error.""" + test_name = "v1 chat completions - invalid key" + self.print_header(f"Testing {test_name}") + + try: + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hi"}]}, + headers=self._v1_headers("invalid-key-12345"), + timeout=30, + ) + + # Should return 400 or 500 (agent not found) + if response.status_code in [400, 401, 500]: + self.print_success(f"Correctly returned {response.status_code} for invalid key") + self.record_result(test_name, True, f"Error as expected ({response.status_code})") + return True + else: + self.print_error(f"Unexpected status: {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + except Exception as e: + self.print_error(f"Request failed: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_missing_messages_returns_400(self) -> bool: + """Test that missing messages field returns 400.""" + test_name = "v1 chat completions - missing messages" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json={"stream": False}, + headers=self._v1_headers(api_key), + timeout=10, + ) + + if response.status_code == 400: + self.print_success("Correctly returned 400 for missing messages") + self.record_result(test_name, True, "400 as expected") + return True + else: + self.print_error(f"Expected 400, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + except Exception as e: + self.print_error(f"Request failed: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # /v1/chat/completions — Non-streaming + # ------------------------------------------------------------------------- + + def test_non_streaming_basic(self) -> bool: + """Test basic non-streaming chat completion.""" + test_name = "v1 chat completions - non-streaming" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Say hello in one word."}], + "stream": False, + }, + headers=self._v1_headers(api_key), + timeout=60, + ) + + self.print_info(f"Status: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.print_error(f"Response: {response.text[:300]}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + data = response.json() + + # Verify standard format + checks = [ + ("id" in data, "has id"), + (data.get("object") == "chat.completion", "object is chat.completion"), + ("choices" in data, "has choices"), + (len(data["choices"]) > 0, "choices not empty"), + (data["choices"][0].get("message", {}).get("role") == "assistant", "role is assistant"), + (data["choices"][0].get("message", {}).get("content") is not None, "has content"), + (data["choices"][0].get("finish_reason") == "stop", "finish_reason is stop"), + ("usage" in data, "has usage"), + ] + + all_passed = True + for check, label in checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + content = data["choices"][0]["message"]["content"] + self.print_info(f"Response: {content[:100]}") + + # Check docsgpt extension + if "docsgpt" in data: + self.print_success(" has docsgpt extension") + if "conversation_id" in data["docsgpt"]: + self.print_success(f" conversation_id: {data['docsgpt']['conversation_id'][:8]}...") + + self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed") + return all_passed + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # /v1/chat/completions — Streaming + # ------------------------------------------------------------------------- + + def test_streaming_basic(self) -> bool: + """Test basic streaming chat completion.""" + test_name = "v1 chat completions - streaming" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Say hi briefly."}], + "stream": True, + }, + headers=self._v1_headers(api_key), + stream=True, + timeout=60, + ) + + self.print_info(f"Status: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + chunks = [] + content_pieces = [] + got_done = False + got_stop = False + got_id = False + + for line in response.iter_lines(): + if not line: + continue + line_str = line.decode("utf-8") + if not line_str.startswith("data: "): + continue + + data_str = line_str[6:] + if data_str.strip() == "[DONE]": + got_done = True + break + + try: + chunk = json_module.loads(data_str) + chunks.append(chunk) + + # Standard chunks + if "choices" in chunk: + delta = chunk["choices"][0].get("delta", {}) + if "content" in delta: + content_pieces.append(delta["content"]) + if chunk["choices"][0].get("finish_reason") == "stop": + got_stop = True + + # Extension chunks + if "docsgpt" in chunk: + ext = chunk["docsgpt"] + if ext.get("type") == "id": + got_id = True + + except json_module.JSONDecodeError: + pass + + full_content = "".join(content_pieces) + + checks = [ + (len(chunks) > 0, f"received {len(chunks)} chunks"), + (len(content_pieces) > 0, f"got content: {full_content[:50]}..."), + (got_stop, "got finish_reason=stop"), + (got_done, "got [DONE] sentinel"), + ] + + all_passed = True + for check, label in checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + if got_id: + self.print_success(" got conversation_id via docsgpt extension") + + self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed") + return all_passed + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # /v1/chat/completions — Multi-turn conversation + # ------------------------------------------------------------------------- + + def test_multi_turn_conversation(self) -> bool: + """Test multi-turn conversation with history in messages.""" + test_name = "v1 chat completions - multi-turn" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "My name is TestBot."}, + {"role": "assistant", "content": "Hello TestBot!"}, + {"role": "user", "content": "What is my name?"}, + ], + "stream": False, + }, + headers=self._v1_headers(api_key), + timeout=60, + ) + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + data = response.json() + content = data["choices"][0]["message"]["content"] + self.print_info(f"Response: {content[:150]}") + + # The response should reference "TestBot" from the history + has_content = bool(content) + self.print_success(f" Got response with {len(content)} chars") + self.record_result(test_name, has_content, "Multi-turn works") + return has_content + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # /v1/models + # ------------------------------------------------------------------------- + + def test_list_models(self) -> bool: + """Test GET /v1/models endpoint.""" + test_name = "v1 models - list" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + response = requests.get( + f"{self.base_url}/v1/models", + headers=self._v1_headers(api_key), + timeout=10, + ) + + self.print_info(f"Status: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + data = response.json() + + checks = [ + (data.get("object") == "list", "object is list"), + ("data" in data, "has data array"), + (len(data.get("data", [])) > 0, f"has {len(data.get('data', []))} model(s)"), + ] + + all_passed = True + for check, label in checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + if data.get("data"): + model = data["data"][0] + model_checks = [ + ("id" in model, "model has id"), + (model.get("object") == "model", "model object is 'model'"), + (model.get("owned_by") == "docsgpt", "owned_by is docsgpt"), + ] + for check, label in model_checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + self.record_result(test_name, all_passed, "All checks passed" if all_passed else "Some checks failed") + return all_passed + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_models_no_auth(self) -> bool: + """Test that /v1/models without auth returns 401.""" + test_name = "v1 models - no auth" + self.print_header(f"Testing {test_name}") + + try: + response = requests.get( + f"{self.base_url}/v1/models", + timeout=10, + ) + + if response.status_code == 401: + self.print_success("Correctly returned 401") + self.record_result(test_name, True, "401 as expected") + return True + else: + self.print_error(f"Expected 401, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # Backward Compatibility — old endpoints still work + # ------------------------------------------------------------------------- + + def test_old_stream_endpoint_still_works(self) -> bool: + """Verify the old /stream endpoint still works after v1 changes.""" + test_name = "Backward compat - /stream" + self.print_header(f"Testing {test_name}") + + payload = { + "question": "Say hello briefly.", + "history": "[]", + "isNoneDoc": True, + } + + try: + response = requests.post( + f"{self.base_url}/stream", + json=payload, + headers=self.headers, + stream=True, + timeout=60, + ) + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + events = [] + got_end = False + got_answer = False + + for line in response.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + try: + data = json_module.loads(line_str[6:]) + events.append(data) + if data.get("type") == "answer": + got_answer = True + if data.get("type") == "end": + got_end = True + break + except json_module.JSONDecodeError: + pass + + checks = [ + (len(events) > 0, f"received {len(events)} events"), + (got_answer, "got answer event"), + (got_end, "got end event"), + ] + + all_passed = True + for check, label in checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression") + return all_passed + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_old_answer_endpoint_still_works(self) -> bool: + """Verify the old /api/answer endpoint still works.""" + test_name = "Backward compat - /api/answer" + self.print_header(f"Testing {test_name}") + + payload = { + "question": "Say hi.", + "history": "[]", + "isNoneDoc": True, + } + + try: + response = requests.post( + f"{self.base_url}/api/answer", + json=payload, + headers=self.headers, + timeout=60, + ) + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + data = response.json() + checks = [ + ("answer" in data, "has answer"), + ("conversation_id" in data, "has conversation_id"), + ] + + all_passed = True + for check, label in checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + self.print_info(f"Answer: {data.get('answer', '')[:100]}") + self.record_result(test_name, all_passed, "Old endpoint works" if all_passed else "Regression") + return all_passed + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # Cleanup + # ------------------------------------------------------------------------- + + def cleanup(self): + """Clean up test resources.""" + if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated: + try: + self.post(f"/api/delete_agent?id={self._agent_id}", json={}) + self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...") + except Exception: + pass + + # ------------------------------------------------------------------------- + # Run All + # ------------------------------------------------------------------------- + + def run_all(self) -> bool: + """Run all v1 API integration tests.""" + self.print_header("V1 Chat Completions API Integration Tests") + self.print_info(f"Base URL: {self.base_url}") + self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}") + + try: + # Auth tests (no agent needed) + self.test_no_auth_returns_401() + time.sleep(0.5) + + self.test_models_no_auth() + time.sleep(0.5) + + self.test_invalid_key_returns_error() + time.sleep(0.5) + + self.test_missing_messages_returns_400() + time.sleep(0.5) + + # Non-streaming + self.test_non_streaming_basic() + time.sleep(1) + + # Streaming + self.test_streaming_basic() + time.sleep(1) + + # Multi-turn + self.test_multi_turn_conversation() + time.sleep(1) + + # Models + self.test_list_models() + time.sleep(0.5) + + # Backward compatibility + self.test_old_stream_endpoint_still_works() + time.sleep(1) + + self.test_old_answer_endpoint_still_works() + time.sleep(1) + + finally: + self.cleanup() + + return self.print_summary() + + +def main(): + """Main entry point.""" + client = create_client_from_args(V1ApiTests, "DocsGPT V1 API Integration Tests") + success = client.run_all() + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/integration/test_v1_tool_calls.py b/tests/integration/test_v1_tool_calls.py new file mode 100644 index 00000000..af0f650a --- /dev/null +++ b/tests/integration/test_v1_tool_calls.py @@ -0,0 +1,539 @@ +#!/usr/bin/env python3 +r""" +Integration tests for the /v1/ chat completions API — client tool-call flow. + +Tests the full lifecycle: +1. Send request with client tools → LLM triggers a tool call +2. Verify response returns clean tool names (no internal _ct\d+ suffix) +3. Send continuation with tool results + top-level conversation_id +4. Verify the continuation completes successfully + +Usage: + python tests/integration/test_v1_tool_calls.py + python tests/integration/test_v1_tool_calls.py --base-url http://localhost:7091 +""" + +import json as json_module +import re +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import requests + +_THIS_DIR = Path(__file__).parent +_TESTS_DIR = _THIS_DIR.parent +_ROOT_DIR = _TESTS_DIR.parent +if str(_ROOT_DIR) not in sys.path: + sys.path.insert(0, str(_ROOT_DIR)) + +from tests.integration.base import DocsGPTTestBase, create_client_from_args + +# Internal suffix pattern that should NOT appear in client responses +_CT_SUFFIX_RE = re.compile(r"_ct\d+$") + + +class V1ToolCallTests(DocsGPTTestBase): + """Integration tests for /v1/ client tool-call flows.""" + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + + def get_or_create_agent_key(self) -> Optional[str]: + """Get or create a test agent and return its API key.""" + if hasattr(self, "_agent_key") and self._agent_key: + return self._agent_key + + payload = { + "name": f"V1 ToolCall Test {int(time.time())}", + "description": "Integration test agent for tool-call flow", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": "classic", + "status": "published", + "source": "default", + } + + try: + response = self.post("/api/create_agent", json=payload, timeout=10) + if response.status_code in [200, 201]: + result = response.json() + api_key = result.get("key") + self._agent_id = result.get("id") + if api_key: + self._agent_key = api_key + self.print_info(f"Created test agent with key: {api_key[:8]}...") + return api_key + except Exception as e: + self.print_error(f"Failed to create agent: {e}") + + return None + + def _v1_headers(self, api_key: str) -> dict: + return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + + # A simple client tool definition in OpenAI format + _CLIENT_TOOLS = [ + { + "type": "function", + "function": { + "name": "create", + "description": "Create a new todo item", + "parameters": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "The title of the new todo item", + } + }, + "required": ["title"], + }, + }, + } + ] + + def _send_streaming_request( + self, + api_key: str, + messages: List[Dict], + tools: Optional[List[Dict]] = None, + conversation_id: Optional[str] = None, + ) -> Tuple[List[Dict], str, Optional[Dict]]: + """Send a streaming request and collect all events. + + Returns: + (all_chunks, full_content, tool_call_info) + tool_call_info is a dict with 'name', 'arguments', 'call_id' + if the response paused for a client tool call, else None. + """ + body: Dict[str, Any] = { + "messages": messages, + "stream": True, + } + if tools: + body["tools"] = tools + if conversation_id: + body["conversation_id"] = conversation_id + + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json=body, + headers=self._v1_headers(api_key), + stream=True, + timeout=120, + ) + + if response.status_code != 200: + raise RuntimeError( + f"Expected 200, got {response.status_code}: {response.text[:300]}" + ) + + chunks: List[Dict] = [] + content_pieces: List[str] = [] + tool_call_info: Optional[Dict] = None + conversation_id_from_response: Optional[str] = None + + for line in response.iter_lines(): + if not line: + continue + line_str = line.decode("utf-8") + if not line_str.startswith("data: "): + continue + + data_str = line_str[6:] + if data_str.strip() == "[DONE]": + break + + try: + chunk = json_module.loads(data_str) + chunks.append(chunk) + + # Standard chunks + if "choices" in chunk: + delta = chunk["choices"][0].get("delta", {}) + if "content" in delta: + content_pieces.append(delta["content"]) + + # Tool call delta + if "tool_calls" in delta: + tc = delta["tool_calls"][0] + tool_call_info = { + "call_id": tc.get("id", ""), + "name": tc["function"]["name"], + "arguments": tc["function"].get("arguments", "{}"), + } + + # Extension chunks + if "docsgpt" in chunk: + ext = chunk["docsgpt"] + if ext.get("type") == "id": + conversation_id_from_response = ext.get("conversation_id") + + except json_module.JSONDecodeError: + pass + + full_content = "".join(content_pieces) + + # Attach conversation_id to tool_call_info for convenience + if tool_call_info and conversation_id_from_response: + tool_call_info["conversation_id"] = conversation_id_from_response + + return chunks, full_content, tool_call_info + + def _send_non_streaming_request( + self, + api_key: str, + messages: List[Dict], + tools: Optional[List[Dict]] = None, + conversation_id: Optional[str] = None, + ) -> Dict: + """Send a non-streaming request and return parsed JSON.""" + body: Dict[str, Any] = { + "messages": messages, + "stream": False, + } + if tools: + body["tools"] = tools + if conversation_id: + body["conversation_id"] = conversation_id + + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json=body, + headers=self._v1_headers(api_key), + timeout=120, + ) + + if response.status_code != 200: + raise RuntimeError( + f"Expected 200, got {response.status_code}: {response.text[:300]}" + ) + + return response.json() + + # ------------------------------------------------------------------------- + # Tests + # ------------------------------------------------------------------------- + + def test_streaming_tool_call_clean_name(self) -> bool: + """Streaming: tool names returned to client must not have _ct suffixes.""" + test_name = "v1 streaming tool call - clean name" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Test integration'. Call the tool now."}, + ] + chunks, content, tool_call_info = self._send_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + if not tool_call_info: + # LLM didn't trigger the tool — could happen, not a failure of our code + self.print_warning("LLM did not trigger a tool call (may need prompt tuning)") + self.print_info(f"Got text response instead: {content[:100]}") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + tool_name = tool_call_info["name"] + self.print_info(f"Tool call name: {tool_name}") + + has_suffix = bool(_CT_SUFFIX_RE.search(tool_name)) + if has_suffix: + self.print_error(f"Tool name has internal suffix: {tool_name}") + self.record_result(test_name, False, f"Suffix leak: {tool_name}") + return False + + self.print_success(f"Tool name is clean: {tool_name}") + self.record_result(test_name, True, f"Clean name: {tool_name}") + return True + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_non_streaming_tool_call_clean_name(self) -> bool: + """Non-streaming: tool names returned to client must not have _ct suffixes.""" + test_name = "v1 non-streaming tool call - clean name" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Test non-stream'. Call the tool now."}, + ] + data = self._send_non_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + message = data["choices"][0]["message"] + tool_calls = message.get("tool_calls") + + if not tool_calls: + content = message.get("content", "") + self.print_warning("LLM did not trigger a tool call") + self.print_info(f"Got text response: {content[:100]}") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + tool_name = tool_calls[0]["function"]["name"] + self.print_info(f"Tool call name: {tool_name}") + + has_suffix = bool(_CT_SUFFIX_RE.search(tool_name)) + if has_suffix: + self.print_error(f"Tool name has internal suffix: {tool_name}") + self.record_result(test_name, False, f"Suffix leak: {tool_name}") + return False + + self.print_success(f"Tool name is clean: {tool_name}") + self.record_result(test_name, True, f"Clean name: {tool_name}") + return True + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool: + """Full tool-call round-trip: trigger → get conversation_id → continue with top-level id.""" + test_name = "v1 streaming tool continuation - top-level conversation_id" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + # Step 1: trigger a tool call + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Round trip test'. Call the tool now."}, + ] + chunks, content, tool_call_info = self._send_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + if not tool_call_info: + self.print_warning("LLM did not trigger a tool call") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + conversation_id = tool_call_info.get("conversation_id") + if not conversation_id: + self.print_error("No conversation_id returned in stream") + self.record_result(test_name, False, "Missing conversation_id") + return False + + self.print_info(f"Got conversation_id: {conversation_id[:12]}...") + self.print_info(f"Tool call: {tool_call_info['name']}({tool_call_info['arguments']})") + + # Step 2: send continuation with tool result + top-level conversation_id + # (standard OpenAI format — no docsgpt field in assistant message) + continuation_messages = [ + *messages, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_info["call_id"], + "type": "function", + "function": { + "name": tool_call_info["name"], + "arguments": tool_call_info["arguments"], + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": tool_call_info["call_id"], + "content": json_module.dumps({"id": 99, "title": "Round trip test", "status": "created"}), + }, + ] + + chunks2, content2, tool_call_info2 = self._send_streaming_request( + api_key, + continuation_messages, + tools=self._CLIENT_TOOLS, + conversation_id=conversation_id, + ) + + checks = [ + (len(chunks2) > 0, f"continuation returned {len(chunks2)} chunks"), + (bool(content2) or tool_call_info2 is not None, "got content or another tool call"), + ] + + all_passed = True + for check, label in checks: + if check: + self.print_success(f" {label}") + else: + self.print_error(f" {label}") + all_passed = False + + if content2: + self.print_info(f"Continuation response: {content2[:150]}") + + self.record_result( + test_name, + all_passed, + "Full round-trip works" if all_passed else "Continuation failed", + ) + return all_passed + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + def test_non_streaming_tool_continuation_with_top_level_conversation_id(self) -> bool: + """Non-streaming full round-trip with top-level conversation_id.""" + test_name = "v1 non-streaming tool continuation - top-level conversation_id" + self.print_header(f"Testing {test_name}") + + api_key = self.get_or_create_agent_key() + if not api_key: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + try: + # Step 1: trigger a tool call + messages = [ + {"role": "user", "content": "Use the create tool to add a todo item titled 'Non-stream round trip'. Call the tool now."}, + ] + data = self._send_non_streaming_request( + api_key, messages, tools=self._CLIENT_TOOLS + ) + + message = data["choices"][0]["message"] + tool_calls = message.get("tool_calls") + + if not tool_calls: + self.print_warning("LLM did not trigger a tool call") + self.record_result(test_name, True, "Skipped (LLM didn't call tool)") + return True + + conversation_id = data.get("docsgpt", {}).get("conversation_id") + if not conversation_id: + self.print_error("No conversation_id in response") + self.record_result(test_name, False, "Missing conversation_id") + return False + + tc = tool_calls[0] + self.print_info(f"Got tool call: {tc['function']['name']}") + self.print_info(f"conversation_id: {conversation_id[:12]}...") + + # Step 2: send continuation (standard format, top-level conversation_id) + continuation_messages = [ + *messages, + { + "role": "assistant", + "content": None, + "tool_calls": [tc], + }, + { + "role": "tool", + "tool_call_id": tc["id"], + "content": json_module.dumps({"id": 100, "title": "Non-stream round trip", "status": "created"}), + }, + ] + + data2 = self._send_non_streaming_request( + api_key, + continuation_messages, + tools=self._CLIENT_TOOLS, + conversation_id=conversation_id, + ) + + message2 = data2["choices"][0]["message"] + has_response = bool(message2.get("content")) or bool(message2.get("tool_calls")) + + if has_response: + self.print_success("Continuation returned a response") + content2 = message2.get("content", "") + if content2: + self.print_info(f"Response: {content2[:150]}") + else: + self.print_error("Continuation returned empty response") + + self.record_result( + test_name, + has_response, + "Round-trip works" if has_response else "Empty continuation response", + ) + return has_response + + except Exception as e: + self.print_error(f"Error: {e}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # Cleanup & Run All + # ------------------------------------------------------------------------- + + def cleanup(self): + if hasattr(self, "_agent_id") and self._agent_id and self.is_authenticated: + try: + self.post(f"/api/delete_agent?id={self._agent_id}", json={}) + self.print_info(f"Cleaned up test agent {self._agent_id[:8]}...") + except Exception: + pass + + def run_all(self) -> bool: + self.print_header("V1 Tool-Call Flow Integration Tests") + self.print_info(f"Base URL: {self.base_url}") + self.print_info(f"Authentication: {'Yes' if self.is_authenticated else 'No'}") + + try: + # Streaming tests + self.test_streaming_tool_call_clean_name() + time.sleep(1) + + self.test_non_streaming_tool_call_clean_name() + time.sleep(1) + + # Full round-trip tests + self.test_streaming_tool_continuation_with_top_level_conversation_id() + time.sleep(1) + + self.test_non_streaming_tool_continuation_with_top_level_conversation_id() + time.sleep(1) + + finally: + self.cleanup() + + return self.print_summary() + + +def main(): + client = create_client_from_args(V1ToolCallTests, "DocsGPT V1 Tool-Call Integration Tests") + success = client.run_all() + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/llm/handlers/test_google.py b/tests/llm/handlers/test_google.py index 900b2d5f..1a4558fc 100644 --- a/tests/llm/handlers/test_google.py +++ b/tests/llm/handlers/test_google.py @@ -196,9 +196,9 @@ class TestGoogleLLMHandler: assert result.finish_reason == "tool_calls" def test_create_tool_message(self): - """Test creating tool message.""" + """Test creating tool message in standard format.""" handler = GoogleLLMHandler() - + tool_call = ToolCall( id="call_123", name="get_weather", @@ -206,35 +206,26 @@ class TestGoogleLLMHandler: index=0 ) result = {"temperature": "25C", "condition": "cloudy"} - + message = handler.create_tool_message(tool_call, result) - - expected = { - "role": "model", - "content": [ - { - "function_response": { - "name": "get_weather", - "response": {"result": result}, - } - } - ], - } - - assert message == expected + + assert message["role"] == "tool" + assert message["tool_call_id"] == "call_123" + import json + assert json.loads(message["content"]) == result def test_create_tool_message_string_result(self): """Test creating tool message with string result.""" handler = GoogleLLMHandler() - + tool_call = ToolCall(id="call_456", name="get_time", arguments={}) result = "2023-12-01 15:30:00 JST" - + message = handler.create_tool_message(tool_call, result) - - assert message["role"] == "model" - assert message["content"][0]["function_response"]["response"]["result"] == result - assert message["content"][0]["function_response"]["name"] == "get_time" + + assert message["role"] == "tool" + assert message["tool_call_id"] == "call_456" + assert message["content"] == result def test_iterate_stream(self): """Test stream iteration.""" diff --git a/tests/llm/handlers/test_llm_handlers.py b/tests/llm/handlers/test_llm_handlers.py index 304d89fb..6baf04e8 100644 --- a/tests/llm/handlers/test_llm_handlers.py +++ b/tests/llm/handlers/test_llm_handlers.py @@ -621,6 +621,7 @@ class TestHandleToolCalls: agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) def fake_execute(tools_dict, call): yield {"type": "tool_call", "data": {"status": "pending"}} @@ -641,7 +642,7 @@ class TestHandleToolCalls: while True: events.append(next(gen)) except StopIteration as e: - messages = e.value + messages, _pending = e.value assert any(e.get("type") == "tool_call" for e in events) assert len(messages) >= 2 # function_call + tool_message @@ -675,6 +676,9 @@ class TestHandleToolCalls: agent = Mock() agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) + agent.tool_executor._name_to_tool = {} agent._execute_tool_action = Mock(side_effect=RuntimeError("exec error")) call = ToolCall(id="c1", name="action_1", arguments="{}") @@ -704,18 +708,17 @@ class TestHandleToolCalls: while True: next(gen) except StopIteration as e: - messages = e.value + messages, _pending = e.value + # Standard format: thought_signature is on tool_calls items assistant_msgs = [ m for m in messages - if m.get("role") == "assistant" - and isinstance(m.get("content"), list) + if m.get("role") == "assistant" and m.get("tool_calls") ] assert any( - "thought_signature" in item + tc.get("thought_signature") == "sig" for m in assistant_msgs - for item in m["content"] - if isinstance(item, dict) + for tc in m["tool_calls"] ) @@ -751,6 +754,7 @@ class TestHandleNonStreaming: agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) # First response requires tool call, second is final call_count = {"n": 0} @@ -856,6 +860,7 @@ class TestHandleStreaming: agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) # First chunk has partial tool call, second completes it chunk1 = LLMResponse( @@ -907,6 +912,7 @@ class TestHandleStreaming: agent.context_limit_reached = True agent._check_context_limit = Mock(return_value=True) agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) # Chunk finishes with tool_calls chunk = LLMResponse( @@ -929,7 +935,7 @@ class TestHandleStreaming: def fake_handle_tool_calls(agent, calls, tools_dict, messages): agent.context_limit_reached = True yield {"type": "tool_call", "data": {"status": "skipped"}} - return messages + return messages, None handler.handle_tool_calls = fake_handle_tool_calls @@ -1501,6 +1507,7 @@ class TestHandleToolCallsCompressionSuccess: agent._check_context_limit = Mock(side_effect=check_limit) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) def fake_execute(tools_dict, call): yield {"type": "tool_call", "data": {"status": "pending"}} @@ -1538,6 +1545,7 @@ class TestHandleToolCallsCompressionSuccess: agent = Mock() agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) exec_count = {"n": 0} diff --git a/tests/llm/handlers/test_openai.py b/tests/llm/handlers/test_openai.py index 64c89f6c..03494573 100644 --- a/tests/llm/handlers/test_openai.py +++ b/tests/llm/handlers/test_openai.py @@ -128,9 +128,9 @@ class TestOpenAILLMHandler: assert result.finish_reason == "" def test_create_tool_message(self): - """Test creating tool message.""" + """Test creating tool message in standard format.""" handler = OpenAILLMHandler() - + tool_call = ToolCall( id="call_123", name="get_weather", @@ -138,36 +138,26 @@ class TestOpenAILLMHandler: index=0 ) result = {"temperature": "72F", "condition": "sunny"} - + message = handler.create_tool_message(tool_call, result) - - expected = { - "role": "tool", - "content": [ - { - "function_response": { - "name": "get_weather", - "response": {"result": result}, - "call_id": "call_123", - } - } - ], - } - - assert message == expected + + assert message["role"] == "tool" + assert message["tool_call_id"] == "call_123" + import json + assert json.loads(message["content"]) == result def test_create_tool_message_string_result(self): """Test creating tool message with string result.""" handler = OpenAILLMHandler() - + tool_call = ToolCall(id="call_456", name="get_time", arguments={}) result = "2023-12-01 10:30:00" - + message = handler.create_tool_message(tool_call, result) - + assert message["role"] == "tool" - assert message["content"][0]["function_response"]["response"]["result"] == result - assert message["content"][0]["function_response"]["call_id"] == "call_456" + assert message["tool_call_id"] == "call_456" + assert message["content"] == result def test_iterate_stream(self): """Test stream iteration.""" diff --git a/tests/llm/test_base.py b/tests/llm/test_base.py index c12bbdc7..e62d0186 100644 --- a/tests/llm/test_base.py +++ b/tests/llm/test_base.py @@ -478,11 +478,14 @@ class TestHandleToolCallsErrors: handler = ConcreteHandler() agent = MagicMock() agent._check_context_limit = MagicMock(return_value=False) + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = MagicMock(return_value=None) + agent.tool_executor._name_to_tool = {"search": ("1", "search")} agent._execute_tool_action = MagicMock( side_effect=RuntimeError("tool failed") ) - tool_call = ToolCall(id="tc1", name="search_1", arguments={"q": "test"}) + tool_call = ToolCall(id="tc1", name="search", arguments={"q": "test"}) tools_dict = {"1": {"name": "search_tool"}} messages = [{"role": "user", "content": "hi"}] @@ -506,6 +509,9 @@ class TestHandleToolCallsErrors: handler = ConcreteHandler() agent = MagicMock() agent._check_context_limit = MagicMock(return_value=False) + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = MagicMock(return_value=None) + agent.tool_executor._name_to_tool = {} agent._execute_tool_action = MagicMock( side_effect=RuntimeError("tool failed") ) @@ -1169,12 +1175,15 @@ class TestHandleToolCallsErrorsAdditional: handler = ConcreteHandler() agent = MagicMock() agent._check_context_limit = MagicMock(return_value=False) + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = MagicMock(return_value=None) + agent.tool_executor._name_to_tool = {"do_thing": ("42", "do_thing")} agent._execute_tool_action = MagicMock( side_effect=RuntimeError("broken tool") ) tool_call = ToolCall( - id="tc1", name="do_thing_42", arguments={"x": 1} + id="tc1", name="do_thing", arguments={"x": 1} ) tools_dict = {"42": {"name": "my_tool"}} messages = [{"role": "user", "content": "go"}] @@ -1188,7 +1197,7 @@ class TestHandleToolCallsErrorsAdditional: while True: events.append(next(gen)) except StopIteration as e: - final_messages = e.value + final_messages, _pending = e.value # Verify the error message was appended error_msgs = [ @@ -1205,12 +1214,17 @@ class TestHandleToolCallsErrorsAdditional: ] assert len(error_events) == 1 assert error_events[0]["data"]["tool_name"] == "my_tool" - assert error_events[0]["data"]["action_name"] == "do_thing_42" + assert error_events[0]["data"]["action_name"] == "do_thing" def test_tool_error_with_no_context_check(self): """Cover line 660: messages.copy() at start of handle_tool_calls.""" handler = ConcreteHandler() agent = MagicMock(spec=[]) # No _check_context_limit attribute + agent.llm = MagicMock() + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor = MagicMock() + agent.tool_executor.check_pause = MagicMock(return_value=None) + agent.tool_executor._name_to_tool = {} agent._execute_tool_action = MagicMock( side_effect=ValueError("bad args") ) diff --git a/tests/test_agent_token_tracking.py b/tests/test_agent_token_tracking.py index e168567a..bdb21549 100644 --- a/tests/test_agent_token_tracking.py +++ b/tests/test_agent_token_tracking.py @@ -176,6 +176,9 @@ class TestLLMHandlerTokenTracking: # Create mock agent that hits limit on second tool mock_agent = Mock() mock_agent.context_limit_reached = False + mock_agent.llm.__class__.__name__ = "MockLLM" + mock_agent.tool_executor.check_pause = Mock(return_value=None) + mock_agent.tool_executor._name_to_tool = {} call_count = [0] @@ -235,6 +238,9 @@ class TestLLMHandlerTokenTracking: mock_agent = Mock() mock_agent.context_limit_reached = False mock_agent._check_context_limit = Mock(return_value=False) + mock_agent.llm.__class__.__name__ = "MockLLM" + mock_agent.tool_executor.check_pause = Mock(return_value=None) + mock_agent.tool_executor._name_to_tool = {} mock_agent._execute_tool_action = Mock( return_value=iter([{"type": "tool_call", "data": {}}]) ) @@ -300,7 +306,7 @@ class TestLLMHandlerTokenTracking: def tool_handler_gen(*args): yield {"type": "tool", "data": {}} - return [] + return [], None # Mock handle_tool_calls to return messages and set flag with patch.object( diff --git a/tests/test_client_tools.py b/tests/test_client_tools.py new file mode 100644 index 00000000..6c4cd6d8 --- /dev/null +++ b/tests/test_client_tools.py @@ -0,0 +1,430 @@ +"""Tests for client-side tools (Phase 2). + +Covers merge_client_tools, prepare_tools_for_llm with client tools, +check_pause for client-side tools, and the full flow through the handler. +""" + +from unittest.mock import Mock + +import pytest + +from application.agents.tool_executor import ToolExecutor +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +# --------------------------------------------------------------------------- +# ToolExecutor.merge_client_tools +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestMergeClientTools: + + def test_merge_single_tool(self): + executor = ToolExecutor() + tools_dict = {} + client_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"], + }, + }, + } + ] + + result = executor.merge_client_tools(tools_dict, client_tools) + + assert "ct0" in result + tool = result["ct0"] + assert tool["name"] == "get_weather" + assert tool["client_side"] is True + assert len(tool["actions"]) == 1 + assert tool["actions"][0]["name"] == "get_weather" + assert tool["actions"][0]["active"] is True + assert "city" in tool["actions"][0]["parameters"]["properties"] + + def test_merge_multiple_tools(self): + executor = ToolExecutor() + tools_dict = {"0": {"name": "existing_tool", "actions": []}} + client_tools = [ + {"type": "function", "function": {"name": "tool_a", "description": "A"}}, + {"type": "function", "function": {"name": "tool_b", "description": "B"}}, + ] + + result = executor.merge_client_tools(tools_dict, client_tools) + + # Original tool still present + assert "0" in result + # Client tools added + assert "ct0" in result + assert "ct1" in result + assert result["ct0"]["name"] == "tool_a" + assert result["ct1"]["name"] == "tool_b" + + def test_merge_bare_format(self): + """Accept simplified format without the outer 'function' wrapper.""" + executor = ToolExecutor() + tools_dict = {} + client_tools = [ + {"name": "simple_tool", "description": "Simple", "parameters": {}}, + ] + + result = executor.merge_client_tools(tools_dict, client_tools) + + assert "ct0" in result + assert result["ct0"]["name"] == "simple_tool" + + def test_merge_preserves_existing_tools(self): + executor = ToolExecutor() + tools_dict = { + "abc123": { + "name": "brave", + "actions": [{"name": "search", "active": True}], + } + } + client_tools = [ + {"type": "function", "function": {"name": "my_tool", "description": "D"}}, + ] + + executor.merge_client_tools(tools_dict, client_tools) + + assert "abc123" in tools_dict + assert tools_dict["abc123"]["name"] == "brave" + assert "ct0" in tools_dict + + def test_merge_empty_list(self): + executor = ToolExecutor() + tools_dict = {"0": {"name": "existing"}} + + executor.merge_client_tools(tools_dict, []) + + assert len(tools_dict) == 1 + + +# --------------------------------------------------------------------------- +# prepare_tools_for_llm with client tools +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestPrepareClientToolsForLlm: + + def test_client_tools_included_in_llm_schema(self): + executor = ToolExecutor() + tools_dict = {} + client_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"} + }, + "required": ["city"], + }, + }, + } + ] + executor.merge_client_tools(tools_dict, client_tools) + + schemas = executor.prepare_tools_for_llm(tools_dict) + + assert len(schemas) == 1 + assert schemas[0]["type"] == "function" + assert schemas[0]["function"]["name"] == "get_weather" + assert schemas[0]["function"]["description"] == "Get weather" + # Parameters passed through directly (not filtered by _build_tool_parameters) + assert "city" in schemas[0]["function"]["parameters"]["properties"] + assert schemas[0]["function"]["parameters"]["required"] == ["city"] + + def test_mixed_server_and_client_tools(self): + executor = ToolExecutor() + tools_dict = { + "t1": { + "name": "test_tool", + "actions": [ + { + "name": "do_thing", + "description": "Does a thing", + "active": True, + "parameters": { + "properties": { + "query": { + "type": "string", + "filled_by_llm": True, + "required": True, + } + } + }, + } + ], + } + } + client_tools = [ + { + "type": "function", + "function": { + "name": "local_fn", + "description": "Local function", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + executor.merge_client_tools(tools_dict, client_tools) + + schemas = executor.prepare_tools_for_llm(tools_dict) + + assert len(schemas) == 2 + names = {s["function"]["name"] for s in schemas} + assert "do_thing" in names + assert "local_fn" in names + + +# --------------------------------------------------------------------------- +# get_tools auto-merges client_tools +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestGetToolsAutoMerge: + + def test_get_tools_merges_client_tools(self, mock_mongo_db): + executor = ToolExecutor(user="alice") + executor.client_tools = [ + { + "type": "function", + "function": {"name": "my_fn", "description": "test"}, + } + ] + + tools = executor.get_tools() + + assert any( + t.get("client_side") is True for t in tools.values() + ), "Client tools should be merged into tools_dict" + + def test_get_tools_no_client_tools(self, mock_mongo_db): + executor = ToolExecutor(user="alice") + + tools = executor.get_tools() + + assert not any( + t.get("client_side") for t in tools.values() + ) + + +# --------------------------------------------------------------------------- +# check_pause for client-side tools +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckPauseClientTools: + + def _make_call(self, name="action_0", call_id="c1"): + call = Mock() + call.name = name + call.id = call_id + call.arguments = "{}" + call.thought_signature = None + return call + + def test_client_tool_triggers_pause(self): + executor = ToolExecutor() + tools_dict = { + "ct0": { + "name": "get_weather", + "client_side": True, + "actions": [ + {"name": "get_weather", "active": True, "parameters": {}}, + ], + } + } + executor.prepare_tools_for_llm(tools_dict) + call = self._make_call(name="get_weather") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + + assert result is not None + assert result["pause_type"] == "requires_client_execution" + assert result["tool_name"] == "get_weather" + assert result["tool_id"] == "ct0" + + def test_server_tool_no_pause(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "brave", + "actions": [ + {"name": "search", "active": True, "parameters": {}}, + ], + } + } + executor.prepare_tools_for_llm(tools_dict) + call = self._make_call(name="search") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + + assert result is None + + +# --------------------------------------------------------------------------- +# Handler flow: client tool causes pause +# --------------------------------------------------------------------------- + + +class ConcreteHandler(LLMHandler): + """Minimal concrete handler for testing.""" + + def parse_response(self, response): + return LLMResponse( + content=str(response), tool_calls=[], finish_reason="stop", + raw_response=response, + ) + + def create_tool_message(self, tool_call, result): + return {"role": "tool", "content": str(result)} + + def _iterate_stream(self, response): + for chunk in response: + yield chunk + + +@pytest.mark.unit +class TestHandlerClientToolPause: + + def test_client_tool_pauses_stream(self): + """When LLM calls a client-side tool, handler yields tool_calls_pending.""" + handler = ConcreteHandler() + + agent = Mock() + agent.llm = Mock() + agent.model_id = "test" + agent.tools = [] + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + # check_pause returns pause info for client tool + agent.tool_executor.check_pause = Mock(return_value={ + "call_id": "c1", + "name": "get_weather", + "tool_name": "get_weather", + "tool_id": "ct0", + "action_name": "get_weather", + "llm_name": "get_weather", + "arguments": {"city": "SF"}, + "pause_type": "requires_client_execution", + "thought_signature": None, + }) + agent.tool_executor._name_to_tool = {"get_weather": ("ct0", "get_weather")} + + # Simulate streaming: one chunk with tool_calls finish_reason + chunk = LLMResponse( + content="", + tool_calls=[ToolCall(id="c1", name="get_weather", arguments='{"city": "SF"}', index=0)], + finish_reason="tool_calls", + raw_response={}, + ) + handler.parse_response = lambda c: c + handler._iterate_stream = lambda r: iter(r) + + gen = handler.handle_streaming( + agent, [chunk], {"ct0": {"name": "get_weather", "client_side": True}}, [] + ) + events = list(gen) + + # Should have a requires_client_execution event + client_events = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "requires_client_execution" + ] + assert len(client_events) == 1 + + # Should have a tool_calls_pending event + pending_events = [ + e for e in events + if isinstance(e, dict) and e.get("type") == "tool_calls_pending" + ] + assert len(pending_events) == 1 + + def test_mixed_server_and_client_tools_in_batch(self): + """Server tool executes, client tool pauses.""" + handler = ConcreteHandler() + + agent = Mock() + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + call_count = {"n": 0} + + def check_pause_fn(tools_dict, call, llm_class): + call_count["n"] += 1 + if call_count["n"] == 2: # Second tool is client-side + return { + "call_id": "c2", + "name": "get_weather", + "tool_name": "get_weather", + "tool_id": "ct0", + "action_name": "get_weather", + "llm_name": "get_weather", + "arguments": {}, + "pause_type": "requires_client_execution", + "thought_signature": None, + } + return None + + agent.tool_executor.check_pause = Mock(side_effect=check_pause_fn) + agent.tool_executor._name_to_tool = { + "search": ("0", "search"), + "get_weather": ("ct0", "get_weather"), + } + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("server result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + + calls = [ + ToolCall(id="c1", name="search", arguments="{}"), + ToolCall(id="c2", name="get_weather", arguments="{}"), + ] + + gen = handler.handle_tool_calls( + agent, + calls, + { + "0": {"name": "search"}, + "ct0": {"name": "get_weather", "client_side": True}, + }, + [], + ) + + events = [] + messages = None + pending = None + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + # Server tool executed + assert agent._execute_tool_action.call_count == 1 + # Client tool pending + assert pending is not None + assert len(pending) == 1 + assert pending[0]["pause_type"] == "requires_client_execution" diff --git a/tests/test_continuation.py b/tests/test_continuation.py new file mode 100644 index 00000000..2858290c --- /dev/null +++ b/tests/test_continuation.py @@ -0,0 +1,667 @@ +"""Tests for the continuation infrastructure (Phase 1). + +Covers ContinuationService, ToolExecutor.check_pause, handler pause +signaling, BaseAgent.gen_continuation, and request validation. +""" + +from unittest.mock import Mock + +import pytest + +from application.agents.tool_executor import ToolExecutor +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +# --------------------------------------------------------------------------- +# ContinuationService +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestContinuationService: + + def test_save_and_load(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + svc.save_state( + conversation_id="conv-1", + user="alice", + messages=[{"role": "user", "content": "hi"}], + pending_tool_calls=[{"call_id": "c1", "pause_type": "awaiting_approval"}], + tools_dict={"0": {"name": "test_tool"}}, + tool_schemas=[{"type": "function", "function": {"name": "act_0"}}], + agent_config={"model_id": "gpt-4"}, + ) + + state = svc.load_state("conv-1", "alice") + assert state is not None + assert state["conversation_id"] == "conv-1" + assert state["user"] == "alice" + assert len(state["messages"]) == 1 + assert len(state["pending_tool_calls"]) == 1 + assert state["agent_config"]["model_id"] == "gpt-4" + + def test_load_returns_none_when_missing(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + assert svc.load_state("nonexistent", "alice") is None + + def test_delete_state(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + svc.save_state( + conversation_id="conv-2", + user="bob", + messages=[], + pending_tool_calls=[], + tools_dict={}, + tool_schemas=[], + agent_config={}, + ) + assert svc.delete_state("conv-2", "bob") is True + assert svc.load_state("conv-2", "bob") is None + + def test_delete_nonexistent(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + assert svc.delete_state("nope", "nope") is False + + def test_upsert_replaces_existing(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + svc.save_state( + conversation_id="conv-3", + user="carol", + messages=[{"role": "user", "content": "v1"}], + pending_tool_calls=[], + tools_dict={}, + tool_schemas=[], + agent_config={}, + ) + svc.save_state( + conversation_id="conv-3", + user="carol", + messages=[{"role": "user", "content": "v2"}], + pending_tool_calls=[{"call_id": "c2"}], + tools_dict={}, + tool_schemas=[], + agent_config={}, + ) + state = svc.load_state("conv-3", "carol") + assert state["messages"][0]["content"] == "v2" + assert len(state["pending_tool_calls"]) == 1 + + +# --------------------------------------------------------------------------- +# ToolExecutor.check_pause +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckPause: + + def _make_call(self, name="action_0", call_id="c1", arguments="{}"): + call = Mock() + call.name = name + call.id = call_id + call.arguments = arguments + call.thought_signature = None + return call + + def test_returns_none_for_normal_tool(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "brave", + "actions": [ + {"name": "search", "active": True, "parameters": {}}, + ], + } + } + call = self._make_call(name="search_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is None + + def test_returns_pause_for_client_side_tool(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "get_weather", + "client_side": True, + "actions": [ + {"name": "get_weather", "active": True, "parameters": {}}, + ], + } + } + call = self._make_call(name="get_weather_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is not None + assert result["pause_type"] == "requires_client_execution" + assert result["call_id"] == "c1" + assert result["tool_id"] == "0" + + def test_returns_pause_for_approval_required(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "telegram", + "actions": [ + { + "name": "send_msg", + "active": True, + "require_approval": True, + "parameters": {}, + }, + ], + } + } + call = self._make_call(name="send_msg_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is not None + assert result["pause_type"] == "awaiting_approval" + + def test_returns_none_when_parse_fails(self): + executor = ToolExecutor() + call = self._make_call(name="bad_name_no_id", arguments="not json") + # Bad arguments will cause parse error -> None + result = executor.check_pause({}, call, "OpenAILLM") + assert result is None + + def test_returns_none_when_tool_not_in_dict(self): + executor = ToolExecutor() + call = self._make_call(name="action_99") + result = executor.check_pause({"0": {"name": "t"}}, call, "OpenAILLM") + assert result is None + + def test_api_tool_approval(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "api_tool", + "config": { + "actions": { + "delete_user": { + "name": "delete_user", + "require_approval": True, + "url": "http://example.com", + "method": "DELETE", + "active": True, + } + } + }, + } + } + call = self._make_call(name="delete_user_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is not None + assert result["pause_type"] == "awaiting_approval" + + +# --------------------------------------------------------------------------- +# Handler pause signaling (handle_tool_calls returns pending_actions) +# --------------------------------------------------------------------------- + + +class ConcreteHandler(LLMHandler): + """Minimal concrete handler for testing.""" + + def parse_response(self, response): + return LLMResponse( + content=str(response), tool_calls=[], finish_reason="stop", + raw_response=response, + ) + + def create_tool_message(self, tool_call, result): + return { + "role": "tool", + "content": [ + { + "function_response": { + "name": tool_call.name, + "response": {"result": result}, + "call_id": tool_call.id, + } + } + ], + } + + def _iterate_stream(self, response): + for chunk in response: + yield chunk + + +@pytest.mark.unit +class TestHandlerPauseSignaling: + + def _make_agent(self): + agent = Mock() + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("tool result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + return agent + + def test_no_pause_returns_none_pending(self): + handler = ConcreteHandler() + agent = self._make_agent() + call = ToolCall(id="c1", name="action_0", arguments="{}") + + gen = handler.handle_tool_calls(agent, [call], {"0": {"name": "t"}}, []) + events = [] + messages = None + pending = "NOT_SET" + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + assert pending is None + assert messages is not None + + def test_pause_returns_pending_actions(self): + handler = ConcreteHandler() + agent = self._make_agent() + agent.tool_executor.check_pause = Mock(return_value={ + "call_id": "c1", + "name": "send_msg_0", + "tool_name": "telegram", + "tool_id": "0", + "action_name": "send_msg", + "arguments": {"text": "hello"}, + "pause_type": "awaiting_approval", + "thought_signature": None, + }) + + call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}') + gen = handler.handle_tool_calls( + agent, [call], {"0": {"name": "telegram"}}, [] + ) + + events = [] + pending = None + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + assert pending is not None + assert len(pending) == 1 + assert pending[0]["pause_type"] == "awaiting_approval" + + # Should have yielded a tool_call event with awaiting_approval status + pause_events = [ + e for e in events + if e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "awaiting_approval" + ] + assert len(pause_events) == 1 + + def test_mixed_execute_and_pause(self): + """One tool executes, another needs approval.""" + handler = ConcreteHandler() + agent = self._make_agent() + + call_count = {"n": 0} + + def selective_pause(tools_dict, call, llm_class): + call_count["n"] += 1 + if call_count["n"] == 2: + return { + "call_id": "c2", + "name": "danger_0", + "tool_name": "danger", + "tool_id": "0", + "action_name": "danger", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + return None + + agent.tool_executor.check_pause = Mock(side_effect=selective_pause) + + calls = [ + ToolCall(id="c1", name="safe_0", arguments="{}"), + ToolCall(id="c2", name="danger_0", arguments="{}"), + ] + gen = handler.handle_tool_calls( + agent, calls, {"0": {"name": "multi"}}, [] + ) + + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + # First tool was executed normally + assert agent._execute_tool_action.call_count == 1 + # Second tool is pending + assert pending is not None + assert len(pending) == 1 + assert pending[0]["call_id"] == "c2" + + +# --------------------------------------------------------------------------- +# handle_streaming yields tool_calls_pending +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestStreamingPause: + + def test_streaming_yields_tool_calls_pending(self): + handler = ConcreteHandler() + agent = Mock() + agent.llm = Mock() + agent.model_id = "test" + agent.tools = [] + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + pause_info = { + "call_id": "c1", + "name": "fn_0", + "tool_name": "test", + "tool_id": "0", + "action_name": "fn", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + agent.tool_executor.check_pause = Mock(return_value=pause_info) + + chunk = LLMResponse( + content="", + tool_calls=[ToolCall(id="c1", name="fn_0", arguments="{}", index=0)], + finish_reason="tool_calls", + raw_response={}, + ) + handler.parse_response = lambda c: c + + def fake_iterate(response): + yield from response + + handler._iterate_stream = fake_iterate + + gen = handler.handle_streaming(agent, [chunk], {"0": {"name": "t"}}, []) + events = list(gen) + + # Should contain a tool_calls_pending event + pending_events = [ + e for e in events + if isinstance(e, dict) and e.get("type") == "tool_calls_pending" + ] + assert len(pending_events) == 1 + assert len(pending_events[0]["data"]["pending_tool_calls"]) == 1 + + # Agent should have _pending_continuation set + assert hasattr(agent, "_pending_continuation") + + +# --------------------------------------------------------------------------- +# BaseAgent.gen_continuation +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestGenContinuation: + + def test_approved_tool_executes(self): + """When a tool action is approved, the tool is executed.""" + from application.agents.classic_agent import ClassicAgent + + mock_llm = Mock() + mock_llm._supports_tools = True + mock_llm.gen_stream = Mock(return_value=iter(["Final answer"])) + mock_llm._supports_structured_output = Mock(return_value=False) + mock_llm.__class__.__name__ = "MockLLM" + + mock_handler = Mock() + mock_handler.process_message_flow = Mock(return_value=iter([])) + mock_handler.create_tool_message = Mock( + return_value={"role": "tool", "content": [{"function_response": { + "name": "act_0", "response": {"result": "done"}, "call_id": "c1" + }}]} + ) + + mock_executor = Mock() + mock_executor.tool_calls = [] + mock_executor.prepare_tools_for_llm = Mock(return_value=[]) + mock_executor.get_truncated_tool_calls = Mock(return_value=[]) + + def fake_execute(tools_dict, call, llm_class): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("result_data", "c1") + + mock_executor.execute = Mock(side_effect=fake_execute) + + agent = ClassicAgent( + endpoint="stream", + llm_name="openai", + model_id="gpt-4", + api_key="test", + llm=mock_llm, + llm_handler=mock_handler, + tool_executor=mock_executor, + ) + + messages = [{"role": "system", "content": "You are helpful."}] + tools_dict = {"0": {"name": "test_tool"}} + pending = [ + { + "call_id": "c1", + "name": "act_0", + "tool_name": "test_tool", + "tool_id": "0", + "action_name": "act", + "arguments": {"q": "test"}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + tool_actions = [{"call_id": "c1", "decision": "approved"}] + + list(agent.gen_continuation(messages, tools_dict, pending, tool_actions)) + + # Tool should have been executed + assert mock_executor.execute.called + + def test_denied_tool_sends_denial(self): + """When a tool action is denied, a denial message is added.""" + from application.agents.classic_agent import ClassicAgent + + mock_llm = Mock() + mock_llm._supports_tools = True + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + mock_llm._supports_structured_output = Mock(return_value=False) + mock_llm.__class__.__name__ = "MockLLM" + + mock_handler = Mock() + mock_handler.process_message_flow = Mock(return_value=iter([])) + mock_handler.create_tool_message = Mock( + return_value={"role": "tool", "content": "denied"} + ) + + mock_executor = Mock() + mock_executor.tool_calls = [] + mock_executor.prepare_tools_for_llm = Mock(return_value=[]) + mock_executor.get_truncated_tool_calls = Mock(return_value=[]) + + agent = ClassicAgent( + endpoint="stream", + llm_name="openai", + model_id="gpt-4", + api_key="test", + llm=mock_llm, + llm_handler=mock_handler, + tool_executor=mock_executor, + ) + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "danger_0", + "tool_name": "danger", + "tool_id": "0", + "action_name": "danger", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + tool_actions = [ + {"call_id": "c1", "decision": "denied", "comment": "too risky"} + ] + + events = list( + agent.gen_continuation(messages, {"0": {"name": "danger"}}, pending, tool_actions) + ) + + # Should have a denied tool_call event + denied = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "denied" + ] + assert len(denied) == 1 + + # create_tool_message should have been called with denial text + denial_arg = mock_handler.create_tool_message.call_args[0][1] + assert "denied" in denial_arg.lower() + assert "too risky" in denial_arg + + def test_client_result_appended(self): + """Client-provided tool result is added to messages.""" + from application.agents.classic_agent import ClassicAgent + + mock_llm = Mock() + mock_llm._supports_tools = True + mock_llm.gen_stream = Mock(return_value=iter(["Done"])) + mock_llm._supports_structured_output = Mock(return_value=False) + mock_llm.__class__.__name__ = "MockLLM" + + mock_handler = Mock() + mock_handler.process_message_flow = Mock(return_value=iter([])) + mock_handler.create_tool_message = Mock( + return_value={"role": "tool", "content": "client result"} + ) + + mock_executor = Mock() + mock_executor.tool_calls = [] + mock_executor.prepare_tools_for_llm = Mock(return_value=[]) + mock_executor.get_truncated_tool_calls = Mock(return_value=[]) + + agent = ClassicAgent( + endpoint="stream", + llm_name="openai", + model_id="gpt-4", + api_key="test", + llm=mock_llm, + llm_handler=mock_handler, + tool_executor=mock_executor, + ) + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "weather_0", + "tool_name": "weather", + "tool_id": "0", + "action_name": "weather", + "arguments": {"city": "SF"}, + "pause_type": "requires_client_execution", + "thought_signature": None, + } + ] + tool_actions = [{"call_id": "c1", "result": {"temp": "72F"}}] + + events = list( + agent.gen_continuation(messages, {"0": {"name": "weather"}}, pending, tool_actions) + ) + + # create_tool_message was called with the client result + result_arg = mock_handler.create_tool_message.call_args[0][1] + assert "72F" in result_arg + + # Should have a completed tool_call event + completed = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "completed" + ] + assert len(completed) == 1 + + +# --------------------------------------------------------------------------- +# validate_request +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestValidateRequest: + + @pytest.fixture(autouse=True) + def _app_context(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + yield + + def test_continuation_request_without_question(self, mock_mongo_db): + from application.api.answer.routes.base import BaseAnswerResource + + base = BaseAnswerResource() + data = { + "conversation_id": "conv-1", + "tool_actions": [{"call_id": "c1", "decision": "approved"}], + } + result = base.validate_request(data) + assert result is None # Valid + + def test_continuation_request_missing_conversation_id(self, mock_mongo_db): + from application.api.answer.routes.base import BaseAnswerResource + + base = BaseAnswerResource() + data = { + "tool_actions": [{"call_id": "c1", "decision": "approved"}], + } + result = base.validate_request(data) + assert result is not None # Error — missing conversation_id + + def test_normal_request_still_requires_question(self, mock_mongo_db): + from application.api.answer.routes.base import BaseAnswerResource + + base = BaseAnswerResource() + data = {"conversation_id": "conv-1"} + result = base.validate_request(data) + assert result is not None # Error — missing question diff --git a/tests/test_remaining_coverage.py b/tests/test_remaining_coverage.py index e27f86f2..62d430bb 100644 --- a/tests/test_remaining_coverage.py +++ b/tests/test_remaining_coverage.py @@ -1238,7 +1238,8 @@ class TestEmbeddingPipelineAddDocWithRetry: # NUL characters should be removed assert "\x00" not in doc.page_content - def test_add_text_to_store_with_retry_failure(self): + @patch("time.sleep", return_value=None) + def test_add_text_to_store_with_retry_failure(self, _mock_sleep): from application.parser.embedding_pipeline import add_text_to_store_with_retry mock_store = MagicMock() diff --git a/tests/test_tool_approval.py b/tests/test_tool_approval.py new file mode 100644 index 00000000..a9e95b0b --- /dev/null +++ b/tests/test_tool_approval.py @@ -0,0 +1,481 @@ +"""Tests for tool approval (Phase 3). + +Covers require_approval flag, check_pause for approval, the handler +pause/resume flow, and gen_continuation with approved/denied actions. +""" + +from unittest.mock import Mock + +import pytest + +from application.agents.tool_executor import ToolExecutor +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +# --------------------------------------------------------------------------- +# check_pause with require_approval +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckPauseApproval: + + def _make_call(self, name="action_0", call_id="c1"): + call = Mock() + call.name = name + call.id = call_id + call.arguments = "{}" + call.thought_signature = None + return call + + def test_approval_required_triggers_pause(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "telegram", + "actions": [ + { + "name": "send_msg", + "active": True, + "require_approval": True, + "parameters": {}, + }, + ], + } + } + call = self._make_call(name="send_msg_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + + assert result is not None + assert result["pause_type"] == "awaiting_approval" + assert result["tool_name"] == "telegram" + assert result["action_name"] == "send_msg" + assert result["tool_id"] == "0" + + def test_approval_not_required_no_pause(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "brave", + "actions": [ + { + "name": "search", + "active": True, + "require_approval": False, + "parameters": {}, + }, + ], + } + } + call = self._make_call(name="search_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is None + + def test_approval_absent_defaults_to_false(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "brave", + "actions": [ + { + "name": "search", + "active": True, + "parameters": {}, + }, + ], + } + } + call = self._make_call(name="search_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is None + + def test_api_tool_approval(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "api_tool", + "config": { + "actions": { + "delete_user": { + "name": "delete_user", + "require_approval": True, + "url": "http://example.com", + "method": "DELETE", + "active": True, + } + } + }, + } + } + call = self._make_call(name="delete_user_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is not None + assert result["pause_type"] == "awaiting_approval" + + def test_api_tool_no_approval(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "api_tool", + "config": { + "actions": { + "list_users": { + "name": "list_users", + "url": "http://example.com", + "method": "GET", + "active": True, + } + } + }, + } + } + call = self._make_call(name="list_users_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is None + + +# --------------------------------------------------------------------------- +# Handler: approval tool causes pause signal +# --------------------------------------------------------------------------- + + +class ConcreteHandler(LLMHandler): + def parse_response(self, response): + return LLMResponse( + content=str(response), tool_calls=[], finish_reason="stop", + raw_response=response, + ) + + def create_tool_message(self, tool_call, result): + import json as _json + content = _json.dumps(result) if not isinstance(result, str) else result + return {"role": "tool", "tool_call_id": tool_call.id, "content": content} + + def _iterate_stream(self, response): + for chunk in response: + yield chunk + + +@pytest.mark.unit +class TestHandlerApprovalPause: + + def _make_agent(self, pause_return): + agent = Mock() + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=pause_return) + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("tool result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + return agent + + def test_approval_tool_pauses(self): + handler = ConcreteHandler() + pause_info = { + "call_id": "c1", + "name": "send_msg_0", + "tool_name": "telegram", + "tool_id": "0", + "action_name": "send_msg", + "arguments": {"text": "hello"}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + agent = self._make_agent(pause_info) + + call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}') + gen = handler.handle_tool_calls( + agent, [call], {"0": {"name": "telegram"}}, [] + ) + + events = [] + pending = None + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + assert pending is not None + assert len(pending) == 1 + assert pending[0]["pause_type"] == "awaiting_approval" + + # Should NOT have executed the tool + assert agent._execute_tool_action.call_count == 0 + + # Should have yielded awaiting_approval status + approval_events = [ + e for e in events + if e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "awaiting_approval" + ] + assert len(approval_events) == 1 + + def test_mixed_normal_and_approval(self): + """First tool runs normally, second needs approval.""" + handler = ConcreteHandler() + + call_count = {"n": 0} + + def selective_pause(tools_dict, call, llm_class): + call_count["n"] += 1 + if call_count["n"] == 2: + return { + "call_id": "c2", + "name": "send_msg_0", + "tool_name": "telegram", + "tool_id": "0", + "action_name": "send_msg", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + return None + + agent = self._make_agent(None) + agent.tool_executor.check_pause = Mock(side_effect=selective_pause) + + calls = [ + ToolCall(id="c1", name="search_0", arguments="{}"), + ToolCall(id="c2", name="send_msg_0", arguments="{}"), + ] + + gen = handler.handle_tool_calls( + agent, calls, {"0": {"name": "multi"}}, [] + ) + + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + # First tool executed + assert agent._execute_tool_action.call_count == 1 + # Second tool is pending + assert pending is not None + assert len(pending) == 1 + assert pending[0]["call_id"] == "c2" + + +# --------------------------------------------------------------------------- +# gen_continuation: approval and denial flows +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestGenContinuationApproval: + + def _make_agent(self): + from application.agents.classic_agent import ClassicAgent + + mock_llm = Mock() + mock_llm._supports_tools = True + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + mock_llm._supports_structured_output = Mock(return_value=False) + mock_llm.__class__.__name__ = "MockLLM" + + mock_handler = Mock() + mock_handler.process_message_flow = Mock(return_value=iter([])) + mock_handler.create_tool_message = Mock( + return_value={"role": "tool", "tool_call_id": "c1", "content": "result"} + ) + + mock_executor = Mock() + mock_executor.tool_calls = [] + mock_executor.prepare_tools_for_llm = Mock(return_value=[]) + mock_executor.get_truncated_tool_calls = Mock(return_value=[]) + + def fake_execute(tools_dict, call, llm_class): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("executed_result", "c1") + + mock_executor.execute = Mock(side_effect=fake_execute) + + agent = ClassicAgent( + endpoint="stream", + llm_name="openai", + model_id="gpt-4", + api_key="test", + llm=mock_llm, + llm_handler=mock_handler, + tool_executor=mock_executor, + ) + return agent, mock_executor, mock_handler + + def test_approved_tool_executes(self): + agent, mock_executor, mock_handler = self._make_agent() + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "send_msg_0", + "tool_name": "telegram", + "tool_id": "0", + "action_name": "send_msg", + "arguments": {"text": "hello"}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + tool_actions = [{"call_id": "c1", "decision": "approved"}] + + list(agent.gen_continuation( + messages, {"0": {"name": "telegram"}}, pending, tool_actions + )) + + # Tool should have been executed + assert mock_executor.execute.called + + def test_denied_tool_sends_denial_to_llm(self): + agent, mock_executor, mock_handler = self._make_agent() + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "send_msg_0", + "tool_name": "telegram", + "tool_id": "0", + "action_name": "send_msg", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + tool_actions = [ + {"call_id": "c1", "decision": "denied", "comment": "not safe"}, + ] + + events = list(agent.gen_continuation( + messages, {"0": {"name": "telegram"}}, pending, tool_actions + )) + + # Tool should NOT have been executed + assert not mock_executor.execute.called + + # Should have a denied event + denied = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "denied" + ] + assert len(denied) == 1 + + # create_tool_message should have been called with denial text + denial_text = mock_handler.create_tool_message.call_args[0][1] + assert "denied" in denial_text.lower() + assert "not safe" in denial_text + + def test_denied_without_comment(self): + agent, mock_executor, mock_handler = self._make_agent() + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "act_0", + "tool_name": "tool", + "tool_id": "0", + "action_name": "act", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + tool_actions = [{"call_id": "c1", "decision": "denied"}] + + list(agent.gen_continuation( + messages, {"0": {"name": "tool"}}, pending, tool_actions + )) + + denial_text = mock_handler.create_tool_message.call_args[0][1] + assert "denied" in denial_text.lower() + + def test_mixed_approve_deny_batch(self): + """Two tools: one approved, one denied.""" + agent, mock_executor, mock_handler = self._make_agent() + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "safe_0", + "tool_name": "safe", + "tool_id": "0", + "action_name": "safe", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + }, + { + "call_id": "c2", + "name": "danger_0", + "tool_name": "danger", + "tool_id": "0", + "action_name": "danger", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + }, + ] + tool_actions = [ + {"call_id": "c1", "decision": "approved"}, + {"call_id": "c2", "decision": "denied", "comment": "too risky"}, + ] + + events = list(agent.gen_continuation( + messages, {"0": {"name": "multi"}}, pending, tool_actions + )) + + # First tool executed, second denied + assert mock_executor.execute.call_count == 1 + + denied = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "denied" + ] + assert len(denied) == 1 + + def test_missing_action_defaults_to_denial(self): + """If client doesn't respond for a pending tool, treat as denied.""" + agent, mock_executor, mock_handler = self._make_agent() + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "act_0", + "tool_name": "tool", + "tool_id": "0", + "action_name": "act", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + # Empty tool_actions — no response for c1 + tool_actions = [] + + events = list(agent.gen_continuation( + messages, {"0": {"name": "tool"}}, pending, tool_actions + )) + + # Should have been treated as denied + assert not mock_executor.execute.called + denied = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "denied" + ] + assert len(denied) == 1 diff --git a/tests/test_v1_translator.py b/tests/test_v1_translator.py new file mode 100644 index 00000000..92deb912 --- /dev/null +++ b/tests/test_v1_translator.py @@ -0,0 +1,577 @@ +"""Tests for the v1 API translator (Phase 4). + +Covers request translation, response translation, streaming event +translation, continuation detection, and history conversion. +""" + +import json + +import pytest + +from application.api.v1.translator import ( + _get_client_tool_name, + convert_history, + extract_tool_results, + is_continuation, + translate_request, + translate_response, + translate_stream_event, +) + + +# --------------------------------------------------------------------------- +# is_continuation +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestIsContinuation: + + def test_normal_messages_not_continuation(self): + messages = [ + {"role": "user", "content": "Hello"}, + ] + assert is_continuation(messages) is False + + def test_tool_after_assistant_tool_calls_is_continuation(self): + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "c1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "c1", "content": '{"temp": "72F"}'}, + ] + assert is_continuation(messages) is True + + def test_assistant_without_tool_calls_not_continuation(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "tool", "tool_call_id": "c1", "content": "result"}, + ] + # assistant has no tool_calls — not a valid continuation + assert is_continuation(messages) is False + + def test_empty_messages(self): + assert is_continuation([]) is False + + def test_multiple_tool_results(self): + messages = [ + {"role": "user", "content": "Do stuff"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "c1", "type": "function", "function": {"name": "a", "arguments": "{}"}}, + {"id": "c2", "type": "function", "function": {"name": "b", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": "c1", "content": "r1"}, + {"role": "tool", "tool_call_id": "c2", "content": "r2"}, + ] + assert is_continuation(messages) is True + + +# --------------------------------------------------------------------------- +# extract_tool_results +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestExtractToolResults: + + def test_extracts_results(self): + messages = [ + {"role": "assistant", "tool_calls": [{"id": "c1"}]}, + {"role": "tool", "tool_call_id": "c1", "content": '{"temp": "72F"}'}, + ] + results = extract_tool_results(messages) + assert len(results) == 1 + assert results[0]["call_id"] == "c1" + assert results[0]["result"] == {"temp": "72F"} + + def test_string_content(self): + messages = [ + {"role": "tool", "tool_call_id": "c1", "content": "plain text"}, + ] + results = extract_tool_results(messages) + assert results[0]["result"] == "plain text" + + def test_multiple_results(self): + messages = [ + {"role": "assistant", "tool_calls": []}, + {"role": "tool", "tool_call_id": "c1", "content": "r1"}, + {"role": "tool", "tool_call_id": "c2", "content": "r2"}, + ] + results = extract_tool_results(messages) + assert len(results) == 2 + assert results[0]["call_id"] == "c1" + assert results[1]["call_id"] == "c2" + + +# --------------------------------------------------------------------------- +# convert_history +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestConvertHistory: + + def test_user_assistant_pairs(self): + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I'm good"}, + {"role": "user", "content": "What's 2+2?"}, # Last user = question + ] + history = convert_history(messages) + assert len(history) == 2 + assert history[0]["prompt"] == "Hello" + assert history[0]["response"] == "Hi there" + assert history[1]["prompt"] == "How are you?" + assert history[1]["response"] == "I'm good" + + def test_single_user_message(self): + messages = [{"role": "user", "content": "Hi"}] + history = convert_history(messages) + assert history == [] + + def test_system_messages_skipped(self): + messages = [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "Question"}, + ] + history = convert_history(messages) + assert history == [] + + +# --------------------------------------------------------------------------- +# translate_request +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTranslateRequest: + + def test_normal_request(self): + data = { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "What's 2+2?"}, + ], + } + result = translate_request(data, "test-key") + assert result["question"] == "What's 2+2?" + assert result["api_key"] == "test-key" + assert result["save_conversation"] is True + history = json.loads(result["history"]) + assert len(history) == 1 + assert history[0]["prompt"] == "Hello" + + def test_continuation_request(self): + data = { + "messages": [ + {"role": "user", "content": "Search for X"}, + { + "role": "assistant", + "tool_calls": [{"id": "c1", "type": "function", "function": {"name": "search", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "c1", "content": '{"found": true}'}, + ], + } + result = translate_request(data, "key") + assert "tool_actions" in result + assert len(result["tool_actions"]) == 1 + assert result["tool_actions"][0]["call_id"] == "c1" + + def test_continuation_with_top_level_conversation_id(self): + """Standard clients send conversation_id at request level, not in messages.""" + data = { + "conversation_id": "conv-top-level", + "messages": [ + {"role": "user", "content": "Do stuff"}, + { + "role": "assistant", + "tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "c1", "content": "done"}, + ], + } + result = translate_request(data, "key") + assert result["conversation_id"] == "conv-top-level" + + def test_continuation_in_message_conversation_id_takes_precedence(self): + """When both in-message and top-level conversation_id exist, in-message wins.""" + data = { + "conversation_id": "conv-top-level", + "messages": [ + {"role": "user", "content": "Do stuff"}, + { + "role": "assistant", + "tool_calls": [{"id": "c1", "type": "function", "function": {"name": "act", "arguments": "{}"}}], + "docsgpt": {"conversation_id": "conv-in-message"}, + }, + {"role": "tool", "tool_call_id": "c1", "content": "done"}, + ], + } + result = translate_request(data, "key") + assert result["conversation_id"] == "conv-in-message" + + def test_client_tools_passed_through(self): + data = { + "messages": [{"role": "user", "content": "Hi"}], + "tools": [{"type": "function", "function": {"name": "my_tool"}}], + } + result = translate_request(data, "key") + assert result["client_tools"] == data["tools"] + + def test_docsgpt_attachments(self): + data = { + "messages": [{"role": "user", "content": "Hi"}], + "docsgpt": {"attachments": ["att1", "att2"]}, + } + result = translate_request(data, "key") + assert result["attachments"] == ["att1", "att2"] + + +# --------------------------------------------------------------------------- +# translate_response +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTranslateResponse: + + def test_basic_response(self): + resp = translate_response( + conversation_id="conv-1", + answer="Hello!", + sources=[], + tool_calls=[], + thought="", + model_name="my-agent", + ) + assert resp["id"] == "chatcmpl-conv-1" + assert resp["object"] == "chat.completion" + assert resp["model"] == "my-agent" + assert resp["choices"][0]["message"]["content"] == "Hello!" + assert resp["choices"][0]["finish_reason"] == "stop" + assert "reasoning_content" not in resp["choices"][0]["message"] + + def test_response_with_thought(self): + resp = translate_response( + conversation_id="c1", + answer="Result", + sources=[], + tool_calls=[], + thought="Thinking about it...", + model_name="agent", + ) + assert resp["choices"][0]["message"]["reasoning_content"] == "Thinking about it..." + + def test_response_with_sources(self): + sources = [{"title": "doc.txt", "text": "content", "source": "/doc.txt"}] + resp = translate_response( + conversation_id="c1", + answer="Found it", + sources=sources, + tool_calls=[], + thought="", + model_name="agent", + ) + assert resp["docsgpt"]["sources"] == sources + + def test_response_with_tool_calls(self): + tool_calls = [{"tool_name": "notes", "call_id": "c1", "artifact_id": "a1"}] + resp = translate_response( + conversation_id="c1", + answer="Done", + sources=[], + tool_calls=tool_calls, + thought="", + model_name="agent", + ) + assert resp["docsgpt"]["tool_calls"] == tool_calls + + def test_pending_tool_calls_uses_tool_name(self): + """Client tool responses use the original tool_name, not the LLM-visible action_name.""" + pending = [ + { + "call_id": "c1", + "tool_name": "get_weather", + "action_name": "get_weather", + "arguments": {"city": "SF"}, + } + ] + resp = translate_response( + conversation_id="c1", + answer="", + sources=[], + tool_calls=[], + thought="", + model_name="agent", + pending_tool_calls=pending, + ) + tc = resp["choices"][0]["message"]["tool_calls"][0] + assert tc["function"]["name"] == "get_weather" + + def test_pending_tool_calls_tool_name_takes_precedence(self): + """When tool_name differs from action_name, tool_name is used.""" + pending = [ + { + "call_id": "c1", + "tool_name": "search", + "action_name": "search_1", + "arguments": {"q": "test"}, + } + ] + resp = translate_response( + conversation_id="c1", + answer="", + sources=[], + tool_calls=[], + thought="", + model_name="agent", + pending_tool_calls=pending, + ) + tc = resp["choices"][0]["message"]["tool_calls"][0] + assert tc["function"]["name"] == "search" + + def test_pending_tool_calls(self): + pending = [ + { + "call_id": "c1", + "name": "get_weather", + "arguments": {"city": "SF"}, + } + ] + resp = translate_response( + conversation_id="c1", + answer="", + sources=[], + tool_calls=[], + thought="", + model_name="agent", + pending_tool_calls=pending, + ) + assert resp["choices"][0]["finish_reason"] == "tool_calls" + assert resp["choices"][0]["message"]["content"] is None + assert len(resp["choices"][0]["message"]["tool_calls"]) == 1 + tc = resp["choices"][0]["message"]["tool_calls"][0] + assert tc["id"] == "c1" + assert tc["function"]["name"] == "get_weather" + + def test_no_docsgpt_when_empty(self): + resp = translate_response( + conversation_id="", + answer="Hi", + sources=None, + tool_calls=None, + thought="", + model_name="agent", + ) + assert "docsgpt" not in resp + + +# --------------------------------------------------------------------------- +# translate_stream_event +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTranslateStreamEvent: + + def test_answer_event(self): + chunks = translate_stream_event( + {"type": "answer", "answer": "Hello"}, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["choices"][0]["delta"]["content"] == "Hello" + + def test_thought_event(self): + chunks = translate_stream_event( + {"type": "thought", "thought": "reasoning"}, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["choices"][0]["delta"]["reasoning_content"] == "reasoning" + + def test_source_event(self): + chunks = translate_stream_event( + {"type": "source", "source": [{"title": "t", "text": "x"}]}, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["docsgpt"]["type"] == "source" + assert len(parsed["docsgpt"]["sources"]) == 1 + + def test_end_event(self): + chunks = translate_stream_event( + {"type": "end"}, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 2 + # First chunk: finish_reason stop + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["choices"][0]["finish_reason"] == "stop" + # Second chunk: [DONE] + assert chunks[1].strip() == "data: [DONE]" + + def test_tool_call_client_execution(self): + chunks = translate_stream_event( + { + "type": "tool_call", + "data": { + "call_id": "c1", + "action_name": "get_weather", + "arguments": {"city": "SF"}, + "status": "requires_client_execution", + }, + }, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + tc = parsed["choices"][0]["delta"]["tool_calls"][0] + assert tc["id"] == "c1" + assert tc["function"]["name"] == "get_weather" + + def test_tool_call_client_execution_uses_tool_name(self): + """Streaming tool calls use tool_name (original name) for client responses.""" + chunks = translate_stream_event( + { + "type": "tool_call", + "data": { + "call_id": "c1", + "tool_name": "create", + "action_name": "create", + "arguments": {"title": "test"}, + "status": "requires_client_execution", + }, + }, + "chatcmpl-1", "agent", + ) + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + tc = parsed["choices"][0]["delta"]["tool_calls"][0] + assert tc["function"]["name"] == "create" + + def test_tool_call_completed(self): + chunks = translate_stream_event( + { + "type": "tool_call", + "data": { + "call_id": "c1", + "status": "completed", + "result": "done", + "artifact_id": "a1", + }, + }, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["docsgpt"]["type"] == "tool_call" + assert parsed["docsgpt"]["data"]["artifact_id"] == "a1" + + def test_tool_calls_pending(self): + chunks = translate_stream_event( + { + "type": "tool_calls_pending", + "data": {"pending_tool_calls": [{"call_id": "c1"}]}, + }, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 2 + # Standard chunk with finish_reason tool_calls + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["choices"][0]["finish_reason"] == "tool_calls" + # Extension chunk + ext = json.loads(chunks[1].replace("data: ", "").strip()) + assert ext["docsgpt"]["type"] == "tool_calls_pending" + + def test_id_event(self): + chunks = translate_stream_event( + {"type": "id", "id": "conv-123"}, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["docsgpt"]["conversation_id"] == "conv-123" + + def test_error_event(self): + chunks = translate_stream_event( + {"type": "error", "error": "Something went wrong"}, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["error"]["message"] == "Something went wrong" + + def test_tool_calls_event_skipped(self): + """The aggregate tool_calls event is redundant and should be skipped.""" + chunks = translate_stream_event( + {"type": "tool_calls", "tool_calls": [{"call_id": "c1"}]}, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 0 + + def test_research_events_skipped(self): + assert translate_stream_event( + {"type": "research_plan", "data": {}}, "id", "m" + ) == [] + assert translate_stream_event( + {"type": "research_progress", "data": {}}, "id", "m" + ) == [] + + def test_awaiting_approval_as_extension(self): + chunks = translate_stream_event( + { + "type": "tool_call", + "data": {"call_id": "c1", "status": "awaiting_approval"}, + }, + "chatcmpl-1", "agent", + ) + assert len(chunks) == 1 + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + assert parsed["docsgpt"]["type"] == "tool_call" + + def test_standard_clients_can_ignore_docsgpt(self): + """Standard clients parse only 'choices' — docsgpt namespace is ignored.""" + chunks = translate_stream_event( + {"type": "source", "source": [{"title": "t"}]}, + "chatcmpl-1", "agent", + ) + parsed = json.loads(chunks[0].replace("data: ", "").strip()) + # No "choices" key — standard parsers skip this chunk entirely + assert "choices" not in parsed + # docsgpt key is present + assert "docsgpt" in parsed + + +# --------------------------------------------------------------------------- +# _get_client_tool_name +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestGetClientToolName: + + def test_uses_tool_name_when_present(self): + assert _get_client_tool_name({"tool_name": "create", "action_name": "create_1"}) == "create" + + def test_falls_back_to_action_name(self): + assert _get_client_tool_name({"action_name": "get_weather"}) == "get_weather" + + def test_falls_back_to_name(self): + assert _get_client_tool_name({"name": "search"}) == "search" + + def test_returns_empty_when_no_fields(self): + assert _get_client_tool_name({}) == ""