mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
* feat: template-based prompt rendering with dynamic namespace injection * refactor: improve template engine initialization with clearer formatting * refactor: streamline ReActAgent methods and improve content extraction logic feat: enhance error handling in NamespaceManager and TemplateEngine fix: update NewAgent component to ensure consistent form data submission test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage * feat: tools namespace + three-tier token budget * refactor: remove unused variable assignment in message building tests * Enhance prompt customization and tool pre-fetching functionality * ruff lint fix * refactor: cleaner error handling and reduce code clutter --------- Co-authored-by: Alex <a@tushynski.me>
233 lines
6.6 KiB
Python
233 lines
6.6 KiB
Python
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from application.agents.classic_agent import ClassicAgent
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestClassicAgent:
|
|
|
|
def test_classic_agent_initialization(
|
|
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
|
):
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
assert isinstance(agent, ClassicAgent)
|
|
assert agent.endpoint == agent_base_params["endpoint"]
|
|
assert agent.llm_name == agent_base_params["llm_name"]
|
|
|
|
def test_gen_inner_basic_flow(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
def mock_gen_stream(*args, **kwargs):
|
|
yield "Answer chunk 1"
|
|
yield "Answer chunk 2"
|
|
|
|
mock_llm.gen_stream = Mock(return_value=mock_gen_stream())
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed answer"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
results = list(agent._gen_inner("Test query", log_context))
|
|
|
|
assert len(results) >= 2
|
|
sources = [r for r in results if "sources" in r]
|
|
tool_calls = [r for r in results if "tool_calls" in r]
|
|
|
|
assert len(sources) == 1
|
|
assert len(tool_calls) == 1
|
|
|
|
def test_gen_inner_retrieves_documents(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
list(agent._gen_inner("Test query", log_context))
|
|
|
|
def test_gen_inner_uses_user_api_key_tools(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
from application.core.settings import settings
|
|
from bson.objectid import ObjectId
|
|
|
|
tool_id = str(ObjectId())
|
|
mock_mongo_db[settings.MONGO_DB_NAME]["agents"].docs = {
|
|
"api_key_123": {"key": "api_key_123", "tools": [tool_id]}
|
|
}
|
|
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
|
tool_id: {"_id": ObjectId(tool_id), "name": "test_tool"}
|
|
}
|
|
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent_base_params["user_api_key"] = "api_key_123"
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
list(agent._gen_inner("Test query", log_context))
|
|
|
|
assert len(agent.tools) >= 0
|
|
|
|
def test_gen_inner_uses_user_tools(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
from application.core.settings import settings
|
|
|
|
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
|
|
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True}
|
|
}
|
|
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
list(agent._gen_inner("Test query", log_context))
|
|
|
|
assert len(agent.tools) >= 0
|
|
|
|
def test_gen_inner_builds_correct_messages(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
list(agent._gen_inner("Test query", log_context))
|
|
|
|
call_kwargs = mock_llm.gen_stream.call_args[1]
|
|
messages = call_kwargs["messages"]
|
|
|
|
assert len(messages) >= 2
|
|
assert messages[0]["role"] == "system"
|
|
assert messages[-1]["role"] == "user"
|
|
assert messages[-1]["content"] == "Test query"
|
|
|
|
def test_gen_inner_logs_tool_calls(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
log_context,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
agent.tool_calls = [{"tool": "test", "result": "success"}]
|
|
|
|
list(agent._gen_inner("Test query", log_context))
|
|
|
|
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
|
|
assert len(agent_logs) == 1
|
|
assert "tool_calls" in agent_logs[0]["data"]
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestClassicAgentIntegration:
|
|
|
|
def test_gen_method_with_logging(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
results = list(agent.gen("Test query"))
|
|
|
|
assert len(results) >= 1
|
|
|
|
def test_gen_method_decorator_applied(
|
|
self,
|
|
agent_base_params,
|
|
mock_llm,
|
|
mock_llm_handler,
|
|
mock_llm_creator,
|
|
mock_llm_handler_creator,
|
|
mock_mongo_db,
|
|
):
|
|
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
|
|
|
def mock_handler(*args, **kwargs):
|
|
yield "Processed"
|
|
|
|
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
|
|
|
agent = ClassicAgent(**agent_base_params)
|
|
|
|
assert hasattr(agent.gen, "__wrapped__")
|