From feea508ff9c3a84c2e15374b0d98b5e50feda5b7 Mon Sep 17 00:00:00 2001 From: GH05TCREW Date: Mon, 22 Dec 2025 16:41:19 -0700 Subject: [PATCH] test: fix graph tests --- pentestagent/knowledge/graph.py | 34 +++++++++++++++++++--------- pentestagent/tools/notes/__init__.py | 12 +++++----- tests/test_graph.py | 17 ++++++++++++-- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/pentestagent/knowledge/graph.py b/pentestagent/knowledge/graph.py index da0d6c1..27e8f75 100644 --- a/pentestagent/knowledge/graph.py +++ b/pentestagent/knowledge/graph.py @@ -128,18 +128,30 @@ class ShadowGraph: # 2. Process structured metadata regardless of category # Category is organizational only - the metadata structure determines graph entities - + # Process credential data if present - if category == "credential" or metadata.get("username") or metadata.get("password"): + if ( + category == "credential" + or metadata.get("username") + or metadata.get("password") + ): self._process_credential(key, content, hosts, metadata, status) - + # Process services/endpoints/technologies if present - if (metadata.get("services") or metadata.get("endpoints") or - metadata.get("technologies") or metadata.get("port")): + if ( + metadata.get("services") + or metadata.get("endpoints") + or metadata.get("technologies") + or metadata.get("port") + ): self._process_services_and_tech(key, content, hosts, metadata, status) - + # Process vulnerability data if present - if category == "vulnerability" or metadata.get("cve") or metadata.get("weaknesses"): + if ( + category == "vulnerability" + or metadata.get("cve") + or metadata.get("weaknesses") + ): self._process_vulnerability(key, content, hosts, metadata, status) # 3. Link note to hosts (provenance) @@ -223,7 +235,7 @@ class ShadowGraph: port = svc.get("port") if not port: continue - + product = svc.get("product", "") version = svc.get("version", "") proto = svc.get("protocol", "tcp") @@ -251,7 +263,7 @@ class ShadowGraph: path = ep.get("path") if not path: continue - + methods = ep.get("methods", []) for host_id in target_hosts: endpoint_id = f"endpoint:{host_id}:{path}" @@ -268,7 +280,7 @@ class ShadowGraph: name = tech.get("name") if not name: continue - + version = tech.get("version", "") for host_id in target_hosts: tech_id = f"tech:{host_id}:{name}" @@ -287,7 +299,7 @@ class ShadowGraph: proto = "tcp" if "/" in port_str: port_str, proto = port_str.split("/") - + for host_id in target_hosts: service_id = f"service:{host_id}:{port_str}" label = f"{port_str}/{proto}" diff --git a/pentestagent/tools/notes/__init__.py b/pentestagent/tools/notes/__init__.py index 651262f..4a3883e 100644 --- a/pentestagent/tools/notes/__init__.py +++ b/pentestagent/tools/notes/__init__.py @@ -106,33 +106,33 @@ CATEGORY_REQUIREMENTS = { def _validate_note_schema(category: str, metadata: Dict[str, Any]) -> str | None: """ Validate note schema based on declarative rules. - + Returns: Error message if validation fails, None if valid """ # Check if note has host-specific structured data has_host_data = bool(HOST_SPECIFIC_FIELDS & metadata.keys()) - + # If note has host-specific data, require target if has_host_data and not metadata.get("target"): fields = ", ".join(f"'{f}'" for f in HOST_SPECIFIC_FIELDS if f in metadata) return f"Error: 'target' field is required when providing host-specific data ({fields})." - + # Apply category-specific validation rules if category in CATEGORY_REQUIREMENTS: rules = CATEGORY_REQUIREMENTS[category] - + # Check required fields for field in rules.get("required", []): if not metadata.get(field): return f"Error: '{field}' field is required for category '{category}'." - + # Check one_of constraints (at least one field from each group must be present) for field_group in rules.get("one_of", []): if not any(metadata.get(field) for field in field_group): field_list = "' or '".join(field_group) return f"Error: At least one of '{field_list}' is required for category '{category}'." - + return None diff --git a/tests/test_graph.py b/tests/test_graph.py index ef5420f..6d7b8ad 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -40,7 +40,14 @@ class TestShadowGraph: notes = { "ports_scan": { "content": "Found open ports: 80/tcp, 443/tcp on 10.0.0.5", - "category": "finding" + "category": "finding", + "metadata": { + "target": "10.0.0.5", + "services": [ + {"port": 80, "protocol": "tcp", "service": "http"}, + {"port": 443, "protocol": "tcp", "service": "https"} + ] + } } } graph.update_from_notes(notes) @@ -63,7 +70,13 @@ class TestShadowGraph: notes = { "ssh_creds": { "content": "Found user: admin with password 'password123' for SSH on 192.168.1.20", - "category": "credential" + "category": "credential", + "metadata": { + "target": "192.168.1.20", + "username": "admin", + "password": "password123", + "protocol": "ssh" + } } } graph.update_from_notes(notes)