mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
feat: continuation messages
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generator, List, Optional
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.core.json_schema_utils import (
|
||||
@@ -9,6 +10,7 @@ from application.core.json_schema_utils import (
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.base import ToolCall
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
@@ -113,6 +115,140 @@ class BaseAgent(ABC):
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
def gen_continuation(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools_dict: Dict,
|
||||
pending_tool_calls: List[Dict],
|
||||
tool_actions: List[Dict],
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Resume generation after tool actions are resolved.
|
||||
|
||||
Processes the client-provided *tool_actions* (approvals, denials,
|
||||
or client-side results), appends the resulting messages, then
|
||||
hands back to the LLM to continue the conversation.
|
||||
|
||||
Args:
|
||||
messages: The saved messages array from the pause point.
|
||||
tools_dict: The saved tools dictionary.
|
||||
pending_tool_calls: The pending tool call descriptors from the pause.
|
||||
tool_actions: Client-provided actions resolving the pending calls.
|
||||
"""
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
action = actions_by_id.get(call_id)
|
||||
if not action:
|
||||
action = {
|
||||
"call_id": call_id,
|
||||
"decision": "denied",
|
||||
"comment": "No response provided",
|
||||
}
|
||||
|
||||
# Build the assistant tool-call message
|
||||
args = pending["arguments"]
|
||||
function_call_content: Dict[str, Any] = {
|
||||
"function_call": {
|
||||
"name": pending["name"],
|
||||
"args": args,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
if pending.get("thought_signature"):
|
||||
function_call_content["thought_signature"] = pending[
|
||||
"thought_signature"
|
||||
]
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_content]}
|
||||
)
|
||||
|
||||
if action.get("decision") == "approved":
|
||||
# Execute the tool server-side
|
||||
tc = ToolCall(
|
||||
id=call_id,
|
||||
name=pending["name"],
|
||||
arguments=(
|
||||
json.dumps(args) if isinstance(args, dict) else args
|
||||
),
|
||||
)
|
||||
tool_gen = self._execute_tool_action(tools_dict, tc)
|
||||
tool_response = None
|
||||
while True:
|
||||
try:
|
||||
event = next(tool_gen)
|
||||
yield event
|
||||
except StopIteration as e:
|
||||
tool_response, _ = e.value
|
||||
break
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, tool_response)
|
||||
)
|
||||
|
||||
elif action.get("decision") == "denied":
|
||||
comment = action.get("comment", "")
|
||||
denial = (
|
||||
f"Tool execution denied by user. Reason: {comment}"
|
||||
if comment
|
||||
else "Tool execution denied by user."
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, denial)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": f"{pending['action_name']}_{pending['tool_id']}",
|
||||
"arguments": args,
|
||||
"status": "denied",
|
||||
},
|
||||
}
|
||||
|
||||
elif "result" in action:
|
||||
result = action["result"]
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, result_str)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": f"{pending['action_name']}_{pending['tool_id']}",
|
||||
"arguments": args,
|
||||
"result": (
|
||||
result_str[:50] + "..."
|
||||
if len(result_str) > 50
|
||||
else result_str
|
||||
),
|
||||
"status": "completed",
|
||||
},
|
||||
}
|
||||
|
||||
# Resume the LLM loop with the updated messages
|
||||
llm_response = self._llm_gen(messages)
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, None
|
||||
)
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||
|
||||
@property
|
||||
|
||||
@@ -104,6 +104,60 @@ class ToolExecutor:
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def check_pause(
|
||||
self, tools_dict: Dict, call, llm_class_name: str
|
||||
) -> Optional[Dict]:
|
||||
"""Check if a tool call requires pausing for approval or client execution.
|
||||
|
||||
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||
"""
|
||||
parser = ToolActionParser(llm_class_name)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
if tool_id is None or action_name is None or tool_id not in tools_dict:
|
||||
return None # Will be handled as error by execute()
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
|
||||
# Phase 2: client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": getattr(call, "name", f"{action_name}_{tool_id}"),
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
# Phase 3: approval required
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||
action_name, {}
|
||||
)
|
||||
else:
|
||||
action_data = next(
|
||||
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
|
||||
{},
|
||||
)
|
||||
|
||||
if action_data.get("require_approval"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": getattr(call, "name", f"{action_name}_{tool_id}"),
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||
parser = ToolActionParser(llm_class_name)
|
||||
|
||||
@@ -74,27 +74,56 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
stream = self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if len(stream_result) == 7:
|
||||
@@ -105,16 +134,17 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
tool_calls,
|
||||
thought,
|
||||
error,
|
||||
structured_info,
|
||||
extra_info,
|
||||
) = stream_result
|
||||
else:
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
stream_result
|
||||
)
|
||||
structured_info = None
|
||||
extra_info = None
|
||||
|
||||
if error:
|
||||
return make_response({"error": error}, 400)
|
||||
|
||||
result = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response,
|
||||
@@ -123,8 +153,8 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"thought": thought,
|
||||
}
|
||||
|
||||
if structured_info:
|
||||
result.update(structured_info)
|
||||
if extra_info:
|
||||
result.update(extra_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
@@ -39,7 +40,16 @@ class BaseAnswerResource:
|
||||
def validate_request(
|
||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||
) -> Optional[Response]:
|
||||
"""Common request validation"""
|
||||
"""Common request validation.
|
||||
|
||||
Continuation requests (``tool_actions`` present) require
|
||||
``conversation_id`` but not ``question``.
|
||||
"""
|
||||
if data.get("tool_actions"):
|
||||
# Continuation mode — question is not required
|
||||
if missing := check_required_fields(data, ["conversation_id"]):
|
||||
return missing
|
||||
return None
|
||||
required_fields = ["question"]
|
||||
if require_conversation_id:
|
||||
required_fields.append("conversation_id")
|
||||
@@ -177,6 +187,7 @@ class BaseAnswerResource:
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
_continuation: Optional[Dict] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -207,8 +218,19 @@ class BaseAnswerResource:
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
|
||||
for line in agent.gen(query=question):
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
messages=_continuation["messages"],
|
||||
tools_dict=_continuation["tools_dict"],
|
||||
pending_tool_calls=_continuation["pending_tool_calls"],
|
||||
tool_actions=_continuation["tool_actions"],
|
||||
)
|
||||
else:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
@@ -244,15 +266,21 @@ class BaseAnswerResource:
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
if line.get("type") == "error":
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield f"data: {data}\n\n"
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -262,6 +290,46 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation and conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.save_state(
|
||||
conversation_id=str(conversation_id),
|
||||
user=decoded_token.get("sub", "local"),
|
||||
messages=continuation["messages"],
|
||||
pending_tool_calls=continuation["pending_tool_calls"],
|
||||
tools_dict=continuation["tools_dict"],
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent.__class__.__name__,
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
@@ -435,6 +503,7 @@ class BaseAnswerResource:
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
pending_tool_calls = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
@@ -453,6 +522,10 @@ class BaseAnswerResource:
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "tool_calls_pending":
|
||||
pending_tool_calls = event.get("data", {}).get(
|
||||
"pending_tool_calls", []
|
||||
)
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
@@ -466,6 +539,18 @@ class BaseAnswerResource:
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
|
||||
if pending_tool_calls is not None:
|
||||
return (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
thought,
|
||||
None,
|
||||
{"pending_tool_calls": pending_tool_calls},
|
||||
)
|
||||
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
|
||||
@@ -79,7 +79,39 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data["question"])
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
|
||||
141
application/api/answer/services/continuation_service.py
Normal file
141
application/api/answer/services/continuation_service.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Service for saving and restoring tool-call continuation state.
|
||||
|
||||
When a stream pauses (tool needs approval or client-side execution),
|
||||
the full execution state is persisted to MongoDB so the client can
|
||||
resume later by sending tool_actions.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
|
||||
if isinstance(obj, ObjectId):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
"""Manages pending tool-call state in MongoDB."""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.collection = db["pending_tool_state"]
|
||||
self._ensure_indexes()
|
||||
|
||||
def _ensure_indexes(self):
|
||||
try:
|
||||
self.collection.create_index(
|
||||
"expires_at", expireAfterSeconds=0
|
||||
)
|
||||
self.collection.create_index(
|
||||
[("conversation_id", 1), ("user", 1)], unique=True
|
||||
)
|
||||
except Exception:
|
||||
# Indexes may already exist or mongomock doesn't support TTL
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user: str,
|
||||
messages: List[Dict],
|
||||
pending_tool_calls: List[Dict],
|
||||
tools_dict: Dict,
|
||||
tool_schemas: List[Dict],
|
||||
agent_config: Dict,
|
||||
client_tools: Optional[List[Dict]] = None,
|
||||
) -> str:
|
||||
"""Save execution state for later continuation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation this state belongs to.
|
||||
user: Owner user ID.
|
||||
messages: Full messages array at the pause point.
|
||||
pending_tool_calls: Tool calls awaiting client action.
|
||||
tools_dict: Serializable tools configuration dict.
|
||||
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||
agent_config: Config needed to recreate the agent on resume.
|
||||
client_tools: Client-provided tool schemas (Phase 2).
|
||||
|
||||
Returns:
|
||||
The string ID of the saved state document.
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
|
||||
|
||||
doc = {
|
||||
"conversation_id": conversation_id,
|
||||
"user": user,
|
||||
"messages": _make_serializable(messages),
|
||||
"pending_tool_calls": _make_serializable(pending_tool_calls),
|
||||
"tools_dict": _make_serializable(tools_dict),
|
||||
"tool_schemas": _make_serializable(tool_schemas),
|
||||
"agent_config": _make_serializable(agent_config),
|
||||
"client_tools": _make_serializable(client_tools) if client_tools else None,
|
||||
"created_at": now,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
# Upsert — only one pending state per conversation per user
|
||||
result = self.collection.replace_one(
|
||||
{"conversation_id": conversation_id, "user": user},
|
||||
doc,
|
||||
upsert=True,
|
||||
)
|
||||
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
|
||||
logger.info(
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
return state_id
|
||||
|
||||
def load_state(
|
||||
self, conversation_id: str, user: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pending continuation state.
|
||||
|
||||
Returns:
|
||||
The state dict, or None if no pending state exists.
|
||||
"""
|
||||
doc = self.collection.find_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
doc["_id"] = str(doc["_id"])
|
||||
return doc
|
||||
|
||||
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||
"""Delete pending state after successful resumption.
|
||||
|
||||
Returns:
|
||||
True if a document was deleted.
|
||||
"""
|
||||
result = self.collection.delete_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
if result.deleted_count:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
return result.deleted_count > 0
|
||||
@@ -771,6 +771,115 @@ class StreamProcessor:
|
||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||
return None
|
||||
|
||||
def resume_from_tool_actions(
|
||||
self,
|
||||
tool_actions: list,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""Resume a paused agent from saved continuation state.
|
||||
|
||||
Loads the pending state from MongoDB, recreates the agent with
|
||||
the saved configuration, and returns an agent ready to call
|
||||
``gen_continuation()``.
|
||||
|
||||
Args:
|
||||
tool_actions: Client-provided actions (approvals / results).
|
||||
conversation_id: The conversation being resumed.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
|
||||
"""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
cont_service = ContinuationService()
|
||||
state = cont_service.load_state(conversation_id, self.initial_user_id)
|
||||
if not state:
|
||||
raise ValueError("No pending tool state found for this conversation")
|
||||
|
||||
messages = state["messages"]
|
||||
pending_tool_calls = state["pending_tool_calls"]
|
||||
tools_dict = state["tools_dict"]
|
||||
tool_schemas = state.get("tool_schemas", [])
|
||||
agent_config = state["agent_config"]
|
||||
|
||||
model_id = agent_config.get("model_id")
|
||||
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||||
api_key = agent_config.get("api_key")
|
||||
user_api_key = agent_config.get("user_api_key")
|
||||
agent_id = agent_config.get("agent_id")
|
||||
prompt = agent_config.get("prompt", "")
|
||||
json_schema = agent_config.get("json_schema")
|
||||
retriever_config = agent_config.get("retriever_config")
|
||||
|
||||
# Recreate dependencies
|
||||
system_api_key = api_key or get_api_key_for_provider(llm_name)
|
||||
llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||||
tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=self.initial_user_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
tool_executor.conversation_id = conversation_id
|
||||
|
||||
agent_type = agent_config.get("agent_type", "ClassicAgent")
|
||||
# Map class names back to agent creator keys
|
||||
type_map = {
|
||||
"ClassicAgent": "classic",
|
||||
"AgenticAgent": "agentic",
|
||||
"ResearchAgent": "research",
|
||||
"WorkflowAgent": "workflow",
|
||||
}
|
||||
agent_key = type_map.get(agent_type, "classic")
|
||||
|
||||
agent_kwargs = {
|
||||
"endpoint": "stream",
|
||||
"llm_name": llm_name,
|
||||
"model_id": model_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": agent_id,
|
||||
"user_api_key": user_api_key,
|
||||
"prompt": prompt,
|
||||
"chat_history": [],
|
||||
"decoded_token": self.decoded_token,
|
||||
"json_schema": json_schema,
|
||||
"llm": llm,
|
||||
"llm_handler": llm_handler,
|
||||
"tool_executor": tool_executor,
|
||||
}
|
||||
|
||||
if agent_key in ("agentic", "research") and retriever_config:
|
||||
agent_kwargs["retriever_config"] = retriever_config
|
||||
|
||||
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
|
||||
agent.conversation_id = conversation_id
|
||||
agent.initial_user_id = self.initial_user_id
|
||||
agent.tools = tool_schemas
|
||||
|
||||
# Store config for the route layer
|
||||
self.model_id = model_id
|
||||
self.agent_id = agent_id
|
||||
self.agent_config["user_api_key"] = user_api_key
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
# Delete state so it can't be replayed
|
||||
cont_service.delete_state(conversation_id, self.initial_user_id)
|
||||
|
||||
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
docs_together: Optional[str] = None,
|
||||
|
||||
@@ -648,6 +648,13 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
Execute tool calls and update conversation history.
|
||||
|
||||
When a tool requires approval or client-side execution, it is
|
||||
collected as a pending action instead of being executed. The
|
||||
generator returns ``(updated_messages, pending_actions)`` where
|
||||
*pending_actions* is ``None`` when every tool was executed
|
||||
normally, or a list of dicts describing actions the client must
|
||||
resolve before the LLM loop can continue.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
tool_calls: List of tool calls to execute
|
||||
@@ -655,9 +662,11 @@ class LLMHandler(ABC):
|
||||
messages: Current conversation history
|
||||
|
||||
Returns:
|
||||
Updated messages list
|
||||
Tuple of (updated_messages, pending_actions).
|
||||
pending_actions is None if all tools executed, otherwise a list.
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
pending_actions: List[Dict] = []
|
||||
|
||||
for i, call in enumerate(tool_calls):
|
||||
# Check context limit before executing tool call
|
||||
@@ -763,6 +772,29 @@ class LLMHandler(ABC):
|
||||
# Set flag on agent
|
||||
agent.context_limit_reached = True
|
||||
break
|
||||
|
||||
# ---- Pause check: approval / client-side execution ----
|
||||
llm_class = agent.llm.__class__.__name__
|
||||
pause_info = agent.tool_executor.check_pause(
|
||||
tools_dict, call, llm_class
|
||||
)
|
||||
if pause_info:
|
||||
# Yield pause event so the client knows this tool is waiting
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pause_info["tool_name"],
|
||||
"call_id": pause_info["call_id"],
|
||||
"action_name": f"{pause_info['action_name']}_{pause_info['tool_id']}",
|
||||
"arguments": pause_info["arguments"],
|
||||
"status": pause_info["pause_type"],
|
||||
},
|
||||
}
|
||||
pending_actions.append(pause_info)
|
||||
# Do NOT add messages for pending tools here.
|
||||
# They will be added on resume to keep call/result pairs together.
|
||||
continue
|
||||
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
@@ -772,7 +804,7 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
@@ -823,7 +855,7 @@ class LLMHandler(ABC):
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
return updated_messages
|
||||
return updated_messages, pending_actions if pending_actions else None
|
||||
|
||||
def handle_non_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
@@ -851,8 +883,22 @@ class LLMHandler(ABC):
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, pending_actions = e.value
|
||||
break
|
||||
|
||||
# If tools need approval or client execution, pause the loop
|
||||
if pending_actions:
|
||||
agent._pending_continuation = {
|
||||
"messages": messages,
|
||||
"pending_tool_calls": pending_actions,
|
||||
"tools_dict": tools_dict,
|
||||
}
|
||||
yield {
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": pending_actions},
|
||||
}
|
||||
return ""
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
)
|
||||
@@ -913,10 +959,23 @@ class LLMHandler(ABC):
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, pending_actions = e.value
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
# If tools need approval or client execution, pause the loop
|
||||
if pending_actions:
|
||||
agent._pending_continuation = {
|
||||
"messages": messages,
|
||||
"pending_tool_calls": pending_actions,
|
||||
"tools_dict": tools_dict,
|
||||
}
|
||||
yield {
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": pending_actions},
|
||||
}
|
||||
return
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
|
||||
@@ -621,6 +621,7 @@ class TestHandleToolCalls:
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
@@ -641,7 +642,7 @@ class TestHandleToolCalls:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, _pending = e.value
|
||||
|
||||
assert any(e.get("type") == "tool_call" for e in events)
|
||||
assert len(messages) >= 2 # function_call + tool_message
|
||||
@@ -675,6 +676,8 @@ class TestHandleToolCalls:
|
||||
agent = Mock()
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
agent._execute_tool_action = Mock(side_effect=RuntimeError("exec error"))
|
||||
|
||||
call = ToolCall(id="c1", name="action_1", arguments="{}")
|
||||
@@ -704,7 +707,7 @@ class TestHandleToolCalls:
|
||||
while True:
|
||||
next(gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
messages, _pending = e.value
|
||||
|
||||
assistant_msgs = [
|
||||
m for m in messages
|
||||
@@ -751,6 +754,7 @@ class TestHandleNonStreaming:
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
# First response requires tool call, second is final
|
||||
call_count = {"n": 0}
|
||||
@@ -856,6 +860,7 @@ class TestHandleStreaming:
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
# First chunk has partial tool call, second completes it
|
||||
chunk1 = LLMResponse(
|
||||
@@ -907,6 +912,7 @@ class TestHandleStreaming:
|
||||
agent.context_limit_reached = True
|
||||
agent._check_context_limit = Mock(return_value=True)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
# Chunk finishes with tool_calls
|
||||
chunk = LLMResponse(
|
||||
@@ -929,7 +935,7 @@ class TestHandleStreaming:
|
||||
def fake_handle_tool_calls(agent, calls, tools_dict, messages):
|
||||
agent.context_limit_reached = True
|
||||
yield {"type": "tool_call", "data": {"status": "skipped"}}
|
||||
return messages
|
||||
return messages, None
|
||||
|
||||
handler.handle_tool_calls = fake_handle_tool_calls
|
||||
|
||||
@@ -1501,6 +1507,7 @@ class TestHandleToolCallsCompressionSuccess:
|
||||
agent._check_context_limit = Mock(side_effect=check_limit)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
@@ -1538,6 +1545,7 @@ class TestHandleToolCallsCompressionSuccess:
|
||||
agent = Mock()
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
exec_count = {"n": 0}
|
||||
|
||||
|
||||
@@ -478,6 +478,8 @@ class TestHandleToolCallsErrors:
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock()
|
||||
agent._check_context_limit = MagicMock(return_value=False)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=RuntimeError("tool failed")
|
||||
)
|
||||
@@ -506,6 +508,8 @@ class TestHandleToolCallsErrors:
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock()
|
||||
agent._check_context_limit = MagicMock(return_value=False)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=RuntimeError("tool failed")
|
||||
)
|
||||
@@ -1169,6 +1173,8 @@ class TestHandleToolCallsErrorsAdditional:
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock()
|
||||
agent._check_context_limit = MagicMock(return_value=False)
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=RuntimeError("broken tool")
|
||||
)
|
||||
@@ -1188,7 +1194,7 @@ class TestHandleToolCallsErrorsAdditional:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
final_messages = e.value
|
||||
final_messages, _pending = e.value
|
||||
|
||||
# Verify the error message was appended
|
||||
error_msgs = [
|
||||
@@ -1211,6 +1217,10 @@ class TestHandleToolCallsErrorsAdditional:
|
||||
"""Cover line 660: messages.copy() at start of handle_tool_calls."""
|
||||
handler = ConcreteHandler()
|
||||
agent = MagicMock(spec=[]) # No _check_context_limit attribute
|
||||
agent.llm = MagicMock()
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor = MagicMock()
|
||||
agent.tool_executor.check_pause = MagicMock(return_value=None)
|
||||
agent._execute_tool_action = MagicMock(
|
||||
side_effect=ValueError("bad args")
|
||||
)
|
||||
|
||||
@@ -176,6 +176,8 @@ class TestLLMHandlerTokenTracking:
|
||||
# Create mock agent that hits limit on second tool
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
mock_agent.llm.__class__.__name__ = "MockLLM"
|
||||
mock_agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
@@ -235,6 +237,8 @@ class TestLLMHandlerTokenTracking:
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
mock_agent._check_context_limit = Mock(return_value=False)
|
||||
mock_agent.llm.__class__.__name__ = "MockLLM"
|
||||
mock_agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
mock_agent._execute_tool_action = Mock(
|
||||
return_value=iter([{"type": "tool_call", "data": {}}])
|
||||
)
|
||||
@@ -300,7 +304,7 @@ class TestLLMHandlerTokenTracking:
|
||||
|
||||
def tool_handler_gen(*args):
|
||||
yield {"type": "tool", "data": {}}
|
||||
return []
|
||||
return [], None
|
||||
|
||||
# Mock handle_tool_calls to return messages and set flag
|
||||
with patch.object(
|
||||
|
||||
667
tests/test_continuation.py
Normal file
667
tests/test_continuation.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""Tests for the continuation infrastructure (Phase 1).
|
||||
|
||||
Covers ContinuationService, ToolExecutor.check_pause, handler pause
|
||||
signaling, BaseAgent.gen_continuation, and request validation.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContinuationService
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContinuationService:
|
||||
|
||||
def test_save_and_load(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
svc.save_state(
|
||||
conversation_id="conv-1",
|
||||
user="alice",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
pending_tool_calls=[{"call_id": "c1", "pause_type": "awaiting_approval"}],
|
||||
tools_dict={"0": {"name": "test_tool"}},
|
||||
tool_schemas=[{"type": "function", "function": {"name": "act_0"}}],
|
||||
agent_config={"model_id": "gpt-4"},
|
||||
)
|
||||
|
||||
state = svc.load_state("conv-1", "alice")
|
||||
assert state is not None
|
||||
assert state["conversation_id"] == "conv-1"
|
||||
assert state["user"] == "alice"
|
||||
assert len(state["messages"]) == 1
|
||||
assert len(state["pending_tool_calls"]) == 1
|
||||
assert state["agent_config"]["model_id"] == "gpt-4"
|
||||
|
||||
def test_load_returns_none_when_missing(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
assert svc.load_state("nonexistent", "alice") is None
|
||||
|
||||
def test_delete_state(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
svc.save_state(
|
||||
conversation_id="conv-2",
|
||||
user="bob",
|
||||
messages=[],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
assert svc.delete_state("conv-2", "bob") is True
|
||||
assert svc.load_state("conv-2", "bob") is None
|
||||
|
||||
def test_delete_nonexistent(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
assert svc.delete_state("nope", "nope") is False
|
||||
|
||||
def test_upsert_replaces_existing(self, mock_mongo_db):
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
|
||||
svc = ContinuationService()
|
||||
svc.save_state(
|
||||
conversation_id="conv-3",
|
||||
user="carol",
|
||||
messages=[{"role": "user", "content": "v1"}],
|
||||
pending_tool_calls=[],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
svc.save_state(
|
||||
conversation_id="conv-3",
|
||||
user="carol",
|
||||
messages=[{"role": "user", "content": "v2"}],
|
||||
pending_tool_calls=[{"call_id": "c2"}],
|
||||
tools_dict={},
|
||||
tool_schemas=[],
|
||||
agent_config={},
|
||||
)
|
||||
state = svc.load_state("conv-3", "carol")
|
||||
assert state["messages"][0]["content"] == "v2"
|
||||
assert len(state["pending_tool_calls"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolExecutor.check_pause
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckPause:
|
||||
|
||||
def _make_call(self, name="action_0", call_id="c1", arguments="{}"):
|
||||
call = Mock()
|
||||
call.name = name
|
||||
call.id = call_id
|
||||
call.arguments = arguments
|
||||
call.thought_signature = None
|
||||
return call
|
||||
|
||||
def test_returns_none_for_normal_tool(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "brave",
|
||||
"actions": [
|
||||
{"name": "search", "active": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="search_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_returns_pause_for_client_side_tool(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "get_weather",
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{"name": "get_weather", "active": True, "parameters": {}},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="get_weather_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "requires_client_execution"
|
||||
assert result["call_id"] == "c1"
|
||||
assert result["tool_id"] == "0"
|
||||
|
||||
def test_returns_pause_for_approval_required(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "telegram",
|
||||
"actions": [
|
||||
{
|
||||
"name": "send_msg",
|
||||
"active": True,
|
||||
"require_approval": True,
|
||||
"parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="send_msg_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "awaiting_approval"
|
||||
|
||||
def test_returns_none_when_parse_fails(self):
|
||||
executor = ToolExecutor()
|
||||
call = self._make_call(name="bad_name_no_id", arguments="not json")
|
||||
# Bad arguments will cause parse error -> None
|
||||
result = executor.check_pause({}, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_tool_not_in_dict(self):
|
||||
executor = ToolExecutor()
|
||||
call = self._make_call(name="action_99")
|
||||
result = executor.check_pause({"0": {"name": "t"}}, call, "OpenAILLM")
|
||||
assert result is None
|
||||
|
||||
def test_api_tool_approval(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"0": {
|
||||
"name": "api_tool",
|
||||
"config": {
|
||||
"actions": {
|
||||
"delete_user": {
|
||||
"name": "delete_user",
|
||||
"require_approval": True,
|
||||
"url": "http://example.com",
|
||||
"method": "DELETE",
|
||||
"active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
call = self._make_call(name="delete_user_0")
|
||||
result = executor.check_pause(tools_dict, call, "OpenAILLM")
|
||||
assert result is not None
|
||||
assert result["pause_type"] == "awaiting_approval"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler pause signaling (handle_tool_calls returns pending_actions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteHandler(LLMHandler):
|
||||
"""Minimal concrete handler for testing."""
|
||||
|
||||
def parse_response(self, response):
|
||||
return LLMResponse(
|
||||
content=str(response), tool_calls=[], finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
"call_id": tool_call.id,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandlerPauseSignaling:
|
||||
|
||||
def _make_agent(self):
|
||||
agent = Mock()
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
agent.tool_executor.check_pause = Mock(return_value=None)
|
||||
|
||||
def fake_execute(tools_dict, call):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
return ("tool result", call.id)
|
||||
|
||||
agent._execute_tool_action = Mock(side_effect=fake_execute)
|
||||
return agent
|
||||
|
||||
def test_no_pause_returns_none_pending(self):
|
||||
handler = ConcreteHandler()
|
||||
agent = self._make_agent()
|
||||
call = ToolCall(id="c1", name="action_0", arguments="{}")
|
||||
|
||||
gen = handler.handle_tool_calls(agent, [call], {"0": {"name": "t"}}, [])
|
||||
events = []
|
||||
messages = None
|
||||
pending = "NOT_SET"
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
assert pending is None
|
||||
assert messages is not None
|
||||
|
||||
def test_pause_returns_pending_actions(self):
|
||||
handler = ConcreteHandler()
|
||||
agent = self._make_agent()
|
||||
agent.tool_executor.check_pause = Mock(return_value={
|
||||
"call_id": "c1",
|
||||
"name": "send_msg_0",
|
||||
"tool_name": "telegram",
|
||||
"tool_id": "0",
|
||||
"action_name": "send_msg",
|
||||
"arguments": {"text": "hello"},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
})
|
||||
|
||||
call = ToolCall(id="c1", name="send_msg_0", arguments='{"text": "hello"}')
|
||||
gen = handler.handle_tool_calls(
|
||||
agent, [call], {"0": {"name": "telegram"}}, []
|
||||
)
|
||||
|
||||
events = []
|
||||
pending = None
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
assert pending is not None
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["pause_type"] == "awaiting_approval"
|
||||
|
||||
# Should have yielded a tool_call event with awaiting_approval status
|
||||
pause_events = [
|
||||
e for e in events
|
||||
if e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "awaiting_approval"
|
||||
]
|
||||
assert len(pause_events) == 1
|
||||
|
||||
def test_mixed_execute_and_pause(self):
|
||||
"""One tool executes, another needs approval."""
|
||||
handler = ConcreteHandler()
|
||||
agent = self._make_agent()
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def selective_pause(tools_dict, call, llm_class):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 2:
|
||||
return {
|
||||
"call_id": "c2",
|
||||
"name": "danger_0",
|
||||
"tool_name": "danger",
|
||||
"tool_id": "0",
|
||||
"action_name": "danger",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
return None
|
||||
|
||||
agent.tool_executor.check_pause = Mock(side_effect=selective_pause)
|
||||
|
||||
calls = [
|
||||
ToolCall(id="c1", name="safe_0", arguments="{}"),
|
||||
ToolCall(id="c2", name="danger_0", arguments="{}"),
|
||||
]
|
||||
gen = handler.handle_tool_calls(
|
||||
agent, calls, {"0": {"name": "multi"}}, []
|
||||
)
|
||||
|
||||
events = []
|
||||
try:
|
||||
while True:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
messages, pending = e.value
|
||||
|
||||
# First tool was executed normally
|
||||
assert agent._execute_tool_action.call_count == 1
|
||||
# Second tool is pending
|
||||
assert pending is not None
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["call_id"] == "c2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_streaming yields tool_calls_pending
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamingPause:
|
||||
|
||||
def test_streaming_yields_tool_calls_pending(self):
|
||||
handler = ConcreteHandler()
|
||||
agent = Mock()
|
||||
agent.llm = Mock()
|
||||
agent.model_id = "test"
|
||||
agent.tools = []
|
||||
agent._check_context_limit = Mock(return_value=False)
|
||||
agent.context_limit_reached = False
|
||||
agent.llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
pause_info = {
|
||||
"call_id": "c1",
|
||||
"name": "fn_0",
|
||||
"tool_name": "test",
|
||||
"tool_id": "0",
|
||||
"action_name": "fn",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
agent.tool_executor.check_pause = Mock(return_value=pause_info)
|
||||
|
||||
chunk = LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="c1", name="fn_0", arguments="{}", index=0)],
|
||||
finish_reason="tool_calls",
|
||||
raw_response={},
|
||||
)
|
||||
handler.parse_response = lambda c: c
|
||||
|
||||
def fake_iterate(response):
|
||||
yield from response
|
||||
|
||||
handler._iterate_stream = fake_iterate
|
||||
|
||||
gen = handler.handle_streaming(agent, [chunk], {"0": {"name": "t"}}, [])
|
||||
events = list(gen)
|
||||
|
||||
# Should contain a tool_calls_pending event
|
||||
pending_events = [
|
||||
e for e in events
|
||||
if isinstance(e, dict) and e.get("type") == "tool_calls_pending"
|
||||
]
|
||||
assert len(pending_events) == 1
|
||||
assert len(pending_events[0]["data"]["pending_tool_calls"]) == 1
|
||||
|
||||
# Agent should have _pending_continuation set
|
||||
assert hasattr(agent, "_pending_continuation")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseAgent.gen_continuation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenContinuation:
|
||||
|
||||
def test_approved_tool_executes(self):
|
||||
"""When a tool action is approved, the tool is executed."""
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm._supports_tools = True
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Final answer"]))
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
mock_llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
mock_handler = Mock()
|
||||
mock_handler.process_message_flow = Mock(return_value=iter([]))
|
||||
mock_handler.create_tool_message = Mock(
|
||||
return_value={"role": "tool", "content": [{"function_response": {
|
||||
"name": "act_0", "response": {"result": "done"}, "call_id": "c1"
|
||||
}}]}
|
||||
)
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.tool_calls = []
|
||||
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
|
||||
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
|
||||
|
||||
def fake_execute(tools_dict, call, llm_class):
|
||||
yield {"type": "tool_call", "data": {"status": "pending"}}
|
||||
return ("result_data", "c1")
|
||||
|
||||
mock_executor.execute = Mock(side_effect=fake_execute)
|
||||
|
||||
agent = ClassicAgent(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test",
|
||||
llm=mock_llm,
|
||||
llm_handler=mock_handler,
|
||||
tool_executor=mock_executor,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": "You are helpful."}]
|
||||
tools_dict = {"0": {"name": "test_tool"}}
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "act_0",
|
||||
"tool_name": "test_tool",
|
||||
"tool_id": "0",
|
||||
"action_name": "act",
|
||||
"arguments": {"q": "test"},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [{"call_id": "c1", "decision": "approved"}]
|
||||
|
||||
list(agent.gen_continuation(messages, tools_dict, pending, tool_actions))
|
||||
|
||||
# Tool should have been executed
|
||||
assert mock_executor.execute.called
|
||||
|
||||
def test_denied_tool_sends_denial(self):
|
||||
"""When a tool action is denied, a denial message is added."""
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm._supports_tools = True
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
mock_llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
mock_handler = Mock()
|
||||
mock_handler.process_message_flow = Mock(return_value=iter([]))
|
||||
mock_handler.create_tool_message = Mock(
|
||||
return_value={"role": "tool", "content": "denied"}
|
||||
)
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.tool_calls = []
|
||||
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
|
||||
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
|
||||
|
||||
agent = ClassicAgent(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test",
|
||||
llm=mock_llm,
|
||||
llm_handler=mock_handler,
|
||||
tool_executor=mock_executor,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "danger_0",
|
||||
"tool_name": "danger",
|
||||
"tool_id": "0",
|
||||
"action_name": "danger",
|
||||
"arguments": {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [
|
||||
{"call_id": "c1", "decision": "denied", "comment": "too risky"}
|
||||
]
|
||||
|
||||
events = list(
|
||||
agent.gen_continuation(messages, {"0": {"name": "danger"}}, pending, tool_actions)
|
||||
)
|
||||
|
||||
# Should have a denied tool_call event
|
||||
denied = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "denied"
|
||||
]
|
||||
assert len(denied) == 1
|
||||
|
||||
# create_tool_message should have been called with denial text
|
||||
denial_arg = mock_handler.create_tool_message.call_args[0][1]
|
||||
assert "denied" in denial_arg.lower()
|
||||
assert "too risky" in denial_arg
|
||||
|
||||
def test_client_result_appended(self):
|
||||
"""Client-provided tool result is added to messages."""
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm._supports_tools = True
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Done"]))
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
mock_llm.__class__.__name__ = "MockLLM"
|
||||
|
||||
mock_handler = Mock()
|
||||
mock_handler.process_message_flow = Mock(return_value=iter([]))
|
||||
mock_handler.create_tool_message = Mock(
|
||||
return_value={"role": "tool", "content": "client result"}
|
||||
)
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.tool_calls = []
|
||||
mock_executor.prepare_tools_for_llm = Mock(return_value=[])
|
||||
mock_executor.get_truncated_tool_calls = Mock(return_value=[])
|
||||
|
||||
agent = ClassicAgent(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test",
|
||||
llm=mock_llm,
|
||||
llm_handler=mock_handler,
|
||||
tool_executor=mock_executor,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": "test"}]
|
||||
pending = [
|
||||
{
|
||||
"call_id": "c1",
|
||||
"name": "weather_0",
|
||||
"tool_name": "weather",
|
||||
"tool_id": "0",
|
||||
"action_name": "weather",
|
||||
"arguments": {"city": "SF"},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": None,
|
||||
}
|
||||
]
|
||||
tool_actions = [{"call_id": "c1", "result": {"temp": "72F"}}]
|
||||
|
||||
events = list(
|
||||
agent.gen_continuation(messages, {"0": {"name": "weather"}}, pending, tool_actions)
|
||||
)
|
||||
|
||||
# create_tool_message was called with the client result
|
||||
result_arg = mock_handler.create_tool_message.call_args[0][1]
|
||||
assert "72F" in result_arg
|
||||
|
||||
# Should have a completed tool_call event
|
||||
completed = [
|
||||
e for e in events
|
||||
if isinstance(e, dict)
|
||||
and e.get("type") == "tool_call"
|
||||
and e.get("data", {}).get("status") == "completed"
|
||||
]
|
||||
assert len(completed) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateRequest:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _app_context(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
yield
|
||||
|
||||
def test_continuation_request_without_question(self, mock_mongo_db):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
base = BaseAnswerResource()
|
||||
data = {
|
||||
"conversation_id": "conv-1",
|
||||
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
|
||||
}
|
||||
result = base.validate_request(data)
|
||||
assert result is None # Valid
|
||||
|
||||
def test_continuation_request_missing_conversation_id(self, mock_mongo_db):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
base = BaseAnswerResource()
|
||||
data = {
|
||||
"tool_actions": [{"call_id": "c1", "decision": "approved"}],
|
||||
}
|
||||
result = base.validate_request(data)
|
||||
assert result is not None # Error — missing conversation_id
|
||||
|
||||
def test_normal_request_still_requires_question(self, mock_mongo_db):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
base = BaseAnswerResource()
|
||||
data = {"conversation_id": "conv-1"}
|
||||
result = base.validate_request(data)
|
||||
assert result is not None # Error — missing question
|
||||
Reference in New Issue
Block a user