test: add shadow graph and notes tests, fix linting

This commit is contained in:
GH05TCREW
2025-12-13 10:24:50 -07:00
parent 0d668a14af
commit d191a0104b
9 changed files with 426 additions and 61 deletions

View File

@@ -153,12 +153,11 @@ ghostcrew mcp add <name> <command> [args...] # Add MCP server
ghostcrew mcp test <name> # Test MCP connection
```
## Knowledge Base (RAG)
## Knowledge
Place files in `ghostcrew/knowledge/sources/` for RAG context injection:
- `methodologies.md` - Testing methodologies
- `cves.json` - CVE database
- `wordlists.txt` - Common wordlists
- **RAG:** Place methodologies, CVEs, or wordlists in `ghostcrew/knowledge/sources/` for automatic context injection.
- **Notes:** Agents save findings to `loot/notes.json` with categories (`credential`, `vulnerability`, `finding`, `artifact`). Notes persist across sessions and are injected into agent context.
- **Shadow Graph:** In Crew mode, the orchestrator builds a knowledge graph from notes to derive strategic insights (e.g., "We have credentials for host X").
## Project Structure
@@ -167,7 +166,7 @@ ghostcrew/
agents/ # Agent implementations
config/ # Settings and constants
interface/ # TUI and CLI
knowledge/ # RAG system
knowledge/ # RAG system and shadow graph
llm/ # LiteLLM wrapper
mcp/ # MCP client and server configs
runtime/ # Execution environment

View File

@@ -5,11 +5,11 @@ import platform
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional
from ...config.constants import DEFAULT_MAX_ITERATIONS
from ...knowledge.graph import ShadowGraph
from ..prompts import ghost_crew
from .models import CrewState, WorkerCallback
from .tools import create_crew_tools
from .worker_pool import WorkerPool
from ...knowledge.graph import ShadowGraph
if TYPE_CHECKING:
from ...llm import LLM
@@ -77,7 +77,7 @@ class CrewOrchestrator:
else:
cat = data.get("category", "info")
content = data.get("content", "")
# Truncate long notes in system prompt to save tokens
if len(content) > 200:
content = content[:197] + "..."
@@ -85,11 +85,18 @@ class CrewOrchestrator:
if cat not in grouped:
grouped[cat] = []
grouped[cat].append(f"- {key}: {content}")
# Format output with specific order
sections = []
order = ["credential", "vulnerability", "finding", "artifact", "task", "info"]
order = [
"credential",
"vulnerability",
"finding",
"artifact",
"task",
"info",
]
for cat in order:
if cat in grouped:
header = cat.title() + "s"
@@ -97,13 +104,13 @@ class CrewOrchestrator:
header = "General Information"
sections.append(f"## {header}")
sections.append("\n".join(grouped[cat]))
# Add any remaining categories
for cat in sorted(grouped.keys()):
if cat not in order:
sections.append(f"## {cat.title()}")
sections.append("\n".join(grouped[cat]))
notes_context = "\n\n".join(sections)
except Exception:
pass # Notes not available
@@ -111,7 +118,9 @@ class CrewOrchestrator:
# Format insights for prompt
insights_text = ""
if graph_insights:
insights_text = "\n\n## Strategic Insights (Graph Analysis)\n" + "\n".join(f"- {i}" for i in graph_insights)
insights_text = "\n\n## Strategic Insights (Graph Analysis)\n" + "\n".join(
f"- {i}" for i in graph_insights
)
return ghost_crew.render(
target=self.target or "Not specified",

View File

@@ -140,9 +140,11 @@ class WorkerPool:
# Capture final response (text without tool calls)
if response.content and not response.tool_calls:
final_response = response.content
# Check if max iterations was hit
if response.metadata and response.metadata.get("max_iterations_reached"):
if response.metadata and response.metadata.get(
"max_iterations_reached"
):
hit_max_iterations = True
worker.result = final_response or "No findings."

View File

@@ -81,7 +81,7 @@ class GhostCrewAgent(BaseAgent):
else:
cat = data.get("category", "info")
content = data.get("content", "")
# Truncate long notes in system prompt to save tokens
# The agent can use the 'read' tool to get the full content
if len(content) > 200:
@@ -90,11 +90,18 @@ class GhostCrewAgent(BaseAgent):
if cat not in grouped:
grouped[cat] = []
grouped[cat].append(f"- {key}: {content}")
# Format output with specific order
sections = []
order = ["credential", "vulnerability", "finding", "artifact", "task", "info"]
order = [
"credential",
"vulnerability",
"finding",
"artifact",
"task",
"info",
]
for cat in order:
if cat in grouped:
header = cat.title() + "s"
@@ -102,13 +109,13 @@ class GhostCrewAgent(BaseAgent):
header = "General Information"
sections.append(f"## {header}")
sections.append("\n".join(grouped[cat]))
# Add any remaining categories
for cat in sorted(grouped.keys()):
if cat not in order:
sections.append(f"## {cat.title()}")
sections.append("\n".join(grouped[cat]))
notes_context = "\n\n".join(sections)
except Exception:
pass # Notes not available

View File

@@ -13,7 +13,7 @@ Architecture:
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Set
import networkx as nx
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
@dataclass
class GraphNode:
"""A node in the shadow graph."""
id: str
type: str # host, service, credential, finding, artifact
label: str
@@ -35,6 +36,7 @@ class GraphNode:
@dataclass
class GraphEdge:
"""An edge in the shadow graph."""
source: str
target: str
type: str # CONNECTS_TO, HAS_SERVICE, AUTH_ACCESS, RELATED_TO
@@ -49,16 +51,16 @@ class ShadowGraph:
def __init__(self):
self.graph = nx.DiGraph()
self._processed_notes: Set[str] = set()
# Regex patterns for entity extraction
self._ip_pattern = re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b')
self._port_pattern = re.compile(r'(\d{1,5})/(tcp|udp)')
self._user_pattern = re.compile(r'user[:\s]+([a-zA-Z0-9_.-]+)', re.IGNORECASE)
self._ip_pattern = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")
self._port_pattern = re.compile(r"(\d{1,5})/(tcp|udp)")
self._user_pattern = re.compile(r"user[:\s]+([a-zA-Z0-9_.-]+)", re.IGNORECASE)
def update_from_notes(self, notes: Dict[str, Dict[str, Any]]) -> None:
"""
Update the graph based on new notes.
This method is idempotent and incremental. It only processes notes
that haven't been seen before (based on key).
"""
@@ -79,7 +81,7 @@ class ShadowGraph:
def _process_note(self, key: str, content: str, category: str) -> None:
"""Extract entities and relationships from a single note."""
# 1. Extract IPs (Hosts)
ips = self._ip_pattern.findall(content)
hosts = []
@@ -110,12 +112,14 @@ class ShadowGraph:
if self.graph.has_node(source) and self.graph.has_node(target):
self.graph.add_edge(source, target, type=edge_type, **kwargs)
def _process_credential(self, key: str, content: str, related_hosts: List[str]) -> None:
def _process_credential(
self, key: str, content: str, related_hosts: List[str]
) -> None:
"""Process a credential note."""
# Extract username
user_match = self._user_pattern.search(content)
username = user_match.group(1) if user_match else "unknown"
cred_id = f"cred:{key}"
self._add_node(cred_id, "credential", f"Creds ({username})")
@@ -125,7 +129,9 @@ class ShadowGraph:
protocol = "ssh" if "ssh" in content.lower() else "unknown"
self._add_edge(cred_id, host_id, "AUTH_ACCESS", protocol=protocol)
def _process_finding(self, key: str, content: str, related_hosts: List[str]) -> None:
def _process_finding(
self, key: str, content: str, related_hosts: List[str]
) -> None:
"""Process a finding note (e.g., open ports)."""
# Extract ports
ports = self._port_pattern.findall(content)
@@ -135,15 +141,17 @@ class ShadowGraph:
self._add_node(service_id, "service", f"{port}/{proto}")
self._add_edge(host_id, service_id, "HAS_SERVICE", protocol=proto)
def _process_vulnerability(self, key: str, content: str, related_hosts: List[str]) -> None:
def _process_vulnerability(
self, key: str, content: str, related_hosts: List[str]
) -> None:
"""Process a vulnerability note."""
vuln_id = f"vuln:{key}"
# Try to extract CVE
cve_match = re.search(r'CVE-\d{4}-\d{4,7}', content, re.IGNORECASE)
cve_match = re.search(r"CVE-\d{4}-\d{4,7}", content, re.IGNORECASE)
label = cve_match.group(0) if cve_match else "Vulnerability"
self._add_node(vuln_id, "vulnerability", label)
for host_id in related_hosts:
self._add_edge(host_id, vuln_id, "AFFECTED_BY")
@@ -152,7 +160,7 @@ class ShadowGraph:
Analyze the graph and return natural language insights for the Orchestrator.
"""
insights = []
# Insight 1: Unused Credentials
# Find credentials that have AUTH_ACCESS to a host, but we haven't "explored" that host fully?
# Or simply list valid access paths.
@@ -161,29 +169,53 @@ class ShadowGraph:
# Find what it connects to
targets = [v for u, v in self.graph.out_edges(node)]
if targets:
target_labels = [self.graph.nodes[t].get("label", t) for t in targets]
insights.append(f"We have credentials that provide access to: {', '.join(target_labels)}")
target_labels = [
self.graph.nodes[t].get("label", t) for t in targets
]
insights.append(
f"We have credentials that provide access to: {', '.join(target_labels)}"
)
# Insight 2: High Value Targets (Hosts with many open ports/vulns)
for node, data in self.graph.nodes(data=True):
if data.get("type") == "host":
# Count services
services = [v for u, v in self.graph.out_edges(node) if self.graph.nodes[v].get("type") == "service"]
vulns = [v for u, v in self.graph.out_edges(node) if self.graph.nodes[v].get("type") == "vulnerability"]
services = [
v
for u, v in self.graph.out_edges(node)
if self.graph.nodes[v].get("type") == "service"
]
vulns = [
v
for u, v in self.graph.out_edges(node)
if self.graph.nodes[v].get("type") == "vulnerability"
]
if len(services) > 0 or len(vulns) > 0:
insights.append(f"Host {data['label']} has {len(services)} services and {len(vulns)} known vulnerabilities.")
insights.append(
f"Host {data['label']} has {len(services)} services and {len(vulns)} known vulnerabilities."
)
# Insight 3: Potential Pivots (Host A -> Cred -> Host B)
# This is harder without explicit "source" of creds, but we can infer.
return insights
def export_summary(self) -> str:
"""Export a text summary of the graph state."""
stats = {
"hosts": len([n for n, d in self.graph.nodes(data=True) if d['type'] == 'host']),
"creds": len([n for n, d in self.graph.nodes(data=True) if d['type'] == 'credential']),
"vulns": len([n for n, d in self.graph.nodes(data=True) if d['type'] == 'vulnerability']),
"hosts": len(
[n for n, d in self.graph.nodes(data=True) if d["type"] == "host"]
),
"creds": len(
[n for n, d in self.graph.nodes(data=True) if d["type"] == "credential"]
),
"vulns": len(
[
n
for n, d in self.graph.nodes(data=True)
if d["type"] == "vulnerability"
]
),
}
return f"Graph State: {stats['hosts']} Hosts, {stats['creds']} Credentials, {stats['vulns']} Vulnerabilities"

View File

@@ -24,7 +24,11 @@ def _load_notes_unlocked() -> None:
_notes = {}
for k, v in loaded.items():
if isinstance(v, str):
_notes[k] = {"content": v, "category": "info", "confidence": "medium"}
_notes[k] = {
"content": v,
"category": "info",
"confidence": "medium",
}
else:
_notes[k] = v
except (json.JSONDecodeError, IOError):
@@ -55,7 +59,11 @@ def get_all_notes_sync() -> Dict[str, Dict[str, Any]]:
result = {}
for k, v in loaded.items():
if isinstance(v, str):
result[k] = {"content": v, "category": "info", "confidence": "medium"}
result[k] = {
"content": v,
"category": "info",
"confidence": "medium",
}
else:
result[k] = v
return result
@@ -96,7 +104,14 @@ _load_notes_unlocked()
},
"category": {
"type": "string",
"enum": ["finding", "credential", "task", "info", "vulnerability", "artifact"],
"enum": [
"finding",
"credential",
"task",
"info",
"vulnerability",
"artifact",
],
"description": "Category for organization (default: info)",
},
"confidence": {
@@ -123,13 +138,20 @@ async def notes(arguments: dict, runtime) -> str:
action = arguments["action"]
key = arguments.get("key", "").strip()
value = arguments.get("value", "")
# Soft validation for category
category = arguments.get("category", "info")
valid_categories = ["finding", "credential", "task", "info", "vulnerability", "artifact"]
valid_categories = [
"finding",
"credential",
"task",
"info",
"vulnerability",
"artifact",
]
if category not in valid_categories:
category = "info"
confidence = arguments.get("confidence", "medium")
async with _notes_lock:
@@ -144,7 +166,7 @@ async def notes(arguments: dict, runtime) -> str:
_notes[key] = {
"content": value,
"category": category,
"confidence": confidence
"confidence": confidence,
}
_save_notes_unlocked()
return f"Created note '{key}' ({category})"
@@ -156,7 +178,9 @@ async def notes(arguments: dict, runtime) -> str:
return f"Note '{key}' not found"
note = _notes[key]
return f"[{key}] ({note['category']}, {note['confidence']}) {note['content']}"
return (
f"[{key}] ({note['category']}, {note['confidence']}) {note['content']}"
)
elif action == "update":
if not key:
@@ -165,17 +189,17 @@ async def notes(arguments: dict, runtime) -> str:
return "Error: value is required for update"
existed = key in _notes
# Preserve existing metadata if not provided? No, overwrite is cleaner for now,
# Preserve existing metadata if not provided? No, overwrite is cleaner for now,
# but maybe we should default to existing if not provided.
# For now, let's just overwrite with defaults if missing, or use provided.
# Actually, if updating, we might want to keep category if not specified.
# But arguments.get("category", "info") defaults to info.
# Let's stick to simple overwrite for now to match previous behavior.
_notes[key] = {
"content": value,
"category": category,
"confidence": confidence
"confidence": confidence,
}
_save_notes_unlocked()
return f"{'Updated' if existed else 'Created'} note '{key}'"
@@ -195,7 +219,7 @@ async def notes(arguments: dict, runtime) -> str:
return "No notes saved"
lines = [f"Notes ({len(_notes)} entries):"]
# Group by category for display
by_category = {}
for k, v in _notes.items():
@@ -203,12 +227,14 @@ async def notes(arguments: dict, runtime) -> str:
if cat not in by_category:
by_category[cat] = []
by_category[cat].append((k, v))
for cat in sorted(by_category.keys()):
lines.append(f"\n## {cat.title()}")
for k, v in by_category[cat]:
content = v["content"]
display_val = content if len(content) <= 60 else content[:57] + "..."
display_val = (
content if len(content) <= 60 else content[:57] + "..."
)
conf = v.get("confidence", "medium")
lines.append(f" [{k}] ({conf}) {display_val}")

View File

@@ -40,6 +40,7 @@ dependencies = [
"beautifulsoup4>=4.12.0",
"httpx>=0.27.0",
"numpy>=1.26.0",
"networkx>=3.3",
"docker>=7.0.0",
"rich>=13.7.0",
"textual>=0.63.0",

127
tests/test_graph.py Normal file
View File

@@ -0,0 +1,127 @@
"""Tests for the Shadow Graph knowledge system."""
import pytest
import networkx as nx
from ghostcrew.knowledge.graph import ShadowGraph, GraphNode, GraphEdge
class TestShadowGraph:
"""Tests for ShadowGraph class."""
@pytest.fixture
def graph(self):
"""Create a fresh ShadowGraph for each test."""
return ShadowGraph()
def test_initialization(self, graph):
"""Test graph initialization."""
assert isinstance(graph.graph, nx.DiGraph)
assert len(graph.graph.nodes) == 0
assert len(graph._processed_notes) == 0
def test_extract_host_from_note(self, graph):
"""Test extracting host IP from a note."""
notes = {
"scan_result": {
"content": "Nmap scan for 192.168.1.10 shows open ports.",
"category": "info"
}
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:192.168.1.10")
node = graph.graph.nodes["host:192.168.1.10"]
assert node["type"] == "host"
assert node["label"] == "192.168.1.10"
def test_extract_service_finding(self, graph):
"""Test extracting services from a finding note."""
notes = {
"ports_scan": {
"content": "Found open ports: 80/tcp, 443/tcp on 10.0.0.5",
"category": "finding"
}
}
graph.update_from_notes(notes)
# Check host exists
assert graph.graph.has_node("host:10.0.0.5")
# Check services exist
assert graph.graph.has_node("service:host:10.0.0.5:80")
assert graph.graph.has_node("service:host:10.0.0.5:443")
# Check edges
assert graph.graph.has_edge("host:10.0.0.5", "service:host:10.0.0.5:80")
edge = graph.graph.edges["host:10.0.0.5", "service:host:10.0.0.5:80"]
assert edge["type"] == "HAS_SERVICE"
assert edge["protocol"] == "tcp"
def test_extract_credential(self, graph):
"""Test extracting credentials and linking to host."""
notes = {
"ssh_creds": {
"content": "Found user: admin with password 'password123' for SSH on 192.168.1.20",
"category": "credential"
}
}
graph.update_from_notes(notes)
cred_id = "cred:ssh_creds"
host_id = "host:192.168.1.20"
assert graph.graph.has_node(cred_id)
assert graph.graph.has_node(host_id)
# Check edge
assert graph.graph.has_edge(cred_id, host_id)
edge = graph.graph.edges[cred_id, host_id]
assert edge["type"] == "AUTH_ACCESS"
assert edge["protocol"] == "ssh"
def test_legacy_note_format(self, graph):
"""Test handling legacy string-only notes."""
notes = {
"legacy_note": "Just a simple note about 10.10.10.10"
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:10.10.10.10")
def test_idempotency(self, graph):
"""Test that processing the same note twice doesn't duplicate or error."""
notes = {
"scan": {
"content": "Host 192.168.1.1 is up.",
"category": "info"
}
}
# First pass
graph.update_from_notes(notes)
assert len(graph.graph.nodes) == 1
# Second pass
graph.update_from_notes(notes)
assert len(graph.graph.nodes) == 1
# Modify note (simulate update - though currently graph only processes new keys,
# in a real scenario we might want to handle updates, but for now we test it ignores processed keys)
notes["scan"]["content"] = "Host 192.168.1.1 is down."
graph.update_from_notes(notes)
# Should still be based on first pass if we strictly check processed keys
# The current implementation uses a set of processed keys, so it won't re-process.
assert len(graph.graph.nodes) == 1
def test_multiple_ips_in_one_note(self, graph):
"""Test a single note referencing multiple hosts."""
notes = {
"subnet_scan": {
"content": "Scanning 192.168.1.1, 192.168.1.2, and 192.168.1.3",
"category": "info"
}
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:192.168.1.1")
assert graph.graph.has_node("host:192.168.1.2")
assert graph.graph.has_node("host:192.168.1.3")

162
tests/test_notes.py Normal file
View File

@@ -0,0 +1,162 @@
"""Tests for the Notes tool."""
import pytest
import json
import asyncio
from pathlib import Path
from unittest.mock import MagicMock, patch
from ghostcrew.tools.notes import notes, set_notes_file, get_all_notes, _notes
# We need to reset the global state for tests
@pytest.fixture(autouse=True)
def reset_notes_state(tmp_path):
"""Reset the notes global state for each test."""
# Point to a temp file
temp_notes_file = tmp_path / "notes.json"
set_notes_file(temp_notes_file)
# Clear the global dictionary (it's imported from the module)
# We need to clear the actual dictionary object in the module
from ghostcrew.tools.notes import _notes
_notes.clear()
yield
# Cleanup is handled by tmp_path
@pytest.mark.asyncio
async def test_create_note():
"""Test creating a new note."""
args = {
"action": "create",
"key": "test_note",
"value": "This is a test note",
"category": "info",
"confidence": "high"
}
result = await notes(args, runtime=None)
assert "Created note 'test_note'" in result
all_notes = await get_all_notes()
assert "test_note" in all_notes
assert all_notes["test_note"]["content"] == "This is a test note"
assert all_notes["test_note"]["category"] == "info"
assert all_notes["test_note"]["confidence"] == "high"
@pytest.mark.asyncio
async def test_read_note():
"""Test reading an existing note."""
# Create first
await notes({
"action": "create",
"key": "read_me",
"value": "Content to read"
}, runtime=None)
# Read
result = await notes({
"action": "read",
"key": "read_me"
}, runtime=None)
assert "Content to read" in result
# The format is "[key] (category, confidence) content"
assert "(info, medium)" in result
@pytest.mark.asyncio
async def test_update_note():
"""Test updating a note."""
await notes({
"action": "create",
"key": "update_me",
"value": "Original content"
}, runtime=None)
result = await notes({
"action": "update",
"key": "update_me",
"value": "New content"
}, runtime=None)
assert "Updated note 'update_me'" in result
all_notes = await get_all_notes()
assert all_notes["update_me"]["content"] == "New content"
@pytest.mark.asyncio
async def test_delete_note():
"""Test deleting a note."""
await notes({
"action": "create",
"key": "delete_me",
"value": "Bye bye"
}, runtime=None)
result = await notes({
"action": "delete",
"key": "delete_me"
}, runtime=None)
assert "Deleted note 'delete_me'" in result
all_notes = await get_all_notes()
assert "delete_me" not in all_notes
@pytest.mark.asyncio
async def test_list_notes():
"""Test listing all notes."""
await notes({"action": "create", "key": "n1", "value": "v1"}, runtime=None)
await notes({"action": "create", "key": "n2", "value": "v2"}, runtime=None)
result = await notes({"action": "list"}, runtime=None)
assert "n1" in result
assert "n2" in result
assert "Notes (2 entries):" in result
@pytest.mark.asyncio
async def test_persistence(tmp_path):
"""Test that notes are saved to disk."""
# The fixture already sets a temp file
temp_file = tmp_path / "notes.json"
await notes({
"action": "create",
"key": "persistent_note",
"value": "I survive restarts"
}, runtime=None)
assert temp_file.exists()
content = json.loads(temp_file.read_text())
assert "persistent_note" in content
assert content["persistent_note"]["content"] == "I survive restarts"
@pytest.mark.asyncio
async def test_legacy_migration(tmp_path):
"""Test migration of legacy string notes."""
# Create a legacy file
legacy_file = tmp_path / "legacy_notes.json"
legacy_data = {
"old_note": "Just a string",
"new_note": {"content": "A dict", "category": "info"}
}
legacy_file.write_text(json.dumps(legacy_data))
# Point the tool to this file
set_notes_file(legacy_file)
# Trigger load (get_all_notes calls _load_notes_unlocked if empty, but we need to clear first)
from ghostcrew.tools.notes import _notes
_notes.clear()
all_notes = await get_all_notes()
assert "old_note" in all_notes
assert isinstance(all_notes["old_note"], dict)
assert all_notes["old_note"]["content"] == "Just a string"
assert all_notes["old_note"]["category"] == "info"
assert "new_note" in all_notes
assert all_notes["new_note"]["content"] == "A dict"