mirror of
https://github.com/GH05TCREW/pentestagent.git
synced 2026-03-07 14:23:20 +00:00
chore: apply lint and formatting fix
This commit is contained in:
@@ -30,7 +30,10 @@ from textual.widgets.tree import TreeNode
|
||||
from ..config.constants import DEFAULT_MODEL
|
||||
|
||||
# ANSI escape sequence pattern for stripping control codes from input
|
||||
_ANSI_ESCAPE = re.compile(r'\x1b\[[0-9;]*[mGKHflSTABCDEFsu]|\x1b\].*?\x07|\x1b\[<[0-9;]*[Mm]')
|
||||
_ANSI_ESCAPE = re.compile(
|
||||
r"\x1b\[[0-9;]*[mGKHflSTABCDEFsu]|\x1b\].*?\x07|\x1b\[<[0-9;]*[Mm]"
|
||||
)
|
||||
|
||||
|
||||
# ASCII-safe scrollbar renderer to avoid Unicode glyph issues
|
||||
class ASCIIScrollBarRender(ScrollBarRender):
|
||||
@@ -1044,7 +1047,7 @@ Be concise. Use the actual data from notes."""
|
||||
return
|
||||
|
||||
# Strip ANSI escape sequences and control codes
|
||||
message = _ANSI_ESCAPE.sub('', event.value).strip()
|
||||
message = _ANSI_ESCAPE.sub("", event.value).strip()
|
||||
if not message:
|
||||
return
|
||||
|
||||
|
||||
@@ -234,7 +234,7 @@ class ShadowGraph:
|
||||
product = svc.get("product", "")
|
||||
version = svc.get("version", "")
|
||||
proto = svc.get("protocol", "tcp")
|
||||
|
||||
|
||||
if port:
|
||||
for host_id in target_hosts:
|
||||
service_id = f"service:{host_id}:{port}"
|
||||
@@ -243,10 +243,18 @@ class ShadowGraph:
|
||||
label += f" {product}"
|
||||
if version:
|
||||
label += f" {version}"
|
||||
|
||||
self._add_node(service_id, "service", label, product=product, version=version)
|
||||
self._add_edge(host_id, service_id, "HAS_SERVICE", protocol=proto)
|
||||
|
||||
|
||||
self._add_node(
|
||||
service_id,
|
||||
"service",
|
||||
label,
|
||||
product=product,
|
||||
version=version,
|
||||
)
|
||||
self._add_edge(
|
||||
host_id, service_id, "HAS_SERVICE", protocol=proto
|
||||
)
|
||||
|
||||
# Handle nested endpoints metadata
|
||||
if metadata.get("endpoints"):
|
||||
for ep in metadata["endpoints"]:
|
||||
@@ -258,10 +266,10 @@ class ShadowGraph:
|
||||
label = path
|
||||
if methods:
|
||||
label += f" ({','.join(methods)})"
|
||||
|
||||
|
||||
self._add_node(endpoint_id, "endpoint", label, methods=methods)
|
||||
self._add_edge(host_id, endpoint_id, "HAS_ENDPOINT")
|
||||
|
||||
|
||||
# Handle nested technologies metadata
|
||||
if metadata.get("technologies"):
|
||||
for tech in metadata["technologies"]:
|
||||
@@ -273,12 +281,18 @@ class ShadowGraph:
|
||||
label = name
|
||||
if version and version != "unknown":
|
||||
label += f" {version}"
|
||||
|
||||
self._add_node(tech_id, "technology", label, name=name, version=version)
|
||||
|
||||
self._add_node(
|
||||
tech_id, "technology", label, name=name, version=version
|
||||
)
|
||||
self._add_edge(host_id, tech_id, "USES_TECH")
|
||||
|
||||
|
||||
# If we processed nested metadata, we're done
|
||||
if metadata.get("services") or metadata.get("endpoints") or metadata.get("technologies"):
|
||||
if (
|
||||
metadata.get("services")
|
||||
or metadata.get("endpoints")
|
||||
or metadata.get("technologies")
|
||||
):
|
||||
return
|
||||
|
||||
# Fallback to old port extraction logic
|
||||
@@ -378,8 +392,16 @@ class ShadowGraph:
|
||||
)
|
||||
|
||||
# Insight 2: High Value Targets (Hosts with many open ports/vulns/endpoints)
|
||||
high_value_endpoints = ["admin", "phpmyadmin", "phpMyAdmin", "manager", "console", "webdav", "dav"]
|
||||
|
||||
high_value_endpoints = [
|
||||
"admin",
|
||||
"phpmyadmin",
|
||||
"phpMyAdmin",
|
||||
"manager",
|
||||
"console",
|
||||
"webdav",
|
||||
"dav",
|
||||
]
|
||||
|
||||
for node, data in self.graph.nodes(data=True):
|
||||
if data.get("type") == "host":
|
||||
# Count services
|
||||
@@ -404,7 +426,12 @@ class ShadowGraph:
|
||||
if self.graph.nodes[v].get("type") == "technology"
|
||||
]
|
||||
|
||||
if len(services) > 0 or len(vulns) > 0 or len(endpoints) > 0 or len(technologies) > 0:
|
||||
if (
|
||||
len(services) > 0
|
||||
or len(vulns) > 0
|
||||
or len(endpoints) > 0
|
||||
or len(technologies) > 0
|
||||
):
|
||||
parts = []
|
||||
if len(services) > 0:
|
||||
parts.append(f"{len(services)} services")
|
||||
@@ -414,10 +441,8 @@ class ShadowGraph:
|
||||
parts.append(f"{len(technologies)} technologies")
|
||||
if len(vulns) > 0:
|
||||
parts.append(f"{len(vulns)} vulnerabilities")
|
||||
insights.append(
|
||||
f"Host {data['label']} has {', '.join(parts)}."
|
||||
)
|
||||
|
||||
insights.append(f"Host {data['label']} has {', '.join(parts)}.")
|
||||
|
||||
# Flag high-value endpoints
|
||||
for ep_id in endpoints:
|
||||
ep_label = self.graph.nodes[ep_id].get("label", "")
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Test fixtures for PentestAgent tests."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Generator, AsyncGenerator
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from pentestagent.agents.state import AgentStateManager
|
||||
from pentestagent.config import Settings
|
||||
from pentestagent.agents.state import AgentState, AgentStateManager
|
||||
from pentestagent.tools import get_all_tools, Tool, ToolSchema
|
||||
from pentestagent.tools import Tool, ToolSchema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -82,7 +82,7 @@ def sample_tool() -> Tool:
|
||||
"""Create a sample tool for testing."""
|
||||
async def dummy_execute(arguments: dict, runtime) -> str:
|
||||
return f"Executed with: {arguments}"
|
||||
|
||||
|
||||
return Tool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Tests for the agent state management."""
|
||||
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from pentestagent.agents.state import AgentState, AgentStateManager, StateTransition
|
||||
|
||||
|
||||
class TestAgentState:
|
||||
"""Tests for AgentState enum."""
|
||||
|
||||
|
||||
def test_state_values(self):
|
||||
"""Test state enum values."""
|
||||
assert AgentState.IDLE.value == "idle"
|
||||
@@ -17,7 +17,7 @@ class TestAgentState:
|
||||
assert AgentState.WAITING_INPUT.value == "waiting_input"
|
||||
assert AgentState.COMPLETE.value == "complete"
|
||||
assert AgentState.ERROR.value == "error"
|
||||
|
||||
|
||||
def test_all_states_exist(self):
|
||||
"""Test that all expected states exist."""
|
||||
states = list(AgentState)
|
||||
@@ -26,76 +26,76 @@ class TestAgentState:
|
||||
|
||||
class TestAgentStateManager:
|
||||
"""Tests for AgentStateManager class."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def state_manager(self):
|
||||
"""Create a fresh AgentStateManager for each test."""
|
||||
return AgentStateManager()
|
||||
|
||||
|
||||
def test_initial_state(self, state_manager):
|
||||
"""Test initial state is IDLE."""
|
||||
assert state_manager.current_state == AgentState.IDLE
|
||||
assert len(state_manager.history) == 0
|
||||
|
||||
|
||||
def test_valid_transition(self, state_manager):
|
||||
"""Test valid state transition."""
|
||||
result = state_manager.transition_to(AgentState.THINKING)
|
||||
assert result is True
|
||||
assert state_manager.current_state == AgentState.THINKING
|
||||
assert len(state_manager.history) == 1
|
||||
|
||||
|
||||
def test_invalid_transition(self, state_manager):
|
||||
"""Test invalid state transition."""
|
||||
result = state_manager.transition_to(AgentState.COMPLETE)
|
||||
assert result is False
|
||||
assert state_manager.current_state == AgentState.IDLE
|
||||
|
||||
|
||||
def test_transition_chain(self, state_manager):
|
||||
"""Test a chain of valid transitions."""
|
||||
assert state_manager.transition_to(AgentState.THINKING)
|
||||
assert state_manager.transition_to(AgentState.EXECUTING)
|
||||
assert state_manager.transition_to(AgentState.THINKING)
|
||||
assert state_manager.transition_to(AgentState.COMPLETE)
|
||||
|
||||
|
||||
assert state_manager.current_state == AgentState.COMPLETE
|
||||
assert len(state_manager.history) == 4
|
||||
|
||||
|
||||
def test_force_transition(self, state_manager):
|
||||
"""Test forcing a transition."""
|
||||
state_manager.force_transition(AgentState.ERROR, reason="Test error")
|
||||
assert state_manager.current_state == AgentState.ERROR
|
||||
assert "FORCED" in state_manager.history[-1].reason
|
||||
|
||||
|
||||
def test_reset(self, state_manager):
|
||||
"""Test resetting state."""
|
||||
state_manager.transition_to(AgentState.THINKING)
|
||||
state_manager.transition_to(AgentState.EXECUTING)
|
||||
|
||||
|
||||
state_manager.reset()
|
||||
|
||||
|
||||
assert state_manager.current_state == AgentState.IDLE
|
||||
assert len(state_manager.history) == 0
|
||||
|
||||
|
||||
def test_is_terminal(self, state_manager):
|
||||
"""Test terminal state detection."""
|
||||
assert state_manager.is_terminal() is False
|
||||
|
||||
|
||||
state_manager.transition_to(AgentState.THINKING)
|
||||
state_manager.transition_to(AgentState.COMPLETE)
|
||||
|
||||
|
||||
assert state_manager.is_terminal() is True
|
||||
|
||||
|
||||
def test_is_active(self, state_manager):
|
||||
"""Test active state detection."""
|
||||
assert state_manager.is_active() is False
|
||||
|
||||
|
||||
state_manager.transition_to(AgentState.THINKING)
|
||||
assert state_manager.is_active() is True
|
||||
|
||||
|
||||
class TestStateTransition:
|
||||
"""Tests for StateTransition dataclass."""
|
||||
|
||||
|
||||
def test_create_transition(self):
|
||||
"""Test creating a state transition."""
|
||||
transition = StateTransition(
|
||||
@@ -103,7 +103,7 @@ class TestStateTransition:
|
||||
to_state=AgentState.THINKING,
|
||||
reason="Starting work"
|
||||
)
|
||||
|
||||
|
||||
assert transition.from_state == AgentState.IDLE
|
||||
assert transition.to_state == AgentState.THINKING
|
||||
assert transition.reason == "Starting work"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Tests for the Shadow Graph knowledge system."""
|
||||
|
||||
import pytest
|
||||
import networkx as nx
|
||||
from pentestagent.knowledge.graph import ShadowGraph, GraphNode, GraphEdge
|
||||
import pytest
|
||||
|
||||
from pentestagent.knowledge.graph import ShadowGraph
|
||||
|
||||
|
||||
class TestShadowGraph:
|
||||
"""Tests for ShadowGraph class."""
|
||||
@@ -27,7 +29,7 @@ class TestShadowGraph:
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
assert graph.graph.has_node("host:192.168.1.10")
|
||||
node = graph.graph.nodes["host:192.168.1.10"]
|
||||
assert node["type"] == "host"
|
||||
@@ -42,14 +44,14 @@ class TestShadowGraph:
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
# Check host exists
|
||||
assert graph.graph.has_node("host:10.0.0.5")
|
||||
|
||||
|
||||
# Check services exist
|
||||
assert graph.graph.has_node("service:host:10.0.0.5:80")
|
||||
assert graph.graph.has_node("service:host:10.0.0.5:443")
|
||||
|
||||
|
||||
# Check edges
|
||||
assert graph.graph.has_edge("host:10.0.0.5", "service:host:10.0.0.5:80")
|
||||
edge = graph.graph.edges["host:10.0.0.5", "service:host:10.0.0.5:80"]
|
||||
@@ -65,13 +67,13 @@ class TestShadowGraph:
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
cred_id = "cred:ssh_creds"
|
||||
host_id = "host:192.168.1.20"
|
||||
|
||||
|
||||
assert graph.graph.has_node(cred_id)
|
||||
assert graph.graph.has_node(host_id)
|
||||
|
||||
|
||||
# Check edge
|
||||
assert graph.graph.has_edge(cred_id, host_id)
|
||||
edge = graph.graph.edges[cred_id, host_id]
|
||||
@@ -91,11 +93,11 @@ class TestShadowGraph:
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
# Check "Username: root" extraction
|
||||
node1 = graph.graph.nodes["cred:creds_1"]
|
||||
assert node1["label"] == "Creds (root)"
|
||||
|
||||
|
||||
# Check fallback for no username
|
||||
node2 = graph.graph.nodes["cred:creds_2"]
|
||||
assert node2["label"] == "Credentials"
|
||||
@@ -122,19 +124,19 @@ class TestShadowGraph:
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
# Check Credential Metadata
|
||||
cred_node = graph.graph.nodes["cred:meta_cred"]
|
||||
assert cred_node["label"] == "Creds (admin_meta)"
|
||||
|
||||
|
||||
# Check Target Host
|
||||
assert graph.graph.has_node("host:10.0.0.99")
|
||||
assert graph.graph.has_edge("cred:meta_cred", "host:10.0.0.99")
|
||||
|
||||
|
||||
# Check Source Host (CONTAINS edge)
|
||||
assert graph.graph.has_node("host:10.0.0.1")
|
||||
assert graph.graph.has_edge("host:10.0.0.1", "cred:meta_cred")
|
||||
|
||||
|
||||
# Check Vulnerability Metadata
|
||||
vuln_node = graph.graph.nodes["vuln:meta_vuln"]
|
||||
assert vuln_node["label"] == "CVE-2025-1234"
|
||||
@@ -154,7 +156,7 @@ class TestShadowGraph:
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
service_id = "service:host:10.0.0.5:80"
|
||||
assert graph.graph.has_node(service_id)
|
||||
node = graph.graph.nodes[service_id]
|
||||
@@ -166,7 +168,7 @@ class TestShadowGraph:
|
||||
"legacy_note": "Just a simple note about 10.10.10.10"
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
assert graph.graph.has_node("host:10.10.10.10")
|
||||
|
||||
def test_idempotency(self, graph):
|
||||
@@ -177,11 +179,11 @@ class TestShadowGraph:
|
||||
"category": "info"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# First pass
|
||||
graph.update_from_notes(notes)
|
||||
assert len(graph.graph.nodes) == 1
|
||||
|
||||
|
||||
# Second pass
|
||||
graph.update_from_notes(notes)
|
||||
assert len(graph.graph.nodes) == 1
|
||||
@@ -193,15 +195,15 @@ class TestShadowGraph:
|
||||
graph._add_node("cred:1", "credential", "Root Creds")
|
||||
graph._add_node("host:A", "host", "10.0.0.1")
|
||||
graph._add_edge("cred:1", "host:A", "AUTH_ACCESS")
|
||||
|
||||
|
||||
# 2. HostA has Cred2 (this edge type isn't auto-extracted yet, but logic should handle it)
|
||||
graph._add_node("cred:2", "credential", "Db Admin")
|
||||
graph._add_edge("host:A", "cred:2", "CONTAINS_CRED")
|
||||
|
||||
|
||||
# 3. Cred2 gives access to HostB
|
||||
graph._add_node("host:B", "host", "10.0.0.2")
|
||||
graph._add_edge("cred:2", "host:B", "AUTH_ACCESS")
|
||||
|
||||
|
||||
paths = graph._find_attack_paths()
|
||||
assert len(paths) == 1
|
||||
assert "Root Creds" in paths[0]
|
||||
@@ -214,7 +216,7 @@ class TestShadowGraph:
|
||||
graph._add_node("host:1", "host", "10.0.0.1")
|
||||
graph._add_node("cred:1", "credential", "admin")
|
||||
graph._add_edge("cred:1", "host:1", "AUTH_ACCESS")
|
||||
|
||||
|
||||
mermaid = graph.to_mermaid()
|
||||
assert "graph TD" in mermaid
|
||||
assert 'host_1["🖥️ 10.0.0.1"]' in mermaid
|
||||
@@ -230,7 +232,7 @@ class TestShadowGraph:
|
||||
}
|
||||
}
|
||||
graph.update_from_notes(notes)
|
||||
|
||||
|
||||
assert graph.graph.has_node("host:192.168.1.1")
|
||||
assert graph.graph.has_node("host:192.168.1.2")
|
||||
assert graph.graph.has_node("host:192.168.1.3")
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
"""Tests for the RAG knowledge system."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from pentestagent.knowledge.rag import RAGEngine, Document
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pentestagent.knowledge.rag import Document, RAGEngine
|
||||
|
||||
|
||||
class TestDocument:
|
||||
"""Tests for Document dataclass."""
|
||||
|
||||
|
||||
def test_create_document(self):
|
||||
"""Test creating a document."""
|
||||
doc = Document(content="Test content", source="test.md")
|
||||
@@ -18,7 +17,7 @@ class TestDocument:
|
||||
assert doc.source == "test.md"
|
||||
assert doc.metadata == {}
|
||||
assert doc.doc_id is not None
|
||||
|
||||
|
||||
def test_document_with_metadata(self):
|
||||
"""Test document with metadata."""
|
||||
doc = Document(
|
||||
@@ -28,14 +27,14 @@ class TestDocument:
|
||||
)
|
||||
assert doc.metadata["cve_id"] == "CVE-2021-1234"
|
||||
assert doc.metadata["severity"] == "high"
|
||||
|
||||
|
||||
def test_document_with_embedding(self):
|
||||
"""Test document with embedding."""
|
||||
embedding = np.random.rand(384)
|
||||
doc = Document(content="Test", source="test.md", embedding=embedding)
|
||||
assert doc.embedding is not None
|
||||
assert len(doc.embedding) == 384
|
||||
|
||||
|
||||
def test_document_with_custom_id(self):
|
||||
"""Test document with custom doc_id."""
|
||||
doc = Document(content="Test", source="test.md", doc_id="custom-id-123")
|
||||
@@ -44,7 +43,7 @@ class TestDocument:
|
||||
|
||||
class TestRAGEngine:
|
||||
"""Tests for RAGEngine class."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_engine(self, tmp_path):
|
||||
"""Create a RAG engine for testing."""
|
||||
@@ -52,49 +51,49 @@ class TestRAGEngine:
|
||||
knowledge_path=tmp_path / "knowledge",
|
||||
use_local_embeddings=True
|
||||
)
|
||||
|
||||
|
||||
def test_create_engine(self, rag_engine):
|
||||
"""Test creating a RAG engine."""
|
||||
assert rag_engine is not None
|
||||
assert len(rag_engine.documents) == 0
|
||||
assert rag_engine.embeddings is None
|
||||
|
||||
|
||||
def test_get_document_count_empty(self, rag_engine):
|
||||
"""Test document count on empty engine."""
|
||||
assert rag_engine.get_document_count() == 0
|
||||
|
||||
|
||||
def test_clear(self, rag_engine):
|
||||
"""Test clearing the engine."""
|
||||
rag_engine.documents.append(Document(content="test", source="test.md"))
|
||||
rag_engine.embeddings = np.random.rand(1, 384)
|
||||
rag_engine._indexed = True
|
||||
|
||||
|
||||
rag_engine.clear()
|
||||
|
||||
|
||||
assert len(rag_engine.documents) == 0
|
||||
assert rag_engine.embeddings is None
|
||||
assert rag_engine._indexed == False
|
||||
assert not rag_engine._indexed
|
||||
|
||||
|
||||
class TestRAGEngineChunking:
|
||||
"""Tests for text chunking functionality."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, tmp_path):
|
||||
"""Create engine for chunking tests."""
|
||||
return RAGEngine(knowledge_path=tmp_path)
|
||||
|
||||
|
||||
def test_chunk_short_text(self, engine):
|
||||
"""Test chunking text shorter than chunk size."""
|
||||
text = "This is a short paragraph.\n\nThis is another paragraph."
|
||||
chunks = engine._chunk_text(text, source="test.md", chunk_size=1000)
|
||||
|
||||
|
||||
assert len(chunks) >= 1
|
||||
assert all(isinstance(c, Document) for c in chunks)
|
||||
|
||||
|
||||
def test_chunk_preserves_source(self, engine):
|
||||
"""Test that chunking preserves source information."""
|
||||
text = "Test paragraph 1.\n\nTest paragraph 2."
|
||||
chunks = engine._chunk_text(text, source="my_source.md")
|
||||
|
||||
|
||||
assert all(c.source == "my_source.md" for c in chunks)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
"""Tests for the Notes tool."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pentestagent.tools.notes import notes, set_notes_file, get_all_notes, _notes
|
||||
import pytest
|
||||
|
||||
from pentestagent.tools.notes import _notes, get_all_notes, notes, set_notes_file
|
||||
|
||||
|
||||
# We need to reset the global state for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -15,14 +14,13 @@ def reset_notes_state(tmp_path):
|
||||
# Point to a temp file
|
||||
temp_notes_file = tmp_path / "notes.json"
|
||||
set_notes_file(temp_notes_file)
|
||||
|
||||
|
||||
# Clear the global dictionary (it's imported from the module)
|
||||
# We need to clear the actual dictionary object in the module
|
||||
from pentestagent.tools.notes import _notes
|
||||
_notes.clear()
|
||||
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# Cleanup is handled by tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -35,10 +33,10 @@ async def test_create_note():
|
||||
"category": "info",
|
||||
"confidence": "high"
|
||||
}
|
||||
|
||||
|
||||
result = await notes(args, runtime=None)
|
||||
assert "Created note 'test_note'" in result
|
||||
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert "test_note" in all_notes
|
||||
assert all_notes["test_note"]["content"] == "This is a test note"
|
||||
@@ -54,13 +52,13 @@ async def test_read_note():
|
||||
"key": "read_me",
|
||||
"value": "Content to read"
|
||||
}, runtime=None)
|
||||
|
||||
|
||||
# Read
|
||||
result = await notes({
|
||||
"action": "read",
|
||||
"key": "read_me"
|
||||
}, runtime=None)
|
||||
|
||||
|
||||
assert "Content to read" in result
|
||||
# The format is "[key] (category, confidence, status) content"
|
||||
assert "(info, medium, confirmed)" in result
|
||||
@@ -73,15 +71,15 @@ async def test_update_note():
|
||||
"key": "update_me",
|
||||
"value": "Original content"
|
||||
}, runtime=None)
|
||||
|
||||
|
||||
result = await notes({
|
||||
"action": "update",
|
||||
"key": "update_me",
|
||||
"value": "New content"
|
||||
}, runtime=None)
|
||||
|
||||
|
||||
assert "Updated note 'update_me'" in result
|
||||
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert all_notes["update_me"]["content"] == "New content"
|
||||
|
||||
@@ -93,14 +91,14 @@ async def test_delete_note():
|
||||
"key": "delete_me",
|
||||
"value": "Bye bye"
|
||||
}, runtime=None)
|
||||
|
||||
|
||||
result = await notes({
|
||||
"action": "delete",
|
||||
"key": "delete_me"
|
||||
}, runtime=None)
|
||||
|
||||
|
||||
assert "Deleted note 'delete_me'" in result
|
||||
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
assert "delete_me" not in all_notes
|
||||
|
||||
@@ -109,9 +107,9 @@ async def test_list_notes():
|
||||
"""Test listing all notes."""
|
||||
await notes({"action": "create", "key": "n1", "value": "v1"}, runtime=None)
|
||||
await notes({"action": "create", "key": "n2", "value": "v2"}, runtime=None)
|
||||
|
||||
|
||||
result = await notes({"action": "list"}, runtime=None)
|
||||
|
||||
|
||||
assert "n1" in result
|
||||
assert "n2" in result
|
||||
assert "Notes (2 entries):" in result
|
||||
@@ -121,13 +119,13 @@ async def test_persistence(tmp_path):
|
||||
"""Test that notes are saved to disk."""
|
||||
# The fixture already sets a temp file
|
||||
temp_file = tmp_path / "notes.json"
|
||||
|
||||
|
||||
await notes({
|
||||
"action": "create",
|
||||
"key": "persistent_note",
|
||||
"value": "I survive restarts"
|
||||
}, runtime=None)
|
||||
|
||||
|
||||
assert temp_file.exists()
|
||||
content = json.loads(temp_file.read_text())
|
||||
assert "persistent_note" in content
|
||||
@@ -143,20 +141,19 @@ async def test_legacy_migration(tmp_path):
|
||||
"new_note": {"content": "A dict", "category": "info"}
|
||||
}
|
||||
legacy_file.write_text(json.dumps(legacy_data))
|
||||
|
||||
|
||||
# Point the tool to this file
|
||||
set_notes_file(legacy_file)
|
||||
|
||||
|
||||
# Trigger load (get_all_notes calls _load_notes_unlocked if empty, but we need to clear first)
|
||||
from pentestagent.tools.notes import _notes
|
||||
_notes.clear()
|
||||
|
||||
|
||||
all_notes = await get_all_notes()
|
||||
|
||||
|
||||
assert "old_note" in all_notes
|
||||
assert isinstance(all_notes["old_note"], dict)
|
||||
assert all_notes["old_note"]["content"] == "Just a string"
|
||||
assert all_notes["old_note"]["category"] == "info"
|
||||
|
||||
|
||||
assert "new_note" in all_notes
|
||||
assert all_notes["new_note"]["content"] == "A dict"
|
||||
|
||||
@@ -3,49 +3,54 @@
|
||||
import pytest
|
||||
|
||||
from pentestagent.tools import (
|
||||
Tool, ToolSchema, register_tool, get_all_tools, get_tool,
|
||||
enable_tool, disable_tool, get_tool_names
|
||||
ToolSchema,
|
||||
disable_tool,
|
||||
enable_tool,
|
||||
get_all_tools,
|
||||
get_tool,
|
||||
get_tool_names,
|
||||
register_tool,
|
||||
)
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Tests for tool registry functions."""
|
||||
|
||||
|
||||
def test_tools_loaded(self):
|
||||
"""Test that built-in tools are loaded."""
|
||||
tools = get_all_tools()
|
||||
assert len(tools) > 0
|
||||
|
||||
|
||||
tool_names = get_tool_names()
|
||||
assert "terminal" in tool_names
|
||||
assert "browser" in tool_names
|
||||
|
||||
|
||||
def test_get_tool(self):
|
||||
"""Test getting a tool by name."""
|
||||
tool = get_tool("terminal")
|
||||
assert tool is not None
|
||||
assert tool.name == "terminal"
|
||||
assert tool.category == "execution"
|
||||
|
||||
|
||||
def test_get_nonexistent_tool(self):
|
||||
"""Test getting a tool that doesn't exist."""
|
||||
tool = get_tool("nonexistent_tool_xyz")
|
||||
assert tool is None
|
||||
|
||||
|
||||
def test_disable_enable_tool(self):
|
||||
"""Test disabling and enabling a tool."""
|
||||
result = disable_tool("terminal")
|
||||
assert result is True
|
||||
|
||||
|
||||
tool = get_tool("terminal")
|
||||
assert tool.enabled is False
|
||||
|
||||
|
||||
result = enable_tool("terminal")
|
||||
assert result is True
|
||||
|
||||
|
||||
tool = get_tool("terminal")
|
||||
assert tool.enabled is True
|
||||
|
||||
|
||||
def test_disable_nonexistent_tool(self):
|
||||
"""Test disabling a tool that doesn't exist."""
|
||||
result = disable_tool("nonexistent_tool_xyz")
|
||||
@@ -54,7 +59,7 @@ class TestToolRegistry:
|
||||
|
||||
class TestToolSchema:
|
||||
"""Tests for ToolSchema class."""
|
||||
|
||||
|
||||
def test_create_schema(self):
|
||||
"""Test creating a tool schema."""
|
||||
schema = ToolSchema(
|
||||
@@ -63,18 +68,18 @@ class TestToolSchema:
|
||||
},
|
||||
required=["command"]
|
||||
)
|
||||
|
||||
|
||||
assert schema.type == "object"
|
||||
assert "command" in schema.properties
|
||||
assert "command" in schema.required
|
||||
|
||||
|
||||
def test_schema_to_dict(self):
|
||||
"""Test converting schema to dictionary."""
|
||||
schema = ToolSchema(
|
||||
properties={"input": {"type": "string"}},
|
||||
required=["input"]
|
||||
)
|
||||
|
||||
|
||||
d = schema.to_dict()
|
||||
assert d["type"] == "object"
|
||||
assert d["properties"]["input"]["type"] == "string"
|
||||
@@ -83,33 +88,33 @@ class TestToolSchema:
|
||||
|
||||
class TestTool:
|
||||
"""Tests for Tool class."""
|
||||
|
||||
|
||||
def test_create_tool(self, sample_tool):
|
||||
"""Test creating a tool."""
|
||||
assert sample_tool.name == "test_tool"
|
||||
assert sample_tool.description == "A test tool"
|
||||
assert sample_tool.category == "test"
|
||||
assert sample_tool.enabled is True
|
||||
|
||||
|
||||
def test_tool_to_llm_format(self, sample_tool):
|
||||
"""Test converting tool to LLM format."""
|
||||
llm_format = sample_tool.to_llm_format()
|
||||
|
||||
|
||||
assert llm_format["type"] == "function"
|
||||
assert llm_format["function"]["name"] == "test_tool"
|
||||
assert llm_format["function"]["description"] == "A test tool"
|
||||
assert "parameters" in llm_format["function"]
|
||||
|
||||
|
||||
def test_tool_validate_arguments(self, sample_tool):
|
||||
"""Test argument validation."""
|
||||
is_valid, error = sample_tool.validate_arguments({"param": "value"})
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
|
||||
is_valid, error = sample_tool.validate_arguments({})
|
||||
assert is_valid is False
|
||||
assert "param" in error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execute(self, sample_tool):
|
||||
"""Test tool execution."""
|
||||
@@ -119,11 +124,11 @@ class TestTool:
|
||||
|
||||
class TestRegisterToolDecorator:
|
||||
"""Tests for register_tool decorator."""
|
||||
|
||||
|
||||
def test_decorator_registers_tool(self):
|
||||
"""Test that decorator registers a new tool."""
|
||||
initial_count = len(get_all_tools())
|
||||
|
||||
|
||||
@register_tool(
|
||||
name="pytest_test_tool_unique",
|
||||
description="A tool registered in tests",
|
||||
@@ -132,10 +137,10 @@ class TestRegisterToolDecorator:
|
||||
)
|
||||
async def pytest_test_tool_unique(arguments, runtime):
|
||||
return "test result"
|
||||
|
||||
|
||||
new_count = len(get_all_tools())
|
||||
assert new_count == initial_count + 1
|
||||
|
||||
|
||||
tool = get_tool("pytest_test_tool_unique")
|
||||
assert tool is not None
|
||||
assert tool.name == "pytest_test_tool_unique"
|
||||
|
||||
Reference in New Issue
Block a user