chore: apply lint and formatting fix

This commit is contained in:
GH05TCREW
2025-12-22 15:48:25 -07:00
parent 642ce2f4cc
commit 9b14094a4a
8 changed files with 176 additions and 145 deletions

View File

@@ -30,7 +30,10 @@ from textual.widgets.tree import TreeNode
from ..config.constants import DEFAULT_MODEL
# ANSI escape sequence pattern for stripping control codes from input
_ANSI_ESCAPE = re.compile(r'\x1b\[[0-9;]*[mGKHflSTABCDEFsu]|\x1b\].*?\x07|\x1b\[<[0-9;]*[Mm]')
_ANSI_ESCAPE = re.compile(
r"\x1b\[[0-9;]*[mGKHflSTABCDEFsu]|\x1b\].*?\x07|\x1b\[<[0-9;]*[Mm]"
)
# ASCII-safe scrollbar renderer to avoid Unicode glyph issues
class ASCIIScrollBarRender(ScrollBarRender):
@@ -1044,7 +1047,7 @@ Be concise. Use the actual data from notes."""
return
# Strip ANSI escape sequences and control codes
message = _ANSI_ESCAPE.sub('', event.value).strip()
message = _ANSI_ESCAPE.sub("", event.value).strip()
if not message:
return

View File

@@ -234,7 +234,7 @@ class ShadowGraph:
product = svc.get("product", "")
version = svc.get("version", "")
proto = svc.get("protocol", "tcp")
if port:
for host_id in target_hosts:
service_id = f"service:{host_id}:{port}"
@@ -243,10 +243,18 @@ class ShadowGraph:
label += f" {product}"
if version:
label += f" {version}"
self._add_node(service_id, "service", label, product=product, version=version)
self._add_edge(host_id, service_id, "HAS_SERVICE", protocol=proto)
self._add_node(
service_id,
"service",
label,
product=product,
version=version,
)
self._add_edge(
host_id, service_id, "HAS_SERVICE", protocol=proto
)
# Handle nested endpoints metadata
if metadata.get("endpoints"):
for ep in metadata["endpoints"]:
@@ -258,10 +266,10 @@ class ShadowGraph:
label = path
if methods:
label += f" ({','.join(methods)})"
self._add_node(endpoint_id, "endpoint", label, methods=methods)
self._add_edge(host_id, endpoint_id, "HAS_ENDPOINT")
# Handle nested technologies metadata
if metadata.get("technologies"):
for tech in metadata["technologies"]:
@@ -273,12 +281,18 @@ class ShadowGraph:
label = name
if version and version != "unknown":
label += f" {version}"
self._add_node(tech_id, "technology", label, name=name, version=version)
self._add_node(
tech_id, "technology", label, name=name, version=version
)
self._add_edge(host_id, tech_id, "USES_TECH")
# If we processed nested metadata, we're done
if metadata.get("services") or metadata.get("endpoints") or metadata.get("technologies"):
if (
metadata.get("services")
or metadata.get("endpoints")
or metadata.get("technologies")
):
return
# Fallback to old port extraction logic
@@ -378,8 +392,16 @@ class ShadowGraph:
)
# Insight 2: High Value Targets (Hosts with many open ports/vulns/endpoints)
high_value_endpoints = ["admin", "phpmyadmin", "phpMyAdmin", "manager", "console", "webdav", "dav"]
high_value_endpoints = [
"admin",
"phpmyadmin",
"phpMyAdmin",
"manager",
"console",
"webdav",
"dav",
]
for node, data in self.graph.nodes(data=True):
if data.get("type") == "host":
# Count services
@@ -404,7 +426,12 @@ class ShadowGraph:
if self.graph.nodes[v].get("type") == "technology"
]
if len(services) > 0 or len(vulns) > 0 or len(endpoints) > 0 or len(technologies) > 0:
if (
len(services) > 0
or len(vulns) > 0
or len(endpoints) > 0
or len(technologies) > 0
):
parts = []
if len(services) > 0:
parts.append(f"{len(services)} services")
@@ -414,10 +441,8 @@ class ShadowGraph:
parts.append(f"{len(technologies)} technologies")
if len(vulns) > 0:
parts.append(f"{len(vulns)} vulnerabilities")
insights.append(
f"Host {data['label']} has {', '.join(parts)}."
)
insights.append(f"Host {data['label']} has {', '.join(parts)}.")
# Flag high-value endpoints
for ep_id in endpoints:
ep_label = self.graph.nodes[ep_id].get("label", "")

View File

@@ -1,14 +1,14 @@
"""Test fixtures for PentestAgent tests."""
import pytest
import asyncio
from pathlib import Path
from typing import Generator, AsyncGenerator
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
import pytest
from pentestagent.agents.state import AgentStateManager
from pentestagent.config import Settings
from pentestagent.agents.state import AgentState, AgentStateManager
from pentestagent.tools import get_all_tools, Tool, ToolSchema
from pentestagent.tools import Tool, ToolSchema
@pytest.fixture
@@ -82,7 +82,7 @@ def sample_tool() -> Tool:
"""Create a sample tool for testing."""
async def dummy_execute(arguments: dict, runtime) -> str:
return f"Executed with: {arguments}"
return Tool(
name="test_tool",
description="A test tool",

View File

@@ -1,14 +1,14 @@
"""Tests for the agent state management."""
import pytest
from datetime import datetime
from pentestagent.agents.state import AgentState, AgentStateManager, StateTransition
class TestAgentState:
"""Tests for AgentState enum."""
def test_state_values(self):
"""Test state enum values."""
assert AgentState.IDLE.value == "idle"
@@ -17,7 +17,7 @@ class TestAgentState:
assert AgentState.WAITING_INPUT.value == "waiting_input"
assert AgentState.COMPLETE.value == "complete"
assert AgentState.ERROR.value == "error"
def test_all_states_exist(self):
"""Test that all expected states exist."""
states = list(AgentState)
@@ -26,76 +26,76 @@ class TestAgentState:
class TestAgentStateManager:
"""Tests for AgentStateManager class."""
@pytest.fixture
def state_manager(self):
"""Create a fresh AgentStateManager for each test."""
return AgentStateManager()
def test_initial_state(self, state_manager):
"""Test initial state is IDLE."""
assert state_manager.current_state == AgentState.IDLE
assert len(state_manager.history) == 0
def test_valid_transition(self, state_manager):
"""Test valid state transition."""
result = state_manager.transition_to(AgentState.THINKING)
assert result is True
assert state_manager.current_state == AgentState.THINKING
assert len(state_manager.history) == 1
def test_invalid_transition(self, state_manager):
"""Test invalid state transition."""
result = state_manager.transition_to(AgentState.COMPLETE)
assert result is False
assert state_manager.current_state == AgentState.IDLE
def test_transition_chain(self, state_manager):
"""Test a chain of valid transitions."""
assert state_manager.transition_to(AgentState.THINKING)
assert state_manager.transition_to(AgentState.EXECUTING)
assert state_manager.transition_to(AgentState.THINKING)
assert state_manager.transition_to(AgentState.COMPLETE)
assert state_manager.current_state == AgentState.COMPLETE
assert len(state_manager.history) == 4
def test_force_transition(self, state_manager):
"""Test forcing a transition."""
state_manager.force_transition(AgentState.ERROR, reason="Test error")
assert state_manager.current_state == AgentState.ERROR
assert "FORCED" in state_manager.history[-1].reason
def test_reset(self, state_manager):
"""Test resetting state."""
state_manager.transition_to(AgentState.THINKING)
state_manager.transition_to(AgentState.EXECUTING)
state_manager.reset()
assert state_manager.current_state == AgentState.IDLE
assert len(state_manager.history) == 0
def test_is_terminal(self, state_manager):
"""Test terminal state detection."""
assert state_manager.is_terminal() is False
state_manager.transition_to(AgentState.THINKING)
state_manager.transition_to(AgentState.COMPLETE)
assert state_manager.is_terminal() is True
def test_is_active(self, state_manager):
"""Test active state detection."""
assert state_manager.is_active() is False
state_manager.transition_to(AgentState.THINKING)
assert state_manager.is_active() is True
class TestStateTransition:
"""Tests for StateTransition dataclass."""
def test_create_transition(self):
"""Test creating a state transition."""
transition = StateTransition(
@@ -103,7 +103,7 @@ class TestStateTransition:
to_state=AgentState.THINKING,
reason="Starting work"
)
assert transition.from_state == AgentState.IDLE
assert transition.to_state == AgentState.THINKING
assert transition.reason == "Starting work"

View File

@@ -1,8 +1,10 @@
"""Tests for the Shadow Graph knowledge system."""
import pytest
import networkx as nx
from pentestagent.knowledge.graph import ShadowGraph, GraphNode, GraphEdge
import pytest
from pentestagent.knowledge.graph import ShadowGraph
class TestShadowGraph:
"""Tests for ShadowGraph class."""
@@ -27,7 +29,7 @@ class TestShadowGraph:
}
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:192.168.1.10")
node = graph.graph.nodes["host:192.168.1.10"]
assert node["type"] == "host"
@@ -42,14 +44,14 @@ class TestShadowGraph:
}
}
graph.update_from_notes(notes)
# Check host exists
assert graph.graph.has_node("host:10.0.0.5")
# Check services exist
assert graph.graph.has_node("service:host:10.0.0.5:80")
assert graph.graph.has_node("service:host:10.0.0.5:443")
# Check edges
assert graph.graph.has_edge("host:10.0.0.5", "service:host:10.0.0.5:80")
edge = graph.graph.edges["host:10.0.0.5", "service:host:10.0.0.5:80"]
@@ -65,13 +67,13 @@ class TestShadowGraph:
}
}
graph.update_from_notes(notes)
cred_id = "cred:ssh_creds"
host_id = "host:192.168.1.20"
assert graph.graph.has_node(cred_id)
assert graph.graph.has_node(host_id)
# Check edge
assert graph.graph.has_edge(cred_id, host_id)
edge = graph.graph.edges[cred_id, host_id]
@@ -91,11 +93,11 @@ class TestShadowGraph:
}
}
graph.update_from_notes(notes)
# Check "Username: root" extraction
node1 = graph.graph.nodes["cred:creds_1"]
assert node1["label"] == "Creds (root)"
# Check fallback for no username
node2 = graph.graph.nodes["cred:creds_2"]
assert node2["label"] == "Credentials"
@@ -122,19 +124,19 @@ class TestShadowGraph:
}
}
graph.update_from_notes(notes)
# Check Credential Metadata
cred_node = graph.graph.nodes["cred:meta_cred"]
assert cred_node["label"] == "Creds (admin_meta)"
# Check Target Host
assert graph.graph.has_node("host:10.0.0.99")
assert graph.graph.has_edge("cred:meta_cred", "host:10.0.0.99")
# Check Source Host (CONTAINS edge)
assert graph.graph.has_node("host:10.0.0.1")
assert graph.graph.has_edge("host:10.0.0.1", "cred:meta_cred")
# Check Vulnerability Metadata
vuln_node = graph.graph.nodes["vuln:meta_vuln"]
assert vuln_node["label"] == "CVE-2025-1234"
@@ -154,7 +156,7 @@ class TestShadowGraph:
}
}
graph.update_from_notes(notes)
service_id = "service:host:10.0.0.5:80"
assert graph.graph.has_node(service_id)
node = graph.graph.nodes[service_id]
@@ -166,7 +168,7 @@ class TestShadowGraph:
"legacy_note": "Just a simple note about 10.10.10.10"
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:10.10.10.10")
def test_idempotency(self, graph):
@@ -177,11 +179,11 @@ class TestShadowGraph:
"category": "info"
}
}
# First pass
graph.update_from_notes(notes)
assert len(graph.graph.nodes) == 1
# Second pass
graph.update_from_notes(notes)
assert len(graph.graph.nodes) == 1
@@ -193,15 +195,15 @@ class TestShadowGraph:
graph._add_node("cred:1", "credential", "Root Creds")
graph._add_node("host:A", "host", "10.0.0.1")
graph._add_edge("cred:1", "host:A", "AUTH_ACCESS")
# 2. HostA has Cred2 (this edge type isn't auto-extracted yet, but logic should handle it)
graph._add_node("cred:2", "credential", "Db Admin")
graph._add_edge("host:A", "cred:2", "CONTAINS_CRED")
# 3. Cred2 gives access to HostB
graph._add_node("host:B", "host", "10.0.0.2")
graph._add_edge("cred:2", "host:B", "AUTH_ACCESS")
paths = graph._find_attack_paths()
assert len(paths) == 1
assert "Root Creds" in paths[0]
@@ -214,7 +216,7 @@ class TestShadowGraph:
graph._add_node("host:1", "host", "10.0.0.1")
graph._add_node("cred:1", "credential", "admin")
graph._add_edge("cred:1", "host:1", "AUTH_ACCESS")
mermaid = graph.to_mermaid()
assert "graph TD" in mermaid
assert 'host_1["🖥️ 10.0.0.1"]' in mermaid
@@ -230,7 +232,7 @@ class TestShadowGraph:
}
}
graph.update_from_notes(notes)
assert graph.graph.has_node("host:192.168.1.1")
assert graph.graph.has_node("host:192.168.1.2")
assert graph.graph.has_node("host:192.168.1.3")

View File

@@ -1,16 +1,15 @@
"""Tests for the RAG knowledge system."""
import pytest
import numpy as np
from pathlib import Path
from unittest.mock import patch
from pentestagent.knowledge.rag import RAGEngine, Document
import numpy as np
import pytest
from pentestagent.knowledge.rag import Document, RAGEngine
class TestDocument:
"""Tests for Document dataclass."""
def test_create_document(self):
"""Test creating a document."""
doc = Document(content="Test content", source="test.md")
@@ -18,7 +17,7 @@ class TestDocument:
assert doc.source == "test.md"
assert doc.metadata == {}
assert doc.doc_id is not None
def test_document_with_metadata(self):
"""Test document with metadata."""
doc = Document(
@@ -28,14 +27,14 @@ class TestDocument:
)
assert doc.metadata["cve_id"] == "CVE-2021-1234"
assert doc.metadata["severity"] == "high"
def test_document_with_embedding(self):
"""Test document with embedding."""
embedding = np.random.rand(384)
doc = Document(content="Test", source="test.md", embedding=embedding)
assert doc.embedding is not None
assert len(doc.embedding) == 384
def test_document_with_custom_id(self):
"""Test document with custom doc_id."""
doc = Document(content="Test", source="test.md", doc_id="custom-id-123")
@@ -44,7 +43,7 @@ class TestDocument:
class TestRAGEngine:
"""Tests for RAGEngine class."""
@pytest.fixture
def rag_engine(self, tmp_path):
"""Create a RAG engine for testing."""
@@ -52,49 +51,49 @@ class TestRAGEngine:
knowledge_path=tmp_path / "knowledge",
use_local_embeddings=True
)
def test_create_engine(self, rag_engine):
"""Test creating a RAG engine."""
assert rag_engine is not None
assert len(rag_engine.documents) == 0
assert rag_engine.embeddings is None
def test_get_document_count_empty(self, rag_engine):
"""Test document count on empty engine."""
assert rag_engine.get_document_count() == 0
def test_clear(self, rag_engine):
"""Test clearing the engine."""
rag_engine.documents.append(Document(content="test", source="test.md"))
rag_engine.embeddings = np.random.rand(1, 384)
rag_engine._indexed = True
rag_engine.clear()
assert len(rag_engine.documents) == 0
assert rag_engine.embeddings is None
assert rag_engine._indexed == False
assert not rag_engine._indexed
class TestRAGEngineChunking:
"""Tests for text chunking functionality."""
@pytest.fixture
def engine(self, tmp_path):
"""Create engine for chunking tests."""
return RAGEngine(knowledge_path=tmp_path)
def test_chunk_short_text(self, engine):
"""Test chunking text shorter than chunk size."""
text = "This is a short paragraph.\n\nThis is another paragraph."
chunks = engine._chunk_text(text, source="test.md", chunk_size=1000)
assert len(chunks) >= 1
assert all(isinstance(c, Document) for c in chunks)
def test_chunk_preserves_source(self, engine):
"""Test that chunking preserves source information."""
text = "Test paragraph 1.\n\nTest paragraph 2."
chunks = engine._chunk_text(text, source="my_source.md")
assert all(c.source == "my_source.md" for c in chunks)

View File

@@ -1,12 +1,11 @@
"""Tests for the Notes tool."""
import pytest
import json
import asyncio
from pathlib import Path
from unittest.mock import MagicMock, patch
from pentestagent.tools.notes import notes, set_notes_file, get_all_notes, _notes
import pytest
from pentestagent.tools.notes import _notes, get_all_notes, notes, set_notes_file
# We need to reset the global state for tests
@pytest.fixture(autouse=True)
@@ -15,14 +14,13 @@ def reset_notes_state(tmp_path):
# Point to a temp file
temp_notes_file = tmp_path / "notes.json"
set_notes_file(temp_notes_file)
# Clear the global dictionary (it's imported from the module)
# We need to clear the actual dictionary object in the module
from pentestagent.tools.notes import _notes
_notes.clear()
yield
# Cleanup is handled by tmp_path
@pytest.mark.asyncio
@@ -35,10 +33,10 @@ async def test_create_note():
"category": "info",
"confidence": "high"
}
result = await notes(args, runtime=None)
assert "Created note 'test_note'" in result
all_notes = await get_all_notes()
assert "test_note" in all_notes
assert all_notes["test_note"]["content"] == "This is a test note"
@@ -54,13 +52,13 @@ async def test_read_note():
"key": "read_me",
"value": "Content to read"
}, runtime=None)
# Read
result = await notes({
"action": "read",
"key": "read_me"
}, runtime=None)
assert "Content to read" in result
# The format is "[key] (category, confidence, status) content"
assert "(info, medium, confirmed)" in result
@@ -73,15 +71,15 @@ async def test_update_note():
"key": "update_me",
"value": "Original content"
}, runtime=None)
result = await notes({
"action": "update",
"key": "update_me",
"value": "New content"
}, runtime=None)
assert "Updated note 'update_me'" in result
all_notes = await get_all_notes()
assert all_notes["update_me"]["content"] == "New content"
@@ -93,14 +91,14 @@ async def test_delete_note():
"key": "delete_me",
"value": "Bye bye"
}, runtime=None)
result = await notes({
"action": "delete",
"key": "delete_me"
}, runtime=None)
assert "Deleted note 'delete_me'" in result
all_notes = await get_all_notes()
assert "delete_me" not in all_notes
@@ -109,9 +107,9 @@ async def test_list_notes():
"""Test listing all notes."""
await notes({"action": "create", "key": "n1", "value": "v1"}, runtime=None)
await notes({"action": "create", "key": "n2", "value": "v2"}, runtime=None)
result = await notes({"action": "list"}, runtime=None)
assert "n1" in result
assert "n2" in result
assert "Notes (2 entries):" in result
@@ -121,13 +119,13 @@ async def test_persistence(tmp_path):
"""Test that notes are saved to disk."""
# The fixture already sets a temp file
temp_file = tmp_path / "notes.json"
await notes({
"action": "create",
"key": "persistent_note",
"value": "I survive restarts"
}, runtime=None)
assert temp_file.exists()
content = json.loads(temp_file.read_text())
assert "persistent_note" in content
@@ -143,20 +141,19 @@ async def test_legacy_migration(tmp_path):
"new_note": {"content": "A dict", "category": "info"}
}
legacy_file.write_text(json.dumps(legacy_data))
# Point the tool to this file
set_notes_file(legacy_file)
# Trigger load (get_all_notes calls _load_notes_unlocked if empty, but we need to clear first)
from pentestagent.tools.notes import _notes
_notes.clear()
all_notes = await get_all_notes()
assert "old_note" in all_notes
assert isinstance(all_notes["old_note"], dict)
assert all_notes["old_note"]["content"] == "Just a string"
assert all_notes["old_note"]["category"] == "info"
assert "new_note" in all_notes
assert all_notes["new_note"]["content"] == "A dict"

View File

@@ -3,49 +3,54 @@
import pytest
from pentestagent.tools import (
Tool, ToolSchema, register_tool, get_all_tools, get_tool,
enable_tool, disable_tool, get_tool_names
ToolSchema,
disable_tool,
enable_tool,
get_all_tools,
get_tool,
get_tool_names,
register_tool,
)
class TestToolRegistry:
"""Tests for tool registry functions."""
def test_tools_loaded(self):
"""Test that built-in tools are loaded."""
tools = get_all_tools()
assert len(tools) > 0
tool_names = get_tool_names()
assert "terminal" in tool_names
assert "browser" in tool_names
def test_get_tool(self):
"""Test getting a tool by name."""
tool = get_tool("terminal")
assert tool is not None
assert tool.name == "terminal"
assert tool.category == "execution"
def test_get_nonexistent_tool(self):
"""Test getting a tool that doesn't exist."""
tool = get_tool("nonexistent_tool_xyz")
assert tool is None
def test_disable_enable_tool(self):
"""Test disabling and enabling a tool."""
result = disable_tool("terminal")
assert result is True
tool = get_tool("terminal")
assert tool.enabled is False
result = enable_tool("terminal")
assert result is True
tool = get_tool("terminal")
assert tool.enabled is True
def test_disable_nonexistent_tool(self):
"""Test disabling a tool that doesn't exist."""
result = disable_tool("nonexistent_tool_xyz")
@@ -54,7 +59,7 @@ class TestToolRegistry:
class TestToolSchema:
"""Tests for ToolSchema class."""
def test_create_schema(self):
"""Test creating a tool schema."""
schema = ToolSchema(
@@ -63,18 +68,18 @@ class TestToolSchema:
},
required=["command"]
)
assert schema.type == "object"
assert "command" in schema.properties
assert "command" in schema.required
def test_schema_to_dict(self):
"""Test converting schema to dictionary."""
schema = ToolSchema(
properties={"input": {"type": "string"}},
required=["input"]
)
d = schema.to_dict()
assert d["type"] == "object"
assert d["properties"]["input"]["type"] == "string"
@@ -83,33 +88,33 @@ class TestToolSchema:
class TestTool:
"""Tests for Tool class."""
def test_create_tool(self, sample_tool):
"""Test creating a tool."""
assert sample_tool.name == "test_tool"
assert sample_tool.description == "A test tool"
assert sample_tool.category == "test"
assert sample_tool.enabled is True
def test_tool_to_llm_format(self, sample_tool):
"""Test converting tool to LLM format."""
llm_format = sample_tool.to_llm_format()
assert llm_format["type"] == "function"
assert llm_format["function"]["name"] == "test_tool"
assert llm_format["function"]["description"] == "A test tool"
assert "parameters" in llm_format["function"]
def test_tool_validate_arguments(self, sample_tool):
"""Test argument validation."""
is_valid, error = sample_tool.validate_arguments({"param": "value"})
assert is_valid is True
assert error is None
is_valid, error = sample_tool.validate_arguments({})
assert is_valid is False
assert "param" in error
@pytest.mark.asyncio
async def test_tool_execute(self, sample_tool):
"""Test tool execution."""
@@ -119,11 +124,11 @@ class TestTool:
class TestRegisterToolDecorator:
"""Tests for register_tool decorator."""
def test_decorator_registers_tool(self):
"""Test that decorator registers a new tool."""
initial_count = len(get_all_tools())
@register_tool(
name="pytest_test_tool_unique",
description="A tool registered in tests",
@@ -132,10 +137,10 @@ class TestRegisterToolDecorator:
)
async def pytest_test_tool_unique(arguments, runtime):
return "test result"
new_count = len(get_all_tools())
assert new_count == initial_count + 1
tool = get_tool("pytest_test_tool_unique")
assert tool is not None
assert tool.name == "pytest_test_tool_unique"