mirror of
https://github.com/GH05TCREW/pentestagent.git
synced 2026-03-07 14:23:20 +00:00
test: add shadow graph and notes tests, fix linting
This commit is contained in:
11
README.md
11
README.md
@@ -153,12 +153,11 @@ ghostcrew mcp add <name> <command> [args...] # Add MCP server
|
||||
ghostcrew mcp test <name> # Test MCP connection
|
||||
```
|
||||
|
||||
## Knowledge Base (RAG)
|
||||
## Knowledge
|
||||
|
||||
Place files in `ghostcrew/knowledge/sources/` for RAG context injection:
|
||||
- `methodologies.md` - Testing methodologies
|
||||
- `cves.json` - CVE database
|
||||
- `wordlists.txt` - Common wordlists
|
||||
- **RAG:** Place methodologies, CVEs, or wordlists in `ghostcrew/knowledge/sources/` for automatic context injection.
|
||||
- **Notes:** Agents save findings to `loot/notes.json` with categories (`credential`, `vulnerability`, `finding`, `artifact`). Notes persist across sessions and are injected into agent context.
|
||||
- **Shadow Graph:** In Crew mode, the orchestrator builds a knowledge graph from notes to derive strategic insights (e.g., "We have credentials for host X").
|
||||
|
||||
## Project Structure
|
||||
|
||||
@@ -167,7 +166,7 @@ ghostcrew/
|
||||
agents/ # Agent implementations
|
||||
config/ # Settings and constants
|
||||
interface/ # TUI and CLI
|
||||
knowledge/ # RAG system
|
||||
knowledge/ # RAG system and shadow graph
|
||||
llm/ # LiteLLM wrapper
|
||||
mcp/ # MCP client and server configs
|
||||
runtime/ # Execution environment
|
||||
|
||||
@@ -5,11 +5,11 @@ import platform
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
from ...config.constants import DEFAULT_MAX_ITERATIONS
|
||||
from ...knowledge.graph import ShadowGraph
|
||||
from ..prompts import ghost_crew
|
||||
from .models import CrewState, WorkerCallback
|
||||
from .tools import create_crew_tools
|
||||
from .worker_pool import WorkerPool
|
||||
from ...knowledge.graph import ShadowGraph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...llm import LLM
|
||||
@@ -77,7 +77,7 @@ class CrewOrchestrator:
|
||||
else:
|
||||
cat = data.get("category", "info")
|
||||
content = data.get("content", "")
|
||||
|
||||
|
||||
# Truncate long notes in system prompt to save tokens
|
||||
if len(content) > 200:
|
||||
content = content[:197] + "..."
|
||||
@@ -85,11 +85,18 @@ class CrewOrchestrator:
|
||||
if cat not in grouped:
|
||||
grouped[cat] = []
|
||||
grouped[cat].append(f"- {key}: {content}")
|
||||
|
||||
|
||||
# Format output with specific order
|
||||
sections = []
|
||||
order = ["credential", "vulnerability", "finding", "artifact", "task", "info"]
|
||||
|
||||
order = [
|
||||
"credential",
|
||||
"vulnerability",
|
||||
"finding",
|
||||
"artifact",
|
||||
"task",
|
||||
"info",
|
||||
]
|
||||
|
||||
for cat in order:
|
||||
if cat in grouped:
|
||||
header = cat.title() + "s"
|
||||
@@ -97,13 +104,13 @@ class CrewOrchestrator:
|
||||
header = "General Information"
|
||||
sections.append(f"## {header}")
|
||||
sections.append("\n".join(grouped[cat]))
|
||||
|
||||
|
||||
# Add any remaining categories
|
||||
for cat in sorted(grouped.keys()):
|
||||
if cat not in order:
|
||||
sections.append(f"## {cat.title()}")
|
||||
sections.append("\n".join(grouped[cat]))
|
||||
|
||||
|
||||
notes_context = "\n\n".join(sections)
|
||||
except Exception:
|
||||
pass # Notes not available
|
||||
@@ -111,7 +118,9 @@ class CrewOrchestrator:
|
||||
# Format insights for prompt
|
||||
insights_text = ""
|
||||
if graph_insights:
|
||||
insights_text = "\n\n## Strategic Insights (Graph Analysis)\n" + "\n".join(f"- {i}" for i in graph_insights)
|
||||
insights_text = "\n\n## Strategic Insights (Graph Analysis)\n" + "\n".join(
|
||||
f"- {i}" for i in graph_insights
|
||||
)
|
||||
|
||||
return ghost_crew.render(
|
||||
target=self.target or "Not specified",
|
||||
|
||||
@@ -140,9 +140,11 @@ class WorkerPool:
|
||||
# Capture final response (text without tool calls)
|
||||
if response.content and not response.tool_calls:
|
||||
final_response = response.content
|
||||
|
||||
|
||||
# Check if max iterations was hit
|
||||
if response.metadata and response.metadata.get("max_iterations_reached"):
|
||||
if response.metadata and response.metadata.get(
|
||||
"max_iterations_reached"
|
||||
):
|
||||
hit_max_iterations = True
|
||||
|
||||
worker.result = final_response or "No findings."
|
||||
|
||||
@@ -81,7 +81,7 @@ class GhostCrewAgent(BaseAgent):
|
||||
else:
|
||||
cat = data.get("category", "info")
|
||||
content = data.get("content", "")
|
||||
|
||||
|
||||
# Truncate long notes in system prompt to save tokens
|
||||
# The agent can use the 'read' tool to get the full content
|
||||
if len(content) > 200:
|
||||
@@ -90,11 +90,18 @@ class GhostCrewAgent(BaseAgent):
|
||||
if cat not in grouped:
|
||||
grouped[cat] = []
|
||||
grouped[cat].append(f"- {key}: {content}")
|
||||
|
||||
|
||||
# Format output with specific order
|
||||
sections = []
|
||||
order = ["credential", "vulnerability", "finding", "artifact", "task", "info"]
|
||||
|
||||
order = [
|
||||
"credential",
|
||||
"vulnerability",
|
||||
"finding",
|
||||
"artifact",
|
||||
"task",
|
||||
"info",
|
||||
]
|
||||
|
||||
for cat in order:
|
||||
if cat in grouped:
|
||||
header = cat.title() + "s"
|
||||
@@ -102,13 +109,13 @@ class GhostCrewAgent(BaseAgent):
|
||||
header = "General Information"
|
||||
sections.append(f"## {header}")
|
||||
sections.append("\n".join(grouped[cat]))
|
||||
|
||||
|
||||
# Add any remaining categories
|
||||
for cat in sorted(grouped.keys()):
|
||||
if cat not in order:
|
||||
sections.append(f"## {cat.title()}")
|
||||
sections.append("\n".join(grouped[cat]))
|
||||
|
||||
|
||||
notes_context = "\n\n".join(sections)
|
||||
except Exception:
|
||||
pass # Notes not available
|
||||
|
||||
@@ -13,7 +13,7 @@ Architecture:
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class GraphNode:
|
||||
"""A node in the shadow graph."""
|
||||
|
||||
id: str
|
||||
type: str # host, service, credential, finding, artifact
|
||||
label: str
|
||||
@@ -35,6 +36,7 @@ class GraphNode:
|
||||
@dataclass
|
||||
class GraphEdge:
|
||||
"""An edge in the shadow graph."""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
type: str # CONNECTS_TO, HAS_SERVICE, AUTH_ACCESS, RELATED_TO
|
||||
@@ -49,16 +51,16 @@ class ShadowGraph:
|
||||
def __init__(self):
|
||||
self.graph = nx.DiGraph()
|
||||
self._processed_notes: Set[str] = set()
|
||||
|
||||
|
||||
# Regex patterns for entity extraction
|
||||
self._ip_pattern = re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b')
|
||||
self._port_pattern = re.compile(r'(\d{1,5})/(tcp|udp)')
|
||||
self._user_pattern = re.compile(r'user[:\s]+([a-zA-Z0-9_.-]+)', re.IGNORECASE)
|
||||
self._ip_pattern = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")
|
||||
self._port_pattern = re.compile(r"(\d{1,5})/(tcp|udp)")
|
||||
self._user_pattern = re.compile(r"user[:\s]+([a-zA-Z0-9_.-]+)", re.IGNORECASE)
|
||||
|
||||
def update_from_notes(self, notes: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Update the graph based on new notes.
|
||||
|
||||
|
||||
This method is idempotent and incremental. It only processes notes
|
||||
that haven't been seen before (based on key).
|
||||
"""
|
||||
@@ -79,7 +81,7 @@ class ShadowGraph:
|
||||
|
||||
def _process_note(self, key: str, content: str, category: str) -> None:
|
||||
"""Extract entities and relationships from a single note."""
|
||||
|
||||
|
||||
# 1. Extract IPs (Hosts)
|
||||
ips = self._ip_pattern.findall(content)
|
||||
hosts = []
|
||||
@@ -110,12 +112,14 @@ class ShadowGraph:
|
||||
if self.graph.has_node(source) and self.graph.has_node(target):
|
||||
self.graph.add_edge(source, target, type=edge_type, **kwargs)
|
||||
|
||||
def _process_credential(self, key: str, content: str, related_hosts: List[str]) -> None:
|
||||
def _process_credential(
|
||||
self, key: str, content: str, related_hosts: List[str]
|
||||
) -> None:
|
||||
"""Process a credential note."""
|
||||
# Extract username
|
||||
user_match = self._user_pattern.search(content)
|
||||
username = user_match.group(1) if user_match else "unknown"
|
||||
|
||||
|
||||
cred_id = f"cred:{key}"
|
||||
self._add_node(cred_id, "credential", f"Creds ({username})")
|
||||
|
||||
@@ -125,7 +129,9 @@ class ShadowGraph:
|
||||
protocol = "ssh" if "ssh" in content.lower() else "unknown"
|
||||
self._add_edge(cred_id, host_id, "AUTH_ACCESS", protocol=protocol)
|
||||
|
||||
def _process_finding(self, key: str, content: str, related_hosts: List[str]) -> None:
|
||||
def _process_finding(
|
||||
self, key: str, content: str, related_hosts: List[str]
|
||||
) -> None:
|
||||
"""Process a finding note (e.g., open ports)."""
|
||||
# Extract ports
|
||||
ports = self._port_pattern.findall(content)
|
||||
@@ -135,15 +141,17 @@ class ShadowGraph:
|
||||
self._add_node(service_id, "service", f"{port}/{proto}")
|
||||
self._add_edge(host_id, service_id, "HAS_SERVICE", protocol=proto)
|
||||
|
||||
def _process_vulnerability(self, key: str, content: str, related_hosts: List[str]) -> None:
|
||||
def _process_vulnerability(
|
||||
self, key: str, content: str, related_hosts: List[str]
|
||||
) -> None:
|
||||
"""Process a vulnerability note."""
|
||||
vuln_id = f"vuln:{key}"
|
||||
# Try to extract CVE
|
||||
cve_match = re.search(r'CVE-\d{4}-\d{4,7}', content, re.IGNORECASE)
|
||||
cve_match = re.search(r"CVE-\d{4}-\d{4,7}", content, re.IGNORECASE)
|
||||
label = cve_match.group(0) if cve_match else "Vulnerability"
|
||||
|
||||
|
||||
self._add_node(vuln_id, "vulnerability", label)
|
||||
|
||||
|
||||
for host_id in related_hosts:
|
||||
self._add_edge(host_id, vuln_id, "AFFECTED_BY")
|
||||
|
||||
@@ -152,7 +160,7 @@ class ShadowGraph:
|
||||
Analyze the graph and return natural language insights for the Orchestrator.
|
||||
"""
|
||||
insights = []
|
||||
|
||||
|
||||
# Insight 1: Unused Credentials
|
||||
# Find credentials that have AUTH_ACCESS to a host, but we haven't "explored" that host fully?
|
||||
# Or simply list valid access paths.
|
||||
@@ -161,29 +169,53 @@ class ShadowGraph:
|
||||
# Find what it connects to
|
||||
targets = [v for u, v in self.graph.out_edges(node)]
|
||||
if targets:
|
||||
target_labels = [self.graph.nodes[t].get("label", t) for t in targets]
|
||||
insights.append(f"We have credentials that provide access to: {', '.join(target_labels)}")
|
||||
target_labels = [
|
||||
self.graph.nodes[t].get("label", t) for t in targets
|
||||
]
|
||||
insights.append(
|
||||
f"We have credentials that provide access to: {', '.join(target_labels)}"
|
||||
)
|
||||
|
||||
# Insight 2: High Value Targets (Hosts with many open ports/vulns)
|
||||
for node, data in self.graph.nodes(data=True):
|
||||
if data.get("type") == "host":
|
||||
# Count services
|
||||
services = [v for u, v in self.graph.out_edges(node) if self.graph.nodes[v].get("type") == "service"]
|
||||
vulns = [v for u, v in self.graph.out_edges(node) if self.graph.nodes[v].get("type") == "vulnerability"]
|
||||
|
||||
services = [
|
||||
v
|
||||
for u, v in self.graph.out_edges(node)
|
||||
if self.graph.nodes[v].get("type") == "service"
|
||||
]
|
||||
vulns = [
|
||||
v
|
||||
for u, v in self.graph.out_edges(node)
|
||||
if self.graph.nodes[v].get("type") == "vulnerability"
|
||||
]
|
||||
|
||||
if len(services) > 0 or len(vulns) > 0:
|
||||
insights.append(f"Host {data['label']} has {len(services)} services and {len(vulns)} known vulnerabilities.")
|
||||
insights.append(
|
||||
f"Host {data['label']} has {len(services)} services and {len(vulns)} known vulnerabilities."
|
||||
)
|
||||
|
||||
# Insight 3: Potential Pivots (Host A -> Cred -> Host B)
|
||||
# This is harder without explicit "source" of creds, but we can infer.
|
||||
|
||||
|
||||
return insights
|
||||
|
||||
def export_summary(self) -> str:
|
||||
"""Export a text summary of the graph state."""
|
||||
stats = {
|
||||
"hosts": len([n for n, d in self.graph.nodes(data=True) if d['type'] == 'host']),
|
||||
"creds": len([n for n, d in self.graph.nodes(data=True) if d['type'] == 'credential']),
|
||||
"vulns": len([n for n, d in self.graph.nodes(data=True) if d['type'] == 'vulnerability']),
|
||||
"hosts": len(
|
||||
[n for n, d in self.graph.nodes(data=True) if d["type"] == "host"]
|
||||
),
|
||||
"creds": len(
|
||||
[n for n, d in self.graph.nodes(data=True) if d["type"] == "credential"]
|
||||
),
|
||||
"vulns": len(
|
||||
[
|
||||
n
|
||||
for n, d in self.graph.nodes(data=True)
|
||||
if d["type"] == "vulnerability"
|
||||
]
|
||||
),
|
||||
}
|
||||
return f"Graph State: {stats['hosts']} Hosts, {stats['creds']} Credentials, {stats['vulns']} Vulnerabilities"
|
||||
|
||||
@@ -24,7 +24,11 @@ def _load_notes_unlocked() -> None:
|
||||
_notes = {}
|
||||
for k, v in loaded.items():
|
||||
if isinstance(v, str):
|
||||
_notes[k] = {"content": v, "category": "info", "confidence": "medium"}
|
||||
_notes[k] = {
|
||||
"content": v,
|
||||
"category": "info",
|
||||
"confidence": "medium",
|
||||
}
|
||||
else:
|
||||
_notes[k] = v
|
||||
except (json.JSONDecodeError, IOError):
|
||||
@@ -55,7 +59,11 @@ def get_all_notes_sync() -> Dict[str, Dict[str, Any]]:
|
||||
result = {}
|
||||
for k, v in loaded.items():
|
||||
if isinstance(v, str):
|
||||
result[k] = {"content": v, "category": "info", "confidence": "medium"}
|
||||
result[k] = {
|
||||
"content": v,
|
||||
"category": "info",
|
||||
"confidence": "medium",
|
||||
}
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
@@ -96,7 +104,14 @@ _load_notes_unlocked()
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["finding", "credential", "task", "info", "vulnerability", "artifact"],
|
||||
"enum": [
|
||||
"finding",
|
||||
"credential",
|
||||
"task",
|
||||
"info",
|
||||
"vulnerability",
|
||||
"artifact",
|
||||
],
|
||||
"description": "Category for organization (default: info)",
|
||||
},
|
||||
"confidence": {
|
||||
@@ -123,13 +138,20 @@ async def notes(arguments: dict, runtime) -> str:
|
||||
action = arguments["action"]
|
||||
key = arguments.get("key", "").strip()
|
||||
value = arguments.get("value", "")
|
||||
|
||||
|
||||
# Soft validation for category
|
||||
category = arguments.get("category", "info")
|
||||
valid_categories = ["finding", "credential", "task", "info", "vulnerability", "artifact"]
|
||||
valid_categories = [
|
||||
"finding",
|
||||
"credential",
|
||||
"task",
|
||||
"info",
|
||||
"vulnerability",
|
||||
"artifact",
|
||||
]
|
||||
if category not in valid_categories:
|
||||
category = "info"
|
||||
|
||||
|
||||
confidence = arguments.get("confidence", "medium")
|
||||
|
||||
async with _notes_lock:
|
||||
@@ -144,7 +166,7 @@ async def notes(arguments: dict, runtime) -> str:
|
||||
_notes[key] = {
|
||||
"content": value,
|
||||
"category": category,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
_save_notes_unlocked()
|
||||
return f"Created note '{key}' ({category})"
|
||||
@@ -156,7 +178,9 @@ async def notes(arguments: dict, runtime) -> str:
|
||||
return f"Note '{key}' not found"
|
||||
|
||||
note = _notes[key]
|
||||
return f"[{key}] ({note['category']}, {note['confidence']}) {note['content']}"
|
||||
return (
|
||||
f"[{key}] ({note['category']}, {note['confidence']}) {note['content']}"
|
||||
)
|
||||
|
||||
elif action == "update":
|
||||
if not key:
|
||||
@@ -165,17 +189,17 @@ async def notes(arguments: dict, runtime) -> str:
|
||||
return "Error: value is required for update"
|
||||
|
||||
existed = key in _notes
|
||||
# Preserve existing metadata if not provided? No, overwrite is cleaner for now,
|
||||
# Preserve existing metadata if not provided? No, overwrite is cleaner for now,
|
||||
# but maybe we should default to existing if not provided.
|
||||
# For now, let's just overwrite with defaults if missing, or use provided.
|
||||
# Actually, if updating, we might want to keep category if not specified.
|
||||
# But arguments.get("category", "info") defaults to info.
|
||||
# Let's stick to simple overwrite for now to match previous behavior.
|
||||
|
||||
|
||||
_notes[key] = {
|
||||
"content": value,
|
||||
"category": category,
|
||||
"confidence": confidence
|
||||
"confidence": confidence,
|
||||
}
|
||||
_save_notes_unlocked()
|
||||
return f"{'Updated' if existed else 'Created'} note '{key}'"
|
||||
@@ -195,7 +219,7 @@ async def notes(arguments: dict, runtime) -> str:
|
||||
return "No notes saved"
|
||||
|
||||
lines = [f"Notes ({len(_notes)} entries):"]
|
||||
|
||||
|
||||
# Group by category for display
|
||||
by_category = {}
|
||||
for k, v in _notes.items():
|
||||
@@ -203,12 +227,14 @@ async def notes(arguments: dict, runtime) -> str:
|
||||
if cat not in by_category:
|
||||
by_category[cat] = []
|
||||
by_category[cat].append((k, v))
|
||||
|
||||
|
||||
for cat in sorted(by_category.keys()):
|
||||
lines.append(f"\n## {cat.title()}")
|
||||
for k, v in by_category[cat]:
|
||||
content = v["content"]
|
||||
display_val = content if len(content) <= 60 else content[:57] + "..."
|
||||
display_val = (
|
||||
content if len(content) <= 60 else content[:57] + "..."
|
||||
)
|
||||
conf = v.get("confidence", "medium")
|
||||
lines.append(f" [{k}] ({conf}) {display_val}")
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ dependencies = [
|
||||
"beautifulsoup4>=4.12.0",
|
||||
"httpx>=0.27.0",
|
||||
"numpy>=1.26.0",
|
||||
"networkx>=3.3",
|
||||
"docker>=7.0.0",
|
||||
"rich>=13.7.0",
|
||||
"textual>=0.63.0",
|
||||
|
||||
127
tests/test_graph.py
Normal file
127
tests/test_graph.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Tests for the Shadow Graph knowledge system."""
|
||||
|
||||
import pytest
|
||||
import networkx as nx
|
||||
from ghostcrew.knowledge.graph import ShadowGraph, GraphNode, GraphEdge
|
||||
|
||||
class TestShadowGraph:
|
||||
"""Tests for ShadowGraph class."""
|
||||
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
"""Create a fresh ShadowGraph for each test."""
|
||||
return ShadowGraph()
|
||||
|
||||
def test_initialization(self, graph):
|
||||
"""Test graph initialization."""
|
||||
assert isinstance(graph.graph, nx.DiGraph)
|
||||
assert len(graph.graph.nodes) == 0
|
||||
assert len(graph._processed_notes) == 0
|
||||
|
||||
def test_extract_host_from_note(self, graph):
|
||||
"""Test extracting host IP from a note."""
|
||||
notes = {
|
||||
"scan_result": {
|
||||
"content": "Nmap scan for 192.168.1.10 shows open ports.",
|
||||
"category": "info"
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
assert graph.graph.has_node("host:192.168.1.10")
|
||||
node = graph.graph.nodes["host:192.168.1.10"]
|
||||
assert node["type"] == "host"
|
||||
assert node["label"] == "192.168.1.10"
|
||||
|
||||
def test_extract_service_finding(self, graph):
|
||||
"""Test extracting services from a finding note."""
|
||||
notes = {
|
||||
"ports_scan": {
|
||||
"content": "Found open ports: 80/tcp, 443/tcp on 10.0.0.5",
|
||||
"category": "finding"
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
# Check host exists
|
||||
assert graph.graph.has_node("host:10.0.0.5")
|
||||
|
||||
# Check services exist
|
||||
assert graph.graph.has_node("service:host:10.0.0.5:80")
|
||||
assert graph.graph.has_node("service:host:10.0.0.5:443")
|
||||
|
||||
# Check edges
|
||||
assert graph.graph.has_edge("host:10.0.0.5", "service:host:10.0.0.5:80")
|
||||
edge = graph.graph.edges["host:10.0.0.5", "service:host:10.0.0.5:80"]
|
||||
assert edge["type"] == "HAS_SERVICE"
|
||||
assert edge["protocol"] == "tcp"
|
||||
|
||||
def test_extract_credential(self, graph):
|
||||
"""Test extracting credentials and linking to host."""
|
||||
notes = {
|
||||
"ssh_creds": {
|
||||
"content": "Found user: admin with password 'password123' for SSH on 192.168.1.20",
|
||||
"category": "credential"
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
cred_id = "cred:ssh_creds"
|
||||
host_id = "host:192.168.1.20"
|
||||
|
||||
assert graph.graph.has_node(cred_id)
|
||||
assert graph.graph.has_node(host_id)
|
||||
|
||||
# Check edge
|
||||
assert graph.graph.has_edge(cred_id, host_id)
|
||||
edge = graph.graph.edges[cred_id, host_id]
|
||||
assert edge["type"] == "AUTH_ACCESS"
|
||||
assert edge["protocol"] == "ssh"
|
||||
|
||||
def test_legacy_note_format(self, graph):
|
||||
"""Test handling legacy string-only notes."""
|
||||
notes = {
|
||||
"legacy_note": "Just a simple note about 10.10.10.10"
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
assert graph.graph.has_node("host:10.10.10.10")
|
||||
|
||||
def test_idempotency(self, graph):
|
||||
"""Test that processing the same note twice doesn't duplicate or error."""
|
||||
notes = {
|
||||
"scan": {
|
||||
"content": "Host 192.168.1.1 is up.",
|
||||
"category": "info"
|
||||
}
|
||||
}
|
||||
|
||||
# First pass
|
||||
graph.update_from_notes(notes)
|
||||
assert len(graph.graph.nodes) == 1
|
||||
|
||||
# Second pass
|
||||
graph.update_from_notes(notes)
|
||||
assert len(graph.graph.nodes) == 1
|
||||
|
||||
# Modify note (simulate update - though currently graph only processes new keys,
|
||||
# in a real scenario we might want to handle updates, but for now we test it ignores processed keys)
|
||||
notes["scan"]["content"] = "Host 192.168.1.1 is down."
|
||||
graph.update_from_notes(notes)
|
||||
# Should still be based on first pass if we strictly check processed keys
|
||||
# The current implementation uses a set of processed keys, so it won't re-process.
|
||||
assert len(graph.graph.nodes) == 1
|
||||
|
||||
def test_multiple_ips_in_one_note(self, graph):
|
||||
"""Test a single note referencing multiple hosts."""
|
||||
notes = {
|
||||
"subnet_scan": {
|
||||
"content": "Scanning 192.168.1.1, 192.168.1.2, and 192.168.1.3",
|
||||
"category": "info"
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
assert graph.graph.has_node("host:192.168.1.1")
|
||||
assert graph.graph.has_node("host:192.168.1.2")
|
||||
assert graph.graph.has_node("host:192.168.1.3")
|
||||
162
tests/test_notes.py
Normal file
162
tests/test_notes.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Tests for the Notes tool."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from ghostcrew.tools.notes import notes, set_notes_file, get_all_notes, _notes
|
||||
|
||||
# We need to reset the global state for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_notes_state(tmp_path):
|
||||
"""Reset the notes global state for each test."""
|
||||
# Point to a temp file
|
||||
temp_notes_file = tmp_path / "notes.json"
|
||||
set_notes_file(temp_notes_file)
|
||||
|
||||
# Clear the global dictionary (it's imported from the module)
|
||||
# We need to clear the actual dictionary object in the module
|
||||
from ghostcrew.tools.notes import _notes
|
||||
_notes.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup is handled by tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note():
|
||||
"""Test creating a new note."""
|
||||
args = {
|
||||
"action": "create",
|
||||
"key": "test_note",
|
||||
"value": "This is a test note",
|
||||
"category": "info",
|
||||
"confidence": "high"
|
||||
}
|
||||
|
||||
result = await notes(args, runtime=None)
|
||||
assert "Created note 'test_note'" in result
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert "test_note" in all_notes
|
||||
assert all_notes["test_note"]["content"] == "This is a test note"
|
||||
assert all_notes["test_note"]["category"] == "info"
|
||||
assert all_notes["test_note"]["confidence"] == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_note():
|
||||
"""Test reading an existing note."""
|
||||
# Create first
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "read_me",
|
||||
"value": "Content to read"
|
||||
}, runtime=None)
|
||||
|
||||
# Read
|
||||
result = await notes({
|
||||
"action": "read",
|
||||
"key": "read_me"
|
||||
}, runtime=None)
|
||||
|
||||
assert "Content to read" in result
|
||||
# The format is "[key] (category, confidence) content"
|
||||
assert "(info, medium)" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note():
|
||||
"""Test updating a note."""
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "update_me",
|
||||
"value": "Original content"
|
||||
}, runtime=None)
|
||||
|
||||
result = await notes({
|
||||
"action": "update",
|
||||
"key": "update_me",
|
||||
"value": "New content"
|
||||
}, runtime=None)
|
||||
|
||||
assert "Updated note 'update_me'" in result
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert all_notes["update_me"]["content"] == "New content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_note():
|
||||
"""Test deleting a note."""
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "delete_me",
|
||||
"value": "Bye bye"
|
||||
}, runtime=None)
|
||||
|
||||
result = await notes({
|
||||
"action": "delete",
|
||||
"key": "delete_me"
|
||||
}, runtime=None)
|
||||
|
||||
assert "Deleted note 'delete_me'" in result
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert "delete_me" not in all_notes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes():
|
||||
"""Test listing all notes."""
|
||||
await notes({"action": "create", "key": "n1", "value": "v1"}, runtime=None)
|
||||
await notes({"action": "create", "key": "n2", "value": "v2"}, runtime=None)
|
||||
|
||||
result = await notes({"action": "list"}, runtime=None)
|
||||
|
||||
assert "n1" in result
|
||||
assert "n2" in result
|
||||
assert "Notes (2 entries):" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence(tmp_path):
|
||||
"""Test that notes are saved to disk."""
|
||||
# The fixture already sets a temp file
|
||||
temp_file = tmp_path / "notes.json"
|
||||
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "persistent_note",
|
||||
"value": "I survive restarts"
|
||||
}, runtime=None)
|
||||
|
||||
assert temp_file.exists()
|
||||
content = json.loads(temp_file.read_text())
|
||||
assert "persistent_note" in content
|
||||
assert content["persistent_note"]["content"] == "I survive restarts"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_migration(tmp_path):
|
||||
"""Test migration of legacy string notes."""
|
||||
# Create a legacy file
|
||||
legacy_file = tmp_path / "legacy_notes.json"
|
||||
legacy_data = {
|
||||
"old_note": "Just a string",
|
||||
"new_note": {"content": "A dict", "category": "info"}
|
||||
}
|
||||
legacy_file.write_text(json.dumps(legacy_data))
|
||||
|
||||
# Point the tool to this file
|
||||
set_notes_file(legacy_file)
|
||||
|
||||
# Trigger load (get_all_notes calls _load_notes_unlocked if empty, but we need to clear first)
|
||||
from ghostcrew.tools.notes import _notes
|
||||
_notes.clear()
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
|
||||
assert "old_note" in all_notes
|
||||
assert isinstance(all_notes["old_note"], dict)
|
||||
assert all_notes["old_note"]["content"] == "Just a string"
|
||||
assert all_notes["old_note"]["category"] == "info"
|
||||
|
||||
assert "new_note" in all_notes
|
||||
assert all_notes["new_note"]["content"] == "A dict"
|
||||
Reference in New Issue
Block a user