feat: continuation messages

This commit is contained in:
Alex
2026-03-31 21:30:24 +01:00
parent 772860b667
commit d609efca49
12 changed files with 1373 additions and 38 deletions

View File

@@ -1,7 +1,8 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional
from application.agents.tool_executor import ToolExecutor
from application.core.json_schema_utils import (
@@ -9,6 +10,7 @@ from application.core.json_schema_utils import (
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.llm.handlers.base import ToolCall
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
@@ -113,6 +115,140 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
def gen_continuation(
self,
messages: List[Dict],
tools_dict: Dict,
pending_tool_calls: List[Dict],
tool_actions: List[Dict],
) -> Generator[Dict, None, None]:
"""Resume generation after tool actions are resolved.
Processes the client-provided *tool_actions* (approvals, denials,
or client-side results), appends the resulting messages, then
hands back to the LLM to continue the conversation.
Args:
messages: The saved messages array from the pause point.
tools_dict: The saved tools dictionary.
pending_tool_calls: The pending tool call descriptors from the pause.
tool_actions: Client-provided actions resolving the pending calls.
"""
self._prepare_tools(tools_dict)
actions_by_id = {a["call_id"]: a for a in tool_actions}
for pending in pending_tool_calls:
call_id = pending["call_id"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
# Build the assistant tool-call message
args = pending["arguments"]
function_call_content: Dict[str, Any] = {
"function_call": {
"name": pending["name"],
"args": args,
"call_id": call_id,
}
}
if pending.get("thought_signature"):
function_call_content["thought_signature"] = pending[
"thought_signature"
]
messages.append(
{"role": "assistant", "content": [function_call_content]}
)
if action.get("decision") == "approved":
# Execute the tool server-side
tc = ToolCall(
id=call_id,
name=pending["name"],
arguments=(
json.dumps(args) if isinstance(args, dict) else args
),
)
tool_gen = self._execute_tool_action(tools_dict, tc)
tool_response = None
while True:
try:
event = next(tool_gen)
yield event
except StopIteration as e:
tool_response, _ = e.value
break
messages.append(
self.llm_handler.create_tool_message(tc, tool_response)
)
elif action.get("decision") == "denied":
comment = action.get("comment", "")
denial = (
f"Tool execution denied by user. Reason: {comment}"
if comment
else "Tool execution denied by user."
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, denial)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": f"{pending['action_name']}_{pending['tool_id']}",
"arguments": args,
"status": "denied",
},
}
elif "result" in action:
result = action["result"]
result_str = (
json.dumps(result)
if not isinstance(result, str)
else result
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, result_str)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": f"{pending['action_name']}_{pending['tool_id']}",
"arguments": args,
"result": (
result_str[:50] + "..."
if len(result_str) > 50
else result_str
),
"status": "completed",
},
}
# Resume the LLM loop with the updated messages
llm_response = self._llm_gen(messages)
yield from self._handle_response(
llm_response, tools_dict, messages, None
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property

View File

@@ -104,6 +104,60 @@ class ToolExecutor:
params["required"].append(k)
return params
def check_pause(
self, tools_dict: Dict, call, llm_class_name: str
) -> Optional[Dict]:
"""Check if a tool call requires pausing for approval or client execution.
Returns a dict describing the pending action if pause is needed, None otherwise.
"""
parser = ToolActionParser(llm_class_name)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None or tool_id not in tools_dict:
return None # Will be handled as error by execute()
tool_data = tools_dict[tool_id]
# Phase 2: client-side tools
if tool_data.get("client_side"):
return {
"call_id": call_id,
"name": getattr(call, "name", f"{action_name}_{tool_id}"),
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "requires_client_execution",
"thought_signature": getattr(call, "thought_signature", None),
}
# Phase 3: approval required
if tool_data["name"] == "api_tool":
action_data = tool_data.get("config", {}).get("actions", {}).get(
action_name, {}
)
else:
action_data = next(
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
{},
)
if action_data.get("require_approval"):
return {
"call_id": call_id,
"name": getattr(call, "name", f"{action_name}_{tool_id}"),
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "awaiting_approval",
"thought_signature": getattr(call, "thought_signature", None),
}
return None
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name)

View File

@@ -74,27 +74,56 @@ class AnswerResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
stream = self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
)
else:
# ---- Normal mode ----
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
@@ -105,16 +134,17 @@ class AnswerResource(Resource, BaseAnswerResource):
tool_calls,
thought,
error,
structured_info,
extra_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
extra_info = None
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
@@ -123,8 +153,8 @@ class AnswerResource(Resource, BaseAnswerResource):
"thought": thought,
}
if structured_info:
result.update(structured_info)
if extra_info:
result.update(extra_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
@@ -39,7 +40,16 @@ class BaseAnswerResource:
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
"""Common request validation.
Continuation requests (``tool_actions`` present) require
``conversation_id`` but not ``question``.
"""
if data.get("tool_actions"):
# Continuation mode — question is not required
if missing := check_required_fields(data, ["conversation_id"]):
return missing
return None
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
@@ -177,6 +187,7 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -207,8 +218,19 @@ class BaseAnswerResource:
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
for line in agent.gen(query=question):
if _continuation:
gen_iter = agent.gen_continuation(
messages=_continuation["messages"],
tools_dict=_continuation["tools_dict"],
pending_tool_calls=_continuation["pending_tool_calls"],
tool_actions=_continuation["tool_actions"],
)
else:
gen_iter = agent.gen(query=question)
for line in gen_iter:
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
@@ -244,15 +266,21 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "error":
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -262,6 +290,46 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation and conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
@@ -435,6 +503,7 @@ class BaseAnswerResource:
stream_ended = False
is_structured = False
schema_info = None
pending_tool_calls = None
for line in stream:
try:
@@ -453,6 +522,10 @@ class BaseAnswerResource:
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "tool_calls_pending":
pending_tool_calls = event.get("data", {}).get(
"pending_tool_calls", []
)
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
@@ -466,6 +539,18 @@ class BaseAnswerResource:
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly", None
if pending_tool_calls is not None:
return (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
{"pending_tool_calls": pending_tool_calls},
)
result = (
conversation_id,
response_full,

View File

@@ -79,7 +79,39 @@ class StreamResource(Resource, BaseAnswerResource):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
return Response(
self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
),
mimetype="text/event-stream",
)
# ---- Normal mode ----
agent = processor.build_agent(data["question"])
if not processor.decoded_token:
return Response(

View File

@@ -0,0 +1,141 @@
"""Service for saving and restoring tool-call continuation state.
When a stream pauses (tool needs approval or client-side execution),
the full execution state is persisted to MongoDB so the client can
resume later by sending tool_actions.
"""
import datetime
import logging
from typing import Any, Dict, List, Optional
from bson import ObjectId
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
def _make_serializable(obj: Any) -> Any:
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
if isinstance(obj, ObjectId):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
"""Manages pending tool-call state in MongoDB."""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.collection = db["pending_tool_state"]
self._ensure_indexes()
def _ensure_indexes(self):
try:
self.collection.create_index(
"expires_at", expireAfterSeconds=0
)
self.collection.create_index(
[("conversation_id", 1), ("user", 1)], unique=True
)
except Exception:
# Indexes may already exist or mongomock doesn't support TTL
pass
def save_state(
self,
conversation_id: str,
user: str,
messages: List[Dict],
pending_tool_calls: List[Dict],
tools_dict: Dict,
tool_schemas: List[Dict],
agent_config: Dict,
client_tools: Optional[List[Dict]] = None,
) -> str:
"""Save execution state for later continuation.
Args:
conversation_id: The conversation this state belongs to.
user: Owner user ID.
messages: Full messages array at the pause point.
pending_tool_calls: Tool calls awaiting client action.
tools_dict: Serializable tools configuration dict.
tool_schemas: LLM-formatted tool schemas (agent.tools).
agent_config: Config needed to recreate the agent on resume.
client_tools: Client-provided tool schemas (Phase 2).
Returns:
The string ID of the saved state document.
"""
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
doc = {
"conversation_id": conversation_id,
"user": user,
"messages": _make_serializable(messages),
"pending_tool_calls": _make_serializable(pending_tool_calls),
"tools_dict": _make_serializable(tools_dict),
"tool_schemas": _make_serializable(tool_schemas),
"agent_config": _make_serializable(agent_config),
"client_tools": _make_serializable(client_tools) if client_tools else None,
"created_at": now,
"expires_at": expires_at,
}
# Upsert — only one pending state per conversation per user
result = self.collection.replace_one(
{"conversation_id": conversation_id, "user": user},
doc,
upsert=True,
)
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
return state_id
def load_state(
self, conversation_id: str, user: str
) -> Optional[Dict[str, Any]]:
"""Load pending continuation state.
Returns:
The state dict, or None if no pending state exists.
"""
doc = self.collection.find_one(
{"conversation_id": conversation_id, "user": user}
)
if not doc:
return None
doc["_id"] = str(doc["_id"])
return doc
def delete_state(self, conversation_id: str, user: str) -> bool:
"""Delete pending state after successful resumption.
Returns:
True if a document was deleted.
"""
result = self.collection.delete_one(
{"conversation_id": conversation_id, "user": user}
)
if result.deleted_count:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
return result.deleted_count > 0

View File

@@ -771,6 +771,115 @@ class StreamProcessor:
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
return None
def resume_from_tool_actions(
self,
tool_actions: list,
conversation_id: str,
):
"""Resume a paused agent from saved continuation state.
Loads the pending state from MongoDB, recreates the agent with
the saved configuration, and returns an agent ready to call
``gen_continuation()``.
Args:
tool_actions: Client-provided actions (approvals / results).
conversation_id: The conversation being resumed.
Returns:
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
"""
from application.api.answer.services.continuation_service import (
ContinuationService,
)
from application.agents.agent_creator import AgentCreator
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
cont_service = ContinuationService()
state = cont_service.load_state(conversation_id, self.initial_user_id)
if not state:
raise ValueError("No pending tool state found for this conversation")
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
tool_schemas = state.get("tool_schemas", [])
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
agent_id = agent_config.get("agent_id")
prompt = agent_config.get("prompt", "")
json_schema = agent_config.get("json_schema")
retriever_config = agent_config.get("retriever_config")
# Recreate dependencies
system_api_key = api_key or get_api_key_for_provider(llm_name)
llm = LLMCreator.create_llm(
llm_name,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.initial_user_id,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = conversation_id
agent_type = agent_config.get("agent_type", "ClassicAgent")
# Map class names back to agent creator keys
type_map = {
"ClassicAgent": "classic",
"AgenticAgent": "agentic",
"ResearchAgent": "research",
"WorkflowAgent": "workflow",
}
agent_key = type_map.get(agent_type, "classic")
agent_kwargs = {
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
"prompt": prompt,
"chat_history": [],
"decoded_token": self.decoded_token,
"json_schema": json_schema,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
if agent_key in ("agentic", "research") and retriever_config:
agent_kwargs["retriever_config"] = retriever_config
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
agent.conversation_id = conversation_id
agent.initial_user_id = self.initial_user_id
agent.tools = tool_schemas
# Store config for the route layer
self.model_id = model_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Delete state so it can't be replayed
cont_service.delete_state(conversation_id, self.initial_user_id)
return agent, messages, tools_dict, pending_tool_calls, tool_actions
def create_agent(
self,
docs_together: Optional[str] = None,

View File

@@ -648,6 +648,13 @@ class LLMHandler(ABC):
"""
Execute tool calls and update conversation history.
When a tool requires approval or client-side execution, it is
collected as a pending action instead of being executed. The
generator returns ``(updated_messages, pending_actions)`` where
*pending_actions* is ``None`` when every tool was executed
normally, or a list of dicts describing actions the client must
resolve before the LLM loop can continue.
Args:
agent: The agent instance
tool_calls: List of tool calls to execute
@@ -655,9 +662,11 @@ class LLMHandler(ABC):
messages: Current conversation history
Returns:
Updated messages list
Tuple of (updated_messages, pending_actions).
pending_actions is None if all tools executed, otherwise a list.
"""
updated_messages = messages.copy()
pending_actions: List[Dict] = []
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
@@ -763,6 +772,29 @@ class LLMHandler(ABC):
# Set flag on agent
agent.context_limit_reached = True
break
# ---- Pause check: approval / client-side execution ----
llm_class = agent.llm.__class__.__name__
pause_info = agent.tool_executor.check_pause(
tools_dict, call, llm_class
)
if pause_info:
# Yield pause event so the client knows this tool is waiting
yield {
"type": "tool_call",
"data": {
"tool_name": pause_info["tool_name"],
"call_id": pause_info["call_id"],
"action_name": f"{pause_info['action_name']}_{pause_info['tool_id']}",
"arguments": pause_info["arguments"],
"status": pause_info["pause_type"],
},
}
pending_actions.append(pause_info)
# Do NOT add messages for pending tools here.
# They will be added on resume to keep call/result pairs together.
continue
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -772,7 +804,7 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
@@ -823,7 +855,7 @@ class LLMHandler(ABC):
"status": "error",
},
}
return updated_messages
return updated_messages, pending_actions if pending_actions else None
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
@@ -851,8 +883,22 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
messages, pending_actions = e.value
break
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return ""
response = agent.llm.gen(
model=agent.model_id, messages=messages, tools=agent.tools
)
@@ -913,10 +959,23 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages = e.value
messages, pending_actions = e.value
break
tool_calls = {}
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit

View File

@@ -621,6 +621,7 @@ class TestHandleToolCalls:
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
@@ -641,7 +642,7 @@ class TestHandleToolCalls:
while True:
events.append(next(gen))
except StopIteration as e:
messages = e.value
messages, _pending = e.value
assert any(e.get("type") == "tool_call" for e in events)
assert len(messages) >= 2 # function_call + tool_message
@@ -675,6 +676,8 @@ class TestHandleToolCalls:
agent = Mock()
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
agent._execute_tool_action = Mock(side_effect=RuntimeError("exec error"))
call = ToolCall(id="c1", name="action_1", arguments="{}")
@@ -704,7 +707,7 @@ class TestHandleToolCalls:
while True:
next(gen)
except StopIteration as e:
messages = e.value
messages, _pending = e.value
assistant_msgs = [
m for m in messages
@@ -751,6 +754,7 @@ class TestHandleNonStreaming:
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
# First response requires tool call, second is final
call_count = {"n": 0}
@@ -856,6 +860,7 @@ class TestHandleStreaming:
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
# First chunk has partial tool call, second completes it
chunk1 = LLMResponse(
@@ -907,6 +912,7 @@ class TestHandleStreaming:
agent.context_limit_reached = True
agent._check_context_limit = Mock(return_value=True)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
# Chunk finishes with tool_calls
chunk = LLMResponse(
@@ -929,7 +935,7 @@ class TestHandleStreaming:
def fake_handle_tool_calls(agent, calls, tools_dict, messages):
agent.context_limit_reached = True
yield {"type": "tool_call", "data": {"status": "skipped"}}
return messages
return messages, None
handler.handle_tool_calls = fake_handle_tool_calls
@@ -1501,6 +1507,7 @@ class TestHandleToolCallsCompressionSuccess:
agent._check_context_limit = Mock(side_effect=check_limit)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
@@ -1538,6 +1545,7 @@ class TestHandleToolCallsCompressionSuccess:
agent = Mock()
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
exec_count = {"n": 0}

View File

@@ -478,6 +478,8 @@ class TestHandleToolCallsErrors:
handler = ConcreteHandler()
agent = MagicMock()
agent._check_context_limit = MagicMock(return_value=False)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent._execute_tool_action = MagicMock(
side_effect=RuntimeError("tool failed")
)
@@ -506,6 +508,8 @@ class TestHandleToolCallsErrors:
handler = ConcreteHandler()
agent = MagicMock()
agent._check_context_limit = MagicMock(return_value=False)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent._execute_tool_action = MagicMock(
side_effect=RuntimeError("tool failed")
)
@@ -1169,6 +1173,8 @@ class TestHandleToolCallsErrorsAdditional:
handler = ConcreteHandler()
agent = MagicMock()
agent._check_context_limit = MagicMock(return_value=False)
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent._execute_tool_action = MagicMock(
side_effect=RuntimeError("broken tool")
)
@@ -1188,7 +1194,7 @@ class TestHandleToolCallsErrorsAdditional:
while True:
events.append(next(gen))
except StopIteration as e:
final_messages = e.value
final_messages, _pending = e.value
# Verify the error message was appended
error_msgs = [
@@ -1211,6 +1217,10 @@ class TestHandleToolCallsErrorsAdditional:
"""Cover line 660: messages.copy() at start of handle_tool_calls."""
handler = ConcreteHandler()
agent = MagicMock(spec=[]) # No _check_context_limit attribute
agent.llm = MagicMock()
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor = MagicMock()
agent.tool_executor.check_pause = MagicMock(return_value=None)
agent._execute_tool_action = MagicMock(
side_effect=ValueError("bad args")
)

View File

@@ -176,6 +176,8 @@ class TestLLMHandlerTokenTracking:
# Create mock agent that hits limit on second tool
mock_agent = Mock()
mock_agent.context_limit_reached = False
mock_agent.llm.__class__.__name__ = "MockLLM"
mock_agent.tool_executor.check_pause = Mock(return_value=None)
call_count = [0]
@@ -235,6 +237,8 @@ class TestLLMHandlerTokenTracking:
mock_agent = Mock()
mock_agent.context_limit_reached = False
mock_agent._check_context_limit = Mock(return_value=False)
mock_agent.llm.__class__.__name__ = "MockLLM"
mock_agent.tool_executor.check_pause = Mock(return_value=None)
mock_agent._execute_tool_action = Mock(
return_value=iter([{"type": "tool_call", "data": {}}])
)
@@ -300,7 +304,7 @@ class TestLLMHandlerTokenTracking:
def tool_handler_gen(*args):
yield {"type": "tool", "data": {}}
return []
return [], None
# Mock handle_tool_calls to return messages and set flag
with patch.object(

667
tests/test_continuation.py Normal file
View File

@@ -0,0 +1,667 @@
"""Tests for the continuation infrastructure (Phase 1).
Covers ContinuationService, ToolExecutor.check_pause, handler pause
signaling, BaseAgent.gen_continuation, and request validation.
"""
from unittest.mock import Mock
import pytest
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
# ---------------------------------------------------------------------------
# ContinuationService
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestContinuationService:
def test_save_and_load(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
svc.save_state(
conversation_id="conv-1",
user="alice",
messages=[{"role": "user", "content": "hi"}],
pending_tool_calls=[{"call_id": "c1", "pause_type": "awaiting_approval"}],
tools_dict={"0": {"name": "test_tool"}},
tool_schemas=[{"type": "function", "function": {"name": "act_0"}}],
agent_config={"model_id": "gpt-4"},
)
state = svc.load_state("conv-1", "alice")
assert state is not None
assert state["conversation_id"] == "conv-1"
assert state["user"] == "alice"
assert len(state["messages"]) == 1
assert len(state["pending_tool_calls"]) == 1
assert state["agent_config"]["model_id"] == "gpt-4"
def test_load_returns_none_when_missing(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
assert svc.load_state("nonexistent", "alice") is None
def test_delete_state(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
svc.save_state(
conversation_id="conv-2",
user="bob",
messages=[],
pending_tool_calls=[],
tools_dict={},
tool_schemas=[],
agent_config={},
)
assert svc.delete_state("conv-2", "bob") is True
assert svc.load_state("conv-2", "bob") is None
def test_delete_nonexistent(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
assert svc.delete_state("nope", "nope") is False
def test_upsert_replaces_existing(self, mock_mongo_db):
from application.api.answer.services.continuation_service import (
ContinuationService,
)
svc = ContinuationService()
svc.save_state(
conversation_id="conv-3",
user="carol",
messages=[{"role": "user", "content": "v1"}],
pending_tool_calls=[],
tools_dict={},
tool_schemas=[],
agent_config={},
)
svc.save_state(
conversation_id="conv-3",
user="carol",
messages=[{"role": "user", "content": "v2"}],
pending_tool_calls=[{"call_id": "c2"}],
tools_dict={},
tool_schemas=[],
agent_config={},
)
state = svc.load_state("conv-3", "carol")
assert state["messages"][0]["content"] == "v2"
assert len(state["pending_tool_calls"]) == 1
# ---------------------------------------------------------------------------
# ToolExecutor.check_pause
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCheckPause:
def _make_call(self, name="action_0", call_id="c1", arguments="{}"):
call = Mock()
call.name = name
call.id = call_id
call.arguments = arguments
call.thought_signature = None
return call
def test_returns_none_for_normal_tool(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "brave",
"actions": [
{"name": "search", "active": True, "parameters": {}},
],
}
}
call = self._make_call(name="search_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is None
def test_returns_pause_for_client_side_tool(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "get_weather",
"client_side": True,
"actions": [
{"name": "get_weather", "active": True, "parameters": {}},
],
}
}
call = self._make_call(name="get_weather_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "requires_client_execution"
assert result["call_id"] == "c1"
assert result["tool_id"] == "0"
def test_returns_pause_for_approval_required(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "telegram",
"actions": [
{
"name": "send_msg",
"active": True,
"require_approval": True,
"parameters": {},
},
],
}
}
call = self._make_call(name="send_msg_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "awaiting_approval"
def test_returns_none_when_parse_fails(self):
executor = ToolExecutor()
call = self._make_call(name="bad_name_no_id", arguments="not json")
# Bad arguments will cause parse error -> None
result = executor.check_pause({}, call, "OpenAILLM")
assert result is None
def test_returns_none_when_tool_not_in_dict(self):
executor = ToolExecutor()
call = self._make_call(name="action_99")
result = executor.check_pause({"0": {"name": "t"}}, call, "OpenAILLM")
assert result is None
def test_api_tool_approval(self):
executor = ToolExecutor()
tools_dict = {
"0": {
"name": "api_tool",
"config": {
"actions": {
"delete_user": {
"name": "delete_user",
"require_approval": True,
"url": "http://example.com",
"method": "DELETE",
"active": True,
}
}
},
}
}
call = self._make_call(name="delete_user_0")
result = executor.check_pause(tools_dict, call, "OpenAILLM")
assert result is not None
assert result["pause_type"] == "awaiting_approval"
# ---------------------------------------------------------------------------
# Handler pause signaling (handle_tool_calls returns pending_actions)
# ---------------------------------------------------------------------------
class ConcreteHandler(LLMHandler):
"""Minimal concrete handler for testing."""
def parse_response(self, response):
return LLMResponse(
content=str(response), tool_calls=[], finish_reason="stop",
raw_response=response,
)
def create_tool_message(self, tool_call, result):
return {
"role": "tool",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
"call_id": tool_call.id,
}
}
],
}
def _iterate_stream(self, response):
for chunk in response:
yield chunk
@pytest.mark.unit
class TestHandlerPauseSignaling:
def _make_agent(self):
agent = Mock()
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
agent.tool_executor.check_pause = Mock(return_value=None)
def fake_execute(tools_dict, call):
yield {"type": "tool_call", "data": {"status": "pending"}}
return ("tool result", call.id)
agent._execute_tool_action = Mock(side_effect=fake_execute)
return agent
def test_no_pause_returns_none_pending(self):
handler = ConcreteHandler()
agent = self._make_agent()
call = ToolCall(id="c1", name="action_0", arguments="{}")
gen = handler.handle_tool_calls(agent, [call], {"0": {"name": "t"}}, [])
events = []
messages = None
pending = "NOT_SET"
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
assert pending is None
assert messages is not None
def test_pause_returns_pending_actions(self):
handler = ConcreteHandler()
agent = self._make_agent()
agent.tool_executor.check_pause = Mock(return_value={
"call_id": "c1",
"name": "send_msg_0",
"tool_name": "telegram",
"tool_id": "0",
"action_name": "send_msg",
"arguments": {"text": "hello"},
"pause_type": "awaiting_approval",
"thought_signature": None,
})
call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}')
gen = handler.handle_tool_calls(
agent, [call], {"0": {"name": "telegram"}}, []
)
events = []
pending = None
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
assert pending is not None
assert len(pending) == 1
assert pending[0]["pause_type"] == "awaiting_approval"
# Should have yielded a tool_call event with awaiting_approval status
pause_events = [
e for e in events
if e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "awaiting_approval"
]
assert len(pause_events) == 1
def test_mixed_execute_and_pause(self):
"""One tool executes, another needs approval."""
handler = ConcreteHandler()
agent = self._make_agent()
call_count = {"n": 0}
def selective_pause(tools_dict, call, llm_class):
call_count["n"] += 1
if call_count["n"] == 2:
return {
"call_id": "c2",
"name": "danger_0",
"tool_name": "danger",
"tool_id": "0",
"action_name": "danger",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
return None
agent.tool_executor.check_pause = Mock(side_effect=selective_pause)
calls = [
ToolCall(id="c1", name="safe_0", arguments="{}"),
ToolCall(id="c2", name="danger_0", arguments="{}"),
]
gen = handler.handle_tool_calls(
agent, calls, {"0": {"name": "multi"}}, []
)
events = []
try:
while True:
events.append(next(gen))
except StopIteration as e:
messages, pending = e.value
# First tool was executed normally
assert agent._execute_tool_action.call_count == 1
# Second tool is pending
assert pending is not None
assert len(pending) == 1
assert pending[0]["call_id"] == "c2"
# ---------------------------------------------------------------------------
# handle_streaming yields tool_calls_pending
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestStreamingPause:
def test_streaming_yields_tool_calls_pending(self):
handler = ConcreteHandler()
agent = Mock()
agent.llm = Mock()
agent.model_id = "test"
agent.tools = []
agent._check_context_limit = Mock(return_value=False)
agent.context_limit_reached = False
agent.llm.__class__.__name__ = "MockLLM"
pause_info = {
"call_id": "c1",
"name": "fn_0",
"tool_name": "test",
"tool_id": "0",
"action_name": "fn",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
agent.tool_executor.check_pause = Mock(return_value=pause_info)
chunk = LLMResponse(
content="",
tool_calls=[ToolCall(id="c1", name="fn_0", arguments="{}", index=0)],
finish_reason="tool_calls",
raw_response={},
)
handler.parse_response = lambda c: c
def fake_iterate(response):
yield from response
handler._iterate_stream = fake_iterate
gen = handler.handle_streaming(agent, [chunk], {"0": {"name": "t"}}, [])
events = list(gen)
# Should contain a tool_calls_pending event
pending_events = [
e for e in events
if isinstance(e, dict) and e.get("type") == "tool_calls_pending"
]
assert len(pending_events) == 1
assert len(pending_events[0]["data"]["pending_tool_calls"]) == 1
# Agent should have _pending_continuation set
assert hasattr(agent, "_pending_continuation")
# ---------------------------------------------------------------------------
# BaseAgent.gen_continuation
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGenContinuation:
def test_approved_tool_executes(self):
"""When a tool action is approved, the tool is executed."""
from application.agents.classic_agent import ClassicAgent
mock_llm = Mock()
mock_llm._supports_tools = True
mock_llm.gen_stream = Mock(return_value=iter(["Final answer"]))
mock_llm._supports_structured_output = Mock(return_value=False)
mock_llm.__class__.__name__ = "MockLLM"
mock_handler = Mock()
mock_handler.process_message_flow = Mock(return_value=iter([]))
mock_handler.create_tool_message = Mock(
return_value={"role": "tool", "content": [{"function_response": {
"name": "act_0", "response": {"result": "done"}, "call_id": "c1"
}}]}
)
mock_executor = Mock()
mock_executor.tool_calls = []
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
def fake_execute(tools_dict, call, llm_class):
yield {"type": "tool_call", "data": {"status": "pending"}}
return ("result_data", "c1")
mock_executor.execute = Mock(side_effect=fake_execute)
agent = ClassicAgent(
endpoint="stream",
llm_name="openai",
model_id="gpt-4",
api_key="test",
llm=mock_llm,
llm_handler=mock_handler,
tool_executor=mock_executor,
)
messages = [{"role": "system", "content": "You are helpful."}]
tools_dict = {"0": {"name": "test_tool"}}
pending = [
{
"call_id": "c1",
"name": "act_0",
"tool_name": "test_tool",
"tool_id": "0",
"action_name": "act",
"arguments": {"q": "test"},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
tool_actions = [{"call_id": "c1", "decision": "approved"}]
list(agent.gen_continuation(messages, tools_dict, pending, tool_actions))
# Tool should have been executed
assert mock_executor.execute.called
def test_denied_tool_sends_denial(self):
"""When a tool action is denied, a denial message is added."""
from application.agents.classic_agent import ClassicAgent
mock_llm = Mock()
mock_llm._supports_tools = True
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
mock_llm._supports_structured_output = Mock(return_value=False)
mock_llm.__class__.__name__ = "MockLLM"
mock_handler = Mock()
mock_handler.process_message_flow = Mock(return_value=iter([]))
mock_handler.create_tool_message = Mock(
return_value={"role": "tool", "content": "denied"}
)
mock_executor = Mock()
mock_executor.tool_calls = []
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
agent = ClassicAgent(
endpoint="stream",
llm_name="openai",
model_id="gpt-4",
api_key="test",
llm=mock_llm,
llm_handler=mock_handler,
tool_executor=mock_executor,
)
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "danger_0",
"tool_name": "danger",
"tool_id": "0",
"action_name": "danger",
"arguments": {},
"pause_type": "awaiting_approval",
"thought_signature": None,
}
]
tool_actions = [
{"call_id": "c1", "decision": "denied", "comment": "too risky"}
]
events = list(
agent.gen_continuation(messages, {"0": {"name": "danger"}}, pending, tool_actions)
)
# Should have a denied tool_call event
denied = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "denied"
]
assert len(denied) == 1
# create_tool_message should have been called with denial text
denial_arg = mock_handler.create_tool_message.call_args[0][1]
assert "denied" in denial_arg.lower()
assert "too risky" in denial_arg
def test_client_result_appended(self):
"""Client-provided tool result is added to messages."""
from application.agents.classic_agent import ClassicAgent
mock_llm = Mock()
mock_llm._supports_tools = True
mock_llm.gen_stream = Mock(return_value=iter(["Done"]))
mock_llm._supports_structured_output = Mock(return_value=False)
mock_llm.__class__.__name__ = "MockLLM"
mock_handler = Mock()
mock_handler.process_message_flow = Mock(return_value=iter([]))
mock_handler.create_tool_message = Mock(
return_value={"role": "tool", "content": "client result"}
)
mock_executor = Mock()
mock_executor.tool_calls = []
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
agent = ClassicAgent(
endpoint="stream",
llm_name="openai",
model_id="gpt-4",
api_key="test",
llm=mock_llm,
llm_handler=mock_handler,
tool_executor=mock_executor,
)
messages = [{"role": "system", "content": "test"}]
pending = [
{
"call_id": "c1",
"name": "weather_0",
"tool_name": "weather",
"tool_id": "0",
"action_name": "weather",
"arguments": {"city": "SF"},
"pause_type": "requires_client_execution",
"thought_signature": None,
}
]
tool_actions = [{"call_id": "c1", "result": {"temp": "72F"}}]
events = list(
agent.gen_continuation(messages, {"0": {"name": "weather"}}, pending, tool_actions)
)
# create_tool_message was called with the client result
result_arg = mock_handler.create_tool_message.call_args[0][1]
assert "72F" in result_arg
# Should have a completed tool_call event
completed = [
e for e in events
if isinstance(e, dict)
and e.get("type") == "tool_call"
and e.get("data", {}).get("status") == "completed"
]
assert len(completed) == 1
# ---------------------------------------------------------------------------
# validate_request
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestValidateRequest:
@pytest.fixture(autouse=True)
def _app_context(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
yield
def test_continuation_request_without_question(self, mock_mongo_db):
from application.api.answer.routes.base import BaseAnswerResource
base = BaseAnswerResource()
data = {
"conversation_id": "conv-1",
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
}
result = base.validate_request(data)
assert result is None # Valid
def test_continuation_request_missing_conversation_id(self, mock_mongo_db):
from application.api.answer.routes.base import BaseAnswerResource
base = BaseAnswerResource()
data = {
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
}
result = base.validate_request(data)
assert result is not None # Error — missing conversation_id
def test_normal_request_still_requires_question(self, mock_mongo_db):
from application.api.answer.routes.base import BaseAnswerResource
base = BaseAnswerResource()
data = {"conversation_id": "conv-1"}
result = base.validate_request(data)
assert result is not None # Error — missing question