Files
pentestagent/ghostcrew/llm/llm.py

330 lines
10 KiB
Python

"""LiteLLM wrapper for GhostCrew."""
import asyncio
import random
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncIterator, List, Optional
from ..config.constants import DEFAULT_MODEL
from .config import ModelConfig
from .memory import ConversationMemory
if TYPE_CHECKING:
from ..knowledge import RAGEngine
from ..tools import Tool
@dataclass
class LLMResponse:
"""Response from LLM."""
content: Optional[str]
tool_calls: Optional[List[Any]]
usage: Optional[dict]
model: str = ""
finish_reason: str = ""
class LLM:
"""LiteLLM wrapper with tool calling support."""
def __init__(
self,
model: str = None,
config: Optional[ModelConfig] = None,
rag_engine: Optional["RAGEngine"] = None,
):
"""
Initialize the LLM wrapper.
Args:
model: The model to use (supports LiteLLM model names)
config: Model configuration
rag_engine: Optional RAG engine for context injection
"""
self.model = model or DEFAULT_MODEL
self.config = config or ModelConfig()
self.rag_engine = rag_engine
self.memory = ConversationMemory(max_tokens=self.config.max_context_tokens)
# Ensure litellm is available
try:
import litellm
# Drop unsupported params for models that don't support them
litellm.drop_params = True
self._litellm = litellm
except ImportError as e:
raise ImportError(
"litellm is required for LLM functionality. "
"Install with: pip install litellm"
) from e
def _is_rate_limit_error(self, error: Exception) -> bool:
"""Check if an error is a rate limit error."""
error_str = str(error).lower()
error_type = type(error).__name__.lower()
return (
"rate" in error_str
and "limit" in error_str
or "ratelimit" in error_type
or "429" in error_str
or "too many requests" in error_str
)
async def _retry_with_backoff(self, coro_factory, max_retries: int = None):
"""
Retry a coroutine with exponential backoff for rate limits.
Args:
coro_factory: A callable that returns a new coroutine each call
max_retries: Max retry attempts (uses config if not specified)
"""
retries = max_retries or self.config.max_retries
base_delay = self.config.retry_delay
for attempt in range(retries + 1):
try:
return await coro_factory()
except Exception as e:
if not self._is_rate_limit_error(e) or attempt >= retries:
raise
# Exponential backoff with jitter
delay = base_delay * (2**attempt) + random.uniform(0, 1)
await asyncio.sleep(delay)
# Should not reach here
raise RuntimeError("Retry logic failed unexpectedly")
async def generate(
self,
system_prompt: str,
messages: List[dict],
tools: Optional[List["Tool"]] = None,
stream: bool = False,
) -> LLMResponse:
"""
Generate a response from the LLM.
Args:
system_prompt: The system prompt
messages: Conversation messages
tools: Available tools for function calling
stream: Whether to stream the response
Returns:
LLMResponse with the result
"""
# Build messages list
llm_messages = [{"role": "system", "content": system_prompt}]
# Add conversation history with summarization if needed
history = await self.memory.get_messages_with_summary(
messages, llm_call=self._summarize_call
)
llm_messages.extend(history)
# Build tools list
llm_tools = None
if tools:
llm_tools = [tool.to_llm_format() for tool in tools if tool.enabled]
try:
# Build call kwargs - only pass non-default optional params
# to avoid conflicts (e.g., Claude doesn't allow temperature + top_p together)
call_kwargs = {
"model": self.model,
"messages": llm_messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
}
# Only include tools if they exist (Anthropic requires tools param to be omitted, not None/empty)
if llm_tools:
call_kwargs["tools"] = llm_tools
# Only add optional params if explicitly changed from defaults
if self.config.top_p != 1.0:
call_kwargs["top_p"] = self.config.top_p
if self.config.frequency_penalty != 0.0:
call_kwargs["frequency_penalty"] = self.config.frequency_penalty
if self.config.presence_penalty != 0.0:
call_kwargs["presence_penalty"] = self.config.presence_penalty
# Call LLM with retry for rate limits
async def _call():
return await self._litellm.acompletion(**call_kwargs)
response = await self._retry_with_backoff(_call)
# Parse response
choice = response.choices[0]
message = choice.message
# Handle usage - convert to dict safely
usage_dict = None
if response.usage:
try:
usage_dict = dict(response.usage)
except (TypeError, ValueError):
usage_dict = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"completion_tokens": getattr(
response.usage, "completion_tokens", 0
),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
return LLMResponse(
content=message.content,
tool_calls=message.tool_calls,
usage=usage_dict,
model=response.model if hasattr(response, "model") else self.model,
finish_reason=choice.finish_reason or "",
)
except Exception as e:
# Return error as response (after retries exhausted)
return LLMResponse(
content=f"LLM Error: {str(e)}",
tool_calls=None,
usage=None,
model=self.model,
finish_reason="error",
)
async def generate_stream(
self,
system_prompt: str,
messages: List[dict],
tools: Optional[List["Tool"]] = None,
) -> AsyncIterator[str]:
"""
Stream a response from the LLM.
Args:
system_prompt: The system prompt
messages: Conversation messages
tools: Available tools for function calling
Yields:
Response content chunks
"""
llm_messages = [{"role": "system", "content": system_prompt}]
history = await self.memory.get_messages_with_summary(
messages, llm_call=self._summarize_call
)
llm_messages.extend(history)
llm_tools = None
if tools:
llm_tools = [tool.to_llm_format() for tool in tools if tool.enabled]
try:
response = await self._litellm.acompletion(
model=self.model,
messages=llm_messages,
tools=llm_tools,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
stream=True,
)
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
yield f"\nLLM Error: {str(e)}"
async def simple_completion(
self, prompt: str, system: str = "You are a helpful assistant."
) -> str:
"""
Simple completion without tools.
Args:
prompt: The user prompt
system: The system prompt
Returns:
The response text
"""
response = await self.generate(
system_prompt=system,
messages=[{"role": "user", "content": prompt}],
tools=None,
)
return response.content or ""
def set_model(self, model: str):
"""Change the model."""
self.model = model
def update_config(self, **kwargs):
"""Update configuration parameters."""
for key, value in kwargs.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
async def _summarize_call(self, prompt: str) -> str:
"""
Internal LLM call for summarization.
Args:
prompt: The summarization prompt
Returns:
Summary text
"""
try:
response = await self._litellm.acompletion(
model=self.model,
messages=[
{
"role": "system",
"content": "You are a terse summarizer for a pentesting agent.",
},
{"role": "user", "content": prompt},
],
temperature=0.3, # Lower temperature for consistent summaries
max_tokens=1000, # Summaries should be concise
)
return response.choices[0].message.content or ""
except Exception as e:
return f"[Summarization failed: {e}]"
def clear_memory(self):
"""Clear conversation memory and summary cache."""
self.memory.clear_summary_cache()
def get_memory_stats(self) -> dict:
"""Get memory usage statistics."""
return self.memory.get_stats()
def get_available_models(self) -> List[str]:
"""
Get list of commonly available models.
Returns:
List of model names
"""
return [
# OpenAI
"gpt-5",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
# Anthropic
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
# Google
"gemini-2.5-pro",
"gemini-2.5-flash",
# Others via LiteLLM
"ollama/llama3",
"ollama/mixtral",
"groq/llama3-70b-8192",
]