mirror of
https://github.com/GH05TCREW/pentestagent.git
synced 2026-03-08 06:44:11 +00:00
236 lines
5.9 KiB
Python
236 lines
5.9 KiB
Python
"""Tool registry for GhostCrew."""
|
|
|
|
from dataclasses import dataclass, field
|
|
from functools import wraps
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|
|
|
if TYPE_CHECKING:
|
|
from ..runtime import Runtime
|
|
|
|
|
|
@dataclass
|
|
class ToolSchema:
|
|
"""JSON Schema for tool parameters."""
|
|
|
|
type: str = "object"
|
|
properties: Optional[Dict[str, Any]] = None
|
|
required: Optional[List[str]] = None
|
|
|
|
def __post_init__(self):
|
|
if self.properties is None:
|
|
self.properties = {}
|
|
if self.required is None:
|
|
self.required = []
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary format."""
|
|
return {
|
|
"type": self.type,
|
|
"properties": self.properties,
|
|
"required": self.required,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Tool:
|
|
"""Represents a tool available to agents."""
|
|
|
|
name: str
|
|
description: str
|
|
schema: ToolSchema
|
|
execute_fn: Callable
|
|
category: str = "general"
|
|
enabled: bool = True
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
async def execute(self, arguments: dict, runtime: "Runtime") -> str:
|
|
"""
|
|
Execute the tool with given arguments.
|
|
|
|
Args:
|
|
arguments: The arguments for the tool
|
|
runtime: The runtime environment
|
|
|
|
Returns:
|
|
The tool execution result as a string
|
|
"""
|
|
if not self.enabled:
|
|
return f"Tool '{self.name}' is currently disabled."
|
|
|
|
return await self.execute_fn(arguments, runtime)
|
|
|
|
def to_llm_format(self) -> dict:
|
|
"""Convert to LLM function calling format."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"parameters": {
|
|
"type": self.schema.type,
|
|
"properties": self.schema.properties or {},
|
|
"required": self.schema.required or [],
|
|
},
|
|
},
|
|
}
|
|
|
|
def validate_arguments(self, arguments: dict) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate arguments against the schema.
|
|
|
|
Args:
|
|
arguments: The arguments to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
# Check required fields
|
|
for required_field in self.schema.required or []:
|
|
if required_field not in arguments:
|
|
return False, f"Missing required field: {required_field}"
|
|
|
|
# Type checking (basic)
|
|
for key, value in arguments.items():
|
|
if key in (self.schema.properties or {}):
|
|
expected_type = self.schema.properties[key].get("type")
|
|
if expected_type:
|
|
if not self._check_type(value, expected_type):
|
|
return (
|
|
False,
|
|
f"Invalid type for {key}: expected {expected_type}",
|
|
)
|
|
|
|
return True, None
|
|
|
|
def _check_type(self, value: Any, expected_type: str) -> bool:
|
|
"""Check if a value matches the expected type."""
|
|
type_map = {
|
|
"string": str,
|
|
"integer": int,
|
|
"number": (int, float),
|
|
"boolean": bool,
|
|
"array": list,
|
|
"object": dict,
|
|
}
|
|
|
|
expected = type_map.get(expected_type)
|
|
if expected is None:
|
|
return True # Unknown type, allow
|
|
|
|
return isinstance(value, expected)
|
|
|
|
|
|
# Global tool registry
|
|
_tools: Dict[str, Tool] = {}
|
|
|
|
|
|
def register_tool(
|
|
name: str, description: str, schema: ToolSchema, category: str = "general"
|
|
) -> Callable:
|
|
"""
|
|
Decorator to register a tool.
|
|
|
|
Args:
|
|
name: The tool name
|
|
description: Description of what the tool does
|
|
schema: The parameter schema
|
|
category: Tool category for organization
|
|
|
|
Returns:
|
|
Decorator function
|
|
"""
|
|
|
|
def decorator(fn: Callable) -> Callable:
|
|
@wraps(fn)
|
|
async def wrapper(arguments: dict, runtime: "Runtime") -> str:
|
|
return await fn(arguments, runtime)
|
|
|
|
tool = Tool(
|
|
name=name,
|
|
description=description,
|
|
schema=schema,
|
|
execute_fn=wrapper,
|
|
category=category,
|
|
)
|
|
_tools[name] = tool
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def get_all_tools() -> List[Tool]:
|
|
"""Get all registered tools."""
|
|
return list(_tools.values())
|
|
|
|
|
|
def get_tool(name: str) -> Optional[Tool]:
|
|
"""Get a tool by name."""
|
|
return _tools.get(name)
|
|
|
|
|
|
def register_tool_instance(tool: Tool) -> None:
|
|
"""
|
|
Register a pre-built tool instance (used for MCP tools).
|
|
|
|
Args:
|
|
tool: The Tool instance to register
|
|
"""
|
|
_tools[tool.name] = tool
|
|
|
|
|
|
def unregister_tool(name: str) -> bool:
|
|
"""
|
|
Unregister a tool by name.
|
|
|
|
Args:
|
|
name: The tool name to unregister
|
|
|
|
Returns:
|
|
True if tool was unregistered, False if not found
|
|
"""
|
|
if name in _tools:
|
|
del _tools[name]
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_tools_by_category(category: str) -> List[Tool]:
|
|
"""
|
|
Get all tools in a specific category.
|
|
|
|
Args:
|
|
category: The category to filter by
|
|
|
|
Returns:
|
|
List of tools in that category
|
|
"""
|
|
return [tool for tool in _tools.values() if tool.category == category]
|
|
|
|
|
|
def clear_tools() -> None:
|
|
"""Clear all registered tools."""
|
|
_tools.clear()
|
|
|
|
|
|
def get_tool_names() -> List[str]:
|
|
"""Get list of all registered tool names."""
|
|
return list(_tools.keys())
|
|
|
|
|
|
def enable_tool(name: str) -> bool:
|
|
"""Enable a tool by name."""
|
|
tool = _tools.get(name)
|
|
if tool:
|
|
tool.enabled = True
|
|
return True
|
|
return False
|
|
|
|
|
|
def disable_tool(name: str) -> bool:
|
|
"""Disable a tool by name."""
|
|
tool = _tools.get(name)
|
|
if tool:
|
|
tool.enabled = False
|
|
return True
|
|
return False
|