diff --git a/pentestagent/interface/tui.py b/pentestagent/interface/tui.py index a035e78..03babc2 100644 --- a/pentestagent/interface/tui.py +++ b/pentestagent/interface/tui.py @@ -30,7 +30,10 @@ from textual.widgets.tree import TreeNode from ..config.constants import DEFAULT_MODEL # ANSI escape sequence pattern for stripping control codes from input -_ANSI_ESCAPE = re.compile(r'\x1b\[[0-9;]*[mGKHflSTABCDEFsu]|\x1b\].*?\x07|\x1b\[<[0-9;]*[Mm]') +_ANSI_ESCAPE = re.compile( + r"\x1b\[[0-9;]*[mGKHflSTABCDEFsu]|\x1b\].*?\x07|\x1b\[<[0-9;]*[Mm]" +) + # ASCII-safe scrollbar renderer to avoid Unicode glyph issues class ASCIIScrollBarRender(ScrollBarRender): @@ -1044,7 +1047,7 @@ Be concise. Use the actual data from notes.""" return # Strip ANSI escape sequences and control codes - message = _ANSI_ESCAPE.sub('', event.value).strip() + message = _ANSI_ESCAPE.sub("", event.value).strip() if not message: return diff --git a/pentestagent/knowledge/graph.py b/pentestagent/knowledge/graph.py index 6cf297d..96e2023 100644 --- a/pentestagent/knowledge/graph.py +++ b/pentestagent/knowledge/graph.py @@ -234,7 +234,7 @@ class ShadowGraph: product = svc.get("product", "") version = svc.get("version", "") proto = svc.get("protocol", "tcp") - + if port: for host_id in target_hosts: service_id = f"service:{host_id}:{port}" @@ -243,10 +243,18 @@ class ShadowGraph: label += f" {product}" if version: label += f" {version}" - - self._add_node(service_id, "service", label, product=product, version=version) - self._add_edge(host_id, service_id, "HAS_SERVICE", protocol=proto) - + + self._add_node( + service_id, + "service", + label, + product=product, + version=version, + ) + self._add_edge( + host_id, service_id, "HAS_SERVICE", protocol=proto + ) + # Handle nested endpoints metadata if metadata.get("endpoints"): for ep in metadata["endpoints"]: @@ -258,10 +266,10 @@ class ShadowGraph: label = path if methods: label += f" ({','.join(methods)})" - + self._add_node(endpoint_id, "endpoint", label, methods=methods) self._add_edge(host_id, endpoint_id, "HAS_ENDPOINT") - + # Handle nested technologies metadata if metadata.get("technologies"): for tech in metadata["technologies"]: @@ -273,12 +281,18 @@ class ShadowGraph: label = name if version and version != "unknown": label += f" {version}" - - self._add_node(tech_id, "technology", label, name=name, version=version) + + self._add_node( + tech_id, "technology", label, name=name, version=version + ) self._add_edge(host_id, tech_id, "USES_TECH") - + # If we processed nested metadata, we're done - if metadata.get("services") or metadata.get("endpoints") or metadata.get("technologies"): + if ( + metadata.get("services") + or metadata.get("endpoints") + or metadata.get("technologies") + ): return # Fallback to old port extraction logic @@ -378,8 +392,16 @@ class ShadowGraph: ) # Insight 2: High Value Targets (Hosts with many open ports/vulns/endpoints) - high_value_endpoints = ["admin", "phpmyadmin", "phpMyAdmin", "manager", "console", "webdav", "dav"] - + high_value_endpoints = [ + "admin", + "phpmyadmin", + "phpMyAdmin", + "manager", + "console", + "webdav", + "dav", + ] + for node, data in self.graph.nodes(data=True): if data.get("type") == "host": # Count services @@ -404,7 +426,12 @@ class ShadowGraph: if self.graph.nodes[v].get("type") == "technology" ] - if len(services) > 0 or len(vulns) > 0 or len(endpoints) > 0 or len(technologies) > 0: + if ( + len(services) > 0 + or len(vulns) > 0 + or len(endpoints) > 0 + or len(technologies) > 0 + ): parts = [] if len(services) > 0: parts.append(f"{len(services)} services") @@ -414,10 +441,8 @@ class ShadowGraph: parts.append(f"{len(technologies)} technologies") if len(vulns) > 0: parts.append(f"{len(vulns)} vulnerabilities") - insights.append( - f"Host {data['label']} has {', '.join(parts)}." - ) - + insights.append(f"Host {data['label']} has {', '.join(parts)}.") + # Flag high-value endpoints for ep_id in endpoints: ep_label = self.graph.nodes[ep_id].get("label", "") diff --git a/tests/conftest.py b/tests/conftest.py index 3703ae7..2b54504 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,14 @@ """Test fixtures for PentestAgent tests.""" -import pytest import asyncio from pathlib import Path -from typing import Generator, AsyncGenerator -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import AsyncMock, MagicMock +import pytest + +from pentestagent.agents.state import AgentStateManager from pentestagent.config import Settings -from pentestagent.agents.state import AgentState, AgentStateManager -from pentestagent.tools import get_all_tools, Tool, ToolSchema +from pentestagent.tools import Tool, ToolSchema @pytest.fixture @@ -82,7 +82,7 @@ def sample_tool() -> Tool: """Create a sample tool for testing.""" async def dummy_execute(arguments: dict, runtime) -> str: return f"Executed with: {arguments}" - + return Tool( name="test_tool", description="A test tool", diff --git a/tests/test_agents.py b/tests/test_agents.py index b08ab0d..0aff807 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -1,14 +1,14 @@ """Tests for the agent state management.""" + import pytest -from datetime import datetime from pentestagent.agents.state import AgentState, AgentStateManager, StateTransition class TestAgentState: """Tests for AgentState enum.""" - + def test_state_values(self): """Test state enum values.""" assert AgentState.IDLE.value == "idle" @@ -17,7 +17,7 @@ class TestAgentState: assert AgentState.WAITING_INPUT.value == "waiting_input" assert AgentState.COMPLETE.value == "complete" assert AgentState.ERROR.value == "error" - + def test_all_states_exist(self): """Test that all expected states exist.""" states = list(AgentState) @@ -26,76 +26,76 @@ class TestAgentState: class TestAgentStateManager: """Tests for AgentStateManager class.""" - + @pytest.fixture def state_manager(self): """Create a fresh AgentStateManager for each test.""" return AgentStateManager() - + def test_initial_state(self, state_manager): """Test initial state is IDLE.""" assert state_manager.current_state == AgentState.IDLE assert len(state_manager.history) == 0 - + def test_valid_transition(self, state_manager): """Test valid state transition.""" result = state_manager.transition_to(AgentState.THINKING) assert result is True assert state_manager.current_state == AgentState.THINKING assert len(state_manager.history) == 1 - + def test_invalid_transition(self, state_manager): """Test invalid state transition.""" result = state_manager.transition_to(AgentState.COMPLETE) assert result is False assert state_manager.current_state == AgentState.IDLE - + def test_transition_chain(self, state_manager): """Test a chain of valid transitions.""" assert state_manager.transition_to(AgentState.THINKING) assert state_manager.transition_to(AgentState.EXECUTING) assert state_manager.transition_to(AgentState.THINKING) assert state_manager.transition_to(AgentState.COMPLETE) - + assert state_manager.current_state == AgentState.COMPLETE assert len(state_manager.history) == 4 - + def test_force_transition(self, state_manager): """Test forcing a transition.""" state_manager.force_transition(AgentState.ERROR, reason="Test error") assert state_manager.current_state == AgentState.ERROR assert "FORCED" in state_manager.history[-1].reason - + def test_reset(self, state_manager): """Test resetting state.""" state_manager.transition_to(AgentState.THINKING) state_manager.transition_to(AgentState.EXECUTING) - + state_manager.reset() - + assert state_manager.current_state == AgentState.IDLE assert len(state_manager.history) == 0 - + def test_is_terminal(self, state_manager): """Test terminal state detection.""" assert state_manager.is_terminal() is False - + state_manager.transition_to(AgentState.THINKING) state_manager.transition_to(AgentState.COMPLETE) - + assert state_manager.is_terminal() is True - + def test_is_active(self, state_manager): """Test active state detection.""" assert state_manager.is_active() is False - + state_manager.transition_to(AgentState.THINKING) assert state_manager.is_active() is True class TestStateTransition: """Tests for StateTransition dataclass.""" - + def test_create_transition(self): """Test creating a state transition.""" transition = StateTransition( @@ -103,7 +103,7 @@ class TestStateTransition: to_state=AgentState.THINKING, reason="Starting work" ) - + assert transition.from_state == AgentState.IDLE assert transition.to_state == AgentState.THINKING assert transition.reason == "Starting work" diff --git a/tests/test_graph.py b/tests/test_graph.py index 49bfdcb..ef5420f 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,8 +1,10 @@ """Tests for the Shadow Graph knowledge system.""" -import pytest import networkx as nx -from pentestagent.knowledge.graph import ShadowGraph, GraphNode, GraphEdge +import pytest + +from pentestagent.knowledge.graph import ShadowGraph + class TestShadowGraph: """Tests for ShadowGraph class.""" @@ -27,7 +29,7 @@ class TestShadowGraph: } } graph.update_from_notes(notes) - + assert graph.graph.has_node("host:192.168.1.10") node = graph.graph.nodes["host:192.168.1.10"] assert node["type"] == "host" @@ -42,14 +44,14 @@ class TestShadowGraph: } } graph.update_from_notes(notes) - + # Check host exists assert graph.graph.has_node("host:10.0.0.5") - + # Check services exist assert graph.graph.has_node("service:host:10.0.0.5:80") assert graph.graph.has_node("service:host:10.0.0.5:443") - + # Check edges assert graph.graph.has_edge("host:10.0.0.5", "service:host:10.0.0.5:80") edge = graph.graph.edges["host:10.0.0.5", "service:host:10.0.0.5:80"] @@ -65,13 +67,13 @@ class TestShadowGraph: } } graph.update_from_notes(notes) - + cred_id = "cred:ssh_creds" host_id = "host:192.168.1.20" - + assert graph.graph.has_node(cred_id) assert graph.graph.has_node(host_id) - + # Check edge assert graph.graph.has_edge(cred_id, host_id) edge = graph.graph.edges[cred_id, host_id] @@ -91,11 +93,11 @@ class TestShadowGraph: } } graph.update_from_notes(notes) - + # Check "Username: root" extraction node1 = graph.graph.nodes["cred:creds_1"] assert node1["label"] == "Creds (root)" - + # Check fallback for no username node2 = graph.graph.nodes["cred:creds_2"] assert node2["label"] == "Credentials" @@ -122,19 +124,19 @@ class TestShadowGraph: } } graph.update_from_notes(notes) - + # Check Credential Metadata cred_node = graph.graph.nodes["cred:meta_cred"] assert cred_node["label"] == "Creds (admin_meta)" - + # Check Target Host assert graph.graph.has_node("host:10.0.0.99") assert graph.graph.has_edge("cred:meta_cred", "host:10.0.0.99") - + # Check Source Host (CONTAINS edge) assert graph.graph.has_node("host:10.0.0.1") assert graph.graph.has_edge("host:10.0.0.1", "cred:meta_cred") - + # Check Vulnerability Metadata vuln_node = graph.graph.nodes["vuln:meta_vuln"] assert vuln_node["label"] == "CVE-2025-1234" @@ -154,7 +156,7 @@ class TestShadowGraph: } } graph.update_from_notes(notes) - + service_id = "service:host:10.0.0.5:80" assert graph.graph.has_node(service_id) node = graph.graph.nodes[service_id] @@ -166,7 +168,7 @@ class TestShadowGraph: "legacy_note": "Just a simple note about 10.10.10.10" } graph.update_from_notes(notes) - + assert graph.graph.has_node("host:10.10.10.10") def test_idempotency(self, graph): @@ -177,11 +179,11 @@ class TestShadowGraph: "category": "info" } } - + # First pass graph.update_from_notes(notes) assert len(graph.graph.nodes) == 1 - + # Second pass graph.update_from_notes(notes) assert len(graph.graph.nodes) == 1 @@ -193,15 +195,15 @@ class TestShadowGraph: graph._add_node("cred:1", "credential", "Root Creds") graph._add_node("host:A", "host", "10.0.0.1") graph._add_edge("cred:1", "host:A", "AUTH_ACCESS") - + # 2. HostA has Cred2 (this edge type isn't auto-extracted yet, but logic should handle it) graph._add_node("cred:2", "credential", "Db Admin") graph._add_edge("host:A", "cred:2", "CONTAINS_CRED") - + # 3. Cred2 gives access to HostB graph._add_node("host:B", "host", "10.0.0.2") graph._add_edge("cred:2", "host:B", "AUTH_ACCESS") - + paths = graph._find_attack_paths() assert len(paths) == 1 assert "Root Creds" in paths[0] @@ -214,7 +216,7 @@ class TestShadowGraph: graph._add_node("host:1", "host", "10.0.0.1") graph._add_node("cred:1", "credential", "admin") graph._add_edge("cred:1", "host:1", "AUTH_ACCESS") - + mermaid = graph.to_mermaid() assert "graph TD" in mermaid assert 'host_1["🖥️ 10.0.0.1"]' in mermaid @@ -230,7 +232,7 @@ class TestShadowGraph: } } graph.update_from_notes(notes) - + assert graph.graph.has_node("host:192.168.1.1") assert graph.graph.has_node("host:192.168.1.2") assert graph.graph.has_node("host:192.168.1.3") diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py index 35730cb..8181333 100644 --- a/tests/test_knowledge.py +++ b/tests/test_knowledge.py @@ -1,16 +1,15 @@ """Tests for the RAG knowledge system.""" -import pytest -import numpy as np -from pathlib import Path -from unittest.mock import patch -from pentestagent.knowledge.rag import RAGEngine, Document +import numpy as np +import pytest + +from pentestagent.knowledge.rag import Document, RAGEngine class TestDocument: """Tests for Document dataclass.""" - + def test_create_document(self): """Test creating a document.""" doc = Document(content="Test content", source="test.md") @@ -18,7 +17,7 @@ class TestDocument: assert doc.source == "test.md" assert doc.metadata == {} assert doc.doc_id is not None - + def test_document_with_metadata(self): """Test document with metadata.""" doc = Document( @@ -28,14 +27,14 @@ class TestDocument: ) assert doc.metadata["cve_id"] == "CVE-2021-1234" assert doc.metadata["severity"] == "high" - + def test_document_with_embedding(self): """Test document with embedding.""" embedding = np.random.rand(384) doc = Document(content="Test", source="test.md", embedding=embedding) assert doc.embedding is not None assert len(doc.embedding) == 384 - + def test_document_with_custom_id(self): """Test document with custom doc_id.""" doc = Document(content="Test", source="test.md", doc_id="custom-id-123") @@ -44,7 +43,7 @@ class TestDocument: class TestRAGEngine: """Tests for RAGEngine class.""" - + @pytest.fixture def rag_engine(self, tmp_path): """Create a RAG engine for testing.""" @@ -52,49 +51,49 @@ class TestRAGEngine: knowledge_path=tmp_path / "knowledge", use_local_embeddings=True ) - + def test_create_engine(self, rag_engine): """Test creating a RAG engine.""" assert rag_engine is not None assert len(rag_engine.documents) == 0 assert rag_engine.embeddings is None - + def test_get_document_count_empty(self, rag_engine): """Test document count on empty engine.""" assert rag_engine.get_document_count() == 0 - + def test_clear(self, rag_engine): """Test clearing the engine.""" rag_engine.documents.append(Document(content="test", source="test.md")) rag_engine.embeddings = np.random.rand(1, 384) rag_engine._indexed = True - + rag_engine.clear() - + assert len(rag_engine.documents) == 0 assert rag_engine.embeddings is None - assert rag_engine._indexed == False + assert not rag_engine._indexed class TestRAGEngineChunking: """Tests for text chunking functionality.""" - + @pytest.fixture def engine(self, tmp_path): """Create engine for chunking tests.""" return RAGEngine(knowledge_path=tmp_path) - + def test_chunk_short_text(self, engine): """Test chunking text shorter than chunk size.""" text = "This is a short paragraph.\n\nThis is another paragraph." chunks = engine._chunk_text(text, source="test.md", chunk_size=1000) - + assert len(chunks) >= 1 assert all(isinstance(c, Document) for c in chunks) - + def test_chunk_preserves_source(self, engine): """Test that chunking preserves source information.""" text = "Test paragraph 1.\n\nTest paragraph 2." chunks = engine._chunk_text(text, source="my_source.md") - + assert all(c.source == "my_source.md" for c in chunks) diff --git a/tests/test_notes.py b/tests/test_notes.py index e3f0bff..6ce565b 100644 --- a/tests/test_notes.py +++ b/tests/test_notes.py @@ -1,12 +1,11 @@ """Tests for the Notes tool.""" -import pytest import json -import asyncio -from pathlib import Path -from unittest.mock import MagicMock, patch -from pentestagent.tools.notes import notes, set_notes_file, get_all_notes, _notes +import pytest + +from pentestagent.tools.notes import _notes, get_all_notes, notes, set_notes_file + # We need to reset the global state for tests @pytest.fixture(autouse=True) @@ -15,14 +14,13 @@ def reset_notes_state(tmp_path): # Point to a temp file temp_notes_file = tmp_path / "notes.json" set_notes_file(temp_notes_file) - + # Clear the global dictionary (it's imported from the module) # We need to clear the actual dictionary object in the module - from pentestagent.tools.notes import _notes _notes.clear() - + yield - + # Cleanup is handled by tmp_path @pytest.mark.asyncio @@ -35,10 +33,10 @@ async def test_create_note(): "category": "info", "confidence": "high" } - + result = await notes(args, runtime=None) assert "Created note 'test_note'" in result - + all_notes = await get_all_notes() assert "test_note" in all_notes assert all_notes["test_note"]["content"] == "This is a test note" @@ -54,13 +52,13 @@ async def test_read_note(): "key": "read_me", "value": "Content to read" }, runtime=None) - + # Read result = await notes({ "action": "read", "key": "read_me" }, runtime=None) - + assert "Content to read" in result # The format is "[key] (category, confidence, status) content" assert "(info, medium, confirmed)" in result @@ -73,15 +71,15 @@ async def test_update_note(): "key": "update_me", "value": "Original content" }, runtime=None) - + result = await notes({ "action": "update", "key": "update_me", "value": "New content" }, runtime=None) - + assert "Updated note 'update_me'" in result - + all_notes = await get_all_notes() assert all_notes["update_me"]["content"] == "New content" @@ -93,14 +91,14 @@ async def test_delete_note(): "key": "delete_me", "value": "Bye bye" }, runtime=None) - + result = await notes({ "action": "delete", "key": "delete_me" }, runtime=None) - + assert "Deleted note 'delete_me'" in result - + all_notes = await get_all_notes() assert "delete_me" not in all_notes @@ -109,9 +107,9 @@ async def test_list_notes(): """Test listing all notes.""" await notes({"action": "create", "key": "n1", "value": "v1"}, runtime=None) await notes({"action": "create", "key": "n2", "value": "v2"}, runtime=None) - + result = await notes({"action": "list"}, runtime=None) - + assert "n1" in result assert "n2" in result assert "Notes (2 entries):" in result @@ -121,13 +119,13 @@ async def test_persistence(tmp_path): """Test that notes are saved to disk.""" # The fixture already sets a temp file temp_file = tmp_path / "notes.json" - + await notes({ "action": "create", "key": "persistent_note", "value": "I survive restarts" }, runtime=None) - + assert temp_file.exists() content = json.loads(temp_file.read_text()) assert "persistent_note" in content @@ -143,20 +141,19 @@ async def test_legacy_migration(tmp_path): "new_note": {"content": "A dict", "category": "info"} } legacy_file.write_text(json.dumps(legacy_data)) - + # Point the tool to this file set_notes_file(legacy_file) - + # Trigger load (get_all_notes calls _load_notes_unlocked if empty, but we need to clear first) - from pentestagent.tools.notes import _notes _notes.clear() - + all_notes = await get_all_notes() - + assert "old_note" in all_notes assert isinstance(all_notes["old_note"], dict) assert all_notes["old_note"]["content"] == "Just a string" assert all_notes["old_note"]["category"] == "info" - + assert "new_note" in all_notes assert all_notes["new_note"]["content"] == "A dict" diff --git a/tests/test_tools.py b/tests/test_tools.py index 5aeaa9a..b741e72 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -3,49 +3,54 @@ import pytest from pentestagent.tools import ( - Tool, ToolSchema, register_tool, get_all_tools, get_tool, - enable_tool, disable_tool, get_tool_names + ToolSchema, + disable_tool, + enable_tool, + get_all_tools, + get_tool, + get_tool_names, + register_tool, ) class TestToolRegistry: """Tests for tool registry functions.""" - + def test_tools_loaded(self): """Test that built-in tools are loaded.""" tools = get_all_tools() assert len(tools) > 0 - + tool_names = get_tool_names() assert "terminal" in tool_names assert "browser" in tool_names - + def test_get_tool(self): """Test getting a tool by name.""" tool = get_tool("terminal") assert tool is not None assert tool.name == "terminal" assert tool.category == "execution" - + def test_get_nonexistent_tool(self): """Test getting a tool that doesn't exist.""" tool = get_tool("nonexistent_tool_xyz") assert tool is None - + def test_disable_enable_tool(self): """Test disabling and enabling a tool.""" result = disable_tool("terminal") assert result is True - + tool = get_tool("terminal") assert tool.enabled is False - + result = enable_tool("terminal") assert result is True - + tool = get_tool("terminal") assert tool.enabled is True - + def test_disable_nonexistent_tool(self): """Test disabling a tool that doesn't exist.""" result = disable_tool("nonexistent_tool_xyz") @@ -54,7 +59,7 @@ class TestToolRegistry: class TestToolSchema: """Tests for ToolSchema class.""" - + def test_create_schema(self): """Test creating a tool schema.""" schema = ToolSchema( @@ -63,18 +68,18 @@ class TestToolSchema: }, required=["command"] ) - + assert schema.type == "object" assert "command" in schema.properties assert "command" in schema.required - + def test_schema_to_dict(self): """Test converting schema to dictionary.""" schema = ToolSchema( properties={"input": {"type": "string"}}, required=["input"] ) - + d = schema.to_dict() assert d["type"] == "object" assert d["properties"]["input"]["type"] == "string" @@ -83,33 +88,33 @@ class TestToolSchema: class TestTool: """Tests for Tool class.""" - + def test_create_tool(self, sample_tool): """Test creating a tool.""" assert sample_tool.name == "test_tool" assert sample_tool.description == "A test tool" assert sample_tool.category == "test" assert sample_tool.enabled is True - + def test_tool_to_llm_format(self, sample_tool): """Test converting tool to LLM format.""" llm_format = sample_tool.to_llm_format() - + assert llm_format["type"] == "function" assert llm_format["function"]["name"] == "test_tool" assert llm_format["function"]["description"] == "A test tool" assert "parameters" in llm_format["function"] - + def test_tool_validate_arguments(self, sample_tool): """Test argument validation.""" is_valid, error = sample_tool.validate_arguments({"param": "value"}) assert is_valid is True assert error is None - + is_valid, error = sample_tool.validate_arguments({}) assert is_valid is False assert "param" in error - + @pytest.mark.asyncio async def test_tool_execute(self, sample_tool): """Test tool execution.""" @@ -119,11 +124,11 @@ class TestTool: class TestRegisterToolDecorator: """Tests for register_tool decorator.""" - + def test_decorator_registers_tool(self): """Test that decorator registers a new tool.""" initial_count = len(get_all_tools()) - + @register_tool( name="pytest_test_tool_unique", description="A tool registered in tests", @@ -132,10 +137,10 @@ class TestRegisterToolDecorator: ) async def pytest_test_tool_unique(arguments, runtime): return "test result" - + new_count = len(get_all_tools()) assert new_count == initial_count + 1 - + tool = get_tool("pytest_test_tool_unique") assert tool is not None assert tool.name == "pytest_test_tool_unique"