diff --git a/application/agents/base.py b/application/agents/base.py index 15735c8c..c791ec45 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -1,7 +1,8 @@ +import json import logging import uuid from abc import ABC, abstractmethod -from typing import Dict, Generator, List, Optional +from typing import Any, Dict, Generator, List, Optional from application.agents.tool_executor import ToolExecutor from application.core.json_schema_utils import ( @@ -9,6 +10,7 @@ from application.core.json_schema_utils import ( normalize_json_schema_payload, ) from application.core.settings import settings +from application.llm.handlers.base import ToolCall from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext @@ -113,6 +115,140 @@ class BaseAgent(ABC): ) -> Generator[Dict, None, None]: pass + def gen_continuation( + self, + messages: List[Dict], + tools_dict: Dict, + pending_tool_calls: List[Dict], + tool_actions: List[Dict], + ) -> Generator[Dict, None, None]: + """Resume generation after tool actions are resolved. + + Processes the client-provided *tool_actions* (approvals, denials, + or client-side results), appends the resulting messages, then + hands back to the LLM to continue the conversation. + + Args: + messages: The saved messages array from the pause point. + tools_dict: The saved tools dictionary. + pending_tool_calls: The pending tool call descriptors from the pause. + tool_actions: Client-provided actions resolving the pending calls. + """ + self._prepare_tools(tools_dict) + + actions_by_id = {a["call_id"]: a for a in tool_actions} + + for pending in pending_tool_calls: + call_id = pending["call_id"] + action = actions_by_id.get(call_id) + if not action: + action = { + "call_id": call_id, + "decision": "denied", + "comment": "No response provided", + } + + # Build the assistant tool-call message + args = pending["arguments"] + function_call_content: Dict[str, Any] = { + "function_call": { + "name": pending["name"], + "args": args, + "call_id": call_id, + } + } + if pending.get("thought_signature"): + function_call_content["thought_signature"] = pending[ + "thought_signature" + ] + messages.append( + {"role": "assistant", "content": [function_call_content]} + ) + + if action.get("decision") == "approved": + # Execute the tool server-side + tc = ToolCall( + id=call_id, + name=pending["name"], + arguments=( + json.dumps(args) if isinstance(args, dict) else args + ), + ) + tool_gen = self._execute_tool_action(tools_dict, tc) + tool_response = None + while True: + try: + event = next(tool_gen) + yield event + except StopIteration as e: + tool_response, _ = e.value + break + messages.append( + self.llm_handler.create_tool_message(tc, tool_response) + ) + + elif action.get("decision") == "denied": + comment = action.get("comment", "") + denial = ( + f"Tool execution denied by user. Reason: {comment}" + if comment + else "Tool execution denied by user." + ) + tc = ToolCall( + id=call_id, name=pending["name"], arguments=args + ) + messages.append( + self.llm_handler.create_tool_message(tc, denial) + ) + yield { + "type": "tool_call", + "data": { + "tool_name": pending.get("tool_name", "unknown"), + "call_id": call_id, + "action_name": f"{pending['action_name']}_{pending['tool_id']}", + "arguments": args, + "status": "denied", + }, + } + + elif "result" in action: + result = action["result"] + result_str = ( + json.dumps(result) + if not isinstance(result, str) + else result + ) + tc = ToolCall( + id=call_id, name=pending["name"], arguments=args + ) + messages.append( + self.llm_handler.create_tool_message(tc, result_str) + ) + yield { + "type": "tool_call", + "data": { + "tool_name": pending.get("tool_name", "unknown"), + "call_id": call_id, + "action_name": f"{pending['action_name']}_{pending['tool_id']}", + "arguments": args, + "result": ( + result_str[:50] + "..." + if len(result_str) > 50 + else result_str + ), + "status": "completed", + }, + } + + # Resume the LLM loop with the updated messages + llm_response = self._llm_gen(messages) + yield from self._handle_response( + llm_response, tools_dict, messages, None + ) + + yield {"sources": self.retrieved_docs} + yield {"tool_calls": self._get_truncated_tool_calls()} + # ---- Tool delegation (thin wrappers around ToolExecutor) ---- @property diff --git a/application/agents/tool_executor.py b/application/agents/tool_executor.py index 69739076..f8fec61e 100644 --- a/application/agents/tool_executor.py +++ b/application/agents/tool_executor.py @@ -104,6 +104,60 @@ class ToolExecutor: params["required"].append(k) return params + def check_pause( + self, tools_dict: Dict, call, llm_class_name: str + ) -> Optional[Dict]: + """Check if a tool call requires pausing for approval or client execution. + + Returns a dict describing the pending action if pause is needed, None otherwise. + """ + parser = ToolActionParser(llm_class_name) + tool_id, action_name, call_args = parser.parse_args(call) + call_id = getattr(call, "id", None) or str(uuid.uuid4()) + + if tool_id is None or action_name is None or tool_id not in tools_dict: + return None # Will be handled as error by execute() + + tool_data = tools_dict[tool_id] + + # Phase 2: client-side tools + if tool_data.get("client_side"): + return { + "call_id": call_id, + "name": getattr(call, "name", f"{action_name}_{tool_id}"), + "tool_name": tool_data.get("name", "unknown"), + "tool_id": tool_id, + "action_name": action_name, + "arguments": call_args if isinstance(call_args, dict) else {}, + "pause_type": "requires_client_execution", + "thought_signature": getattr(call, "thought_signature", None), + } + + # Phase 3: approval required + if tool_data["name"] == "api_tool": + action_data = tool_data.get("config", {}).get("actions", {}).get( + action_name, {} + ) + else: + action_data = next( + (a for a in tool_data.get("actions", []) if a["name"] == action_name), + {}, + ) + + if action_data.get("require_approval"): + return { + "call_id": call_id, + "name": getattr(call, "name", f"{action_name}_{tool_id}"), + "tool_name": tool_data.get("name", "unknown"), + "tool_id": tool_id, + "action_name": action_name, + "arguments": call_args if isinstance(call_args, dict) else {}, + "pause_type": "awaiting_approval", + "thought_signature": getattr(call, "thought_signature", None), + } + + return None + def execute(self, tools_dict: Dict, call, llm_class_name: str): """Execute a tool call. Yields status events, returns (result, call_id).""" parser = ToolActionParser(llm_class_name) diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py index f3111605..b6fe288a 100644 --- a/application/api/answer/routes/answer.py +++ b/application/api/answer/routes/answer.py @@ -74,27 +74,56 @@ class AnswerResource(Resource, BaseAnswerResource): decoded_token = getattr(request, "decoded_token", None) processor = StreamProcessor(data, decoded_token) try: - agent = processor.build_agent(data.get("question", "")) - if not processor.decoded_token: - return make_response({"error": "Unauthorized"}, 401) + # ---- Continuation mode ---- + if data.get("tool_actions"): + ( + agent, + messages, + tools_dict, + pending_tool_calls, + tool_actions, + ) = processor.resume_from_tool_actions( + data["tool_actions"], data["conversation_id"] + ) + stream = self.complete_stream( + question="", + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + agent_id=processor.agent_id, + model_id=processor.model_id, + _continuation={ + "messages": messages, + "tools_dict": tools_dict, + "pending_tool_calls": pending_tool_calls, + "tool_actions": tool_actions, + }, + ) + else: + # ---- Normal mode ---- + agent = processor.build_agent(data.get("question", "")) + if not processor.decoded_token: + return make_response({"error": "Unauthorized"}, 401) - if error := self.check_usage(processor.agent_config): - return error + if error := self.check_usage(processor.agent_config): + return error + + stream = self.complete_stream( + question=data["question"], + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + isNoneDoc=data.get("isNoneDoc"), + index=None, + should_save_conversation=data.get("save_conversation", True), + agent_id=processor.agent_id, + is_shared_usage=processor.is_shared_usage, + shared_token=processor.shared_token, + model_id=processor.model_id, + ) - stream = self.complete_stream( - question=data["question"], - agent=agent, - conversation_id=processor.conversation_id, - user_api_key=processor.agent_config.get("user_api_key"), - decoded_token=processor.decoded_token, - isNoneDoc=data.get("isNoneDoc"), - index=None, - should_save_conversation=data.get("save_conversation", True), - agent_id=processor.agent_id, - is_shared_usage=processor.is_shared_usage, - shared_token=processor.shared_token, - model_id=processor.model_id, - ) stream_result = self.process_response_stream(stream) if len(stream_result) == 7: @@ -105,16 +134,17 @@ class AnswerResource(Resource, BaseAnswerResource): tool_calls, thought, error, - structured_info, + extra_info, ) = stream_result else: conversation_id, response, sources, tool_calls, thought, error = ( stream_result ) - structured_info = None + extra_info = None if error: return make_response({"error": error}, 400) + result = { "conversation_id": conversation_id, "answer": response, @@ -123,8 +153,8 @@ class AnswerResource(Resource, BaseAnswerResource): "thought": thought, } - if structured_info: - result.update(structured_info) + if extra_info: + result.update(extra_info) except Exception as e: logger.error( f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index 29d3d66c..9abd0261 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional from flask import jsonify, make_response, Response from flask_restx import Namespace +from application.api.answer.services.continuation_service import ContinuationService from application.api.answer.services.conversation_service import ConversationService from application.core.model_utils import ( get_api_key_for_provider, @@ -39,7 +40,16 @@ class BaseAnswerResource: def validate_request( self, data: Dict[str, Any], require_conversation_id: bool = False ) -> Optional[Response]: - """Common request validation""" + """Common request validation. + + Continuation requests (``tool_actions`` present) require + ``conversation_id`` but not ``question``. + """ + if data.get("tool_actions"): + # Continuation mode — question is not required + if missing := check_required_fields(data, ["conversation_id"]): + return missing + return None required_fields = ["question"] if require_conversation_id: required_fields.append("conversation_id") @@ -177,6 +187,7 @@ class BaseAnswerResource: is_shared_usage: bool = False, shared_token: Optional[str] = None, model_id: Optional[str] = None, + _continuation: Optional[Dict] = None, ) -> Generator[str, None, None]: """ Generator function that streams the complete conversation response. @@ -207,8 +218,19 @@ class BaseAnswerResource: schema_info = None structured_chunks = [] query_metadata = {} + paused = False - for line in agent.gen(query=question): + if _continuation: + gen_iter = agent.gen_continuation( + messages=_continuation["messages"], + tools_dict=_continuation["tools_dict"], + pending_tool_calls=_continuation["pending_tool_calls"], + tool_actions=_continuation["tool_actions"], + ) + else: + gen_iter = agent.gen(query=question) + + for line in gen_iter: if "metadata" in line: query_metadata.update(line["metadata"]) elif "answer" in line: @@ -244,15 +266,21 @@ class BaseAnswerResource: data = json.dumps({"type": "thought", "thought": line["thought"]}) yield f"data: {data}\n\n" elif "type" in line: - if line.get("type") == "error": + if line.get("type") == "tool_calls_pending": + # Save continuation state and end the stream + paused = True + data = json.dumps(line) + yield f"data: {data}\n\n" + elif line.get("type") == "error": sanitized_error = { "type": "error", "error": sanitize_api_error(line.get("error", "An error occurred")) } data = json.dumps(sanitized_error) + yield f"data: {data}\n\n" else: data = json.dumps(line) - yield f"data: {data}\n\n" + yield f"data: {data}\n\n" if is_structured and structured_chunks: structured_data = { "type": "structured_answer", @@ -262,6 +290,46 @@ class BaseAnswerResource: } data = json.dumps(structured_data) yield f"data: {data}\n\n" + + # ---- Paused: save continuation state and end stream early ---- + if paused: + continuation = getattr(agent, "_pending_continuation", None) + if continuation and conversation_id: + try: + cont_service = ContinuationService() + cont_service.save_state( + conversation_id=str(conversation_id), + user=decoded_token.get("sub", "local"), + messages=continuation["messages"], + pending_tool_calls=continuation["pending_tool_calls"], + tools_dict=continuation["tools_dict"], + tool_schemas=getattr(agent, "tools", []), + agent_config={ + "model_id": model_id or self.default_model_id, + "llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER), + "api_key": getattr(agent, "api_key", None), + "user_api_key": user_api_key, + "agent_id": agent_id, + "agent_type": agent.__class__.__name__, + "prompt": getattr(agent, "prompt", ""), + "json_schema": getattr(agent, "json_schema", None), + "retriever_config": getattr(agent, "retriever_config", None), + }, + ) + except Exception as e: + logger.error( + f"Failed to save continuation state: {str(e)}", + exc_info=True, + ) + + id_data = {"type": "id", "id": str(conversation_id)} + data = json.dumps(id_data) + yield f"data: {data}\n\n" + + data = json.dumps({"type": "end"}) + yield f"data: {data}\n\n" + return + if isNoneDoc: for doc in source_log_docs: doc["source"] = "None" @@ -435,6 +503,7 @@ class BaseAnswerResource: stream_ended = False is_structured = False schema_info = None + pending_tool_calls = None for line in stream: try: @@ -453,6 +522,10 @@ class BaseAnswerResource: source_log_docs = event["source"] elif event["type"] == "tool_calls": tool_calls = event["tool_calls"] + elif event["type"] == "tool_calls_pending": + pending_tool_calls = event.get("data", {}).get( + "pending_tool_calls", [] + ) elif event["type"] == "thought": thought = event["thought"] elif event["type"] == "error": @@ -466,6 +539,18 @@ class BaseAnswerResource: if not stream_ended: logger.error("Stream ended unexpectedly without an 'end' event.") return None, None, None, None, "Stream ended unexpectedly", None + + if pending_tool_calls is not None: + return ( + conversation_id, + response_full, + source_log_docs, + tool_calls, + thought, + None, + {"pending_tool_calls": pending_tool_calls}, + ) + result = ( conversation_id, response_full, diff --git a/application/api/answer/routes/stream.py b/application/api/answer/routes/stream.py index c2cc0ec6..2a6b9a11 100644 --- a/application/api/answer/routes/stream.py +++ b/application/api/answer/routes/stream.py @@ -79,7 +79,39 @@ class StreamResource(Resource, BaseAnswerResource): return error decoded_token = getattr(request, "decoded_token", None) processor = StreamProcessor(data, decoded_token) + try: + # ---- Continuation mode ---- + if data.get("tool_actions"): + ( + agent, + messages, + tools_dict, + pending_tool_calls, + tool_actions, + ) = processor.resume_from_tool_actions( + data["tool_actions"], data["conversation_id"] + ) + return Response( + self.complete_stream( + question="", + agent=agent, + conversation_id=processor.conversation_id, + user_api_key=processor.agent_config.get("user_api_key"), + decoded_token=processor.decoded_token, + agent_id=processor.agent_id, + model_id=processor.model_id, + _continuation={ + "messages": messages, + "tools_dict": tools_dict, + "pending_tool_calls": pending_tool_calls, + "tool_actions": tool_actions, + }, + ), + mimetype="text/event-stream", + ) + + # ---- Normal mode ---- agent = processor.build_agent(data["question"]) if not processor.decoded_token: return Response( diff --git a/application/api/answer/services/continuation_service.py b/application/api/answer/services/continuation_service.py new file mode 100644 index 00000000..fa9a8a0d --- /dev/null +++ b/application/api/answer/services/continuation_service.py @@ -0,0 +1,141 @@ +"""Service for saving and restoring tool-call continuation state. + +When a stream pauses (tool needs approval or client-side execution), +the full execution state is persisted to MongoDB so the client can +resume later by sending tool_actions. +""" + +import datetime +import logging +from typing import Any, Dict, List, Optional + +from bson import ObjectId + +from application.core.mongo_db import MongoDB +from application.core.settings import settings + +logger = logging.getLogger(__name__) + +# TTL for pending states — auto-cleaned after this period +PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes + + +def _make_serializable(obj: Any) -> Any: + """Recursively convert MongoDB ObjectIds and other non-JSON types.""" + if isinstance(obj, ObjectId): + return str(obj) + if isinstance(obj, dict): + return {str(k): _make_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_make_serializable(v) for v in obj] + if isinstance(obj, bytes): + return obj.decode("utf-8", errors="replace") + return obj + + +class ContinuationService: + """Manages pending tool-call state in MongoDB.""" + + def __init__(self): + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + self.collection = db["pending_tool_state"] + self._ensure_indexes() + + def _ensure_indexes(self): + try: + self.collection.create_index( + "expires_at", expireAfterSeconds=0 + ) + self.collection.create_index( + [("conversation_id", 1), ("user", 1)], unique=True + ) + except Exception: + # Indexes may already exist or mongomock doesn't support TTL + pass + + def save_state( + self, + conversation_id: str, + user: str, + messages: List[Dict], + pending_tool_calls: List[Dict], + tools_dict: Dict, + tool_schemas: List[Dict], + agent_config: Dict, + client_tools: Optional[List[Dict]] = None, + ) -> str: + """Save execution state for later continuation. + + Args: + conversation_id: The conversation this state belongs to. + user: Owner user ID. + messages: Full messages array at the pause point. + pending_tool_calls: Tool calls awaiting client action. + tools_dict: Serializable tools configuration dict. + tool_schemas: LLM-formatted tool schemas (agent.tools). + agent_config: Config needed to recreate the agent on resume. + client_tools: Client-provided tool schemas (Phase 2). + + Returns: + The string ID of the saved state document. + """ + now = datetime.datetime.now(datetime.timezone.utc) + expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS) + + doc = { + "conversation_id": conversation_id, + "user": user, + "messages": _make_serializable(messages), + "pending_tool_calls": _make_serializable(pending_tool_calls), + "tools_dict": _make_serializable(tools_dict), + "tool_schemas": _make_serializable(tool_schemas), + "agent_config": _make_serializable(agent_config), + "client_tools": _make_serializable(client_tools) if client_tools else None, + "created_at": now, + "expires_at": expires_at, + } + + # Upsert — only one pending state per conversation per user + result = self.collection.replace_one( + {"conversation_id": conversation_id, "user": user}, + doc, + upsert=True, + ) + state_id = str(result.upserted_id) if result.upserted_id else conversation_id + logger.info( + f"Saved continuation state for conversation {conversation_id} " + f"with {len(pending_tool_calls)} pending tool call(s)" + ) + return state_id + + def load_state( + self, conversation_id: str, user: str + ) -> Optional[Dict[str, Any]]: + """Load pending continuation state. + + Returns: + The state dict, or None if no pending state exists. + """ + doc = self.collection.find_one( + {"conversation_id": conversation_id, "user": user} + ) + if not doc: + return None + doc["_id"] = str(doc["_id"]) + return doc + + def delete_state(self, conversation_id: str, user: str) -> bool: + """Delete pending state after successful resumption. + + Returns: + True if a document was deleted. + """ + result = self.collection.delete_one( + {"conversation_id": conversation_id, "user": user} + ) + if result.deleted_count: + logger.info( + f"Deleted continuation state for conversation {conversation_id}" + ) + return result.deleted_count > 0 diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index 28a2e8d4..e663460f 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -771,6 +771,115 @@ class StreamProcessor: logger.warning(f"Failed to fetch memory tool data: {str(e)}") return None + def resume_from_tool_actions( + self, + tool_actions: list, + conversation_id: str, + ): + """Resume a paused agent from saved continuation state. + + Loads the pending state from MongoDB, recreates the agent with + the saved configuration, and returns an agent ready to call + ``gen_continuation()``. + + Args: + tool_actions: Client-provided actions (approvals / results). + conversation_id: The conversation being resumed. + + Returns: + Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions). + """ + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + from application.agents.agent_creator import AgentCreator + from application.agents.tool_executor import ToolExecutor + from application.llm.handlers.handler_creator import LLMHandlerCreator + from application.llm.llm_creator import LLMCreator + + cont_service = ContinuationService() + state = cont_service.load_state(conversation_id, self.initial_user_id) + if not state: + raise ValueError("No pending tool state found for this conversation") + + messages = state["messages"] + pending_tool_calls = state["pending_tool_calls"] + tools_dict = state["tools_dict"] + tool_schemas = state.get("tool_schemas", []) + agent_config = state["agent_config"] + + model_id = agent_config.get("model_id") + llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER) + api_key = agent_config.get("api_key") + user_api_key = agent_config.get("user_api_key") + agent_id = agent_config.get("agent_id") + prompt = agent_config.get("prompt", "") + json_schema = agent_config.get("json_schema") + retriever_config = agent_config.get("retriever_config") + + # Recreate dependencies + system_api_key = api_key or get_api_key_for_provider(llm_name) + llm = LLMCreator.create_llm( + llm_name, + api_key=system_api_key, + user_api_key=user_api_key, + decoded_token=self.decoded_token, + model_id=model_id, + agent_id=agent_id, + ) + llm_handler = LLMHandlerCreator.create_handler(llm_name or "default") + tool_executor = ToolExecutor( + user_api_key=user_api_key, + user=self.initial_user_id, + decoded_token=self.decoded_token, + ) + tool_executor.conversation_id = conversation_id + + agent_type = agent_config.get("agent_type", "ClassicAgent") + # Map class names back to agent creator keys + type_map = { + "ClassicAgent": "classic", + "AgenticAgent": "agentic", + "ResearchAgent": "research", + "WorkflowAgent": "workflow", + } + agent_key = type_map.get(agent_type, "classic") + + agent_kwargs = { + "endpoint": "stream", + "llm_name": llm_name, + "model_id": model_id, + "api_key": system_api_key, + "agent_id": agent_id, + "user_api_key": user_api_key, + "prompt": prompt, + "chat_history": [], + "decoded_token": self.decoded_token, + "json_schema": json_schema, + "llm": llm, + "llm_handler": llm_handler, + "tool_executor": tool_executor, + } + + if agent_key in ("agentic", "research") and retriever_config: + agent_kwargs["retriever_config"] = retriever_config + + agent = AgentCreator.create_agent(agent_key, **agent_kwargs) + agent.conversation_id = conversation_id + agent.initial_user_id = self.initial_user_id + agent.tools = tool_schemas + + # Store config for the route layer + self.model_id = model_id + self.agent_id = agent_id + self.agent_config["user_api_key"] = user_api_key + self.conversation_id = conversation_id + + # Delete state so it can't be replayed + cont_service.delete_state(conversation_id, self.initial_user_id) + + return agent, messages, tools_dict, pending_tool_calls, tool_actions + def create_agent( self, docs_together: Optional[str] = None, diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index 7537d9c5..be8aff22 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -648,6 +648,13 @@ class LLMHandler(ABC): """ Execute tool calls and update conversation history. + When a tool requires approval or client-side execution, it is + collected as a pending action instead of being executed. The + generator returns ``(updated_messages, pending_actions)`` where + *pending_actions* is ``None`` when every tool was executed + normally, or a list of dicts describing actions the client must + resolve before the LLM loop can continue. + Args: agent: The agent instance tool_calls: List of tool calls to execute @@ -655,9 +662,11 @@ class LLMHandler(ABC): messages: Current conversation history Returns: - Updated messages list + Tuple of (updated_messages, pending_actions). + pending_actions is None if all tools executed, otherwise a list. """ updated_messages = messages.copy() + pending_actions: List[Dict] = [] for i, call in enumerate(tool_calls): # Check context limit before executing tool call @@ -763,6 +772,29 @@ class LLMHandler(ABC): # Set flag on agent agent.context_limit_reached = True break + + # ---- Pause check: approval / client-side execution ---- + llm_class = agent.llm.__class__.__name__ + pause_info = agent.tool_executor.check_pause( + tools_dict, call, llm_class + ) + if pause_info: + # Yield pause event so the client knows this tool is waiting + yield { + "type": "tool_call", + "data": { + "tool_name": pause_info["tool_name"], + "call_id": pause_info["call_id"], + "action_name": f"{pause_info['action_name']}_{pause_info['tool_id']}", + "arguments": pause_info["arguments"], + "status": pause_info["pause_type"], + }, + } + pending_actions.append(pause_info) + # Do NOT add messages for pending tools here. + # They will be added on resume to keep call/result pairs together. + continue + try: self.tool_calls.append(call) tool_executor_gen = agent._execute_tool_action(tools_dict, call) @@ -772,7 +804,7 @@ class LLMHandler(ABC): except StopIteration as e: tool_response, call_id = e.value break - + function_call_content = { "function_call": { "name": call.name, @@ -823,7 +855,7 @@ class LLMHandler(ABC): "status": "error", }, } - return updated_messages + return updated_messages, pending_actions if pending_actions else None def handle_non_streaming( self, agent, response: Any, tools_dict: Dict, messages: List[Dict] @@ -851,8 +883,22 @@ class LLMHandler(ABC): try: yield next(tool_handler_gen) except StopIteration as e: - messages = e.value + messages, pending_actions = e.value break + + # If tools need approval or client execution, pause the loop + if pending_actions: + agent._pending_continuation = { + "messages": messages, + "pending_tool_calls": pending_actions, + "tools_dict": tools_dict, + } + yield { + "type": "tool_calls_pending", + "data": {"pending_tool_calls": pending_actions}, + } + return "" + response = agent.llm.gen( model=agent.model_id, messages=messages, tools=agent.tools ) @@ -913,10 +959,23 @@ class LLMHandler(ABC): try: yield next(tool_handler_gen) except StopIteration as e: - messages = e.value + messages, pending_actions = e.value break tool_calls = {} + # If tools need approval or client execution, pause the loop + if pending_actions: + agent._pending_continuation = { + "messages": messages, + "pending_tool_calls": pending_actions, + "tools_dict": tools_dict, + } + yield { + "type": "tool_calls_pending", + "data": {"pending_tool_calls": pending_actions}, + } + return + # Check if context limit was reached during tool execution if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached: # Add system message warning about context limit diff --git a/tests/llm/handlers/test_llm_handlers.py b/tests/llm/handlers/test_llm_handlers.py index 304d89fb..b302970d 100644 --- a/tests/llm/handlers/test_llm_handlers.py +++ b/tests/llm/handlers/test_llm_handlers.py @@ -621,6 +621,7 @@ class TestHandleToolCalls: agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) def fake_execute(tools_dict, call): yield {"type": "tool_call", "data": {"status": "pending"}} @@ -641,7 +642,7 @@ class TestHandleToolCalls: while True: events.append(next(gen)) except StopIteration as e: - messages = e.value + messages, _pending = e.value assert any(e.get("type") == "tool_call" for e in events) assert len(messages) >= 2 # function_call + tool_message @@ -675,6 +676,8 @@ class TestHandleToolCalls: agent = Mock() agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) agent._execute_tool_action = Mock(side_effect=RuntimeError("exec error")) call = ToolCall(id="c1", name="action_1", arguments="{}") @@ -704,7 +707,7 @@ class TestHandleToolCalls: while True: next(gen) except StopIteration as e: - messages = e.value + messages, _pending = e.value assistant_msgs = [ m for m in messages @@ -751,6 +754,7 @@ class TestHandleNonStreaming: agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) # First response requires tool call, second is final call_count = {"n": 0} @@ -856,6 +860,7 @@ class TestHandleStreaming: agent._check_context_limit = Mock(return_value=False) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) # First chunk has partial tool call, second completes it chunk1 = LLMResponse( @@ -907,6 +912,7 @@ class TestHandleStreaming: agent.context_limit_reached = True agent._check_context_limit = Mock(return_value=True) agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) # Chunk finishes with tool_calls chunk = LLMResponse( @@ -929,7 +935,7 @@ class TestHandleStreaming: def fake_handle_tool_calls(agent, calls, tools_dict, messages): agent.context_limit_reached = True yield {"type": "tool_call", "data": {"status": "skipped"}} - return messages + return messages, None handler.handle_tool_calls = fake_handle_tool_calls @@ -1501,6 +1507,7 @@ class TestHandleToolCallsCompressionSuccess: agent._check_context_limit = Mock(side_effect=check_limit) agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) def fake_execute(tools_dict, call): yield {"type": "tool_call", "data": {"status": "pending"}} @@ -1538,6 +1545,7 @@ class TestHandleToolCallsCompressionSuccess: agent = Mock() agent.context_limit_reached = False agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) exec_count = {"n": 0} diff --git a/tests/llm/test_base.py b/tests/llm/test_base.py index c12bbdc7..613e541a 100644 --- a/tests/llm/test_base.py +++ b/tests/llm/test_base.py @@ -478,6 +478,8 @@ class TestHandleToolCallsErrors: handler = ConcreteHandler() agent = MagicMock() agent._check_context_limit = MagicMock(return_value=False) + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = MagicMock(return_value=None) agent._execute_tool_action = MagicMock( side_effect=RuntimeError("tool failed") ) @@ -506,6 +508,8 @@ class TestHandleToolCallsErrors: handler = ConcreteHandler() agent = MagicMock() agent._check_context_limit = MagicMock(return_value=False) + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = MagicMock(return_value=None) agent._execute_tool_action = MagicMock( side_effect=RuntimeError("tool failed") ) @@ -1169,6 +1173,8 @@ class TestHandleToolCallsErrorsAdditional: handler = ConcreteHandler() agent = MagicMock() agent._check_context_limit = MagicMock(return_value=False) + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = MagicMock(return_value=None) agent._execute_tool_action = MagicMock( side_effect=RuntimeError("broken tool") ) @@ -1188,7 +1194,7 @@ class TestHandleToolCallsErrorsAdditional: while True: events.append(next(gen)) except StopIteration as e: - final_messages = e.value + final_messages, _pending = e.value # Verify the error message was appended error_msgs = [ @@ -1211,6 +1217,10 @@ class TestHandleToolCallsErrorsAdditional: """Cover line 660: messages.copy() at start of handle_tool_calls.""" handler = ConcreteHandler() agent = MagicMock(spec=[]) # No _check_context_limit attribute + agent.llm = MagicMock() + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor = MagicMock() + agent.tool_executor.check_pause = MagicMock(return_value=None) agent._execute_tool_action = MagicMock( side_effect=ValueError("bad args") ) diff --git a/tests/test_agent_token_tracking.py b/tests/test_agent_token_tracking.py index e168567a..b1ab1a99 100644 --- a/tests/test_agent_token_tracking.py +++ b/tests/test_agent_token_tracking.py @@ -176,6 +176,8 @@ class TestLLMHandlerTokenTracking: # Create mock agent that hits limit on second tool mock_agent = Mock() mock_agent.context_limit_reached = False + mock_agent.llm.__class__.__name__ = "MockLLM" + mock_agent.tool_executor.check_pause = Mock(return_value=None) call_count = [0] @@ -235,6 +237,8 @@ class TestLLMHandlerTokenTracking: mock_agent = Mock() mock_agent.context_limit_reached = False mock_agent._check_context_limit = Mock(return_value=False) + mock_agent.llm.__class__.__name__ = "MockLLM" + mock_agent.tool_executor.check_pause = Mock(return_value=None) mock_agent._execute_tool_action = Mock( return_value=iter([{"type": "tool_call", "data": {}}]) ) @@ -300,7 +304,7 @@ class TestLLMHandlerTokenTracking: def tool_handler_gen(*args): yield {"type": "tool", "data": {}} - return [] + return [], None # Mock handle_tool_calls to return messages and set flag with patch.object( diff --git a/tests/test_continuation.py b/tests/test_continuation.py new file mode 100644 index 00000000..2858290c --- /dev/null +++ b/tests/test_continuation.py @@ -0,0 +1,667 @@ +"""Tests for the continuation infrastructure (Phase 1). + +Covers ContinuationService, ToolExecutor.check_pause, handler pause +signaling, BaseAgent.gen_continuation, and request validation. +""" + +from unittest.mock import Mock + +import pytest + +from application.agents.tool_executor import ToolExecutor +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +# --------------------------------------------------------------------------- +# ContinuationService +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestContinuationService: + + def test_save_and_load(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + svc.save_state( + conversation_id="conv-1", + user="alice", + messages=[{"role": "user", "content": "hi"}], + pending_tool_calls=[{"call_id": "c1", "pause_type": "awaiting_approval"}], + tools_dict={"0": {"name": "test_tool"}}, + tool_schemas=[{"type": "function", "function": {"name": "act_0"}}], + agent_config={"model_id": "gpt-4"}, + ) + + state = svc.load_state("conv-1", "alice") + assert state is not None + assert state["conversation_id"] == "conv-1" + assert state["user"] == "alice" + assert len(state["messages"]) == 1 + assert len(state["pending_tool_calls"]) == 1 + assert state["agent_config"]["model_id"] == "gpt-4" + + def test_load_returns_none_when_missing(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + assert svc.load_state("nonexistent", "alice") is None + + def test_delete_state(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + svc.save_state( + conversation_id="conv-2", + user="bob", + messages=[], + pending_tool_calls=[], + tools_dict={}, + tool_schemas=[], + agent_config={}, + ) + assert svc.delete_state("conv-2", "bob") is True + assert svc.load_state("conv-2", "bob") is None + + def test_delete_nonexistent(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + assert svc.delete_state("nope", "nope") is False + + def test_upsert_replaces_existing(self, mock_mongo_db): + from application.api.answer.services.continuation_service import ( + ContinuationService, + ) + + svc = ContinuationService() + svc.save_state( + conversation_id="conv-3", + user="carol", + messages=[{"role": "user", "content": "v1"}], + pending_tool_calls=[], + tools_dict={}, + tool_schemas=[], + agent_config={}, + ) + svc.save_state( + conversation_id="conv-3", + user="carol", + messages=[{"role": "user", "content": "v2"}], + pending_tool_calls=[{"call_id": "c2"}], + tools_dict={}, + tool_schemas=[], + agent_config={}, + ) + state = svc.load_state("conv-3", "carol") + assert state["messages"][0]["content"] == "v2" + assert len(state["pending_tool_calls"]) == 1 + + +# --------------------------------------------------------------------------- +# ToolExecutor.check_pause +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckPause: + + def _make_call(self, name="action_0", call_id="c1", arguments="{}"): + call = Mock() + call.name = name + call.id = call_id + call.arguments = arguments + call.thought_signature = None + return call + + def test_returns_none_for_normal_tool(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "brave", + "actions": [ + {"name": "search", "active": True, "parameters": {}}, + ], + } + } + call = self._make_call(name="search_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is None + + def test_returns_pause_for_client_side_tool(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "get_weather", + "client_side": True, + "actions": [ + {"name": "get_weather", "active": True, "parameters": {}}, + ], + } + } + call = self._make_call(name="get_weather_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is not None + assert result["pause_type"] == "requires_client_execution" + assert result["call_id"] == "c1" + assert result["tool_id"] == "0" + + def test_returns_pause_for_approval_required(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "telegram", + "actions": [ + { + "name": "send_msg", + "active": True, + "require_approval": True, + "parameters": {}, + }, + ], + } + } + call = self._make_call(name="send_msg_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is not None + assert result["pause_type"] == "awaiting_approval" + + def test_returns_none_when_parse_fails(self): + executor = ToolExecutor() + call = self._make_call(name="bad_name_no_id", arguments="not json") + # Bad arguments will cause parse error -> None + result = executor.check_pause({}, call, "OpenAILLM") + assert result is None + + def test_returns_none_when_tool_not_in_dict(self): + executor = ToolExecutor() + call = self._make_call(name="action_99") + result = executor.check_pause({"0": {"name": "t"}}, call, "OpenAILLM") + assert result is None + + def test_api_tool_approval(self): + executor = ToolExecutor() + tools_dict = { + "0": { + "name": "api_tool", + "config": { + "actions": { + "delete_user": { + "name": "delete_user", + "require_approval": True, + "url": "http://example.com", + "method": "DELETE", + "active": True, + } + } + }, + } + } + call = self._make_call(name="delete_user_0") + result = executor.check_pause(tools_dict, call, "OpenAILLM") + assert result is not None + assert result["pause_type"] == "awaiting_approval" + + +# --------------------------------------------------------------------------- +# Handler pause signaling (handle_tool_calls returns pending_actions) +# --------------------------------------------------------------------------- + + +class ConcreteHandler(LLMHandler): + """Minimal concrete handler for testing.""" + + def parse_response(self, response): + return LLMResponse( + content=str(response), tool_calls=[], finish_reason="stop", + raw_response=response, + ) + + def create_tool_message(self, tool_call, result): + return { + "role": "tool", + "content": [ + { + "function_response": { + "name": tool_call.name, + "response": {"result": result}, + "call_id": tool_call.id, + } + } + ], + } + + def _iterate_stream(self, response): + for chunk in response: + yield chunk + + +@pytest.mark.unit +class TestHandlerPauseSignaling: + + def _make_agent(self): + agent = Mock() + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + agent.tool_executor.check_pause = Mock(return_value=None) + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("tool result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + return agent + + def test_no_pause_returns_none_pending(self): + handler = ConcreteHandler() + agent = self._make_agent() + call = ToolCall(id="c1", name="action_0", arguments="{}") + + gen = handler.handle_tool_calls(agent, [call], {"0": {"name": "t"}}, []) + events = [] + messages = None + pending = "NOT_SET" + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + assert pending is None + assert messages is not None + + def test_pause_returns_pending_actions(self): + handler = ConcreteHandler() + agent = self._make_agent() + agent.tool_executor.check_pause = Mock(return_value={ + "call_id": "c1", + "name": "send_msg_0", + "tool_name": "telegram", + "tool_id": "0", + "action_name": "send_msg", + "arguments": {"text": "hello"}, + "pause_type": "awaiting_approval", + "thought_signature": None, + }) + + call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}') + gen = handler.handle_tool_calls( + agent, [call], {"0": {"name": "telegram"}}, [] + ) + + events = [] + pending = None + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + assert pending is not None + assert len(pending) == 1 + assert pending[0]["pause_type"] == "awaiting_approval" + + # Should have yielded a tool_call event with awaiting_approval status + pause_events = [ + e for e in events + if e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "awaiting_approval" + ] + assert len(pause_events) == 1 + + def test_mixed_execute_and_pause(self): + """One tool executes, another needs approval.""" + handler = ConcreteHandler() + agent = self._make_agent() + + call_count = {"n": 0} + + def selective_pause(tools_dict, call, llm_class): + call_count["n"] += 1 + if call_count["n"] == 2: + return { + "call_id": "c2", + "name": "danger_0", + "tool_name": "danger", + "tool_id": "0", + "action_name": "danger", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + return None + + agent.tool_executor.check_pause = Mock(side_effect=selective_pause) + + calls = [ + ToolCall(id="c1", name="safe_0", arguments="{}"), + ToolCall(id="c2", name="danger_0", arguments="{}"), + ] + gen = handler.handle_tool_calls( + agent, calls, {"0": {"name": "multi"}}, [] + ) + + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages, pending = e.value + + # First tool was executed normally + assert agent._execute_tool_action.call_count == 1 + # Second tool is pending + assert pending is not None + assert len(pending) == 1 + assert pending[0]["call_id"] == "c2" + + +# --------------------------------------------------------------------------- +# handle_streaming yields tool_calls_pending +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestStreamingPause: + + def test_streaming_yields_tool_calls_pending(self): + handler = ConcreteHandler() + agent = Mock() + agent.llm = Mock() + agent.model_id = "test" + agent.tools = [] + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + pause_info = { + "call_id": "c1", + "name": "fn_0", + "tool_name": "test", + "tool_id": "0", + "action_name": "fn", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + agent.tool_executor.check_pause = Mock(return_value=pause_info) + + chunk = LLMResponse( + content="", + tool_calls=[ToolCall(id="c1", name="fn_0", arguments="{}", index=0)], + finish_reason="tool_calls", + raw_response={}, + ) + handler.parse_response = lambda c: c + + def fake_iterate(response): + yield from response + + handler._iterate_stream = fake_iterate + + gen = handler.handle_streaming(agent, [chunk], {"0": {"name": "t"}}, []) + events = list(gen) + + # Should contain a tool_calls_pending event + pending_events = [ + e for e in events + if isinstance(e, dict) and e.get("type") == "tool_calls_pending" + ] + assert len(pending_events) == 1 + assert len(pending_events[0]["data"]["pending_tool_calls"]) == 1 + + # Agent should have _pending_continuation set + assert hasattr(agent, "_pending_continuation") + + +# --------------------------------------------------------------------------- +# BaseAgent.gen_continuation +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestGenContinuation: + + def test_approved_tool_executes(self): + """When a tool action is approved, the tool is executed.""" + from application.agents.classic_agent import ClassicAgent + + mock_llm = Mock() + mock_llm._supports_tools = True + mock_llm.gen_stream = Mock(return_value=iter(["Final answer"])) + mock_llm._supports_structured_output = Mock(return_value=False) + mock_llm.__class__.__name__ = "MockLLM" + + mock_handler = Mock() + mock_handler.process_message_flow = Mock(return_value=iter([])) + mock_handler.create_tool_message = Mock( + return_value={"role": "tool", "content": [{"function_response": { + "name": "act_0", "response": {"result": "done"}, "call_id": "c1" + }}]} + ) + + mock_executor = Mock() + mock_executor.tool_calls = [] + mock_executor.prepare_tools_for_llm = Mock(return_value=[]) + mock_executor.get_truncated_tool_calls = Mock(return_value=[]) + + def fake_execute(tools_dict, call, llm_class): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("result_data", "c1") + + mock_executor.execute = Mock(side_effect=fake_execute) + + agent = ClassicAgent( + endpoint="stream", + llm_name="openai", + model_id="gpt-4", + api_key="test", + llm=mock_llm, + llm_handler=mock_handler, + tool_executor=mock_executor, + ) + + messages = [{"role": "system", "content": "You are helpful."}] + tools_dict = {"0": {"name": "test_tool"}} + pending = [ + { + "call_id": "c1", + "name": "act_0", + "tool_name": "test_tool", + "tool_id": "0", + "action_name": "act", + "arguments": {"q": "test"}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + tool_actions = [{"call_id": "c1", "decision": "approved"}] + + list(agent.gen_continuation(messages, tools_dict, pending, tool_actions)) + + # Tool should have been executed + assert mock_executor.execute.called + + def test_denied_tool_sends_denial(self): + """When a tool action is denied, a denial message is added.""" + from application.agents.classic_agent import ClassicAgent + + mock_llm = Mock() + mock_llm._supports_tools = True + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + mock_llm._supports_structured_output = Mock(return_value=False) + mock_llm.__class__.__name__ = "MockLLM" + + mock_handler = Mock() + mock_handler.process_message_flow = Mock(return_value=iter([])) + mock_handler.create_tool_message = Mock( + return_value={"role": "tool", "content": "denied"} + ) + + mock_executor = Mock() + mock_executor.tool_calls = [] + mock_executor.prepare_tools_for_llm = Mock(return_value=[]) + mock_executor.get_truncated_tool_calls = Mock(return_value=[]) + + agent = ClassicAgent( + endpoint="stream", + llm_name="openai", + model_id="gpt-4", + api_key="test", + llm=mock_llm, + llm_handler=mock_handler, + tool_executor=mock_executor, + ) + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "danger_0", + "tool_name": "danger", + "tool_id": "0", + "action_name": "danger", + "arguments": {}, + "pause_type": "awaiting_approval", + "thought_signature": None, + } + ] + tool_actions = [ + {"call_id": "c1", "decision": "denied", "comment": "too risky"} + ] + + events = list( + agent.gen_continuation(messages, {"0": {"name": "danger"}}, pending, tool_actions) + ) + + # Should have a denied tool_call event + denied = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "denied" + ] + assert len(denied) == 1 + + # create_tool_message should have been called with denial text + denial_arg = mock_handler.create_tool_message.call_args[0][1] + assert "denied" in denial_arg.lower() + assert "too risky" in denial_arg + + def test_client_result_appended(self): + """Client-provided tool result is added to messages.""" + from application.agents.classic_agent import ClassicAgent + + mock_llm = Mock() + mock_llm._supports_tools = True + mock_llm.gen_stream = Mock(return_value=iter(["Done"])) + mock_llm._supports_structured_output = Mock(return_value=False) + mock_llm.__class__.__name__ = "MockLLM" + + mock_handler = Mock() + mock_handler.process_message_flow = Mock(return_value=iter([])) + mock_handler.create_tool_message = Mock( + return_value={"role": "tool", "content": "client result"} + ) + + mock_executor = Mock() + mock_executor.tool_calls = [] + mock_executor.prepare_tools_for_llm = Mock(return_value=[]) + mock_executor.get_truncated_tool_calls = Mock(return_value=[]) + + agent = ClassicAgent( + endpoint="stream", + llm_name="openai", + model_id="gpt-4", + api_key="test", + llm=mock_llm, + llm_handler=mock_handler, + tool_executor=mock_executor, + ) + + messages = [{"role": "system", "content": "test"}] + pending = [ + { + "call_id": "c1", + "name": "weather_0", + "tool_name": "weather", + "tool_id": "0", + "action_name": "weather", + "arguments": {"city": "SF"}, + "pause_type": "requires_client_execution", + "thought_signature": None, + } + ] + tool_actions = [{"call_id": "c1", "result": {"temp": "72F"}}] + + events = list( + agent.gen_continuation(messages, {"0": {"name": "weather"}}, pending, tool_actions) + ) + + # create_tool_message was called with the client result + result_arg = mock_handler.create_tool_message.call_args[0][1] + assert "72F" in result_arg + + # Should have a completed tool_call event + completed = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "completed" + ] + assert len(completed) == 1 + + +# --------------------------------------------------------------------------- +# validate_request +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestValidateRequest: + + @pytest.fixture(autouse=True) + def _app_context(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + yield + + def test_continuation_request_without_question(self, mock_mongo_db): + from application.api.answer.routes.base import BaseAnswerResource + + base = BaseAnswerResource() + data = { + "conversation_id": "conv-1", + "tool_actions": [{"call_id": "c1", "decision": "approved"}], + } + result = base.validate_request(data) + assert result is None # Valid + + def test_continuation_request_missing_conversation_id(self, mock_mongo_db): + from application.api.answer.routes.base import BaseAnswerResource + + base = BaseAnswerResource() + data = { + "tool_actions": [{"call_id": "c1", "decision": "approved"}], + } + result = base.validate_request(data) + assert result is not None # Error — missing conversation_id + + def test_normal_request_still_requires_question(self, mock_mongo_db): + from application.api.answer.routes.base import BaseAnswerResource + + base = BaseAnswerResource() + data = {"conversation_id": "conv-1"} + result = base.validate_request(data) + assert result is not None # Error — missing question