mirror of
https://github.com/GH05TCREW/pentestagent.git
synced 2026-03-07 14:23:20 +00:00
Merge pull request #19 from giveen/workspace
feat(workspaces): add unified /workspace lifecycle, target persistence, and workspace-scoped RAG
This commit is contained in:
12
.gitignore
vendored
12
.gitignore
vendored
@@ -81,3 +81,15 @@ Thumbs.db
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
|
||||
# Local test artifacts and test scripts (do not commit local test runs)
|
||||
tests/*.log
|
||||
tests/*.out
|
||||
tests/output/
|
||||
tests/tmp/
|
||||
tests/*.local.py
|
||||
scripts/test_*.sh
|
||||
*.test.sh
|
||||
|
||||
# Workspaces directory (user data should not be committed)
|
||||
/workspaces/
|
||||
|
||||
BIN
dupe-workspace.tar.gz
Normal file
BIN
dupe-workspace.tar.gz
Normal file
Binary file not shown.
BIN
expimp-workspace.tar.gz
Normal file
BIN
expimp-workspace.tar.gz
Normal file
Binary file not shown.
@@ -5,6 +5,8 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, List, Optional
|
||||
|
||||
from ..config.constants import AGENT_MAX_ITERATIONS
|
||||
from ..workspaces.manager import WorkspaceManager
|
||||
from ..workspaces import validation
|
||||
from .state import AgentState, AgentStateManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -79,68 +81,73 @@ class BaseAgent(ABC):
|
||||
tools: List["Tool"],
|
||||
runtime: "Runtime",
|
||||
max_iterations: int = AGENT_MAX_ITERATIONS,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the base agent.
|
||||
Initialize base agent state.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance for generating responses
|
||||
tools: List of tools available to the agent
|
||||
runtime: The runtime environment for tool execution
|
||||
max_iterations: Maximum iterations before forcing stop (safety limit)
|
||||
llm: LLM instance used for generation
|
||||
tools: Available tool list
|
||||
runtime: Runtime used for tool execution
|
||||
max_iterations: Safety limit for iterations
|
||||
"""
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
self.runtime = runtime
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
# Agent runtime state
|
||||
self.state_manager = AgentStateManager()
|
||||
self.conversation_history: List[AgentMessage] = []
|
||||
|
||||
# Each agent gets its own plan instance
|
||||
from ..tools.finish import TaskPlan
|
||||
# Task planning structure (used by finish tool)
|
||||
try:
|
||||
from ..tools.finish import TaskPlan
|
||||
|
||||
self._task_plan = TaskPlan()
|
||||
self._task_plan = TaskPlan()
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
# Attach plan to runtime so finish tool can access it
|
||||
self.runtime.plan = self._task_plan
|
||||
logging.getLogger(__name__).exception("Failed importing TaskPlan: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
# Use tools as-is (finish accesses plan via runtime)
|
||||
self.tools = list(tools)
|
||||
notify("warning", f"Failed to import TaskPlan: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about TaskPlan import failure")
|
||||
# Fallback simple plan structure
|
||||
class _SimplePlan:
|
||||
def __init__(self):
|
||||
self.steps = []
|
||||
self.original_request = ""
|
||||
|
||||
@property
|
||||
def state(self) -> AgentState:
|
||||
"""Get current agent state."""
|
||||
return self.state_manager.current_state
|
||||
def clear(self):
|
||||
self.steps.clear()
|
||||
|
||||
@state.setter
|
||||
def state(self, value: AgentState):
|
||||
"""Set agent state."""
|
||||
self.state_manager.transition_to(value)
|
||||
def is_complete(self):
|
||||
return True
|
||||
|
||||
def cleanup_after_cancel(self) -> None:
|
||||
"""
|
||||
Clean up agent state after a cancellation.
|
||||
def has_failure(self):
|
||||
return False
|
||||
|
||||
Removes the cancelled request and any pending tool calls from
|
||||
conversation history to prevent stale responses from contaminating
|
||||
the next conversation.
|
||||
"""
|
||||
# Remove incomplete messages from the end of conversation
|
||||
while self.conversation_history:
|
||||
last_msg = self.conversation_history[-1]
|
||||
# Remove assistant message with tool calls (incomplete tool execution)
|
||||
if last_msg.role == "assistant" and last_msg.tool_calls:
|
||||
self.conversation_history.pop()
|
||||
# Remove orphaned tool_result messages
|
||||
elif last_msg.role == "tool":
|
||||
self.conversation_history.pop()
|
||||
# Remove the user message that triggered the cancelled request
|
||||
elif last_msg.role == "user":
|
||||
self.conversation_history.pop()
|
||||
break # Stop after removing the user message
|
||||
else:
|
||||
break
|
||||
self._task_plan = _SimplePlan()
|
||||
|
||||
# Reset state to idle
|
||||
# Expose plan to runtime so tools like `finish` can access it
|
||||
try:
|
||||
self.runtime.plan = self._task_plan
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to attach plan to runtime: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to attach plan to runtime: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about runtime plan attach failure")
|
||||
|
||||
# Ensure agent starts idle
|
||||
self.state_manager.transition_to(AgentState.IDLE)
|
||||
|
||||
@abstractmethod
|
||||
@@ -448,16 +455,59 @@ class BaseAgent(ABC):
|
||||
|
||||
if tool:
|
||||
try:
|
||||
result = await tool.execute(arguments, self.runtime)
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=name,
|
||||
result=result,
|
||||
success=True,
|
||||
# Before executing, enforce target safety gate when workspace active
|
||||
wm = WorkspaceManager()
|
||||
active = wm.get_active()
|
||||
|
||||
# Use centralized validation helpers for target extraction and scope checks
|
||||
candidates = validation.gather_candidate_targets(arguments)
|
||||
out_of_scope = []
|
||||
if active:
|
||||
allowed = wm.list_targets(active)
|
||||
for c in candidates:
|
||||
try:
|
||||
if not validation.is_target_in_scope(c, allowed):
|
||||
out_of_scope.append(c)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).exception(
|
||||
"Error validating candidate target %s: %s", c, e
|
||||
)
|
||||
out_of_scope.append(c)
|
||||
|
||||
if active and out_of_scope:
|
||||
# Block execution and return an explicit error requiring operator confirmation
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=name,
|
||||
error=(
|
||||
f"Out-of-scope target(s): {out_of_scope} - operator confirmation required. "
|
||||
"Set workspace targets with /target or run tool manually."
|
||||
),
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = await tool.execute(arguments, self.runtime)
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=name,
|
||||
result=result,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error executing tool %s: %s", name, e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Tool execution failed ({name}): {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about tool execution failure")
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
|
||||
@@ -111,8 +111,16 @@ class CrewOrchestrator:
|
||||
sections.append("\n".join(grouped[cat]))
|
||||
|
||||
notes_context = "\n\n".join(sections)
|
||||
except Exception:
|
||||
pass # Notes not available
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to gather notes for orchestrator prompt: %s", e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Orchestrator: failed to gather notes: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about orchestrator notes failure: %s", ne)
|
||||
|
||||
# Format insights for prompt
|
||||
insights_text = ""
|
||||
@@ -271,6 +279,15 @@ class CrewOrchestrator:
|
||||
break # Exit immediately after finish
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Worker tool execution failed (%s): %s", tc_name, e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Worker tool execution failed ({tc_name}): {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about worker tool failure: %s", ne)
|
||||
error_msg = f"Error: {e}"
|
||||
yield {
|
||||
"phase": "tool_result",
|
||||
@@ -327,6 +344,15 @@ class CrewOrchestrator:
|
||||
self.pool.finish_tokens = 0
|
||||
break
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Auto-finish failed: %s", e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Auto-finish failed: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about auto-finish failure: %s", ne)
|
||||
yield {
|
||||
"phase": "error",
|
||||
"error": f"Auto-finish failed: {e}",
|
||||
@@ -344,6 +370,15 @@ class CrewOrchestrator:
|
||||
yield {"phase": "complete", "report": final_report}
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Orchestrator run failed: %s", e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("error", f"CrewOrchestrator run failed: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about orchestrator run failure: %s", ne)
|
||||
self.state = CrewState.ERROR
|
||||
yield {"phase": "error", "error": str(e)}
|
||||
|
||||
|
||||
@@ -230,6 +230,15 @@ class WorkerPool:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Worker execution failed (%s): %s", worker.id, e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Worker execution failed ({worker.id}): {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about worker execution failure: %s", ne)
|
||||
worker.error = str(e)
|
||||
worker.status = AgentStatus.ERROR
|
||||
worker.completed_at = time.time()
|
||||
@@ -239,8 +248,16 @@ class WorkerPool:
|
||||
# Cleanup worker's isolated runtime
|
||||
try:
|
||||
await worker_runtime.stop()
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to stop worker runtime for %s: %s", worker.id, e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to stop worker runtime for {worker.id}: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about worker runtime stop failure: %s", ne)
|
||||
|
||||
async def _wait_for_dependencies(self, depends_on: List[str]) -> None:
|
||||
"""Wait for dependent workers to complete."""
|
||||
@@ -248,8 +265,17 @@ class WorkerPool:
|
||||
if dep_id in self._tasks:
|
||||
try:
|
||||
await self._tasks[dep_id]
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass # Dependency failed, but we continue
|
||||
except (asyncio.CancelledError, Exception) as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Dependency wait failed for %s: %s", dep_id, e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Dependency wait failed for {dep_id}: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about dependency wait failure: %s", ne)
|
||||
# Dependency failed, but we continue
|
||||
|
||||
async def wait_for(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -269,8 +295,16 @@ class WorkerPool:
|
||||
if agent_id in self._tasks:
|
||||
try:
|
||||
await self._tasks[agent_id]
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
except (asyncio.CancelledError, Exception) as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Waiting for agent task %s failed: %s", agent_id, e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Waiting for agent {agent_id} failed: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about wait_for agent failure: %s", ne)
|
||||
|
||||
worker = self._workers.get(agent_id)
|
||||
if worker:
|
||||
|
||||
@@ -117,8 +117,16 @@ class PentestAgentAgent(BaseAgent):
|
||||
sections.append("\n".join(grouped[cat]))
|
||||
|
||||
notes_context = "\n\n".join(sections)
|
||||
except Exception:
|
||||
pass # Notes not available
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to gather notes for agent prompt: %s", e)
|
||||
try:
|
||||
from ...interface.notifier import notify
|
||||
|
||||
notify("warning", f"Agent: failed to gather notes: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about agent notes failure")
|
||||
|
||||
# Get environment info from runtime
|
||||
env = self.runtime.environment
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from ..config.constants import AGENT_MAX_ITERATIONS, DEFAULT_MODEL
|
||||
from .cli import run_cli
|
||||
@@ -127,6 +128,25 @@ Examples:
|
||||
mcp_test = mcp_subparsers.add_parser("test", help="Test MCP server connection")
|
||||
mcp_test.add_argument("name", help="Server name to test")
|
||||
|
||||
# workspace management
|
||||
ws_parser = subparsers.add_parser(
|
||||
"workspace", help="Workspace lifecycle and info commands"
|
||||
)
|
||||
ws_parser.add_argument(
|
||||
"action",
|
||||
nargs="?",
|
||||
help="Action or workspace name. Subcommands: info, note, clear, export, import",
|
||||
)
|
||||
ws_parser.add_argument("rest", nargs=argparse.REMAINDER, help="Additional arguments")
|
||||
|
||||
# NOTE: use `workspace list` to list workspaces (handled by workspace subcommand)
|
||||
|
||||
# target management
|
||||
tgt_parser = subparsers.add_parser(
|
||||
"target", help="Add or list targets for the active workspace"
|
||||
)
|
||||
tgt_parser.add_argument("values", nargs="*", help="Targets to add (IP/CIDR/hostname)")
|
||||
|
||||
return parser, parser.parse_args()
|
||||
|
||||
|
||||
@@ -304,6 +324,232 @@ def handle_mcp_command(args: argparse.Namespace):
|
||||
console.print("[yellow]Use 'pentestagent mcp --help' for available commands[/]")
|
||||
|
||||
|
||||
def handle_workspace_command(args: argparse.Namespace):
|
||||
"""Handle workspace lifecycle commands and actions."""
|
||||
|
||||
from pentestagent.workspaces.manager import WorkspaceError, WorkspaceManager
|
||||
from pentestagent.workspaces.utils import (
|
||||
export_workspace,
|
||||
import_workspace,
|
||||
resolve_knowledge_paths,
|
||||
)
|
||||
|
||||
wm = WorkspaceManager()
|
||||
|
||||
action = args.action
|
||||
rest = args.rest or []
|
||||
|
||||
# No args -> show active workspace
|
||||
if not action:
|
||||
active = wm.get_active()
|
||||
if not active:
|
||||
print("No active workspace.")
|
||||
else:
|
||||
print(f"Active workspace: {active}")
|
||||
return
|
||||
|
||||
# Subcommands
|
||||
if action == "info":
|
||||
# show info for active or specified workspace
|
||||
name = rest[0] if rest else wm.get_active()
|
||||
if not name:
|
||||
print("No workspace specified and no active workspace.")
|
||||
return
|
||||
try:
|
||||
meta = wm.get_meta(name)
|
||||
created = meta.get("created_at")
|
||||
last_active = meta.get("last_active_at")
|
||||
targets = meta.get("targets", [])
|
||||
kp = resolve_knowledge_paths()
|
||||
ks = "workspace" if kp.get("using_workspace") else "global"
|
||||
# estimate loot size if present
|
||||
import os
|
||||
|
||||
loot_dir = (wm.workspace_path(name) / "loot").resolve()
|
||||
size = 0
|
||||
files = 0
|
||||
if loot_dir.exists():
|
||||
for rootp, _, filenames in os.walk(loot_dir):
|
||||
for fn in filenames:
|
||||
try:
|
||||
fp = os.path.join(rootp, fn)
|
||||
size += os.path.getsize(fp)
|
||||
files += 1
|
||||
except Exception:
|
||||
# Best-effort loot stats: skip files we can't stat (e.g., permissions, broken symlinks)
|
||||
pass
|
||||
|
||||
print(f"Name: {name}")
|
||||
print(f"Created: {created}")
|
||||
print(f"Last active: {last_active}")
|
||||
print(f"Targets: {len(targets)}")
|
||||
print(f"Knowledge scope: {ks}")
|
||||
print(f"Loot files: {files}, approx size: {size} bytes")
|
||||
except Exception as e:
|
||||
print(f"Error retrieving workspace info: {e}")
|
||||
return
|
||||
|
||||
if action == "list":
|
||||
# list all workspaces and mark active
|
||||
try:
|
||||
wss = wm.list_workspaces()
|
||||
active = wm.get_active()
|
||||
if not wss:
|
||||
print("No workspaces found.")
|
||||
return
|
||||
for name in sorted(wss):
|
||||
prefix = "* " if name == active else " "
|
||||
print(f"{prefix}{name}")
|
||||
except Exception as e:
|
||||
print(f"Error listing workspaces: {e}")
|
||||
return
|
||||
|
||||
if action == "note":
|
||||
# Append operator note to active workspace (or specified via --workspace/-w)
|
||||
active = wm.get_active()
|
||||
name = active
|
||||
|
||||
text_parts = rest or []
|
||||
i = 0
|
||||
# Parse optional workspace selector flags before the note text.
|
||||
while i < len(text_parts):
|
||||
part = text_parts[i]
|
||||
if part in ("--workspace", "-w"):
|
||||
if i + 1 >= len(text_parts):
|
||||
print("Usage: workspace note [--workspace NAME] <text>")
|
||||
return
|
||||
name = text_parts[i + 1]
|
||||
i += 2
|
||||
continue
|
||||
# First non-option token marks the start of the note text
|
||||
break
|
||||
|
||||
if not name:
|
||||
print("No active workspace. Set one with /workspace <name>.")
|
||||
return
|
||||
|
||||
text = " ".join(text_parts[i:])
|
||||
if not text:
|
||||
print("Usage: workspace note [--workspace NAME] <text>")
|
||||
return
|
||||
try:
|
||||
wm.set_operator_note(name, text)
|
||||
print(f"Operator note saved for workspace '{name}'.")
|
||||
except Exception as e:
|
||||
print(f"Error saving note: {e}")
|
||||
return
|
||||
|
||||
if action == "clear":
|
||||
active = wm.get_active()
|
||||
if not active:
|
||||
print("No active workspace.")
|
||||
return
|
||||
marker = wm.active_marker()
|
||||
try:
|
||||
if marker.exists():
|
||||
marker.unlink()
|
||||
print(f"Workspace '{active}' deactivated.")
|
||||
except Exception as e:
|
||||
print(f"Error deactivating workspace: {e}")
|
||||
return
|
||||
|
||||
if action == "export":
|
||||
# export <NAME> [--output file.tar.gz]
|
||||
if not rest:
|
||||
print("Usage: workspace export <NAME> [--output file.tar.gz]")
|
||||
return
|
||||
name = rest[0]
|
||||
out = None
|
||||
if "--output" in rest:
|
||||
idx = rest.index("--output")
|
||||
if idx + 1 < len(rest):
|
||||
out = Path(rest[idx + 1])
|
||||
try:
|
||||
archive = export_workspace(name, output=out)
|
||||
print(f"Workspace exported: {archive}")
|
||||
except Exception as e:
|
||||
print(f"Export failed: {e}")
|
||||
return
|
||||
|
||||
if action == "import":
|
||||
# import <ARCHIVE>
|
||||
if not rest:
|
||||
print("Usage: workspace import <archive.tar.gz>")
|
||||
return
|
||||
archive = Path(rest[0])
|
||||
try:
|
||||
name = import_workspace(archive)
|
||||
print(f"Workspace imported: {name} (not activated)")
|
||||
except Exception as e:
|
||||
print(f"Import failed: {e}")
|
||||
return
|
||||
|
||||
# Default: treat action as workspace name -> create and set active
|
||||
name = action
|
||||
try:
|
||||
existed = wm.workspace_path(name).exists()
|
||||
if not existed:
|
||||
wm.create(name)
|
||||
wm.set_active(name)
|
||||
|
||||
# restore last target if present
|
||||
last = wm.get_meta_field(name, "last_target")
|
||||
if last:
|
||||
print(f"Workspace '{name}' set active. Restored target: {last}")
|
||||
else:
|
||||
if existed:
|
||||
print(f"Workspace '{name}' set active.")
|
||||
else:
|
||||
print(f"Workspace '{name}' created and set active.")
|
||||
except WorkspaceError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error creating workspace: {e}")
|
||||
|
||||
|
||||
def handle_workspaces_list():
|
||||
from pentestagent.workspaces.manager import WorkspaceManager
|
||||
|
||||
wm = WorkspaceManager()
|
||||
wss = wm.list_workspaces()
|
||||
active = wm.get_active()
|
||||
if not wss:
|
||||
print("No workspaces found.")
|
||||
return
|
||||
for name in sorted(wss):
|
||||
prefix = "* " if name == active else " "
|
||||
print(f"{prefix}{name}")
|
||||
|
||||
|
||||
|
||||
def handle_target_command(args: argparse.Namespace):
|
||||
"""Handle target add/list commands."""
|
||||
from pentestagent.workspaces.manager import WorkspaceError, WorkspaceManager
|
||||
|
||||
wm = WorkspaceManager()
|
||||
active = wm.get_active()
|
||||
if not active:
|
||||
print("No active workspace. Set one with /workspace <name>.")
|
||||
return
|
||||
|
||||
vals = args.values or []
|
||||
try:
|
||||
if not vals:
|
||||
targets = wm.list_targets(active)
|
||||
if not targets:
|
||||
print(f"No targets for workspace '{active}'.")
|
||||
else:
|
||||
print(f"Targets for workspace '{active}': {targets}")
|
||||
return
|
||||
|
||||
saved = wm.add_targets(active, vals)
|
||||
print(f"Targets for workspace '{active}': {saved}")
|
||||
except WorkspaceError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error updating targets: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser, args = parse_arguments()
|
||||
@@ -317,6 +563,16 @@ def main():
|
||||
handle_mcp_command(args)
|
||||
return
|
||||
|
||||
if args.command == "workspace":
|
||||
handle_workspace_command(args)
|
||||
return
|
||||
|
||||
# 'workspace list' handled by workspace subcommand
|
||||
|
||||
if args.command == "target":
|
||||
handle_target_command(args)
|
||||
return
|
||||
|
||||
if args.command == "run":
|
||||
# Check model configuration
|
||||
if not args.model:
|
||||
|
||||
40
pentestagent/interface/notifier.py
Normal file
40
pentestagent/interface/notifier.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Simple notifier bridge for UI notifications.
|
||||
|
||||
Modules can call `notify(level, message)` to emit operator-visible
|
||||
notifications. A UI (TUI) may register a callback via `register_callback()`
|
||||
to receive notifications and display them. If no callback is registered,
|
||||
notifications are logged.
|
||||
"""
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
|
||||
_callback: Optional[Callable[[str, str], None]] = None
|
||||
|
||||
|
||||
def register_callback(cb: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback to receive notifications.
|
||||
|
||||
Callback receives (level, message).
|
||||
"""
|
||||
global _callback
|
||||
_callback = cb
|
||||
|
||||
|
||||
def notify(level: str, message: str) -> None:
|
||||
"""Emit a notification. If UI callback registered, call it; otherwise log."""
|
||||
global _callback
|
||||
if _callback:
|
||||
try:
|
||||
_callback(level, message)
|
||||
return
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Notifier callback failed")
|
||||
|
||||
# Fallback to logging
|
||||
log = logging.getLogger("pentestagent.notifier")
|
||||
if level.lower() in ("error", "critical"):
|
||||
log.error(message)
|
||||
elif level.lower() in ("warn", "warning"):
|
||||
log.warning(message)
|
||||
else:
|
||||
log.info(message)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
from ..workspaces.utils import resolve_knowledge_paths
|
||||
from .rag import Document
|
||||
|
||||
|
||||
@@ -51,6 +52,11 @@ class KnowledgeIndexer:
|
||||
total_files = 0
|
||||
indexed_files = 0
|
||||
|
||||
# If directory is the default 'knowledge', prefer workspace knowledge if available
|
||||
if directory == Path("knowledge"):
|
||||
kp = resolve_knowledge_paths()
|
||||
directory = kp.get("sources", Path("knowledge"))
|
||||
|
||||
if not directory.exists():
|
||||
return documents, IndexingResult(
|
||||
0, 0, 0, [f"Directory not found: {directory}"]
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""RAG (Retrieval Augmented Generation) engine for PentestAgent."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..workspaces.utils import resolve_knowledge_paths
|
||||
from .embeddings import get_embeddings
|
||||
|
||||
|
||||
@@ -65,9 +67,49 @@ class RAGEngine:
|
||||
chunks = []
|
||||
self._source_files = set() # Reset source file tracking
|
||||
|
||||
# Resolve knowledge paths (prefer workspace if available)
|
||||
if self.knowledge_path != Path("knowledge"):
|
||||
sources_base = self.knowledge_path
|
||||
kp = None
|
||||
else:
|
||||
kp = resolve_knowledge_paths()
|
||||
sources_base = kp.get("sources", Path("knowledge"))
|
||||
|
||||
# If workspace has a persisted index and we're not forcing reindex, try to load it
|
||||
try:
|
||||
if kp and kp.get("using_workspace"):
|
||||
emb_dir = kp.get("embeddings")
|
||||
emb_dir.mkdir(parents=True, exist_ok=True)
|
||||
idx_path = emb_dir / "index.pkl"
|
||||
if idx_path.exists() and not force:
|
||||
try:
|
||||
self.load_index(idx_path)
|
||||
return
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to load persisted RAG index at %s, will re-index: %s",
|
||||
idx_path,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify(
|
||||
"warning",
|
||||
f"Failed to load persisted RAG index at {idx_path}: {e}",
|
||||
)
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about RAG load failure: %s", ne
|
||||
)
|
||||
except Exception as e:
|
||||
# Non-fatal — continue to index from sources, but log the error
|
||||
logging.getLogger(__name__).exception(
|
||||
"Error while checking for persisted workspace index: %s", e
|
||||
)
|
||||
|
||||
# Process all files in knowledge directory
|
||||
if self.knowledge_path.exists():
|
||||
for file in self.knowledge_path.rglob("*"):
|
||||
if sources_base.exists():
|
||||
for file in sources_base.rglob("*"):
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
@@ -107,7 +149,9 @@ class RAGEngine:
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[RAG] Error processing {file}: {e}")
|
||||
logging.getLogger(__name__).exception(
|
||||
"[RAG] Error processing %s: %s", file, e
|
||||
)
|
||||
|
||||
self.documents = chunks
|
||||
|
||||
@@ -127,6 +171,36 @@ class RAGEngine:
|
||||
doc.embedding = self.embeddings[i]
|
||||
|
||||
self._indexed = True
|
||||
# If using a workspace, persist the built index for faster future loads
|
||||
try:
|
||||
if kp and kp.get("using_workspace") and self.embeddings is not None:
|
||||
emb_dir = kp.get("embeddings")
|
||||
emb_dir.mkdir(parents=True, exist_ok=True)
|
||||
idx_path = emb_dir / "index.pkl"
|
||||
try:
|
||||
self.save_index(idx_path)
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to save RAG index to %s: %s", idx_path, e
|
||||
)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Failed to save RAG index to {idx_path}: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about RAG save failure: %s", ne
|
||||
)
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Error while attempting to persist RAG index: %s", e
|
||||
)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Error while attempting to persist RAG index: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about RAG persist error: %s", ne
|
||||
)
|
||||
|
||||
def _chunk_text(
|
||||
self, text: str, source: str, chunk_size: int = 1000, overlap: int = 200
|
||||
@@ -408,6 +482,22 @@ class RAGEngine:
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
def save_index_to_workspace(self, root: Optional[Path] = None, filename: str = "index.pkl"):
|
||||
"""
|
||||
Convenience helper to save the index into the active workspace embeddings path.
|
||||
|
||||
Args:
|
||||
root: Optional project root to resolve workspaces (defaults to cwd)
|
||||
filename: Filename to use for the saved index
|
||||
"""
|
||||
from pathlib import Path as _P
|
||||
|
||||
kp = resolve_knowledge_paths(root=root)
|
||||
emb_dir = kp.get("embeddings")
|
||||
emb_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = _P(emb_dir) / filename
|
||||
self.save_index(path)
|
||||
|
||||
def load_index(self, path: Path):
|
||||
"""
|
||||
Load the index from disk.
|
||||
@@ -437,3 +527,20 @@ class RAGEngine:
|
||||
doc.embedding = self.embeddings[i]
|
||||
|
||||
self._indexed = True
|
||||
|
||||
def load_index_from_workspace(self, root: Optional[Path] = None, filename: str = "index.pkl"):
|
||||
"""
|
||||
Convenience helper to load the index from the active workspace embeddings path.
|
||||
|
||||
Args:
|
||||
root: Optional project root to resolve workspaces (defaults to cwd)
|
||||
filename: Filename used for the saved index
|
||||
"""
|
||||
from pathlib import Path as _P
|
||||
|
||||
kp = resolve_knowledge_paths(root=root)
|
||||
emb_dir = kp.get("embeddings")
|
||||
path = _P(emb_dir) / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Workspace index not found: {path}")
|
||||
self.load_index(path)
|
||||
|
||||
@@ -12,11 +12,10 @@ operates.
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import signal
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -24,9 +23,7 @@ except Exception:
|
||||
aiohttp = None
|
||||
|
||||
|
||||
LOOT_DIR = Path("loot/artifacts")
|
||||
LOOT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
LOG_FILE = LOOT_DIR / "hexstrike.log"
|
||||
from ..workspaces.utils import get_loot_file
|
||||
|
||||
|
||||
class HexstrikeAdapter:
|
||||
@@ -97,10 +94,19 @@ class HexstrikeAdapter:
|
||||
try:
|
||||
pid = getattr(self._process, "pid", None)
|
||||
if pid:
|
||||
with LOG_FILE.open("a") as fh:
|
||||
log_file = get_loot_file("artifacts/hexstrike.log")
|
||||
with log_file.open("a") as fh:
|
||||
fh.write(f"[HexstrikeAdapter] started pid={pid}\n")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to write hexstrike start PID to log: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to write hexstrike PID to log: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike PID log failure")
|
||||
|
||||
# Start a background reader task to capture logs
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -118,18 +124,26 @@ class HexstrikeAdapter:
|
||||
return
|
||||
|
||||
try:
|
||||
with LOG_FILE.open("ab") as fh:
|
||||
log_file = get_loot_file("artifacts/hexstrike.log")
|
||||
with log_file.open("ab") as fh:
|
||||
while True:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
# Prefix timestamps for easier debugging
|
||||
fh.write(line)
|
||||
fh.flush()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error capturing hexstrike output: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"HexStrike log capture failed: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike log capture failure")
|
||||
|
||||
async def stop(self, timeout: int = 5) -> None:
|
||||
"""Stop the server process gracefully."""
|
||||
@@ -143,10 +157,26 @@ class HexstrikeAdapter:
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to kill hexstrike after timeout: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to kill hexstrike after timeout: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike kill failure")
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error stopping hexstrike process: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error stopping hexstrike process: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike stop error")
|
||||
|
||||
self._process = None
|
||||
|
||||
@@ -154,8 +184,14 @@ class HexstrikeAdapter:
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).exception("Error awaiting hexstrike reader task: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Error awaiting hexstrike reader task: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike reader await failure")
|
||||
|
||||
def stop_sync(self, timeout: int = 5) -> None:
|
||||
"""Synchronous stop helper for use during process-exit cleanup.
|
||||
@@ -179,7 +215,15 @@ class HexstrikeAdapter:
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except Exception:
|
||||
pass
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to SIGTERM hexstrike pid: %s", pid)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to SIGTERM hexstrike pid {pid}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike SIGTERM failure")
|
||||
|
||||
# wait briefly for process to exit
|
||||
end = time.time() + float(timeout)
|
||||
@@ -197,20 +241,52 @@ class HexstrikeAdapter:
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to SIGKILL hexstrike pid: %s", pid)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to SIGKILL hexstrike pid {pid}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike SIGKILL failure")
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error during hexstrike stop_sync cleanup: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error during hexstrike stop_sync cleanup: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike stop_sync cleanup error")
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.stop_sync()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Exception during HexstrikeAdapter.__del__: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error during HexstrikeAdapter cleanup: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike __del__ error")
|
||||
# Clear references
|
||||
try:
|
||||
self._process = None
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to clear HexstrikeAdapter process reference: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to clear hexstrike process reference: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike process-clear failure")
|
||||
|
||||
async def health_check(self, timeout: int = 5) -> bool:
|
||||
"""Check the server health endpoint. Returns True if healthy."""
|
||||
@@ -221,7 +297,16 @@ class HexstrikeAdapter:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=timeout) as resp:
|
||||
return resp.status == 200
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("HexstrikeAdapter health_check (aiohttp) failed: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"HexStrike health check failed: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike health check failure")
|
||||
return False
|
||||
|
||||
# Fallback: synchronous urllib in thread
|
||||
@@ -231,7 +316,16 @@ class HexstrikeAdapter:
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=timeout) as r:
|
||||
return r.status == 200
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("HexstrikeAdapter health_check (urllib) failed: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"HexStrike health check failed: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike urllib health check failure")
|
||||
return False
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
@@ -13,9 +13,9 @@ Uses standard MCP configuration format:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
import atexit
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -78,8 +78,14 @@ class MCPManager:
|
||||
# Ensure we attempt to clean up vendored servers on process exit
|
||||
try:
|
||||
atexit.register(self._atexit_cleanup)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to register atexit cleanup: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to register MCP atexit cleanup: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about atexit.register failure")
|
||||
|
||||
def _find_config(self) -> Path:
|
||||
for path in self.DEFAULT_CONFIG_PATHS:
|
||||
@@ -202,8 +208,10 @@ class MCPManager:
|
||||
try:
|
||||
stop()
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Error running adapter.stop(): %s", e
|
||||
)
|
||||
|
||||
# Final fallback: kill underlying PID if available
|
||||
pid = None
|
||||
@@ -213,17 +221,18 @@ class MCPManager:
|
||||
if pid:
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to SIGTERM pid %s: %s", pid, e)
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e2:
|
||||
logging.getLogger(__name__).exception("Failed to SIGKILL pid %s: %s", pid, e2)
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Error while attempting synchronous adapter stop: %s", e)
|
||||
|
||||
async def _stop_started_adapters_and_disconnect(self) -> None:
|
||||
# Stop any adapters we started
|
||||
for name, adapter in list(self._started_adapters.items()):
|
||||
for _name, adapter in list(self._started_adapters.items()):
|
||||
try:
|
||||
stop = getattr(adapter, "stop", None)
|
||||
if stop:
|
||||
@@ -233,15 +242,15 @@ class MCPManager:
|
||||
# run blocking stop in executor
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, stop)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Error stopping adapter in async shutdown: %s", e)
|
||||
self._started_adapters.clear()
|
||||
|
||||
# Disconnect any active MCP server connections
|
||||
try:
|
||||
await self.disconnect_all()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Error during disconnect_all in shutdown: %s", e)
|
||||
|
||||
def add_server(
|
||||
self,
|
||||
|
||||
@@ -10,10 +10,10 @@ health check on a configurable port.
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import time
|
||||
import signal
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -21,9 +21,7 @@ except Exception:
|
||||
aiohttp = None
|
||||
|
||||
|
||||
LOOT_DIR = Path("loot/artifacts")
|
||||
LOOT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
LOG_FILE = LOOT_DIR / "metasploit_mcp.log"
|
||||
from ..workspaces.utils import get_loot_file
|
||||
|
||||
|
||||
class MetasploitAdapter:
|
||||
@@ -136,14 +134,24 @@ class MetasploitAdapter:
|
||||
await asyncio.sleep(0.5)
|
||||
# If we fallthrough, msfrpcd didn't become ready in time
|
||||
return
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to start msfrpcd: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to start msfrpcd: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd start failure")
|
||||
return
|
||||
|
||||
async def _capture_msfrpcd_output(self) -> None:
|
||||
if not self._msfrpcd_proc or not self._msfrpcd_proc.stdout:
|
||||
return
|
||||
try:
|
||||
with LOG_FILE.open("ab") as fh:
|
||||
log_file = get_loot_file("artifacts/msfrpcd.log")
|
||||
with log_file.open("ab") as fh:
|
||||
while True:
|
||||
line = await self._msfrpcd_proc.stdout.readline()
|
||||
if not line:
|
||||
@@ -151,9 +159,17 @@ class MetasploitAdapter:
|
||||
fh.write(b"[msfrpcd] " + line)
|
||||
fh.flush()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error capturing msfrpcd output: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"msfrpcd log capture failed: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd log capture failure")
|
||||
|
||||
async def start(self, background: bool = True, timeout: int = 30) -> bool:
|
||||
"""Start the vendored Metasploit MCP server.
|
||||
@@ -173,8 +189,16 @@ class MetasploitAdapter:
|
||||
if str(self.transport).lower() in ("http", "sse"):
|
||||
try:
|
||||
await self._start_msfrpcd_if_needed()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error starting msfrpcd: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error starting msfrpcd: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd error")
|
||||
|
||||
cmd = self._build_command()
|
||||
resolved = shutil.which(self.python_cmd) or self.python_cmd
|
||||
@@ -193,10 +217,19 @@ class MetasploitAdapter:
|
||||
try:
|
||||
pid = getattr(self._process, "pid", None)
|
||||
if pid:
|
||||
with LOG_FILE.open("a") as fh:
|
||||
log_file = get_loot_file("artifacts/metasploit_mcp.log")
|
||||
with log_file.open("a") as fh:
|
||||
fh.write(f"[MetasploitAdapter] started pid={pid}\n")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to write metasploit start PID to log: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to write metasploit PID to log: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about metasploit PID log failure")
|
||||
|
||||
# Start background reader
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -204,7 +237,16 @@ class MetasploitAdapter:
|
||||
|
||||
try:
|
||||
return await self.health_check(timeout=timeout)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("MetasploitAdapter health_check raised: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Metasploit health check failed: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about metasploit health check failure")
|
||||
return False
|
||||
|
||||
async def _capture_output(self) -> None:
|
||||
@@ -212,7 +254,8 @@ class MetasploitAdapter:
|
||||
return
|
||||
|
||||
try:
|
||||
with LOG_FILE.open("ab") as fh:
|
||||
log_file = get_loot_file("artifacts/metasploit_mcp.log")
|
||||
with log_file.open("ab") as fh:
|
||||
while True:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
@@ -220,9 +263,17 @@ class MetasploitAdapter:
|
||||
fh.write(line)
|
||||
fh.flush()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error capturing metasploit output: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Metasploit log capture failed: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about metasploit log capture failure")
|
||||
|
||||
async def stop(self, timeout: int = 5) -> None:
|
||||
proc = self._process
|
||||
@@ -237,8 +288,16 @@ class MetasploitAdapter:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error waiting for process termination: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error stopping metasploit adapter: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about metasploit stop error")
|
||||
|
||||
self._process = None
|
||||
|
||||
@@ -246,8 +305,14 @@ class MetasploitAdapter:
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).exception("Failed to kill msfrpcd during stop: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Failed to kill msfrpcd: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd kill failure")
|
||||
|
||||
# Stop msfrpcd if we started it
|
||||
try:
|
||||
@@ -261,8 +326,16 @@ class MetasploitAdapter:
|
||||
msf_proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error stopping metasploit adapter cleanup: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error stopping metasploit adapter: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about metasploit adapter cleanup error")
|
||||
finally:
|
||||
self._msfrpcd_proc = None
|
||||
|
||||
|
||||
@@ -210,7 +210,16 @@ class SSETransport(MCPTransport):
|
||||
except asyncio.TimeoutError:
|
||||
# If endpoint not discovered, continue; send() will try discovery
|
||||
pass
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed opening SSE stream: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed opening SSE stream: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE open failure")
|
||||
# If opening the SSE stream fails, still mark connected so
|
||||
# send() can attempt POST discovery and report meaningful errors.
|
||||
self._sse_response = None
|
||||
@@ -312,7 +321,16 @@ class SSETransport(MCPTransport):
|
||||
else:
|
||||
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/')}"
|
||||
return
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error during SSE POST endpoint discovery: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error during SSE POST endpoint discovery: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE discovery error")
|
||||
return
|
||||
|
||||
async def disconnect(self):
|
||||
@@ -323,21 +341,53 @@ class SSETransport(MCPTransport):
|
||||
self._sse_task.cancel()
|
||||
try:
|
||||
await self._sse_task
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error awaiting SSE listener task during disconnect: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error awaiting SSE listener task during disconnect: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE listener await failure")
|
||||
self._sse_task = None
|
||||
except Exception:
|
||||
pass
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error cancelling SSE listener task during disconnect")
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", "Error cancelling SSE listener task during disconnect")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE listener cancellation error")
|
||||
|
||||
try:
|
||||
if self._sse_response:
|
||||
try:
|
||||
await self._sse_response.release()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error releasing SSE response during disconnect: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Error releasing SSE response during disconnect: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE response release error")
|
||||
self._sse_response = None
|
||||
except Exception:
|
||||
pass
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Error handling SSE response during disconnect")
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", "Error handling SSE response during disconnect")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE response handling error")
|
||||
|
||||
# Fail any pending requests
|
||||
async with self._pending_lock:
|
||||
@@ -365,17 +415,20 @@ class SSETransport(MCPTransport):
|
||||
async for raw in resp.content:
|
||||
try:
|
||||
line = raw.decode(errors="ignore").rstrip("\r\n")
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to decode SSE raw chunk: %s", e)
|
||||
continue
|
||||
if line == "":
|
||||
# End of event; process accumulated lines
|
||||
event_name = None
|
||||
data_lines: list[str] = []
|
||||
for l in event_lines:
|
||||
if l.startswith("event:"):
|
||||
event_name = l.split(":", 1)[1].strip()
|
||||
elif l.startswith("data:"):
|
||||
data_lines.append(l.split(":", 1)[1].lstrip())
|
||||
for evt_line in event_lines:
|
||||
if evt_line.startswith("event:"):
|
||||
event_name = evt_line.split(":", 1)[1].strip()
|
||||
elif evt_line.startswith("data:"):
|
||||
data_lines.append(evt_line.split(":", 1)[1].lstrip())
|
||||
|
||||
if data_lines:
|
||||
data_text = "\n".join(data_lines)
|
||||
@@ -392,14 +445,30 @@ class SSETransport(MCPTransport):
|
||||
self._post_url = f"{p.scheme}://{p.netloc}{endpoint}"
|
||||
else:
|
||||
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/')}"
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed parsing SSE endpoint announcement: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed parsing SSE endpoint announcement: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE endpoint parse failure")
|
||||
# Notify connect() that endpoint is ready
|
||||
try:
|
||||
if self._endpoint_ready and not self._endpoint_ready.is_set():
|
||||
self._endpoint_ready.set()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed to set SSE endpoint ready event: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to set SSE endpoint ready event: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE endpoint ready event failure")
|
||||
else:
|
||||
# Try to parse as JSON and resolve pending futures
|
||||
try:
|
||||
@@ -410,8 +479,16 @@ class SSETransport(MCPTransport):
|
||||
fut = self._pending.get(msg_id)
|
||||
if fut and not fut.done():
|
||||
fut.set_result(obj)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).exception("Failed parsing SSE event JSON or resolving pending future: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed parsing SSE event JSON or resolving pending future: {e}")
|
||||
except Exception:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about SSE event parse/future failure")
|
||||
|
||||
event_lines = []
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import io
|
||||
import tarfile
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@@ -77,7 +78,14 @@ class DockerRuntime(Runtime):
|
||||
if self.container.status != "running":
|
||||
self.container.start()
|
||||
await asyncio.sleep(2) # Wait for container to fully start
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to get or start existing container: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"DockerRuntime: container check failed: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about docker container check failure: %s", ne)
|
||||
# Create new container
|
||||
volumes = {
|
||||
str(Path.home() / ".pentestagent"): {
|
||||
@@ -109,8 +117,14 @@ class DockerRuntime(Runtime):
|
||||
try:
|
||||
self.container.stop(timeout=10)
|
||||
self.container.remove()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed stopping/removing container: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"DockerRuntime: failed stopping/removing container: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception("Failed to notify operator about docker stop error: %s", ne)
|
||||
finally:
|
||||
self.container = None
|
||||
|
||||
@@ -261,7 +275,17 @@ class DockerRuntime(Runtime):
|
||||
try:
|
||||
self.container.reload()
|
||||
return self.container.status == "running"
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to determine container running state: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"DockerRuntime: is_running check failed: {e}")
|
||||
except Exception as notify_error:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to send notification for DockerRuntime.is_running error: %s",
|
||||
notify_error,
|
||||
)
|
||||
return False
|
||||
|
||||
async def get_status(self) -> dict:
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -304,8 +304,8 @@ def detect_environment() -> EnvironmentInfo:
|
||||
with open("/proc/version", "r") as f:
|
||||
if "microsoft" in f.read().lower():
|
||||
os_name = "Linux (WSL)"
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).debug("WSL detection probe failed: %s", e)
|
||||
|
||||
# Detect available tools with categories
|
||||
available_tools = []
|
||||
@@ -455,11 +455,14 @@ class LocalRuntime(Runtime):
|
||||
async def start(self):
|
||||
"""Start the local runtime."""
|
||||
self._running = True
|
||||
# Create organized loot directory structure
|
||||
Path("loot").mkdir(exist_ok=True)
|
||||
Path("loot/reports").mkdir(exist_ok=True)
|
||||
Path("loot/artifacts").mkdir(exist_ok=True)
|
||||
Path("loot/artifacts/screenshots").mkdir(exist_ok=True)
|
||||
# Create organized loot directory structure (workspace-aware)
|
||||
from ..workspaces.utils import get_loot_base
|
||||
|
||||
base = get_loot_base()
|
||||
(base).mkdir(parents=True, exist_ok=True)
|
||||
(base / "reports").mkdir(parents=True, exist_ok=True)
|
||||
(base / "artifacts").mkdir(parents=True, exist_ok=True)
|
||||
(base / "artifacts" / "screenshots").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the local runtime gracefully."""
|
||||
@@ -476,8 +479,16 @@ class LocalRuntime(Runtime):
|
||||
proc.stdout.close()
|
||||
if proc.stderr:
|
||||
proc.stderr.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Error cleaning up active process: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Runtime: error cleaning up process: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about process cleanup error: %s", ne
|
||||
)
|
||||
self._active_processes.clear()
|
||||
|
||||
# Clean up browser
|
||||
@@ -490,29 +501,57 @@ class LocalRuntime(Runtime):
|
||||
if self._page:
|
||||
try:
|
||||
await self._page.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to close browser page: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Runtime: failed to close browser page: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about browser page close error: %s", ne
|
||||
)
|
||||
self._page = None
|
||||
|
||||
if self._browser_context:
|
||||
try:
|
||||
await self._browser_context.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to close browser context: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Runtime: failed to close browser context: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about browser context close error: %s", ne
|
||||
)
|
||||
self._browser_context = None
|
||||
|
||||
if self._browser:
|
||||
try:
|
||||
await self._browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to close browser: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Runtime: failed to close browser: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about browser close error: %s", ne
|
||||
)
|
||||
self._browser = None
|
||||
|
||||
if self._playwright:
|
||||
try:
|
||||
await self._playwright.stop()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Failed to stop playwright: %s", e)
|
||||
try:
|
||||
from ..interface.notifier import notify
|
||||
notify("warning", f"Runtime: failed to stop playwright: {e}")
|
||||
except Exception as ne:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to notify operator about playwright stop error: %s", ne
|
||||
)
|
||||
self._playwright = None
|
||||
|
||||
async def _ensure_browser(self):
|
||||
@@ -651,7 +690,6 @@ class LocalRuntime(Runtime):
|
||||
elif action == "screenshot":
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
# Navigate first if URL provided
|
||||
if kwargs.get("url"):
|
||||
@@ -659,9 +697,10 @@ class LocalRuntime(Runtime):
|
||||
kwargs["url"], timeout=timeout, wait_until="domcontentloaded"
|
||||
)
|
||||
|
||||
# Save screenshot to loot/artifacts/screenshots/
|
||||
output_dir = Path("loot/artifacts/screenshots")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Save screenshot to workspace-aware loot/artifacts/screenshots/
|
||||
from ..workspaces.utils import get_loot_file
|
||||
|
||||
output_dir = get_loot_file("artifacts/screenshots").parent
|
||||
|
||||
timestamp = int(time.time())
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
|
||||
@@ -9,17 +9,27 @@ from ..registry import ToolSchema, register_tool
|
||||
|
||||
# Notes storage - kept at loot root for easy access
|
||||
_notes: Dict[str, Dict[str, Any]] = {}
|
||||
_notes_file: Path = Path("loot/notes.json")
|
||||
# Optional override (tests can call set_notes_file)
|
||||
_custom_notes_file: Path | None = None
|
||||
# Lock for safe concurrent access from multiple agents (asyncio since agents are async tasks)
|
||||
_notes_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _notes_file_path() -> Path:
|
||||
from ...workspaces.utils import get_loot_file
|
||||
|
||||
if _custom_notes_file:
|
||||
return _custom_notes_file
|
||||
return get_loot_file("notes.json")
|
||||
|
||||
|
||||
def _load_notes_unlocked() -> None:
|
||||
"""Load notes from file (caller must hold lock)."""
|
||||
global _notes
|
||||
if _notes_file.exists():
|
||||
nf = _notes_file_path()
|
||||
if nf.exists():
|
||||
try:
|
||||
loaded = json.loads(_notes_file.read_text(encoding="utf-8"))
|
||||
loaded = json.loads(nf.read_text(encoding="utf-8"))
|
||||
# Migration: Convert legacy string values to dicts
|
||||
_notes = {}
|
||||
for k, v in loaded.items():
|
||||
@@ -37,8 +47,9 @@ def _load_notes_unlocked() -> None:
|
||||
|
||||
def _save_notes_unlocked() -> None:
|
||||
"""Save notes to file (caller must hold lock)."""
|
||||
_notes_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
_notes_file.write_text(json.dumps(_notes, indent=2), encoding="utf-8")
|
||||
nf = _notes_file_path()
|
||||
nf.parent.mkdir(parents=True, exist_ok=True)
|
||||
nf.write_text(json.dumps(_notes, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
async def get_all_notes() -> Dict[str, Dict[str, Any]]:
|
||||
@@ -52,9 +63,9 @@ async def get_all_notes() -> Dict[str, Dict[str, Any]]:
|
||||
def get_all_notes_sync() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all notes synchronously (read-only, best effort for prompts)."""
|
||||
# If notes are empty, try to load from disk (safe read)
|
||||
if not _notes and _notes_file.exists():
|
||||
if not _notes and _notes_file_path().exists():
|
||||
try:
|
||||
loaded = json.loads(_notes_file.read_text(encoding="utf-8"))
|
||||
loaded = json.loads(_notes_file_path().read_text(encoding="utf-8"))
|
||||
# Migration for sync read
|
||||
result = {}
|
||||
for k, v in loaded.items():
|
||||
@@ -74,14 +85,13 @@ def get_all_notes_sync() -> Dict[str, Dict[str, Any]]:
|
||||
|
||||
def set_notes_file(path: Path) -> None:
|
||||
"""Set custom notes file path."""
|
||||
global _notes_file
|
||||
_notes_file = path
|
||||
global _custom_notes_file
|
||||
_custom_notes_file = Path(path)
|
||||
# Can't use async here, so load without lock (called at init time)
|
||||
_load_notes_unlocked()
|
||||
|
||||
|
||||
# Load notes on module import (init time, no contention yet)
|
||||
_load_notes_unlocked()
|
||||
# Defer loading until first access to avoid caching active workspace path at import
|
||||
|
||||
|
||||
# Validation schema - declarative rules for note structure
|
||||
|
||||
@@ -11,8 +11,8 @@ from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
# Persistent storage (loot root)
|
||||
_data_file: Path = Path("loot/token_usage.json")
|
||||
# Persistent storage (loot root) - compute at use to respect active workspace
|
||||
_custom_data_file: Path | None = None
|
||||
_data_lock = threading.Lock()
|
||||
|
||||
# In-memory cache
|
||||
@@ -27,9 +27,15 @@ _data: Dict[str, Any] = {
|
||||
|
||||
def _load_unlocked() -> None:
|
||||
global _data
|
||||
if _data_file.exists():
|
||||
data_file = _custom_data_file or None
|
||||
if not data_file:
|
||||
from ..workspaces.utils import get_loot_file
|
||||
|
||||
data_file = get_loot_file("token_usage.json")
|
||||
|
||||
if data_file.exists():
|
||||
try:
|
||||
loaded = json.loads(_data_file.read_text(encoding="utf-8"))
|
||||
loaded = json.loads(data_file.read_text(encoding="utf-8"))
|
||||
# Merge with defaults to be robust to schema changes
|
||||
d = {**_data, **(loaded or {})}
|
||||
_data = d
|
||||
@@ -45,14 +51,20 @@ def _load_unlocked() -> None:
|
||||
|
||||
|
||||
def _save_unlocked() -> None:
|
||||
_data_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
_data_file.write_text(json.dumps(_data, indent=2), encoding="utf-8")
|
||||
data_file = _custom_data_file or None
|
||||
if not data_file:
|
||||
from ..workspaces.utils import get_loot_file
|
||||
|
||||
data_file = get_loot_file("token_usage.json")
|
||||
|
||||
data_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
data_file.write_text(json.dumps(_data, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def set_data_file(path: Path) -> None:
|
||||
"""Override the data file (used by tests)."""
|
||||
global _data_file
|
||||
_data_file = path
|
||||
global _custom_data_file
|
||||
_custom_data_file = Path(path)
|
||||
_load_unlocked()
|
||||
|
||||
|
||||
|
||||
3
pentestagent/workspaces/__init__.py
Normal file
3
pentestagent/workspaces/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .manager import TargetManager, WorkspaceError, WorkspaceManager
|
||||
|
||||
__all__ = ["WorkspaceManager", "TargetManager", "WorkspaceError"]
|
||||
226
pentestagent/workspaces/manager.py
Normal file
226
pentestagent/workspaces/manager.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""WorkspaceManager: file-backed workspace and target management using YAML.
|
||||
|
||||
Design goals:
|
||||
- Workspace metadata stored as YAML at workspaces/{name}/meta.yaml
|
||||
- Active workspace marker stored at workspaces/.active
|
||||
- No in-memory caching: all operations read/write files directly
|
||||
- Lightweight hostname validation; accept IPs, CIDRs, hostnames
|
||||
"""
|
||||
import ipaddress
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class WorkspaceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
WORKSPACES_DIR_NAME = "workspaces"
|
||||
NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
|
||||
|
||||
|
||||
def _safe_mkdir(path: Path):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class TargetManager:
|
||||
"""Validate and normalize targets (IP, CIDR, hostname).
|
||||
|
||||
Hostname validation is intentionally light: allow letters, digits, hyphens, dots.
|
||||
"""
|
||||
|
||||
HOST_RE = re.compile(r"^[A-Za-z0-9.-]{1,253}$")
|
||||
|
||||
@staticmethod
|
||||
def normalize_target(value: str) -> str:
|
||||
v = value.strip()
|
||||
# try CIDR or IP
|
||||
try:
|
||||
if "/" in v:
|
||||
net = ipaddress.ip_network(v, strict=False)
|
||||
return str(net)
|
||||
else:
|
||||
ip = ipaddress.ip_address(v)
|
||||
return str(ip)
|
||||
except Exception:
|
||||
# fallback to hostname validation (light)
|
||||
if TargetManager.HOST_RE.match(v) and ".." not in v:
|
||||
return v.lower()
|
||||
raise WorkspaceError(f"Invalid target: {value}") from None
|
||||
|
||||
@staticmethod
|
||||
def validate(value: str) -> bool:
|
||||
try:
|
||||
TargetManager.normalize_target(value)
|
||||
return True
|
||||
except WorkspaceError:
|
||||
return False
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""File-backed workspace manager. No persistent in-memory state.
|
||||
|
||||
Root defaults to current working directory.
|
||||
"""
|
||||
|
||||
def __init__(self, root: Path = Path(".")):
|
||||
self.root = Path(root)
|
||||
self.workspaces_dir = self.root / WORKSPACES_DIR_NAME
|
||||
_safe_mkdir(self.workspaces_dir)
|
||||
|
||||
def validate_name(self, name: str):
|
||||
if not NAME_RE.match(name):
|
||||
raise WorkspaceError(
|
||||
"Invalid workspace name; allowed characters: A-Za-z0-9._- (1-64 chars)"
|
||||
)
|
||||
# prevent path traversal and slashes
|
||||
if "/" in name or ".." in name:
|
||||
raise WorkspaceError("Invalid workspace name; must not contain '/' or '..'")
|
||||
|
||||
def workspace_path(self, name: str) -> Path:
|
||||
self.validate_name(name)
|
||||
return self.workspaces_dir / name
|
||||
|
||||
def meta_path(self, name: str) -> Path:
|
||||
return self.workspace_path(name) / "meta.yaml"
|
||||
|
||||
def active_marker(self) -> Path:
|
||||
return self.workspaces_dir / ".active"
|
||||
|
||||
def create(self, name: str) -> dict:
|
||||
self.validate_name(name)
|
||||
p = self.workspace_path(name)
|
||||
# create required dirs
|
||||
for sub in ("loot", "knowledge/sources", "knowledge/embeddings", "notes", "memory"):
|
||||
_safe_mkdir(p / sub)
|
||||
|
||||
# initialize meta if missing
|
||||
if not self.meta_path(name).exists():
|
||||
meta = {"name": name, "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"), "targets": []}
|
||||
self._write_meta(name, meta)
|
||||
return meta
|
||||
return self._read_meta(name)
|
||||
|
||||
def _read_meta(self, name: str) -> dict:
|
||||
mp = self.meta_path(name)
|
||||
if not mp.exists():
|
||||
return {"name": name, "targets": []}
|
||||
try:
|
||||
data = yaml.safe_load(mp.read_text(encoding="utf-8"))
|
||||
if data is None:
|
||||
return {"name": name, "targets": []}
|
||||
# ensure keys
|
||||
data.setdefault("name", name)
|
||||
data.setdefault("targets", [])
|
||||
return data
|
||||
except Exception as e:
|
||||
raise WorkspaceError(f"Failed to read meta for {name}: {e}") from e
|
||||
|
||||
def _write_meta(self, name: str, meta: dict):
|
||||
mp = self.meta_path(name)
|
||||
mp.parent.mkdir(parents=True, exist_ok=True)
|
||||
mp.write_text(yaml.safe_dump(meta, sort_keys=False), encoding="utf-8")
|
||||
|
||||
def set_active(self, name: str):
|
||||
# ensure workspace exists
|
||||
self.create(name)
|
||||
marker = self.active_marker()
|
||||
marker.write_text(name, encoding="utf-8")
|
||||
# update last_active_at in meta.yaml
|
||||
try:
|
||||
meta = self._read_meta(name)
|
||||
meta["last_active_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
# ensure operator_notes and tool_runs exist
|
||||
meta.setdefault("operator_notes", "")
|
||||
meta.setdefault("tool_runs", [])
|
||||
self._write_meta(name, meta)
|
||||
except Exception as e:
|
||||
# Non-fatal - don't block activation on meta write errors, but log for visibility
|
||||
logging.getLogger(__name__).exception(
|
||||
"Failed to update meta.yaml for workspace '%s': %s", name, e
|
||||
)
|
||||
try:
|
||||
# Emit operator-visible notification if UI present
|
||||
from ..interface.notifier import notify
|
||||
|
||||
notify("warning", f"Failed to update workspace meta for '{name}': {e}")
|
||||
except Exception:
|
||||
# ignore notifier failures
|
||||
pass
|
||||
|
||||
def set_operator_note(self, name: str, note: str) -> dict:
|
||||
"""Append or set operator_notes for a workspace (plain text)."""
|
||||
meta = self._read_meta(name)
|
||||
prev = meta.get("operator_notes", "") or ""
|
||||
if prev:
|
||||
new = prev + "\n" + note
|
||||
else:
|
||||
new = note
|
||||
meta["operator_notes"] = new
|
||||
self._write_meta(name, meta)
|
||||
return meta
|
||||
|
||||
def get_meta_field(self, name: str, field: str):
|
||||
meta = self._read_meta(name)
|
||||
return meta.get(field)
|
||||
|
||||
def get_active(self) -> str:
|
||||
marker = self.active_marker()
|
||||
if not marker.exists():
|
||||
return ""
|
||||
return marker.read_text(encoding="utf-8").strip()
|
||||
|
||||
def list_workspaces(self) -> List[str]:
|
||||
if not self.workspaces_dir.exists():
|
||||
return []
|
||||
return [p.name for p in self.workspaces_dir.iterdir() if p.is_dir()]
|
||||
|
||||
def get_meta(self, name: str) -> dict:
|
||||
return self._read_meta(name)
|
||||
|
||||
def add_targets(self, name: str, values: List[str]) -> List[str]:
|
||||
# read-modify-write for strict file-backed behavior
|
||||
meta = self._read_meta(name)
|
||||
existing = set(meta.get("targets", []))
|
||||
changed = False
|
||||
for v in values:
|
||||
norm = TargetManager.normalize_target(v)
|
||||
if norm not in existing:
|
||||
existing.add(norm)
|
||||
changed = True
|
||||
if changed:
|
||||
meta["targets"] = sorted(existing)
|
||||
self._write_meta(name, meta)
|
||||
return meta.get("targets", [])
|
||||
|
||||
def set_last_target(self, name: str, value: str) -> str:
|
||||
"""Set the workspace's last used target and ensure it's in the targets list."""
|
||||
norm = TargetManager.normalize_target(value)
|
||||
meta = self._read_meta(name)
|
||||
# ensure targets contains it
|
||||
existing = set(meta.get("targets", []))
|
||||
if norm not in existing:
|
||||
existing.add(norm)
|
||||
meta["targets"] = sorted(existing)
|
||||
meta["last_target"] = norm
|
||||
self._write_meta(name, meta)
|
||||
return norm
|
||||
|
||||
def remove_target(self, name: str, value: str) -> List[str]:
|
||||
meta = self._read_meta(name)
|
||||
existing = set(meta.get("targets", []))
|
||||
norm = TargetManager.normalize_target(value)
|
||||
if norm in existing:
|
||||
existing.remove(norm)
|
||||
meta["targets"] = sorted(existing)
|
||||
self._write_meta(name, meta)
|
||||
return meta.get("targets", [])
|
||||
|
||||
def list_targets(self, name: str) -> List[str]:
|
||||
meta = self._read_meta(name)
|
||||
return meta.get("targets", [])
|
||||
189
pentestagent/workspaces/utils.py
Normal file
189
pentestagent/workspaces/utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Utilities to route loot/output into the active workspace or global loot.
|
||||
|
||||
All functions are file-backed and do not cache the active workspace selection.
|
||||
This module will emit a single warning per run if no active workspace is set.
|
||||
"""
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .manager import WorkspaceManager
|
||||
|
||||
_WARNED = False
|
||||
|
||||
|
||||
def get_loot_base(root: Optional[Path] = None) -> Path:
|
||||
"""Return the base loot directory: workspaces/{active}/loot or top-level `loot/`.
|
||||
|
||||
Emits a single warning if no workspace is active.
|
||||
"""
|
||||
global _WARNED
|
||||
root = Path(root or "./")
|
||||
wm = WorkspaceManager(root=root)
|
||||
active = wm.get_active()
|
||||
if active:
|
||||
base = root / "workspaces" / active / "loot"
|
||||
else:
|
||||
if not _WARNED:
|
||||
logging.warning("No active workspace — writing loot to global loot/ directory.")
|
||||
_WARNED = True
|
||||
base = root / "loot"
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base
|
||||
|
||||
|
||||
def get_loot_file(relpath: str, root: Optional[Path] = None) -> Path:
|
||||
"""Return a Path for a file under the loot base, creating parent dirs.
|
||||
|
||||
Example: get_loot_file('artifacts/hexstrike.log')
|
||||
"""
|
||||
base = get_loot_base(root=root)
|
||||
p = base / relpath
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
|
||||
|
||||
def resolve_knowledge_paths(root: Optional[Path] = None) -> dict:
|
||||
"""Resolve knowledge-related paths, preferring active workspace if present.
|
||||
|
||||
Returns a dict with keys: base, sources, embeddings, graph, index, using_workspace
|
||||
"""
|
||||
root = Path(root or "./")
|
||||
wm = WorkspaceManager(root=root)
|
||||
active = wm.get_active()
|
||||
|
||||
global_base = root / "knowledge"
|
||||
workspace_base = root / "workspaces" / active / "knowledge" if active else None
|
||||
|
||||
use_workspace = False
|
||||
if workspace_base and workspace_base.exists():
|
||||
# prefer workspace if it has any content (explicit opt-in)
|
||||
try:
|
||||
# Use a non-recursive check to avoid walking the entire directory tree
|
||||
if any(workspace_base.iterdir()):
|
||||
use_workspace = True
|
||||
# Also allow an explicit opt-in marker file .use_workspace
|
||||
elif (workspace_base / ".use_workspace").exists():
|
||||
use_workspace = True
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception(
|
||||
"Error while checking workspace knowledge directory: %s", e
|
||||
)
|
||||
use_workspace = False
|
||||
|
||||
if use_workspace:
|
||||
base = workspace_base
|
||||
else:
|
||||
base = global_base
|
||||
|
||||
paths = {
|
||||
"base": base,
|
||||
"sources": base / "sources",
|
||||
"embeddings": base / "embeddings",
|
||||
"graph": base / "graph",
|
||||
"index": base / "index",
|
||||
"using_workspace": use_workspace,
|
||||
}
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def export_workspace(name: str, output: Optional[Path] = None, root: Optional[Path] = None) -> Path:
|
||||
"""Create a deterministic tar.gz archive of workspaces/{name}/ and return the archive path.
|
||||
|
||||
Excludes __pycache__ and *.pyc. Does not mutate workspace.
|
||||
"""
|
||||
import tarfile
|
||||
|
||||
root = Path(root or "./")
|
||||
ws_dir = root / "workspaces" / name
|
||||
if not ws_dir.exists() or not ws_dir.is_dir():
|
||||
raise FileNotFoundError(f"Workspace not found: {name}")
|
||||
|
||||
out_path = Path(output) if output else Path(f"{name}-workspace.tar.gz")
|
||||
|
||||
# Use deterministic ordering
|
||||
entries = []
|
||||
for p in ws_dir.rglob("*"):
|
||||
# skip __pycache__ and .pyc
|
||||
if "__pycache__" in p.parts:
|
||||
continue
|
||||
if p.suffix == ".pyc":
|
||||
continue
|
||||
rel = p.relative_to(root)
|
||||
entries.append(rel)
|
||||
|
||||
entries = sorted(entries, key=str)
|
||||
|
||||
# Create tar.gz
|
||||
with tarfile.open(out_path, "w:gz") as tf:
|
||||
for rel in entries:
|
||||
full = root / rel
|
||||
# store with relative path (preserve workspaces/<name>/...)
|
||||
tf.add(str(full), arcname=str(rel))
|
||||
|
||||
return out_path
|
||||
|
||||
|
||||
def import_workspace(archive: Path, root: Optional[Path] = None) -> str:
|
||||
"""Import a workspace tar.gz into workspaces/. Returns workspace name.
|
||||
|
||||
Fails if workspace already exists. Requires meta.yaml present in archive.
|
||||
"""
|
||||
import tarfile
|
||||
import tempfile
|
||||
|
||||
root = Path(root or "./")
|
||||
archive = Path(archive)
|
||||
if not archive.exists():
|
||||
raise FileNotFoundError(f"Archive not found: {archive}")
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
tdpath = Path(td)
|
||||
with tarfile.open(archive, "r:gz") as tf:
|
||||
tf.extractall(path=tdpath)
|
||||
|
||||
# Look for workspaces/<name>/meta.yaml or meta.yaml at root
|
||||
candidates = list(tdpath.rglob("meta.yaml"))
|
||||
if not candidates:
|
||||
raise ValueError("No meta.yaml found in archive")
|
||||
meta_file = candidates[0]
|
||||
# read name
|
||||
import yaml
|
||||
|
||||
meta = yaml.safe_load(meta_file.read_text(encoding="utf-8"))
|
||||
if not meta or not meta.get("name"):
|
||||
raise ValueError("meta.yaml missing 'name' field")
|
||||
name = meta["name"]
|
||||
|
||||
dest = root / "workspaces" / name
|
||||
if dest.exists():
|
||||
raise FileExistsError(f"Workspace already exists: {name}")
|
||||
|
||||
# Move extracted tree into place
|
||||
# Find root folder under tdpath that contains the workspace files
|
||||
# If archive stored paths with workspaces/<name>/..., move that subtree
|
||||
candidate_root = None
|
||||
for p in tdpath.iterdir():
|
||||
if p.is_dir() and p.name == "workspaces":
|
||||
candidate_root = p / name
|
||||
break
|
||||
if candidate_root and candidate_root.exists():
|
||||
# move candidate_root to dest (use shutil.move to support cross-filesystem)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
shutil.move(str(candidate_root), str(dest))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to move workspace subtree into place: {e}") from e
|
||||
else:
|
||||
# Otherwise, assume contents are directly the workspace folder
|
||||
# move the parent of meta_file (or its containing dir)
|
||||
src = meta_file.parent
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
shutil.move(str(src), str(dest))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to move extracted workspace into place: {e}") from e
|
||||
|
||||
return name
|
||||
115
pentestagent/workspaces/validation.py
Normal file
115
pentestagent/workspaces/validation.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Workspace target validation utilities.
|
||||
|
||||
Provides helpers to extract candidate targets from arbitrary tool arguments
|
||||
and to determine whether a candidate target is covered by the allowed
|
||||
workspace targets (IP, CIDR, hostname).
|
||||
"""
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
from .manager import TargetManager
|
||||
|
||||
|
||||
def gather_candidate_targets(obj: Any) -> List[str]:
|
||||
"""
|
||||
Extract candidate target strings from arguments (shallow, non-recursive).
|
||||
|
||||
This function inspects only the top-level of the provided object (str or dict)
|
||||
and collects values for common target keys (e.g., 'target', 'host', 'ip', etc.).
|
||||
It does NOT recurse into nested dictionaries or lists. If you need to extract
|
||||
targets from deeply nested structures, you must implement or call a recursive
|
||||
extractor separately.
|
||||
|
||||
Rationale: Shallow extraction is fast and predictable for most tool argument
|
||||
schemas. For recursive extraction, see the project documentation or extend
|
||||
this function as needed.
|
||||
"""
|
||||
candidates: List[str] = []
|
||||
if isinstance(obj, str):
|
||||
candidates.append(obj)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
if k.lower() in (
|
||||
"target",
|
||||
"host",
|
||||
"hostname",
|
||||
"ip",
|
||||
"address",
|
||||
"url",
|
||||
"hosts",
|
||||
"targets",
|
||||
):
|
||||
if isinstance(v, (list, tuple)):
|
||||
for it in v:
|
||||
if isinstance(it, str):
|
||||
candidates.append(it)
|
||||
elif isinstance(v, str):
|
||||
candidates.append(v)
|
||||
return candidates
|
||||
|
||||
|
||||
def is_target_in_scope(candidate: str, allowed: List[str]) -> bool:
|
||||
"""Check whether `candidate` is covered by any entry in `allowed`.
|
||||
|
||||
Allowed entries may be IPs, CIDRs, or hostnames/labels. Candidate may
|
||||
also be an IP, CIDR, or hostname. The function normalizes inputs and
|
||||
performs robust comparisons for networks and addresses.
|
||||
"""
|
||||
try:
|
||||
norm = TargetManager.normalize_target(candidate)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# If candidate is a network (contains '/'), treat as network
|
||||
try:
|
||||
if "/" in norm:
|
||||
cand_net = ipaddress.ip_network(norm, strict=False)
|
||||
for a in allowed:
|
||||
try:
|
||||
if "/" in a:
|
||||
an = ipaddress.ip_network(a, strict=False)
|
||||
if cand_net.subnet_of(an) or cand_net == an:
|
||||
return True
|
||||
else:
|
||||
# allowed is IP or hostname; accept only when candidate
|
||||
# network represents exactly one address equal to allowed IP
|
||||
try:
|
||||
allowed_ip = ipaddress.ip_address(a)
|
||||
except Exception:
|
||||
# not an IP (likely hostname) - skip
|
||||
continue
|
||||
if cand_net.num_addresses == 1 and cand_net.network_address == allowed_ip:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
else:
|
||||
# candidate is a single IP/hostname
|
||||
try:
|
||||
cand_ip = ipaddress.ip_address(norm)
|
||||
for a in allowed:
|
||||
try:
|
||||
if "/" in a:
|
||||
an = ipaddress.ip_network(a, strict=False)
|
||||
if cand_ip in an:
|
||||
return True
|
||||
else:
|
||||
if TargetManager.normalize_target(a) == norm:
|
||||
return True
|
||||
except Exception:
|
||||
if isinstance(a, str) and a.lower() == norm.lower():
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
# candidate is likely a hostname; compare case-insensitively
|
||||
for a in allowed:
|
||||
try:
|
||||
if a.lower() == norm.lower():
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).exception("Error checking target scope: %s", e)
|
||||
return False
|
||||
@@ -126,6 +126,8 @@ known_first_party = ["pentestagent"]
|
||||
line-length = 88
|
||||
target-version = "py310"
|
||||
|
||||
exclude = ["third_party/"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
|
||||
@@ -36,7 +36,7 @@ typer>=0.12.0
|
||||
pydantic>=2.7.0
|
||||
pydantic-settings>=2.2.0
|
||||
python-dotenv>=1.0.0
|
||||
pyyaml>=6.0.0
|
||||
PyYAML>=6.0
|
||||
jinja2>=3.1.0
|
||||
|
||||
# Dev
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""Tests for the agent state management."""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from pentestagent.agents.state import AgentState, AgentStateManager, StateTransition
|
||||
|
||||
|
||||
class TestAgentState:
|
||||
"""Tests for AgentState enum."""
|
||||
|
||||
def test_state_values(self):
|
||||
"""Test state enum values."""
|
||||
assert AgentState.IDLE.value == "idle"
|
||||
assert AgentState.THINKING.value == "thinking"
|
||||
assert AgentState.EXECUTING.value == "executing"
|
||||
assert AgentState.WAITING_INPUT.value == "waiting_input"
|
||||
assert AgentState.COMPLETE.value == "complete"
|
||||
assert AgentState.ERROR.value == "error"
|
||||
|
||||
def test_all_states_exist(self):
|
||||
"""Test that all expected states exist."""
|
||||
states = list(AgentState)
|
||||
assert len(states) >= 6
|
||||
|
||||
|
||||
class TestAgentStateManager:
|
||||
"""Tests for AgentStateManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def state_manager(self):
|
||||
"""Create a fresh AgentStateManager for each test."""
|
||||
return AgentStateManager()
|
||||
|
||||
def test_initial_state(self, state_manager):
|
||||
"""Test initial state is IDLE."""
|
||||
assert state_manager.current_state == AgentState.IDLE
|
||||
assert len(state_manager.history) == 0
|
||||
|
||||
def test_valid_transition(self, state_manager):
|
||||
"""Test valid state transition."""
|
||||
result = state_manager.transition_to(AgentState.THINKING)
|
||||
assert result is True
|
||||
assert state_manager.current_state == AgentState.THINKING
|
||||
assert len(state_manager.history) == 1
|
||||
|
||||
def test_invalid_transition(self, state_manager):
|
||||
"""Test invalid state transition."""
|
||||
result = state_manager.transition_to(AgentState.COMPLETE)
|
||||
assert result is False
|
||||
assert state_manager.current_state == AgentState.IDLE
|
||||
|
||||
def test_transition_chain(self, state_manager):
|
||||
"""Test a chain of valid transitions."""
|
||||
assert state_manager.transition_to(AgentState.THINKING)
|
||||
assert state_manager.transition_to(AgentState.EXECUTING)
|
||||
assert state_manager.transition_to(AgentState.THINKING)
|
||||
assert state_manager.transition_to(AgentState.COMPLETE)
|
||||
|
||||
assert state_manager.current_state == AgentState.COMPLETE
|
||||
assert len(state_manager.history) == 4
|
||||
|
||||
def test_force_transition(self, state_manager):
|
||||
"""Test forcing a transition."""
|
||||
state_manager.force_transition(AgentState.ERROR, reason="Test error")
|
||||
assert state_manager.current_state == AgentState.ERROR
|
||||
assert "FORCED" in state_manager.history[-1].reason
|
||||
|
||||
def test_reset(self, state_manager):
|
||||
"""Test resetting state."""
|
||||
state_manager.transition_to(AgentState.THINKING)
|
||||
state_manager.transition_to(AgentState.EXECUTING)
|
||||
|
||||
state_manager.reset()
|
||||
|
||||
assert state_manager.current_state == AgentState.IDLE
|
||||
assert len(state_manager.history) == 0
|
||||
|
||||
def test_is_terminal(self, state_manager):
|
||||
"""Test terminal state detection."""
|
||||
assert state_manager.is_terminal() is False
|
||||
|
||||
state_manager.transition_to(AgentState.THINKING)
|
||||
state_manager.transition_to(AgentState.COMPLETE)
|
||||
|
||||
assert state_manager.is_terminal() is True
|
||||
|
||||
def test_is_active(self, state_manager):
|
||||
"""Test active state detection."""
|
||||
assert state_manager.is_active() is False
|
||||
|
||||
state_manager.transition_to(AgentState.THINKING)
|
||||
assert state_manager.is_active() is True
|
||||
|
||||
|
||||
class TestStateTransition:
|
||||
"""Tests for StateTransition dataclass."""
|
||||
|
||||
def test_create_transition(self):
|
||||
"""Test creating a state transition."""
|
||||
transition = StateTransition(
|
||||
from_state=AgentState.IDLE,
|
||||
to_state=AgentState.THINKING,
|
||||
reason="Starting work"
|
||||
)
|
||||
|
||||
assert transition.from_state == AgentState.IDLE
|
||||
assert transition.to_state == AgentState.THINKING
|
||||
assert transition.reason == "Starting work"
|
||||
assert transition.timestamp is not None
|
||||
@@ -1,251 +0,0 @@
|
||||
"""Tests for the Shadow Graph knowledge system."""
|
||||
|
||||
import networkx as nx
|
||||
import pytest
|
||||
|
||||
from pentestagent.knowledge.graph import ShadowGraph
|
||||
|
||||
|
||||
class TestShadowGraph:
|
||||
"""Tests for ShadowGraph class."""
|
||||
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
"""Create a fresh ShadowGraph for each test."""
|
||||
return ShadowGraph()
|
||||
|
||||
def test_initialization(self, graph):
|
||||
"""Test graph initialization."""
|
||||
assert isinstance(graph.graph, nx.DiGraph)
|
||||
assert len(graph.graph.nodes) == 0
|
||||
assert len(graph._processed_notes) == 0
|
||||
|
||||
def test_extract_host_from_note(self, graph):
|
||||
"""Test extracting host IP from a note."""
|
||||
notes = {
|
||||
"scan_result": {
|
||||
"content": "Nmap scan for 192.168.1.10 shows open ports.",
|
||||
"category": "info"
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
assert graph.graph.has_node("host:192.168.1.10")
|
||||
node = graph.graph.nodes["host:192.168.1.10"]
|
||||
assert node["type"] == "host"
|
||||
assert node["label"] == "192.168.1.10"
|
||||
|
||||
def test_extract_service_finding(self, graph):
|
||||
"""Test extracting services from a finding note."""
|
||||
notes = {
|
||||
"ports_scan": {
|
||||
"content": "Found open ports: 80/tcp, 443/tcp on 10.0.0.5",
|
||||
"category": "finding",
|
||||
"metadata": {
|
||||
"target": "10.0.0.5",
|
||||
"services": [
|
||||
{"port": 80, "protocol": "tcp", "service": "http"},
|
||||
{"port": 443, "protocol": "tcp", "service": "https"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
# Check host exists
|
||||
assert graph.graph.has_node("host:10.0.0.5")
|
||||
|
||||
# Check services exist
|
||||
assert graph.graph.has_node("service:host:10.0.0.5:80")
|
||||
assert graph.graph.has_node("service:host:10.0.0.5:443")
|
||||
|
||||
# Check edges
|
||||
assert graph.graph.has_edge("host:10.0.0.5", "service:host:10.0.0.5:80")
|
||||
edge = graph.graph.edges["host:10.0.0.5", "service:host:10.0.0.5:80"]
|
||||
assert edge["type"] == "HAS_SERVICE"
|
||||
assert edge["protocol"] == "tcp"
|
||||
|
||||
def test_extract_credential(self, graph):
|
||||
"""Test extracting credentials and linking to host."""
|
||||
notes = {
|
||||
"ssh_creds": {
|
||||
"content": "Found user: admin with password 'password123' for SSH on 192.168.1.20",
|
||||
"category": "credential",
|
||||
"metadata": {
|
||||
"target": "192.168.1.20",
|
||||
"username": "admin",
|
||||
"password": "password123",
|
||||
"protocol": "ssh"
|
||||
}
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
cred_id = "cred:ssh_creds"
|
||||
host_id = "host:192.168.1.20"
|
||||
|
||||
assert graph.graph.has_node(cred_id)
|
||||
assert graph.graph.has_node(host_id)
|
||||
|
||||
# Check edge
|
||||
assert graph.graph.has_edge(cred_id, host_id)
|
||||
edge = graph.graph.edges[cred_id, host_id]
|
||||
assert edge["type"] == "AUTH_ACCESS"
|
||||
assert edge["protocol"] == "ssh"
|
||||
|
||||
def test_extract_credential_variations(self, graph):
|
||||
"""Test different credential formats."""
|
||||
notes = {
|
||||
"creds_1": {
|
||||
"content": "Username: root, Password: toor",
|
||||
"category": "credential"
|
||||
},
|
||||
"creds_2": {
|
||||
"content": "Just a password: secret",
|
||||
"category": "credential"
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
# Check "Username: root" extraction
|
||||
node1 = graph.graph.nodes["cred:creds_1"]
|
||||
assert node1["label"] == "Creds (root)"
|
||||
|
||||
# Check fallback for no username
|
||||
node2 = graph.graph.nodes["cred:creds_2"]
|
||||
assert node2["label"] == "Credentials"
|
||||
|
||||
def test_metadata_extraction(self, graph):
|
||||
"""Test extracting entities from structured metadata."""
|
||||
notes = {
|
||||
"meta_cred": {
|
||||
"content": "Some random text",
|
||||
"category": "credential",
|
||||
"metadata": {
|
||||
"username": "admin_meta",
|
||||
"target": "10.0.0.99",
|
||||
"source": "10.0.0.1"
|
||||
}
|
||||
},
|
||||
"meta_vuln": {
|
||||
"content": "Bad stuff",
|
||||
"category": "vulnerability",
|
||||
"metadata": {
|
||||
"cve": "CVE-2025-1234",
|
||||
"target": "10.0.0.99"
|
||||
}
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
# Check Credential Metadata
|
||||
cred_node = graph.graph.nodes["cred:meta_cred"]
|
||||
assert cred_node["label"] == "Creds (admin_meta)"
|
||||
|
||||
# Check Target Host
|
||||
assert graph.graph.has_node("host:10.0.0.99")
|
||||
assert graph.graph.has_edge("cred:meta_cred", "host:10.0.0.99")
|
||||
|
||||
# Check Source Host (CONTAINS edge)
|
||||
assert graph.graph.has_node("host:10.0.0.1")
|
||||
assert graph.graph.has_edge("host:10.0.0.1", "cred:meta_cred")
|
||||
|
||||
# Check Vulnerability Metadata
|
||||
vuln_node = graph.graph.nodes["vuln:meta_vuln"]
|
||||
assert vuln_node["label"] == "CVE-2025-1234"
|
||||
assert graph.graph.has_edge("host:10.0.0.99", "vuln:meta_vuln")
|
||||
|
||||
def test_url_metadata(self, graph):
|
||||
"""Test that URL metadata is added to service labels."""
|
||||
notes = {
|
||||
"web_app": {
|
||||
"content": "Admin panel found",
|
||||
"category": "finding",
|
||||
"metadata": {
|
||||
"target": "10.0.0.5",
|
||||
"port": "80/tcp",
|
||||
"url": "http://10.0.0.5/admin"
|
||||
}
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
service_id = "service:host:10.0.0.5:80"
|
||||
assert graph.graph.has_node(service_id)
|
||||
node = graph.graph.nodes[service_id]
|
||||
assert "http://10.0.0.5/admin" in node["label"]
|
||||
|
||||
def test_legacy_note_format(self, graph):
|
||||
"""Test handling legacy string-only notes."""
|
||||
notes = {
|
||||
"legacy_note": "Just a simple note about 10.10.10.10"
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
assert graph.graph.has_node("host:10.10.10.10")
|
||||
|
||||
def test_idempotency(self, graph):
|
||||
"""Test that processing the same note twice doesn't duplicate or error."""
|
||||
notes = {
|
||||
"scan": {
|
||||
"content": "Host 192.168.1.1 is up.",
|
||||
"category": "info"
|
||||
}
|
||||
}
|
||||
|
||||
# First pass
|
||||
graph.update_from_notes(notes)
|
||||
assert len(graph.graph.nodes) == 1
|
||||
|
||||
# Second pass
|
||||
graph.update_from_notes(notes)
|
||||
assert len(graph.graph.nodes) == 1
|
||||
|
||||
def test_attack_paths(self, graph):
|
||||
"""Test detection of multi-step attack paths."""
|
||||
# Manually construct a path: Cred1 -> HostA -> Cred2 -> HostB
|
||||
# 1. Cred1 gives access to HostA
|
||||
graph._add_node("cred:1", "credential", "Root Creds")
|
||||
graph._add_node("host:A", "host", "10.0.0.1")
|
||||
graph._add_edge("cred:1", "host:A", "AUTH_ACCESS")
|
||||
|
||||
# 2. HostA has Cred2 (this edge type isn't auto-extracted yet, but logic should handle it)
|
||||
graph._add_node("cred:2", "credential", "Db Admin")
|
||||
graph._add_edge("host:A", "cred:2", "CONTAINS_CRED")
|
||||
|
||||
# 3. Cred2 gives access to HostB
|
||||
graph._add_node("host:B", "host", "10.0.0.2")
|
||||
graph._add_edge("cred:2", "host:B", "AUTH_ACCESS")
|
||||
|
||||
paths = graph._find_attack_paths()
|
||||
assert len(paths) == 1
|
||||
assert "Root Creds" in paths[0]
|
||||
assert "10.0.0.1" in paths[0]
|
||||
assert "Db Admin" in paths[0]
|
||||
assert "10.0.0.2" in paths[0]
|
||||
|
||||
def test_mermaid_export(self, graph):
|
||||
"""Test Mermaid diagram generation."""
|
||||
graph._add_node("host:1", "host", "10.0.0.1")
|
||||
graph._add_node("cred:1", "credential", "admin")
|
||||
graph._add_edge("cred:1", "host:1", "AUTH_ACCESS")
|
||||
|
||||
mermaid = graph.to_mermaid()
|
||||
assert "graph TD" in mermaid
|
||||
assert 'host_1["🖥️ 10.0.0.1"]' in mermaid
|
||||
assert 'cred_1["🔑 admin"]' in mermaid
|
||||
assert "cred_1 -->|AUTH_ACCESS| host_1" in mermaid
|
||||
|
||||
def test_multiple_ips_in_one_note(self, graph):
|
||||
"""Test a single note referencing multiple hosts."""
|
||||
notes = {
|
||||
"subnet_scan": {
|
||||
"content": "Scanning 192.168.1.1, 192.168.1.2, and 192.168.1.3",
|
||||
"category": "info"
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
assert graph.graph.has_node("host:192.168.1.1")
|
||||
assert graph.graph.has_node("host:192.168.1.2")
|
||||
assert graph.graph.has_node("host:192.168.1.3")
|
||||
@@ -1,99 +0,0 @@
|
||||
"""Tests for the RAG knowledge system."""
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pentestagent.knowledge.rag import Document, RAGEngine
|
||||
|
||||
|
||||
class TestDocument:
|
||||
"""Tests for Document dataclass."""
|
||||
|
||||
def test_create_document(self):
|
||||
"""Test creating a document."""
|
||||
doc = Document(content="Test content", source="test.md")
|
||||
assert doc.content == "Test content"
|
||||
assert doc.source == "test.md"
|
||||
assert doc.metadata == {}
|
||||
assert doc.doc_id is not None
|
||||
|
||||
def test_document_with_metadata(self):
|
||||
"""Test document with metadata."""
|
||||
doc = Document(
|
||||
content="Test",
|
||||
source="test.md",
|
||||
metadata={"cve_id": "CVE-2021-1234", "severity": "high"}
|
||||
)
|
||||
assert doc.metadata["cve_id"] == "CVE-2021-1234"
|
||||
assert doc.metadata["severity"] == "high"
|
||||
|
||||
def test_document_with_embedding(self):
|
||||
"""Test document with embedding."""
|
||||
embedding = np.random.rand(384)
|
||||
doc = Document(content="Test", source="test.md", embedding=embedding)
|
||||
assert doc.embedding is not None
|
||||
assert len(doc.embedding) == 384
|
||||
|
||||
def test_document_with_custom_id(self):
|
||||
"""Test document with custom doc_id."""
|
||||
doc = Document(content="Test", source="test.md", doc_id="custom-id-123")
|
||||
assert doc.doc_id == "custom-id-123"
|
||||
|
||||
|
||||
class TestRAGEngine:
|
||||
"""Tests for RAGEngine class."""
|
||||
|
||||
@pytest.fixture
|
||||
def rag_engine(self, tmp_path):
|
||||
"""Create a RAG engine for testing."""
|
||||
return RAGEngine(
|
||||
knowledge_path=tmp_path / "knowledge",
|
||||
use_local_embeddings=True
|
||||
)
|
||||
|
||||
def test_create_engine(self, rag_engine):
|
||||
"""Test creating a RAG engine."""
|
||||
assert rag_engine is not None
|
||||
assert len(rag_engine.documents) == 0
|
||||
assert rag_engine.embeddings is None
|
||||
|
||||
def test_get_document_count_empty(self, rag_engine):
|
||||
"""Test document count on empty engine."""
|
||||
assert rag_engine.get_document_count() == 0
|
||||
|
||||
def test_clear(self, rag_engine):
|
||||
"""Test clearing the engine."""
|
||||
rag_engine.documents.append(Document(content="test", source="test.md"))
|
||||
rag_engine.embeddings = np.random.rand(1, 384)
|
||||
rag_engine._indexed = True
|
||||
|
||||
rag_engine.clear()
|
||||
|
||||
assert len(rag_engine.documents) == 0
|
||||
assert rag_engine.embeddings is None
|
||||
assert not rag_engine._indexed
|
||||
|
||||
|
||||
class TestRAGEngineChunking:
|
||||
"""Tests for text chunking functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, tmp_path):
|
||||
"""Create engine for chunking tests."""
|
||||
return RAGEngine(knowledge_path=tmp_path)
|
||||
|
||||
def test_chunk_short_text(self, engine):
|
||||
"""Test chunking text shorter than chunk size."""
|
||||
text = "This is a short paragraph.\n\nThis is another paragraph."
|
||||
chunks = engine._chunk_text(text, source="test.md", chunk_size=1000)
|
||||
|
||||
assert len(chunks) >= 1
|
||||
assert all(isinstance(c, Document) for c in chunks)
|
||||
|
||||
def test_chunk_preserves_source(self, engine):
|
||||
"""Test that chunking preserves source information."""
|
||||
text = "Test paragraph 1.\n\nTest paragraph 2."
|
||||
chunks = engine._chunk_text(text, source="my_source.md")
|
||||
|
||||
assert all(c.source == "my_source.md" for c in chunks)
|
||||
@@ -1,159 +0,0 @@
|
||||
"""Tests for the Notes tool."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from pentestagent.tools.notes import _notes, get_all_notes, notes, set_notes_file
|
||||
|
||||
|
||||
# We need to reset the global state for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_notes_state(tmp_path):
|
||||
"""Reset the notes global state for each test."""
|
||||
# Point to a temp file
|
||||
temp_notes_file = tmp_path / "notes.json"
|
||||
set_notes_file(temp_notes_file)
|
||||
|
||||
# Clear the global dictionary (it's imported from the module)
|
||||
# We need to clear the actual dictionary object in the module
|
||||
_notes.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup is handled by tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note():
|
||||
"""Test creating a new note."""
|
||||
args = {
|
||||
"action": "create",
|
||||
"key": "test_note",
|
||||
"value": "This is a test note",
|
||||
"category": "info",
|
||||
"confidence": "high"
|
||||
}
|
||||
|
||||
result = await notes(args, runtime=None)
|
||||
assert "Created note 'test_note'" in result
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert "test_note" in all_notes
|
||||
assert all_notes["test_note"]["content"] == "This is a test note"
|
||||
assert all_notes["test_note"]["category"] == "info"
|
||||
assert all_notes["test_note"]["confidence"] == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_note():
|
||||
"""Test reading an existing note."""
|
||||
# Create first
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "read_me",
|
||||
"value": "Content to read"
|
||||
}, runtime=None)
|
||||
|
||||
# Read
|
||||
result = await notes({
|
||||
"action": "read",
|
||||
"key": "read_me"
|
||||
}, runtime=None)
|
||||
|
||||
assert "Content to read" in result
|
||||
# The format is "[key] (category, confidence, status) content"
|
||||
assert "(info, medium, confirmed)" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note():
|
||||
"""Test updating a note."""
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "update_me",
|
||||
"value": "Original content"
|
||||
}, runtime=None)
|
||||
|
||||
result = await notes({
|
||||
"action": "update",
|
||||
"key": "update_me",
|
||||
"value": "New content"
|
||||
}, runtime=None)
|
||||
|
||||
assert "Updated note 'update_me'" in result
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert all_notes["update_me"]["content"] == "New content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_note():
|
||||
"""Test deleting a note."""
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "delete_me",
|
||||
"value": "Bye bye"
|
||||
}, runtime=None)
|
||||
|
||||
result = await notes({
|
||||
"action": "delete",
|
||||
"key": "delete_me"
|
||||
}, runtime=None)
|
||||
|
||||
assert "Deleted note 'delete_me'" in result
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert "delete_me" not in all_notes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes():
|
||||
"""Test listing all notes."""
|
||||
await notes({"action": "create", "key": "n1", "value": "v1"}, runtime=None)
|
||||
await notes({"action": "create", "key": "n2", "value": "v2"}, runtime=None)
|
||||
|
||||
result = await notes({"action": "list"}, runtime=None)
|
||||
|
||||
assert "n1" in result
|
||||
assert "n2" in result
|
||||
assert "Notes (2 entries):" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence(tmp_path):
|
||||
"""Test that notes are saved to disk."""
|
||||
# The fixture already sets a temp file
|
||||
temp_file = tmp_path / "notes.json"
|
||||
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "persistent_note",
|
||||
"value": "I survive restarts"
|
||||
}, runtime=None)
|
||||
|
||||
assert temp_file.exists()
|
||||
content = json.loads(temp_file.read_text())
|
||||
assert "persistent_note" in content
|
||||
assert content["persistent_note"]["content"] == "I survive restarts"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_migration(tmp_path):
|
||||
"""Test migration of legacy string notes."""
|
||||
# Create a legacy file
|
||||
legacy_file = tmp_path / "legacy_notes.json"
|
||||
legacy_data = {
|
||||
"old_note": "Just a string",
|
||||
"new_note": {"content": "A dict", "category": "info"}
|
||||
}
|
||||
legacy_file.write_text(json.dumps(legacy_data))
|
||||
|
||||
# Point the tool to this file
|
||||
set_notes_file(legacy_file)
|
||||
|
||||
# Trigger load (get_all_notes calls _load_notes_unlocked if empty, but we need to clear first)
|
||||
_notes.clear()
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
|
||||
assert "old_note" in all_notes
|
||||
assert isinstance(all_notes["old_note"], dict)
|
||||
assert all_notes["old_note"]["content"] == "Just a string"
|
||||
assert all_notes["old_note"]["category"] == "info"
|
||||
|
||||
assert "new_note" in all_notes
|
||||
assert all_notes["new_note"]["content"] == "A dict"
|
||||
@@ -1,146 +0,0 @@
|
||||
"""Tests for the tool system."""
|
||||
|
||||
import pytest
|
||||
|
||||
from pentestagent.tools import (
|
||||
ToolSchema,
|
||||
disable_tool,
|
||||
enable_tool,
|
||||
get_all_tools,
|
||||
get_tool,
|
||||
get_tool_names,
|
||||
register_tool,
|
||||
)
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Tests for tool registry functions."""
|
||||
|
||||
def test_tools_loaded(self):
|
||||
"""Test that built-in tools are loaded."""
|
||||
tools = get_all_tools()
|
||||
assert len(tools) > 0
|
||||
|
||||
tool_names = get_tool_names()
|
||||
assert "terminal" in tool_names
|
||||
assert "browser" in tool_names
|
||||
|
||||
def test_get_tool(self):
|
||||
"""Test getting a tool by name."""
|
||||
tool = get_tool("terminal")
|
||||
assert tool is not None
|
||||
assert tool.name == "terminal"
|
||||
assert tool.category == "execution"
|
||||
|
||||
def test_get_nonexistent_tool(self):
|
||||
"""Test getting a tool that doesn't exist."""
|
||||
tool = get_tool("nonexistent_tool_xyz")
|
||||
assert tool is None
|
||||
|
||||
def test_disable_enable_tool(self):
|
||||
"""Test disabling and enabling a tool."""
|
||||
result = disable_tool("terminal")
|
||||
assert result is True
|
||||
|
||||
tool = get_tool("terminal")
|
||||
assert tool.enabled is False
|
||||
|
||||
result = enable_tool("terminal")
|
||||
assert result is True
|
||||
|
||||
tool = get_tool("terminal")
|
||||
assert tool.enabled is True
|
||||
|
||||
def test_disable_nonexistent_tool(self):
|
||||
"""Test disabling a tool that doesn't exist."""
|
||||
result = disable_tool("nonexistent_tool_xyz")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestToolSchema:
|
||||
"""Tests for ToolSchema class."""
|
||||
|
||||
def test_create_schema(self):
|
||||
"""Test creating a tool schema."""
|
||||
schema = ToolSchema(
|
||||
properties={
|
||||
"command": {"type": "string", "description": "Command to run"}
|
||||
},
|
||||
required=["command"]
|
||||
)
|
||||
|
||||
assert schema.type == "object"
|
||||
assert "command" in schema.properties
|
||||
assert "command" in schema.required
|
||||
|
||||
def test_schema_to_dict(self):
|
||||
"""Test converting schema to dictionary."""
|
||||
schema = ToolSchema(
|
||||
properties={"input": {"type": "string"}},
|
||||
required=["input"]
|
||||
)
|
||||
|
||||
d = schema.to_dict()
|
||||
assert d["type"] == "object"
|
||||
assert d["properties"]["input"]["type"] == "string"
|
||||
assert d["required"] == ["input"]
|
||||
|
||||
|
||||
class TestTool:
|
||||
"""Tests for Tool class."""
|
||||
|
||||
def test_create_tool(self, sample_tool):
|
||||
"""Test creating a tool."""
|
||||
assert sample_tool.name == "test_tool"
|
||||
assert sample_tool.description == "A test tool"
|
||||
assert sample_tool.category == "test"
|
||||
assert sample_tool.enabled is True
|
||||
|
||||
def test_tool_to_llm_format(self, sample_tool):
|
||||
"""Test converting tool to LLM format."""
|
||||
llm_format = sample_tool.to_llm_format()
|
||||
|
||||
assert llm_format["type"] == "function"
|
||||
assert llm_format["function"]["name"] == "test_tool"
|
||||
assert llm_format["function"]["description"] == "A test tool"
|
||||
assert "parameters" in llm_format["function"]
|
||||
|
||||
def test_tool_validate_arguments(self, sample_tool):
|
||||
"""Test argument validation."""
|
||||
is_valid, error = sample_tool.validate_arguments({"param": "value"})
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
is_valid, error = sample_tool.validate_arguments({})
|
||||
assert is_valid is False
|
||||
assert "param" in error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execute(self, sample_tool):
|
||||
"""Test tool execution."""
|
||||
result = await sample_tool.execute({"param": "test"}, runtime=None)
|
||||
assert "test" in result
|
||||
|
||||
|
||||
class TestRegisterToolDecorator:
|
||||
"""Tests for register_tool decorator."""
|
||||
|
||||
def test_decorator_registers_tool(self):
|
||||
"""Test that decorator registers a new tool."""
|
||||
initial_count = len(get_all_tools())
|
||||
|
||||
@register_tool(
|
||||
name="pytest_test_tool_unique",
|
||||
description="A tool registered in tests",
|
||||
schema=ToolSchema(properties={}, required=[]),
|
||||
category="test"
|
||||
)
|
||||
async def pytest_test_tool_unique(arguments, runtime):
|
||||
return "test result"
|
||||
|
||||
new_count = len(get_all_tools())
|
||||
assert new_count == initial_count + 1
|
||||
|
||||
tool = get_tool("pytest_test_tool_unique")
|
||||
assert tool is not None
|
||||
assert tool.name == "pytest_test_tool_unique"
|
||||
176
third_party/hexstrike/hexstrike_mcp.py
vendored
176
third_party/hexstrike/hexstrike_mcp.py
vendored
@@ -17,17 +17,17 @@ Architecture: MCP Client for AI agent communication with HexStrike server
|
||||
Framework: FastMCP integration for tool orchestration
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
import requests
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
|
||||
class HexStrikeColors:
|
||||
"""Enhanced color palette matching the server's ModernVisualEngine.COLORS"""
|
||||
|
||||
@@ -447,9 +447,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"☁️ Starting Prowler {provider} security assessment")
|
||||
result = hexstrike_client.safe_post("api/tools/prowler", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Prowler assessment completed")
|
||||
logger.info("✅ Prowler assessment completed")
|
||||
else:
|
||||
logger.error(f"❌ Prowler assessment failed")
|
||||
logger.error("❌ Prowler assessment failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -517,9 +517,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"☁️ Starting Scout Suite {provider} assessment")
|
||||
result = hexstrike_client.safe_post("api/tools/scout-suite", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Scout Suite assessment completed")
|
||||
logger.info("✅ Scout Suite assessment completed")
|
||||
else:
|
||||
logger.error(f"❌ Scout Suite assessment failed")
|
||||
logger.error("❌ Scout Suite assessment failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -575,12 +575,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"regions": regions,
|
||||
"additional_args": additional_args
|
||||
}
|
||||
logger.info(f"☁️ Starting Pacu AWS exploitation")
|
||||
logger.info("☁️ Starting Pacu AWS exploitation")
|
||||
result = hexstrike_client.safe_post("api/tools/pacu", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Pacu exploitation completed")
|
||||
logger.info("✅ Pacu exploitation completed")
|
||||
else:
|
||||
logger.error(f"❌ Pacu exploitation failed")
|
||||
logger.error("❌ Pacu exploitation failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -611,12 +611,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"report": report,
|
||||
"additional_args": additional_args
|
||||
}
|
||||
logger.info(f"☁️ Starting kube-hunter Kubernetes scan")
|
||||
logger.info("☁️ Starting kube-hunter Kubernetes scan")
|
||||
result = hexstrike_client.safe_post("api/tools/kube-hunter", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ kube-hunter scan completed")
|
||||
logger.info("✅ kube-hunter scan completed")
|
||||
else:
|
||||
logger.error(f"❌ kube-hunter scan failed")
|
||||
logger.error("❌ kube-hunter scan failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -642,12 +642,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"output_format": output_format,
|
||||
"additional_args": additional_args
|
||||
}
|
||||
logger.info(f"☁️ Starting kube-bench CIS benchmark")
|
||||
logger.info("☁️ Starting kube-bench CIS benchmark")
|
||||
result = hexstrike_client.safe_post("api/tools/kube-bench", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ kube-bench benchmark completed")
|
||||
logger.info("✅ kube-bench benchmark completed")
|
||||
else:
|
||||
logger.error(f"❌ kube-bench benchmark failed")
|
||||
logger.error("❌ kube-bench benchmark failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -672,12 +672,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"output_file": output_file,
|
||||
"additional_args": additional_args
|
||||
}
|
||||
logger.info(f"🐳 Starting Docker Bench Security assessment")
|
||||
logger.info("🐳 Starting Docker Bench Security assessment")
|
||||
result = hexstrike_client.safe_post("api/tools/docker-bench-security", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Docker Bench Security completed")
|
||||
logger.info("✅ Docker Bench Security completed")
|
||||
else:
|
||||
logger.error(f"❌ Docker Bench Security failed")
|
||||
logger.error("❌ Docker Bench Security failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -736,9 +736,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🛡️ Starting Falco runtime monitoring for {duration}s")
|
||||
result = hexstrike_client.safe_post("api/tools/falco", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Falco monitoring completed")
|
||||
logger.info("✅ Falco monitoring completed")
|
||||
else:
|
||||
logger.error(f"❌ Falco monitoring failed")
|
||||
logger.error("❌ Falco monitoring failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -770,9 +770,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔍 Starting Checkov IaC scan: {directory}")
|
||||
result = hexstrike_client.safe_post("api/tools/checkov", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Checkov scan completed")
|
||||
logger.info("✅ Checkov scan completed")
|
||||
else:
|
||||
logger.error(f"❌ Checkov scan failed")
|
||||
logger.error("❌ Checkov scan failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -804,9 +804,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔍 Starting Terrascan IaC scan: {iac_dir}")
|
||||
result = hexstrike_client.safe_post("api/tools/terrascan", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Terrascan scan completed")
|
||||
logger.info("✅ Terrascan scan completed")
|
||||
else:
|
||||
logger.error(f"❌ Terrascan scan failed")
|
||||
logger.error("❌ Terrascan scan failed")
|
||||
return result
|
||||
|
||||
# ============================================================================
|
||||
@@ -932,9 +932,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🎯 Generating {payload_type} payload: {size} bytes")
|
||||
result = hexstrike_client.safe_post("api/payloads/generate", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Payload generated successfully")
|
||||
logger.info("✅ Payload generated successfully")
|
||||
else:
|
||||
logger.error(f"❌ Failed to generate payload")
|
||||
logger.error("❌ Failed to generate payload")
|
||||
return result
|
||||
|
||||
# ============================================================================
|
||||
@@ -988,9 +988,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🐍 Executing Python script in env {env_name}")
|
||||
result = hexstrike_client.safe_post("api/python/execute", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Python script executed successfully")
|
||||
logger.info("✅ Python script executed successfully")
|
||||
else:
|
||||
logger.error(f"❌ Python script execution failed")
|
||||
logger.error("❌ Python script execution failed")
|
||||
return result
|
||||
|
||||
# ============================================================================
|
||||
@@ -1167,9 +1167,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔐 Starting John the Ripper: {hash_file}")
|
||||
result = hexstrike_client.safe_post("api/tools/john", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ John the Ripper completed")
|
||||
logger.info("✅ John the Ripper completed")
|
||||
else:
|
||||
logger.error(f"❌ John the Ripper failed")
|
||||
logger.error("❌ John the Ripper failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1337,9 +1337,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔐 Starting Hashcat attack: mode {attack_mode}")
|
||||
result = hexstrike_client.safe_post("api/tools/hashcat", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Hashcat attack completed")
|
||||
logger.info("✅ Hashcat attack completed")
|
||||
else:
|
||||
logger.error(f"❌ Hashcat attack failed")
|
||||
logger.error("❌ Hashcat attack failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1690,9 +1690,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔍 Starting arp-scan: {target if target else 'local network'}")
|
||||
result = hexstrike_client.safe_post("api/tools/arp-scan", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ arp-scan completed")
|
||||
logger.info("✅ arp-scan completed")
|
||||
else:
|
||||
logger.error(f"❌ arp-scan failed")
|
||||
logger.error("❌ arp-scan failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1727,9 +1727,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔍 Starting Responder on interface: {interface}")
|
||||
result = hexstrike_client.safe_post("api/tools/responder", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Responder completed")
|
||||
logger.info("✅ Responder completed")
|
||||
else:
|
||||
logger.error(f"❌ Responder failed")
|
||||
logger.error("❌ Responder failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1755,9 +1755,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🧠 Starting Volatility analysis: {plugin}")
|
||||
result = hexstrike_client.safe_post("api/tools/volatility", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Volatility analysis completed")
|
||||
logger.info("✅ Volatility analysis completed")
|
||||
else:
|
||||
logger.error(f"❌ Volatility analysis failed")
|
||||
logger.error("❌ Volatility analysis failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1787,9 +1787,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🚀 Starting MSFVenom payload generation: {payload}")
|
||||
result = hexstrike_client.safe_post("api/tools/msfvenom", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ MSFVenom payload generated")
|
||||
logger.info("✅ MSFVenom payload generated")
|
||||
else:
|
||||
logger.error(f"❌ MSFVenom payload generation failed")
|
||||
logger.error("❌ MSFVenom payload generation failed")
|
||||
return result
|
||||
|
||||
# ============================================================================
|
||||
@@ -2071,9 +2071,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔧 Starting Pwntools exploit: {exploit_type}")
|
||||
result = hexstrike_client.safe_post("api/tools/pwntools", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Pwntools exploit completed")
|
||||
logger.info("✅ Pwntools exploit completed")
|
||||
else:
|
||||
logger.error(f"❌ Pwntools exploit failed")
|
||||
logger.error("❌ Pwntools exploit failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -2097,9 +2097,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔧 Starting one_gadget analysis: {libc_path}")
|
||||
result = hexstrike_client.safe_post("api/tools/one-gadget", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ one_gadget analysis completed")
|
||||
logger.info("✅ one_gadget analysis completed")
|
||||
else:
|
||||
logger.error(f"❌ one_gadget analysis failed")
|
||||
logger.error("❌ one_gadget analysis failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -2157,9 +2157,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔧 Starting GDB-PEDA analysis: {binary or f'PID {attach_pid}' or core_file}")
|
||||
result = hexstrike_client.safe_post("api/tools/gdb-peda", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ GDB-PEDA analysis completed")
|
||||
logger.info("✅ GDB-PEDA analysis completed")
|
||||
else:
|
||||
logger.error(f"❌ GDB-PEDA analysis failed")
|
||||
logger.error("❌ GDB-PEDA analysis failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -2191,9 +2191,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔧 Starting angr analysis: {binary}")
|
||||
result = hexstrike_client.safe_post("api/tools/angr", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ angr analysis completed")
|
||||
logger.info("✅ angr analysis completed")
|
||||
else:
|
||||
logger.error(f"❌ angr analysis failed")
|
||||
logger.error("❌ angr analysis failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -2225,9 +2225,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔧 Starting ropper analysis: {binary}")
|
||||
result = hexstrike_client.safe_post("api/tools/ropper", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ ropper analysis completed")
|
||||
logger.info("✅ ropper analysis completed")
|
||||
else:
|
||||
logger.error(f"❌ ropper analysis failed")
|
||||
logger.error("❌ ropper analysis failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -2256,9 +2256,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔧 Starting pwninit setup: {binary}")
|
||||
result = hexstrike_client.safe_post("api/tools/pwninit", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ pwninit setup completed")
|
||||
logger.info("✅ pwninit setup completed")
|
||||
else:
|
||||
logger.error(f"❌ pwninit setup failed")
|
||||
logger.error("❌ pwninit setup failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -2667,9 +2667,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🎯 Starting Dalfox XSS scan: {url if url else 'pipe mode'}")
|
||||
result = hexstrike_client.safe_post("api/tools/dalfox", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Dalfox XSS scan completed")
|
||||
logger.info("✅ Dalfox XSS scan completed")
|
||||
else:
|
||||
logger.error(f"❌ Dalfox XSS scan failed")
|
||||
logger.error("❌ Dalfox XSS scan failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -2922,7 +2922,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
if payload_info.get("risk_level") == "HIGH":
|
||||
results["summary"]["high_risk_payloads"] += 1
|
||||
|
||||
logger.info(f"✅ Attack suite generated:")
|
||||
logger.info("✅ Attack suite generated:")
|
||||
logger.info(f" ├─ Total payloads: {results['summary']['total_payloads']}")
|
||||
logger.info(f" ├─ High-risk payloads: {results['summary']['high_risk_payloads']}")
|
||||
logger.info(f" └─ Test cases: {results['summary']['test_cases']}")
|
||||
@@ -2967,7 +2967,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
endpoint_count = len(result.get("results", []))
|
||||
logger.info(f"✅ API endpoint testing completed: {endpoint_count} endpoints tested")
|
||||
else:
|
||||
logger.info(f"✅ API endpoint discovery completed")
|
||||
logger.info("✅ API endpoint discovery completed")
|
||||
else:
|
||||
logger.error("❌ API fuzzing failed")
|
||||
|
||||
@@ -3032,7 +3032,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"target_url": target_url
|
||||
}
|
||||
|
||||
logger.info(f"🔍 Starting JWT security analysis")
|
||||
logger.info("🔍 Starting JWT security analysis")
|
||||
result = hexstrike_client.safe_post("api/tools/jwt_analyzer", data)
|
||||
|
||||
if result.get("success"):
|
||||
@@ -3089,7 +3089,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.warning(f" ├─ [{severity}] {issue_type}")
|
||||
|
||||
if endpoint_count > 0:
|
||||
logger.info(f"📊 Discovered endpoints:")
|
||||
logger.info("📊 Discovered endpoints:")
|
||||
for endpoint in analysis.get("endpoints_found", [])[:5]: # Show first 5
|
||||
method = endpoint.get("method", "GET")
|
||||
path = endpoint.get("path", "/")
|
||||
@@ -3183,7 +3183,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"audit_coverage": "comprehensive" if len(audit_results["tests_performed"]) >= 3 else "partial"
|
||||
}
|
||||
|
||||
logger.info(f"✅ Comprehensive API audit completed:")
|
||||
logger.info("✅ Comprehensive API audit completed:")
|
||||
logger.info(f" ├─ Tests performed: {audit_results['summary']['tests_performed']}")
|
||||
logger.info(f" ├─ Total vulnerabilities: {audit_results['summary']['total_vulnerabilities']}")
|
||||
logger.info(f" └─ Coverage: {audit_results['summary']['audit_coverage']}")
|
||||
@@ -3220,9 +3220,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🧠 Starting Volatility3 analysis: {plugin}")
|
||||
result = hexstrike_client.safe_post("api/tools/volatility3", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Volatility3 analysis completed")
|
||||
logger.info("✅ Volatility3 analysis completed")
|
||||
else:
|
||||
logger.error(f"❌ Volatility3 analysis failed")
|
||||
logger.error("❌ Volatility3 analysis failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -3248,9 +3248,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"📁 Starting Foremost file carving: {input_file}")
|
||||
result = hexstrike_client.safe_post("api/tools/foremost", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Foremost carving completed")
|
||||
logger.info("✅ Foremost carving completed")
|
||||
else:
|
||||
logger.error(f"❌ Foremost carving failed")
|
||||
logger.error("❌ Foremost carving failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -3308,9 +3308,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"📷 Starting ExifTool analysis: {file_path}")
|
||||
result = hexstrike_client.safe_post("api/tools/exiftool", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ ExifTool analysis completed")
|
||||
logger.info("✅ ExifTool analysis completed")
|
||||
else:
|
||||
logger.error(f"❌ ExifTool analysis failed")
|
||||
logger.error("❌ ExifTool analysis failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -3335,12 +3335,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"append_data": append_data,
|
||||
"additional_args": additional_args
|
||||
}
|
||||
logger.info(f"🔐 Starting HashPump attack")
|
||||
logger.info("🔐 Starting HashPump attack")
|
||||
result = hexstrike_client.safe_post("api/tools/hashpump", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ HashPump attack completed")
|
||||
logger.info("✅ HashPump attack completed")
|
||||
else:
|
||||
logger.error(f"❌ HashPump attack failed")
|
||||
logger.error("❌ HashPump attack failed")
|
||||
return result
|
||||
|
||||
# ============================================================================
|
||||
@@ -3383,9 +3383,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🕷️ Starting Hakrawler crawling: {url}")
|
||||
result = hexstrike_client.safe_post("api/tools/hakrawler", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Hakrawler crawling completed")
|
||||
logger.info("✅ Hakrawler crawling completed")
|
||||
else:
|
||||
logger.error(f"❌ Hakrawler crawling failed")
|
||||
logger.error("❌ Hakrawler crawling failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -3416,12 +3416,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"output_file": output_file,
|
||||
"additional_args": additional_args
|
||||
}
|
||||
logger.info(f"🌐 Starting HTTPx probing")
|
||||
logger.info("🌐 Starting HTTPx probing")
|
||||
result = hexstrike_client.safe_post("api/tools/httpx", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ HTTPx probing completed")
|
||||
logger.info("✅ HTTPx probing completed")
|
||||
else:
|
||||
logger.error(f"❌ HTTPx probing failed")
|
||||
logger.error("❌ HTTPx probing failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -3449,9 +3449,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
logger.info(f"🔍 Starting ParamSpider discovery: {domain}")
|
||||
result = hexstrike_client.safe_post("api/tools/paramspider", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ ParamSpider discovery completed")
|
||||
logger.info("✅ ParamSpider discovery completed")
|
||||
else:
|
||||
logger.error(f"❌ ParamSpider discovery failed")
|
||||
logger.error("❌ ParamSpider discovery failed")
|
||||
return result
|
||||
|
||||
# ============================================================================
|
||||
@@ -3486,12 +3486,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
"output_file": output_file,
|
||||
"additional_args": additional_args
|
||||
}
|
||||
logger.info(f"🔍 Starting Burp Suite scan")
|
||||
logger.info("🔍 Starting Burp Suite scan")
|
||||
result = hexstrike_client.safe_post("api/tools/burpsuite", data)
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Burp Suite scan completed")
|
||||
logger.info("✅ Burp Suite scan completed")
|
||||
else:
|
||||
logger.error(f"❌ Burp Suite scan failed")
|
||||
logger.error("❌ Burp Suite scan failed")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -3794,7 +3794,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
Returns:
|
||||
Server health information with tool availability and telemetry
|
||||
"""
|
||||
logger.info(f"🏥 Checking HexStrike AI server health")
|
||||
logger.info("🏥 Checking HexStrike AI server health")
|
||||
result = hexstrike_client.check_health()
|
||||
if result.get("status") == "healthy":
|
||||
logger.info(f"✅ Server is healthy - {result.get('total_tools_available', 0)} tools available")
|
||||
@@ -3810,7 +3810,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
Returns:
|
||||
Cache performance statistics
|
||||
"""
|
||||
logger.info(f"💾 Getting cache statistics")
|
||||
logger.info("💾 Getting cache statistics")
|
||||
result = hexstrike_client.safe_get("api/cache/stats")
|
||||
if "hit_rate" in result:
|
||||
logger.info(f"📊 Cache hit rate: {result.get('hit_rate', 'unknown')}")
|
||||
@@ -3824,12 +3824,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
Returns:
|
||||
Cache clear operation results
|
||||
"""
|
||||
logger.info(f"🧹 Clearing server cache")
|
||||
logger.info("🧹 Clearing server cache")
|
||||
result = hexstrike_client.safe_post("api/cache/clear", {})
|
||||
if result.get("success"):
|
||||
logger.info(f"✅ Cache cleared successfully")
|
||||
logger.info("✅ Cache cleared successfully")
|
||||
else:
|
||||
logger.error(f"❌ Failed to clear cache")
|
||||
logger.error("❌ Failed to clear cache")
|
||||
return result
|
||||
|
||||
@mcp.tool()
|
||||
@@ -3840,7 +3840,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
Returns:
|
||||
System performance and usage telemetry
|
||||
"""
|
||||
logger.info(f"📈 Getting system telemetry")
|
||||
logger.info("📈 Getting system telemetry")
|
||||
result = hexstrike_client.safe_get("api/telemetry")
|
||||
if "commands_executed" in result:
|
||||
logger.info(f"📊 Commands executed: {result.get('commands_executed', 0)}")
|
||||
@@ -3993,7 +3993,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
|
||||
execution_time = result.get("execution_time", 0)
|
||||
logger.info(f"✅ Command completed successfully in {execution_time:.2f}s")
|
||||
else:
|
||||
logger.warning(f"⚠️ Command completed with errors")
|
||||
logger.warning("⚠️ Command completed with errors")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -5433,7 +5433,7 @@ def main():
|
||||
logger.debug("🔍 Debug logging enabled")
|
||||
|
||||
# MCP compatibility: No banner output to avoid JSON parsing issues
|
||||
logger.info(f"🚀 Starting HexStrike AI MCP Client v6.0")
|
||||
logger.info("🚀 Starting HexStrike AI MCP Client v6.0")
|
||||
logger.info(f"🔗 Connecting to: {args.server}")
|
||||
|
||||
try:
|
||||
|
||||
427
third_party/hexstrike/hexstrike_server.py
vendored
427
third_party/hexstrike/hexstrike_server.py
vendored
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user