Files
DocsGPT/application/agents/workflows/node_agent.py
2026-03-25 22:34:25 +00:00

105 lines
3.0 KiB
Python

"""Workflow Node Agents - defines specialized agents for workflow nodes."""
from typing import Any, Dict, List, Optional, Type
from application.agents.agentic_agent import AgenticAgent
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflows.schemas import AgentType
class ToolFilterMixin:
"""Mixin that filters fetched tools to only those specified in tool_ids."""
_allowed_tool_ids: List[str]
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_user_tools(user)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_tools(api_key)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
class _WorkflowNodeMixin:
"""Common __init__ for all workflow node agents."""
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
pass
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
pass
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
pass
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
AgentType.RESEARCH: WorkflowNodeResearchAgent,
}
@classmethod
def create(
cls,
agent_type: AgentType,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
) -> BaseAgent:
agent_class = cls._agents.get(agent_type)
if not agent_class:
raise ValueError(f"Unsupported agent type: {agent_type}")
return agent_class(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
tool_ids=tool_ids,
**kwargs,
)