mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-21 21:05:05 +00:00
280 lines
8.8 KiB
Python
280 lines
8.8 KiB
Python
"""Tests for ToolExecutor — tool discovery, preparation, and execution."""
|
|
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from application.agents.tool_executor import ToolExecutor
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestToolExecutorInit:
|
|
|
|
def test_default_state(self):
|
|
executor = ToolExecutor()
|
|
assert executor.user_api_key is None
|
|
assert executor.user is None
|
|
assert executor.tool_calls == []
|
|
assert executor._loaded_tools == {}
|
|
assert executor.conversation_id is None
|
|
|
|
def test_init_with_params(self):
|
|
executor = ToolExecutor(
|
|
user_api_key="key", user="alice", decoded_token={"sub": "alice"}
|
|
)
|
|
assert executor.user_api_key == "key"
|
|
assert executor.user == "alice"
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestToolExecutorGetTools:
|
|
|
|
def test_get_tools_uses_api_key_when_present(self, mock_mongo_db):
|
|
executor = ToolExecutor(user_api_key="test_key", user="alice")
|
|
tools = executor.get_tools()
|
|
assert isinstance(tools, dict)
|
|
|
|
def test_get_tools_uses_user_when_no_api_key(self, mock_mongo_db):
|
|
executor = ToolExecutor(user="alice")
|
|
tools = executor.get_tools()
|
|
assert isinstance(tools, dict)
|
|
|
|
def test_get_tools_defaults_to_local(self, mock_mongo_db):
|
|
executor = ToolExecutor()
|
|
tools = executor.get_tools()
|
|
assert isinstance(tools, dict)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestToolExecutorPrepare:
|
|
|
|
def test_prepare_tools_for_llm_empty(self):
|
|
executor = ToolExecutor()
|
|
result = executor.prepare_tools_for_llm({})
|
|
assert result == []
|
|
|
|
def test_prepare_tools_for_llm_non_api_tool(self):
|
|
executor = ToolExecutor()
|
|
tools_dict = {
|
|
"t1": {
|
|
"name": "test_tool",
|
|
"actions": [
|
|
{
|
|
"name": "do_thing",
|
|
"description": "Does a thing",
|
|
"active": True,
|
|
"parameters": {
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The query",
|
|
"filled_by_llm": True,
|
|
"required": True,
|
|
}
|
|
}
|
|
},
|
|
}
|
|
],
|
|
}
|
|
}
|
|
|
|
result = executor.prepare_tools_for_llm(tools_dict)
|
|
assert len(result) == 1
|
|
assert result[0]["type"] == "function"
|
|
assert result[0]["function"]["name"] == "do_thing_t1"
|
|
assert "query" in result[0]["function"]["parameters"]["properties"]
|
|
|
|
def test_prepare_tools_skips_inactive_actions(self):
|
|
executor = ToolExecutor()
|
|
tools_dict = {
|
|
"t1": {
|
|
"name": "test_tool",
|
|
"actions": [
|
|
{"name": "active_one", "description": "D", "active": True, "parameters": {"properties": {}}},
|
|
{"name": "inactive_one", "description": "D", "active": False, "parameters": {"properties": {}}},
|
|
],
|
|
}
|
|
}
|
|
|
|
result = executor.prepare_tools_for_llm(tools_dict)
|
|
assert len(result) == 1
|
|
assert result[0]["function"]["name"] == "active_one_t1"
|
|
|
|
def test_build_tool_parameters_filters_non_llm_fields(self):
|
|
executor = ToolExecutor()
|
|
action = {
|
|
"parameters": {
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Search query",
|
|
"filled_by_llm": True,
|
|
"value": "default_val",
|
|
"required": True,
|
|
},
|
|
"hidden": {
|
|
"type": "string",
|
|
"filled_by_llm": False,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
result = executor._build_tool_parameters(action)
|
|
assert "query" in result["properties"]
|
|
assert "hidden" not in result["properties"]
|
|
assert "query" in result["required"]
|
|
# filled_by_llm, value, required stripped from schema
|
|
assert "filled_by_llm" not in result["properties"]["query"]
|
|
assert "value" not in result["properties"]["query"]
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestToolExecutorExecute:
|
|
|
|
def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"):
|
|
call = Mock()
|
|
call.name = name
|
|
call.id = call_id
|
|
call.arguments = arguments
|
|
return call
|
|
|
|
def test_execute_parse_failure(self, monkeypatch):
|
|
executor = ToolExecutor()
|
|
|
|
monkeypatch.setattr(
|
|
"application.agents.tool_executor.ToolActionParser",
|
|
lambda _cls: Mock(parse_args=Mock(return_value=(None, None, {}))),
|
|
)
|
|
|
|
call = self._make_call(name="bad")
|
|
gen = executor.execute({}, call, "MockLLM")
|
|
|
|
events = []
|
|
result = None
|
|
while True:
|
|
try:
|
|
events.append(next(gen))
|
|
except StopIteration as e:
|
|
result = e.value
|
|
break
|
|
|
|
assert result[0] == "Failed to parse tool call."
|
|
assert len(executor.tool_calls) == 1
|
|
assert events[0]["data"]["status"] == "error"
|
|
|
|
def test_execute_tool_not_found(self, monkeypatch):
|
|
executor = ToolExecutor()
|
|
|
|
monkeypatch.setattr(
|
|
"application.agents.tool_executor.ToolActionParser",
|
|
lambda _cls: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))),
|
|
)
|
|
|
|
call = self._make_call()
|
|
gen = executor.execute({}, call, "MockLLM")
|
|
|
|
events = []
|
|
result = None
|
|
while True:
|
|
try:
|
|
events.append(next(gen))
|
|
except StopIteration as e:
|
|
result = e.value
|
|
break
|
|
|
|
assert "not found" in result[0]
|
|
assert events[0]["data"]["status"] == "error"
|
|
|
|
def test_execute_success(self, mock_tool_manager, monkeypatch):
|
|
executor = ToolExecutor(user="test_user")
|
|
|
|
monkeypatch.setattr(
|
|
"application.agents.tool_executor.ToolActionParser",
|
|
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))),
|
|
)
|
|
|
|
tools_dict = {
|
|
"t1": {
|
|
"name": "test_tool",
|
|
"config": {"key": "val"},
|
|
"actions": [
|
|
{"name": "test_action", "description": "Test", "parameters": {"properties": {}}},
|
|
],
|
|
}
|
|
}
|
|
|
|
call = self._make_call(name="test_action_t1", call_id="c1")
|
|
gen = executor.execute(tools_dict, call, "MockLLM")
|
|
|
|
events = []
|
|
result = None
|
|
while True:
|
|
try:
|
|
events.append(next(gen))
|
|
except StopIteration as e:
|
|
result = e.value
|
|
break
|
|
|
|
assert result[0] == "Tool result"
|
|
assert result[1] == "c1"
|
|
|
|
statuses = [e["data"]["status"] for e in events]
|
|
assert "pending" in statuses
|
|
assert "completed" in statuses
|
|
|
|
def test_get_truncated_tool_calls(self):
|
|
executor = ToolExecutor()
|
|
executor.tool_calls = [
|
|
{
|
|
"tool_name": "test",
|
|
"call_id": "1",
|
|
"action_name": "act",
|
|
"arguments": {},
|
|
"result": "A" * 100,
|
|
}
|
|
]
|
|
|
|
truncated = executor.get_truncated_tool_calls()
|
|
assert len(truncated) == 1
|
|
assert len(truncated[0]["result"]) <= 53
|
|
assert truncated[0]["status"] == "completed"
|
|
|
|
def test_tool_caching(self, mock_tool_manager, monkeypatch):
|
|
executor = ToolExecutor(user="test_user")
|
|
|
|
monkeypatch.setattr(
|
|
"application.agents.tool_executor.ToolActionParser",
|
|
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))),
|
|
)
|
|
|
|
tools_dict = {
|
|
"t1": {
|
|
"name": "test_tool",
|
|
"config": {"key": "val"},
|
|
"actions": [
|
|
{"name": "test_action", "description": "Test", "parameters": {"properties": {}}},
|
|
],
|
|
}
|
|
}
|
|
|
|
call = self._make_call(name="test_action_t1")
|
|
|
|
# First execution — loads tool
|
|
gen = executor.execute(tools_dict, call, "MockLLM")
|
|
while True:
|
|
try:
|
|
next(gen)
|
|
except StopIteration:
|
|
break
|
|
|
|
# Second execution — should use cache
|
|
gen = executor.execute(tools_dict, call, "MockLLM")
|
|
while True:
|
|
try:
|
|
next(gen)
|
|
except StopIteration:
|
|
break
|
|
|
|
# load_tool called only once due to cache
|
|
assert mock_tool_manager.load_tool.call_count == 1
|