Merge pull request #19 from giveen/workspace

feat(workspaces): add unified /workspace lifecycle, target persistence, and workspace-scoped RAG
This commit is contained in:
Masic
2026-01-19 21:25:53 -07:00
committed by GitHub
33 changed files with 2652 additions and 1347 deletions

12
.gitignore vendored
View File

@@ -81,3 +81,15 @@ Thumbs.db
tmp/
temp/
*.tmp
# Local test artifacts and test scripts (do not commit local test runs)
tests/*.log
tests/*.out
tests/output/
tests/tmp/
tests/*.local.py
scripts/test_*.sh
*.test.sh
# Workspaces directory (user data should not be committed)
/workspaces/

BIN
dupe-workspace.tar.gz Normal file

Binary file not shown.

BIN
expimp-workspace.tar.gz Normal file

Binary file not shown.

View File

@@ -5,6 +5,8 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, AsyncIterator, List, Optional
from ..config.constants import AGENT_MAX_ITERATIONS
from ..workspaces.manager import WorkspaceManager
from ..workspaces import validation
from .state import AgentState, AgentStateManager
if TYPE_CHECKING:
@@ -79,68 +81,73 @@ class BaseAgent(ABC):
tools: List["Tool"],
runtime: "Runtime",
max_iterations: int = AGENT_MAX_ITERATIONS,
**kwargs,
):
"""
Initialize the base agent.
Initialize base agent state.
Args:
llm: The LLM instance for generating responses
tools: List of tools available to the agent
runtime: The runtime environment for tool execution
max_iterations: Maximum iterations before forcing stop (safety limit)
llm: LLM instance used for generation
tools: Available tool list
runtime: Runtime used for tool execution
max_iterations: Safety limit for iterations
"""
self.llm = llm
self.tools = tools
self.runtime = runtime
self.max_iterations = max_iterations
# Agent runtime state
self.state_manager = AgentStateManager()
self.conversation_history: List[AgentMessage] = []
# Each agent gets its own plan instance
from ..tools.finish import TaskPlan
# Task planning structure (used by finish tool)
try:
from ..tools.finish import TaskPlan
self._task_plan = TaskPlan()
self._task_plan = TaskPlan()
except Exception as e:
import logging
# Attach plan to runtime so finish tool can access it
self.runtime.plan = self._task_plan
logging.getLogger(__name__).exception("Failed importing TaskPlan: %s", e)
try:
from ..interface.notifier import notify
# Use tools as-is (finish accesses plan via runtime)
self.tools = list(tools)
notify("warning", f"Failed to import TaskPlan: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about TaskPlan import failure")
# Fallback simple plan structure
class _SimplePlan:
def __init__(self):
self.steps = []
self.original_request = ""
@property
def state(self) -> AgentState:
"""Get current agent state."""
return self.state_manager.current_state
def clear(self):
self.steps.clear()
@state.setter
def state(self, value: AgentState):
"""Set agent state."""
self.state_manager.transition_to(value)
def is_complete(self):
return True
def cleanup_after_cancel(self) -> None:
"""
Clean up agent state after a cancellation.
def has_failure(self):
return False
Removes the cancelled request and any pending tool calls from
conversation history to prevent stale responses from contaminating
the next conversation.
"""
# Remove incomplete messages from the end of conversation
while self.conversation_history:
last_msg = self.conversation_history[-1]
# Remove assistant message with tool calls (incomplete tool execution)
if last_msg.role == "assistant" and last_msg.tool_calls:
self.conversation_history.pop()
# Remove orphaned tool_result messages
elif last_msg.role == "tool":
self.conversation_history.pop()
# Remove the user message that triggered the cancelled request
elif last_msg.role == "user":
self.conversation_history.pop()
break # Stop after removing the user message
else:
break
self._task_plan = _SimplePlan()
# Reset state to idle
# Expose plan to runtime so tools like `finish` can access it
try:
self.runtime.plan = self._task_plan
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to attach plan to runtime: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to attach plan to runtime: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about runtime plan attach failure")
# Ensure agent starts idle
self.state_manager.transition_to(AgentState.IDLE)
@abstractmethod
@@ -448,16 +455,59 @@ class BaseAgent(ABC):
if tool:
try:
result = await tool.execute(arguments, self.runtime)
results.append(
ToolResult(
tool_call_id=tool_call_id,
tool_name=name,
result=result,
success=True,
# Before executing, enforce target safety gate when workspace active
wm = WorkspaceManager()
active = wm.get_active()
# Use centralized validation helpers for target extraction and scope checks
candidates = validation.gather_candidate_targets(arguments)
out_of_scope = []
if active:
allowed = wm.list_targets(active)
for c in candidates:
try:
if not validation.is_target_in_scope(c, allowed):
out_of_scope.append(c)
except Exception as e:
import logging
logging.getLogger(__name__).exception(
"Error validating candidate target %s: %s", c, e
)
out_of_scope.append(c)
if active and out_of_scope:
# Block execution and return an explicit error requiring operator confirmation
results.append(
ToolResult(
tool_call_id=tool_call_id,
tool_name=name,
error=(
f"Out-of-scope target(s): {out_of_scope} - operator confirmation required. "
"Set workspace targets with /target or run tool manually."
),
success=False,
)
)
else:
result = await tool.execute(arguments, self.runtime)
results.append(
ToolResult(
tool_call_id=tool_call_id,
tool_name=name,
result=result,
success=True,
)
)
)
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error executing tool %s: %s", name, e)
try:
from ..interface.notifier import notify
notify("warning", f"Tool execution failed ({name}): {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about tool execution failure")
results.append(
ToolResult(
tool_call_id=tool_call_id,

View File

@@ -111,8 +111,16 @@ class CrewOrchestrator:
sections.append("\n".join(grouped[cat]))
notes_context = "\n\n".join(sections)
except Exception:
pass # Notes not available
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to gather notes for orchestrator prompt: %s", e)
try:
from ...interface.notifier import notify
notify("warning", f"Orchestrator: failed to gather notes: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about orchestrator notes failure: %s", ne)
# Format insights for prompt
insights_text = ""
@@ -271,6 +279,15 @@ class CrewOrchestrator:
break # Exit immediately after finish
except Exception as e:
import logging
logging.getLogger(__name__).exception("Worker tool execution failed (%s): %s", tc_name, e)
try:
from ...interface.notifier import notify
notify("warning", f"Worker tool execution failed ({tc_name}): {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about worker tool failure: %s", ne)
error_msg = f"Error: {e}"
yield {
"phase": "tool_result",
@@ -327,6 +344,15 @@ class CrewOrchestrator:
self.pool.finish_tokens = 0
break
except Exception as e:
import logging
logging.getLogger(__name__).exception("Auto-finish failed: %s", e)
try:
from ...interface.notifier import notify
notify("warning", f"Auto-finish failed: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about auto-finish failure: %s", ne)
yield {
"phase": "error",
"error": f"Auto-finish failed: {e}",
@@ -344,6 +370,15 @@ class CrewOrchestrator:
yield {"phase": "complete", "report": final_report}
except Exception as e:
import logging
logging.getLogger(__name__).exception("Orchestrator run failed: %s", e)
try:
from ...interface.notifier import notify
notify("error", f"CrewOrchestrator run failed: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about orchestrator run failure: %s", ne)
self.state = CrewState.ERROR
yield {"phase": "error", "error": str(e)}

View File

@@ -230,6 +230,15 @@ class WorkerPool:
raise
except Exception as e:
import logging
logging.getLogger(__name__).exception("Worker execution failed (%s): %s", worker.id, e)
try:
from ...interface.notifier import notify
notify("warning", f"Worker execution failed ({worker.id}): {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about worker execution failure: %s", ne)
worker.error = str(e)
worker.status = AgentStatus.ERROR
worker.completed_at = time.time()
@@ -239,8 +248,16 @@ class WorkerPool:
# Cleanup worker's isolated runtime
try:
await worker_runtime.stop()
except Exception:
pass # Best effort cleanup
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to stop worker runtime for %s: %s", worker.id, e)
try:
from ...interface.notifier import notify
notify("warning", f"Failed to stop worker runtime for {worker.id}: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about worker runtime stop failure: %s", ne)
async def _wait_for_dependencies(self, depends_on: List[str]) -> None:
"""Wait for dependent workers to complete."""
@@ -248,8 +265,17 @@ class WorkerPool:
if dep_id in self._tasks:
try:
await self._tasks[dep_id]
except (asyncio.CancelledError, Exception):
pass # Dependency failed, but we continue
except (asyncio.CancelledError, Exception) as e:
import logging
logging.getLogger(__name__).exception("Dependency wait failed for %s: %s", dep_id, e)
try:
from ...interface.notifier import notify
notify("warning", f"Dependency wait failed for {dep_id}: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about dependency wait failure: %s", ne)
# Dependency failed, but we continue
async def wait_for(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Any]:
"""
@@ -269,8 +295,16 @@ class WorkerPool:
if agent_id in self._tasks:
try:
await self._tasks[agent_id]
except (asyncio.CancelledError, Exception):
pass
except (asyncio.CancelledError, Exception) as e:
import logging
logging.getLogger(__name__).exception("Waiting for agent task %s failed: %s", agent_id, e)
try:
from ...interface.notifier import notify
notify("warning", f"Waiting for agent {agent_id} failed: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about wait_for agent failure: %s", ne)
worker = self._workers.get(agent_id)
if worker:

View File

@@ -117,8 +117,16 @@ class PentestAgentAgent(BaseAgent):
sections.append("\n".join(grouped[cat]))
notes_context = "\n\n".join(sections)
except Exception:
pass # Notes not available
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to gather notes for agent prompt: %s", e)
try:
from ...interface.notifier import notify
notify("warning", f"Agent: failed to gather notes: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about agent notes failure")
# Get environment info from runtime
env = self.runtime.environment

View File

@@ -2,6 +2,7 @@
import argparse
import asyncio
from pathlib import Path
from ..config.constants import AGENT_MAX_ITERATIONS, DEFAULT_MODEL
from .cli import run_cli
@@ -127,6 +128,25 @@ Examples:
mcp_test = mcp_subparsers.add_parser("test", help="Test MCP server connection")
mcp_test.add_argument("name", help="Server name to test")
# workspace management
ws_parser = subparsers.add_parser(
"workspace", help="Workspace lifecycle and info commands"
)
ws_parser.add_argument(
"action",
nargs="?",
help="Action or workspace name. Subcommands: info, note, clear, export, import",
)
ws_parser.add_argument("rest", nargs=argparse.REMAINDER, help="Additional arguments")
# NOTE: use `workspace list` to list workspaces (handled by workspace subcommand)
# target management
tgt_parser = subparsers.add_parser(
"target", help="Add or list targets for the active workspace"
)
tgt_parser.add_argument("values", nargs="*", help="Targets to add (IP/CIDR/hostname)")
return parser, parser.parse_args()
@@ -304,6 +324,232 @@ def handle_mcp_command(args: argparse.Namespace):
console.print("[yellow]Use 'pentestagent mcp --help' for available commands[/]")
def handle_workspace_command(args: argparse.Namespace):
"""Handle workspace lifecycle commands and actions."""
from pentestagent.workspaces.manager import WorkspaceError, WorkspaceManager
from pentestagent.workspaces.utils import (
export_workspace,
import_workspace,
resolve_knowledge_paths,
)
wm = WorkspaceManager()
action = args.action
rest = args.rest or []
# No args -> show active workspace
if not action:
active = wm.get_active()
if not active:
print("No active workspace.")
else:
print(f"Active workspace: {active}")
return
# Subcommands
if action == "info":
# show info for active or specified workspace
name = rest[0] if rest else wm.get_active()
if not name:
print("No workspace specified and no active workspace.")
return
try:
meta = wm.get_meta(name)
created = meta.get("created_at")
last_active = meta.get("last_active_at")
targets = meta.get("targets", [])
kp = resolve_knowledge_paths()
ks = "workspace" if kp.get("using_workspace") else "global"
# estimate loot size if present
import os
loot_dir = (wm.workspace_path(name) / "loot").resolve()
size = 0
files = 0
if loot_dir.exists():
for rootp, _, filenames in os.walk(loot_dir):
for fn in filenames:
try:
fp = os.path.join(rootp, fn)
size += os.path.getsize(fp)
files += 1
except Exception:
# Best-effort loot stats: skip files we can't stat (e.g., permissions, broken symlinks)
pass
print(f"Name: {name}")
print(f"Created: {created}")
print(f"Last active: {last_active}")
print(f"Targets: {len(targets)}")
print(f"Knowledge scope: {ks}")
print(f"Loot files: {files}, approx size: {size} bytes")
except Exception as e:
print(f"Error retrieving workspace info: {e}")
return
if action == "list":
# list all workspaces and mark active
try:
wss = wm.list_workspaces()
active = wm.get_active()
if not wss:
print("No workspaces found.")
return
for name in sorted(wss):
prefix = "* " if name == active else " "
print(f"{prefix}{name}")
except Exception as e:
print(f"Error listing workspaces: {e}")
return
if action == "note":
# Append operator note to active workspace (or specified via --workspace/-w)
active = wm.get_active()
name = active
text_parts = rest or []
i = 0
# Parse optional workspace selector flags before the note text.
while i < len(text_parts):
part = text_parts[i]
if part in ("--workspace", "-w"):
if i + 1 >= len(text_parts):
print("Usage: workspace note [--workspace NAME] <text>")
return
name = text_parts[i + 1]
i += 2
continue
# First non-option token marks the start of the note text
break
if not name:
print("No active workspace. Set one with /workspace <name>.")
return
text = " ".join(text_parts[i:])
if not text:
print("Usage: workspace note [--workspace NAME] <text>")
return
try:
wm.set_operator_note(name, text)
print(f"Operator note saved for workspace '{name}'.")
except Exception as e:
print(f"Error saving note: {e}")
return
if action == "clear":
active = wm.get_active()
if not active:
print("No active workspace.")
return
marker = wm.active_marker()
try:
if marker.exists():
marker.unlink()
print(f"Workspace '{active}' deactivated.")
except Exception as e:
print(f"Error deactivating workspace: {e}")
return
if action == "export":
# export <NAME> [--output file.tar.gz]
if not rest:
print("Usage: workspace export <NAME> [--output file.tar.gz]")
return
name = rest[0]
out = None
if "--output" in rest:
idx = rest.index("--output")
if idx + 1 < len(rest):
out = Path(rest[idx + 1])
try:
archive = export_workspace(name, output=out)
print(f"Workspace exported: {archive}")
except Exception as e:
print(f"Export failed: {e}")
return
if action == "import":
# import <ARCHIVE>
if not rest:
print("Usage: workspace import <archive.tar.gz>")
return
archive = Path(rest[0])
try:
name = import_workspace(archive)
print(f"Workspace imported: {name} (not activated)")
except Exception as e:
print(f"Import failed: {e}")
return
# Default: treat action as workspace name -> create and set active
name = action
try:
existed = wm.workspace_path(name).exists()
if not existed:
wm.create(name)
wm.set_active(name)
# restore last target if present
last = wm.get_meta_field(name, "last_target")
if last:
print(f"Workspace '{name}' set active. Restored target: {last}")
else:
if existed:
print(f"Workspace '{name}' set active.")
else:
print(f"Workspace '{name}' created and set active.")
except WorkspaceError as e:
print(f"Error: {e}")
except Exception as e:
print(f"Error creating workspace: {e}")
def handle_workspaces_list():
from pentestagent.workspaces.manager import WorkspaceManager
wm = WorkspaceManager()
wss = wm.list_workspaces()
active = wm.get_active()
if not wss:
print("No workspaces found.")
return
for name in sorted(wss):
prefix = "* " if name == active else " "
print(f"{prefix}{name}")
def handle_target_command(args: argparse.Namespace):
"""Handle target add/list commands."""
from pentestagent.workspaces.manager import WorkspaceError, WorkspaceManager
wm = WorkspaceManager()
active = wm.get_active()
if not active:
print("No active workspace. Set one with /workspace <name>.")
return
vals = args.values or []
try:
if not vals:
targets = wm.list_targets(active)
if not targets:
print(f"No targets for workspace '{active}'.")
else:
print(f"Targets for workspace '{active}': {targets}")
return
saved = wm.add_targets(active, vals)
print(f"Targets for workspace '{active}': {saved}")
except WorkspaceError as e:
print(f"Error: {e}")
except Exception as e:
print(f"Error updating targets: {e}")
def main():
"""Main entry point."""
parser, args = parse_arguments()
@@ -317,6 +563,16 @@ def main():
handle_mcp_command(args)
return
if args.command == "workspace":
handle_workspace_command(args)
return
# 'workspace list' handled by workspace subcommand
if args.command == "target":
handle_target_command(args)
return
if args.command == "run":
# Check model configuration
if not args.model:

View File

@@ -0,0 +1,40 @@
"""Simple notifier bridge for UI notifications.
Modules can call `notify(level, message)` to emit operator-visible
notifications. A UI (TUI) may register a callback via `register_callback()`
to receive notifications and display them. If no callback is registered,
notifications are logged.
"""
import logging
from typing import Callable, Optional
_callback: Optional[Callable[[str, str], None]] = None
def register_callback(cb: Callable[[str, str], None]) -> None:
"""Register a callback to receive notifications.
Callback receives (level, message).
"""
global _callback
_callback = cb
def notify(level: str, message: str) -> None:
"""Emit a notification. If UI callback registered, call it; otherwise log."""
global _callback
if _callback:
try:
_callback(level, message)
return
except Exception:
logging.getLogger(__name__).exception("Notifier callback failed")
# Fallback to logging
log = logging.getLogger("pentestagent.notifier")
if level.lower() in ("error", "critical"):
log.error(message)
elif level.lower() in ("warn", "warning"):
log.warning(message)
else:
log.info(message)

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any, List
from ..workspaces.utils import resolve_knowledge_paths
from .rag import Document
@@ -51,6 +52,11 @@ class KnowledgeIndexer:
total_files = 0
indexed_files = 0
# If directory is the default 'knowledge', prefer workspace knowledge if available
if directory == Path("knowledge"):
kp = resolve_knowledge_paths()
directory = kp.get("sources", Path("knowledge"))
if not directory.exists():
return documents, IndexingResult(
0, 0, 0, [f"Directory not found: {directory}"]

View File

@@ -1,12 +1,14 @@
"""RAG (Retrieval Augmented Generation) engine for PentestAgent."""
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
from ..workspaces.utils import resolve_knowledge_paths
from .embeddings import get_embeddings
@@ -65,9 +67,49 @@ class RAGEngine:
chunks = []
self._source_files = set() # Reset source file tracking
# Resolve knowledge paths (prefer workspace if available)
if self.knowledge_path != Path("knowledge"):
sources_base = self.knowledge_path
kp = None
else:
kp = resolve_knowledge_paths()
sources_base = kp.get("sources", Path("knowledge"))
# If workspace has a persisted index and we're not forcing reindex, try to load it
try:
if kp and kp.get("using_workspace"):
emb_dir = kp.get("embeddings")
emb_dir.mkdir(parents=True, exist_ok=True)
idx_path = emb_dir / "index.pkl"
if idx_path.exists() and not force:
try:
self.load_index(idx_path)
return
except Exception as e:
logging.getLogger(__name__).exception(
"Failed to load persisted RAG index at %s, will re-index: %s",
idx_path,
e,
)
try:
from ..interface.notifier import notify
notify(
"warning",
f"Failed to load persisted RAG index at {idx_path}: {e}",
)
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about RAG load failure: %s", ne
)
except Exception as e:
# Non-fatal — continue to index from sources, but log the error
logging.getLogger(__name__).exception(
"Error while checking for persisted workspace index: %s", e
)
# Process all files in knowledge directory
if self.knowledge_path.exists():
for file in self.knowledge_path.rglob("*"):
if sources_base.exists():
for file in sources_base.rglob("*"):
if not file.is_file():
continue
@@ -107,7 +149,9 @@ class RAGEngine:
)
)
except Exception as e:
print(f"[RAG] Error processing {file}: {e}")
logging.getLogger(__name__).exception(
"[RAG] Error processing %s: %s", file, e
)
self.documents = chunks
@@ -127,6 +171,36 @@ class RAGEngine:
doc.embedding = self.embeddings[i]
self._indexed = True
# If using a workspace, persist the built index for faster future loads
try:
if kp and kp.get("using_workspace") and self.embeddings is not None:
emb_dir = kp.get("embeddings")
emb_dir.mkdir(parents=True, exist_ok=True)
idx_path = emb_dir / "index.pkl"
try:
self.save_index(idx_path)
except Exception as e:
logging.getLogger(__name__).exception(
"Failed to save RAG index to %s: %s", idx_path, e
)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to save RAG index to {idx_path}: {e}")
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about RAG save failure: %s", ne
)
except Exception as e:
logging.getLogger(__name__).exception(
"Error while attempting to persist RAG index: %s", e
)
try:
from ..interface.notifier import notify
notify("warning", f"Error while attempting to persist RAG index: {e}")
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about RAG persist error: %s", ne
)
def _chunk_text(
self, text: str, source: str, chunk_size: int = 1000, overlap: int = 200
@@ -408,6 +482,22 @@ class RAGEngine:
with open(path, "wb") as f:
pickle.dump(data, f)
def save_index_to_workspace(self, root: Optional[Path] = None, filename: str = "index.pkl"):
"""
Convenience helper to save the index into the active workspace embeddings path.
Args:
root: Optional project root to resolve workspaces (defaults to cwd)
filename: Filename to use for the saved index
"""
from pathlib import Path as _P
kp = resolve_knowledge_paths(root=root)
emb_dir = kp.get("embeddings")
emb_dir.mkdir(parents=True, exist_ok=True)
path = _P(emb_dir) / filename
self.save_index(path)
def load_index(self, path: Path):
"""
Load the index from disk.
@@ -437,3 +527,20 @@ class RAGEngine:
doc.embedding = self.embeddings[i]
self._indexed = True
def load_index_from_workspace(self, root: Optional[Path] = None, filename: str = "index.pkl"):
"""
Convenience helper to load the index from the active workspace embeddings path.
Args:
root: Optional project root to resolve workspaces (defaults to cwd)
filename: Filename used for the saved index
"""
from pathlib import Path as _P
kp = resolve_knowledge_paths(root=root)
emb_dir = kp.get("embeddings")
path = _P(emb_dir) / filename
if not path.exists():
raise FileNotFoundError(f"Workspace index not found: {path}")
self.load_index(path)

View File

@@ -12,11 +12,10 @@ operates.
import asyncio
import os
import shutil
import sys
from pathlib import Path
from typing import Optional
import signal
import time
from pathlib import Path
from typing import Optional
try:
import aiohttp
@@ -24,9 +23,7 @@ except Exception:
aiohttp = None
LOOT_DIR = Path("loot/artifacts")
LOOT_DIR.mkdir(parents=True, exist_ok=True)
LOG_FILE = LOOT_DIR / "hexstrike.log"
from ..workspaces.utils import get_loot_file
class HexstrikeAdapter:
@@ -97,10 +94,19 @@ class HexstrikeAdapter:
try:
pid = getattr(self._process, "pid", None)
if pid:
with LOG_FILE.open("a") as fh:
log_file = get_loot_file("artifacts/hexstrike.log")
with log_file.open("a") as fh:
fh.write(f"[HexstrikeAdapter] started pid={pid}\n")
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to write hexstrike start PID to log: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to write hexstrike PID to log: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike PID log failure")
# Start a background reader task to capture logs
loop = asyncio.get_running_loop()
@@ -118,18 +124,26 @@ class HexstrikeAdapter:
return
try:
with LOG_FILE.open("ab") as fh:
log_file = get_loot_file("artifacts/hexstrike.log")
with log_file.open("ab") as fh:
while True:
line = await self._process.stdout.readline()
if not line:
break
# Prefix timestamps for easier debugging
fh.write(line)
fh.flush()
except asyncio.CancelledError:
pass
except Exception:
pass
return
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error capturing hexstrike output: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"HexStrike log capture failed: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike log capture failure")
async def stop(self, timeout: int = 5) -> None:
"""Stop the server process gracefully."""
@@ -143,10 +157,26 @@ class HexstrikeAdapter:
except asyncio.TimeoutError:
try:
proc.kill()
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to kill hexstrike after timeout: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to kill hexstrike after timeout: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike kill failure")
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error stopping hexstrike process: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error stopping hexstrike process: {e}")
except Exception:
pass
except Exception:
pass
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike stop error")
self._process = None
@@ -154,8 +184,14 @@ class HexstrikeAdapter:
self._reader_task.cancel()
try:
await self._reader_task
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error awaiting hexstrike reader task: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error awaiting hexstrike reader task: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike reader await failure")
def stop_sync(self, timeout: int = 5) -> None:
"""Synchronous stop helper for use during process-exit cleanup.
@@ -179,7 +215,15 @@ class HexstrikeAdapter:
try:
os.kill(pid, signal.SIGTERM)
except Exception:
pass
import logging
logging.getLogger(__name__).exception("Failed to SIGTERM hexstrike pid: %s", pid)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to SIGTERM hexstrike pid {pid}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike SIGTERM failure")
# wait briefly for process to exit
end = time.time() + float(timeout)
@@ -197,20 +241,52 @@ class HexstrikeAdapter:
try:
os.kill(pid, signal.SIGKILL)
except Exception:
pass
except Exception:
pass
import logging
logging.getLogger(__name__).exception("Failed to SIGKILL hexstrike pid: %s", pid)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to SIGKILL hexstrike pid {pid}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike SIGKILL failure")
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error during hexstrike stop_sync cleanup: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error during hexstrike stop_sync cleanup: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike stop_sync cleanup error")
def __del__(self):
try:
self.stop_sync()
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Exception during HexstrikeAdapter.__del__: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error during HexstrikeAdapter cleanup: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike __del__ error")
# Clear references
try:
self._process = None
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to clear HexstrikeAdapter process reference: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to clear hexstrike process reference: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike process-clear failure")
async def health_check(self, timeout: int = 5) -> bool:
"""Check the server health endpoint. Returns True if healthy."""
@@ -221,7 +297,16 @@ class HexstrikeAdapter:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=timeout) as resp:
return resp.status == 200
except Exception:
except Exception as e:
import logging
logging.getLogger(__name__).exception("HexstrikeAdapter health_check (aiohttp) failed: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"HexStrike health check failed: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike health check failure")
return False
# Fallback: synchronous urllib in thread
@@ -231,7 +316,16 @@ class HexstrikeAdapter:
try:
with urllib.request.urlopen(url, timeout=timeout) as r:
return r.status == 200
except Exception:
except Exception as e:
import logging
logging.getLogger(__name__).exception("HexstrikeAdapter health_check (urllib) failed: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"HexStrike health check failed: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about hexstrike urllib health check failure")
return False
loop = asyncio.get_running_loop()

View File

@@ -13,9 +13,9 @@ Uses standard MCP configuration format:
"""
import asyncio
import atexit
import json
import os
import atexit
import signal
from dataclasses import dataclass, field
from pathlib import Path
@@ -78,8 +78,14 @@ class MCPManager:
# Ensure we attempt to clean up vendored servers on process exit
try:
atexit.register(self._atexit_cleanup)
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Failed to register atexit cleanup: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to register MCP atexit cleanup: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about atexit.register failure")
def _find_config(self) -> Path:
for path in self.DEFAULT_CONFIG_PATHS:
@@ -202,8 +208,10 @@ class MCPManager:
try:
stop()
continue
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception(
"Error running adapter.stop(): %s", e
)
# Final fallback: kill underlying PID if available
pid = None
@@ -213,17 +221,18 @@ class MCPManager:
if pid:
try:
os.kill(pid, signal.SIGTERM)
except Exception:
except Exception as e:
logging.getLogger(__name__).exception("Failed to SIGTERM pid %s: %s", pid, e)
try:
os.kill(pid, signal.SIGKILL)
except Exception:
pass
except Exception:
pass
except Exception as e2:
logging.getLogger(__name__).exception("Failed to SIGKILL pid %s: %s", pid, e2)
except Exception as e:
logging.getLogger(__name__).exception("Error while attempting synchronous adapter stop: %s", e)
async def _stop_started_adapters_and_disconnect(self) -> None:
# Stop any adapters we started
for name, adapter in list(self._started_adapters.items()):
for _name, adapter in list(self._started_adapters.items()):
try:
stop = getattr(adapter, "stop", None)
if stop:
@@ -233,15 +242,15 @@ class MCPManager:
# run blocking stop in executor
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, stop)
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Error stopping adapter in async shutdown: %s", e)
self._started_adapters.clear()
# Disconnect any active MCP server connections
try:
await self.disconnect_all()
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Error during disconnect_all in shutdown: %s", e)
def add_server(
self,

View File

@@ -10,10 +10,10 @@ health check on a configurable port.
import asyncio
import os
import shutil
import signal
import time
from pathlib import Path
from typing import Optional
import time
import signal
try:
import aiohttp
@@ -21,9 +21,7 @@ except Exception:
aiohttp = None
LOOT_DIR = Path("loot/artifacts")
LOOT_DIR.mkdir(parents=True, exist_ok=True)
LOG_FILE = LOOT_DIR / "metasploit_mcp.log"
from ..workspaces.utils import get_loot_file
class MetasploitAdapter:
@@ -136,14 +134,24 @@ class MetasploitAdapter:
await asyncio.sleep(0.5)
# If we fallthrough, msfrpcd didn't become ready in time
return
except Exception:
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to start msfrpcd: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to start msfrpcd: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd start failure")
return
async def _capture_msfrpcd_output(self) -> None:
if not self._msfrpcd_proc or not self._msfrpcd_proc.stdout:
return
try:
with LOG_FILE.open("ab") as fh:
log_file = get_loot_file("artifacts/msfrpcd.log")
with log_file.open("ab") as fh:
while True:
line = await self._msfrpcd_proc.stdout.readline()
if not line:
@@ -151,9 +159,17 @@ class MetasploitAdapter:
fh.write(b"[msfrpcd] " + line)
fh.flush()
except asyncio.CancelledError:
pass
except Exception:
pass
return
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error capturing msfrpcd output: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"msfrpcd log capture failed: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd log capture failure")
async def start(self, background: bool = True, timeout: int = 30) -> bool:
"""Start the vendored Metasploit MCP server.
@@ -173,8 +189,16 @@ class MetasploitAdapter:
if str(self.transport).lower() in ("http", "sse"):
try:
await self._start_msfrpcd_if_needed()
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error starting msfrpcd: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error starting msfrpcd: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd error")
cmd = self._build_command()
resolved = shutil.which(self.python_cmd) or self.python_cmd
@@ -193,10 +217,19 @@ class MetasploitAdapter:
try:
pid = getattr(self._process, "pid", None)
if pid:
with LOG_FILE.open("a") as fh:
log_file = get_loot_file("artifacts/metasploit_mcp.log")
with log_file.open("a") as fh:
fh.write(f"[MetasploitAdapter] started pid={pid}\n")
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to write metasploit start PID to log: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to write metasploit PID to log: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about metasploit PID log failure")
# Start background reader
loop = asyncio.get_running_loop()
@@ -204,7 +237,16 @@ class MetasploitAdapter:
try:
return await self.health_check(timeout=timeout)
except Exception:
except Exception as e:
import logging
logging.getLogger(__name__).exception("MetasploitAdapter health_check raised: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Metasploit health check failed: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about metasploit health check failure")
return False
async def _capture_output(self) -> None:
@@ -212,7 +254,8 @@ class MetasploitAdapter:
return
try:
with LOG_FILE.open("ab") as fh:
log_file = get_loot_file("artifacts/metasploit_mcp.log")
with log_file.open("ab") as fh:
while True:
line = await self._process.stdout.readline()
if not line:
@@ -220,9 +263,17 @@ class MetasploitAdapter:
fh.write(line)
fh.flush()
except asyncio.CancelledError:
pass
except Exception:
pass
return
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error capturing metasploit output: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Metasploit log capture failed: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about metasploit log capture failure")
async def stop(self, timeout: int = 5) -> None:
proc = self._process
@@ -237,8 +288,16 @@ class MetasploitAdapter:
proc.kill()
except Exception:
pass
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error waiting for process termination: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error stopping metasploit adapter: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about metasploit stop error")
self._process = None
@@ -246,8 +305,14 @@ class MetasploitAdapter:
self._reader_task.cancel()
try:
await self._reader_task
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to kill msfrpcd during stop: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to kill msfrpcd: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about msfrpcd kill failure")
# Stop msfrpcd if we started it
try:
@@ -261,8 +326,16 @@ class MetasploitAdapter:
msf_proc.kill()
except Exception:
pass
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error stopping metasploit adapter cleanup: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error stopping metasploit adapter: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about metasploit adapter cleanup error")
finally:
self._msfrpcd_proc = None

View File

@@ -210,7 +210,16 @@ class SSETransport(MCPTransport):
except asyncio.TimeoutError:
# If endpoint not discovered, continue; send() will try discovery
pass
except Exception:
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed opening SSE stream: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed opening SSE stream: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE open failure")
# If opening the SSE stream fails, still mark connected so
# send() can attempt POST discovery and report meaningful errors.
self._sse_response = None
@@ -312,7 +321,16 @@ class SSETransport(MCPTransport):
else:
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/')}"
return
except Exception:
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error during SSE POST endpoint discovery: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error during SSE POST endpoint discovery: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE discovery error")
return
async def disconnect(self):
@@ -323,21 +341,53 @@ class SSETransport(MCPTransport):
self._sse_task.cancel()
try:
await self._sse_task
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error awaiting SSE listener task during disconnect: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error awaiting SSE listener task during disconnect: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE listener await failure")
self._sse_task = None
except Exception:
pass
import logging
logging.getLogger(__name__).exception("Error cancelling SSE listener task during disconnect")
try:
from ..interface.notifier import notify
notify("warning", "Error cancelling SSE listener task during disconnect")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE listener cancellation error")
try:
if self._sse_response:
try:
await self._sse_response.release()
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Error releasing SSE response during disconnect: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Error releasing SSE response during disconnect: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE response release error")
self._sse_response = None
except Exception:
pass
import logging
logging.getLogger(__name__).exception("Error handling SSE response during disconnect")
try:
from ..interface.notifier import notify
notify("warning", "Error handling SSE response during disconnect")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE response handling error")
# Fail any pending requests
async with self._pending_lock:
@@ -365,17 +415,20 @@ class SSETransport(MCPTransport):
async for raw in resp.content:
try:
line = raw.decode(errors="ignore").rstrip("\r\n")
except Exception:
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to decode SSE raw chunk: %s", e)
continue
if line == "":
# End of event; process accumulated lines
event_name = None
data_lines: list[str] = []
for l in event_lines:
if l.startswith("event:"):
event_name = l.split(":", 1)[1].strip()
elif l.startswith("data:"):
data_lines.append(l.split(":", 1)[1].lstrip())
for evt_line in event_lines:
if evt_line.startswith("event:"):
event_name = evt_line.split(":", 1)[1].strip()
elif evt_line.startswith("data:"):
data_lines.append(evt_line.split(":", 1)[1].lstrip())
if data_lines:
data_text = "\n".join(data_lines)
@@ -392,14 +445,30 @@ class SSETransport(MCPTransport):
self._post_url = f"{p.scheme}://{p.netloc}{endpoint}"
else:
self._post_url = f"{p.scheme}://{p.netloc}/{endpoint.lstrip('/')}"
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed parsing SSE endpoint announcement: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed parsing SSE endpoint announcement: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE endpoint parse failure")
# Notify connect() that endpoint is ready
try:
if self._endpoint_ready and not self._endpoint_ready.is_set():
self._endpoint_ready.set()
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed to set SSE endpoint ready event: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed to set SSE endpoint ready event: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE endpoint ready event failure")
else:
# Try to parse as JSON and resolve pending futures
try:
@@ -410,8 +479,16 @@ class SSETransport(MCPTransport):
fut = self._pending.get(msg_id)
if fut and not fut.done():
fut.set_result(obj)
except Exception:
pass
except Exception as e:
import logging
logging.getLogger(__name__).exception("Failed parsing SSE event JSON or resolving pending future: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Failed parsing SSE event JSON or resolving pending future: {e}")
except Exception:
logging.getLogger(__name__).exception("Failed to notify operator about SSE event parse/future failure")
event_lines = []
else:

View File

@@ -3,6 +3,7 @@
import asyncio
import io
import tarfile
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
@@ -77,7 +78,14 @@ class DockerRuntime(Runtime):
if self.container.status != "running":
self.container.start()
await asyncio.sleep(2) # Wait for container to fully start
except Exception:
except Exception as e:
logging.getLogger(__name__).exception("Failed to get or start existing container: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"DockerRuntime: container check failed: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about docker container check failure: %s", ne)
# Create new container
volumes = {
str(Path.home() / ".pentestagent"): {
@@ -109,8 +117,14 @@ class DockerRuntime(Runtime):
try:
self.container.stop(timeout=10)
self.container.remove()
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Failed stopping/removing container: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"DockerRuntime: failed stopping/removing container: {e}")
except Exception as ne:
logging.getLogger(__name__).exception("Failed to notify operator about docker stop error: %s", ne)
finally:
self.container = None
@@ -261,7 +275,17 @@ class DockerRuntime(Runtime):
try:
self.container.reload()
return self.container.status == "running"
except Exception:
except Exception as e:
logging.getLogger(__name__).exception("Failed to determine container running state: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"DockerRuntime: is_running check failed: {e}")
except Exception as notify_error:
logging.getLogger(__name__).warning(
"Failed to send notification for DockerRuntime.is_running error: %s",
notify_error,
)
return False
async def get_status(self) -> dict:

View File

@@ -2,9 +2,9 @@
import platform
import shutil
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
@@ -304,8 +304,8 @@ def detect_environment() -> EnvironmentInfo:
with open("/proc/version", "r") as f:
if "microsoft" in f.read().lower():
os_name = "Linux (WSL)"
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).debug("WSL detection probe failed: %s", e)
# Detect available tools with categories
available_tools = []
@@ -455,11 +455,14 @@ class LocalRuntime(Runtime):
async def start(self):
"""Start the local runtime."""
self._running = True
# Create organized loot directory structure
Path("loot").mkdir(exist_ok=True)
Path("loot/reports").mkdir(exist_ok=True)
Path("loot/artifacts").mkdir(exist_ok=True)
Path("loot/artifacts/screenshots").mkdir(exist_ok=True)
# Create organized loot directory structure (workspace-aware)
from ..workspaces.utils import get_loot_base
base = get_loot_base()
(base).mkdir(parents=True, exist_ok=True)
(base / "reports").mkdir(parents=True, exist_ok=True)
(base / "artifacts").mkdir(parents=True, exist_ok=True)
(base / "artifacts" / "screenshots").mkdir(parents=True, exist_ok=True)
async def stop(self):
"""Stop the local runtime gracefully."""
@@ -476,8 +479,16 @@ class LocalRuntime(Runtime):
proc.stdout.close()
if proc.stderr:
proc.stderr.close()
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Error cleaning up active process: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Runtime: error cleaning up process: {e}")
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about process cleanup error: %s", ne
)
self._active_processes.clear()
# Clean up browser
@@ -490,29 +501,57 @@ class LocalRuntime(Runtime):
if self._page:
try:
await self._page.close()
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Failed to close browser page: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Runtime: failed to close browser page: {e}")
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about browser page close error: %s", ne
)
self._page = None
if self._browser_context:
try:
await self._browser_context.close()
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Failed to close browser context: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Runtime: failed to close browser context: {e}")
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about browser context close error: %s", ne
)
self._browser_context = None
if self._browser:
try:
await self._browser.close()
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Failed to close browser: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Runtime: failed to close browser: {e}")
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about browser close error: %s", ne
)
self._browser = None
if self._playwright:
try:
await self._playwright.stop()
except Exception:
pass
except Exception as e:
logging.getLogger(__name__).exception("Failed to stop playwright: %s", e)
try:
from ..interface.notifier import notify
notify("warning", f"Runtime: failed to stop playwright: {e}")
except Exception as ne:
logging.getLogger(__name__).exception(
"Failed to notify operator about playwright stop error: %s", ne
)
self._playwright = None
async def _ensure_browser(self):
@@ -651,7 +690,6 @@ class LocalRuntime(Runtime):
elif action == "screenshot":
import time
import uuid
from pathlib import Path
# Navigate first if URL provided
if kwargs.get("url"):
@@ -659,9 +697,10 @@ class LocalRuntime(Runtime):
kwargs["url"], timeout=timeout, wait_until="domcontentloaded"
)
# Save screenshot to loot/artifacts/screenshots/
output_dir = Path("loot/artifacts/screenshots")
output_dir.mkdir(parents=True, exist_ok=True)
# Save screenshot to workspace-aware loot/artifacts/screenshots/
from ..workspaces.utils import get_loot_file
output_dir = get_loot_file("artifacts/screenshots").parent
timestamp = int(time.time())
unique_id = uuid.uuid4().hex[:8]

View File

@@ -9,17 +9,27 @@ from ..registry import ToolSchema, register_tool
# Notes storage - kept at loot root for easy access
_notes: Dict[str, Dict[str, Any]] = {}
_notes_file: Path = Path("loot/notes.json")
# Optional override (tests can call set_notes_file)
_custom_notes_file: Path | None = None
# Lock for safe concurrent access from multiple agents (asyncio since agents are async tasks)
_notes_lock = asyncio.Lock()
def _notes_file_path() -> Path:
from ...workspaces.utils import get_loot_file
if _custom_notes_file:
return _custom_notes_file
return get_loot_file("notes.json")
def _load_notes_unlocked() -> None:
"""Load notes from file (caller must hold lock)."""
global _notes
if _notes_file.exists():
nf = _notes_file_path()
if nf.exists():
try:
loaded = json.loads(_notes_file.read_text(encoding="utf-8"))
loaded = json.loads(nf.read_text(encoding="utf-8"))
# Migration: Convert legacy string values to dicts
_notes = {}
for k, v in loaded.items():
@@ -37,8 +47,9 @@ def _load_notes_unlocked() -> None:
def _save_notes_unlocked() -> None:
"""Save notes to file (caller must hold lock)."""
_notes_file.parent.mkdir(parents=True, exist_ok=True)
_notes_file.write_text(json.dumps(_notes, indent=2), encoding="utf-8")
nf = _notes_file_path()
nf.parent.mkdir(parents=True, exist_ok=True)
nf.write_text(json.dumps(_notes, indent=2), encoding="utf-8")
async def get_all_notes() -> Dict[str, Dict[str, Any]]:
@@ -52,9 +63,9 @@ async def get_all_notes() -> Dict[str, Dict[str, Any]]:
def get_all_notes_sync() -> Dict[str, Dict[str, Any]]:
"""Get all notes synchronously (read-only, best effort for prompts)."""
# If notes are empty, try to load from disk (safe read)
if not _notes and _notes_file.exists():
if not _notes and _notes_file_path().exists():
try:
loaded = json.loads(_notes_file.read_text(encoding="utf-8"))
loaded = json.loads(_notes_file_path().read_text(encoding="utf-8"))
# Migration for sync read
result = {}
for k, v in loaded.items():
@@ -74,14 +85,13 @@ def get_all_notes_sync() -> Dict[str, Dict[str, Any]]:
def set_notes_file(path: Path) -> None:
"""Set custom notes file path."""
global _notes_file
_notes_file = path
global _custom_notes_file
_custom_notes_file = Path(path)
# Can't use async here, so load without lock (called at init time)
_load_notes_unlocked()
# Load notes on module import (init time, no contention yet)
_load_notes_unlocked()
# Defer loading until first access to avoid caching active workspace path at import
# Validation schema - declarative rules for note structure

View File

@@ -11,8 +11,8 @@ from datetime import date
from pathlib import Path
from typing import Any, Dict
# Persistent storage (loot root)
_data_file: Path = Path("loot/token_usage.json")
# Persistent storage (loot root) - compute at use to respect active workspace
_custom_data_file: Path | None = None
_data_lock = threading.Lock()
# In-memory cache
@@ -27,9 +27,15 @@ _data: Dict[str, Any] = {
def _load_unlocked() -> None:
global _data
if _data_file.exists():
data_file = _custom_data_file or None
if not data_file:
from ..workspaces.utils import get_loot_file
data_file = get_loot_file("token_usage.json")
if data_file.exists():
try:
loaded = json.loads(_data_file.read_text(encoding="utf-8"))
loaded = json.loads(data_file.read_text(encoding="utf-8"))
# Merge with defaults to be robust to schema changes
d = {**_data, **(loaded or {})}
_data = d
@@ -45,14 +51,20 @@ def _load_unlocked() -> None:
def _save_unlocked() -> None:
_data_file.parent.mkdir(parents=True, exist_ok=True)
_data_file.write_text(json.dumps(_data, indent=2), encoding="utf-8")
data_file = _custom_data_file or None
if not data_file:
from ..workspaces.utils import get_loot_file
data_file = get_loot_file("token_usage.json")
data_file.parent.mkdir(parents=True, exist_ok=True)
data_file.write_text(json.dumps(_data, indent=2), encoding="utf-8")
def set_data_file(path: Path) -> None:
"""Override the data file (used by tests)."""
global _data_file
_data_file = path
global _custom_data_file
_custom_data_file = Path(path)
_load_unlocked()

View File

@@ -0,0 +1,3 @@
from .manager import TargetManager, WorkspaceError, WorkspaceManager
__all__ = ["WorkspaceManager", "TargetManager", "WorkspaceError"]

View File

@@ -0,0 +1,226 @@
"""WorkspaceManager: file-backed workspace and target management using YAML.
Design goals:
- Workspace metadata stored as YAML at workspaces/{name}/meta.yaml
- Active workspace marker stored at workspaces/.active
- No in-memory caching: all operations read/write files directly
- Lightweight hostname validation; accept IPs, CIDRs, hostnames
"""
import ipaddress
import logging
import re
import time
from pathlib import Path
from typing import List
import yaml
class WorkspaceError(Exception):
pass
WORKSPACES_DIR_NAME = "workspaces"
NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
def _safe_mkdir(path: Path):
path.mkdir(parents=True, exist_ok=True)
class TargetManager:
"""Validate and normalize targets (IP, CIDR, hostname).
Hostname validation is intentionally light: allow letters, digits, hyphens, dots.
"""
HOST_RE = re.compile(r"^[A-Za-z0-9.-]{1,253}$")
@staticmethod
def normalize_target(value: str) -> str:
v = value.strip()
# try CIDR or IP
try:
if "/" in v:
net = ipaddress.ip_network(v, strict=False)
return str(net)
else:
ip = ipaddress.ip_address(v)
return str(ip)
except Exception:
# fallback to hostname validation (light)
if TargetManager.HOST_RE.match(v) and ".." not in v:
return v.lower()
raise WorkspaceError(f"Invalid target: {value}") from None
@staticmethod
def validate(value: str) -> bool:
try:
TargetManager.normalize_target(value)
return True
except WorkspaceError:
return False
class WorkspaceManager:
"""File-backed workspace manager. No persistent in-memory state.
Root defaults to current working directory.
"""
def __init__(self, root: Path = Path(".")):
self.root = Path(root)
self.workspaces_dir = self.root / WORKSPACES_DIR_NAME
_safe_mkdir(self.workspaces_dir)
def validate_name(self, name: str):
if not NAME_RE.match(name):
raise WorkspaceError(
"Invalid workspace name; allowed characters: A-Za-z0-9._- (1-64 chars)"
)
# prevent path traversal and slashes
if "/" in name or ".." in name:
raise WorkspaceError("Invalid workspace name; must not contain '/' or '..'")
def workspace_path(self, name: str) -> Path:
self.validate_name(name)
return self.workspaces_dir / name
def meta_path(self, name: str) -> Path:
return self.workspace_path(name) / "meta.yaml"
def active_marker(self) -> Path:
return self.workspaces_dir / ".active"
def create(self, name: str) -> dict:
self.validate_name(name)
p = self.workspace_path(name)
# create required dirs
for sub in ("loot", "knowledge/sources", "knowledge/embeddings", "notes", "memory"):
_safe_mkdir(p / sub)
# initialize meta if missing
if not self.meta_path(name).exists():
meta = {"name": name, "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"), "targets": []}
self._write_meta(name, meta)
return meta
return self._read_meta(name)
def _read_meta(self, name: str) -> dict:
mp = self.meta_path(name)
if not mp.exists():
return {"name": name, "targets": []}
try:
data = yaml.safe_load(mp.read_text(encoding="utf-8"))
if data is None:
return {"name": name, "targets": []}
# ensure keys
data.setdefault("name", name)
data.setdefault("targets", [])
return data
except Exception as e:
raise WorkspaceError(f"Failed to read meta for {name}: {e}") from e
def _write_meta(self, name: str, meta: dict):
mp = self.meta_path(name)
mp.parent.mkdir(parents=True, exist_ok=True)
mp.write_text(yaml.safe_dump(meta, sort_keys=False), encoding="utf-8")
def set_active(self, name: str):
# ensure workspace exists
self.create(name)
marker = self.active_marker()
marker.write_text(name, encoding="utf-8")
# update last_active_at in meta.yaml
try:
meta = self._read_meta(name)
meta["last_active_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ")
# ensure operator_notes and tool_runs exist
meta.setdefault("operator_notes", "")
meta.setdefault("tool_runs", [])
self._write_meta(name, meta)
except Exception as e:
# Non-fatal - don't block activation on meta write errors, but log for visibility
logging.getLogger(__name__).exception(
"Failed to update meta.yaml for workspace '%s': %s", name, e
)
try:
# Emit operator-visible notification if UI present
from ..interface.notifier import notify
notify("warning", f"Failed to update workspace meta for '{name}': {e}")
except Exception:
# ignore notifier failures
pass
def set_operator_note(self, name: str, note: str) -> dict:
"""Append or set operator_notes for a workspace (plain text)."""
meta = self._read_meta(name)
prev = meta.get("operator_notes", "") or ""
if prev:
new = prev + "\n" + note
else:
new = note
meta["operator_notes"] = new
self._write_meta(name, meta)
return meta
def get_meta_field(self, name: str, field: str):
meta = self._read_meta(name)
return meta.get(field)
def get_active(self) -> str:
marker = self.active_marker()
if not marker.exists():
return ""
return marker.read_text(encoding="utf-8").strip()
def list_workspaces(self) -> List[str]:
if not self.workspaces_dir.exists():
return []
return [p.name for p in self.workspaces_dir.iterdir() if p.is_dir()]
def get_meta(self, name: str) -> dict:
return self._read_meta(name)
def add_targets(self, name: str, values: List[str]) -> List[str]:
# read-modify-write for strict file-backed behavior
meta = self._read_meta(name)
existing = set(meta.get("targets", []))
changed = False
for v in values:
norm = TargetManager.normalize_target(v)
if norm not in existing:
existing.add(norm)
changed = True
if changed:
meta["targets"] = sorted(existing)
self._write_meta(name, meta)
return meta.get("targets", [])
def set_last_target(self, name: str, value: str) -> str:
"""Set the workspace's last used target and ensure it's in the targets list."""
norm = TargetManager.normalize_target(value)
meta = self._read_meta(name)
# ensure targets contains it
existing = set(meta.get("targets", []))
if norm not in existing:
existing.add(norm)
meta["targets"] = sorted(existing)
meta["last_target"] = norm
self._write_meta(name, meta)
return norm
def remove_target(self, name: str, value: str) -> List[str]:
meta = self._read_meta(name)
existing = set(meta.get("targets", []))
norm = TargetManager.normalize_target(value)
if norm in existing:
existing.remove(norm)
meta["targets"] = sorted(existing)
self._write_meta(name, meta)
return meta.get("targets", [])
def list_targets(self, name: str) -> List[str]:
meta = self._read_meta(name)
return meta.get("targets", [])

View File

@@ -0,0 +1,189 @@
"""Utilities to route loot/output into the active workspace or global loot.
All functions are file-backed and do not cache the active workspace selection.
This module will emit a single warning per run if no active workspace is set.
"""
import logging
import shutil
from pathlib import Path
from typing import Optional
from .manager import WorkspaceManager
_WARNED = False
def get_loot_base(root: Optional[Path] = None) -> Path:
"""Return the base loot directory: workspaces/{active}/loot or top-level `loot/`.
Emits a single warning if no workspace is active.
"""
global _WARNED
root = Path(root or "./")
wm = WorkspaceManager(root=root)
active = wm.get_active()
if active:
base = root / "workspaces" / active / "loot"
else:
if not _WARNED:
logging.warning("No active workspace — writing loot to global loot/ directory.")
_WARNED = True
base = root / "loot"
base.mkdir(parents=True, exist_ok=True)
return base
def get_loot_file(relpath: str, root: Optional[Path] = None) -> Path:
"""Return a Path for a file under the loot base, creating parent dirs.
Example: get_loot_file('artifacts/hexstrike.log')
"""
base = get_loot_base(root=root)
p = base / relpath
p.parent.mkdir(parents=True, exist_ok=True)
return p
def resolve_knowledge_paths(root: Optional[Path] = None) -> dict:
"""Resolve knowledge-related paths, preferring active workspace if present.
Returns a dict with keys: base, sources, embeddings, graph, index, using_workspace
"""
root = Path(root or "./")
wm = WorkspaceManager(root=root)
active = wm.get_active()
global_base = root / "knowledge"
workspace_base = root / "workspaces" / active / "knowledge" if active else None
use_workspace = False
if workspace_base and workspace_base.exists():
# prefer workspace if it has any content (explicit opt-in)
try:
# Use a non-recursive check to avoid walking the entire directory tree
if any(workspace_base.iterdir()):
use_workspace = True
# Also allow an explicit opt-in marker file .use_workspace
elif (workspace_base / ".use_workspace").exists():
use_workspace = True
except Exception as e:
logging.getLogger(__name__).exception(
"Error while checking workspace knowledge directory: %s", e
)
use_workspace = False
if use_workspace:
base = workspace_base
else:
base = global_base
paths = {
"base": base,
"sources": base / "sources",
"embeddings": base / "embeddings",
"graph": base / "graph",
"index": base / "index",
"using_workspace": use_workspace,
}
return paths
def export_workspace(name: str, output: Optional[Path] = None, root: Optional[Path] = None) -> Path:
"""Create a deterministic tar.gz archive of workspaces/{name}/ and return the archive path.
Excludes __pycache__ and *.pyc. Does not mutate workspace.
"""
import tarfile
root = Path(root or "./")
ws_dir = root / "workspaces" / name
if not ws_dir.exists() or not ws_dir.is_dir():
raise FileNotFoundError(f"Workspace not found: {name}")
out_path = Path(output) if output else Path(f"{name}-workspace.tar.gz")
# Use deterministic ordering
entries = []
for p in ws_dir.rglob("*"):
# skip __pycache__ and .pyc
if "__pycache__" in p.parts:
continue
if p.suffix == ".pyc":
continue
rel = p.relative_to(root)
entries.append(rel)
entries = sorted(entries, key=str)
# Create tar.gz
with tarfile.open(out_path, "w:gz") as tf:
for rel in entries:
full = root / rel
# store with relative path (preserve workspaces/<name>/...)
tf.add(str(full), arcname=str(rel))
return out_path
def import_workspace(archive: Path, root: Optional[Path] = None) -> str:
"""Import a workspace tar.gz into workspaces/. Returns workspace name.
Fails if workspace already exists. Requires meta.yaml present in archive.
"""
import tarfile
import tempfile
root = Path(root or "./")
archive = Path(archive)
if not archive.exists():
raise FileNotFoundError(f"Archive not found: {archive}")
with tempfile.TemporaryDirectory() as td:
tdpath = Path(td)
with tarfile.open(archive, "r:gz") as tf:
tf.extractall(path=tdpath)
# Look for workspaces/<name>/meta.yaml or meta.yaml at root
candidates = list(tdpath.rglob("meta.yaml"))
if not candidates:
raise ValueError("No meta.yaml found in archive")
meta_file = candidates[0]
# read name
import yaml
meta = yaml.safe_load(meta_file.read_text(encoding="utf-8"))
if not meta or not meta.get("name"):
raise ValueError("meta.yaml missing 'name' field")
name = meta["name"]
dest = root / "workspaces" / name
if dest.exists():
raise FileExistsError(f"Workspace already exists: {name}")
# Move extracted tree into place
# Find root folder under tdpath that contains the workspace files
# If archive stored paths with workspaces/<name>/..., move that subtree
candidate_root = None
for p in tdpath.iterdir():
if p.is_dir() and p.name == "workspaces":
candidate_root = p / name
break
if candidate_root and candidate_root.exists():
# move candidate_root to dest (use shutil.move to support cross-filesystem)
dest.parent.mkdir(parents=True, exist_ok=True)
try:
shutil.move(str(candidate_root), str(dest))
except Exception as e:
raise RuntimeError(f"Failed to move workspace subtree into place: {e}") from e
else:
# Otherwise, assume contents are directly the workspace folder
# move the parent of meta_file (or its containing dir)
src = meta_file.parent
dest.parent.mkdir(parents=True, exist_ok=True)
try:
shutil.move(str(src), str(dest))
except Exception as e:
raise RuntimeError(f"Failed to move extracted workspace into place: {e}") from e
return name

View File

@@ -0,0 +1,115 @@
"""Workspace target validation utilities.
Provides helpers to extract candidate targets from arbitrary tool arguments
and to determine whether a candidate target is covered by the allowed
workspace targets (IP, CIDR, hostname).
"""
import ipaddress
import logging
from typing import Any, List
from .manager import TargetManager
def gather_candidate_targets(obj: Any) -> List[str]:
"""
Extract candidate target strings from arguments (shallow, non-recursive).
This function inspects only the top-level of the provided object (str or dict)
and collects values for common target keys (e.g., 'target', 'host', 'ip', etc.).
It does NOT recurse into nested dictionaries or lists. If you need to extract
targets from deeply nested structures, you must implement or call a recursive
extractor separately.
Rationale: Shallow extraction is fast and predictable for most tool argument
schemas. For recursive extraction, see the project documentation or extend
this function as needed.
"""
candidates: List[str] = []
if isinstance(obj, str):
candidates.append(obj)
elif isinstance(obj, dict):
for k, v in obj.items():
if k.lower() in (
"target",
"host",
"hostname",
"ip",
"address",
"url",
"hosts",
"targets",
):
if isinstance(v, (list, tuple)):
for it in v:
if isinstance(it, str):
candidates.append(it)
elif isinstance(v, str):
candidates.append(v)
return candidates
def is_target_in_scope(candidate: str, allowed: List[str]) -> bool:
"""Check whether `candidate` is covered by any entry in `allowed`.
Allowed entries may be IPs, CIDRs, or hostnames/labels. Candidate may
also be an IP, CIDR, or hostname. The function normalizes inputs and
performs robust comparisons for networks and addresses.
"""
try:
norm = TargetManager.normalize_target(candidate)
except Exception:
return False
# If candidate is a network (contains '/'), treat as network
try:
if "/" in norm:
cand_net = ipaddress.ip_network(norm, strict=False)
for a in allowed:
try:
if "/" in a:
an = ipaddress.ip_network(a, strict=False)
if cand_net.subnet_of(an) or cand_net == an:
return True
else:
# allowed is IP or hostname; accept only when candidate
# network represents exactly one address equal to allowed IP
try:
allowed_ip = ipaddress.ip_address(a)
except Exception:
# not an IP (likely hostname) - skip
continue
if cand_net.num_addresses == 1 and cand_net.network_address == allowed_ip:
return True
except Exception:
continue
return False
else:
# candidate is a single IP/hostname
try:
cand_ip = ipaddress.ip_address(norm)
for a in allowed:
try:
if "/" in a:
an = ipaddress.ip_network(a, strict=False)
if cand_ip in an:
return True
else:
if TargetManager.normalize_target(a) == norm:
return True
except Exception:
if isinstance(a, str) and a.lower() == norm.lower():
return True
return False
except Exception:
# candidate is likely a hostname; compare case-insensitively
for a in allowed:
try:
if a.lower() == norm.lower():
return True
except Exception:
continue
return False
except Exception as e:
logging.getLogger(__name__).exception("Error checking target scope: %s", e)
return False

View File

@@ -126,6 +126,8 @@ known_first_party = ["pentestagent"]
line-length = 88
target-version = "py310"
exclude = ["third_party/"]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors

View File

@@ -36,7 +36,7 @@ typer>=0.12.0
pydantic>=2.7.0
pydantic-settings>=2.2.0
python-dotenv>=1.0.0
pyyaml>=6.0.0
PyYAML>=6.0
jinja2>=3.1.0
# Dev

View File

@@ -1,110 +0,0 @@
"""Tests for the agent state management."""
import pytest
from pentestagent.agents.state import AgentState, AgentStateManager, StateTransition
class TestAgentState:
"""Tests for AgentState enum."""
def test_state_values(self):
"""Test state enum values."""
assert AgentState.IDLE.value == "idle"
assert AgentState.THINKING.value == "thinking"
assert AgentState.EXECUTING.value == "executing"
assert AgentState.WAITING_INPUT.value == "waiting_input"
assert AgentState.COMPLETE.value == "complete"
assert AgentState.ERROR.value == "error"
def test_all_states_exist(self):
"""Test that all expected states exist."""
states = list(AgentState)
assert len(states) >= 6
class TestAgentStateManager:
"""Tests for AgentStateManager class."""
@pytest.fixture
def state_manager(self):
"""Create a fresh AgentStateManager for each test."""
return AgentStateManager()
def test_initial_state(self, state_manager):
"""Test initial state is IDLE."""
assert state_manager.current_state == AgentState.IDLE
assert len(state_manager.history) == 0
def test_valid_transition(self, state_manager):
"""Test valid state transition."""
result = state_manager.transition_to(AgentState.THINKING)
assert result is True
assert state_manager.current_state == AgentState.THINKING
assert len(state_manager.history) == 1
def test_invalid_transition(self, state_manager):
"""Test invalid state transition."""
result = state_manager.transition_to(AgentState.COMPLETE)
assert result is False
assert state_manager.current_state == AgentState.IDLE
def test_transition_chain(self, state_manager):
"""Test a chain of valid transitions."""
assert state_manager.transition_to(AgentState.THINKING)
assert state_manager.transition_to(AgentState.EXECUTING)
assert state_manager.transition_to(AgentState.THINKING)
assert state_manager.transition_to(AgentState.COMPLETE)
assert state_manager.current_state == AgentState.COMPLETE
assert len(state_manager.history) == 4
def test_force_transition(self, state_manager):
"""Test forcing a transition."""
state_manager.force_transition(AgentState.ERROR, reason="Test error")
assert state_manager.current_state == AgentState.ERROR
assert "FORCED" in state_manager.history[-1].reason
def test_reset(self, state_manager):
"""Test resetting state."""
state_manager.transition_to(AgentState.THINKING)
state_manager.transition_to(AgentState.EXECUTING)
state_manager.reset()
assert state_manager.current_state == AgentState.IDLE
assert len(state_manager.history) == 0
def test_is_terminal(self, state_manager):
"""Test terminal state detection."""
assert state_manager.is_terminal() is False
state_manager.transition_to(AgentState.THINKING)
state_manager.transition_to(AgentState.COMPLETE)
assert state_manager.is_terminal() is True
def test_is_active(self, state_manager):
"""Test active state detection."""
assert state_manager.is_active() is False
state_manager.transition_to(AgentState.THINKING)
assert state_manager.is_active() is True
class TestStateTransition:
"""Tests for StateTransition dataclass."""
def test_create_transition(self):
"""Test creating a state transition."""
transition = StateTransition(
from_state=AgentState.IDLE,
to_state=AgentState.THINKING,
reason="Starting work"
)
assert transition.from_state == AgentState.IDLE
assert transition.to_state == AgentState.THINKING
assert transition.reason == "Starting work"
assert transition.timestamp is not None

View File

@@ -1,251 +0,0 @@
"""Tests for the Shadow Graph knowledge system."""
import networkx as nx
import pytest
from pentestagent.knowledge.graph import ShadowGraph
class TestShadowGraph:
"""Tests for ShadowGraph class."""
@pytest.fixture
def graph(self):
"""Create a fresh ShadowGraph for each test."""
return ShadowGraph()
def test_initialization(self, graph):
"""Test graph initialization."""
assert isinstance(graph.graph, nx.DiGraph)
assert len(graph.graph.nodes) == 0
assert len(graph._processed_notes) == 0
def test_extract_host_from_note(self, graph):
"""Test extracting host IP from a note."""
notes = {
"scan_result": {
"content": "Nmap scan for 192.168.1.10 shows open ports.",
"category": "info"
}
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:192.168.1.10")
node = graph.graph.nodes["host:192.168.1.10"]
assert node["type"] == "host"
assert node["label"] == "192.168.1.10"
def test_extract_service_finding(self, graph):
"""Test extracting services from a finding note."""
notes = {
"ports_scan": {
"content": "Found open ports: 80/tcp, 443/tcp on 10.0.0.5",
"category": "finding",
"metadata": {
"target": "10.0.0.5",
"services": [
{"port": 80, "protocol": "tcp", "service": "http"},
{"port": 443, "protocol": "tcp", "service": "https"}
]
}
}
}
graph.update_from_notes(notes)
# Check host exists
assert graph.graph.has_node("host:10.0.0.5")
# Check services exist
assert graph.graph.has_node("service:host:10.0.0.5:80")
assert graph.graph.has_node("service:host:10.0.0.5:443")
# Check edges
assert graph.graph.has_edge("host:10.0.0.5", "service:host:10.0.0.5:80")
edge = graph.graph.edges["host:10.0.0.5", "service:host:10.0.0.5:80"]
assert edge["type"] == "HAS_SERVICE"
assert edge["protocol"] == "tcp"
def test_extract_credential(self, graph):
"""Test extracting credentials and linking to host."""
notes = {
"ssh_creds": {
"content": "Found user: admin with password 'password123' for SSH on 192.168.1.20",
"category": "credential",
"metadata": {
"target": "192.168.1.20",
"username": "admin",
"password": "password123",
"protocol": "ssh"
}
}
}
graph.update_from_notes(notes)
cred_id = "cred:ssh_creds"
host_id = "host:192.168.1.20"
assert graph.graph.has_node(cred_id)
assert graph.graph.has_node(host_id)
# Check edge
assert graph.graph.has_edge(cred_id, host_id)
edge = graph.graph.edges[cred_id, host_id]
assert edge["type"] == "AUTH_ACCESS"
assert edge["protocol"] == "ssh"
def test_extract_credential_variations(self, graph):
"""Test different credential formats."""
notes = {
"creds_1": {
"content": "Username: root, Password: toor",
"category": "credential"
},
"creds_2": {
"content": "Just a password: secret",
"category": "credential"
}
}
graph.update_from_notes(notes)
# Check "Username: root" extraction
node1 = graph.graph.nodes["cred:creds_1"]
assert node1["label"] == "Creds (root)"
# Check fallback for no username
node2 = graph.graph.nodes["cred:creds_2"]
assert node2["label"] == "Credentials"
def test_metadata_extraction(self, graph):
"""Test extracting entities from structured metadata."""
notes = {
"meta_cred": {
"content": "Some random text",
"category": "credential",
"metadata": {
"username": "admin_meta",
"target": "10.0.0.99",
"source": "10.0.0.1"
}
},
"meta_vuln": {
"content": "Bad stuff",
"category": "vulnerability",
"metadata": {
"cve": "CVE-2025-1234",
"target": "10.0.0.99"
}
}
}
graph.update_from_notes(notes)
# Check Credential Metadata
cred_node = graph.graph.nodes["cred:meta_cred"]
assert cred_node["label"] == "Creds (admin_meta)"
# Check Target Host
assert graph.graph.has_node("host:10.0.0.99")
assert graph.graph.has_edge("cred:meta_cred", "host:10.0.0.99")
# Check Source Host (CONTAINS edge)
assert graph.graph.has_node("host:10.0.0.1")
assert graph.graph.has_edge("host:10.0.0.1", "cred:meta_cred")
# Check Vulnerability Metadata
vuln_node = graph.graph.nodes["vuln:meta_vuln"]
assert vuln_node["label"] == "CVE-2025-1234"
assert graph.graph.has_edge("host:10.0.0.99", "vuln:meta_vuln")
def test_url_metadata(self, graph):
"""Test that URL metadata is added to service labels."""
notes = {
"web_app": {
"content": "Admin panel found",
"category": "finding",
"metadata": {
"target": "10.0.0.5",
"port": "80/tcp",
"url": "http://10.0.0.5/admin"
}
}
}
graph.update_from_notes(notes)
service_id = "service:host:10.0.0.5:80"
assert graph.graph.has_node(service_id)
node = graph.graph.nodes[service_id]
assert "http://10.0.0.5/admin" in node["label"]
def test_legacy_note_format(self, graph):
"""Test handling legacy string-only notes."""
notes = {
"legacy_note": "Just a simple note about 10.10.10.10"
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:10.10.10.10")
def test_idempotency(self, graph):
"""Test that processing the same note twice doesn't duplicate or error."""
notes = {
"scan": {
"content": "Host 192.168.1.1 is up.",
"category": "info"
}
}
# First pass
graph.update_from_notes(notes)
assert len(graph.graph.nodes) == 1
# Second pass
graph.update_from_notes(notes)
assert len(graph.graph.nodes) == 1
def test_attack_paths(self, graph):
"""Test detection of multi-step attack paths."""
# Manually construct a path: Cred1 -> HostA -> Cred2 -> HostB
# 1. Cred1 gives access to HostA
graph._add_node("cred:1", "credential", "Root Creds")
graph._add_node("host:A", "host", "10.0.0.1")
graph._add_edge("cred:1", "host:A", "AUTH_ACCESS")
# 2. HostA has Cred2 (this edge type isn't auto-extracted yet, but logic should handle it)
graph._add_node("cred:2", "credential", "Db Admin")
graph._add_edge("host:A", "cred:2", "CONTAINS_CRED")
# 3. Cred2 gives access to HostB
graph._add_node("host:B", "host", "10.0.0.2")
graph._add_edge("cred:2", "host:B", "AUTH_ACCESS")
paths = graph._find_attack_paths()
assert len(paths) == 1
assert "Root Creds" in paths[0]
assert "10.0.0.1" in paths[0]
assert "Db Admin" in paths[0]
assert "10.0.0.2" in paths[0]
def test_mermaid_export(self, graph):
"""Test Mermaid diagram generation."""
graph._add_node("host:1", "host", "10.0.0.1")
graph._add_node("cred:1", "credential", "admin")
graph._add_edge("cred:1", "host:1", "AUTH_ACCESS")
mermaid = graph.to_mermaid()
assert "graph TD" in mermaid
assert 'host_1["🖥️ 10.0.0.1"]' in mermaid
assert 'cred_1["🔑 admin"]' in mermaid
assert "cred_1 -->|AUTH_ACCESS| host_1" in mermaid
def test_multiple_ips_in_one_note(self, graph):
"""Test a single note referencing multiple hosts."""
notes = {
"subnet_scan": {
"content": "Scanning 192.168.1.1, 192.168.1.2, and 192.168.1.3",
"category": "info"
}
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:192.168.1.1")
assert graph.graph.has_node("host:192.168.1.2")
assert graph.graph.has_node("host:192.168.1.3")

View File

@@ -1,99 +0,0 @@
"""Tests for the RAG knowledge system."""
import numpy as np
import pytest
from pentestagent.knowledge.rag import Document, RAGEngine
class TestDocument:
"""Tests for Document dataclass."""
def test_create_document(self):
"""Test creating a document."""
doc = Document(content="Test content", source="test.md")
assert doc.content == "Test content"
assert doc.source == "test.md"
assert doc.metadata == {}
assert doc.doc_id is not None
def test_document_with_metadata(self):
"""Test document with metadata."""
doc = Document(
content="Test",
source="test.md",
metadata={"cve_id": "CVE-2021-1234", "severity": "high"}
)
assert doc.metadata["cve_id"] == "CVE-2021-1234"
assert doc.metadata["severity"] == "high"
def test_document_with_embedding(self):
"""Test document with embedding."""
embedding = np.random.rand(384)
doc = Document(content="Test", source="test.md", embedding=embedding)
assert doc.embedding is not None
assert len(doc.embedding) == 384
def test_document_with_custom_id(self):
"""Test document with custom doc_id."""
doc = Document(content="Test", source="test.md", doc_id="custom-id-123")
assert doc.doc_id == "custom-id-123"
class TestRAGEngine:
"""Tests for RAGEngine class."""
@pytest.fixture
def rag_engine(self, tmp_path):
"""Create a RAG engine for testing."""
return RAGEngine(
knowledge_path=tmp_path / "knowledge",
use_local_embeddings=True
)
def test_create_engine(self, rag_engine):
"""Test creating a RAG engine."""
assert rag_engine is not None
assert len(rag_engine.documents) == 0
assert rag_engine.embeddings is None
def test_get_document_count_empty(self, rag_engine):
"""Test document count on empty engine."""
assert rag_engine.get_document_count() == 0
def test_clear(self, rag_engine):
"""Test clearing the engine."""
rag_engine.documents.append(Document(content="test", source="test.md"))
rag_engine.embeddings = np.random.rand(1, 384)
rag_engine._indexed = True
rag_engine.clear()
assert len(rag_engine.documents) == 0
assert rag_engine.embeddings is None
assert not rag_engine._indexed
class TestRAGEngineChunking:
"""Tests for text chunking functionality."""
@pytest.fixture
def engine(self, tmp_path):
"""Create engine for chunking tests."""
return RAGEngine(knowledge_path=tmp_path)
def test_chunk_short_text(self, engine):
"""Test chunking text shorter than chunk size."""
text = "This is a short paragraph.\n\nThis is another paragraph."
chunks = engine._chunk_text(text, source="test.md", chunk_size=1000)
assert len(chunks) >= 1
assert all(isinstance(c, Document) for c in chunks)
def test_chunk_preserves_source(self, engine):
"""Test that chunking preserves source information."""
text = "Test paragraph 1.\n\nTest paragraph 2."
chunks = engine._chunk_text(text, source="my_source.md")
assert all(c.source == "my_source.md" for c in chunks)

View File

@@ -1,159 +0,0 @@
"""Tests for the Notes tool."""
import json
import pytest
from pentestagent.tools.notes import _notes, get_all_notes, notes, set_notes_file
# We need to reset the global state for tests
@pytest.fixture(autouse=True)
def reset_notes_state(tmp_path):
"""Reset the notes global state for each test."""
# Point to a temp file
temp_notes_file = tmp_path / "notes.json"
set_notes_file(temp_notes_file)
# Clear the global dictionary (it's imported from the module)
# We need to clear the actual dictionary object in the module
_notes.clear()
yield
# Cleanup is handled by tmp_path
@pytest.mark.asyncio
async def test_create_note():
"""Test creating a new note."""
args = {
"action": "create",
"key": "test_note",
"value": "This is a test note",
"category": "info",
"confidence": "high"
}
result = await notes(args, runtime=None)
assert "Created note 'test_note'" in result
all_notes = await get_all_notes()
assert "test_note" in all_notes
assert all_notes["test_note"]["content"] == "This is a test note"
assert all_notes["test_note"]["category"] == "info"
assert all_notes["test_note"]["confidence"] == "high"
@pytest.mark.asyncio
async def test_read_note():
"""Test reading an existing note."""
# Create first
await notes({
"action": "create",
"key": "read_me",
"value": "Content to read"
}, runtime=None)
# Read
result = await notes({
"action": "read",
"key": "read_me"
}, runtime=None)
assert "Content to read" in result
# The format is "[key] (category, confidence, status) content"
assert "(info, medium, confirmed)" in result
@pytest.mark.asyncio
async def test_update_note():
"""Test updating a note."""
await notes({
"action": "create",
"key": "update_me",
"value": "Original content"
}, runtime=None)
result = await notes({
"action": "update",
"key": "update_me",
"value": "New content"
}, runtime=None)
assert "Updated note 'update_me'" in result
all_notes = await get_all_notes()
assert all_notes["update_me"]["content"] == "New content"
@pytest.mark.asyncio
async def test_delete_note():
"""Test deleting a note."""
await notes({
"action": "create",
"key": "delete_me",
"value": "Bye bye"
}, runtime=None)
result = await notes({
"action": "delete",
"key": "delete_me"
}, runtime=None)
assert "Deleted note 'delete_me'" in result
all_notes = await get_all_notes()
assert "delete_me" not in all_notes
@pytest.mark.asyncio
async def test_list_notes():
"""Test listing all notes."""
await notes({"action": "create", "key": "n1", "value": "v1"}, runtime=None)
await notes({"action": "create", "key": "n2", "value": "v2"}, runtime=None)
result = await notes({"action": "list"}, runtime=None)
assert "n1" in result
assert "n2" in result
assert "Notes (2 entries):" in result
@pytest.mark.asyncio
async def test_persistence(tmp_path):
"""Test that notes are saved to disk."""
# The fixture already sets a temp file
temp_file = tmp_path / "notes.json"
await notes({
"action": "create",
"key": "persistent_note",
"value": "I survive restarts"
}, runtime=None)
assert temp_file.exists()
content = json.loads(temp_file.read_text())
assert "persistent_note" in content
assert content["persistent_note"]["content"] == "I survive restarts"
@pytest.mark.asyncio
async def test_legacy_migration(tmp_path):
"""Test migration of legacy string notes."""
# Create a legacy file
legacy_file = tmp_path / "legacy_notes.json"
legacy_data = {
"old_note": "Just a string",
"new_note": {"content": "A dict", "category": "info"}
}
legacy_file.write_text(json.dumps(legacy_data))
# Point the tool to this file
set_notes_file(legacy_file)
# Trigger load (get_all_notes calls _load_notes_unlocked if empty, but we need to clear first)
_notes.clear()
all_notes = await get_all_notes()
assert "old_note" in all_notes
assert isinstance(all_notes["old_note"], dict)
assert all_notes["old_note"]["content"] == "Just a string"
assert all_notes["old_note"]["category"] == "info"
assert "new_note" in all_notes
assert all_notes["new_note"]["content"] == "A dict"

View File

@@ -1,146 +0,0 @@
"""Tests for the tool system."""
import pytest
from pentestagent.tools import (
ToolSchema,
disable_tool,
enable_tool,
get_all_tools,
get_tool,
get_tool_names,
register_tool,
)
class TestToolRegistry:
"""Tests for tool registry functions."""
def test_tools_loaded(self):
"""Test that built-in tools are loaded."""
tools = get_all_tools()
assert len(tools) > 0
tool_names = get_tool_names()
assert "terminal" in tool_names
assert "browser" in tool_names
def test_get_tool(self):
"""Test getting a tool by name."""
tool = get_tool("terminal")
assert tool is not None
assert tool.name == "terminal"
assert tool.category == "execution"
def test_get_nonexistent_tool(self):
"""Test getting a tool that doesn't exist."""
tool = get_tool("nonexistent_tool_xyz")
assert tool is None
def test_disable_enable_tool(self):
"""Test disabling and enabling a tool."""
result = disable_tool("terminal")
assert result is True
tool = get_tool("terminal")
assert tool.enabled is False
result = enable_tool("terminal")
assert result is True
tool = get_tool("terminal")
assert tool.enabled is True
def test_disable_nonexistent_tool(self):
"""Test disabling a tool that doesn't exist."""
result = disable_tool("nonexistent_tool_xyz")
assert result is False
class TestToolSchema:
"""Tests for ToolSchema class."""
def test_create_schema(self):
"""Test creating a tool schema."""
schema = ToolSchema(
properties={
"command": {"type": "string", "description": "Command to run"}
},
required=["command"]
)
assert schema.type == "object"
assert "command" in schema.properties
assert "command" in schema.required
def test_schema_to_dict(self):
"""Test converting schema to dictionary."""
schema = ToolSchema(
properties={"input": {"type": "string"}},
required=["input"]
)
d = schema.to_dict()
assert d["type"] == "object"
assert d["properties"]["input"]["type"] == "string"
assert d["required"] == ["input"]
class TestTool:
"""Tests for Tool class."""
def test_create_tool(self, sample_tool):
"""Test creating a tool."""
assert sample_tool.name == "test_tool"
assert sample_tool.description == "A test tool"
assert sample_tool.category == "test"
assert sample_tool.enabled is True
def test_tool_to_llm_format(self, sample_tool):
"""Test converting tool to LLM format."""
llm_format = sample_tool.to_llm_format()
assert llm_format["type"] == "function"
assert llm_format["function"]["name"] == "test_tool"
assert llm_format["function"]["description"] == "A test tool"
assert "parameters" in llm_format["function"]
def test_tool_validate_arguments(self, sample_tool):
"""Test argument validation."""
is_valid, error = sample_tool.validate_arguments({"param": "value"})
assert is_valid is True
assert error is None
is_valid, error = sample_tool.validate_arguments({})
assert is_valid is False
assert "param" in error
@pytest.mark.asyncio
async def test_tool_execute(self, sample_tool):
"""Test tool execution."""
result = await sample_tool.execute({"param": "test"}, runtime=None)
assert "test" in result
class TestRegisterToolDecorator:
"""Tests for register_tool decorator."""
def test_decorator_registers_tool(self):
"""Test that decorator registers a new tool."""
initial_count = len(get_all_tools())
@register_tool(
name="pytest_test_tool_unique",
description="A tool registered in tests",
schema=ToolSchema(properties={}, required=[]),
category="test"
)
async def pytest_test_tool_unique(arguments, runtime):
return "test result"
new_count = len(get_all_tools())
assert new_count == initial_count + 1
tool = get_tool("pytest_test_tool_unique")
assert tool is not None
assert tool.name == "pytest_test_tool_unique"

View File

@@ -17,17 +17,17 @@ Architecture: MCP Client for AI agent communication with HexStrike server
Framework: FastMCP integration for tool orchestration
"""
import sys
import os
import argparse
import logging
from typing import Dict, Any, Optional
import requests
import sys
import time
from datetime import datetime
from typing import Any, Dict, Optional
import requests
from mcp.server.fastmcp import FastMCP
class HexStrikeColors:
"""Enhanced color palette matching the server's ModernVisualEngine.COLORS"""
@@ -447,9 +447,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"☁️ Starting Prowler {provider} security assessment")
result = hexstrike_client.safe_post("api/tools/prowler", data)
if result.get("success"):
logger.info(f"✅ Prowler assessment completed")
logger.info("✅ Prowler assessment completed")
else:
logger.error(f"❌ Prowler assessment failed")
logger.error("❌ Prowler assessment failed")
return result
@mcp.tool()
@@ -517,9 +517,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"☁️ Starting Scout Suite {provider} assessment")
result = hexstrike_client.safe_post("api/tools/scout-suite", data)
if result.get("success"):
logger.info(f"✅ Scout Suite assessment completed")
logger.info("✅ Scout Suite assessment completed")
else:
logger.error(f"❌ Scout Suite assessment failed")
logger.error("❌ Scout Suite assessment failed")
return result
@mcp.tool()
@@ -575,12 +575,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"regions": regions,
"additional_args": additional_args
}
logger.info(f"☁️ Starting Pacu AWS exploitation")
logger.info("☁️ Starting Pacu AWS exploitation")
result = hexstrike_client.safe_post("api/tools/pacu", data)
if result.get("success"):
logger.info(f"✅ Pacu exploitation completed")
logger.info("✅ Pacu exploitation completed")
else:
logger.error(f"❌ Pacu exploitation failed")
logger.error("❌ Pacu exploitation failed")
return result
@mcp.tool()
@@ -611,12 +611,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"report": report,
"additional_args": additional_args
}
logger.info(f"☁️ Starting kube-hunter Kubernetes scan")
logger.info("☁️ Starting kube-hunter Kubernetes scan")
result = hexstrike_client.safe_post("api/tools/kube-hunter", data)
if result.get("success"):
logger.info(f"✅ kube-hunter scan completed")
logger.info("✅ kube-hunter scan completed")
else:
logger.error(f"❌ kube-hunter scan failed")
logger.error("❌ kube-hunter scan failed")
return result
@mcp.tool()
@@ -642,12 +642,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"output_format": output_format,
"additional_args": additional_args
}
logger.info(f"☁️ Starting kube-bench CIS benchmark")
logger.info("☁️ Starting kube-bench CIS benchmark")
result = hexstrike_client.safe_post("api/tools/kube-bench", data)
if result.get("success"):
logger.info(f"✅ kube-bench benchmark completed")
logger.info("✅ kube-bench benchmark completed")
else:
logger.error(f"❌ kube-bench benchmark failed")
logger.error("❌ kube-bench benchmark failed")
return result
@mcp.tool()
@@ -672,12 +672,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"output_file": output_file,
"additional_args": additional_args
}
logger.info(f"🐳 Starting Docker Bench Security assessment")
logger.info("🐳 Starting Docker Bench Security assessment")
result = hexstrike_client.safe_post("api/tools/docker-bench-security", data)
if result.get("success"):
logger.info(f"✅ Docker Bench Security completed")
logger.info("✅ Docker Bench Security completed")
else:
logger.error(f"❌ Docker Bench Security failed")
logger.error("❌ Docker Bench Security failed")
return result
@mcp.tool()
@@ -736,9 +736,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🛡️ Starting Falco runtime monitoring for {duration}s")
result = hexstrike_client.safe_post("api/tools/falco", data)
if result.get("success"):
logger.info(f"✅ Falco monitoring completed")
logger.info("✅ Falco monitoring completed")
else:
logger.error(f"❌ Falco monitoring failed")
logger.error("❌ Falco monitoring failed")
return result
@mcp.tool()
@@ -770,9 +770,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔍 Starting Checkov IaC scan: {directory}")
result = hexstrike_client.safe_post("api/tools/checkov", data)
if result.get("success"):
logger.info(f"✅ Checkov scan completed")
logger.info("✅ Checkov scan completed")
else:
logger.error(f"❌ Checkov scan failed")
logger.error("❌ Checkov scan failed")
return result
@mcp.tool()
@@ -804,9 +804,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔍 Starting Terrascan IaC scan: {iac_dir}")
result = hexstrike_client.safe_post("api/tools/terrascan", data)
if result.get("success"):
logger.info(f"✅ Terrascan scan completed")
logger.info("✅ Terrascan scan completed")
else:
logger.error(f"❌ Terrascan scan failed")
logger.error("❌ Terrascan scan failed")
return result
# ============================================================================
@@ -932,9 +932,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🎯 Generating {payload_type} payload: {size} bytes")
result = hexstrike_client.safe_post("api/payloads/generate", data)
if result.get("success"):
logger.info(f"✅ Payload generated successfully")
logger.info("✅ Payload generated successfully")
else:
logger.error(f"❌ Failed to generate payload")
logger.error("❌ Failed to generate payload")
return result
# ============================================================================
@@ -988,9 +988,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🐍 Executing Python script in env {env_name}")
result = hexstrike_client.safe_post("api/python/execute", data)
if result.get("success"):
logger.info(f"✅ Python script executed successfully")
logger.info("✅ Python script executed successfully")
else:
logger.error(f"❌ Python script execution failed")
logger.error("❌ Python script execution failed")
return result
# ============================================================================
@@ -1167,9 +1167,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔐 Starting John the Ripper: {hash_file}")
result = hexstrike_client.safe_post("api/tools/john", data)
if result.get("success"):
logger.info(f"✅ John the Ripper completed")
logger.info("✅ John the Ripper completed")
else:
logger.error(f"❌ John the Ripper failed")
logger.error("❌ John the Ripper failed")
return result
@mcp.tool()
@@ -1337,9 +1337,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔐 Starting Hashcat attack: mode {attack_mode}")
result = hexstrike_client.safe_post("api/tools/hashcat", data)
if result.get("success"):
logger.info(f"✅ Hashcat attack completed")
logger.info("✅ Hashcat attack completed")
else:
logger.error(f"❌ Hashcat attack failed")
logger.error("❌ Hashcat attack failed")
return result
@mcp.tool()
@@ -1690,9 +1690,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔍 Starting arp-scan: {target if target else 'local network'}")
result = hexstrike_client.safe_post("api/tools/arp-scan", data)
if result.get("success"):
logger.info(f"✅ arp-scan completed")
logger.info("✅ arp-scan completed")
else:
logger.error(f"❌ arp-scan failed")
logger.error("❌ arp-scan failed")
return result
@mcp.tool()
@@ -1727,9 +1727,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔍 Starting Responder on interface: {interface}")
result = hexstrike_client.safe_post("api/tools/responder", data)
if result.get("success"):
logger.info(f"✅ Responder completed")
logger.info("✅ Responder completed")
else:
logger.error(f"❌ Responder failed")
logger.error("❌ Responder failed")
return result
@mcp.tool()
@@ -1755,9 +1755,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🧠 Starting Volatility analysis: {plugin}")
result = hexstrike_client.safe_post("api/tools/volatility", data)
if result.get("success"):
logger.info(f"✅ Volatility analysis completed")
logger.info("✅ Volatility analysis completed")
else:
logger.error(f"❌ Volatility analysis failed")
logger.error("❌ Volatility analysis failed")
return result
@mcp.tool()
@@ -1787,9 +1787,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🚀 Starting MSFVenom payload generation: {payload}")
result = hexstrike_client.safe_post("api/tools/msfvenom", data)
if result.get("success"):
logger.info(f"✅ MSFVenom payload generated")
logger.info("✅ MSFVenom payload generated")
else:
logger.error(f"❌ MSFVenom payload generation failed")
logger.error("❌ MSFVenom payload generation failed")
return result
# ============================================================================
@@ -2071,9 +2071,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔧 Starting Pwntools exploit: {exploit_type}")
result = hexstrike_client.safe_post("api/tools/pwntools", data)
if result.get("success"):
logger.info(f"✅ Pwntools exploit completed")
logger.info("✅ Pwntools exploit completed")
else:
logger.error(f"❌ Pwntools exploit failed")
logger.error("❌ Pwntools exploit failed")
return result
@mcp.tool()
@@ -2097,9 +2097,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔧 Starting one_gadget analysis: {libc_path}")
result = hexstrike_client.safe_post("api/tools/one-gadget", data)
if result.get("success"):
logger.info(f"✅ one_gadget analysis completed")
logger.info("✅ one_gadget analysis completed")
else:
logger.error(f"❌ one_gadget analysis failed")
logger.error("❌ one_gadget analysis failed")
return result
@mcp.tool()
@@ -2157,9 +2157,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔧 Starting GDB-PEDA analysis: {binary or f'PID {attach_pid}' or core_file}")
result = hexstrike_client.safe_post("api/tools/gdb-peda", data)
if result.get("success"):
logger.info(f"✅ GDB-PEDA analysis completed")
logger.info("✅ GDB-PEDA analysis completed")
else:
logger.error(f"❌ GDB-PEDA analysis failed")
logger.error("❌ GDB-PEDA analysis failed")
return result
@mcp.tool()
@@ -2191,9 +2191,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔧 Starting angr analysis: {binary}")
result = hexstrike_client.safe_post("api/tools/angr", data)
if result.get("success"):
logger.info(f"✅ angr analysis completed")
logger.info("✅ angr analysis completed")
else:
logger.error(f"❌ angr analysis failed")
logger.error("❌ angr analysis failed")
return result
@mcp.tool()
@@ -2225,9 +2225,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔧 Starting ropper analysis: {binary}")
result = hexstrike_client.safe_post("api/tools/ropper", data)
if result.get("success"):
logger.info(f"✅ ropper analysis completed")
logger.info("✅ ropper analysis completed")
else:
logger.error(f"❌ ropper analysis failed")
logger.error("❌ ropper analysis failed")
return result
@mcp.tool()
@@ -2256,9 +2256,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔧 Starting pwninit setup: {binary}")
result = hexstrike_client.safe_post("api/tools/pwninit", data)
if result.get("success"):
logger.info(f"✅ pwninit setup completed")
logger.info("✅ pwninit setup completed")
else:
logger.error(f"❌ pwninit setup failed")
logger.error("❌ pwninit setup failed")
return result
@mcp.tool()
@@ -2667,9 +2667,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🎯 Starting Dalfox XSS scan: {url if url else 'pipe mode'}")
result = hexstrike_client.safe_post("api/tools/dalfox", data)
if result.get("success"):
logger.info(f"✅ Dalfox XSS scan completed")
logger.info("✅ Dalfox XSS scan completed")
else:
logger.error(f"❌ Dalfox XSS scan failed")
logger.error("❌ Dalfox XSS scan failed")
return result
@mcp.tool()
@@ -2922,7 +2922,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
if payload_info.get("risk_level") == "HIGH":
results["summary"]["high_risk_payloads"] += 1
logger.info(f"✅ Attack suite generated:")
logger.info("✅ Attack suite generated:")
logger.info(f" ├─ Total payloads: {results['summary']['total_payloads']}")
logger.info(f" ├─ High-risk payloads: {results['summary']['high_risk_payloads']}")
logger.info(f" └─ Test cases: {results['summary']['test_cases']}")
@@ -2967,7 +2967,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
endpoint_count = len(result.get("results", []))
logger.info(f"✅ API endpoint testing completed: {endpoint_count} endpoints tested")
else:
logger.info(f"✅ API endpoint discovery completed")
logger.info("✅ API endpoint discovery completed")
else:
logger.error("❌ API fuzzing failed")
@@ -3032,7 +3032,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"target_url": target_url
}
logger.info(f"🔍 Starting JWT security analysis")
logger.info("🔍 Starting JWT security analysis")
result = hexstrike_client.safe_post("api/tools/jwt_analyzer", data)
if result.get("success"):
@@ -3089,7 +3089,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.warning(f" ├─ [{severity}] {issue_type}")
if endpoint_count > 0:
logger.info(f"📊 Discovered endpoints:")
logger.info("📊 Discovered endpoints:")
for endpoint in analysis.get("endpoints_found", [])[:5]: # Show first 5
method = endpoint.get("method", "GET")
path = endpoint.get("path", "/")
@@ -3183,7 +3183,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"audit_coverage": "comprehensive" if len(audit_results["tests_performed"]) >= 3 else "partial"
}
logger.info(f"✅ Comprehensive API audit completed:")
logger.info("✅ Comprehensive API audit completed:")
logger.info(f" ├─ Tests performed: {audit_results['summary']['tests_performed']}")
logger.info(f" ├─ Total vulnerabilities: {audit_results['summary']['total_vulnerabilities']}")
logger.info(f" └─ Coverage: {audit_results['summary']['audit_coverage']}")
@@ -3220,9 +3220,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🧠 Starting Volatility3 analysis: {plugin}")
result = hexstrike_client.safe_post("api/tools/volatility3", data)
if result.get("success"):
logger.info(f"✅ Volatility3 analysis completed")
logger.info("✅ Volatility3 analysis completed")
else:
logger.error(f"❌ Volatility3 analysis failed")
logger.error("❌ Volatility3 analysis failed")
return result
@mcp.tool()
@@ -3248,9 +3248,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"📁 Starting Foremost file carving: {input_file}")
result = hexstrike_client.safe_post("api/tools/foremost", data)
if result.get("success"):
logger.info(f"✅ Foremost carving completed")
logger.info("✅ Foremost carving completed")
else:
logger.error(f"❌ Foremost carving failed")
logger.error("❌ Foremost carving failed")
return result
@mcp.tool()
@@ -3308,9 +3308,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"📷 Starting ExifTool analysis: {file_path}")
result = hexstrike_client.safe_post("api/tools/exiftool", data)
if result.get("success"):
logger.info(f"✅ ExifTool analysis completed")
logger.info("✅ ExifTool analysis completed")
else:
logger.error(f"❌ ExifTool analysis failed")
logger.error("❌ ExifTool analysis failed")
return result
@mcp.tool()
@@ -3335,12 +3335,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"append_data": append_data,
"additional_args": additional_args
}
logger.info(f"🔐 Starting HashPump attack")
logger.info("🔐 Starting HashPump attack")
result = hexstrike_client.safe_post("api/tools/hashpump", data)
if result.get("success"):
logger.info(f"✅ HashPump attack completed")
logger.info("✅ HashPump attack completed")
else:
logger.error(f"❌ HashPump attack failed")
logger.error("❌ HashPump attack failed")
return result
# ============================================================================
@@ -3383,9 +3383,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🕷️ Starting Hakrawler crawling: {url}")
result = hexstrike_client.safe_post("api/tools/hakrawler", data)
if result.get("success"):
logger.info(f"✅ Hakrawler crawling completed")
logger.info("✅ Hakrawler crawling completed")
else:
logger.error(f"❌ Hakrawler crawling failed")
logger.error("❌ Hakrawler crawling failed")
return result
@mcp.tool()
@@ -3416,12 +3416,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"output_file": output_file,
"additional_args": additional_args
}
logger.info(f"🌐 Starting HTTPx probing")
logger.info("🌐 Starting HTTPx probing")
result = hexstrike_client.safe_post("api/tools/httpx", data)
if result.get("success"):
logger.info(f"✅ HTTPx probing completed")
logger.info("✅ HTTPx probing completed")
else:
logger.error(f"❌ HTTPx probing failed")
logger.error("❌ HTTPx probing failed")
return result
@mcp.tool()
@@ -3449,9 +3449,9 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
logger.info(f"🔍 Starting ParamSpider discovery: {domain}")
result = hexstrike_client.safe_post("api/tools/paramspider", data)
if result.get("success"):
logger.info(f"✅ ParamSpider discovery completed")
logger.info("✅ ParamSpider discovery completed")
else:
logger.error(f"❌ ParamSpider discovery failed")
logger.error("❌ ParamSpider discovery failed")
return result
# ============================================================================
@@ -3486,12 +3486,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
"output_file": output_file,
"additional_args": additional_args
}
logger.info(f"🔍 Starting Burp Suite scan")
logger.info("🔍 Starting Burp Suite scan")
result = hexstrike_client.safe_post("api/tools/burpsuite", data)
if result.get("success"):
logger.info(f"✅ Burp Suite scan completed")
logger.info("✅ Burp Suite scan completed")
else:
logger.error(f"❌ Burp Suite scan failed")
logger.error("❌ Burp Suite scan failed")
return result
@mcp.tool()
@@ -3794,7 +3794,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
Returns:
Server health information with tool availability and telemetry
"""
logger.info(f"🏥 Checking HexStrike AI server health")
logger.info("🏥 Checking HexStrike AI server health")
result = hexstrike_client.check_health()
if result.get("status") == "healthy":
logger.info(f"✅ Server is healthy - {result.get('total_tools_available', 0)} tools available")
@@ -3810,7 +3810,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
Returns:
Cache performance statistics
"""
logger.info(f"💾 Getting cache statistics")
logger.info("💾 Getting cache statistics")
result = hexstrike_client.safe_get("api/cache/stats")
if "hit_rate" in result:
logger.info(f"📊 Cache hit rate: {result.get('hit_rate', 'unknown')}")
@@ -3824,12 +3824,12 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
Returns:
Cache clear operation results
"""
logger.info(f"🧹 Clearing server cache")
logger.info("🧹 Clearing server cache")
result = hexstrike_client.safe_post("api/cache/clear", {})
if result.get("success"):
logger.info(f"✅ Cache cleared successfully")
logger.info("✅ Cache cleared successfully")
else:
logger.error(f"❌ Failed to clear cache")
logger.error("❌ Failed to clear cache")
return result
@mcp.tool()
@@ -3840,7 +3840,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
Returns:
System performance and usage telemetry
"""
logger.info(f"📈 Getting system telemetry")
logger.info("📈 Getting system telemetry")
result = hexstrike_client.safe_get("api/telemetry")
if "commands_executed" in result:
logger.info(f"📊 Commands executed: {result.get('commands_executed', 0)}")
@@ -3993,7 +3993,7 @@ def setup_mcp_server(hexstrike_client: HexStrikeClient) -> FastMCP:
execution_time = result.get("execution_time", 0)
logger.info(f"✅ Command completed successfully in {execution_time:.2f}s")
else:
logger.warning(f"⚠️ Command completed with errors")
logger.warning("⚠️ Command completed with errors")
return result
except Exception as e:
@@ -5433,7 +5433,7 @@ def main():
logger.debug("🔍 Debug logging enabled")
# MCP compatibility: No banner output to avoid JSON parsing issues
logger.info(f"🚀 Starting HexStrike AI MCP Client v6.0")
logger.info("🚀 Starting HexStrike AI MCP Client v6.0")
logger.info(f"🔗 Connecting to: {args.server}")
try:

File diff suppressed because it is too large Load Diff