From 462f2e9494f765957fe5d276dde65dbad8d2a26a Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 25 Mar 2026 22:34:25 +0000 Subject: [PATCH] mini refactors --- application/agents/agentic_agent.py | 58 +- application/agents/research_agent.py | 34 +- application/agents/tools/internal_search.py | 53 ++ application/agents/workflows/node_agent.py | 19 +- application/agents/workflows/schemas.py | 2 + .../agents/workflows/workflow_engine.py | 39 +- application/prompts/agentic/creative.txt | 2 +- application/prompts/agentic/default.txt | 2 +- application/prompts/agentic/strict.txt | 2 +- .../prompts/research/clarification.txt | 2 +- application/prompts/research/planning.txt | 2 +- application/prompts/research/step.txt | 2 +- application/prompts/research/synthesis.txt | 2 +- frontend/src/agents/WorkflowBuilder.tsx | 8 +- .../src/agents/workflow/WorkflowBuilder.tsx | 8 +- tests/agents/test_agentic_agent.py | 139 +++++ tests/agents/test_internal_search_tool.py | 250 +++++++++ tests/agents/test_research_agent.py | 400 ++++++++++++++ tests/agents/test_think_tool.py | 52 ++ tests/agents/test_tool_executor.py | 279 ++++++++++ tests/agents/test_workflow_agent_types.py | 475 +++++++++++++++++ tests/integration/test_chat.py | 501 ++++++++++++++++++ 22 files changed, 2233 insertions(+), 98 deletions(-) create mode 100644 tests/agents/test_agentic_agent.py create mode 100644 tests/agents/test_internal_search_tool.py create mode 100644 tests/agents/test_research_agent.py create mode 100644 tests/agents/test_think_tool.py create mode 100644 tests/agents/test_tool_executor.py create mode 100644 tests/agents/test_workflow_agent_types.py diff --git a/application/agents/agentic_agent.py b/application/agents/agentic_agent.py index c8af7a2a..fe08277a 100644 --- a/application/agents/agentic_agent.py +++ b/application/agents/agentic_agent.py @@ -4,8 +4,7 @@ from typing import Dict, Generator, Optional from application.agents.base import BaseAgent from application.agents.tools.internal_search import ( INTERNAL_TOOL_ID, - build_internal_tool_config, - build_internal_tool_entry, + add_internal_search_tool, ) from application.logging import LogContext @@ -32,24 +31,8 @@ class AgenticAgent(BaseAgent): def _gen_inner( self, query: str, log_context: LogContext ) -> Generator[Dict, None, None]: - # 1. Get user tools (same as ClassicAgent) tools_dict = self.tool_executor.get_tools() - - # 2. Add internal search as a synthetic tool (only if sources are configured) - source = self.retriever_config.get("source", {}) - has_sources = bool(source.get("active_docs")) - if self.retriever_config and has_sources: - has_dir = _sources_have_directory_structure(source) - internal_entry = build_internal_tool_entry( - has_directory_structure=has_dir - ) - internal_entry["config"] = build_internal_tool_config( - **self.retriever_config, - has_directory_structure=has_dir, - ) - tools_dict[INTERNAL_TOOL_ID] = internal_entry - - # 3. Prepare all tools for the LLM + add_internal_search_tool(tools_dict, self.retriever_config) self._prepare_tools(tools_dict) # 4. Build messages (prompt has NO pre-fetched docs) @@ -78,40 +61,3 @@ class AgenticAgent(BaseAgent): tool = self.tool_executor._loaded_tools.get(cache_key) if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs: self.retrieved_docs = tool.retrieved_docs - - -def _sources_have_directory_structure(source: Dict) -> bool: - """Check if any of the active sources have directory_structure in MongoDB.""" - active_docs = source.get("active_docs", []) - if not active_docs: - return False - - try: - from bson.objectid import ObjectId - from application.core.mongo_db import MongoDB - - mongo = MongoDB.get_client() - db = mongo[settings.MONGO_DB_NAME] - sources_collection = db["sources"] - - if isinstance(active_docs, str): - active_docs = [active_docs] - - for doc_id in active_docs: - try: - source_doc = sources_collection.find_one( - {"_id": ObjectId(doc_id)}, - {"directory_structure": 1}, - ) - if source_doc and source_doc.get("directory_structure"): - return True - except Exception: - continue - except Exception as e: - logger.debug(f"Could not check directory structure: {e}") - - return False - - -# Import settings at module level for _sources_have_directory_structure -from application.core.settings import settings # noqa: E402 diff --git a/application/agents/research_agent.py b/application/agents/research_agent.py index b743cb16..280fa2cd 100644 --- a/application/agents/research_agent.py +++ b/application/agents/research_agent.py @@ -6,11 +6,9 @@ from typing import Dict, Generator, List, Optional from application.agents.base import BaseAgent from application.agents.tool_executor import ToolExecutor -from application.agents.agentic_agent import _sources_have_directory_structure from application.agents.tools.internal_search import ( INTERNAL_TOOL_ID, - build_internal_tool_config, - build_internal_tool_entry, + add_internal_search_tool, ) from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID from application.logging import LogContext @@ -130,6 +128,7 @@ class ResearchAgent(BaseAgent): self.citations = CitationManager() self._start_time: float = 0 self._tokens_used: int = 0 + self._last_token_snapshot: int = 0 # ------------------------------------------------------------------ # Budget & timeout helpers @@ -153,7 +152,9 @@ class ResearchAgent(BaseAgent): def _snapshot_llm_tokens(self) -> int: """Read current token usage from LLM and return delta since last snapshot.""" current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0) - return current + delta = current - self._last_token_snapshot + self._last_token_snapshot = current + return delta # ------------------------------------------------------------------ # Main orchestration @@ -272,21 +273,7 @@ class ResearchAgent(BaseAgent): """Build tools_dict with user tools + internal search + think.""" tools_dict = self.tool_executor.get_tools() - # Only add internal search if sources are configured - source = self.retriever_config.get("source", {}) - has_sources = bool(source.get("active_docs")) - if self.retriever_config and has_sources: - has_dir = _sources_have_directory_structure(source) - internal_entry = build_internal_tool_entry( - has_directory_structure=has_dir - ) - internal_entry["config"] = build_internal_tool_config( - **self.retriever_config, - has_directory_structure=has_dir, - ) - tools_dict[INTERNAL_TOOL_ID] = internal_entry - elif self.retriever_config and not has_sources: - logger.info("ResearchAgent: No sources configured, skipping internal_search tool") + add_internal_search_tool(tools_dict, self.retriever_config) think_entry = dict(THINK_TOOL_ENTRY) think_entry["config"] = {} @@ -580,7 +567,14 @@ class ResearchAgent(BaseAgent): call_id = None while True: try: - next(gen) + event = next(gen) + # Log tool_call status events instead of discarding them + if isinstance(event, dict) and event.get("type") == "tool_call": + logger.debug( + "Tool %s status: %s", + event.get("data", {}).get("action_name", ""), + event.get("data", {}).get("status", ""), + ) except StopIteration as e: result, call_id = e.value break diff --git a/application/agents/tools/internal_search.py b/application/agents/tools/internal_search.py index 66f91e83..2cd7915b 100644 --- a/application/agents/tools/internal_search.py +++ b/application/agents/tools/internal_search.py @@ -354,6 +354,59 @@ def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict: INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False) +def sources_have_directory_structure(source: Dict) -> bool: + """Check if any of the active sources have directory_structure in MongoDB.""" + active_docs = source.get("active_docs", []) + if not active_docs: + return False + + try: + from bson.objectid import ObjectId + from application.core.mongo_db import MongoDB + + mongo = MongoDB.get_client() + db = mongo[settings.MONGO_DB_NAME] + sources_collection = db["sources"] + + if isinstance(active_docs, str): + active_docs = [active_docs] + + for doc_id in active_docs: + try: + source_doc = sources_collection.find_one( + {"_id": ObjectId(doc_id)}, + {"directory_structure": 1}, + ) + if source_doc and source_doc.get("directory_structure"): + return True + except Exception: + continue + except Exception as e: + logger.debug(f"Could not check directory structure: {e}") + + return False + + +def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None: + """Add the internal search tool to tools_dict if sources are configured. + + Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic. + Mutates tools_dict in place. + """ + source = retriever_config.get("source", {}) + has_sources = bool(source.get("active_docs")) + if not retriever_config or not has_sources: + return + + has_dir = sources_have_directory_structure(source) + internal_entry = build_internal_tool_entry(has_directory_structure=has_dir) + internal_entry["config"] = build_internal_tool_config( + **retriever_config, + has_directory_structure=has_dir, + ) + tools_dict[INTERNAL_TOOL_ID] = internal_entry + + def build_internal_tool_config( source: Dict, retriever_name: str = "classic", diff --git a/application/agents/workflows/node_agent.py b/application/agents/workflows/node_agent.py index c26f6464..67437e45 100644 --- a/application/agents/workflows/node_agent.py +++ b/application/agents/workflows/node_agent.py @@ -2,8 +2,10 @@ from typing import Any, Dict, List, Optional, Type +from application.agents.agentic_agent import AgenticAgent from application.agents.base import BaseAgent from application.agents.classic_agent import ClassicAgent +from application.agents.research_agent import ResearchAgent from application.agents.workflows.schemas import AgentType @@ -35,7 +37,8 @@ class ToolFilterMixin: return filtered_tools -class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent): +class _WorkflowNodeMixin: + """Common __init__ for all workflow node agents.""" def __init__( self, @@ -56,11 +59,25 @@ class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent): self._allowed_tool_ids = tool_ids or [] +class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent): + pass + + +class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent): + pass + + +class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent): + pass + + class WorkflowNodeAgentFactory: _agents: Dict[AgentType, Type[BaseAgent]] = { AgentType.CLASSIC: WorkflowNodeClassicAgent, AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat + AgentType.AGENTIC: WorkflowNodeAgenticAgent, + AgentType.RESEARCH: WorkflowNodeResearchAgent, } @classmethod diff --git a/application/agents/workflows/schemas.py b/application/agents/workflows/schemas.py index 5355b88e..2a5bc79e 100644 --- a/application/agents/workflows/schemas.py +++ b/application/agents/workflows/schemas.py @@ -18,6 +18,8 @@ class NodeType(str, Enum): class AgentType(str, Enum): CLASSIC = "classic" REACT = "react" + AGENTIC = "agentic" + RESEARCH = "research" class ExecutionStatus(str, Enum): diff --git a/application/agents/workflows/workflow_engine.py b/application/agents/workflows/workflow_engine.py index 5444458a..00d471da 100644 --- a/application/agents/workflows/workflow_engine.py +++ b/application/agents/workflows/workflow_engine.py @@ -7,6 +7,7 @@ from application.agents.workflows.cel_evaluator import CelEvaluationError, evalu from application.agents.workflows.node_agent import WorkflowNodeAgentFactory from application.agents.workflows.schemas import ( AgentNodeConfig, + AgentType, ConditionNodeConfig, ExecutionStatus, NodeExecutionLog, @@ -223,18 +224,32 @@ class WorkflowEngine: f'Model "{node_model_id}" does not support structured output for node "{node.title}"' ) - node_agent = WorkflowNodeAgentFactory.create( - agent_type=node_config.agent_type, - endpoint=self.agent.endpoint, - llm_name=node_llm_name, - model_id=node_model_id, - api_key=node_api_key, - tool_ids=node_config.tools, - prompt=node_config.system_prompt, - chat_history=self.agent.chat_history, - decoded_token=self.agent.decoded_token, - json_schema=node_json_schema, - ) + factory_kwargs = { + "agent_type": node_config.agent_type, + "endpoint": self.agent.endpoint, + "llm_name": node_llm_name, + "model_id": node_model_id, + "api_key": node_api_key, + "tool_ids": node_config.tools, + "prompt": node_config.system_prompt, + "chat_history": self.agent.chat_history, + "decoded_token": self.agent.decoded_token, + "json_schema": node_json_schema, + } + + # Agentic/research agents need retriever_config for on-demand search + if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH): + factory_kwargs["retriever_config"] = { + "source": {"active_docs": node_config.sources} if node_config.sources else {}, + "retriever_name": node_config.retriever or "classic", + "chunks": int(node_config.chunks) if node_config.chunks else 2, + "model_id": node_model_id, + "llm_name": node_llm_name, + "api_key": node_api_key, + "decoded_token": self.agent.decoded_token, + } + + node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs) full_response_parts: List[str] = [] structured_response_parts: List[str] = [] diff --git a/application/prompts/agentic/creative.txt b/application/prompts/agentic/creative.txt index da2f8ff2..a1360b76 100644 --- a/application/prompts/agentic/creative.txt +++ b/application/prompts/agentic/creative.txt @@ -13,4 +13,4 @@ Use the search_internal tool to find relevant information before answering quest You may search multiple times with different queries if needed. Do not guess when documents are available — search first, then answer based on what you find. If no relevant documents are found, use your general knowledge and tool capabilities. -Allow yourself to be very creative and use your imagination. \ No newline at end of file +Allow yourself to be very creative and use your imagination. diff --git a/application/prompts/agentic/default.txt b/application/prompts/agentic/default.txt index 41339fe6..a2118414 100644 --- a/application/prompts/agentic/default.txt +++ b/application/prompts/agentic/default.txt @@ -12,4 +12,4 @@ You have access to a search tool that searches the user's uploaded documents and Use the search_internal tool to find relevant information before answering questions. You may search multiple times with different queries if needed. Do not guess when documents are available — search first, then answer based on what you find. -If no relevant documents are found, use your general knowledge and tool capabilities. \ No newline at end of file +If no relevant documents are found, use your general knowledge and tool capabilities. diff --git a/application/prompts/agentic/strict.txt b/application/prompts/agentic/strict.txt index 8878b323..dd327918 100644 --- a/application/prompts/agentic/strict.txt +++ b/application/prompts/agentic/strict.txt @@ -13,4 +13,4 @@ Use the search_internal tool to find relevant information before answering quest You may search multiple times with different queries if needed. You MUST search before answering any factual question. Do not guess or use general knowledge when documents are available. If you dont have enough information from the search results or tools, answer "I don't know" or "I don't have enough information". -Never make up information or provide false information! \ No newline at end of file +Never make up information or provide false information! diff --git a/application/prompts/research/clarification.txt b/application/prompts/research/clarification.txt index 000c73c2..499d46e3 100644 --- a/application/prompts/research/clarification.txt +++ b/application/prompts/research/clarification.txt @@ -20,4 +20,4 @@ You MUST respond with ONLY a valid JSON object (no markdown, no code fences): "needs_clarification": true or false, "questions": ["question 1", "question 2"] (only if needs_clarification is true, max 3 questions), "reason": "brief explanation of why clarification is or isn't needed" -} \ No newline at end of file +} diff --git a/application/prompts/research/planning.txt b/application/prompts/research/planning.txt index 8263bcaf..74c84a94 100644 --- a/application/prompts/research/planning.txt +++ b/application/prompts/research/planning.txt @@ -20,4 +20,4 @@ You MUST respond with ONLY a valid JSON object in this exact format (no markdown "steps": [ {"query": "specific sub-question to investigate", "rationale": "why this step is needed"} ] -} \ No newline at end of file +} diff --git a/application/prompts/research/step.txt b/application/prompts/research/step.txt index aecaf2b8..c0bf48a0 100644 --- a/application/prompts/research/step.txt +++ b/application/prompts/research/step.txt @@ -10,4 +10,4 @@ Instructions: 5. Cite specific documents and passages you found. Reference sources by their titles or filenames. 6. If you cannot find relevant information through tools, use your general knowledge but clearly indicate this. -Be thorough — prefer completeness over brevity. Include all relevant details you find. \ No newline at end of file +Be thorough — prefer completeness over brevity. Include all relevant details you find. diff --git a/application/prompts/research/synthesis.txt b/application/prompts/research/synthesis.txt index 66bb1cad..30ea3958 100644 --- a/application/prompts/research/synthesis.txt +++ b/application/prompts/research/synthesis.txt @@ -19,4 +19,4 @@ Write a well-structured, thorough report that: Available sources for citation: {references} -Format the report with clear headings and sections. Be comprehensive but well-organized. \ No newline at end of file +Format the report with clear headings and sections. Be comprehensive but well-organized. diff --git a/frontend/src/agents/WorkflowBuilder.tsx b/frontend/src/agents/WorkflowBuilder.tsx index 38a2ac67..d3c7c8e1 100644 --- a/frontend/src/agents/WorkflowBuilder.tsx +++ b/frontend/src/agents/WorkflowBuilder.tsx @@ -55,7 +55,7 @@ import WorkflowPreview from './workflow/WorkflowPreview'; import type { Model } from '../models/types'; interface AgentNodeConfig { - agent_type: 'classic'; + agent_type: 'classic' | 'agentic' | 'research'; llm_name?: string; model_id?: string; system_prompt: string; @@ -884,6 +884,12 @@ function WorkflowBuilderInner() { Classic + + Agentic + + + Research + diff --git a/frontend/src/agents/workflow/WorkflowBuilder.tsx b/frontend/src/agents/workflow/WorkflowBuilder.tsx index 845145be..5e765ea9 100644 --- a/frontend/src/agents/workflow/WorkflowBuilder.tsx +++ b/frontend/src/agents/workflow/WorkflowBuilder.tsx @@ -75,7 +75,7 @@ import type { Model } from '../../models/types'; const PRIMARY_ACTION_SPINNER_DELAY_MS = 180; interface AgentNodeConfig { - agent_type: 'classic'; + agent_type: 'classic' | 'agentic' | 'research'; llm_name?: string; model_id?: string; system_prompt: string; @@ -1748,6 +1748,12 @@ function WorkflowBuilderInner() { Classic + + Agentic + + + Research + diff --git a/tests/agents/test_agentic_agent.py b/tests/agents/test_agentic_agent.py new file mode 100644 index 00000000..f742fffd --- /dev/null +++ b/tests/agents/test_agentic_agent.py @@ -0,0 +1,139 @@ +"""Tests for AgenticAgent — LLM-controlled retrieval agent.""" + +from unittest.mock import Mock + +import pytest +from application.agents.agentic_agent import AgenticAgent + + +@pytest.mark.unit +class TestAgenticAgentInit: + + def test_initialization( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = AgenticAgent(**agent_base_params) + assert isinstance(agent, AgenticAgent) + assert agent.retriever_config == {} + + def test_initialization_with_retriever_config( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + rc = {"source": {"active_docs": ["abc"]}, "retriever_name": "classic"} + agent = AgenticAgent(retriever_config=rc, **agent_base_params) + assert agent.retriever_config == rc + + def test_inherits_base_properties( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = AgenticAgent(**agent_base_params) + assert agent.endpoint == agent_base_params["endpoint"] + assert agent.llm_name == agent_base_params["llm_name"] + assert agent.model_id == agent_base_params["model_id"] + + +@pytest.mark.unit +class TestAgenticAgentGenInner: + + def test_basic_flow_yields_sources_and_tool_calls( + self, + agent_base_params, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Processed" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = AgenticAgent(**agent_base_params) + results = list(agent._gen_inner("Test query", log_context)) + + sources = [r for r in results if "sources" in r] + tool_calls = [r for r in results if "tool_calls" in r] + assert len(sources) == 1 + assert len(tool_calls) == 1 + + def test_logs_agent_component( + self, + agent_base_params, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Done" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = AgenticAgent(**agent_base_params) + list(agent._gen_inner("Query", log_context)) + + agent_logs = [s for s in log_context.stacks if s["component"] == "agent"] + assert len(agent_logs) == 1 + assert "tool_calls" in agent_logs[0]["data"] + + def test_no_pre_fetched_docs_in_messages( + self, + agent_base_params, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + mock_mongo_db, + log_context, + ): + mock_llm.gen_stream = Mock(return_value=iter(["Answer"])) + + def mock_handler(*args, **kwargs): + yield "Done" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_handler) + + agent = AgenticAgent(**agent_base_params) + list(agent._gen_inner("Query", log_context)) + + call_kwargs = mock_llm.gen_stream.call_args[1] + messages = call_kwargs["messages"] + # System prompt should not contain {summaries} replacement + assert messages[0]["role"] == "system" + assert messages[-1]["role"] == "user" + assert messages[-1]["content"] == "Query" + + +@pytest.mark.unit +class TestAgenticAgentCollectSources: + + def test_collect_internal_sources_from_cache( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = AgenticAgent(**agent_base_params) + + mock_tool = Mock() + mock_tool.retrieved_docs = [ + {"text": "Found", "title": "Doc", "source": "test"}, + ] + cache_key = f"internal_search:internal:{agent.user or ''}" + agent.tool_executor._loaded_tools[cache_key] = mock_tool + + agent._collect_internal_sources() + assert len(agent.retrieved_docs) == 1 + assert agent.retrieved_docs[0]["title"] == "Doc" + + def test_collect_internal_sources_no_cache( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = AgenticAgent(**agent_base_params) + agent._collect_internal_sources() + assert agent.retrieved_docs == [] diff --git a/tests/agents/test_internal_search_tool.py b/tests/agents/test_internal_search_tool.py new file mode 100644 index 00000000..b9e4fce0 --- /dev/null +++ b/tests/agents/test_internal_search_tool.py @@ -0,0 +1,250 @@ +"""Tests for InternalSearchTool and its helper functions.""" + +from unittest.mock import Mock, patch + +import pytest +from application.agents.tools.internal_search import ( + INTERNAL_TOOL_ID, + InternalSearchTool, + add_internal_search_tool, + build_internal_tool_config, + build_internal_tool_entry, +) + + +@pytest.mark.unit +class TestInternalSearchToolSearch: + + def _make_tool(self, **config_overrides): + config = {"source": {}, "retriever_name": "classic", "chunks": 2} + config.update(config_overrides) + return InternalSearchTool(config) + + def test_search_no_query_returns_error(self): + tool = self._make_tool() + result = tool.execute_action("search", query="") + assert "required" in result.lower() + + def test_search_returns_formatted_docs(self): + tool = self._make_tool() + mock_retriever = Mock() + mock_retriever.search.return_value = [ + {"text": "Hello world", "title": "Doc1", "source": "test", "filename": "doc1.md"}, + ] + tool._retriever = mock_retriever + + result = tool.execute_action("search", query="hello") + assert "doc1.md" in result + assert "Hello world" in result + assert len(tool.retrieved_docs) == 1 + + def test_search_no_results(self): + tool = self._make_tool() + mock_retriever = Mock() + mock_retriever.search.return_value = [] + tool._retriever = mock_retriever + + result = tool.execute_action("search", query="nonexistent") + assert "No documents found" in result + + def test_search_accumulates_docs(self): + tool = self._make_tool() + mock_retriever = Mock() + tool._retriever = mock_retriever + + mock_retriever.search.return_value = [ + {"text": "A", "title": "D1", "source": "s1"}, + ] + tool.execute_action("search", query="first") + + mock_retriever.search.return_value = [ + {"text": "B", "title": "D2", "source": "s2"}, + ] + tool.execute_action("search", query="second") + + assert len(tool.retrieved_docs) == 2 + + def test_search_deduplicates_docs(self): + tool = self._make_tool() + doc = {"text": "Same", "title": "Same", "source": "same"} + mock_retriever = Mock() + mock_retriever.search.return_value = [doc] + tool._retriever = mock_retriever + + tool.execute_action("search", query="q1") + tool.execute_action("search", query="q2") + + assert len(tool.retrieved_docs) == 1 + + def test_search_with_path_filter(self): + tool = self._make_tool() + mock_retriever = Mock() + mock_retriever.search.return_value = [ + {"text": "A", "title": "T", "source": "src/main.py", "filename": "main.py"}, + {"text": "B", "title": "T", "source": "docs/readme.md", "filename": "readme.md"}, + ] + tool._retriever = mock_retriever + + result = tool.execute_action("search", query="code", path_filter="src/") + assert "main.py" in result + assert "readme.md" not in result + + def test_search_path_filter_no_match(self): + tool = self._make_tool() + mock_retriever = Mock() + mock_retriever.search.return_value = [ + {"text": "A", "title": "T", "source": "other/file.txt"}, + ] + tool._retriever = mock_retriever + + result = tool.execute_action("search", query="code", path_filter="src/") + assert "No documents found" in result + + def test_search_retriever_error(self): + tool = self._make_tool() + mock_retriever = Mock() + mock_retriever.search.side_effect = Exception("Connection error") + tool._retriever = mock_retriever + + result = tool.execute_action("search", query="test") + assert "failed" in result.lower() or "error" in result.lower() + + def test_unknown_action(self): + tool = self._make_tool() + result = tool.execute_action("nonexistent") + assert "Unknown action" in result + + +@pytest.mark.unit +class TestInternalSearchToolListFiles: + + def test_list_files_no_structure(self): + tool = InternalSearchTool({"source": {}}) + tool._dir_structure_loaded = True + tool._directory_structure = None + + result = tool.execute_action("list_files") + assert "No file structure" in result + + def test_list_files_root(self): + tool = InternalSearchTool({"source": {}}) + tool._dir_structure_loaded = True + tool._directory_structure = { + "src": {"main.py": {}}, + "README.md": {"type": "md", "token_count": 100}, + } + + result = tool.execute_action("list_files") + assert "src/" in result + assert "README.md" in result + + def test_list_files_nested_path(self): + tool = InternalSearchTool({"source": {}}) + tool._dir_structure_loaded = True + tool._directory_structure = { + "src": { + "utils": {"helper.py": {}}, + }, + } + + result = tool.execute_action("list_files", path="src") + assert "utils/" in result + + def test_list_files_invalid_path(self): + tool = InternalSearchTool({"source": {}}) + tool._dir_structure_loaded = True + tool._directory_structure = {"src": {}} + + result = tool.execute_action("list_files", path="nonexistent") + assert "not found" in result + + +@pytest.mark.unit +class TestInternalSearchToolMetadata: + + def test_actions_without_directory_structure(self): + tool = InternalSearchTool({"has_directory_structure": False}) + meta = tool.get_actions_metadata() + + action_names = [a["name"] for a in meta] + assert "search" in action_names + assert "list_files" not in action_names + + # search should not have path_filter + search = meta[0] + assert "path_filter" not in search["parameters"]["properties"] + + def test_actions_with_directory_structure(self): + tool = InternalSearchTool({"has_directory_structure": True}) + meta = tool.get_actions_metadata() + + action_names = [a["name"] for a in meta] + assert "search" in action_names + assert "list_files" in action_names + + # search should have path_filter + search = next(a for a in meta if a["name"] == "search") + assert "path_filter" in search["parameters"]["properties"] + + +@pytest.mark.unit +class TestBuildHelpers: + + def test_build_entry_without_directory_structure(self): + entry = build_internal_tool_entry(has_directory_structure=False) + assert entry["name"] == "internal_search" + action_names = [a["name"] for a in entry["actions"]] + assert "search" in action_names + assert "list_files" not in action_names + + def test_build_entry_with_directory_structure(self): + entry = build_internal_tool_entry(has_directory_structure=True) + action_names = [a["name"] for a in entry["actions"]] + assert "list_files" in action_names + + def test_build_config(self): + config = build_internal_tool_config( + source={"active_docs": ["abc"]}, + retriever_name="semantic", + chunks=4, + ) + assert config["source"] == {"active_docs": ["abc"]} + assert config["retriever_name"] == "semantic" + assert config["chunks"] == 4 + + def test_internal_tool_id(self): + assert INTERNAL_TOOL_ID == "internal" + + def test_add_internal_search_tool_with_sources(self): + tools_dict = {} + retriever_config = { + "source": {"active_docs": ["abc"]}, + "retriever_name": "classic", + "chunks": 2, + "model_id": "gpt-4", + "llm_name": "openai", + "api_key": "key", + } + + with patch( + "application.agents.tools.internal_search.sources_have_directory_structure", + return_value=False, + ): + add_internal_search_tool(tools_dict, retriever_config) + + assert INTERNAL_TOOL_ID in tools_dict + assert tools_dict[INTERNAL_TOOL_ID]["name"] == "internal_search" + assert "config" in tools_dict[INTERNAL_TOOL_ID] + + def test_add_internal_search_tool_no_sources(self): + tools_dict = {} + retriever_config = {"source": {}} + + add_internal_search_tool(tools_dict, retriever_config) + + assert INTERNAL_TOOL_ID not in tools_dict + + def test_add_internal_search_tool_empty_config(self): + tools_dict = {} + add_internal_search_tool(tools_dict, {}) + assert INTERNAL_TOOL_ID not in tools_dict diff --git a/tests/agents/test_research_agent.py b/tests/agents/test_research_agent.py new file mode 100644 index 00000000..3cc1e6d3 --- /dev/null +++ b/tests/agents/test_research_agent.py @@ -0,0 +1,400 @@ +"""Tests for ResearchAgent — multi-step research with budget controls.""" + +import json +import time +from unittest.mock import Mock, patch + +import pytest +from application.agents.research_agent import ( + CitationManager, + ResearchAgent, + DEFAULT_MAX_STEPS, + DEFAULT_TIMEOUT_SECONDS, + DEFAULT_TOKEN_BUDGET, +) + + +# --------------------------------------------------------------------------- +# CitationManager +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCitationManager: + + def test_add_returns_citation_number(self): + cm = CitationManager() + num = cm.add({"source": "s1", "title": "T1"}) + assert num == 1 + + def test_add_deduplicates(self): + cm = CitationManager() + n1 = cm.add({"source": "s1", "title": "T1"}) + n2 = cm.add({"source": "s1", "title": "T1"}) + assert n1 == n2 + assert len(cm.citations) == 1 + + def test_add_different_sources(self): + cm = CitationManager() + n1 = cm.add({"source": "s1", "title": "T1"}) + n2 = cm.add({"source": "s2", "title": "T2"}) + assert n1 != n2 + assert len(cm.citations) == 2 + + def test_add_docs_returns_mapping(self): + cm = CitationManager() + docs = [ + {"source": "s1", "title": "Doc A"}, + {"source": "s2", "title": "Doc B"}, + ] + text = cm.add_docs(docs) + assert "[1] Doc A" in text + assert "[2] Doc B" in text + + def test_format_references(self): + cm = CitationManager() + cm.add({"source": "http://example.com", "title": "Example", "filename": "ex.md"}) + refs = cm.format_references() + assert "[1]" in refs + assert "ex.md" in refs + + def test_format_references_empty(self): + cm = CitationManager() + assert "No sources" in cm.format_references() + + def test_get_all_docs(self): + cm = CitationManager() + cm.add({"source": "s1", "title": "T1"}) + cm.add({"source": "s2", "title": "T2"}) + docs = cm.get_all_docs() + assert len(docs) == 2 + + +# --------------------------------------------------------------------------- +# ResearchAgent Init & Budget +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestResearchAgentInit: + + def test_initialization( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ResearchAgent(**agent_base_params) + assert isinstance(agent, ResearchAgent) + assert agent.max_steps == DEFAULT_MAX_STEPS + assert agent.timeout_seconds == DEFAULT_TIMEOUT_SECONDS + assert agent.token_budget == DEFAULT_TOKEN_BUDGET + assert agent.retriever_config == {} + + def test_custom_budget( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ResearchAgent( + max_steps=3, + timeout_seconds=60, + token_budget=50_000, + **agent_base_params, + ) + assert agent.max_steps == 3 + assert agent.timeout_seconds == 60 + assert agent.token_budget == 50_000 + + def test_with_retriever_config( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + rc = {"source": {"active_docs": ["abc"]}} + agent = ResearchAgent(retriever_config=rc, **agent_base_params) + assert agent.retriever_config == rc + + +@pytest.mark.unit +class TestResearchAgentBudget: + + def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator, **kwargs): + return ResearchAgent(**kwargs, **agent_base_params) + + def test_timeout_detection( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator, + timeout_seconds=0, + ) + agent._start_time = time.monotonic() - 1 + assert agent._is_timed_out() is True + + def test_not_timed_out( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator, + timeout_seconds=300, + ) + agent._start_time = time.monotonic() + assert agent._is_timed_out() is False + + def test_token_budget_tracking( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator, + token_budget=1000, + ) + agent._track_tokens(500) + assert agent._budget_remaining() == 500 + assert agent._is_over_budget() is False + + agent._track_tokens(500) + assert agent._budget_remaining() == 0 + assert agent._is_over_budget() is True + + def test_snapshot_llm_tokens_returns_delta( + self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator, + ) + mock_llm.token_usage = {"prompt_tokens": 100, "generated_tokens": 50} + + delta1 = agent._snapshot_llm_tokens() + assert delta1 == 150 + + # Simulate more tokens used + mock_llm.token_usage = {"prompt_tokens": 200, "generated_tokens": 100} + delta2 = agent._snapshot_llm_tokens() + assert delta2 == 150 # 300 - 150 + + def test_elapsed( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator, + ) + agent._start_time = time.monotonic() - 1.5 + elapsed = agent._elapsed() + assert elapsed >= 1.0 + + +# --------------------------------------------------------------------------- +# ResearchAgent Phases +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestResearchAgentClarification: + + def test_is_follow_up_no_history( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ResearchAgent(**agent_base_params) + assert agent._is_follow_up() is False + + def test_is_follow_up_with_clarification_metadata( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["chat_history"] = [ + {"prompt": "What?", "response": "Clarify", "metadata": {"is_clarification": True}}, + ] + agent = ResearchAgent(**agent_base_params) + assert agent._is_follow_up() is True + + def test_is_follow_up_without_metadata( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["chat_history"] = [ + {"prompt": "What?", "response": "Normal answer"}, + ] + agent = ResearchAgent(**agent_base_params) + assert agent._is_follow_up() is False + + def test_clarification_returns_none_on_no_clarification_needed( + self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator + ): + response = Mock() + response.choices = [Mock()] + response.choices[0].message = Mock() + response.choices[0].message.content = json.dumps( + {"needs_clarification": False, "reason": "Clear enough"} + ) + mock_llm.gen = Mock(return_value=response) + mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5} + + agent = ResearchAgent(**agent_base_params) + result = agent._clarification_phase("What is Python?") + assert result is None + + def test_clarification_returns_questions( + self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator + ): + clarification_json = json.dumps({ + "needs_clarification": True, + "questions": ["Which version?", "What context?"], + }) + # Return a plain string so _extract_text handles it directly + mock_llm.gen = Mock(return_value=clarification_json) + mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5} + + agent = ResearchAgent(**agent_base_params) + result = agent._clarification_phase("Tell me about it") + assert result is not None + assert "Which version?" in result + assert "What context?" in result + + +@pytest.mark.unit +class TestResearchAgentPlanning: + + def test_planning_returns_steps_and_complexity( + self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator + ): + plan_json = json.dumps({ + "complexity": "moderate", + "steps": [ + {"query": "sub-question 1", "rationale": "reason 1"}, + {"query": "sub-question 2", "rationale": "reason 2"}, + ], + }) + # Return plain string so _extract_text handles it directly + mock_llm.gen = Mock(return_value=plan_json) + mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5} + + agent = ResearchAgent(**agent_base_params) + steps, complexity = agent._planning_phase("Compare A and B") + + assert complexity == "moderate" + assert len(steps) == 2 + assert steps[0]["query"] == "sub-question 1" + + def test_planning_caps_steps_by_complexity( + self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator + ): + plan_json = json.dumps({ + "complexity": "simple", + "steps": [ + {"query": f"q{i}", "rationale": f"r{i}"} for i in range(10) + ], + }) + response = Mock() + response.choices = [Mock()] + response.choices[0].message = Mock() + response.choices[0].message.content = plan_json + mock_llm.gen = Mock(return_value=response) + mock_llm.token_usage = {"prompt_tokens": 10, "generated_tokens": 5} + + agent = ResearchAgent(**agent_base_params) + steps, complexity = agent._planning_phase("Simple question") + + assert complexity == "simple" + assert len(steps) <= 2 # COMPLEXITY_CAPS["simple"] == 2 + + def test_planning_fallback_on_error( + self, agent_base_params, mock_llm, mock_llm_creator, mock_llm_handler_creator + ): + mock_llm.gen = Mock(side_effect=Exception("LLM down")) + mock_llm.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} + + agent = ResearchAgent(**agent_base_params) + steps, complexity = agent._planning_phase("Anything") + + assert complexity == "simple" + assert len(steps) == 1 + assert steps[0]["query"] == "Anything" + + +@pytest.mark.unit +class TestResearchAgentExtractText: + + def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator): + return ResearchAgent(**agent_base_params) + + def test_extract_from_string( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + assert agent._extract_text("hello") == "hello" + + def test_extract_from_openai_response( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + response = Mock() + response.choices = [Mock()] + response.choices[0].message = Mock() + response.choices[0].message.content = "OpenAI content" + response.message = None + response.content = None + assert agent._extract_text(response) == "OpenAI content" + + def test_extract_from_anthropic_response( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + text_block = Mock() + text_block.text = "Anthropic content" + response = Mock() + response.content = [text_block] + response.message = None + response.choices = None + assert agent._extract_text(response) == "Anthropic content" + + def test_extract_from_none( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + assert agent._extract_text(None) == "" + + +@pytest.mark.unit +class TestResearchAgentParseJson: + + def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator): + return ResearchAgent(**agent_base_params) + + def test_parse_plan_direct_json( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + text = '{"steps": [{"query": "q1"}], "complexity": "simple"}' + result = agent._parse_plan_json(text) + assert isinstance(result, dict) + assert len(result["steps"]) == 1 + + def test_parse_plan_from_code_fence( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + text = 'Here is the plan:\n```json\n{"steps": [{"query": "q1"}]}\n```' + result = agent._parse_plan_json(text) + assert isinstance(result, dict) + + def test_parse_plan_invalid_returns_empty( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + result = agent._parse_plan_json("not json at all") + assert result == [] + + def test_parse_clarification_json( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + text = '{"needs_clarification": false, "reason": "clear"}' + result = agent._parse_clarification_json(text) + assert result["needs_clarification"] is False + + def test_parse_clarification_json_from_code_fence( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + text = '```json\n{"needs_clarification": true, "questions": ["q1"]}\n```' + result = agent._parse_clarification_json(text) + assert result["needs_clarification"] is True + + def test_parse_clarification_json_invalid( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent(agent_base_params, mock_llm_creator, mock_llm_handler_creator) + result = agent._parse_clarification_json("not json") + assert result is None diff --git a/tests/agents/test_think_tool.py b/tests/agents/test_think_tool.py new file mode 100644 index 00000000..45f9e727 --- /dev/null +++ b/tests/agents/test_think_tool.py @@ -0,0 +1,52 @@ +"""Tests for ThinkTool — the chain-of-thought pseudo-tool.""" + +import pytest +from application.agents.tools.think import ( + THINK_TOOL_ENTRY, + THINK_TOOL_ID, + ThinkTool, +) + + +@pytest.mark.unit +class TestThinkTool: + + def test_id_constant(self): + assert THINK_TOOL_ID == "think" + + def test_entry_has_reason_action(self): + actions = THINK_TOOL_ENTRY["actions"] + assert len(actions) == 1 + assert actions[0]["name"] == "reason" + assert actions[0]["active"] is True + + def test_execute_reason_returns_continue(self): + tool = ThinkTool() + result = tool.execute_action("reason", reasoning="step by step thinking") + assert result == "Continue." + + def test_execute_unknown_action_returns_continue(self): + tool = ThinkTool() + result = tool.execute_action("unknown_action") + assert result == "Continue." + + def test_get_actions_metadata(self): + tool = ThinkTool() + meta = tool.get_actions_metadata() + assert len(meta) == 1 + assert meta[0]["name"] == "reason" + props = meta[0]["parameters"]["properties"] + assert "reasoning" in props + assert props["reasoning"]["filled_by_llm"] is True + + def test_get_config_requirements_empty(self): + tool = ThinkTool() + assert tool.get_config_requirements() == {} + + def test_init_accepts_no_config(self): + tool = ThinkTool() + assert tool is not None + + def test_init_accepts_config(self): + tool = ThinkTool(config={"key": "value"}) + assert tool is not None diff --git a/tests/agents/test_tool_executor.py b/tests/agents/test_tool_executor.py new file mode 100644 index 00000000..d995c22b --- /dev/null +++ b/tests/agents/test_tool_executor.py @@ -0,0 +1,279 @@ +"""Tests for ToolExecutor — tool discovery, preparation, and execution.""" + +from unittest.mock import Mock + +import pytest +from application.agents.tool_executor import ToolExecutor + + +@pytest.mark.unit +class TestToolExecutorInit: + + def test_default_state(self): + executor = ToolExecutor() + assert executor.user_api_key is None + assert executor.user is None + assert executor.tool_calls == [] + assert executor._loaded_tools == {} + assert executor.conversation_id is None + + def test_init_with_params(self): + executor = ToolExecutor( + user_api_key="key", user="alice", decoded_token={"sub": "alice"} + ) + assert executor.user_api_key == "key" + assert executor.user == "alice" + + +@pytest.mark.unit +class TestToolExecutorGetTools: + + def test_get_tools_uses_api_key_when_present(self, mock_mongo_db): + executor = ToolExecutor(user_api_key="test_key", user="alice") + tools = executor.get_tools() + assert isinstance(tools, dict) + + def test_get_tools_uses_user_when_no_api_key(self, mock_mongo_db): + executor = ToolExecutor(user="alice") + tools = executor.get_tools() + assert isinstance(tools, dict) + + def test_get_tools_defaults_to_local(self, mock_mongo_db): + executor = ToolExecutor() + tools = executor.get_tools() + assert isinstance(tools, dict) + + +@pytest.mark.unit +class TestToolExecutorPrepare: + + def test_prepare_tools_for_llm_empty(self): + executor = ToolExecutor() + result = executor.prepare_tools_for_llm({}) + assert result == [] + + def test_prepare_tools_for_llm_non_api_tool(self): + executor = ToolExecutor() + tools_dict = { + "t1": { + "name": "test_tool", + "actions": [ + { + "name": "do_thing", + "description": "Does a thing", + "active": True, + "parameters": { + "properties": { + "query": { + "type": "string", + "description": "The query", + "filled_by_llm": True, + "required": True, + } + } + }, + } + ], + } + } + + result = executor.prepare_tools_for_llm(tools_dict) + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["function"]["name"] == "do_thing_t1" + assert "query" in result[0]["function"]["parameters"]["properties"] + + def test_prepare_tools_skips_inactive_actions(self): + executor = ToolExecutor() + tools_dict = { + "t1": { + "name": "test_tool", + "actions": [ + {"name": "active_one", "description": "D", "active": True, "parameters": {"properties": {}}}, + {"name": "inactive_one", "description": "D", "active": False, "parameters": {"properties": {}}}, + ], + } + } + + result = executor.prepare_tools_for_llm(tools_dict) + assert len(result) == 1 + assert result[0]["function"]["name"] == "active_one_t1" + + def test_build_tool_parameters_filters_non_llm_fields(self): + executor = ToolExecutor() + action = { + "parameters": { + "properties": { + "query": { + "type": "string", + "description": "Search query", + "filled_by_llm": True, + "value": "default_val", + "required": True, + }, + "hidden": { + "type": "string", + "filled_by_llm": False, + }, + } + } + } + + result = executor._build_tool_parameters(action) + assert "query" in result["properties"] + assert "hidden" not in result["properties"] + assert "query" in result["required"] + # filled_by_llm, value, required stripped from schema + assert "filled_by_llm" not in result["properties"]["query"] + assert "value" not in result["properties"]["query"] + + +@pytest.mark.unit +class TestToolExecutorExecute: + + def _make_call(self, name="action_toolid", call_id="c1", arguments="{}"): + call = Mock() + call.name = name + call.id = call_id + call.arguments = arguments + return call + + def test_execute_parse_failure(self, monkeypatch): + executor = ToolExecutor() + + monkeypatch.setattr( + "application.agents.tool_executor.ToolActionParser", + lambda _cls: Mock(parse_args=Mock(return_value=(None, None, {}))), + ) + + call = self._make_call(name="bad") + gen = executor.execute({}, call, "MockLLM") + + events = [] + result = None + while True: + try: + events.append(next(gen)) + except StopIteration as e: + result = e.value + break + + assert result[0] == "Failed to parse tool call." + assert len(executor.tool_calls) == 1 + assert events[0]["data"]["status"] == "error" + + def test_execute_tool_not_found(self, monkeypatch): + executor = ToolExecutor() + + monkeypatch.setattr( + "application.agents.tool_executor.ToolActionParser", + lambda _cls: Mock(parse_args=Mock(return_value=("missing_id", "action", {}))), + ) + + call = self._make_call() + gen = executor.execute({}, call, "MockLLM") + + events = [] + result = None + while True: + try: + events.append(next(gen)) + except StopIteration as e: + result = e.value + break + + assert "not found" in result[0] + assert events[0]["data"]["status"] == "error" + + def test_execute_success(self, mock_tool_manager, monkeypatch): + executor = ToolExecutor(user="test_user") + + monkeypatch.setattr( + "application.agents.tool_executor.ToolActionParser", + lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {"param1": "val"}))), + ) + + tools_dict = { + "t1": { + "name": "test_tool", + "config": {"key": "val"}, + "actions": [ + {"name": "test_action", "description": "Test", "parameters": {"properties": {}}}, + ], + } + } + + call = self._make_call(name="test_action_t1", call_id="c1") + gen = executor.execute(tools_dict, call, "MockLLM") + + events = [] + result = None + while True: + try: + events.append(next(gen)) + except StopIteration as e: + result = e.value + break + + assert result[0] == "Tool result" + assert result[1] == "c1" + + statuses = [e["data"]["status"] for e in events] + assert "pending" in statuses + assert "completed" in statuses + + def test_get_truncated_tool_calls(self): + executor = ToolExecutor() + executor.tool_calls = [ + { + "tool_name": "test", + "call_id": "1", + "action_name": "act", + "arguments": {}, + "result": "A" * 100, + } + ] + + truncated = executor.get_truncated_tool_calls() + assert len(truncated) == 1 + assert len(truncated[0]["result"]) <= 53 + assert truncated[0]["status"] == "completed" + + def test_tool_caching(self, mock_tool_manager, monkeypatch): + executor = ToolExecutor(user="test_user") + + monkeypatch.setattr( + "application.agents.tool_executor.ToolActionParser", + lambda _cls: Mock(parse_args=Mock(return_value=("t1", "test_action", {}))), + ) + + tools_dict = { + "t1": { + "name": "test_tool", + "config": {"key": "val"}, + "actions": [ + {"name": "test_action", "description": "Test", "parameters": {"properties": {}}}, + ], + } + } + + call = self._make_call(name="test_action_t1") + + # First execution — loads tool + gen = executor.execute(tools_dict, call, "MockLLM") + while True: + try: + next(gen) + except StopIteration: + break + + # Second execution — should use cache + gen = executor.execute(tools_dict, call, "MockLLM") + while True: + try: + next(gen) + except StopIteration: + break + + # load_tool called only once due to cache + assert mock_tool_manager.load_tool.call_count == 1 diff --git a/tests/agents/test_workflow_agent_types.py b/tests/agents/test_workflow_agent_types.py new file mode 100644 index 00000000..4b751d17 --- /dev/null +++ b/tests/agents/test_workflow_agent_types.py @@ -0,0 +1,475 @@ +"""Tests for new agent types (agentic, research) in the workflow builder.""" + +from types import SimpleNamespace +from typing import Any, Dict + +import pytest + +from application.agents.agentic_agent import AgenticAgent +from application.agents.classic_agent import ClassicAgent +from application.agents.research_agent import ResearchAgent +from application.agents.workflows.node_agent import ( + WorkflowNodeAgenticAgent, + WorkflowNodeAgentFactory, + WorkflowNodeClassicAgent, + WorkflowNodeResearchAgent, +) +from application.agents.workflows.schemas import ( + AgentNodeConfig, + AgentType, + NodeType, + Workflow, + WorkflowGraph, + WorkflowNode, +) +from application.agents.workflows.workflow_engine import WorkflowEngine + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class StubNodeAgent: + """Minimal agent stub that yields pre-defined events.""" + + def __init__(self, events): + self.events = events + + def gen(self, _prompt): + yield from self.events + + +def create_engine() -> WorkflowEngine: + graph = WorkflowGraph(workflow=Workflow(name="Test"), nodes=[], edges=[]) + agent = SimpleNamespace( + endpoint="stream", + llm_name="openai", + model_id="gpt-4o-mini", + api_key="test-key", + chat_history=[], + decoded_token={"sub": "user-1"}, + ) + return WorkflowEngine(graph, agent) + + +def create_agent_node( + node_id: str, + agent_type: str = "classic", + sources: list = None, + chunks: str = "2", + retriever: str = "", + output_variable: str = "", +) -> WorkflowNode: + config: Dict[str, Any] = { + "agent_type": agent_type, + "system_prompt": "You are a helpful assistant.", + "prompt_template": "", + "stream_to_user": False, + "tools": [], + "sources": sources or [], + "chunks": chunks, + "retriever": retriever, + } + if output_variable: + config["output_variable"] = output_variable + return WorkflowNode( + id=node_id, + workflow_id="workflow-1", + type=NodeType.AGENT, + title="Agent", + position={"x": 0, "y": 0}, + config=config, + ) + + +# --------------------------------------------------------------------------- +# AgentType enum +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestAgentTypeEnum: + + def test_agentic_value_exists(self): + assert AgentType.AGENTIC == "agentic" + + def test_research_value_exists(self): + assert AgentType.RESEARCH == "research" + + def test_classic_still_exists(self): + assert AgentType.CLASSIC == "classic" + + def test_react_still_exists(self): + assert AgentType.REACT == "react" + + +# --------------------------------------------------------------------------- +# AgentNodeConfig schema validation +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestAgentNodeConfigValidation: + + def test_accepts_agentic_agent_type(self): + config = AgentNodeConfig(agent_type="agentic") + assert config.agent_type == AgentType.AGENTIC + + def test_accepts_research_agent_type(self): + config = AgentNodeConfig(agent_type="research") + assert config.agent_type == AgentType.RESEARCH + + def test_rejects_unknown_agent_type(self): + with pytest.raises(Exception): + AgentNodeConfig(agent_type="nonexistent") + + def test_default_agent_type_is_classic(self): + config = AgentNodeConfig() + assert config.agent_type == AgentType.CLASSIC + + +# --------------------------------------------------------------------------- +# WorkflowNodeAgentFactory registry +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestWorkflowNodeAgentFactoryRegistry: + + def test_factory_has_agentic(self): + assert AgentType.AGENTIC in WorkflowNodeAgentFactory._agents + assert WorkflowNodeAgentFactory._agents[AgentType.AGENTIC] is WorkflowNodeAgenticAgent + + def test_factory_has_research(self): + assert AgentType.RESEARCH in WorkflowNodeAgentFactory._agents + assert WorkflowNodeAgentFactory._agents[AgentType.RESEARCH] is WorkflowNodeResearchAgent + + def test_factory_raises_for_unknown_type(self): + with pytest.raises(ValueError, match="Unsupported agent type"): + WorkflowNodeAgentFactory.create( + agent_type="nonexistent", + endpoint="stream", + llm_name="openai", + model_id="gpt-4o-mini", + api_key="key", + ) + + +# --------------------------------------------------------------------------- +# WorkflowNode agent classes (inheritance) +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestWorkflowNodeAgentClasses: + + def test_agentic_agent_inherits_correctly(self): + assert issubclass(WorkflowNodeAgenticAgent, AgenticAgent) + + def test_research_agent_inherits_correctly(self): + assert issubclass(WorkflowNodeResearchAgent, ResearchAgent) + + def test_classic_agent_inherits_correctly(self): + assert issubclass(WorkflowNodeClassicAgent, ClassicAgent) + + +# --------------------------------------------------------------------------- +# Workflow engine: agentic agent node execution +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWorkflowEngineAgenticNode: + + def test_agentic_node_executes_and_saves_output(self, monkeypatch): + engine = create_engine() + node = create_agent_node( + node_id="agent_agentic", + agent_type="agentic", + output_variable="result", + ) + node_events = [{"answer": "agentic answer"}] + + captured: Dict[str, Any] = {} + + def capture_create(**kwargs): + captured.update(kwargs) + return StubNodeAgent(node_events) + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(capture_create), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + list(engine._execute_agent_node(node)) + + assert engine.state["node_agent_agentic_output"] == "agentic answer" + assert engine.state["result"] == "agentic answer" + + def test_agentic_node_passes_retriever_config(self, monkeypatch): + engine = create_engine() + node = create_agent_node( + node_id="agent_rc", + agent_type="agentic", + sources=["source-abc"], + chunks="4", + retriever="semantic", + ) + node_events = [{"answer": "ok"}] + + captured: Dict[str, Any] = {} + + def capture_create(**kwargs): + captured.update(kwargs) + return StubNodeAgent(node_events) + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(capture_create), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + list(engine._execute_agent_node(node)) + + rc = captured.get("retriever_config") + assert rc is not None + assert rc["source"] == {"active_docs": ["source-abc"]} + assert rc["retriever_name"] == "semantic" + assert rc["chunks"] == 4 + + def test_agentic_node_empty_sources_gives_empty_source_dict(self, monkeypatch): + engine = create_engine() + node = create_agent_node( + node_id="agent_nosrc", + agent_type="agentic", + sources=[], + ) + node_events = [{"answer": "ok"}] + + captured: Dict[str, Any] = {} + + def capture_create(**kwargs): + captured.update(kwargs) + return StubNodeAgent(node_events) + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(capture_create), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + list(engine._execute_agent_node(node)) + + rc = captured["retriever_config"] + assert rc["source"] == {} + + +# --------------------------------------------------------------------------- +# Workflow engine: research agent node execution +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWorkflowEngineResearchNode: + + def test_research_node_executes_and_saves_output(self, monkeypatch): + engine = create_engine() + node = create_agent_node( + node_id="agent_research", + agent_type="research", + output_variable="report", + ) + node_events = [{"answer": "research report"}] + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(lambda **kwargs: StubNodeAgent(node_events)), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + list(engine._execute_agent_node(node)) + + assert engine.state["node_agent_research_output"] == "research report" + assert engine.state["report"] == "research report" + + def test_research_node_passes_retriever_config(self, monkeypatch): + engine = create_engine() + node = create_agent_node( + node_id="agent_rr", + agent_type="research", + sources=["doc-1", "doc-2"], + chunks="6", + ) + node_events = [{"answer": "ok"}] + + captured: Dict[str, Any] = {} + + def capture_create(**kwargs): + captured.update(kwargs) + return StubNodeAgent(node_events) + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(capture_create), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + list(engine._execute_agent_node(node)) + + rc = captured["retriever_config"] + assert rc["source"] == {"active_docs": ["doc-1", "doc-2"]} + assert rc["chunks"] == 6 + assert rc["decoded_token"] == {"sub": "user-1"} + + def test_research_node_handles_structured_output(self, monkeypatch): + engine = create_engine() + node = create_agent_node( + node_id="agent_rs", + agent_type="research", + output_variable="data", + ) + # Simulate structured JSON output from research agent + node_events = [ + {"answer": '{"findings": "important"}', "structured": True}, + ] + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(lambda **kwargs: StubNodeAgent(node_events)), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + list(engine._execute_agent_node(node)) + + # structured=True causes the engine to parse JSON + assert engine.state["data"] == {"findings": "important"} + + +# --------------------------------------------------------------------------- +# Workflow engine: classic node does NOT get retriever_config +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWorkflowEngineClassicNodeNoRetrieverConfig: + + def test_classic_node_does_not_pass_retriever_config(self, monkeypatch): + engine = create_engine() + node = create_agent_node( + node_id="agent_classic", + agent_type="classic", + sources=["some-source"], + ) + node_events = [{"answer": "classic answer"}] + + captured: Dict[str, Any] = {} + + def capture_create(**kwargs): + captured.update(kwargs) + return StubNodeAgent(node_events) + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(capture_create), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + list(engine._execute_agent_node(node)) + + assert "retriever_config" not in captured + + +# --------------------------------------------------------------------------- +# Workflow engine: streaming events from new agent types +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestWorkflowEngineStreamingEvents: + + def test_agentic_node_streams_answer_events(self, monkeypatch): + engine = create_engine() + node = create_agent_node(node_id="agent_s1", agent_type="agentic") + # Modify config to enable streaming + node.config["stream_to_user"] = True + + node_events = [ + {"answer": "chunk 1"}, + {"answer": "chunk 2"}, + ] + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(lambda **kwargs: StubNodeAgent(node_events)), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + results = list(engine._execute_agent_node(node)) + answer_events = [r for r in results if "answer" in r] + assert len(answer_events) == 2 + + def test_research_node_passes_through_non_answer_events(self, monkeypatch): + """Research agents yield research_plan/research_progress events. + The workflow engine only forwards 'answer' events to the user.""" + engine = create_engine() + node = create_agent_node(node_id="agent_s2", agent_type="research") + node.config["stream_to_user"] = True + + node_events = [ + {"type": "research_plan", "data": {"steps": [], "complexity": "simple"}}, + {"type": "research_progress", "data": {"status": "planning"}}, + {"answer": "final report"}, + ] + + monkeypatch.setattr( + WorkflowNodeAgentFactory, + "create", + staticmethod(lambda **kwargs: StubNodeAgent(node_events)), + ) + monkeypatch.setattr( + "application.core.model_utils.get_api_key_for_provider", + lambda _provider: None, + ) + + results = list(engine._execute_agent_node(node)) + # Only answer events are streamed to user + answer_events = [r for r in results if "answer" in r] + assert len(answer_events) == 1 + assert answer_events[0]["answer"] == "final report" + + # State still captures the full text + assert engine.state["node_agent_s2_output"] == "final report" diff --git a/tests/integration/test_chat.py b/tests/integration/test_chat.py index 350faaba..2651d169 100644 --- a/tests/integration/test_chat.py +++ b/tests/integration/test_chat.py @@ -492,6 +492,486 @@ This is test documentation for integration tests. self.record_result(test_name, False, str(e)) return False + # ------------------------------------------------------------------------- + # Agentic / Research Agent Tests + # ------------------------------------------------------------------------- + + def _create_agent_with_type(self, agent_type: str) -> Optional[tuple]: + """Create a test agent with the given agent_type. Returns (agent_id, api_key) or None.""" + if not self.is_authenticated: + return None + + payload = { + "name": f"Chat Test {agent_type.title()} Agent {int(time.time())}", + "description": f"Integration test {agent_type} agent", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": agent_type, + "status": "draft", + } + + try: + response = self.post("/api/create_agent", json=payload, timeout=10) + if response.status_code in [200, 201]: + result = response.json() + agent_id = result.get("id") + api_key = result.get("key") + if agent_id: + return (agent_id, api_key) + except Exception: + pass + + return None + + def test_stream_agentic_agent(self) -> bool: + """Test /stream endpoint with an agentic agent.""" + test_name = "Stream endpoint (agentic agent)" + + agent_result = self._create_agent_with_type("agentic") + if not agent_result: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + agent_id, _ = agent_result + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "agent_id": agent_id, + } + + try: + self.print_info(f"POST /stream with agentic agent_id={agent_id[:8]}...") + + response = requests.post( + f"{self.base_url}/stream", + json=payload, + headers=self.headers, + stream=True, + timeout=60, + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + events = [] + full_response = "" + for line in response.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + try: + data = json_module.loads(line_str[6:]) + events.append(data) + if data.get("type") in ["stream", "answer"]: + full_response += data.get("message", "") or data.get("answer", "") + elif data.get("type") == "end": + break + except json_module.JSONDecodeError: + pass + + self.print_success(f"Received {len(events)} events") + if full_response: + self.print_success(f"Answer preview: {full_response[:100]}...") + self.record_result(test_name, True, "Success") + return True + + except Exception as e: + self.print_error(f"Error: {str(e)}") + self.record_result(test_name, False, str(e)) + return False + + def test_answer_agentic_agent(self) -> bool: + """Test /api/answer endpoint with an agentic agent.""" + test_name = "Answer endpoint (agentic agent)" + + agent_result = self._create_agent_with_type("agentic") + if not agent_result: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + agent_id, _ = agent_result + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "agent_id": agent_id, + } + + try: + self.print_info(f"POST /api/answer with agentic agent_id={agent_id[:8]}...") + + response = self.post("/api/answer", json=payload, timeout=60) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + result = response.json() + answer = result.get("answer", "") + self.print_success(f"Answer received: {answer[:100]}...") + self.record_result(test_name, True, "Success") + return True + + except Exception as e: + self.print_error(f"Error: {str(e)}") + self.record_result(test_name, False, str(e)) + return False + + def test_stream_research_agent(self) -> bool: + """Test /stream endpoint with a research agent. + + Research agents emit additional SSE event types: + - research_plan: the decomposed research steps + - research_progress: per-step status updates + """ + test_name = "Stream endpoint (research agent)" + + agent_result = self._create_agent_with_type("research") + if not agent_result: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + agent_id, _ = agent_result + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "agent_id": agent_id, + } + + try: + self.print_info(f"POST /stream with research agent_id={agent_id[:8]}...") + + response = requests.post( + f"{self.base_url}/stream", + json=payload, + headers=self.headers, + stream=True, + timeout=120, + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + events = [] + full_response = "" + saw_plan = False + saw_progress = False + for line in response.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + try: + data = json_module.loads(line_str[6:]) + events.append(data) + event_type = data.get("type", "") + if event_type in ["stream", "answer"]: + full_response += data.get("message", "") or data.get("answer", "") + elif event_type == "research_plan": + saw_plan = True + steps = data.get("data", {}).get("steps", []) + self.print_info(f"Research plan: {len(steps)} steps, complexity={data.get('data', {}).get('complexity')}") + elif event_type == "research_progress": + saw_progress = True + elif event_type == "end": + break + except json_module.JSONDecodeError: + pass + + self.print_success(f"Received {len(events)} events") + if saw_plan: + self.print_success("Received research_plan event") + if saw_progress: + self.print_success("Received research_progress events") + if full_response: + self.print_success(f"Report preview: {full_response[:100]}...") + + self.record_result(test_name, True, "Success") + return True + + except Exception as e: + self.print_error(f"Error: {str(e)}") + self.record_result(test_name, False, str(e)) + return False + + def test_answer_research_agent(self) -> bool: + """Test /api/answer endpoint with a research agent.""" + test_name = "Answer endpoint (research agent)" + + agent_result = self._create_agent_with_type("research") + if not agent_result: + if not self.require_auth(test_name): + return True + self.record_result(test_name, True, "Skipped (no agent)") + return True + + agent_id, _ = agent_result + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "agent_id": agent_id, + } + + try: + self.print_info(f"POST /api/answer with research agent_id={agent_id[:8]}...") + + response = self.post("/api/answer", json=payload, timeout=120) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + result = response.json() + answer = result.get("answer", "") + self.print_success(f"Answer received: {answer[:100]}...") + self.record_result(test_name, True, "Success") + return True + + except Exception as e: + self.print_error(f"Error: {str(e)}") + self.record_result(test_name, False, str(e)) + return False + + # ------------------------------------------------------------------------- + # Workflow with Agentic Node Tests + # ------------------------------------------------------------------------- + + def _create_workflow_with_agentic_node(self) -> Optional[str]: + """Create a workflow with a single agentic agent node. + + Returns the workflow_id or None on failure. + """ + nodes = [ + { + "id": "start_1", + "type": "start", + "title": "Start", + "data": {}, + }, + { + "id": "agent_1", + "type": "agent", + "title": "Agentic Node", + "data": { + "agent_type": "agentic", + "system_prompt": "You are a helpful assistant.", + "prompt_template": "", + "stream_to_user": True, + "tools": [], + "sources": [], + "chunks": "2", + }, + }, + { + "id": "end_1", + "type": "end", + "title": "End", + "data": {}, + }, + ] + edges = [ + {"id": "edge_1", "source": "start_1", "target": "agent_1"}, + {"id": "edge_2", "source": "agent_1", "target": "end_1"}, + ] + + payload = { + "name": f"Agentic Workflow Test {int(time.time())}", + "nodes": nodes, + "edges": edges, + } + + try: + response = self.post("/api/workflows", json=payload, timeout=15) + if response.status_code in [200, 201]: + result = response.json() + return result.get("id") + except Exception: + pass + + return None + + def _create_workflow_agent(self, workflow_id: str) -> Optional[tuple]: + """Create an agent of type 'workflow' referencing a workflow. + + Returns (agent_id, api_key) or None. + """ + payload = { + "name": f"Workflow Agent Test {int(time.time())}", + "description": "Integration test workflow agent with agentic node", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": "workflow", + "workflow": workflow_id, + "status": "draft", + } + + try: + response = self.post("/api/create_agent", json=payload, timeout=10) + if response.status_code in [200, 201]: + result = response.json() + agent_id = result.get("id") + api_key = result.get("key") + if agent_id: + return (agent_id, api_key) + except Exception: + pass + + return None + + def test_stream_workflow_with_agentic_node(self) -> bool: + """Test /stream with a workflow agent that contains an agentic node.""" + test_name = "Stream endpoint (workflow with agentic node)" + + if not self.require_auth(test_name): + return True + + workflow_id = self._create_workflow_with_agentic_node() + if not workflow_id: + self.print_warning("Could not create workflow") + self.record_result(test_name, True, "Skipped (workflow creation failed)") + return True + + agent_result = self._create_workflow_agent(workflow_id) + if not agent_result: + self.print_warning("Could not create workflow agent") + self.record_result(test_name, True, "Skipped (agent creation failed)") + return True + + agent_id, _ = agent_result + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "agent_id": agent_id, + } + + try: + self.print_info(f"POST /stream with workflow agent_id={agent_id[:8]}...") + self.print_info(f"Workflow ID: {workflow_id}") + + response = requests.post( + f"{self.base_url}/stream", + json=payload, + headers=self.headers, + stream=True, + timeout=60, + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + events = [] + full_response = "" + for line in response.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + try: + data = json_module.loads(line_str[6:]) + events.append(data) + if data.get("type") in ["stream", "answer"]: + full_response += data.get("message", "") or data.get("answer", "") + elif data.get("type") == "end": + break + except json_module.JSONDecodeError: + pass + + self.print_success(f"Received {len(events)} events") + if full_response: + self.print_success(f"Answer preview: {full_response[:100]}...") + + self.record_result(test_name, True, "Success") + return True + + except Exception as e: + self.print_error(f"Error: {str(e)}") + self.record_result(test_name, False, str(e)) + return False + + def test_answer_workflow_with_agentic_node(self) -> bool: + """Test /api/answer with a workflow agent that contains an agentic node.""" + test_name = "Answer endpoint (workflow with agentic node)" + + if not self.require_auth(test_name): + return True + + workflow_id = self._create_workflow_with_agentic_node() + if not workflow_id: + self.print_warning("Could not create workflow") + self.record_result(test_name, True, "Skipped (workflow creation failed)") + return True + + agent_result = self._create_workflow_agent(workflow_id) + if not agent_result: + self.print_warning("Could not create workflow agent") + self.record_result(test_name, True, "Skipped (agent creation failed)") + return True + + agent_id, _ = agent_result + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "agent_id": agent_id, + } + + try: + self.print_info(f"POST /api/answer with workflow agent_id={agent_id[:8]}...") + + response = self.post("/api/answer", json=payload, timeout=60) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.record_result(test_name, False, f"Status {response.status_code}") + return False + + result = response.json() + answer = result.get("answer", "") + self.print_success(f"Answer received: {answer[:100]}...") + self.record_result(test_name, True, "Success") + return True + + except Exception as e: + self.print_error(f"Error: {str(e)}") + self.record_result(test_name, False, str(e)) + return False + # ------------------------------------------------------------------------- # Validation Tests # ------------------------------------------------------------------------- @@ -916,6 +1396,27 @@ This is test documentation for integration tests. self.test_answer_endpoint_with_agent() time.sleep(1) + # Agentic agent tests + self.test_stream_agentic_agent() + time.sleep(1) + + self.test_answer_agentic_agent() + time.sleep(1) + + # Research agent tests + self.test_stream_research_agent() + time.sleep(2) + + self.test_answer_research_agent() + time.sleep(2) + + # Workflow with agentic node tests + self.test_stream_workflow_with_agentic_node() + time.sleep(2) + + self.test_answer_workflow_with_agentic_node() + time.sleep(2) + # API key tests self.test_stream_endpoint_with_api_key() time.sleep(1)