Files
pentestagent/ghostcrew/knowledge/embeddings.py
2025-12-07 09:11:26 -07:00

147 lines
3.7 KiB
Python

"""Embedding generation for GhostCrew."""
from typing import List, Optional
import numpy as np
def get_embeddings(
texts: List[str], model: str = "text-embedding-3-small"
) -> np.ndarray:
"""
Generate embeddings for a list of texts using LiteLLM.
Args:
texts: List of texts to embed
model: The embedding model to use
Returns:
NumPy array of embeddings
"""
try:
import litellm
response = litellm.embedding(model=model, input=texts)
embeddings = [item["embedding"] for item in response.data]
return np.array(embeddings, dtype=np.float32)
except ImportError as e:
raise ImportError(
"litellm is required for embeddings. Install with: pip install litellm"
) from e
except Exception as e:
raise RuntimeError(f"Failed to generate embeddings: {e}") from e
def get_embeddings_local(
texts: List[str], model: str = "all-MiniLM-L6-v2"
) -> np.ndarray:
"""
Generate embeddings locally using sentence-transformers.
Args:
texts: List of texts to embed
model: The sentence-transformer model to use
Returns:
NumPy array of embeddings
"""
try:
from sentence_transformers import SentenceTransformer
encoder = SentenceTransformer(model)
embeddings = encoder.encode(texts, show_progress_bar=False)
return np.array(embeddings, dtype=np.float32)
except ImportError as e:
raise ImportError(
"sentence-transformers is required for local embeddings. "
"Install with: pip install sentence-transformers"
) from e
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""
Compute cosine similarity between two vectors.
Args:
a: First vector
b: Second vector
Returns:
Cosine similarity score
"""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)
def batch_cosine_similarity(query: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
"""
Compute cosine similarity between a query and multiple embeddings.
Args:
query: Query vector
embeddings: Matrix of embeddings
Returns:
Array of similarity scores
"""
query_norm = np.linalg.norm(query)
embeddings_norm = np.linalg.norm(embeddings, axis=1)
return np.dot(embeddings, query) / (embeddings_norm * query_norm + 1e-10)
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation."""
def __init__(self, max_size: int = 1000):
"""
Initialize the embedding cache.
Args:
max_size: Maximum number of embeddings to cache
"""
self.max_size = max_size
self._cache: dict[str, np.ndarray] = {}
self._order: list[str] = []
def get(self, text: str) -> Optional[np.ndarray]:
"""
Get a cached embedding.
Args:
text: The text to look up
Returns:
The cached embedding or None
"""
return self._cache.get(text)
def set(self, text: str, embedding: np.ndarray):
"""
Cache an embedding.
Args:
text: The text key
embedding: The embedding to cache
"""
if text in self._cache:
return
if len(self._cache) >= self.max_size:
# Remove oldest entry
oldest = self._order.pop(0)
del self._cache[oldest]
self._cache[text] = embedding
self._order.append(text)
def clear(self):
"""Clear the cache."""
self._cache.clear()
self._order.clear()
def __len__(self) -> int:
return len(self._cache)