mini refactors

This commit is contained in:
Alex
2026-03-25 22:34:25 +00:00
parent c6ece177cd
commit 462f2e9494
22 changed files with 2233 additions and 98 deletions

View File

@@ -4,8 +4,7 @@ from typing import Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
build_internal_tool_config,
build_internal_tool_entry,
add_internal_search_tool,
)
from application.logging import LogContext
@@ -32,24 +31,8 @@ class AgenticAgent(BaseAgent):
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
# 1. Get user tools (same as ClassicAgent)
tools_dict = self.tool_executor.get_tools()
# 2. Add internal search as a synthetic tool (only if sources are configured)
source = self.retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if self.retriever_config and has_sources:
has_dir = _sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(
has_directory_structure=has_dir
)
internal_entry["config"] = build_internal_tool_config(
**self.retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
# 3. Prepare all tools for the LLM
add_internal_search_tool(tools_dict, self.retriever_config)
self._prepare_tools(tools_dict)
# 4. Build messages (prompt has NO pre-fetched docs)
@@ -78,40 +61,3 @@ class AgenticAgent(BaseAgent):
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
self.retrieved_docs = tool.retrieved_docs
def _sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have directory_structure in MongoDB."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)},
{"directory_structure": 1},
)
if source_doc and source_doc.get("directory_structure"):
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
return False
# Import settings at module level for _sources_have_directory_structure
from application.core.settings import settings # noqa: E402

View File

@@ -6,11 +6,9 @@ from typing import Dict, Generator, List, Optional
from application.agents.base import BaseAgent
from application.agents.tool_executor import ToolExecutor
from application.agents.agentic_agent import _sources_have_directory_structure
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
build_internal_tool_config,
build_internal_tool_entry,
add_internal_search_tool,
)
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
from application.logging import LogContext
@@ -130,6 +128,7 @@ class ResearchAgent(BaseAgent):
self.citations = CitationManager()
self._start_time: float = 0
self._tokens_used: int = 0
self._last_token_snapshot: int = 0
# ------------------------------------------------------------------
# Budget & timeout helpers
@@ -153,7 +152,9 @@ class ResearchAgent(BaseAgent):
def _snapshot_llm_tokens(self) -> int:
"""Read current token usage from LLM and return delta since last snapshot."""
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
return current
delta = current - self._last_token_snapshot
self._last_token_snapshot = current
return delta
# ------------------------------------------------------------------
# Main orchestration
@@ -272,21 +273,7 @@ class ResearchAgent(BaseAgent):
"""Build tools_dict with user tools + internal search + think."""
tools_dict = self.tool_executor.get_tools()
# Only add internal search if sources are configured
source = self.retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if self.retriever_config and has_sources:
has_dir = _sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(
has_directory_structure=has_dir
)
internal_entry["config"] = build_internal_tool_config(
**self.retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
elif self.retriever_config and not has_sources:
logger.info("ResearchAgent: No sources configured, skipping internal_search tool")
add_internal_search_tool(tools_dict, self.retriever_config)
think_entry = dict(THINK_TOOL_ENTRY)
think_entry["config"] = {}
@@ -580,7 +567,14 @@ class ResearchAgent(BaseAgent):
call_id = None
while True:
try:
next(gen)
event = next(gen)
# Log tool_call status events instead of discarding them
if isinstance(event, dict) and event.get("type") == "tool_call":
logger.debug(
"Tool %s status: %s",
event.get("data", {}).get("action_name", ""),
event.get("data", {}).get("status", ""),
)
except StopIteration as e:
result, call_id = e.value
break

View File

@@ -354,6 +354,59 @@ def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
def sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have directory_structure in MongoDB."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)},
{"directory_structure": 1},
)
if source_doc and source_doc.get("directory_structure"):
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
return False
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
"""Add the internal search tool to tools_dict if sources are configured.
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
Mutates tools_dict in place.
"""
source = retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if not retriever_config or not has_sources:
return
has_dir = sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
internal_entry["config"] = build_internal_tool_config(
**retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
def build_internal_tool_config(
source: Dict,
retriever_name: str = "classic",

View File

@@ -2,8 +2,10 @@
from typing import Any, Dict, List, Optional, Type
from application.agents.agentic_agent import AgenticAgent
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflows.schemas import AgentType
@@ -35,7 +37,8 @@ class ToolFilterMixin:
return filtered_tools
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
class _WorkflowNodeMixin:
"""Common __init__ for all workflow node agents."""
def __init__(
self,
@@ -56,11 +59,25 @@ class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
pass
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
pass
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
pass
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
AgentType.RESEARCH: WorkflowNodeResearchAgent,
}
@classmethod

View File

@@ -18,6 +18,8 @@ class NodeType(str, Enum):
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
AGENTIC = "agentic"
RESEARCH = "research"
class ExecutionStatus(str, Enum):

View File

@@ -7,6 +7,7 @@ from application.agents.workflows.cel_evaluator import CelEvaluationError, evalu
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
@@ -223,18 +224,32 @@ class WorkflowEngine:
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_json_schema,
)
factory_kwargs = {
"agent_type": node_config.agent_type,
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,
"chat_history": self.agent.chat_history,
"decoded_token": self.agent.decoded_token,
"json_schema": node_json_schema,
}
# Agentic/research agents need retriever_config for on-demand search
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
factory_kwargs["retriever_config"] = {
"source": {"active_docs": node_config.sources} if node_config.sources else {},
"retriever_name": node_config.retriever or "classic",
"chunks": int(node_config.chunks) if node_config.chunks else 2,
"model_id": node_model_id,
"llm_name": node_llm_name,
"api_key": node_api_key,
"decoded_token": self.agent.decoded_token,
}
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []

View File

@@ -13,4 +13,4 @@ Use the search_internal tool to find relevant information before answering quest
You may search multiple times with different queries if needed.
Do not guess when documents are available — search first, then answer based on what you find.
If no relevant documents are found, use your general knowledge and tool capabilities.
Allow yourself to be very creative and use your imagination.
Allow yourself to be very creative and use your imagination.

View File

@@ -12,4 +12,4 @@ You have access to a search tool that searches the user's uploaded documents and
Use the search_internal tool to find relevant information before answering questions.
You may search multiple times with different queries if needed.
Do not guess when documents are available — search first, then answer based on what you find.
If no relevant documents are found, use your general knowledge and tool capabilities.
If no relevant documents are found, use your general knowledge and tool capabilities.

View File

@@ -13,4 +13,4 @@ Use the search_internal tool to find relevant information before answering quest
You may search multiple times with different queries if needed.
You MUST search before answering any factual question. Do not guess or use general knowledge when documents are available.
If you dont have enough information from the search results or tools, answer "I don't know" or "I don't have enough information".
Never make up information or provide false information!
Never make up information or provide false information!

View File

@@ -20,4 +20,4 @@ You MUST respond with ONLY a valid JSON object (no markdown, no code fences):
"needs_clarification": true or false,
"questions": ["question 1", "question 2"] (only if needs_clarification is true, max 3 questions),
"reason": "brief explanation of why clarification is or isn't needed"
}
}

View File

@@ -20,4 +20,4 @@ You MUST respond with ONLY a valid JSON object in this exact format (no markdown
"steps": [
{"query": "specific sub-question to investigate", "rationale": "why this step is needed"}
]
}
}

View File

@@ -10,4 +10,4 @@ Instructions:
5. Cite specific documents and passages you found. Reference sources by their titles or filenames.
6. If you cannot find relevant information through tools, use your general knowledge but clearly indicate this.
Be thorough — prefer completeness over brevity. Include all relevant details you find.
Be thorough — prefer completeness over brevity. Include all relevant details you find.

View File

@@ -19,4 +19,4 @@ Write a well-structured, thorough report that:
Available sources for citation:
{references}
Format the report with clear headings and sections. Be comprehensive but well-organized.
Format the report with clear headings and sections. Be comprehensive but well-organized.

View File

@@ -55,7 +55,7 @@ import WorkflowPreview from './workflow/WorkflowPreview';
import type { Model } from '../models/types';
interface AgentNodeConfig {
agent_type: 'classic';
agent_type: 'classic' | 'agentic' | 'research';
llm_name?: string;
model_id?: string;
system_prompt: string;
@@ -884,6 +884,12 @@ function WorkflowBuilderInner() {
<SelectItem value="classic">
Classic
</SelectItem>
<SelectItem value="agentic">
Agentic
</SelectItem>
<SelectItem value="research">
Research
</SelectItem>
</SelectContent>
</Select>
</div>

View File

@@ -75,7 +75,7 @@ import type { Model } from '../../models/types';
const PRIMARY_ACTION_SPINNER_DELAY_MS = 180;
interface AgentNodeConfig {
agent_type: 'classic';
agent_type: 'classic' | 'agentic' | 'research';
llm_name?: string;
model_id?: string;
system_prompt: string;
@@ -1748,6 +1748,12 @@ function WorkflowBuilderInner() {
<SelectItem value="classic">
Classic
</SelectItem>
<SelectItem value="agentic">
Agentic
</SelectItem>
<SelectItem value="research">
Research
</SelectItem>
</SelectContent>
</Select>
</div>

View File

@@ -0,0 +1,139 @@
"""Tests for AgenticAgent — LLM-controlled retrieval agent."""
from unittest.mock import Mock
import pytest
from application.agents.agentic_agent import AgenticAgent
@pytest.mark.unit
class TestAgenticAgentInit:
def test_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = AgenticAgent(**agent_base_params)
assert isinstance(agent, AgenticAgent)
assert agent.retriever_config == {}
def test_initialization_with_retriever_config(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
rc = {"source": {"active_docs": ["abc"]}, "retriever_name": "classic"}
agent = AgenticAgent(retriever_config=rc, **agent_base_params)
assert agent.retriever_config == rc
def test_inherits_base_properties(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = AgenticAgent(**agent_base_params)
assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"]
assert agent.model_id == agent_base_params["model_id"]
@pytest.mark.unit
class TestAgenticAgentGenInner:
def test_basic_flow_yields_sources_and_tool_calls(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Processed"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = AgenticAgent(**agent_base_params)
results = list(agent._gen_inner("Test query", log_context))
sources = [r for r in results if "sources" in r]
tool_calls = [r for r in results if "tool_calls" in r]
assert len(sources) == 1
assert len(tool_calls) == 1
def test_logs_agent_component(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Done"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = AgenticAgent(**agent_base_params)
list(agent._gen_inner("Query", log_context))
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
assert len(agent_logs) == 1
assert "tool_calls" in agent_logs[0]["data"]
def test_no_pre_fetched_docs_in_messages(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
mock_mongo_db,
log_context,
):
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
def mock_handler(*args, **kwargs):
yield "Done"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
agent = AgenticAgent(**agent_base_params)
list(agent._gen_inner("Query", log_context))
call_kwargs = mock_llm.gen_stream.call_args[1]
messages = call_kwargs["messages"]
# System prompt should not contain {summaries} replacement
assert messages[0]["role"] == "system"
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == "Query"
@pytest.mark.unit
class TestAgenticAgentCollectSources:
def test_collect_internal_sources_from_cache(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = AgenticAgent(**agent_base_params)
mock_tool = Mock()
mock_tool.retrieved_docs = [
{"text": "Found", "title": "Doc", "source": "test"},
]
cache_key = f"internal_search:internal:{agent.user or ''}"
agent.tool_executor._loaded_tools[cache_key] = mock_tool
agent._collect_internal_sources()
assert len(agent.retrieved_docs) == 1
assert agent.retrieved_docs[0]["title"] == "Doc"
def test_collect_internal_sources_no_cache(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = AgenticAgent(**agent_base_params)
agent._collect_internal_sources()
assert agent.retrieved_docs == []

View File

@@ -0,0 +1,250 @@
"""Tests for InternalSearchTool and its helper functions."""
from unittest.mock import Mock, patch
import pytest
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
InternalSearchTool,
add_internal_search_tool,
build_internal_tool_config,
build_internal_tool_entry,
)
@pytest.mark.unit
class TestInternalSearchToolSearch:
def _make_tool(self, **config_overrides):
config = {"source": {}, "retriever_name": "classic", "chunks": 2}
config.update(config_overrides)
return InternalSearchTool(config)
def test_search_no_query_returns_error(self):
tool = self._make_tool()
result = tool.execute_action("search", query="")
assert "required" in result.lower()
def test_search_returns_formatted_docs(self):
tool = self._make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "Hello world", "title": "Doc1", "source": "test", "filename": "doc1.md"},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="hello")
assert "doc1.md" in result
assert "Hello world" in result
assert len(tool.retrieved_docs) == 1
def test_search_no_results(self):
tool = self._make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = []
tool._retriever = mock_retriever
result = tool.execute_action("search", query="nonexistent")
assert "No documents found" in result
def test_search_accumulates_docs(self):
tool = self._make_tool()
mock_retriever = Mock()
tool._retriever = mock_retriever
mock_retriever.search.return_value = [
{"text": "A", "title": "D1", "source": "s1"},
]
tool.execute_action("search", query="first")
mock_retriever.search.return_value = [
{"text": "B", "title": "D2", "source": "s2"},
]
tool.execute_action("search", query="second")
assert len(tool.retrieved_docs) == 2
def test_search_deduplicates_docs(self):
tool = self._make_tool()
doc = {"text": "Same", "title": "Same", "source": "same"}
mock_retriever = Mock()
mock_retriever.search.return_value = [doc]
tool._retriever = mock_retriever
tool.execute_action("search", query="q1")
tool.execute_action("search", query="q2")
assert len(tool.retrieved_docs) == 1
def test_search_with_path_filter(self):
tool = self._make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "A", "title": "T", "source": "src/main.py", "filename": "main.py"},
{"text": "B", "title": "T", "source": "docs/readme.md", "filename": "readme.md"},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="code", path_filter="src/")
assert "main.py" in result
assert "readme.md" not in result
def test_search_path_filter_no_match(self):
tool = self._make_tool()
mock_retriever = Mock()
mock_retriever.search.return_value = [
{"text": "A", "title": "T", "source": "other/file.txt"},
]
tool._retriever = mock_retriever
result = tool.execute_action("search", query="code", path_filter="src/")
assert "No documents found" in result
def test_search_retriever_error(self):
tool = self._make_tool()
mock_retriever = Mock()
mock_retriever.search.side_effect = Exception("Connection error")
tool._retriever = mock_retriever
result = tool.execute_action("search", query="test")
assert "failed" in result.lower() or "error" in result.lower()
def test_unknown_action(self):
tool = self._make_tool()
result = tool.execute_action("nonexistent")
assert "Unknown action" in result
@pytest.mark.unit
class TestInternalSearchToolListFiles:
def test_list_files_no_structure(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = None
result = tool.execute_action("list_files")
assert "No file structure" in result
def test_list_files_root(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {
"src": {"main.py": {}},
"README.md": {"type": "md", "token_count": 100},
}
result = tool.execute_action("list_files")
assert "src/" in result
assert "README.md" in result
def test_list_files_nested_path(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {
"src": {
"utils": {"helper.py": {}},
},
}
result = tool.execute_action("list_files", path="src")
assert "utils/" in result
def test_list_files_invalid_path(self):
tool = InternalSearchTool({"source": {}})
tool._dir_structure_loaded = True
tool._directory_structure = {"src": {}}
result = tool.execute_action("list_files", path="nonexistent")
assert "not found" in result
@pytest.mark.unit
class TestInternalSearchToolMetadata:
def test_actions_without_directory_structure(self):
tool = InternalSearchTool({"has_directory_structure": False})
meta = tool.get_actions_metadata()
action_names = [a["name"] for a in meta]
assert "search" in action_names
assert "list_files" not in action_names
# search should not have path_filter
search = meta[0]
assert "path_filter" not in search["parameters"]["properties"]
def test_actions_with_directory_structure(self):
tool = InternalSearchTool({"has_directory_structure": True})
meta = tool.get_actions_metadata()
action_names = [a["name"] for a in meta]
assert "search" in action_names
assert "list_files" in action_names
# search should have path_filter
search = next(a for a in meta if a["name"] == "search")
assert "path_filter" in search["parameters"]["properties"]
@pytest.mark.unit
class TestBuildHelpers:
def test_build_entry_without_directory_structure(self):
entry = build_internal_tool_entry(has_directory_structure=False)
assert entry["name"] == "internal_search"
action_names = [a["name"] for a in entry["actions"]]
assert "search" in action_names
assert "list_files" not in action_names
def test_build_entry_with_directory_structure(self):
entry = build_internal_tool_entry(has_directory_structure=True)
action_names = [a["name"] for a in entry["actions"]]
assert "list_files" in action_names
def test_build_config(self):
config = build_internal_tool_config(
source={"active_docs": ["abc"]},
retriever_name="semantic",
chunks=4,
)
assert config["source"] == {"active_docs": ["abc"]}
assert config["retriever_name"] == "semantic"
assert config["chunks"] == 4
def test_internal_tool_id(self):
assert INTERNAL_TOOL_ID == "internal"
def test_add_internal_search_tool_with_sources(self):
tools_dict = {}
retriever_config = {
"source": {"active_docs": ["abc"]},
"retriever_name": "classic",
"chunks": 2,
"model_id": "gpt-4",
"llm_name": "openai",
"api_key": "key",
}
with patch(
"application.agents.tools.internal_search.sources_have_directory_structure",
return_value=False,
):
add_internal_search_tool(tools_dict, retriever_config)
assert INTERNAL_TOOL_ID in tools_dict
assert tools_dict[INTERNAL_TOOL_ID]["name"] == "internal_search"
assert "config" in tools_dict[INTERNAL_TOOL_ID]
def test_add_internal_search_tool_no_sources(self):
tools_dict = {}
retriever_config = {"source": {}}
add_internal_search_tool(tools_dict, retriever_config)
assert INTERNAL_TOOL_ID not in tools_dict
def test_add_internal_search_tool_empty_config(self):
tools_dict = {}
add_internal_search_tool(tools_dict, {})
assert INTERNAL_TOOL_ID not in tools_dict

View File

@@ -0,0 +1,400 @@
"""Tests for ResearchAgent — multi-step research with budget controls."""
import json
import time
from unittest.mock import Mock, patch
import pytest
from application.agents.research_agent import (
CitationManager,
ResearchAgent,
DEFAULT_MAX_STEPS,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TOKEN_BUDGET,
)
# ---------------------------------------------------------------------------
# CitationManager
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCitationManager:
def test_add_returns_citation_number(self):
cm = CitationManager()
num = cm.add({"source": "s1", "title": "T1"})
assert num == 1
def test_add_deduplicates(self):
cm = CitationManager()
n1 = cm.add({"source": "s1", "title": "T1"})
n2 = cm.add({"source": "s1", "title": "T1"})
assert n1 == n2
assert len(cm.citations) == 1
def test_add_different_sources(self):
cm = CitationManager()
n1 = cm.add({"source": "s1", "title": "T1"})
n2 = cm.add({"source": "s2", "title": "T2"})
assert n1 != n2
assert len(cm.citations) == 2
def test_add_docs_returns_mapping(self):
cm = CitationManager()
docs = [
{"source": "s1", "title": "Doc A"},
{"source": "s2", "title": "Doc B"},
]
text = cm.add_docs(docs)
assert "[1] Doc A" in text
assert "[2] Doc B" in text
def test_format_references(self):
cm = CitationManager()
cm.add({"source": "http://example.com", "title": "Example", "filename": "ex.md"})
refs = cm.format_references()
assert "[1]" in refs
assert "ex.md" in refs
def test_format_references_empty(self):
cm = CitationManager()
assert "No sources" in cm.format_references()
def test_get_all_docs(self):
cm = CitationManager()
cm.add({"source": "s1", "title": "T1"})
cm.add({"source": "s2", "title": "T2"})
docs = cm.get_all_docs()
assert len(docs) == 2
# ---------------------------------------------------------------------------
# ResearchAgent Init & Budget
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestResearchAgentInit:
def test_initialization(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ResearchAgent(**agent_base_params)
assert isinstance(agent, ResearchAgent)
assert agent.max_steps == DEFAULT_MAX_STEPS
assert agent.timeout_seconds == DEFAULT_TIMEOUT_SECONDS
assert agent.token_budget == DEFAULT_TOKEN_BUDGET
assert agent.retriever_config == {}
def test_custom_budget(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ResearchAgent(
max_steps=3,
timeout_seconds=60,
token_budget=50_000,
**agent_base_params,
)
assert agent.max_steps == 3
assert agent.timeout_seconds == 60
assert agent.token_budget == 50_000
def test_with_retriever_config(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
rc = {"source": {"active_docs": ["abc"]}}
agent = ResearchAgent(retriever_config=rc, **agent_base_params)
assert agent.retriever_config == rc
@pytest.mark.unit
class TestResearchAgentBudget:
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, **kwargs):
return ResearchAgent(**kwargs, **agent_base_params)
def test_timeout_detection(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
timeout_seconds=0,
)
agent._start_time = time.monotonic() - 1
assert agent._is_timed_out() is True
def test_not_timed_out(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
timeout_seconds=300,
)
agent._start_time = time.monotonic()
assert agent._is_timed_out() is False
def test_token_budget_tracking(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
token_budget=1000,
)
agent._track_tokens(500)
assert agent._budget_remaining() == 500
assert agent._is_over_budget() is False
agent._track_tokens(500)
assert agent._budget_remaining() == 0
assert agent._is_over_budget() is True
def test_snapshot_llm_tokens_returns_delta(
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
)
mock_llm.token_usage = {"prompt_tokens": 100, "generated_tokens": 50}
delta1 = agent._snapshot_llm_tokens()
assert delta1 == 150
# Simulate more tokens used
mock_llm.token_usage = {"prompt_tokens": 200, "generated_tokens": 100}
delta2 = agent._snapshot_llm_tokens()
assert delta2 == 150 # 300 - 150
def test_elapsed(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
)
agent._start_time = time.monotonic() - 1.5
elapsed = agent._elapsed()
assert elapsed >= 1.0
# ---------------------------------------------------------------------------
# ResearchAgent Phases
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestResearchAgentClarification:
def test_is_follow_up_no_history(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ResearchAgent(**agent_base_params)
assert agent._is_follow_up() is False
def test_is_follow_up_with_clarification_metadata(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["chat_history"] = [
{"prompt": "What?", "response": "Clarify", "metadata": {"is_clarification": True}},
]
agent = ResearchAgent(**agent_base_params)
assert agent._is_follow_up() is True
def test_is_follow_up_without_metadata(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["chat_history"] = [
{"prompt": "What?", "response": "Normal answer"},
]
agent = ResearchAgent(**agent_base_params)
assert agent._is_follow_up() is False
def test_clarification_returns_none_on_no_clarification_needed(
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
):
response = Mock()
response.choices = [Mock()]
response.choices[0].message = Mock()
response.choices[0].message.content = json.dumps(
{"needs_clarification": False, "reason": "Clear enough"}
)
mock_llm.gen = Mock(return_value=response)
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
agent = ResearchAgent(**agent_base_params)
result = agent._clarification_phase("What is Python?")
assert result is None
def test_clarification_returns_questions(
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
):
clarification_json = json.dumps({
"needs_clarification": True,
"questions": ["Which version?", "What context?"],
})
# Return a plain string so _extract_text handles it directly
mock_llm.gen = Mock(return_value=clarification_json)
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
agent = ResearchAgent(**agent_base_params)
result = agent._clarification_phase("Tell me about it")
assert result is not None
assert "Which version?" in result
assert "What context?" in result
@pytest.mark.unit
class TestResearchAgentPlanning:
def test_planning_returns_steps_and_complexity(
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
):
plan_json = json.dumps({
"complexity": "moderate",
"steps": [
{"query": "sub-question 1", "rationale": "reason 1"},
{"query": "sub-question 2", "rationale": "reason 2"},
],
})
# Return plain string so _extract_text handles it directly
mock_llm.gen = Mock(return_value=plan_json)
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
agent = ResearchAgent(**agent_base_params)
steps, complexity = agent._planning_phase("Compare A and B")
assert complexity == "moderate"
assert len(steps) == 2
assert steps[0]["query"] == "sub-question 1"
def test_planning_caps_steps_by_complexity(
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
):
plan_json = json.dumps({
"complexity": "simple",
"steps": [
{"query": f"q{i}", "rationale": f"r{i}"} for i in range(10)
],
})
response = Mock()
response.choices = [Mock()]
response.choices[0].message = Mock()
response.choices[0].message.content = plan_json
mock_llm.gen = Mock(return_value=response)
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
agent = ResearchAgent(**agent_base_params)
steps, complexity = agent._planning_phase("Simple question")
assert complexity == "simple"
assert len(steps) <= 2 # COMPLEXITY_CAPS["simple"] == 2
def test_planning_fallback_on_error(
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
):
mock_llm.gen = Mock(side_effect=Exception("LLM down"))
mock_llm.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
agent = ResearchAgent(**agent_base_params)
steps, complexity = agent._planning_phase("Anything")
assert complexity == "simple"
assert len(steps) == 1
assert steps[0]["query"] == "Anything"
@pytest.mark.unit
class TestResearchAgentExtractText:
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
return ResearchAgent(**agent_base_params)
def test_extract_from_string(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
assert agent._extract_text("hello") == "hello"
def test_extract_from_openai_response(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
response = Mock()
response.choices = [Mock()]
response.choices[0].message = Mock()
response.choices[0].message.content = "OpenAI content"
response.message = None
response.content = None
assert agent._extract_text(response) == "OpenAI content"
def test_extract_from_anthropic_response(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
text_block = Mock()
text_block.text = "Anthropic content"
response = Mock()
response.content = [text_block]
response.message = None
response.choices = None
assert agent._extract_text(response) == "Anthropic content"
def test_extract_from_none(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
assert agent._extract_text(None) == ""
@pytest.mark.unit
class TestResearchAgentParseJson:
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
return ResearchAgent(**agent_base_params)
def test_parse_plan_direct_json(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
text = '{"steps": [{"query": "q1"}], "complexity": "simple"}'
result = agent._parse_plan_json(text)
assert isinstance(result, dict)
assert len(result["steps"]) == 1
def test_parse_plan_from_code_fence(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
text = 'Here is the plan:\n```json\n{"steps": [{"query": "q1"}]}\n```'
result = agent._parse_plan_json(text)
assert isinstance(result, dict)
def test_parse_plan_invalid_returns_empty(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
result = agent._parse_plan_json("not json at all")
assert result == []
def test_parse_clarification_json(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
text = '{"needs_clarification": false, "reason": "clear"}'
result = agent._parse_clarification_json(text)
assert result["needs_clarification"] is False
def test_parse_clarification_json_from_code_fence(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
text = '```json\n{"needs_clarification": true, "questions": ["q1"]}\n```'
result = agent._parse_clarification_json(text)
assert result["needs_clarification"] is True
def test_parse_clarification_json_invalid(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
result = agent._parse_clarification_json("not json")
assert result is None

View File

@@ -0,0 +1,52 @@
"""Tests for ThinkTool — the chain-of-thought pseudo-tool."""
import pytest
from application.agents.tools.think import (
THINK_TOOL_ENTRY,
THINK_TOOL_ID,
ThinkTool,
)
@pytest.mark.unit
class TestThinkTool:
def test_id_constant(self):
assert THINK_TOOL_ID == "think"
def test_entry_has_reason_action(self):
actions = THINK_TOOL_ENTRY["actions"]
assert len(actions) == 1
assert actions[0]["name"] == "reason"
assert actions[0]["active"] is True
def test_execute_reason_returns_continue(self):
tool = ThinkTool()
result = tool.execute_action("reason", reasoning="step by step thinking")
assert result == "Continue."
def test_execute_unknown_action_returns_continue(self):
tool = ThinkTool()
result = tool.execute_action("unknown_action")
assert result == "Continue."
def test_get_actions_metadata(self):
tool = ThinkTool()
meta = tool.get_actions_metadata()
assert len(meta) == 1
assert meta[0]["name"] == "reason"
props = meta[0]["parameters"]["properties"]
assert "reasoning" in props
assert props["reasoning"]["filled_by_llm"] is True
def test_get_config_requirements_empty(self):
tool = ThinkTool()
assert tool.get_config_requirements() == {}
def test_init_accepts_no_config(self):
tool = ThinkTool()
assert tool is not None
def test_init_accepts_config(self):
tool = ThinkTool(config={"key": "value"})
assert tool is not None

View File

@@ -0,0 +1,279 @@
"""Tests for ToolExecutor — tool discovery, preparation, and execution."""
from unittest.mock import Mock
import pytest
from application.agents.tool_executor import ToolExecutor
@pytest.mark.unit
class TestToolExecutorInit:
def test_default_state(self):
executor = ToolExecutor()
assert executor.user_api_key is None
assert executor.user is None
assert executor.tool_calls == []
assert executor._loaded_tools == {}
assert executor.conversation_id is None
def test_init_with_params(self):
executor = ToolExecutor(
user_api_key="key", user="alice", decoded_token={"sub": "alice"}
)
assert executor.user_api_key == "key"
assert executor.user == "alice"
@pytest.mark.unit
class TestToolExecutorGetTools:
def test_get_tools_uses_api_key_when_present(self, mock_mongo_db):
executor = ToolExecutor(user_api_key="test_key", user="alice")
tools = executor.get_tools()
assert isinstance(tools, dict)
def test_get_tools_uses_user_when_no_api_key(self, mock_mongo_db):
executor = ToolExecutor(user="alice")
tools = executor.get_tools()
assert isinstance(tools, dict)
def test_get_tools_defaults_to_local(self, mock_mongo_db):
executor = ToolExecutor()
tools = executor.get_tools()
assert isinstance(tools, dict)
@pytest.mark.unit
class TestToolExecutorPrepare:
def test_prepare_tools_for_llm_empty(self):
executor = ToolExecutor()
result = executor.prepare_tools_for_llm({})
assert result == []
def test_prepare_tools_for_llm_non_api_tool(self):
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "test_tool",
"actions": [
{
"name": "do_thing",
"description": "Does a thing",
"active": True,
"parameters": {
"properties": {
"query": {
"type": "string",
"description": "The query",
"filled_by_llm": True,
"required": True,
}
}
},
}
],
}
}
result = executor.prepare_tools_for_llm(tools_dict)
assert len(result) == 1
assert result[0]["type"] == "function"
assert result[0]["function"]["name"] == "do_thing_t1"
assert "query" in result[0]["function"]["parameters"]["properties"]
def test_prepare_tools_skips_inactive_actions(self):
executor = ToolExecutor()
tools_dict = {
"t1": {
"name": "test_tool",
"actions": [
{"name": "active_one", "description": "D", "active": True, "parameters": {"properties": {}}},
{"name": "inactive_one", "description": "D", "active": False, "parameters": {"properties": {}}},
],
}
}
result = executor.prepare_tools_for_llm(tools_dict)
assert len(result) == 1
assert result[0]["function"]["name"] == "active_one_t1"
def test_build_tool_parameters_filters_non_llm_fields(self):
executor = ToolExecutor()
action = {
"parameters": {
"properties": {
"query": {
"type": "string",
"description": "Search query",
"filled_by_llm": True,
"value": "default_val",
"required": True,
},
"hidden": {
"type": "string",
"filled_by_llm": False,
},
}
}
}
result = executor._build_tool_parameters(action)
assert "query" in result["properties"]
assert "hidden" not in result["properties"]
assert "query" in result["required"]
# filled_by_llm, value, required stripped from schema
assert "filled_by_llm" not in result["properties"]["query"]
assert "value" not in result["properties"]["query"]
@pytest.mark.unit
class TestToolExecutorExecute:
def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"):
call = Mock()
call.name = name
call.id = call_id
call.arguments = arguments
return call
def test_execute_parse_failure(self, monkeypatch):
executor = ToolExecutor()
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=(None, None, {}))),
)
call = self._make_call(name="bad")
gen = executor.execute({}, call, "MockLLM")
events = []
result = None
while True:
try:
events.append(next(gen))
except StopIteration as e:
result = e.value
break
assert result[0] == "Failed to parse tool call."
assert len(executor.tool_calls) == 1
assert events[0]["data"]["status"] == "error"
def test_execute_tool_not_found(self, monkeypatch):
executor = ToolExecutor()
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))),
)
call = self._make_call()
gen = executor.execute({}, call, "MockLLM")
events = []
result = None
while True:
try:
events.append(next(gen))
except StopIteration as e:
result = e.value
break
assert "not found" in result[0]
assert events[0]["data"]["status"] == "error"
def test_execute_success(self, mock_tool_manager, monkeypatch):
executor = ToolExecutor(user="test_user")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))),
)
tools_dict = {
"t1": {
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "Test", "parameters": {"properties": {}}},
],
}
}
call = self._make_call(name="test_action_t1", call_id="c1")
gen = executor.execute(tools_dict, call, "MockLLM")
events = []
result = None
while True:
try:
events.append(next(gen))
except StopIteration as e:
result = e.value
break
assert result[0] == "Tool result"
assert result[1] == "c1"
statuses = [e["data"]["status"] for e in events]
assert "pending" in statuses
assert "completed" in statuses
def test_get_truncated_tool_calls(self):
executor = ToolExecutor()
executor.tool_calls = [
{
"tool_name": "test",
"call_id": "1",
"action_name": "act",
"arguments": {},
"result": "A" * 100,
}
]
truncated = executor.get_truncated_tool_calls()
assert len(truncated) == 1
assert len(truncated[0]["result"]) <= 53
assert truncated[0]["status"] == "completed"
def test_tool_caching(self, mock_tool_manager, monkeypatch):
executor = ToolExecutor(user="test_user")
monkeypatch.setattr(
"application.agents.tool_executor.ToolActionParser",
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))),
)
tools_dict = {
"t1": {
"name": "test_tool",
"config": {"key": "val"},
"actions": [
{"name": "test_action", "description": "Test", "parameters": {"properties": {}}},
],
}
}
call = self._make_call(name="test_action_t1")
# First execution — loads tool
gen = executor.execute(tools_dict, call, "MockLLM")
while True:
try:
next(gen)
except StopIteration:
break
# Second execution — should use cache
gen = executor.execute(tools_dict, call, "MockLLM")
while True:
try:
next(gen)
except StopIteration:
break
# load_tool called only once due to cache
assert mock_tool_manager.load_tool.call_count == 1

View File

@@ -0,0 +1,475 @@
"""Tests for new agent types (agentic, research) in the workflow builder."""
from types import SimpleNamespace
from typing import Any, Dict
import pytest
from application.agents.agentic_agent import AgenticAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflows.node_agent import (
WorkflowNodeAgenticAgent,
WorkflowNodeAgentFactory,
WorkflowNodeClassicAgent,
WorkflowNodeResearchAgent,
)
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
NodeType,
Workflow,
WorkflowGraph,
WorkflowNode,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class StubNodeAgent:
"""Minimal agent stub that yields pre-defined events."""
def __init__(self, events):
self.events = events
def gen(self, _prompt):
yield from self.events
def create_engine() -> WorkflowEngine:
graph = WorkflowGraph(workflow=Workflow(name="Test"), nodes=[], edges=[])
agent = SimpleNamespace(
endpoint="stream",
llm_name="openai",
model_id="gpt-4o-mini",
api_key="test-key",
chat_history=[],
decoded_token={"sub": "user-1"},
)
return WorkflowEngine(graph, agent)
def create_agent_node(
node_id: str,
agent_type: str = "classic",
sources: list = None,
chunks: str = "2",
retriever: str = "",
output_variable: str = "",
) -> WorkflowNode:
config: Dict[str, Any] = {
"agent_type": agent_type,
"system_prompt": "You are a helpful assistant.",
"prompt_template": "",
"stream_to_user": False,
"tools": [],
"sources": sources or [],
"chunks": chunks,
"retriever": retriever,
}
if output_variable:
config["output_variable"] = output_variable
return WorkflowNode(
id=node_id,
workflow_id="workflow-1",
type=NodeType.AGENT,
title="Agent",
position={"x": 0, "y": 0},
config=config,
)
# ---------------------------------------------------------------------------
# AgentType enum
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestAgentTypeEnum:
def test_agentic_value_exists(self):
assert AgentType.AGENTIC == "agentic"
def test_research_value_exists(self):
assert AgentType.RESEARCH == "research"
def test_classic_still_exists(self):
assert AgentType.CLASSIC == "classic"
def test_react_still_exists(self):
assert AgentType.REACT == "react"
# ---------------------------------------------------------------------------
# AgentNodeConfig schema validation
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestAgentNodeConfigValidation:
def test_accepts_agentic_agent_type(self):
config = AgentNodeConfig(agent_type="agentic")
assert config.agent_type == AgentType.AGENTIC
def test_accepts_research_agent_type(self):
config = AgentNodeConfig(agent_type="research")
assert config.agent_type == AgentType.RESEARCH
def test_rejects_unknown_agent_type(self):
with pytest.raises(Exception):
AgentNodeConfig(agent_type="nonexistent")
def test_default_agent_type_is_classic(self):
config = AgentNodeConfig()
assert config.agent_type == AgentType.CLASSIC
# ---------------------------------------------------------------------------
# WorkflowNodeAgentFactory registry
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestWorkflowNodeAgentFactoryRegistry:
def test_factory_has_agentic(self):
assert AgentType.AGENTIC in WorkflowNodeAgentFactory._agents
assert WorkflowNodeAgentFactory._agents[AgentType.AGENTIC] is WorkflowNodeAgenticAgent
def test_factory_has_research(self):
assert AgentType.RESEARCH in WorkflowNodeAgentFactory._agents
assert WorkflowNodeAgentFactory._agents[AgentType.RESEARCH] is WorkflowNodeResearchAgent
def test_factory_raises_for_unknown_type(self):
with pytest.raises(ValueError, match="Unsupported agent type"):
WorkflowNodeAgentFactory.create(
agent_type="nonexistent",
endpoint="stream",
llm_name="openai",
model_id="gpt-4o-mini",
api_key="key",
)
# ---------------------------------------------------------------------------
# WorkflowNode agent classes (inheritance)
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestWorkflowNodeAgentClasses:
def test_agentic_agent_inherits_correctly(self):
assert issubclass(WorkflowNodeAgenticAgent, AgenticAgent)
def test_research_agent_inherits_correctly(self):
assert issubclass(WorkflowNodeResearchAgent, ResearchAgent)
def test_classic_agent_inherits_correctly(self):
assert issubclass(WorkflowNodeClassicAgent, ClassicAgent)
# ---------------------------------------------------------------------------
# Workflow engine: agentic agent node execution
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestWorkflowEngineAgenticNode:
def test_agentic_node_executes_and_saves_output(self, monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_agentic",
agent_type="agentic",
output_variable="result",
)
node_events = [{"answer": "agentic answer"}]
captured: Dict[str, Any] = {}
def capture_create(**kwargs):
captured.update(kwargs)
return StubNodeAgent(node_events)
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(capture_create),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
assert engine.state["node_agent_agentic_output"] == "agentic answer"
assert engine.state["result"] == "agentic answer"
def test_agentic_node_passes_retriever_config(self, monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_rc",
agent_type="agentic",
sources=["source-abc"],
chunks="4",
retriever="semantic",
)
node_events = [{"answer": "ok"}]
captured: Dict[str, Any] = {}
def capture_create(**kwargs):
captured.update(kwargs)
return StubNodeAgent(node_events)
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(capture_create),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
rc = captured.get("retriever_config")
assert rc is not None
assert rc["source"] == {"active_docs": ["source-abc"]}
assert rc["retriever_name"] == "semantic"
assert rc["chunks"] == 4
def test_agentic_node_empty_sources_gives_empty_source_dict(self, monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_nosrc",
agent_type="agentic",
sources=[],
)
node_events = [{"answer": "ok"}]
captured: Dict[str, Any] = {}
def capture_create(**kwargs):
captured.update(kwargs)
return StubNodeAgent(node_events)
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(capture_create),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
rc = captured["retriever_config"]
assert rc["source"] == {}
# ---------------------------------------------------------------------------
# Workflow engine: research agent node execution
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestWorkflowEngineResearchNode:
def test_research_node_executes_and_saves_output(self, monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_research",
agent_type="research",
output_variable="report",
)
node_events = [{"answer": "research report"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
assert engine.state["node_agent_research_output"] == "research report"
assert engine.state["report"] == "research report"
def test_research_node_passes_retriever_config(self, monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_rr",
agent_type="research",
sources=["doc-1", "doc-2"],
chunks="6",
)
node_events = [{"answer": "ok"}]
captured: Dict[str, Any] = {}
def capture_create(**kwargs):
captured.update(kwargs)
return StubNodeAgent(node_events)
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(capture_create),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
rc = captured["retriever_config"]
assert rc["source"] == {"active_docs": ["doc-1", "doc-2"]}
assert rc["chunks"] == 6
assert rc["decoded_token"] == {"sub": "user-1"}
def test_research_node_handles_structured_output(self, monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_rs",
agent_type="research",
output_variable="data",
)
# Simulate structured JSON output from research agent
node_events = [
{"answer": '{"findings": "important"}', "structured": True},
]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
# structured=True causes the engine to parse JSON
assert engine.state["data"] == {"findings": "important"}
# ---------------------------------------------------------------------------
# Workflow engine: classic node does NOT get retriever_config
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestWorkflowEngineClassicNodeNoRetrieverConfig:
def test_classic_node_does_not_pass_retriever_config(self, monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_classic",
agent_type="classic",
sources=["some-source"],
)
node_events = [{"answer": "classic answer"}]
captured: Dict[str, Any] = {}
def capture_create(**kwargs):
captured.update(kwargs)
return StubNodeAgent(node_events)
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(capture_create),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
assert "retriever_config" not in captured
# ---------------------------------------------------------------------------
# Workflow engine: streaming events from new agent types
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestWorkflowEngineStreamingEvents:
def test_agentic_node_streams_answer_events(self, monkeypatch):
engine = create_engine()
node = create_agent_node(node_id="agent_s1", agent_type="agentic")
# Modify config to enable streaming
node.config["stream_to_user"] = True
node_events = [
{"answer": "chunk 1"},
{"answer": "chunk 2"},
]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
results = list(engine._execute_agent_node(node))
answer_events = [r for r in results if "answer" in r]
assert len(answer_events) == 2
def test_research_node_passes_through_non_answer_events(self, monkeypatch):
"""Research agents yield research_plan/research_progress events.
The workflow engine only forwards 'answer' events to the user."""
engine = create_engine()
node = create_agent_node(node_id="agent_s2", agent_type="research")
node.config["stream_to_user"] = True
node_events = [
{"type": "research_plan", "data": {"steps": [], "complexity": "simple"}},
{"type": "research_progress", "data": {"status": "planning"}},
{"answer": "final report"},
]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
results = list(engine._execute_agent_node(node))
# Only answer events are streamed to user
answer_events = [r for r in results if "answer" in r]
assert len(answer_events) == 1
assert answer_events[0]["answer"] == "final report"
# State still captures the full text
assert engine.state["node_agent_s2_output"] == "final report"

View File

@@ -492,6 +492,486 @@ This is test documentation for integration tests.
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Agentic / Research Agent Tests
# -------------------------------------------------------------------------
def _create_agent_with_type(self, agent_type: str) -> Optional[tuple]:
"""Create a test agent with the given agent_type. Returns (agent_id, api_key) or None."""
if not self.is_authenticated:
return None
payload = {
"name": f"Chat Test {agent_type.title()} Agent {int(time.time())}",
"description": f"Integration test {agent_type} agent",
"prompt_id": "default",
"chunks": 2,
"retriever": "classic",
"agent_type": agent_type,
"status": "draft",
}
try:
response = self.post("/api/create_agent", json=payload, timeout=10)
if response.status_code in [200, 201]:
result = response.json()
agent_id = result.get("id")
api_key = result.get("key")
if agent_id:
return (agent_id, api_key)
except Exception:
pass
return None
def test_stream_agentic_agent(self) -> bool:
"""Test /stream endpoint with an agentic agent."""
test_name = "Stream endpoint (agentic agent)"
agent_result = self._create_agent_with_type("agentic")
if not agent_result:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
agent_id, _ = agent_result
self.print_header(f"Testing {test_name}")
payload = {
"question": "What is DocsGPT?",
"history": "[]",
"agent_id": agent_id,
}
try:
self.print_info(f"POST /stream with agentic agent_id={agent_id[:8]}...")
response = requests.post(
f"{self.base_url}/stream",
json=payload,
headers=self.headers,
stream=True,
timeout=60,
)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
events = []
full_response = ""
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
try:
data = json_module.loads(line_str[6:])
events.append(data)
if data.get("type") in ["stream", "answer"]:
full_response += data.get("message", "") or data.get("answer", "")
elif data.get("type") == "end":
break
except json_module.JSONDecodeError:
pass
self.print_success(f"Received {len(events)} events")
if full_response:
self.print_success(f"Answer preview: {full_response[:100]}...")
self.record_result(test_name, True, "Success")
return True
except Exception as e:
self.print_error(f"Error: {str(e)}")
self.record_result(test_name, False, str(e))
return False
def test_answer_agentic_agent(self) -> bool:
"""Test /api/answer endpoint with an agentic agent."""
test_name = "Answer endpoint (agentic agent)"
agent_result = self._create_agent_with_type("agentic")
if not agent_result:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
agent_id, _ = agent_result
self.print_header(f"Testing {test_name}")
payload = {
"question": "What is DocsGPT?",
"history": "[]",
"agent_id": agent_id,
}
try:
self.print_info(f"POST /api/answer with agentic agent_id={agent_id[:8]}...")
response = self.post("/api/answer", json=payload, timeout=60)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
result = response.json()
answer = result.get("answer", "")
self.print_success(f"Answer received: {answer[:100]}...")
self.record_result(test_name, True, "Success")
return True
except Exception as e:
self.print_error(f"Error: {str(e)}")
self.record_result(test_name, False, str(e))
return False
def test_stream_research_agent(self) -> bool:
"""Test /stream endpoint with a research agent.
Research agents emit additional SSE event types:
- research_plan: the decomposed research steps
- research_progress: per-step status updates
"""
test_name = "Stream endpoint (research agent)"
agent_result = self._create_agent_with_type("research")
if not agent_result:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
agent_id, _ = agent_result
self.print_header(f"Testing {test_name}")
payload = {
"question": "What is DocsGPT?",
"history": "[]",
"agent_id": agent_id,
}
try:
self.print_info(f"POST /stream with research agent_id={agent_id[:8]}...")
response = requests.post(
f"{self.base_url}/stream",
json=payload,
headers=self.headers,
stream=True,
timeout=120,
)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
events = []
full_response = ""
saw_plan = False
saw_progress = False
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
try:
data = json_module.loads(line_str[6:])
events.append(data)
event_type = data.get("type", "")
if event_type in ["stream", "answer"]:
full_response += data.get("message", "") or data.get("answer", "")
elif event_type == "research_plan":
saw_plan = True
steps = data.get("data", {}).get("steps", [])
self.print_info(f"Research plan: {len(steps)} steps, complexity={data.get('data', {}).get('complexity')}")
elif event_type == "research_progress":
saw_progress = True
elif event_type == "end":
break
except json_module.JSONDecodeError:
pass
self.print_success(f"Received {len(events)} events")
if saw_plan:
self.print_success("Received research_plan event")
if saw_progress:
self.print_success("Received research_progress events")
if full_response:
self.print_success(f"Report preview: {full_response[:100]}...")
self.record_result(test_name, True, "Success")
return True
except Exception as e:
self.print_error(f"Error: {str(e)}")
self.record_result(test_name, False, str(e))
return False
def test_answer_research_agent(self) -> bool:
"""Test /api/answer endpoint with a research agent."""
test_name = "Answer endpoint (research agent)"
agent_result = self._create_agent_with_type("research")
if not agent_result:
if not self.require_auth(test_name):
return True
self.record_result(test_name, True, "Skipped (no agent)")
return True
agent_id, _ = agent_result
self.print_header(f"Testing {test_name}")
payload = {
"question": "What is DocsGPT?",
"history": "[]",
"agent_id": agent_id,
}
try:
self.print_info(f"POST /api/answer with research agent_id={agent_id[:8]}...")
response = self.post("/api/answer", json=payload, timeout=120)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
result = response.json()
answer = result.get("answer", "")
self.print_success(f"Answer received: {answer[:100]}...")
self.record_result(test_name, True, "Success")
return True
except Exception as e:
self.print_error(f"Error: {str(e)}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Workflow with Agentic Node Tests
# -------------------------------------------------------------------------
def _create_workflow_with_agentic_node(self) -> Optional[str]:
"""Create a workflow with a single agentic agent node.
Returns the workflow_id or None on failure.
"""
nodes = [
{
"id": "start_1",
"type": "start",
"title": "Start",
"data": {},
},
{
"id": "agent_1",
"type": "agent",
"title": "Agentic Node",
"data": {
"agent_type": "agentic",
"system_prompt": "You are a helpful assistant.",
"prompt_template": "",
"stream_to_user": True,
"tools": [],
"sources": [],
"chunks": "2",
},
},
{
"id": "end_1",
"type": "end",
"title": "End",
"data": {},
},
]
edges = [
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
]
payload = {
"name": f"Agentic Workflow Test {int(time.time())}",
"nodes": nodes,
"edges": edges,
}
try:
response = self.post("/api/workflows", json=payload, timeout=15)
if response.status_code in [200, 201]:
result = response.json()
return result.get("id")
except Exception:
pass
return None
def _create_workflow_agent(self, workflow_id: str) -> Optional[tuple]:
"""Create an agent of type 'workflow' referencing a workflow.
Returns (agent_id, api_key) or None.
"""
payload = {
"name": f"Workflow Agent Test {int(time.time())}",
"description": "Integration test workflow agent with agentic node",
"prompt_id": "default",
"chunks": 2,
"retriever": "classic",
"agent_type": "workflow",
"workflow": workflow_id,
"status": "draft",
}
try:
response = self.post("/api/create_agent", json=payload, timeout=10)
if response.status_code in [200, 201]:
result = response.json()
agent_id = result.get("id")
api_key = result.get("key")
if agent_id:
return (agent_id, api_key)
except Exception:
pass
return None
def test_stream_workflow_with_agentic_node(self) -> bool:
"""Test /stream with a workflow agent that contains an agentic node."""
test_name = "Stream endpoint (workflow with agentic node)"
if not self.require_auth(test_name):
return True
workflow_id = self._create_workflow_with_agentic_node()
if not workflow_id:
self.print_warning("Could not create workflow")
self.record_result(test_name, True, "Skipped (workflow creation failed)")
return True
agent_result = self._create_workflow_agent(workflow_id)
if not agent_result:
self.print_warning("Could not create workflow agent")
self.record_result(test_name, True, "Skipped (agent creation failed)")
return True
agent_id, _ = agent_result
self.print_header(f"Testing {test_name}")
payload = {
"question": "What is DocsGPT?",
"history": "[]",
"agent_id": agent_id,
}
try:
self.print_info(f"POST /stream with workflow agent_id={agent_id[:8]}...")
self.print_info(f"Workflow ID: {workflow_id}")
response = requests.post(
f"{self.base_url}/stream",
json=payload,
headers=self.headers,
stream=True,
timeout=60,
)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
events = []
full_response = ""
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
try:
data = json_module.loads(line_str[6:])
events.append(data)
if data.get("type") in ["stream", "answer"]:
full_response += data.get("message", "") or data.get("answer", "")
elif data.get("type") == "end":
break
except json_module.JSONDecodeError:
pass
self.print_success(f"Received {len(events)} events")
if full_response:
self.print_success(f"Answer preview: {full_response[:100]}...")
self.record_result(test_name, True, "Success")
return True
except Exception as e:
self.print_error(f"Error: {str(e)}")
self.record_result(test_name, False, str(e))
return False
def test_answer_workflow_with_agentic_node(self) -> bool:
"""Test /api/answer with a workflow agent that contains an agentic node."""
test_name = "Answer endpoint (workflow with agentic node)"
if not self.require_auth(test_name):
return True
workflow_id = self._create_workflow_with_agentic_node()
if not workflow_id:
self.print_warning("Could not create workflow")
self.record_result(test_name, True, "Skipped (workflow creation failed)")
return True
agent_result = self._create_workflow_agent(workflow_id)
if not agent_result:
self.print_warning("Could not create workflow agent")
self.record_result(test_name, True, "Skipped (agent creation failed)")
return True
agent_id, _ = agent_result
self.print_header(f"Testing {test_name}")
payload = {
"question": "What is DocsGPT?",
"history": "[]",
"agent_id": agent_id,
}
try:
self.print_info(f"POST /api/answer with workflow agent_id={agent_id[:8]}...")
response = self.post("/api/answer", json=payload, timeout=60)
self.print_info(f"Status Code: {response.status_code}")
if response.status_code != 200:
self.print_error(f"Expected 200, got {response.status_code}")
self.record_result(test_name, False, f"Status {response.status_code}")
return False
result = response.json()
answer = result.get("answer", "")
self.print_success(f"Answer received: {answer[:100]}...")
self.record_result(test_name, True, "Success")
return True
except Exception as e:
self.print_error(f"Error: {str(e)}")
self.record_result(test_name, False, str(e))
return False
# -------------------------------------------------------------------------
# Validation Tests
# -------------------------------------------------------------------------
@@ -916,6 +1396,27 @@ This is test documentation for integration tests.
self.test_answer_endpoint_with_agent()
time.sleep(1)
# Agentic agent tests
self.test_stream_agentic_agent()
time.sleep(1)
self.test_answer_agentic_agent()
time.sleep(1)
# Research agent tests
self.test_stream_research_agent()
time.sleep(2)
self.test_answer_research_agent()
time.sleep(2)
# Workflow with agentic node tests
self.test_stream_workflow_with_agentic_node()
time.sleep(2)
self.test_answer_workflow_with_agentic_node()
time.sleep(2)
# API key tests
self.test_stream_endpoint_with_api_key()
time.sleep(1)