Files
pentestagent/ghostcrew/knowledge/rag.py
2025-12-07 09:11:26 -07:00

440 lines
14 KiB
Python

"""RAG (Retrieval Augmented Generation) engine for GhostCrew."""
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
from .embeddings import get_embeddings
@dataclass
class Document:
"""A chunk of knowledge."""
content: str
source: str
metadata: Optional[Dict[str, Any]] = None
embedding: Optional[np.ndarray] = None
doc_id: Optional[str] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
if self.doc_id is None:
self.doc_id = f"{hash(self.content)}_{hash(self.source)}"
class RAGEngine:
"""Vector search over security knowledge."""
def __init__(
self,
knowledge_path: Path = Path("knowledge"),
embedding_model: str = "text-embedding-3-small",
use_local_embeddings: bool = False,
):
"""
Initialize the RAG engine.
Args:
knowledge_path: Path to the knowledge directory
embedding_model: Model to use for embeddings
use_local_embeddings: Whether to use local embeddings (sentence-transformers)
"""
self.knowledge_path = knowledge_path
self.embedding_model = embedding_model
self.use_local_embeddings = use_local_embeddings
self.documents: List[Document] = []
self.embeddings: Optional[np.ndarray] = None
self._indexed = False
self._source_files: set = set() # Track unique source files
def index(self, force: bool = False):
"""
Index all documents in knowledge directory.
Args:
force: Force re-indexing even if already indexed
"""
if self._indexed and not force:
return
chunks = []
self._source_files = set() # Reset source file tracking
# Process all files in knowledge directory
if self.knowledge_path.exists():
for file in self.knowledge_path.rglob("*"):
if not file.is_file():
continue
try:
if file.suffix in [".txt", ".md"]:
self._source_files.add(str(file))
content = file.read_text(encoding="utf-8", errors="ignore")
file_chunks = self._chunk_text(content, source=str(file))
chunks.extend(file_chunks)
elif file.suffix == ".json":
self._source_files.add(str(file))
data = json.loads(file.read_text(encoding="utf-8"))
if isinstance(data, list):
for item in data:
chunks.append(
Document(
content=json.dumps(item, indent=2),
source=str(file),
metadata=(
item
if isinstance(item, dict)
else {"data": item}
),
)
)
else:
chunks.append(
Document(
content=json.dumps(data, indent=2),
source=str(file),
metadata=(
data
if isinstance(data, dict)
else {"data": data}
),
)
)
except Exception as e:
print(f"[RAG] Error processing {file}: {e}")
self.documents = chunks
# Generate embeddings
if chunks:
texts = [doc.content for doc in chunks]
if self.use_local_embeddings:
from .embeddings import get_embeddings_local
self.embeddings = get_embeddings_local(texts)
else:
self.embeddings = get_embeddings(texts, model=self.embedding_model)
# Store embeddings in documents
for i, doc in enumerate(self.documents):
doc.embedding = self.embeddings[i]
self._indexed = True
def _chunk_text(
self, text: str, source: str, chunk_size: int = 1000, overlap: int = 200
) -> List[Document]:
"""
Split text into overlapping chunks.
Args:
text: The text to split
source: The source file path
chunk_size: Maximum chunk size in characters
overlap: Overlap between chunks
Returns:
List of Document objects
"""
chunks = []
# Split by paragraphs first for better context
paragraphs = text.split("\n\n")
current_chunk = ""
for para in paragraphs:
if len(current_chunk) + len(para) + 2 <= chunk_size:
current_chunk += para + "\n\n"
else:
if current_chunk.strip():
chunks.append(
Document(content=current_chunk.strip(), source=source)
)
current_chunk = para + "\n\n"
# Add the last chunk
if current_chunk.strip():
chunks.append(Document(content=current_chunk.strip(), source=source))
# If no paragraphs were found, fall back to simple chunking
if not chunks and text.strip():
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
if chunk.strip():
chunks.append(Document(content=chunk.strip(), source=source))
start = end - overlap
return chunks
def search(
self, query: str, k: int = 5, threshold: float = 0.35, max_tokens: int = 1500
) -> List[str]:
"""
Find relevant documents for a query.
Args:
query: The search query
k: Maximum number of results to return
threshold: Minimum similarity threshold
max_tokens: Maximum total tokens to return (prevents context bloat)
Returns:
List of relevant document contents
"""
# Guard against empty/invalid queries
if not query or not isinstance(query, str) or not query.strip():
return []
if not self._indexed:
self.index()
if not self.documents or self.embeddings is None:
return []
# Get query embedding
if self.use_local_embeddings:
from .embeddings import get_embeddings_local
query_embedding = get_embeddings_local([query])[0]
else:
query_embedding = get_embeddings([query], model=self.embedding_model)[0]
# Compute cosine similarities
similarities = np.dot(self.embeddings, query_embedding) / (
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
+ 1e-10
)
# Get top k indices above threshold
indices_above_threshold = np.where(similarities >= threshold)[0]
if len(indices_above_threshold) > 0:
# Sort by similarity (descending) and take top k
sorted_indices = indices_above_threshold[
np.argsort(similarities[indices_above_threshold])[::-1]
]
top_indices = sorted_indices[:k]
else:
# No results above threshold - return empty rather than irrelevant content
return []
# Collect results up to max_tokens budget
results = []
total_tokens = 0
for idx in top_indices:
content = self.documents[idx].content
# Rough token estimate: ~4 chars per token
chunk_tokens = len(content) // 4
if total_tokens + chunk_tokens > max_tokens and results:
# Stop if we'd exceed budget (but always include at least one)
break
results.append(content)
total_tokens += chunk_tokens
return results
def search_with_scores(
self, query: str, k: int = 5, threshold: float = 0.35
) -> List[tuple[Document, float]]:
"""
Search with similarity scores.
Args:
query: The search query
k: Maximum number of results to return
threshold: Minimum similarity threshold
Returns:
List of (Document, score) tuples above threshold
"""
if not self._indexed:
self.index()
if not self.documents or self.embeddings is None:
return []
# Get query embedding
if self.use_local_embeddings:
from .embeddings import get_embeddings_local
query_embedding = get_embeddings_local([query])[0]
else:
query_embedding = get_embeddings([query], model=self.embedding_model)[0]
# Compute cosine similarities
similarities = np.dot(self.embeddings, query_embedding) / (
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
+ 1e-10
)
# Get top k above threshold
indices_above_threshold = np.where(similarities >= threshold)[0]
if len(indices_above_threshold) > 0:
sorted_indices = indices_above_threshold[
np.argsort(similarities[indices_above_threshold])[::-1]
]
top_indices = sorted_indices[:k]
else:
# Fallback: return single best result even if below threshold
top_indices = [np.argmax(similarities)]
return [(self.documents[i], float(similarities[i])) for i in top_indices]
def add_document(
self, content: str, source: str = "user", metadata: Optional[dict] = None
):
"""
Add a document to the knowledge base.
Args:
content: The document content
source: The source identifier
metadata: Optional metadata
"""
doc = Document(content=content, source=source, metadata=metadata)
# Generate embedding
if self.use_local_embeddings:
from .embeddings import get_embeddings_local
new_embedding = get_embeddings_local([content])
else:
new_embedding = get_embeddings([content], model=self.embedding_model)
doc.embedding = new_embedding[0]
self.documents.append(doc)
# Update embeddings array
if self.embeddings is not None:
self.embeddings = np.vstack([self.embeddings, new_embedding])
else:
self.embeddings = new_embedding
def add_documents(self, documents: List[Document]):
"""
Add multiple documents to the knowledge base.
Args:
documents: List of Document objects to add
"""
if not documents:
return
texts = [doc.content for doc in documents]
if self.use_local_embeddings:
from .embeddings import get_embeddings_local
new_embeddings = get_embeddings_local(texts)
else:
new_embeddings = get_embeddings(texts, model=self.embedding_model)
for i, doc in enumerate(documents):
doc.embedding = new_embeddings[i]
self.documents.append(doc)
if self.embeddings is not None:
self.embeddings = np.vstack([self.embeddings, new_embeddings])
else:
self.embeddings = new_embeddings
def remove_document(self, doc_id: str) -> bool:
"""
Remove a document by ID.
Args:
doc_id: The document ID to remove
Returns:
True if removed, False if not found
"""
for i, doc in enumerate(self.documents):
if doc.doc_id == doc_id:
self.documents.pop(i)
if self.embeddings is not None:
self.embeddings = np.delete(self.embeddings, i, axis=0)
return True
return False
def clear(self):
"""Clear all documents and embeddings."""
self.documents.clear()
self.embeddings = None
self._indexed = False
self._source_files = set()
def get_document_count(self) -> int:
"""Get the number of source files indexed."""
return len(self._source_files)
def get_chunk_count(self) -> int:
"""Get the number of indexed chunks (internal document segments)."""
return len(self.documents)
def save_index(self, path: Path):
"""
Save the index to disk.
Args:
path: Path to save the index
"""
import pickle
data = {
"documents": [
{
"content": doc.content,
"source": doc.source,
"metadata": doc.metadata,
"doc_id": doc.doc_id,
}
for doc in self.documents
],
"embeddings": self.embeddings,
}
with open(path, "wb") as f:
pickle.dump(data, f)
def load_index(self, path: Path):
"""
Load the index from disk.
Args:
path: Path to load the index from
"""
import pickle
with open(path, "rb") as f:
data = pickle.load(f)
self.documents = [
Document(
content=d["content"],
source=d["source"],
metadata=d["metadata"],
doc_id=d["doc_id"],
)
for d in data["documents"]
]
self.embeddings = data["embeddings"]
# Restore embeddings in documents
if self.embeddings is not None:
for i, doc in enumerate(self.documents):
doc.embedding = self.embeddings[i]
self._indexed = True