mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
mini refactors
This commit is contained in:
@@ -4,8 +4,7 @@ from typing import Dict, Generator, Optional
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
build_internal_tool_config,
|
||||
build_internal_tool_entry,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.logging import LogContext
|
||||
|
||||
@@ -32,24 +31,8 @@ class AgenticAgent(BaseAgent):
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
# 1. Get user tools (same as ClassicAgent)
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
|
||||
# 2. Add internal search as a synthetic tool (only if sources are configured)
|
||||
source = self.retriever_config.get("source", {})
|
||||
has_sources = bool(source.get("active_docs"))
|
||||
if self.retriever_config and has_sources:
|
||||
has_dir = _sources_have_directory_structure(source)
|
||||
internal_entry = build_internal_tool_entry(
|
||||
has_directory_structure=has_dir
|
||||
)
|
||||
internal_entry["config"] = build_internal_tool_config(
|
||||
**self.retriever_config,
|
||||
has_directory_structure=has_dir,
|
||||
)
|
||||
tools_dict[INTERNAL_TOOL_ID] = internal_entry
|
||||
|
||||
# 3. Prepare all tools for the LLM
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
# 4. Build messages (prompt has NO pre-fetched docs)
|
||||
@@ -78,40 +61,3 @@ class AgenticAgent(BaseAgent):
|
||||
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
|
||||
self.retrieved_docs = tool.retrieved_docs
|
||||
|
||||
|
||||
def _sources_have_directory_structure(source: Dict) -> bool:
|
||||
"""Check if any of the active sources have directory_structure in MongoDB."""
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return False
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)},
|
||||
{"directory_structure": 1},
|
||||
)
|
||||
if source_doc and source_doc.get("directory_structure"):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory structure: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Import settings at module level for _sources_have_directory_structure
|
||||
from application.core.settings import settings # noqa: E402
|
||||
|
||||
@@ -6,11 +6,9 @@ from typing import Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.agents.agentic_agent import _sources_have_directory_structure
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
build_internal_tool_config,
|
||||
build_internal_tool_entry,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
|
||||
from application.logging import LogContext
|
||||
@@ -130,6 +128,7 @@ class ResearchAgent(BaseAgent):
|
||||
self.citations = CitationManager()
|
||||
self._start_time: float = 0
|
||||
self._tokens_used: int = 0
|
||||
self._last_token_snapshot: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Budget & timeout helpers
|
||||
@@ -153,7 +152,9 @@ class ResearchAgent(BaseAgent):
|
||||
def _snapshot_llm_tokens(self) -> int:
|
||||
"""Read current token usage from LLM and return delta since last snapshot."""
|
||||
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
|
||||
return current
|
||||
delta = current - self._last_token_snapshot
|
||||
self._last_token_snapshot = current
|
||||
return delta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main orchestration
|
||||
@@ -272,21 +273,7 @@ class ResearchAgent(BaseAgent):
|
||||
"""Build tools_dict with user tools + internal search + think."""
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
|
||||
# Only add internal search if sources are configured
|
||||
source = self.retriever_config.get("source", {})
|
||||
has_sources = bool(source.get("active_docs"))
|
||||
if self.retriever_config and has_sources:
|
||||
has_dir = _sources_have_directory_structure(source)
|
||||
internal_entry = build_internal_tool_entry(
|
||||
has_directory_structure=has_dir
|
||||
)
|
||||
internal_entry["config"] = build_internal_tool_config(
|
||||
**self.retriever_config,
|
||||
has_directory_structure=has_dir,
|
||||
)
|
||||
tools_dict[INTERNAL_TOOL_ID] = internal_entry
|
||||
elif self.retriever_config and not has_sources:
|
||||
logger.info("ResearchAgent: No sources configured, skipping internal_search tool")
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
|
||||
think_entry = dict(THINK_TOOL_ENTRY)
|
||||
think_entry["config"] = {}
|
||||
@@ -580,7 +567,14 @@ class ResearchAgent(BaseAgent):
|
||||
call_id = None
|
||||
while True:
|
||||
try:
|
||||
next(gen)
|
||||
event = next(gen)
|
||||
# Log tool_call status events instead of discarding them
|
||||
if isinstance(event, dict) and event.get("type") == "tool_call":
|
||||
logger.debug(
|
||||
"Tool %s status: %s",
|
||||
event.get("data", {}).get("action_name", ""),
|
||||
event.get("data", {}).get("status", ""),
|
||||
)
|
||||
except StopIteration as e:
|
||||
result, call_id = e.value
|
||||
break
|
||||
|
||||
@@ -354,6 +354,59 @@ def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
|
||||
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
||||
|
||||
|
||||
def sources_have_directory_structure(source: Dict) -> bool:
|
||||
"""Check if any of the active sources have directory_structure in MongoDB."""
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return False
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)},
|
||||
{"directory_structure": 1},
|
||||
)
|
||||
if source_doc and source_doc.get("directory_structure"):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory structure: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
|
||||
"""Add the internal search tool to tools_dict if sources are configured.
|
||||
|
||||
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
|
||||
Mutates tools_dict in place.
|
||||
"""
|
||||
source = retriever_config.get("source", {})
|
||||
has_sources = bool(source.get("active_docs"))
|
||||
if not retriever_config or not has_sources:
|
||||
return
|
||||
|
||||
has_dir = sources_have_directory_structure(source)
|
||||
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
|
||||
internal_entry["config"] = build_internal_tool_config(
|
||||
**retriever_config,
|
||||
has_directory_structure=has_dir,
|
||||
)
|
||||
tools_dict[INTERNAL_TOOL_ID] = internal_entry
|
||||
|
||||
|
||||
def build_internal_tool_config(
|
||||
source: Dict,
|
||||
retriever_name: str = "classic",
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflows.schemas import AgentType
|
||||
|
||||
|
||||
@@ -35,7 +37,8 @@ class ToolFilterMixin:
|
||||
return filtered_tools
|
||||
|
||||
|
||||
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
|
||||
class _WorkflowNodeMixin:
|
||||
"""Common __init__ for all workflow node agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -56,11 +59,25 @@ class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgentFactory:
|
||||
|
||||
_agents: Dict[AgentType, Type[BaseAgent]] = {
|
||||
AgentType.CLASSIC: WorkflowNodeClassicAgent,
|
||||
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
|
||||
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
|
||||
AgentType.RESEARCH: WorkflowNodeResearchAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -18,6 +18,8 @@ class NodeType(str, Enum):
|
||||
class AgentType(str, Enum):
|
||||
CLASSIC = "classic"
|
||||
REACT = "react"
|
||||
AGENTIC = "agentic"
|
||||
RESEARCH = "research"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
|
||||
@@ -7,6 +7,7 @@ from application.agents.workflows.cel_evaluator import CelEvaluationError, evalu
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
AgentType,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
@@ -223,18 +224,32 @@ class WorkflowEngine:
|
||||
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||
)
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(
|
||||
agent_type=node_config.agent_type,
|
||||
endpoint=self.agent.endpoint,
|
||||
llm_name=node_llm_name,
|
||||
model_id=node_model_id,
|
||||
api_key=node_api_key,
|
||||
tool_ids=node_config.tools,
|
||||
prompt=node_config.system_prompt,
|
||||
chat_history=self.agent.chat_history,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
json_schema=node_json_schema,
|
||||
)
|
||||
factory_kwargs = {
|
||||
"agent_type": node_config.agent_type,
|
||||
"endpoint": self.agent.endpoint,
|
||||
"llm_name": node_llm_name,
|
||||
"model_id": node_model_id,
|
||||
"api_key": node_api_key,
|
||||
"tool_ids": node_config.tools,
|
||||
"prompt": node_config.system_prompt,
|
||||
"chat_history": self.agent.chat_history,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
"json_schema": node_json_schema,
|
||||
}
|
||||
|
||||
# Agentic/research agents need retriever_config for on-demand search
|
||||
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
|
||||
factory_kwargs["retriever_config"] = {
|
||||
"source": {"active_docs": node_config.sources} if node_config.sources else {},
|
||||
"retriever_name": node_config.retriever or "classic",
|
||||
"chunks": int(node_config.chunks) if node_config.chunks else 2,
|
||||
"model_id": node_model_id,
|
||||
"llm_name": node_llm_name,
|
||||
"api_key": node_api_key,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
}
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
|
||||
|
||||
full_response_parts: List[str] = []
|
||||
structured_response_parts: List[str] = []
|
||||
|
||||
@@ -13,4 +13,4 @@ Use the search_internal tool to find relevant information before answering quest
|
||||
You may search multiple times with different queries if needed.
|
||||
Do not guess when documents are available — search first, then answer based on what you find.
|
||||
If no relevant documents are found, use your general knowledge and tool capabilities.
|
||||
Allow yourself to be very creative and use your imagination.
|
||||
Allow yourself to be very creative and use your imagination.
|
||||
|
||||
@@ -12,4 +12,4 @@ You have access to a search tool that searches the user's uploaded documents and
|
||||
Use the search_internal tool to find relevant information before answering questions.
|
||||
You may search multiple times with different queries if needed.
|
||||
Do not guess when documents are available — search first, then answer based on what you find.
|
||||
If no relevant documents are found, use your general knowledge and tool capabilities.
|
||||
If no relevant documents are found, use your general knowledge and tool capabilities.
|
||||
|
||||
@@ -13,4 +13,4 @@ Use the search_internal tool to find relevant information before answering quest
|
||||
You may search multiple times with different queries if needed.
|
||||
You MUST search before answering any factual question. Do not guess or use general knowledge when documents are available.
|
||||
If you dont have enough information from the search results or tools, answer "I don't know" or "I don't have enough information".
|
||||
Never make up information or provide false information!
|
||||
Never make up information or provide false information!
|
||||
|
||||
@@ -20,4 +20,4 @@ You MUST respond with ONLY a valid JSON object (no markdown, no code fences):
|
||||
"needs_clarification": true or false,
|
||||
"questions": ["question 1", "question 2"] (only if needs_clarification is true, max 3 questions),
|
||||
"reason": "brief explanation of why clarification is or isn't needed"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,4 +20,4 @@ You MUST respond with ONLY a valid JSON object in this exact format (no markdown
|
||||
"steps": [
|
||||
{"query": "specific sub-question to investigate", "rationale": "why this step is needed"}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,4 +10,4 @@ Instructions:
|
||||
5. Cite specific documents and passages you found. Reference sources by their titles or filenames.
|
||||
6. If you cannot find relevant information through tools, use your general knowledge but clearly indicate this.
|
||||
|
||||
Be thorough — prefer completeness over brevity. Include all relevant details you find.
|
||||
Be thorough — prefer completeness over brevity. Include all relevant details you find.
|
||||
|
||||
@@ -19,4 +19,4 @@ Write a well-structured, thorough report that:
|
||||
Available sources for citation:
|
||||
{references}
|
||||
|
||||
Format the report with clear headings and sections. Be comprehensive but well-organized.
|
||||
Format the report with clear headings and sections. Be comprehensive but well-organized.
|
||||
|
||||
@@ -55,7 +55,7 @@ import WorkflowPreview from './workflow/WorkflowPreview';
|
||||
|
||||
import type { Model } from '../models/types';
|
||||
interface AgentNodeConfig {
|
||||
agent_type: 'classic';
|
||||
agent_type: 'classic' | 'agentic' | 'research';
|
||||
llm_name?: string;
|
||||
model_id?: string;
|
||||
system_prompt: string;
|
||||
@@ -884,6 +884,12 @@ function WorkflowBuilderInner() {
|
||||
<SelectItem value="classic">
|
||||
Classic
|
||||
</SelectItem>
|
||||
<SelectItem value="agentic">
|
||||
Agentic
|
||||
</SelectItem>
|
||||
<SelectItem value="research">
|
||||
Research
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
@@ -75,7 +75,7 @@ import type { Model } from '../../models/types';
|
||||
const PRIMARY_ACTION_SPINNER_DELAY_MS = 180;
|
||||
|
||||
interface AgentNodeConfig {
|
||||
agent_type: 'classic';
|
||||
agent_type: 'classic' | 'agentic' | 'research';
|
||||
llm_name?: string;
|
||||
model_id?: string;
|
||||
system_prompt: string;
|
||||
@@ -1748,6 +1748,12 @@ function WorkflowBuilderInner() {
|
||||
<SelectItem value="classic">
|
||||
Classic
|
||||
</SelectItem>
|
||||
<SelectItem value="agentic">
|
||||
Agentic
|
||||
</SelectItem>
|
||||
<SelectItem value="research">
|
||||
Research
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
139
tests/agents/test_agentic_agent.py
Normal file
139
tests/agents/test_agentic_agent.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Tests for AgenticAgent — LLM-controlled retrieval agent."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgenticAgentInit:
|
||||
|
||||
def test_initialization(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = AgenticAgent(**agent_base_params)
|
||||
assert isinstance(agent, AgenticAgent)
|
||||
assert agent.retriever_config == {}
|
||||
|
||||
def test_initialization_with_retriever_config(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
rc = {"source": {"active_docs": ["abc"]}, "retriever_name": "classic"}
|
||||
agent = AgenticAgent(retriever_config=rc, **agent_base_params)
|
||||
assert agent.retriever_config == rc
|
||||
|
||||
def test_inherits_base_properties(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = AgenticAgent(**agent_base_params)
|
||||
assert agent.endpoint == agent_base_params["endpoint"]
|
||||
assert agent.llm_name == agent_base_params["llm_name"]
|
||||
assert agent.model_id == agent_base_params["model_id"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgenticAgentGenInner:
|
||||
|
||||
def test_basic_flow_yields_sources_and_tool_calls(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Processed"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = AgenticAgent(**agent_base_params)
|
||||
results = list(agent._gen_inner("Test query", log_context))
|
||||
|
||||
sources = [r for r in results if "sources" in r]
|
||||
tool_calls = [r for r in results if "tool_calls" in r]
|
||||
assert len(sources) == 1
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
def test_logs_agent_component(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Done"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = AgenticAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Query", log_context))
|
||||
|
||||
agent_logs = [s for s in log_context.stacks if s["component"] == "agent"]
|
||||
assert len(agent_logs) == 1
|
||||
assert "tool_calls" in agent_logs[0]["data"]
|
||||
|
||||
def test_no_pre_fetched_docs_in_messages(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
mock_mongo_db,
|
||||
log_context,
|
||||
):
|
||||
mock_llm.gen_stream = Mock(return_value=iter(["Answer"]))
|
||||
|
||||
def mock_handler(*args, **kwargs):
|
||||
yield "Done"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler)
|
||||
|
||||
agent = AgenticAgent(**agent_base_params)
|
||||
list(agent._gen_inner("Query", log_context))
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
messages = call_kwargs["messages"]
|
||||
# System prompt should not contain {summaries} replacement
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == "Query"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgenticAgentCollectSources:
|
||||
|
||||
def test_collect_internal_sources_from_cache(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = AgenticAgent(**agent_base_params)
|
||||
|
||||
mock_tool = Mock()
|
||||
mock_tool.retrieved_docs = [
|
||||
{"text": "Found", "title": "Doc", "source": "test"},
|
||||
]
|
||||
cache_key = f"internal_search:internal:{agent.user or ''}"
|
||||
agent.tool_executor._loaded_tools[cache_key] = mock_tool
|
||||
|
||||
agent._collect_internal_sources()
|
||||
assert len(agent.retrieved_docs) == 1
|
||||
assert agent.retrieved_docs[0]["title"] == "Doc"
|
||||
|
||||
def test_collect_internal_sources_no_cache(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = AgenticAgent(**agent_base_params)
|
||||
agent._collect_internal_sources()
|
||||
assert agent.retrieved_docs == []
|
||||
250
tests/agents/test_internal_search_tool.py
Normal file
250
tests/agents/test_internal_search_tool.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for InternalSearchTool and its helper functions."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
InternalSearchTool,
|
||||
add_internal_search_tool,
|
||||
build_internal_tool_config,
|
||||
build_internal_tool_entry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolSearch:
|
||||
|
||||
def _make_tool(self, **config_overrides):
|
||||
config = {"source": {}, "retriever_name": "classic", "chunks": 2}
|
||||
config.update(config_overrides)
|
||||
return InternalSearchTool(config)
|
||||
|
||||
def test_search_no_query_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
result = tool.execute_action("search", query="")
|
||||
assert "required" in result.lower()
|
||||
|
||||
def test_search_returns_formatted_docs(self):
|
||||
tool = self._make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "Hello world", "title": "Doc1", "source": "test", "filename": "doc1.md"},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="hello")
|
||||
assert "doc1.md" in result
|
||||
assert "Hello world" in result
|
||||
assert len(tool.retrieved_docs) == 1
|
||||
|
||||
def test_search_no_results(self):
|
||||
tool = self._make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = []
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="nonexistent")
|
||||
assert "No documents found" in result
|
||||
|
||||
def test_search_accumulates_docs(self):
|
||||
tool = self._make_tool()
|
||||
mock_retriever = Mock()
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "D1", "source": "s1"},
|
||||
]
|
||||
tool.execute_action("search", query="first")
|
||||
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "B", "title": "D2", "source": "s2"},
|
||||
]
|
||||
tool.execute_action("search", query="second")
|
||||
|
||||
assert len(tool.retrieved_docs) == 2
|
||||
|
||||
def test_search_deduplicates_docs(self):
|
||||
tool = self._make_tool()
|
||||
doc = {"text": "Same", "title": "Same", "source": "same"}
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [doc]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
tool.execute_action("search", query="q1")
|
||||
tool.execute_action("search", query="q2")
|
||||
|
||||
assert len(tool.retrieved_docs) == 1
|
||||
|
||||
def test_search_with_path_filter(self):
|
||||
tool = self._make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "T", "source": "src/main.py", "filename": "main.py"},
|
||||
{"text": "B", "title": "T", "source": "docs/readme.md", "filename": "readme.md"},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="code", path_filter="src/")
|
||||
assert "main.py" in result
|
||||
assert "readme.md" not in result
|
||||
|
||||
def test_search_path_filter_no_match(self):
|
||||
tool = self._make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.return_value = [
|
||||
{"text": "A", "title": "T", "source": "other/file.txt"},
|
||||
]
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="code", path_filter="src/")
|
||||
assert "No documents found" in result
|
||||
|
||||
def test_search_retriever_error(self):
|
||||
tool = self._make_tool()
|
||||
mock_retriever = Mock()
|
||||
mock_retriever.search.side_effect = Exception("Connection error")
|
||||
tool._retriever = mock_retriever
|
||||
|
||||
result = tool.execute_action("search", query="test")
|
||||
assert "failed" in result.lower() or "error" in result.lower()
|
||||
|
||||
def test_unknown_action(self):
|
||||
tool = self._make_tool()
|
||||
result = tool.execute_action("nonexistent")
|
||||
assert "Unknown action" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolListFiles:
|
||||
|
||||
def test_list_files_no_structure(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = None
|
||||
|
||||
result = tool.execute_action("list_files")
|
||||
assert "No file structure" in result
|
||||
|
||||
def test_list_files_root(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {
|
||||
"src": {"main.py": {}},
|
||||
"README.md": {"type": "md", "token_count": 100},
|
||||
}
|
||||
|
||||
result = tool.execute_action("list_files")
|
||||
assert "src/" in result
|
||||
assert "README.md" in result
|
||||
|
||||
def test_list_files_nested_path(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {
|
||||
"src": {
|
||||
"utils": {"helper.py": {}},
|
||||
},
|
||||
}
|
||||
|
||||
result = tool.execute_action("list_files", path="src")
|
||||
assert "utils/" in result
|
||||
|
||||
def test_list_files_invalid_path(self):
|
||||
tool = InternalSearchTool({"source": {}})
|
||||
tool._dir_structure_loaded = True
|
||||
tool._directory_structure = {"src": {}}
|
||||
|
||||
result = tool.execute_action("list_files", path="nonexistent")
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInternalSearchToolMetadata:
|
||||
|
||||
def test_actions_without_directory_structure(self):
|
||||
tool = InternalSearchTool({"has_directory_structure": False})
|
||||
meta = tool.get_actions_metadata()
|
||||
|
||||
action_names = [a["name"] for a in meta]
|
||||
assert "search" in action_names
|
||||
assert "list_files" not in action_names
|
||||
|
||||
# search should not have path_filter
|
||||
search = meta[0]
|
||||
assert "path_filter" not in search["parameters"]["properties"]
|
||||
|
||||
def test_actions_with_directory_structure(self):
|
||||
tool = InternalSearchTool({"has_directory_structure": True})
|
||||
meta = tool.get_actions_metadata()
|
||||
|
||||
action_names = [a["name"] for a in meta]
|
||||
assert "search" in action_names
|
||||
assert "list_files" in action_names
|
||||
|
||||
# search should have path_filter
|
||||
search = next(a for a in meta if a["name"] == "search")
|
||||
assert "path_filter" in search["parameters"]["properties"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildHelpers:
|
||||
|
||||
def test_build_entry_without_directory_structure(self):
|
||||
entry = build_internal_tool_entry(has_directory_structure=False)
|
||||
assert entry["name"] == "internal_search"
|
||||
action_names = [a["name"] for a in entry["actions"]]
|
||||
assert "search" in action_names
|
||||
assert "list_files" not in action_names
|
||||
|
||||
def test_build_entry_with_directory_structure(self):
|
||||
entry = build_internal_tool_entry(has_directory_structure=True)
|
||||
action_names = [a["name"] for a in entry["actions"]]
|
||||
assert "list_files" in action_names
|
||||
|
||||
def test_build_config(self):
|
||||
config = build_internal_tool_config(
|
||||
source={"active_docs": ["abc"]},
|
||||
retriever_name="semantic",
|
||||
chunks=4,
|
||||
)
|
||||
assert config["source"] == {"active_docs": ["abc"]}
|
||||
assert config["retriever_name"] == "semantic"
|
||||
assert config["chunks"] == 4
|
||||
|
||||
def test_internal_tool_id(self):
|
||||
assert INTERNAL_TOOL_ID == "internal"
|
||||
|
||||
def test_add_internal_search_tool_with_sources(self):
|
||||
tools_dict = {}
|
||||
retriever_config = {
|
||||
"source": {"active_docs": ["abc"]},
|
||||
"retriever_name": "classic",
|
||||
"chunks": 2,
|
||||
"model_id": "gpt-4",
|
||||
"llm_name": "openai",
|
||||
"api_key": "key",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.internal_search.sources_have_directory_structure",
|
||||
return_value=False,
|
||||
):
|
||||
add_internal_search_tool(tools_dict, retriever_config)
|
||||
|
||||
assert INTERNAL_TOOL_ID in tools_dict
|
||||
assert tools_dict[INTERNAL_TOOL_ID]["name"] == "internal_search"
|
||||
assert "config" in tools_dict[INTERNAL_TOOL_ID]
|
||||
|
||||
def test_add_internal_search_tool_no_sources(self):
|
||||
tools_dict = {}
|
||||
retriever_config = {"source": {}}
|
||||
|
||||
add_internal_search_tool(tools_dict, retriever_config)
|
||||
|
||||
assert INTERNAL_TOOL_ID not in tools_dict
|
||||
|
||||
def test_add_internal_search_tool_empty_config(self):
|
||||
tools_dict = {}
|
||||
add_internal_search_tool(tools_dict, {})
|
||||
assert INTERNAL_TOOL_ID not in tools_dict
|
||||
400
tests/agents/test_research_agent.py
Normal file
400
tests/agents/test_research_agent.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Tests for ResearchAgent — multi-step research with budget controls."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from application.agents.research_agent import (
|
||||
CitationManager,
|
||||
ResearchAgent,
|
||||
DEFAULT_MAX_STEPS,
|
||||
DEFAULT_TIMEOUT_SECONDS,
|
||||
DEFAULT_TOKEN_BUDGET,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CitationManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCitationManager:
|
||||
|
||||
def test_add_returns_citation_number(self):
|
||||
cm = CitationManager()
|
||||
num = cm.add({"source": "s1", "title": "T1"})
|
||||
assert num == 1
|
||||
|
||||
def test_add_deduplicates(self):
|
||||
cm = CitationManager()
|
||||
n1 = cm.add({"source": "s1", "title": "T1"})
|
||||
n2 = cm.add({"source": "s1", "title": "T1"})
|
||||
assert n1 == n2
|
||||
assert len(cm.citations) == 1
|
||||
|
||||
def test_add_different_sources(self):
|
||||
cm = CitationManager()
|
||||
n1 = cm.add({"source": "s1", "title": "T1"})
|
||||
n2 = cm.add({"source": "s2", "title": "T2"})
|
||||
assert n1 != n2
|
||||
assert len(cm.citations) == 2
|
||||
|
||||
def test_add_docs_returns_mapping(self):
|
||||
cm = CitationManager()
|
||||
docs = [
|
||||
{"source": "s1", "title": "Doc A"},
|
||||
{"source": "s2", "title": "Doc B"},
|
||||
]
|
||||
text = cm.add_docs(docs)
|
||||
assert "[1] Doc A" in text
|
||||
assert "[2] Doc B" in text
|
||||
|
||||
def test_format_references(self):
|
||||
cm = CitationManager()
|
||||
cm.add({"source": "http://example.com", "title": "Example", "filename": "ex.md"})
|
||||
refs = cm.format_references()
|
||||
assert "[1]" in refs
|
||||
assert "ex.md" in refs
|
||||
|
||||
def test_format_references_empty(self):
|
||||
cm = CitationManager()
|
||||
assert "No sources" in cm.format_references()
|
||||
|
||||
def test_get_all_docs(self):
|
||||
cm = CitationManager()
|
||||
cm.add({"source": "s1", "title": "T1"})
|
||||
cm.add({"source": "s2", "title": "T2"})
|
||||
docs = cm.get_all_docs()
|
||||
assert len(docs) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResearchAgent Init & Budget
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResearchAgentInit:
|
||||
|
||||
def test_initialization(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
assert isinstance(agent, ResearchAgent)
|
||||
assert agent.max_steps == DEFAULT_MAX_STEPS
|
||||
assert agent.timeout_seconds == DEFAULT_TIMEOUT_SECONDS
|
||||
assert agent.token_budget == DEFAULT_TOKEN_BUDGET
|
||||
assert agent.retriever_config == {}
|
||||
|
||||
def test_custom_budget(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ResearchAgent(
|
||||
max_steps=3,
|
||||
timeout_seconds=60,
|
||||
token_budget=50_000,
|
||||
**agent_base_params,
|
||||
)
|
||||
assert agent.max_steps == 3
|
||||
assert agent.timeout_seconds == 60
|
||||
assert agent.token_budget == 50_000
|
||||
|
||||
def test_with_retriever_config(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
rc = {"source": {"active_docs": ["abc"]}}
|
||||
agent = ResearchAgent(retriever_config=rc, **agent_base_params)
|
||||
assert agent.retriever_config == rc
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResearchAgentBudget:
|
||||
|
||||
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, **kwargs):
|
||||
return ResearchAgent(**kwargs, **agent_base_params)
|
||||
|
||||
def test_timeout_detection(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
|
||||
timeout_seconds=0,
|
||||
)
|
||||
agent._start_time = time.monotonic() - 1
|
||||
assert agent._is_timed_out() is True
|
||||
|
||||
def test_not_timed_out(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
|
||||
timeout_seconds=300,
|
||||
)
|
||||
agent._start_time = time.monotonic()
|
||||
assert agent._is_timed_out() is False
|
||||
|
||||
def test_token_budget_tracking(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
|
||||
token_budget=1000,
|
||||
)
|
||||
agent._track_tokens(500)
|
||||
assert agent._budget_remaining() == 500
|
||||
assert agent._is_over_budget() is False
|
||||
|
||||
agent._track_tokens(500)
|
||||
assert agent._budget_remaining() == 0
|
||||
assert agent._is_over_budget() is True
|
||||
|
||||
def test_snapshot_llm_tokens_returns_delta(
|
||||
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
|
||||
)
|
||||
mock_llm.token_usage = {"prompt_tokens": 100, "generated_tokens": 50}
|
||||
|
||||
delta1 = agent._snapshot_llm_tokens()
|
||||
assert delta1 == 150
|
||||
|
||||
# Simulate more tokens used
|
||||
mock_llm.token_usage = {"prompt_tokens": 200, "generated_tokens": 100}
|
||||
delta2 = agent._snapshot_llm_tokens()
|
||||
assert delta2 == 150 # 300 - 150
|
||||
|
||||
def test_elapsed(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator,
|
||||
)
|
||||
agent._start_time = time.monotonic() - 1.5
|
||||
elapsed = agent._elapsed()
|
||||
assert elapsed >= 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResearchAgent Phases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResearchAgentClarification:
|
||||
|
||||
def test_is_follow_up_no_history(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
assert agent._is_follow_up() is False
|
||||
|
||||
def test_is_follow_up_with_clarification_metadata(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["chat_history"] = [
|
||||
{"prompt": "What?", "response": "Clarify", "metadata": {"is_clarification": True}},
|
||||
]
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
assert agent._is_follow_up() is True
|
||||
|
||||
def test_is_follow_up_without_metadata(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["chat_history"] = [
|
||||
{"prompt": "What?", "response": "Normal answer"},
|
||||
]
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
assert agent._is_follow_up() is False
|
||||
|
||||
def test_clarification_returns_none_on_no_clarification_needed(
|
||||
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
response = Mock()
|
||||
response.choices = [Mock()]
|
||||
response.choices[0].message = Mock()
|
||||
response.choices[0].message.content = json.dumps(
|
||||
{"needs_clarification": False, "reason": "Clear enough"}
|
||||
)
|
||||
mock_llm.gen = Mock(return_value=response)
|
||||
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
|
||||
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
result = agent._clarification_phase("What is Python?")
|
||||
assert result is None
|
||||
|
||||
def test_clarification_returns_questions(
|
||||
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
clarification_json = json.dumps({
|
||||
"needs_clarification": True,
|
||||
"questions": ["Which version?", "What context?"],
|
||||
})
|
||||
# Return a plain string so _extract_text handles it directly
|
||||
mock_llm.gen = Mock(return_value=clarification_json)
|
||||
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
|
||||
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
result = agent._clarification_phase("Tell me about it")
|
||||
assert result is not None
|
||||
assert "Which version?" in result
|
||||
assert "What context?" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResearchAgentPlanning:
|
||||
|
||||
def test_planning_returns_steps_and_complexity(
|
||||
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
plan_json = json.dumps({
|
||||
"complexity": "moderate",
|
||||
"steps": [
|
||||
{"query": "sub-question 1", "rationale": "reason 1"},
|
||||
{"query": "sub-question 2", "rationale": "reason 2"},
|
||||
],
|
||||
})
|
||||
# Return plain string so _extract_text handles it directly
|
||||
mock_llm.gen = Mock(return_value=plan_json)
|
||||
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
|
||||
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
steps, complexity = agent._planning_phase("Compare A and B")
|
||||
|
||||
assert complexity == "moderate"
|
||||
assert len(steps) == 2
|
||||
assert steps[0]["query"] == "sub-question 1"
|
||||
|
||||
def test_planning_caps_steps_by_complexity(
|
||||
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
plan_json = json.dumps({
|
||||
"complexity": "simple",
|
||||
"steps": [
|
||||
{"query": f"q{i}", "rationale": f"r{i}"} for i in range(10)
|
||||
],
|
||||
})
|
||||
response = Mock()
|
||||
response.choices = [Mock()]
|
||||
response.choices[0].message = Mock()
|
||||
response.choices[0].message.content = plan_json
|
||||
mock_llm.gen = Mock(return_value=response)
|
||||
mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5}
|
||||
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
steps, complexity = agent._planning_phase("Simple question")
|
||||
|
||||
assert complexity == "simple"
|
||||
assert len(steps) <= 2 # COMPLEXITY_CAPS["simple"] == 2
|
||||
|
||||
def test_planning_fallback_on_error(
|
||||
self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
mock_llm.gen = Mock(side_effect=Exception("LLM down"))
|
||||
mock_llm.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
agent = ResearchAgent(**agent_base_params)
|
||||
steps, complexity = agent._planning_phase("Anything")
|
||||
|
||||
assert complexity == "simple"
|
||||
assert len(steps) == 1
|
||||
assert steps[0]["query"] == "Anything"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResearchAgentExtractText:
|
||||
|
||||
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
|
||||
return ResearchAgent(**agent_base_params)
|
||||
|
||||
def test_extract_from_string(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
assert agent._extract_text("hello") == "hello"
|
||||
|
||||
def test_extract_from_openai_response(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
response = Mock()
|
||||
response.choices = [Mock()]
|
||||
response.choices[0].message = Mock()
|
||||
response.choices[0].message.content = "OpenAI content"
|
||||
response.message = None
|
||||
response.content = None
|
||||
assert agent._extract_text(response) == "OpenAI content"
|
||||
|
||||
def test_extract_from_anthropic_response(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
text_block = Mock()
|
||||
text_block.text = "Anthropic content"
|
||||
response = Mock()
|
||||
response.content = [text_block]
|
||||
response.message = None
|
||||
response.choices = None
|
||||
assert agent._extract_text(response) == "Anthropic content"
|
||||
|
||||
def test_extract_from_none(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
assert agent._extract_text(None) == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestResearchAgentParseJson:
|
||||
|
||||
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
|
||||
return ResearchAgent(**agent_base_params)
|
||||
|
||||
def test_parse_plan_direct_json(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
text = '{"steps": [{"query": "q1"}], "complexity": "simple"}'
|
||||
result = agent._parse_plan_json(text)
|
||||
assert isinstance(result, dict)
|
||||
assert len(result["steps"]) == 1
|
||||
|
||||
def test_parse_plan_from_code_fence(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
text = 'Here is the plan:\n```json\n{"steps": [{"query": "q1"}]}\n```'
|
||||
result = agent._parse_plan_json(text)
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_parse_plan_invalid_returns_empty(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
result = agent._parse_plan_json("not json at all")
|
||||
assert result == []
|
||||
|
||||
def test_parse_clarification_json(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
text = '{"needs_clarification": false, "reason": "clear"}'
|
||||
result = agent._parse_clarification_json(text)
|
||||
assert result["needs_clarification"] is False
|
||||
|
||||
def test_parse_clarification_json_from_code_fence(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
text = '```json\n{"needs_clarification": true, "questions": ["q1"]}\n```'
|
||||
result = agent._parse_clarification_json(text)
|
||||
assert result["needs_clarification"] is True
|
||||
|
||||
def test_parse_clarification_json_invalid(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator)
|
||||
result = agent._parse_clarification_json("not json")
|
||||
assert result is None
|
||||
52
tests/agents/test_think_tool.py
Normal file
52
tests/agents/test_think_tool.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Tests for ThinkTool — the chain-of-thought pseudo-tool."""
|
||||
|
||||
import pytest
|
||||
from application.agents.tools.think import (
|
||||
THINK_TOOL_ENTRY,
|
||||
THINK_TOOL_ID,
|
||||
ThinkTool,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestThinkTool:
|
||||
|
||||
def test_id_constant(self):
|
||||
assert THINK_TOOL_ID == "think"
|
||||
|
||||
def test_entry_has_reason_action(self):
|
||||
actions = THINK_TOOL_ENTRY["actions"]
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["name"] == "reason"
|
||||
assert actions[0]["active"] is True
|
||||
|
||||
def test_execute_reason_returns_continue(self):
|
||||
tool = ThinkTool()
|
||||
result = tool.execute_action("reason", reasoning="step by step thinking")
|
||||
assert result == "Continue."
|
||||
|
||||
def test_execute_unknown_action_returns_continue(self):
|
||||
tool = ThinkTool()
|
||||
result = tool.execute_action("unknown_action")
|
||||
assert result == "Continue."
|
||||
|
||||
def test_get_actions_metadata(self):
|
||||
tool = ThinkTool()
|
||||
meta = tool.get_actions_metadata()
|
||||
assert len(meta) == 1
|
||||
assert meta[0]["name"] == "reason"
|
||||
props = meta[0]["parameters"]["properties"]
|
||||
assert "reasoning" in props
|
||||
assert props["reasoning"]["filled_by_llm"] is True
|
||||
|
||||
def test_get_config_requirements_empty(self):
|
||||
tool = ThinkTool()
|
||||
assert tool.get_config_requirements() == {}
|
||||
|
||||
def test_init_accepts_no_config(self):
|
||||
tool = ThinkTool()
|
||||
assert tool is not None
|
||||
|
||||
def test_init_accepts_config(self):
|
||||
tool = ThinkTool(config={"key": "value"})
|
||||
assert tool is not None
|
||||
279
tests/agents/test_tool_executor.py
Normal file
279
tests/agents/test_tool_executor.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tests for ToolExecutor — tool discovery, preparation, and execution."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolExecutorInit:
|
||||
|
||||
def test_default_state(self):
|
||||
executor = ToolExecutor()
|
||||
assert executor.user_api_key is None
|
||||
assert executor.user is None
|
||||
assert executor.tool_calls == []
|
||||
assert executor._loaded_tools == {}
|
||||
assert executor.conversation_id is None
|
||||
|
||||
def test_init_with_params(self):
|
||||
executor = ToolExecutor(
|
||||
user_api_key="key", user="alice", decoded_token={"sub": "alice"}
|
||||
)
|
||||
assert executor.user_api_key == "key"
|
||||
assert executor.user == "alice"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolExecutorGetTools:
|
||||
|
||||
def test_get_tools_uses_api_key_when_present(self, mock_mongo_db):
|
||||
executor = ToolExecutor(user_api_key="test_key", user="alice")
|
||||
tools = executor.get_tools()
|
||||
assert isinstance(tools, dict)
|
||||
|
||||
def test_get_tools_uses_user_when_no_api_key(self, mock_mongo_db):
|
||||
executor = ToolExecutor(user="alice")
|
||||
tools = executor.get_tools()
|
||||
assert isinstance(tools, dict)
|
||||
|
||||
def test_get_tools_defaults_to_local(self, mock_mongo_db):
|
||||
executor = ToolExecutor()
|
||||
tools = executor.get_tools()
|
||||
assert isinstance(tools, dict)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolExecutorPrepare:
|
||||
|
||||
def test_prepare_tools_for_llm_empty(self):
|
||||
executor = ToolExecutor()
|
||||
result = executor.prepare_tools_for_llm({})
|
||||
assert result == []
|
||||
|
||||
def test_prepare_tools_for_llm_non_api_tool(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"actions": [
|
||||
{
|
||||
"name": "do_thing",
|
||||
"description": "Does a thing",
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
result = executor.prepare_tools_for_llm(tools_dict)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function"
|
||||
assert result[0]["function"]["name"] == "do_thing_t1"
|
||||
assert "query" in result[0]["function"]["parameters"]["properties"]
|
||||
|
||||
def test_prepare_tools_skips_inactive_actions(self):
|
||||
executor = ToolExecutor()
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"actions": [
|
||||
{"name": "active_one", "description": "D", "active": True, "parameters": {"properties": {}}},
|
||||
{"name": "inactive_one", "description": "D", "active": False, "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
result = executor.prepare_tools_for_llm(tools_dict)
|
||||
assert len(result) == 1
|
||||
assert result[0]["function"]["name"] == "active_one_t1"
|
||||
|
||||
def test_build_tool_parameters_filters_non_llm_fields(self):
|
||||
executor = ToolExecutor()
|
||||
action = {
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
"filled_by_llm": True,
|
||||
"value": "default_val",
|
||||
"required": True,
|
||||
},
|
||||
"hidden": {
|
||||
"type": "string",
|
||||
"filled_by_llm": False,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = executor._build_tool_parameters(action)
|
||||
assert "query" in result["properties"]
|
||||
assert "hidden" not in result["properties"]
|
||||
assert "query" in result["required"]
|
||||
# filled_by_llm, value, required stripped from schema
|
||||
assert "filled_by_llm" not in result["properties"]["query"]
|
||||
assert "value" not in result["properties"]["query"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestToolExecutorExecute:
|
||||
|
||||
def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"):
|
||||
call = Mock()
|
||||
call.name = name
|
||||
call.id = call_id
|
||||
call.arguments = arguments
|
||||
return call
|
||||
|
||||
def test_execute_parse_failure(self, monkeypatch):
|
||||
executor = ToolExecutor()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=(None, None, {}))),
|
||||
)
|
||||
|
||||
call = self._make_call(name="bad")
|
||||
gen = executor.execute({}, call, "MockLLM")
|
||||
|
||||
events = []
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
result = e.value
|
||||
break
|
||||
|
||||
assert result[0] == "Failed to parse tool call."
|
||||
assert len(executor.tool_calls) == 1
|
||||
assert events[0]["data"]["status"] == "error"
|
||||
|
||||
def test_execute_tool_not_found(self, monkeypatch):
|
||||
executor = ToolExecutor()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))),
|
||||
)
|
||||
|
||||
call = self._make_call()
|
||||
gen = executor.execute({}, call, "MockLLM")
|
||||
|
||||
events = []
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
result = e.value
|
||||
break
|
||||
|
||||
assert "not found" in result[0]
|
||||
assert events[0]["data"]["status"] == "error"
|
||||
|
||||
def test_execute_success(self, mock_tool_manager, monkeypatch):
|
||||
executor = ToolExecutor(user="test_user")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))),
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "Test", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
call = self._make_call(name="test_action_t1", call_id="c1")
|
||||
gen = executor.execute(tools_dict, call, "MockLLM")
|
||||
|
||||
events = []
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
events.append(next(gen))
|
||||
except StopIteration as e:
|
||||
result = e.value
|
||||
break
|
||||
|
||||
assert result[0] == "Tool result"
|
||||
assert result[1] == "c1"
|
||||
|
||||
statuses = [e["data"]["status"] for e in events]
|
||||
assert "pending" in statuses
|
||||
assert "completed" in statuses
|
||||
|
||||
def test_get_truncated_tool_calls(self):
|
||||
executor = ToolExecutor()
|
||||
executor.tool_calls = [
|
||||
{
|
||||
"tool_name": "test",
|
||||
"call_id": "1",
|
||||
"action_name": "act",
|
||||
"arguments": {},
|
||||
"result": "A" * 100,
|
||||
}
|
||||
]
|
||||
|
||||
truncated = executor.get_truncated_tool_calls()
|
||||
assert len(truncated) == 1
|
||||
assert len(truncated[0]["result"]) <= 53
|
||||
assert truncated[0]["status"] == "completed"
|
||||
|
||||
def test_tool_caching(self, mock_tool_manager, monkeypatch):
|
||||
executor = ToolExecutor(user="test_user")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))),
|
||||
)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "Test", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
call = self._make_call(name="test_action_t1")
|
||||
|
||||
# First execution — loads tool
|
||||
gen = executor.execute(tools_dict, call, "MockLLM")
|
||||
while True:
|
||||
try:
|
||||
next(gen)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Second execution — should use cache
|
||||
gen = executor.execute(tools_dict, call, "MockLLM")
|
||||
while True:
|
||||
try:
|
||||
next(gen)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# load_tool called only once due to cache
|
||||
assert mock_tool_manager.load_tool.call_count == 1
|
||||
475
tests/agents/test_workflow_agent_types.py
Normal file
475
tests/agents/test_workflow_agent_types.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""Tests for new agent types (agentic, research) in the workflow builder."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflows.node_agent import (
|
||||
WorkflowNodeAgenticAgent,
|
||||
WorkflowNodeAgentFactory,
|
||||
WorkflowNodeClassicAgent,
|
||||
WorkflowNodeResearchAgent,
|
||||
)
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
AgentType,
|
||||
NodeType,
|
||||
Workflow,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class StubNodeAgent:
|
||||
"""Minimal agent stub that yields pre-defined events."""
|
||||
|
||||
def __init__(self, events):
|
||||
self.events = events
|
||||
|
||||
def gen(self, _prompt):
|
||||
yield from self.events
|
||||
|
||||
|
||||
def create_engine() -> WorkflowEngine:
|
||||
graph = WorkflowGraph(workflow=Workflow(name="Test"), nodes=[], edges=[])
|
||||
agent = SimpleNamespace(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
api_key="test-key",
|
||||
chat_history=[],
|
||||
decoded_token={"sub": "user-1"},
|
||||
)
|
||||
return WorkflowEngine(graph, agent)
|
||||
|
||||
|
||||
def create_agent_node(
|
||||
node_id: str,
|
||||
agent_type: str = "classic",
|
||||
sources: list = None,
|
||||
chunks: str = "2",
|
||||
retriever: str = "",
|
||||
output_variable: str = "",
|
||||
) -> WorkflowNode:
|
||||
config: Dict[str, Any] = {
|
||||
"agent_type": agent_type,
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
"tools": [],
|
||||
"sources": sources or [],
|
||||
"chunks": chunks,
|
||||
"retriever": retriever,
|
||||
}
|
||||
if output_variable:
|
||||
config["output_variable"] = output_variable
|
||||
return WorkflowNode(
|
||||
id=node_id,
|
||||
workflow_id="workflow-1",
|
||||
type=NodeType.AGENT,
|
||||
title="Agent",
|
||||
position={"x": 0, "y": 0},
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentType enum
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentTypeEnum:
|
||||
|
||||
def test_agentic_value_exists(self):
|
||||
assert AgentType.AGENTIC == "agentic"
|
||||
|
||||
def test_research_value_exists(self):
|
||||
assert AgentType.RESEARCH == "research"
|
||||
|
||||
def test_classic_still_exists(self):
|
||||
assert AgentType.CLASSIC == "classic"
|
||||
|
||||
def test_react_still_exists(self):
|
||||
assert AgentType.REACT == "react"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentNodeConfig schema validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAgentNodeConfigValidation:
|
||||
|
||||
def test_accepts_agentic_agent_type(self):
|
||||
config = AgentNodeConfig(agent_type="agentic")
|
||||
assert config.agent_type == AgentType.AGENTIC
|
||||
|
||||
def test_accepts_research_agent_type(self):
|
||||
config = AgentNodeConfig(agent_type="research")
|
||||
assert config.agent_type == AgentType.RESEARCH
|
||||
|
||||
def test_rejects_unknown_agent_type(self):
|
||||
with pytest.raises(Exception):
|
||||
AgentNodeConfig(agent_type="nonexistent")
|
||||
|
||||
def test_default_agent_type_is_classic(self):
|
||||
config = AgentNodeConfig()
|
||||
assert config.agent_type == AgentType.CLASSIC
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowNodeAgentFactory registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowNodeAgentFactoryRegistry:
|
||||
|
||||
def test_factory_has_agentic(self):
|
||||
assert AgentType.AGENTIC in WorkflowNodeAgentFactory._agents
|
||||
assert WorkflowNodeAgentFactory._agents[AgentType.AGENTIC] is WorkflowNodeAgenticAgent
|
||||
|
||||
def test_factory_has_research(self):
|
||||
assert AgentType.RESEARCH in WorkflowNodeAgentFactory._agents
|
||||
assert WorkflowNodeAgentFactory._agents[AgentType.RESEARCH] is WorkflowNodeResearchAgent
|
||||
|
||||
def test_factory_raises_for_unknown_type(self):
|
||||
with pytest.raises(ValueError, match="Unsupported agent type"):
|
||||
WorkflowNodeAgentFactory.create(
|
||||
agent_type="nonexistent",
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
api_key="key",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowNode agent classes (inheritance)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWorkflowNodeAgentClasses:
|
||||
|
||||
def test_agentic_agent_inherits_correctly(self):
|
||||
assert issubclass(WorkflowNodeAgenticAgent, AgenticAgent)
|
||||
|
||||
def test_research_agent_inherits_correctly(self):
|
||||
assert issubclass(WorkflowNodeResearchAgent, ResearchAgent)
|
||||
|
||||
def test_classic_agent_inherits_correctly(self):
|
||||
assert issubclass(WorkflowNodeClassicAgent, ClassicAgent)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow engine: agentic agent node execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWorkflowEngineAgenticNode:
|
||||
|
||||
def test_agentic_node_executes_and_saves_output(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_agentic",
|
||||
agent_type="agentic",
|
||||
output_variable="result",
|
||||
)
|
||||
node_events = [{"answer": "agentic answer"}]
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return StubNodeAgent(node_events)
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(capture_create),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert engine.state["node_agent_agentic_output"] == "agentic answer"
|
||||
assert engine.state["result"] == "agentic answer"
|
||||
|
||||
def test_agentic_node_passes_retriever_config(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_rc",
|
||||
agent_type="agentic",
|
||||
sources=["source-abc"],
|
||||
chunks="4",
|
||||
retriever="semantic",
|
||||
)
|
||||
node_events = [{"answer": "ok"}]
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return StubNodeAgent(node_events)
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(capture_create),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
rc = captured.get("retriever_config")
|
||||
assert rc is not None
|
||||
assert rc["source"] == {"active_docs": ["source-abc"]}
|
||||
assert rc["retriever_name"] == "semantic"
|
||||
assert rc["chunks"] == 4
|
||||
|
||||
def test_agentic_node_empty_sources_gives_empty_source_dict(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_nosrc",
|
||||
agent_type="agentic",
|
||||
sources=[],
|
||||
)
|
||||
node_events = [{"answer": "ok"}]
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return StubNodeAgent(node_events)
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(capture_create),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
rc = captured["retriever_config"]
|
||||
assert rc["source"] == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow engine: research agent node execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWorkflowEngineResearchNode:
|
||||
|
||||
def test_research_node_executes_and_saves_output(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_research",
|
||||
agent_type="research",
|
||||
output_variable="report",
|
||||
)
|
||||
node_events = [{"answer": "research report"}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert engine.state["node_agent_research_output"] == "research report"
|
||||
assert engine.state["report"] == "research report"
|
||||
|
||||
def test_research_node_passes_retriever_config(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_rr",
|
||||
agent_type="research",
|
||||
sources=["doc-1", "doc-2"],
|
||||
chunks="6",
|
||||
)
|
||||
node_events = [{"answer": "ok"}]
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return StubNodeAgent(node_events)
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(capture_create),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
rc = captured["retriever_config"]
|
||||
assert rc["source"] == {"active_docs": ["doc-1", "doc-2"]}
|
||||
assert rc["chunks"] == 6
|
||||
assert rc["decoded_token"] == {"sub": "user-1"}
|
||||
|
||||
def test_research_node_handles_structured_output(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_rs",
|
||||
agent_type="research",
|
||||
output_variable="data",
|
||||
)
|
||||
# Simulate structured JSON output from research agent
|
||||
node_events = [
|
||||
{"answer": '{"findings": "important"}', "structured": True},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
# structured=True causes the engine to parse JSON
|
||||
assert engine.state["data"] == {"findings": "important"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow engine: classic node does NOT get retriever_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWorkflowEngineClassicNodeNoRetrieverConfig:
|
||||
|
||||
def test_classic_node_does_not_pass_retriever_config(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_classic",
|
||||
agent_type="classic",
|
||||
sources=["some-source"],
|
||||
)
|
||||
node_events = [{"answer": "classic answer"}]
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return StubNodeAgent(node_events)
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(capture_create),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert "retriever_config" not in captured
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow engine: streaming events from new agent types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWorkflowEngineStreamingEvents:
|
||||
|
||||
def test_agentic_node_streams_answer_events(self, monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(node_id="agent_s1", agent_type="agentic")
|
||||
# Modify config to enable streaming
|
||||
node.config["stream_to_user"] = True
|
||||
|
||||
node_events = [
|
||||
{"answer": "chunk 1"},
|
||||
{"answer": "chunk 2"},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
results = list(engine._execute_agent_node(node))
|
||||
answer_events = [r for r in results if "answer" in r]
|
||||
assert len(answer_events) == 2
|
||||
|
||||
def test_research_node_passes_through_non_answer_events(self, monkeypatch):
|
||||
"""Research agents yield research_plan/research_progress events.
|
||||
The workflow engine only forwards 'answer' events to the user."""
|
||||
engine = create_engine()
|
||||
node = create_agent_node(node_id="agent_s2", agent_type="research")
|
||||
node.config["stream_to_user"] = True
|
||||
|
||||
node_events = [
|
||||
{"type": "research_plan", "data": {"steps": [], "complexity": "simple"}},
|
||||
{"type": "research_progress", "data": {"status": "planning"}},
|
||||
{"answer": "final report"},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
results = list(engine._execute_agent_node(node))
|
||||
# Only answer events are streamed to user
|
||||
answer_events = [r for r in results if "answer" in r]
|
||||
assert len(answer_events) == 1
|
||||
assert answer_events[0]["answer"] == "final report"
|
||||
|
||||
# State still captures the full text
|
||||
assert engine.state["node_agent_s2_output"] == "final report"
|
||||
@@ -492,6 +492,486 @@ This is test documentation for integration tests.
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Agentic / Research Agent Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _create_agent_with_type(self, agent_type: str) -> Optional[tuple]:
|
||||
"""Create a test agent with the given agent_type. Returns (agent_id, api_key) or None."""
|
||||
if not self.is_authenticated:
|
||||
return None
|
||||
|
||||
payload = {
|
||||
"name": f"Chat Test {agent_type.title()} Agent {int(time.time())}",
|
||||
"description": f"Integration test {agent_type} agent",
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"retriever": "classic",
|
||||
"agent_type": agent_type,
|
||||
"status": "draft",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_agent", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
agent_id = result.get("id")
|
||||
api_key = result.get("key")
|
||||
if agent_id:
|
||||
return (agent_id, api_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def test_stream_agentic_agent(self) -> bool:
|
||||
"""Test /stream endpoint with an agentic agent."""
|
||||
test_name = "Stream endpoint (agentic agent)"
|
||||
|
||||
agent_result = self._create_agent_with_type("agentic")
|
||||
if not agent_result:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /stream with agentic agent_id={agent_id[:8]}...")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
events = []
|
||||
full_response = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
events.append(data)
|
||||
if data.get("type") in ["stream", "answer"]:
|
||||
full_response += data.get("message", "") or data.get("answer", "")
|
||||
elif data.get("type") == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self.print_success(f"Received {len(events)} events")
|
||||
if full_response:
|
||||
self.print_success(f"Answer preview: {full_response[:100]}...")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_answer_agentic_agent(self) -> bool:
|
||||
"""Test /api/answer endpoint with an agentic agent."""
|
||||
test_name = "Answer endpoint (agentic agent)"
|
||||
|
||||
agent_result = self._create_agent_with_type("agentic")
|
||||
if not agent_result:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /api/answer with agentic agent_id={agent_id[:8]}...")
|
||||
|
||||
response = self.post("/api/answer", json=payload, timeout=60)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
answer = result.get("answer", "")
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_stream_research_agent(self) -> bool:
|
||||
"""Test /stream endpoint with a research agent.
|
||||
|
||||
Research agents emit additional SSE event types:
|
||||
- research_plan: the decomposed research steps
|
||||
- research_progress: per-step status updates
|
||||
"""
|
||||
test_name = "Stream endpoint (research agent)"
|
||||
|
||||
agent_result = self._create_agent_with_type("research")
|
||||
if not agent_result:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /stream with research agent_id={agent_id[:8]}...")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
events = []
|
||||
full_response = ""
|
||||
saw_plan = False
|
||||
saw_progress = False
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
events.append(data)
|
||||
event_type = data.get("type", "")
|
||||
if event_type in ["stream", "answer"]:
|
||||
full_response += data.get("message", "") or data.get("answer", "")
|
||||
elif event_type == "research_plan":
|
||||
saw_plan = True
|
||||
steps = data.get("data", {}).get("steps", [])
|
||||
self.print_info(f"Research plan: {len(steps)} steps, complexity={data.get('data', {}).get('complexity')}")
|
||||
elif event_type == "research_progress":
|
||||
saw_progress = True
|
||||
elif event_type == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self.print_success(f"Received {len(events)} events")
|
||||
if saw_plan:
|
||||
self.print_success("Received research_plan event")
|
||||
if saw_progress:
|
||||
self.print_success("Received research_progress events")
|
||||
if full_response:
|
||||
self.print_success(f"Report preview: {full_response[:100]}...")
|
||||
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_answer_research_agent(self) -> bool:
|
||||
"""Test /api/answer endpoint with a research agent."""
|
||||
test_name = "Answer endpoint (research agent)"
|
||||
|
||||
agent_result = self._create_agent_with_type("research")
|
||||
if not agent_result:
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
self.record_result(test_name, True, "Skipped (no agent)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /api/answer with research agent_id={agent_id[:8]}...")
|
||||
|
||||
response = self.post("/api/answer", json=payload, timeout=120)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
answer = result.get("answer", "")
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Workflow with Agentic Node Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _create_workflow_with_agentic_node(self) -> Optional[str]:
|
||||
"""Create a workflow with a single agentic agent node.
|
||||
|
||||
Returns the workflow_id or None on failure.
|
||||
"""
|
||||
nodes = [
|
||||
{
|
||||
"id": "start_1",
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"data": {},
|
||||
},
|
||||
{
|
||||
"id": "agent_1",
|
||||
"type": "agent",
|
||||
"title": "Agentic Node",
|
||||
"data": {
|
||||
"agent_type": "agentic",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": True,
|
||||
"tools": [],
|
||||
"sources": [],
|
||||
"chunks": "2",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "end_1",
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"data": {},
|
||||
},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start_1", "target": "agent_1"},
|
||||
{"id": "edge_2", "source": "agent_1", "target": "end_1"},
|
||||
]
|
||||
|
||||
payload = {
|
||||
"name": f"Agentic Workflow Test {int(time.time())}",
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/workflows", json=payload, timeout=15)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
return result.get("id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _create_workflow_agent(self, workflow_id: str) -> Optional[tuple]:
|
||||
"""Create an agent of type 'workflow' referencing a workflow.
|
||||
|
||||
Returns (agent_id, api_key) or None.
|
||||
"""
|
||||
payload = {
|
||||
"name": f"Workflow Agent Test {int(time.time())}",
|
||||
"description": "Integration test workflow agent with agentic node",
|
||||
"prompt_id": "default",
|
||||
"chunks": 2,
|
||||
"retriever": "classic",
|
||||
"agent_type": "workflow",
|
||||
"workflow": workflow_id,
|
||||
"status": "draft",
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.post("/api/create_agent", json=payload, timeout=10)
|
||||
if response.status_code in [200, 201]:
|
||||
result = response.json()
|
||||
agent_id = result.get("id")
|
||||
api_key = result.get("key")
|
||||
if agent_id:
|
||||
return (agent_id, api_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def test_stream_workflow_with_agentic_node(self) -> bool:
|
||||
"""Test /stream with a workflow agent that contains an agentic node."""
|
||||
test_name = "Stream endpoint (workflow with agentic node)"
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
workflow_id = self._create_workflow_with_agentic_node()
|
||||
if not workflow_id:
|
||||
self.print_warning("Could not create workflow")
|
||||
self.record_result(test_name, True, "Skipped (workflow creation failed)")
|
||||
return True
|
||||
|
||||
agent_result = self._create_workflow_agent(workflow_id)
|
||||
if not agent_result:
|
||||
self.print_warning("Could not create workflow agent")
|
||||
self.record_result(test_name, True, "Skipped (agent creation failed)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /stream with workflow agent_id={agent_id[:8]}...")
|
||||
self.print_info(f"Workflow ID: {workflow_id}")
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/stream",
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
stream=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
events = []
|
||||
full_response = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
try:
|
||||
data = json_module.loads(line_str[6:])
|
||||
events.append(data)
|
||||
if data.get("type") in ["stream", "answer"]:
|
||||
full_response += data.get("message", "") or data.get("answer", "")
|
||||
elif data.get("type") == "end":
|
||||
break
|
||||
except json_module.JSONDecodeError:
|
||||
pass
|
||||
|
||||
self.print_success(f"Received {len(events)} events")
|
||||
if full_response:
|
||||
self.print_success(f"Answer preview: {full_response[:100]}...")
|
||||
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
def test_answer_workflow_with_agentic_node(self) -> bool:
|
||||
"""Test /api/answer with a workflow agent that contains an agentic node."""
|
||||
test_name = "Answer endpoint (workflow with agentic node)"
|
||||
|
||||
if not self.require_auth(test_name):
|
||||
return True
|
||||
|
||||
workflow_id = self._create_workflow_with_agentic_node()
|
||||
if not workflow_id:
|
||||
self.print_warning("Could not create workflow")
|
||||
self.record_result(test_name, True, "Skipped (workflow creation failed)")
|
||||
return True
|
||||
|
||||
agent_result = self._create_workflow_agent(workflow_id)
|
||||
if not agent_result:
|
||||
self.print_warning("Could not create workflow agent")
|
||||
self.record_result(test_name, True, "Skipped (agent creation failed)")
|
||||
return True
|
||||
|
||||
agent_id, _ = agent_result
|
||||
self.print_header(f"Testing {test_name}")
|
||||
|
||||
payload = {
|
||||
"question": "What is DocsGPT?",
|
||||
"history": "[]",
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self.print_info(f"POST /api/answer with workflow agent_id={agent_id[:8]}...")
|
||||
|
||||
response = self.post("/api/answer", json=payload, timeout=60)
|
||||
|
||||
self.print_info(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self.print_error(f"Expected 200, got {response.status_code}")
|
||||
self.record_result(test_name, False, f"Status {response.status_code}")
|
||||
return False
|
||||
|
||||
result = response.json()
|
||||
answer = result.get("answer", "")
|
||||
self.print_success(f"Answer received: {answer[:100]}...")
|
||||
self.record_result(test_name, True, "Success")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.print_error(f"Error: {str(e)}")
|
||||
self.record_result(test_name, False, str(e))
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Validation Tests
|
||||
# -------------------------------------------------------------------------
|
||||
@@ -916,6 +1396,27 @@ This is test documentation for integration tests.
|
||||
self.test_answer_endpoint_with_agent()
|
||||
time.sleep(1)
|
||||
|
||||
# Agentic agent tests
|
||||
self.test_stream_agentic_agent()
|
||||
time.sleep(1)
|
||||
|
||||
self.test_answer_agentic_agent()
|
||||
time.sleep(1)
|
||||
|
||||
# Research agent tests
|
||||
self.test_stream_research_agent()
|
||||
time.sleep(2)
|
||||
|
||||
self.test_answer_research_agent()
|
||||
time.sleep(2)
|
||||
|
||||
# Workflow with agentic node tests
|
||||
self.test_stream_workflow_with_agentic_node()
|
||||
time.sleep(2)
|
||||
|
||||
self.test_answer_workflow_with_agentic_node()
|
||||
time.sleep(2)
|
||||
|
||||
# API key tests
|
||||
self.test_stream_endpoint_with_api_key()
|
||||
time.sleep(1)
|
||||
|
||||
Reference in New Issue
Block a user