Files
pentestagent/ghostcrew/knowledge/graph.py

426 lines
15 KiB
Python

"""
Shadow Graph implementation for GhostCrew.
This module provides a lightweight knowledge graph that is built automatically
from agent notes. It is used by the Orchestrator to compute strategic insights
(e.g., "we have creds for X but haven't scanned it") without burdening the
agents with graph management.
Architecture:
Notes (Source of Truth) -> Shadow Graph (Derived View) -> Insights (Strategic Hints)
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Set
import networkx as nx
logger = logging.getLogger(__name__)
@dataclass
class GraphNode:
"""A node in the shadow graph."""
id: str
type: str # host, service, credential, finding, artifact
label: str
metadata: Dict[str, Any] = field(default_factory=dict)
def __hash__(self):
return hash(self.id)
@dataclass
class GraphEdge:
"""An edge in the shadow graph."""
source: str
target: str
type: str # CONNECTS_TO, HAS_SERVICE, AUTH_ACCESS, RELATED_TO
metadata: Dict[str, Any] = field(default_factory=dict)
class ShadowGraph:
"""
A NetworkX-backed knowledge graph that derives its state from notes.
"""
def __init__(self):
self.graph = nx.DiGraph()
self._processed_notes: Set[str] = set()
# Regex patterns for entity extraction
self._ip_pattern = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")
self._port_pattern = re.compile(r"(\d{1,5})/(tcp|udp)")
self._user_pattern = re.compile(
r"(?:user|username)[:\s]+([a-zA-Z0-9_.-]+)", re.IGNORECASE
)
self._source_pattern = re.compile(
r"(?:found on|dumped from|extracted from|on host)\s+((?:\d{1,3}\.){3}\d{1,3})",
re.IGNORECASE,
)
def update_from_notes(self, notes: Dict[str, Dict[str, Any]]) -> None:
"""
Update the graph based on new notes.
This method is idempotent and incremental. It only processes notes
that haven't been seen before (based on key).
"""
for key, note_data in notes.items():
if key in self._processed_notes:
continue
# Handle legacy format
if isinstance(note_data, str):
content = note_data
category = "info"
metadata = {}
status = "confirmed"
else:
content = note_data.get("content", "")
category = note_data.get("category", "info")
metadata = note_data.get("metadata", {})
status = note_data.get("status", "confirmed")
self._process_note(key, content, category, metadata, status)
self._processed_notes.add(key)
def _process_note(
self,
key: str,
content: str,
category: str,
metadata: Dict[str, Any],
status: str,
) -> None:
"""Extract entities and relationships from a single note."""
# 1. Extract IPs (Hosts)
# Prefer metadata if available
hosts = []
# Check target in metadata
if metadata.get("target"):
target_ip = metadata["target"]
# Validate it looks like an IP or hostname? For now just accept it.
node_id = f"host:{target_ip}"
self._add_node(node_id, "host", target_ip)
hosts.append(node_id)
# Check source in metadata
if metadata.get("source"):
source_ip = metadata["source"]
node_id = f"host:{source_ip}"
self._add_node(node_id, "host", source_ip)
hosts.append(node_id)
# Fallback to regex if no hosts found in metadata
if not hosts:
ips = self._ip_pattern.findall(content)
for ip in ips:
node_id = f"host:{ip}"
self._add_node(node_id, "host", ip)
hosts.append(node_id)
# 2. Handle specific categories
if category == "credential":
self._process_credential(key, content, hosts, metadata, status)
elif category == "finding":
self._process_finding(key, content, hosts, metadata, status)
elif category == "vulnerability":
self._process_vulnerability(key, content, hosts, metadata, status)
# 3. Link note to hosts (provenance)
# We don't add the note itself as a node usually, but we could.
# For now, we just use the note to build Host-to-Host or Host-to-Service links.
def _add_node(self, node_id: str, node_type: str, label: str, **kwargs) -> None:
"""Add a node if it doesn't exist."""
if not self.graph.has_node(node_id):
self.graph.add_node(node_id, type=node_type, label=label, **kwargs)
def _add_edge(self, source: str, target: str, edge_type: str, **kwargs) -> None:
"""Add an edge."""
if self.graph.has_node(source) and self.graph.has_node(target):
self.graph.add_edge(source, target, type=edge_type, **kwargs)
def _process_credential(
self,
key: str,
content: str,
related_hosts: List[str],
metadata: Dict[str, Any],
status: str,
) -> None:
"""Process a credential note."""
# Skip if status is closed/filtered (invalid creds)
if status in ["closed", "filtered"]:
return
# Extract username from metadata or regex
username = metadata.get("username")
if not username:
user_match = self._user_pattern.search(content)
username = user_match.group(1) if user_match else None
cred_id = f"cred:{key}"
label = f"Creds ({username})" if username else "Credentials"
self._add_node(cred_id, "credential", label)
# Check for "found on" source host
source_host = None
if metadata.get("source"):
source_ip = metadata["source"]
source_host = f"host:{source_ip}"
else:
source_match = self._source_pattern.search(content)
if source_match:
source_ip = source_match.group(1)
source_host = f"host:{source_ip}"
if source_host:
# Add CONTAINS edge: Host -> Cred
if self.graph.has_node(source_host):
self._add_edge(source_host, cred_id, "CONTAINS")
# Link cred to hosts it belongs to (or works on)
for host_id in related_hosts:
# If this host is the source, skip adding it as a target unless explicitly clear?
# For now, if we identified it as source, assume it's NOT the target unless it's the only one?
# Let's just exclude the source host from being an AUTH_ACCESS target to avoid loops,
# unless we want to represent local privesc (which is valid).
# But for pivoting, we care about A -> Cred -> B.
# If we found a source, and this host is that source, treat it as CONTAINS (already done).
# Otherwise, treat as AUTH_ACCESS.
if source_host and host_id == source_host:
continue
# If the note says "ssh", assume SSH access
protocol = "ssh" if "ssh" in content.lower() else "unknown"
self._add_edge(cred_id, host_id, "AUTH_ACCESS", protocol=protocol)
def _process_finding(
self,
key: str,
content: str,
related_hosts: List[str],
metadata: Dict[str, Any],
status: str,
) -> None:
"""Process a finding note (e.g., open ports)."""
# Skip if status is closed/filtered
if status in ["closed", "filtered"]:
return
# Filter related_hosts: If we have explicit target metadata, ONLY use that.
# Otherwise, use all related hosts (fallback to regex behavior).
target_hosts = related_hosts
if metadata.get("target"):
target_ip = metadata["target"]
target_id = f"host:{target_ip}"
# Only use the target if it's in the related_hosts list (sanity check)
if target_id in related_hosts:
target_hosts = [target_id]
# Extract ports from metadata or regex
ports = []
if metadata.get("port"):
# Handle single port in metadata
p = str(metadata["port"])
# Assume tcp if not specified?
proto = "tcp"
if "/" in p:
p, proto = p.split("/")
ports.append((p, proto))
# Always check regex too, in case metadata missed some
regex_ports = self._port_pattern.findall(content)
for p, proto in regex_ports:
if (p, proto) not in ports:
ports.append((p, proto))
for port, proto in ports:
for host_id in target_hosts:
service_id = f"service:{host_id}:{port}"
# Add URL to label if present
label = f"{port}/{proto}"
if metadata.get("url"):
label += f" ({metadata['url']})"
self._add_node(service_id, "service", label)
self._add_edge(host_id, service_id, "HAS_SERVICE", protocol=proto)
def _process_vulnerability(
self,
key: str,
content: str,
related_hosts: List[str],
metadata: Dict[str, Any],
status: str,
) -> None:
"""Process a vulnerability note."""
# Skip if status is closed/filtered (patched or not vulnerable)
if status in ["closed", "filtered"]:
return
# Filter related_hosts: If we have explicit target metadata, ONLY use that.
target_hosts = related_hosts
if metadata.get("target"):
target_ip = metadata["target"]
target_id = f"host:{target_ip}"
if target_id in related_hosts:
target_hosts = [target_id]
vuln_id = f"vuln:{key}"
# Try to extract CVE from metadata or regex
label = "Vulnerability"
if metadata.get("cve"):
label = metadata["cve"]
else:
cve_match = re.search(r"CVE-\d{4}-\d{4,7}", content, re.IGNORECASE)
if cve_match:
label = cve_match.group(0)
self._add_node(vuln_id, "vulnerability", label)
for host_id in target_hosts:
self._add_edge(host_id, vuln_id, "AFFECTED_BY")
def get_strategic_insights(self) -> List[str]:
"""
Analyze the graph and return natural language insights for the Orchestrator.
"""
insights = []
# Insight 1: Unused Credentials
# Find credentials that have AUTH_ACCESS to a host, but we haven't "explored" that host fully?
# Or simply list valid access paths.
for node, data in self.graph.nodes(data=True):
if data.get("type") == "credential":
# Find what it connects to
targets = [v for u, v in self.graph.out_edges(node)]
if targets:
target_labels = [
self.graph.nodes[t].get("label", t) for t in targets
]
insights.append(
f"We have credentials that provide access to: {', '.join(target_labels)}"
)
# Insight 2: High Value Targets (Hosts with many open ports/vulns)
for node, data in self.graph.nodes(data=True):
if data.get("type") == "host":
# Count services
services = [
v
for u, v in self.graph.out_edges(node)
if self.graph.nodes[v].get("type") == "service"
]
vulns = [
v
for u, v in self.graph.out_edges(node)
if self.graph.nodes[v].get("type") == "vulnerability"
]
if len(services) > 0 or len(vulns) > 0:
insights.append(
f"Host {data['label']} has {len(services)} services and {len(vulns)} known vulnerabilities."
)
# Insight 3: Potential Pivots (Host A -> Cred -> Host B)
# Use NetworkX to find paths from Credentials to Hosts that aren't directly connected
attack_paths = self._find_attack_paths()
if attack_paths:
insights.extend(attack_paths)
return insights
def _find_attack_paths(self) -> List[str]:
"""
Find multi-step attack paths using shortest path algorithms.
Example: Credential A -> Host A -> Credential B -> Host B
"""
paths = []
creds = [n for n, d in self.graph.nodes(data=True) if d["type"] == "credential"]
hosts = [n for n, d in self.graph.nodes(data=True) if d["type"] == "host"]
for cred in creds:
for host in hosts:
# Skip if directly connected (we already know we have access)
if self.graph.has_edge(cred, host):
continue
try:
# Find shortest path
path = nx.shortest_path(self.graph, cred, host)
# Only interesting if it involves intermediate steps
if len(path) > 2:
# Convert IDs to Labels for readability
readable_path = []
for node_id in path:
node_data = self.graph.nodes[node_id]
readable_path.append(node_data.get("label", node_id))
paths.append(f"Attack Path Found: {' -> '.join(readable_path)}")
except nx.NetworkXNoPath:
continue
return paths
def to_mermaid(self) -> str:
"""Export graph to Mermaid flowchart format."""
lines = ["graph TD"]
# Add nodes
for node, data in self.graph.nodes(data=True):
# Sanitize ID for mermaid
safe_id = re.sub(r"[^a-zA-Z0-9]", "_", node)
label = data.get("label", node).replace('"', "'")
# Style based on type
if data["type"] == "host":
lines.append(f' {safe_id}["🖥️ {label}"]')
elif data["type"] == "credential":
lines.append(f' {safe_id}["🔑 {label}"]')
elif data["type"] == "vulnerability":
lines.append(f' {safe_id}["⚠️ {label}"]')
elif data["type"] == "service":
lines.append(f' {safe_id}["🔌 {label}"]')
else:
lines.append(f' {safe_id}["{label}"]')
# Add edges
for u, v, data in self.graph.edges(data=True):
safe_u = re.sub(r"[^a-zA-Z0-9]", "_", u)
safe_v = re.sub(r"[^a-zA-Z0-9]", "_", v)
edge_label = data.get("type", "")
lines.append(f" {safe_u} -->|{edge_label}| {safe_v}")
return "\n".join(lines)
def export_summary(self) -> str:
"""Export a text summary of the graph state."""
stats = {
"hosts": len(
[n for n, d in self.graph.nodes(data=True) if d["type"] == "host"]
),
"creds": len(
[n for n, d in self.graph.nodes(data=True) if d["type"] == "credential"]
),
"vulns": len(
[
n
for n, d in self.graph.nodes(data=True)
if d["type"] == "vulnerability"
]
),
}
return f"Graph State: {stats['hosts']} Hosts, {stats['creds']} Credentials, {stats['vulns']} Vulnerabilities"