Compare commits

...

2 Commits

Author SHA1 Message Date
Alex
be842f89b9 fix: ruff 2025-11-24 10:39:27 +00:00
Alex
3737beb2ba feat: context compression 2025-11-23 18:35:51 +00:00
28 changed files with 5393 additions and 93 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
experiments/
experiments
# C extensions

View File

@@ -34,6 +34,7 @@ class BaseAgent(ABC):
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -64,6 +65,9 @@ class BaseAgent(ABC):
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
self.request_limit = request_limit
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
@log_activity()
def gen(
@@ -276,12 +280,77 @@ class BaseAgent(ABC):
for tool_call in self.tool_calls
]
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
"""
Calculate total tokens in current context (messages).
Args:
messages: List of message dicts
Returns:
Total token count
"""
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
"""
Check if we're approaching context limit (80%).
Args:
messages: Current message list
Returns:
True if at or above 80% of context limit
"""
from application.core.model_utils import get_token_limit
from application.core.settings import settings
try:
# Calculate current tokens
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
# Get context limit for model
context_limit = get_token_limit(self.model_id)
# Calculate threshold (80%)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
# Check if we've reached the limit
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _build_messages(
self,
system_prompt: str,
query: str,
) -> List[Dict]:
"""Build messages using pre-rendered system prompt"""
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{self.compressed_summary}"
)
system_prompt = system_prompt + compression_context
messages = [{"role": "system", "content": system_prompt}]
for i in self.chat_history:

View File

@@ -266,6 +266,26 @@ class BaseAnswerResource:
shared_token=shared_token,
attachment_ids=attachment_ids,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata: {str(e)}",
exc_info=True,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
@@ -328,6 +348,25 @@ class BaseAnswerResource:
shared_token=shared_token,
attachment_ids=attachment_ids,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata (partial stream): {str(e)}",
exc_info=True,
)
except Exception as e:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True

View File

@@ -0,0 +1,20 @@
"""
Compression module for managing conversation context compression.
"""
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import (
CompressionResult,
CompressionMetadata,
)
__all__ = [
"CompressionOrchestrator",
"CompressionService",
"CompressionResult",
"CompressionMetadata",
]

View File

@@ -0,0 +1,234 @@
"""Message reconstruction utilities for compression."""
import logging
import uuid
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class MessageBuilder:
"""Builds message arrays from compressed context."""
@staticmethod
def build_from_compressed_context(
system_prompt: str,
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_tool_calls: bool = False,
context_type: str = "pre_request",
) -> List[Dict]:
"""
Build messages from compressed context.
Args:
system_prompt: Original system prompt
compressed_summary: Compressed summary (if any)
recent_queries: Recent uncompressed queries
include_tool_calls: Whether to include tool calls from history
context_type: Type of context ('pre_request' or 'mid_execution')
Returns:
List of message dicts ready for LLM
"""
# Append compression summary to system prompt if present
if compressed_summary:
system_prompt = MessageBuilder._append_compression_context(
system_prompt, compressed_summary, context_type
)
messages = [{"role": "system", "content": system_prompt}]
# Add recent history
for query in recent_queries:
if "prompt" in query and "response" in query:
messages.append({"role": "user", "content": query["prompt"]})
messages.append({"role": "assistant", "content": query["response"]})
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
return messages
@staticmethod
def _append_compression_context(
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
) -> str:
"""
Append compression context to system prompt.
Args:
system_prompt: Original system prompt
compressed_summary: Summary to append
context_type: Type of compression context
Returns:
Updated system prompt
"""
# Remove existing compression context if present
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
parts = system_prompt.split("\n\n---\n\n")
system_prompt = parts[0]
# Build appropriate context message based on type
if context_type == "mid_execution":
context_message = (
"\n\n---\n\n"
"Context window limit reached during execution. "
"Previous conversation has been compressed to fit within limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
else: # pre_request
context_message = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
return system_prompt + context_message
@staticmethod
def rebuild_messages_after_compression(
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Args:
messages: Original message list
compressed_summary: Compressed summary
recent_queries: Recent uncompressed queries
include_current_execution: Whether to preserve current execution messages
include_tool_calls: Whether to include tool calls from history
Returns:
Rebuilt message list or None if failed
"""
# Find the system message
system_message = next(
(msg for msg in messages if msg.get("role") == "system"), None
)
if not system_message:
logger.warning("No system message found in messages list")
return None
# Update system message with compressed summary
if compressed_summary:
content = system_message.get("content", "")
system_message["content"] = MessageBuilder._append_compression_context(
content, compressed_summary, "mid_execution"
)
logger.info(
"Appended compression summary to system prompt (truncated): %s",
(
compressed_summary[:500] + "..."
if len(compressed_summary) > 500
else compressed_summary
),
)
rebuilt_messages = [system_message]
# Add recent history from compressed context
for query in recent_queries:
if "prompt" in query and "response" in query:
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
rebuilt_messages.append(
{"role": "assistant", "content": query["response"]}
)
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
rebuilt_messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
if include_current_execution:
# Preserve any messages that were added during the current execution cycle
recent_msg_count = 1 # system message
for query in recent_queries:
if "prompt" in query and "response" in query:
recent_msg_count += 2
if "tool_calls" in query:
recent_msg_count += len(query["tool_calls"]) * 2
if len(messages) > recent_msg_count:
current_execution_messages = messages[recent_msg_count:]
rebuilt_messages.extend(current_execution_messages)
logger.info(
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
)
logger.info(
f"Messages rebuilt: {len(messages)}{len(rebuilt_messages)} messages. "
f"Ready to continue tool execution."
)
return rebuilt_messages

View File

@@ -0,0 +1,232 @@
"""High-level compression orchestration."""
import logging
from typing import Any, Dict, Optional
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.threshold_checker import (
CompressionThresholdChecker,
)
from application.api.answer.services.compression.types import CompressionResult
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
logger = logging.getLogger(__name__)
class CompressionOrchestrator:
"""
Facade for compression operations.
Coordinates between all compression components and provides
a simple interface for callers.
"""
def __init__(
self,
conversation_service: ConversationService,
threshold_checker: Optional[CompressionThresholdChecker] = None,
):
"""
Initialize orchestrator.
Args:
conversation_service: Service for DB operations
threshold_checker: Custom threshold checker (optional)
"""
self.conversation_service = conversation_service
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
def compress_if_needed(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
This is the main entry point for compression operations.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Conversation {conversation_id} not found for user {user_id}"
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
):
# No compression needed, return full history
queries = conversation.get("queries", [])
return CompressionResult.success_no_compression(queries)
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in compress_if_needed: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))
def _perform_compression(
self,
conversation_id: str,
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
) -> CompressionResult:
"""
Perform the actual compression operation.
Args:
conversation_id: Conversation ID
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
Returns:
CompressionResult
"""
try:
# Determine which model to use for compression
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
)
# Create compression service with DB update capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=self.conversation_service,
)
# Compress all queries up to the latest
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0:
logger.warning("No queries to compress")
return CompressionResult.success_no_compression([])
logger.info(
f"Initiating compression for conversation {conversation_id}: "
f"compressing all {queries_count} queries (0-{compress_up_to})"
)
# Perform compression and save to DB
metadata = compression_service.compress_and_save(
conversation_id, conversation, compress_up_to
)
logger.info(
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
)
# Get compressed context
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
return CompressionResult.success_with_compression(
compressed_summary, recent_queries, metadata
)
except Exception as e:
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
return CompressionResult.failure(str(e))
def compress_mid_execution(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
Returns:
CompressionResult
"""
try:
# Load conversation if not provided
if current_conversation:
conversation = current_conversation
else:
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Could not load conversation {conversation_id} for mid-execution compression"
)
return CompressionResult.failure("Conversation not found")
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in mid-execution compression: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))

View File

@@ -0,0 +1,149 @@
"""Compression prompt building logic."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class CompressionPromptBuilder:
"""Builds prompts for LLM compression calls."""
def __init__(self, version: str = "v1.0"):
"""
Initialize prompt builder.
Args:
version: Prompt template version to use
"""
self.version = version
self.system_prompt = self._load_prompt(version)
def _load_prompt(self, version: str) -> str:
"""
Load prompt template from file.
Args:
version: Version string (e.g., 'v1.0')
Returns:
Prompt template content
Raises:
FileNotFoundError: If prompt template file doesn't exist
"""
current_dir = Path(__file__).resolve().parents[4]
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
try:
with open(prompt_path, "r") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Compression prompt template not found: {prompt_path}")
raise FileNotFoundError(
f"Compression prompt template '{version}' not found at {prompt_path}. "
f"Please ensure the template file exists."
)
def build_prompt(
self,
queries: List[Dict[str, Any]],
existing_compressions: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, str]]:
"""
Build messages for compression LLM call.
Args:
queries: List of query objects to compress
existing_compressions: List of previous compression points
Returns:
List of message dicts for LLM
"""
# Build conversation text
conversation_text = self._format_conversation(queries)
# Add existing compression context if present
existing_compression_context = ""
if existing_compressions and len(existing_compressions) > 0:
existing_compression_context = (
"\n\nIMPORTANT: This conversation has been compressed before. "
"Previous compression summaries:\n\n"
)
for i, comp in enumerate(existing_compressions):
existing_compression_context += (
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
f"{comp.get('compressed_summary', '')}\n\n"
)
existing_compression_context += (
"Your task is to create a NEW summary that incorporates the context from "
"previous compressions AND the new messages below. The final summary should "
"be comprehensive and include all important information from both previous "
"compressions and new messages.\n\n"
)
user_prompt = (
f"{existing_compression_context}"
f"Here is the conversation to summarize:\n\n"
f"{conversation_text}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
return messages
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
"""
Format conversation queries into readable text for compression.
Args:
queries: List of query objects
Returns:
Formatted conversation text
"""
conversation_lines = []
for i, query in enumerate(queries):
conversation_lines.append(f"--- Message {i + 1} ---")
conversation_lines.append(f"User: {query.get('prompt', '')}")
# Add tool calls if present
tool_calls = query.get("tool_calls", [])
if tool_calls:
conversation_lines.append("\nTool Calls:")
for tc in tool_calls:
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
arguments = tc.get("arguments", {})
result = tc.get("result", "")
if result is None:
result = ""
status = tc.get("status", "unknown")
# Include full tool result for complete compression context
conversation_lines.append(
f" - {tool_name}.{action_name}({arguments}) "
f"[{status}] → {result}"
)
# Add agent thought if present
thought = query.get("thought", "")
if thought:
conversation_lines.append(f"\nAgent Thought: {thought}")
# Add assistant response
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
# Add sources if present
sources = query.get("sources", [])
if sources:
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
conversation_lines.append("") # Empty line between messages
return "\n".join(conversation_lines)

View File

@@ -0,0 +1,306 @@
"""Core compression service with simplified responsibilities."""
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.api.answer.services.compression.prompt_builder import (
CompressionPromptBuilder,
)
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.compression.types import (
CompressionMetadata,
)
from application.core.settings import settings
logger = logging.getLogger(__name__)
class CompressionService:
"""
Service for compressing conversation history.
Handles DB updates.
"""
def __init__(
self,
llm,
model_id: str,
conversation_service=None,
prompt_builder: Optional[CompressionPromptBuilder] = None,
):
"""
Initialize compression service.
Args:
llm: LLM instance to use for compression
model_id: Model ID for compression
conversation_service: Service for DB operations (optional, for DB updates)
prompt_builder: Custom prompt builder (optional)
"""
self.llm = llm
self.model_id = model_id
self.conversation_service = conversation_service
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
version=settings.COMPRESSION_PROMPT_VERSION
)
def compress_conversation(
self,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation history up to specified index.
Args:
conversation: Full conversation document
compress_up_to_index: Last query index to include in compression
Returns:
CompressionMetadata with compression details
Raises:
ValueError: If compress_up_to_index is invalid
"""
try:
queries = conversation.get("queries", [])
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
raise ValueError(
f"Invalid compress_up_to_index: {compress_up_to_index} "
f"(conversation has {len(queries)} queries)"
)
# Get queries to compress
queries_to_compress = queries[: compress_up_to_index + 1]
# Check if there are existing compressions
existing_compressions = conversation.get("compression_metadata", {}).get(
"compression_points", []
)
if existing_compressions:
logger.info(
f"Found {len(existing_compressions)} previous compression(s) - "
f"will incorporate into new summary"
)
# Calculate original token count
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
# Log tool call stats
self._log_tool_call_stats(queries_to_compress)
# Build compression prompt
messages = self.prompt_builder.build_prompt(
queries_to_compress, existing_compressions
)
# Call LLM to generate compression
logger.info(
f"Starting compression: {len(queries_to_compress)} queries "
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
f"using model {self.model_id}"
)
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
)
# Extract summary from response
compressed_summary = self._extract_summary(response)
# Calculate compressed token count
compressed_tokens = TokenCounter.count_message_tokens(
[{"content": compressed_summary}]
)
# Calculate compression ratio
compression_ratio = (
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
)
logger.info(
f"Compression complete: {original_tokens}{compressed_tokens} tokens "
f"({compression_ratio:.1f}x compression)"
)
# Build compression metadata
compression_metadata = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=compress_up_to_index,
compressed_summary=compressed_summary,
original_token_count=original_tokens,
compressed_token_count=compressed_tokens,
compression_ratio=compression_ratio,
model_used=self.model_id,
compression_prompt_version=self.prompt_builder.version,
)
return compression_metadata
except Exception as e:
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
raise
def compress_and_save(
self,
conversation_id: str,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation and save to database.
Args:
conversation_id: Conversation ID
conversation: Full conversation document
compress_up_to_index: Last query index to include
Returns:
CompressionMetadata
Raises:
ValueError: If conversation_service not provided or invalid index
"""
if not self.conversation_service:
raise ValueError(
"conversation_service required for compress_and_save operation"
)
# Perform compression
metadata = self.compress_conversation(conversation, compress_up_to_index)
# Save to database
self.conversation_service.update_compression_metadata(
conversation_id, metadata.to_dict()
)
logger.info(f"Compression metadata saved to database for {conversation_id}")
return metadata
def get_compressed_context(
self, conversation: Dict[str, Any]
) -> tuple[Optional[str], List[Dict[str, Any]]]:
"""
Get compressed summary + recent uncompressed messages.
Args:
conversation: Full conversation document
Returns:
(compressed_summary, recent_messages)
"""
try:
compression_metadata = conversation.get("compression_metadata", {})
if not compression_metadata.get("is_compressed"):
logger.debug("No compression metadata found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
compression_points = compression_metadata.get("compression_points", [])
if not compression_points:
logger.debug("No compression points found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
# Get the most recent compression point
latest_compression = compression_points[-1]
compressed_summary = latest_compression.get("compressed_summary")
last_compressed_index = latest_compression.get("query_index")
compressed_tokens = latest_compression.get("compressed_token_count", 0)
original_tokens = latest_compression.get("original_token_count", 0)
# Get only messages after compression point
queries = conversation.get("queries", [])
total_queries = len(queries)
recent_queries = queries[last_compressed_index + 1 :]
logger.info(
f"Using compressed context: summary ({compressed_tokens} tokens, "
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
)
return compressed_summary, recent_queries
except Exception as e:
logger.error(
f"Error getting compressed context: {str(e)}", exc_info=True
)
queries = conversation.get("queries", [])
if queries is None:
return None, []
return None, queries
def _extract_summary(self, llm_response: str) -> str:
"""
Extract clean summary from LLM response.
Args:
llm_response: Raw LLM response
Returns:
Cleaned summary text
"""
try:
# Try to extract content within <summary> tags
summary_match = re.search(
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
)
if summary_match:
summary = summary_match.group(1).strip()
else:
# If no summary tags, remove analysis tags and use the rest
summary = re.sub(
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
).strip()
return summary
except Exception as e:
logger.warning(f"Error extracting summary: {str(e)}, using full response")
return llm_response
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
"""Log statistics about tool calls in queries."""
total_tool_calls = 0
total_tool_result_chars = 0
tool_call_breakdown = {}
for q in queries:
for tc in q.get("tool_calls", []):
total_tool_calls += 1
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
key = f"{tool_name}.{action_name}"
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
# Track total tool result size
result = tc.get("result", "")
if result:
total_tool_result_chars += len(str(result))
if total_tool_calls > 0:
tool_breakdown_str = ", ".join(
f"{tool}({count})"
for tool, count in sorted(tool_call_breakdown.items())
)
tool_result_kb = total_tool_result_chars / 1024
logger.info(
f"Tool call breakdown: {tool_breakdown_str} "
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
)

View File

@@ -0,0 +1,103 @@
"""Compression threshold checking logic."""
import logging
from typing import Any, Dict
from application.core.model_utils import get_token_limit
from application.core.settings import settings
from application.api.answer.services.compression.token_counter import TokenCounter
logger = logging.getLogger(__name__)
class CompressionThresholdChecker:
"""Determines if compression is needed based on token thresholds."""
def __init__(self, threshold_percentage: float = None):
"""
Initialize threshold checker.
Args:
threshold_percentage: Percentage of context to use as threshold
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
"""
self.threshold_percentage = (
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
)
def should_compress(
self,
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
) -> bool:
"""
Determine if compression is needed.
Args:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
Returns:
True if tokens >= threshold% of context window
"""
try:
# Calculate total tokens in conversation
total_tokens = TokenCounter.count_conversation_tokens(conversation)
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
compression_needed = total_tokens >= threshold
percentage_used = (total_tokens / context_limit) * 100
if compression_needed:
logger.warning(
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
)
else:
logger.info(
f"Compression check: {total_tokens}/{context_limit} tokens "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
)
return compression_needed
except Exception as e:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:
logger.warning(
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
return False

View File

@@ -0,0 +1,103 @@
"""Token counting utilities for compression."""
import logging
from typing import Any, Dict, List
from application.utils import num_tokens_from_string
from application.core.settings import settings
logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
Calculate total tokens in a list of messages.
Args:
messages: List of message dicts with 'content' field
Returns:
Total token count
"""
total_tokens = 0
for message in messages:
content = message.get("content", "")
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True
) -> int:
"""
Count tokens across multiple query objects.
Args:
queries: List of query objects from conversation
include_tool_calls: Whether to count tool call tokens
Returns:
Total token count
"""
total_tokens = 0
for query in queries:
# Count prompt and response tokens
if "prompt" in query:
total_tokens += num_tokens_from_string(query["prompt"])
if "response" in query:
total_tokens += num_tokens_from_string(query["response"])
if "thought" in query:
total_tokens += num_tokens_from_string(query.get("thought", ""))
# Count tool call tokens
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
tool_call_string = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
total_tokens += num_tokens_from_string(tool_call_string)
return total_tokens
@staticmethod
def count_conversation_tokens(
conversation: Dict[str, Any], include_system_prompt: bool = False
) -> int:
"""
Calculate total tokens in a conversation.
Args:
conversation: Conversation document
include_system_prompt: Whether to include system prompt in count
Returns:
Total token count
"""
try:
queries = conversation.get("queries", [])
total_tokens = TokenCounter.count_query_tokens(queries)
# Add system prompt tokens if requested
if include_system_prompt:
# Rough estimate for system prompt
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
return total_tokens
except Exception as e:
logger.error(f"Error calculating conversation tokens: {str(e)}")
return 0

View File

@@ -0,0 +1,83 @@
"""Type definitions for compression module."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class CompressionMetadata:
"""Metadata about a compression operation."""
timestamp: datetime
query_index: int
compressed_summary: str
original_token_count: int
compressed_token_count: int
compression_ratio: float
model_used: str
compression_prompt_version: str
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for DB storage."""
return {
"timestamp": self.timestamp,
"query_index": self.query_index,
"compressed_summary": self.compressed_summary,
"original_token_count": self.original_token_count,
"compressed_token_count": self.compressed_token_count,
"compression_ratio": self.compression_ratio,
"model_used": self.model_used,
"compression_prompt_version": self.compression_prompt_version,
}
@dataclass
class CompressionResult:
"""Result of a compression operation."""
success: bool
compressed_summary: Optional[str] = None
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
metadata: Optional[CompressionMetadata] = None
error: Optional[str] = None
compression_performed: bool = False
@classmethod
def success_with_compression(
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
) -> "CompressionResult":
"""Create a successful result with compression."""
return cls(
success=True,
compressed_summary=summary,
recent_queries=queries,
metadata=metadata,
compression_performed=True,
)
@classmethod
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
"""Create a successful result without compression needed."""
return cls(
success=True,
recent_queries=queries,
compression_performed=False,
)
@classmethod
def failure(cls, error: str) -> "CompressionResult":
"""Create a failure result."""
return cls(success=False, error=error, compression_performed=False)
def as_history(self) -> List[Dict[str, str]]:
"""
Convert recent queries to history format.
Returns:
List of prompt/response dicts
"""
return [
{"prompt": q["prompt"], "response": q["response"]}
for q in self.recent_queries
]

View File

@@ -180,3 +180,103 @@ class ConversationService:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Update conversation with compression metadata.
Uses $push with $slice to keep only the most recent compression points,
preventing unbounded array growth. Since each compression incorporates
previous compressions, older points become redundant.
Args:
conversation_id: Conversation ID
compression_metadata: Compression point data
"""
try:
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$set": {
"compression_metadata.is_compressed": True,
"compression_metadata.last_compression_at": compression_metadata.get(
"timestamp"
),
},
"$push": {
"compression_metadata.compression_points": {
"$each": [compression_metadata],
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
}
},
},
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
)
raise
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Append a synthetic compression summary entry into the conversation history.
This makes the summary visible in the DB alongside normal queries.
"""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"timestamp": timestamp,
"attachments": [],
"model_id": compression_metadata.get("model_used"),
}
}
},
)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
)
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""
Get compression metadata for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Compression metadata dict or None
"""
try:
conversation = self.conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
)
return conversation.get("compression_metadata") if conversation else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True
)
return None

View File

@@ -10,6 +10,8 @@ from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
@@ -90,9 +92,14 @@ class StreamProcessor:
self.shared_token = None
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
)
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
def initialize(self):
"""Initialize all required components for processing"""
@@ -112,15 +119,72 @@ class StreamProcessor:
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
# Check if compression is enabled and needed
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""
Handle conversation compression logic using orchestrator.
Args:
conversation: Full conversation document
"""
try:
# Use orchestrator to handle all compression logic
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
if not result.success:
logger.error(
f"Compression failed: {result.error}, using full history"
)
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
return
# Set compressed summary if compression was performed
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
[{"content": result.compressed_summary}]
)
logger.info(
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
f"+ {len(result.recent_queries)} recent messages"
)
# Build history from recent queries
self.history = result.as_history()
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
# Fallback to original behavior
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
def _process_attachments(self):
"""Process any attachments in the request"""
attachment_ids = self.data.get("attachments", [])
@@ -658,7 +722,7 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
return AgentCreator.create_agent(
agent = AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=provider or settings.LLM_PROVIDER,
@@ -671,4 +735,10 @@ class StreamProcessor:
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
compressed_summary=self.compressed_summary,
)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id
return agent

View File

@@ -29,63 +29,29 @@ GOOGLE_ATTACHMENTS = [
OPENAI_MODELS = [
AvailableModel(
id="gpt-4o",
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni",
description="Latest and most capable model",
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
context_window=400000,
),
),
AvailableModel(
id="gpt-4o-mini",
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni Mini",
description="Fast and efficient",
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
context_window=400000,
),
),
AvailableModel(
id="gpt-4-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Turbo",
description="Fast GPT-4 with 128k context",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="Most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
AvailableModel(
id="gpt-3.5-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-3.5 Turbo",
description="Fast and cost-effective",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=4096,
),
),
)
]
@@ -159,15 +125,15 @@ GOOGLE_MODELS = [
),
),
AvailableModel(
id="gemini-2.5-pro",
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 2.5 Pro",
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
context_window=20000, # Set low for testing compression
),
),
]

View File

@@ -144,6 +144,13 @@ class Settings(BaseSettings):
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
# Conversation Compression Settings
ENABLE_CONVERSATION_COMPRESSION: bool = True
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -1,4 +1,3 @@
import json
import logging
from google import genai
@@ -11,11 +10,13 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -33,6 +34,12 @@ class GoogleLLM(BaseLLM):
"image/jpg",
"image/webp",
"image/gif",
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -135,12 +142,38 @@ class GoogleLLM(BaseLLM):
raise
def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
Returns:
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
combined system instruction.
"""
cleaned_messages = []
system_instructions = []
def _extract_system_text(content):
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and "text" in item and item["text"] is not None:
parts.append(item["text"])
return "\n".join(parts)
return ""
for message in messages:
role = message.get("role")
content = message.get("content")
# Gemini only accepts user/model in the contents list.
if role == "system":
sys_text = _extract_system_text(content)
if sys_text:
system_instructions.append(sys_text)
continue
if role == "assistant":
role = "model"
elif role == "tool":
@@ -159,12 +192,27 @@ class GoogleLLM(BaseLLM):
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=item["function_call"]["name"],
args=cleaned_args,
),
thoughtSignature=item["thought_signature"],
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
)
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
@@ -188,7 +236,8 @@ class GoogleLLM(BaseLLM):
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
return cleaned_messages, system_instruction
def _clean_schema(self, schema_obj):
"""
@@ -274,6 +323,61 @@ class GoogleLLM(BaseLLM):
genai_tools.append(genai_tool)
return genai_tools
def _extract_preview_from_message(self, message):
"""Get a short, human-readable preview from the last message."""
try:
if hasattr(message, "parts"):
for part in reversed(message.parts):
if getattr(part, "text", None):
return part.text
function_call = getattr(part, "function_call", None)
if function_call:
name = getattr(function_call, "name", "") or "function_call"
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = getattr(function_response, "name", "") or "function_response"
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
for item in reversed(content):
if isinstance(item, str):
return item
if isinstance(item, dict):
if item.get("text"):
return item["text"]
if item.get("function_call"):
fn = item["function_call"]
if isinstance(fn, dict):
name = fn.get("name") or "function_call"
return f"function_call:{name}"
return "function_call"
if item.get("function_response"):
resp = item["function_response"]
if isinstance(resp, dict):
name = resp.get("name") or "function_response"
return f"function_response:{name}"
return "function_response"
if "text" in message and isinstance(message["text"], str):
return message["text"]
except Exception:
pass
return str(message)
def _summarize_messages_for_log(self, messages, preview_chars=20):
"""Return a compact summary for logging to avoid huge payloads."""
message_count = len(messages) if messages else 0
last_preview = ""
if messages:
last_preview = self._extract_preview_from_message(messages[-1]) or ""
last_preview = str(last_preview).replace("\n", " ")
if len(last_preview) > preview_chars:
last_preview = f"{last_preview[:preview_chars]}..."
return f"count={message_count}, last='{last_preview}'"
def _raw_gen(
self,
baseself,
@@ -287,12 +391,12 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
@@ -325,12 +429,12 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
@@ -349,8 +453,12 @@ class GoogleLLM(BaseLLM):
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(

View File

@@ -1,4 +1,5 @@
import logging
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
@@ -16,6 +17,7 @@ class ToolCall:
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
thought_signature: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
@@ -178,6 +180,406 @@ class LLMHandler(ABC):
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
"""
Build a minimal context: system prompt + latest user message only.
Drops all tool/function messages to shrink context aggressively.
"""
system_message = next((m for m in messages if m.get("role") == "system"), None)
if not system_message:
logger.warning("Cannot prune messages minimally: missing system message.")
return None
last_non_system = None
for m in reversed(messages):
if m.get("role") == "user":
last_non_system = m
break
if not last_non_system and m.get("role") not in ("system", None):
last_non_system = m
if not last_non_system:
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
return None
logger.info("Pruning context to system + latest user/assistant message to proceed.")
return [system_message, last_non_system]
def _extract_text_from_content(self, content: Any) -> str:
"""
Convert message content (str or list of parts) to plain text for compression.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts_text = []
for item in content:
if isinstance(item, dict):
if "text" in item and item["text"] is not None:
parts_text.append(str(item["text"]))
elif "function_call" in item or "function_response" in item:
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
parts_text.append(str(item))
return "\n".join(parts_text)
return ""
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
"""
Build a conversation-like dict from current messages so we can compress
even when the conversation isn't persisted yet. Includes tool calls/results.
"""
queries = []
current_prompt = None
current_tool_calls = {}
def _commit_query(response_text: str):
nonlocal current_prompt, current_tool_calls
if current_prompt is None and not response_text:
return
tool_calls_list = list(current_tool_calls.values())
queries.append(
{
"prompt": current_prompt or "",
"response": response_text,
"tool_calls": tool_calls_list,
}
)
current_prompt = None
current_tool_calls = {}
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "user":
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
if isinstance(content, list):
for item in content:
if "function_call" in item:
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fc.get("name"),
"arguments": fc.get("args"),
"result": None,
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
"action_name": "unknown_action",
"arguments": {},
"result": tool_text,
"status": "completed",
}
)
# If there's an unfinished prompt with tool_calls but no response yet, commit it
if current_prompt is not None or current_tool_calls:
_commit_query(response_text="")
if not queries:
return None
return {
"queries": queries,
"compression_metadata": {
"is_compressed": False,
"compression_points": [],
},
}
def _rebuild_messages_after_compression(
self,
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Delegates to MessageBuilder for the actual reconstruction.
"""
from application.api.answer.services.compression.message_builder import (
MessageBuilder,
)
return MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=compressed_summary,
recent_queries=recent_queries,
include_current_execution=include_current_execution,
include_tool_calls=include_tool_calls,
)
def _perform_mid_execution_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Perform compression during tool execution and rebuild messages.
Uses the new orchestrator for simplified compression.
Args:
agent: The agent instance
messages: Current conversation messages
Returns:
(success: bool, rebuilt_messages: Optional[List[Dict]])
"""
try:
from application.api.answer.services.compression import (
CompressionOrchestrator,
)
from application.api.answer.services.conversation_service import (
ConversationService,
)
conversation_service = ConversationService()
orchestrator = CompressionOrchestrator(conversation_service)
# Get conversation from database (may be None for new sessions)
conversation = conversation_service.get_conversation(
agent.conversation_id, agent.initial_user_id
)
if conversation:
# Merge current in-flight messages (including tool calls)
conversation_from_msgs = self._build_conversation_from_messages(messages)
if conversation_from_msgs:
conversation = conversation_from_msgs
else:
logger.warning(
"Could not load conversation for compression; attempting in-memory compression"
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
)
if not result.success:
logger.warning(f"Mid-execution compression failed: {result.error}")
# Try minimal pruning as fallback
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
if not result.compression_performed:
logger.warning("Compression not performed")
return False, None
# Check if compression actually reduced tokens
if result.metadata:
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
logger.warning(
"Compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
logger.info(
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
)
# Also store the compression summary as a visible message
if result.metadata:
conversation_service.append_compression_message(
agent.conversation_id, result.metadata.to_dict()
)
# Update agent's compressed summary for downstream persistence
agent.compressed_summary = result.compressed_summary
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
agent.compression_saved = False
# Reset the context limit flag so tools can continue
agent.context_limit_reached = False
agent.current_token_count = 0
# Rebuild messages
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
result.compressed_summary,
result.recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing mid-execution compression: {str(e)}", exc_info=True
)
return False, None
def _perform_in_memory_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Fallback compression path when the conversation is not yet persisted.
Uses CompressionService directly without DB persistence.
"""
try:
from application.api.answer.services.compression.service import (
CompressionService,
)
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
conversation = self._build_conversation_from_messages(messages)
if not conversation:
logger.warning(
"Cannot perform in-memory compression: no user/assistant turns found"
)
return False, None
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
api_key,
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
)
# Create service without DB persistence capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=None, # No DB updates for in-memory
)
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0 or queries_count == 0:
logger.warning("Not enough queries to compress in-memory context")
return False, None
metadata = compression_service.compress_conversation(
conversation,
compress_up_to_index=compress_up_to,
)
# If compression doesn't reduce tokens, fall back to minimal pruning
if (
metadata.compressed_token_count
>= metadata.original_token_count
):
logger.warning(
"In-memory compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
# Attach metadata to synthetic conversation
conversation["compression_metadata"] = {
"is_compressed": True,
"compression_points": [metadata.to_dict()],
}
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
agent.compressed_summary = compressed_summary
agent.compression_metadata = metadata.to_dict()
agent.compression_saved = False
agent.context_limit_reached = False
agent.current_token_count = 0
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
compressed_summary,
recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
logger.info(
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing in-memory compression: {str(e)}", exc_info=True
)
return False, None
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> Generator:
@@ -195,7 +597,110 @@ class LLMHandler(ABC):
"""
updated_messages = messages.copy()
for call in tool_calls:
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
# Context limit reached - attempt mid-execution compression
compression_attempted = False
compression_successful = False
try:
from application.core.settings import settings
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
except Exception:
compression_enabled = False
if compression_enabled:
compression_attempted = True
try:
logger.info(
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
f"Attempting mid-execution compression..."
)
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
agent, updated_messages
)
if compression_successful and rebuilt_messages is not None:
# Update the messages list with rebuilt compressed version
updated_messages = rebuilt_messages
# Yield compression success message
yield {
"type": "info",
"data": {
"message": "Context window limit reached. Compressed conversation history to continue processing."
}
}
logger.info(
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
)
# Proceed to execute the current tool call with the reduced context
else:
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
except Exception as e:
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
compression_attempted = True
compression_successful = False
# If compression wasn't attempted or failed, skip remaining tools
if not compression_successful:
if i == 0:
# Special case: limit reached before executing any tools
# This can happen when previous tool responses pushed context over limit
if compression_attempted:
logger.warning(
f"Context limit reached before executing any tools. "
f"Compression attempted but failed. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data."
)
else:
logger.warning(
f"Context limit reached before executing any tools. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data. "
f"Consider enabling compression or using a model with larger context window."
)
else:
# Normal case: executed some tools, now stopping
tool_word = "tool call" if i == 1 else "tool calls"
remaining = len(tool_calls) - i
remaining_word = "tool call" if remaining == 1 else "tool calls"
if compression_attempted:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Compression attempted but failed. "
f"Skipping remaining {remaining} {remaining_word}."
)
else:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Skipping remaining {remaining} {remaining_word}. "
f"Consider enabling compression or using a model with larger context window."
)
# Mark remaining tools as skipped
for remaining_call in tool_calls[i:]:
skip_message = {
"type": "tool_call",
"data": {
"tool_name": "system",
"call_id": remaining_call.id,
"action_name": remaining_call.name,
"arguments": {},
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
"status": "skipped"
}
}
yield skip_message
# Set flag on agent
agent.context_limit_reached = True
break
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -205,21 +710,26 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
"content": [function_call_content],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
@@ -324,6 +834,9 @@ class LLMHandler(ABC):
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
existing.thought_signature = call.thought_signature
if parsed.finish_reason == "tool_calls":
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
@@ -336,8 +849,21 @@ class LLMHandler(ABC):
break
tool_calls = {}
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit
messages.append({
"role": "system",
"content": (
"WARNING: Context window limit has been reached. "
"Please provide a final response to the user without making additional tool calls. "
"Summarize the work completed so far."
)
})
logger.info("Context limit reached - instructing agent to wrap up")
response = agent.llm.gen_stream(
model=agent.model_id, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -19,15 +19,20 @@ class GoogleLLMHandler(LLMHandler):
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
)
for part in parts
if hasattr(part, "function_call") and part.function_call is not None
]
tool_calls = []
for idx, part in enumerate(parts):
if hasattr(part, "function_call") and part.function_call is not None:
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
thought_sig = part.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
index=idx,
thought_signature=thought_sig,
)
)
content = " ".join(
part.text
@@ -41,13 +46,17 @@ class GoogleLLMHandler(LLMHandler):
raw_response=response,
)
else:
# This branch handles individual Part objects from streaming responses
tool_calls = []
if hasattr(response, "function_call"):
if hasattr(response, "function_call") and response.function_call is not None:
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
thought_sig = response.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=response.function_call.name,
arguments=response.function_call.args,
thought_signature=thought_sig,
)
)
return LLMResponse(

View File

@@ -128,6 +128,10 @@ class OpenAILLM(BaseLLM):
):
messages = self._clean_messages_openai(messages)
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
"messages": messages,
@@ -159,6 +163,10 @@ class OpenAILLM(BaseLLM):
):
messages = self._clean_messages_openai(messages)
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
"messages": messages,

View File

@@ -0,0 +1,35 @@
Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions.
This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing work without losing context.
Before providing your final summary, wrap your analysis in <analysis> tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process:
1. Chronologically analyze each message, tool call and section of the conversation. For each section thoroughly identify:
- The user's explicit requests and intents
- Your approach to addressing the user's requests
- Key decisions, concepts and patterns
- Specific details like if applicable:
- file names
- full code snippets
- function signatures
- file edits
- Errors that you ran into and how you fixed them
- Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
2. Double-check for accuracy and completeness, addressing each required element thoroughly.
Your summary should include the following sections:
1. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail
2. Key Concepts: List all important concepts discussed.
3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important.
4. Errors and fixes: List all errors that you ran into, and how you fixed them. Pay special attention to specific user feedback that you received, especially if the user told you to do something differently.
5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts.
6. All user messages: List ALL user messages that are not tool results. These are critical for understanding the users' feedback and changing intent.
7. Tool Calls: List ALL tool calls made, including their inputs relevant parts of the outputs.
8. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on.
9. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable.
10. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's most recent explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests or really old requests that were already completed without confirming with the user first.
If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation.
Please provide your summary based on the conversation and tools used so far, following this structure and ensuring precision and thoroughness in your response.

View File

@@ -15,7 +15,7 @@ Flask==3.1.1
faiss-cpu==1.9.0.post1
fastmcp==2.11.0
flask-restx==1.3.0
google-genai==1.3.0
google-genai==1.49.0
google-api-python-client==2.179.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.2

View File

@@ -197,6 +197,24 @@ def generate_image_url(image_path):
return f"{base_url}/api/images/{image_path}"
def calculate_compression_threshold(
model_id: str, threshold_percentage: float = 0.8
) -> int:
"""
Calculate token threshold for triggering compression.
Args:
model_id: Model identifier
threshold_percentage: Percentage of context window (default 80%)
Returns:
Token count threshold
"""
total_context = get_token_limit(model_id)
threshold = int(total_context * threshold_percentage)
return threshold
def clean_text_for_tts(text: str) -> str:
"""
clean text for Text-to-Speech processing.

View File

@@ -91,7 +91,7 @@ def test_clean_messages_google_basic():
{"function_call": {"name": "fn", "args": {"a": 1}}},
]},
]
cleaned = llm._clean_messages_google(msgs)
cleaned, system_instruction = llm._clean_messages_google(msgs)
assert all(hasattr(c, "role") and hasattr(c, "parts") for c in cleaned)
assert any(c.role == "model" for c in cleaned)

View File

@@ -0,0 +1,325 @@
import pytest
from unittest.mock import Mock, patch
from application.agents.base import BaseAgent
from application.llm.handlers.base import LLMHandler, ToolCall
class MockAgent(BaseAgent):
"""Mock agent for testing"""
def _gen_inner(self, query, log_context=None):
yield {"answer": "test"}
@pytest.fixture
def mock_agent():
"""Create a mock agent for testing"""
agent = MockAgent(
endpoint="test",
llm_name="openai",
model_id="gpt-4o",
api_key="test-key",
)
agent.llm = Mock()
return agent
@pytest.fixture
def mock_llm_handler():
"""Create a mock LLM handler"""
handler = Mock(spec=LLMHandler)
handler.tool_calls = []
return handler
class TestAgentTokenTracking:
"""Test suite for agent token tracking during execution"""
def test_calculate_current_context_tokens(self, mock_agent):
"""Test token calculation for current context"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
]
tokens = mock_agent._calculate_current_context_tokens(messages)
# Should count tokens from all messages
assert tokens > 0
# Rough estimate: ~20-40 tokens for this conversation
assert 15 < tokens < 60
def test_calculate_tokens_with_tool_calls(self, mock_agent):
"""Test token calculation includes tool call content"""
messages = [
{"role": "system", "content": "Test"},
{
"role": "assistant",
"content": [
{
"function_call": {
"name": "search_tool",
"args": {"query": "test"},
"call_id": "123",
}
}
],
},
{
"role": "tool",
"content": [
{
"function_response": {
"name": "search_tool",
"response": {"result": "Found 10 results"},
"call_id": "123",
}
}
],
},
]
tokens = mock_agent._calculate_current_context_tokens(messages)
# Should include tool call tokens
assert tokens > 0
@patch("application.core.model_utils.get_token_limit")
@patch("application.core.settings.settings")
def test_check_context_limit_below_threshold(
self, mock_settings, mock_get_token_limit, mock_agent
):
"""Test context limit check when below threshold"""
mock_get_token_limit.return_value = 128000
mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8
messages = [
{"role": "system", "content": "Short message"},
{"role": "user", "content": "Hello"},
]
# Should return False for small conversation
result = mock_agent._check_context_limit(messages)
assert result is False
# Should track current token count
assert mock_agent.current_token_count > 0
assert mock_agent.current_token_count < 128000 * 0.8
@patch("application.core.model_utils.get_token_limit")
@patch("application.core.settings.settings")
def test_check_context_limit_above_threshold(
self, mock_settings, mock_get_token_limit, mock_agent
):
"""Test context limit check when above threshold"""
mock_get_token_limit.return_value = 100 # Very small limit for testing
mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8
# Create messages that will exceed 80 tokens (80% of 100)
messages = [
{"role": "system", "content": "a " * 50}, # ~50 tokens
{"role": "user", "content": "b " * 50}, # ~50 tokens
]
# Should return True when exceeding threshold
result = mock_agent._check_context_limit(messages)
assert result is True
@patch("application.agents.base.logger")
def test_check_context_limit_error_handling(self, mock_logger, mock_agent):
"""Test error handling in context limit check"""
# Force an error by making get_token_limit fail
with patch(
"application.core.model_utils.get_token_limit", side_effect=Exception("Test error")
):
messages = [{"role": "user", "content": "test"}]
result = mock_agent._check_context_limit(messages)
# Should return False on error (safe default)
assert result is False
# Should log the error
assert mock_logger.error.called
def test_context_limit_flag_initialization(self, mock_agent):
"""Test that context limit flag is initialized"""
assert hasattr(mock_agent, "context_limit_reached")
assert mock_agent.context_limit_reached is False
assert hasattr(mock_agent, "current_token_count")
assert mock_agent.current_token_count == 0
class TestLLMHandlerTokenTracking:
"""Test suite for LLM handler token tracking"""
@patch("application.llm.handlers.base.logger")
def test_handle_tool_calls_stops_at_limit(self, mock_logger):
"""Test that tool execution stops when context limit is reached"""
from application.llm.handlers.base import LLMHandler
# Create a concrete handler for testing
class TestHandler(LLMHandler):
def parse_response(self, response):
pass
def create_tool_message(self, tool_call, result):
return {"role": "tool", "content": str(result)}
def _iterate_stream(self, response):
yield ""
handler = TestHandler()
# Create mock agent that hits limit on second tool
mock_agent = Mock()
mock_agent.context_limit_reached = False
call_count = [0]
def check_limit_side_effect(messages):
call_count[0] += 1
# Return True on second call (second tool)
return call_count[0] >= 2
mock_agent._check_context_limit = Mock(side_effect=check_limit_side_effect)
mock_agent._execute_tool_action = Mock(
return_value=iter([{"type": "tool_call", "data": {}}])
)
# Create multiple tool calls
tool_calls = [
ToolCall(id="1", name="tool1", arguments={}),
ToolCall(id="2", name="tool2", arguments={}),
ToolCall(id="3", name="tool3", arguments={}),
]
messages = []
tools_dict = {}
# Execute tool calls
results = list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages))
# First tool should execute
assert mock_agent._execute_tool_action.call_count == 1
# Should have yielded skip messages for tools 2 and 3
skip_messages = [r for r in results if r.get("type") == "tool_call" and r.get("data", {}).get("status") == "skipped"]
assert len(skip_messages) == 2
# Should have set the flag
assert mock_agent.context_limit_reached is True
# Should have logged warning
assert mock_logger.warning.called
def test_handle_tool_calls_all_execute_when_no_limit(self):
"""Test that all tools execute when under limit"""
from application.llm.handlers.base import LLMHandler
class TestHandler(LLMHandler):
def parse_response(self, response):
pass
def create_tool_message(self, tool_call, result):
return {"role": "tool", "content": str(result)}
def _iterate_stream(self, response):
yield ""
handler = TestHandler()
# Create mock agent that never hits limit
mock_agent = Mock()
mock_agent.context_limit_reached = False
mock_agent._check_context_limit = Mock(return_value=False)
mock_agent._execute_tool_action = Mock(
return_value=iter([{"type": "tool_call", "data": {}}])
)
tool_calls = [
ToolCall(id="1", name="tool1", arguments={}),
ToolCall(id="2", name="tool2", arguments={}),
ToolCall(id="3", name="tool3", arguments={}),
]
messages = []
tools_dict = {}
# Execute tool calls
list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages))
# All 3 tools should execute
assert mock_agent._execute_tool_action.call_count == 3
# Should not have set the flag
assert mock_agent.context_limit_reached is False
@patch("application.llm.handlers.base.logger")
def test_handle_streaming_adds_warning_message(self, mock_logger):
"""Test that streaming handler adds warning when limit reached"""
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
class TestHandler(LLMHandler):
def parse_response(self, response):
if isinstance(response, dict) and response.get("type") == "tool_call":
return LLMResponse(
content="",
tool_calls=[ToolCall(id="1", name="test", arguments={}, index=0)],
finish_reason="tool_calls",
raw_response=None,
)
else:
return LLMResponse(
content="Done",
tool_calls=[],
finish_reason="stop",
raw_response=None,
)
def create_tool_message(self, tool_call, result):
return {"role": "tool", "content": str(result)}
def _iterate_stream(self, response):
if response == "first":
yield {"type": "tool_call"} # Object to be parsed, not string
else:
yield {"type": "stop"} # Object to be parsed, not string
handler = TestHandler()
# Create mock agent with limit reached
mock_agent = Mock()
mock_agent.context_limit_reached = True
mock_agent.model_id = "gpt-4o"
mock_agent.tools = []
mock_agent.llm = Mock()
mock_agent.llm.gen_stream = Mock(return_value="second")
def tool_handler_gen(*args):
yield {"type": "tool", "data": {}}
return []
# Mock handle_tool_calls to return messages and set flag
with patch.object(
handler, "handle_tool_calls", return_value=tool_handler_gen()
):
messages = []
tools_dict = {}
# Execute streaming
list(handler.handle_streaming(mock_agent, "first", tools_dict, messages))
# Should have called gen_stream with tools=None (disabled)
mock_agent.llm.gen_stream.assert_called()
call_kwargs = mock_agent.llm.gen_stream.call_args.kwargs
assert call_kwargs.get("tools") is None
# Should have logged the warning
assert mock_logger.info.called
if __name__ == "__main__":
pytest.main([__file__, "-v"])

File diff suppressed because it is too large Load Diff

1287
tests/test_integration.py Executable file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
"""
Tests for model validation and base_url functionality
"""
import pytest
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
ModelRegistry,
)
from application.core.model_utils import (
get_base_url_for_model,
validate_model_id,
)
@pytest.mark.unit
def test_model_with_base_url():
"""Test that AvailableModel can store and retrieve base_url"""
model = AvailableModel(
id="test-model",
provider=ModelProvider.OPENAI,
display_name="Test Model",
description="Test model with custom base URL",
base_url="https://custom-endpoint.com/v1",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=8192,
),
)
assert model.base_url == "https://custom-endpoint.com/v1"
assert model.id == "test-model"
assert model.provider == ModelProvider.OPENAI
# Test to_dict includes base_url
model_dict = model.to_dict()
assert "base_url" in model_dict
assert model_dict["base_url"] == "https://custom-endpoint.com/v1"
@pytest.mark.unit
def test_model_without_base_url():
"""Test that models without base_url still work"""
model = AvailableModel(
id="test-model-no-url",
provider=ModelProvider.OPENAI,
display_name="Test Model",
description="Test model without base URL",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=8192,
),
)
assert model.base_url is None
# Test to_dict doesn't include base_url when None
model_dict = model.to_dict()
assert "base_url" not in model_dict
@pytest.mark.unit
def test_validate_model_id():
"""Test model_id validation"""
# Get the registry instance to check what models are available
ModelRegistry.get_instance()
# Test with a model that should exist (docsgpt-local is always added)
assert validate_model_id("docsgpt-local") is True
# Test with invalid model_id
assert validate_model_id("invalid-model-xyz-123") is False
# Test with None
assert validate_model_id(None) is False
@pytest.mark.unit
def test_get_base_url_for_model():
"""Test retrieving base_url for a model"""
# Test with a model that doesn't have base_url
result = get_base_url_for_model("docsgpt-local")
assert result is None # docsgpt-local doesn't have custom base_url
# Test with invalid model
result = get_base_url_for_model("invalid-model")
assert result is None
@pytest.mark.unit
def test_model_validation_error_message():
"""Test that validation provides helpful error messages"""
from application.api.answer.services.stream_processor import StreamProcessor
# Create processor with invalid model_id
data = {"model_id": "invalid-model-xyz"}
processor = StreamProcessor(data, None)
# Should raise ValueError with helpful message
with pytest.raises(ValueError) as exc_info:
processor._validate_and_set_model()
error_msg = str(exc_info.value)
assert "Invalid model_id 'invalid-model-xyz'" in error_msg
assert "Available models:" in error_msg

View File

@@ -0,0 +1,314 @@
"""
Tests for token management and compression features.
NOTE: These tests are for future planned features that are not yet implemented.
They are skipped until the following modules are created:
- application.compression (DocumentCompressor, HistoryCompressor, etc.)
- application.core.token_budget (TokenBudgetManager)
"""
# ruff: noqa: F821
import pytest
pytest.skip(
"Token management features not yet implemented - planned for future release",
allow_module_level=True,
)
class TestTokenBudgetManager:
"""Test TokenBudgetManager functionality"""
def test_calculate_budget(self):
"""Test budget calculation"""
manager = TokenBudgetManager(model_id="gpt-4o")
budget = manager.calculate_budget()
assert budget.total_budget > 0
assert budget.system_prompt > 0
assert budget.chat_history > 0
assert budget.retrieved_docs > 0
def test_measure_usage(self):
"""Test token usage measurement"""
manager = TokenBudgetManager(model_id="gpt-4o")
usage = manager.measure_usage(
system_prompt="You are a helpful assistant.",
current_query="What is Python?",
chat_history=[
{"prompt": "Hello", "response": "Hi there!"},
{"prompt": "How are you?", "response": "I'm doing well, thanks!"},
],
)
assert usage.total > 0
assert usage.system_prompt > 0
assert usage.current_query > 0
assert usage.chat_history > 0
def test_compression_recommendation(self):
"""Test compression recommendation generation"""
manager = TokenBudgetManager(model_id="gpt-4o")
# Create scenario with excessive history
large_history = [
{"prompt": f"Question {i}" * 100, "response": f"Answer {i}" * 100}
for i in range(100)
]
budget, usage, recommendation = manager.check_and_recommend(
system_prompt="You are a helpful assistant.",
current_query="What is Python?",
chat_history=large_history,
)
# Should recommend compression
assert recommendation.needs_compression()
assert recommendation.compress_history
class TestHistoryCompressor:
"""Test HistoryCompressor functionality"""
def test_sliding_window_compression(self):
"""Test sliding window compression strategy"""
compressor = HistoryCompressor()
history = [
{"prompt": f"Question {i}", "response": f"Answer {i}"} for i in range(20)
]
compressed, metadata = compressor.compress(
history, target_tokens=500, strategy="sliding_window"
)
assert len(compressed) < len(history)
assert metadata["original_messages"] == 20
assert metadata["compressed_messages"] < 20
assert metadata["strategy"] == "sliding_window"
def test_preserve_tool_calls(self):
"""Test that tool calls are preserved during compression"""
compressor = HistoryCompressor()
history = [
{"prompt": "Question 1", "response": "Answer 1"},
{
"prompt": "Use a tool",
"response": "Tool used",
"tool_calls": [{"tool_name": "search", "result": "Found something"}],
},
{"prompt": "Question 3", "response": "Answer 3"},
]
compressed, metadata = compressor.compress(
history, target_tokens=200, strategy="sliding_window", preserve_tool_calls=True
)
# Tool call message should be preserved
has_tool_calls = any("tool_calls" in msg for msg in compressed)
assert has_tool_calls
class TestDocumentCompressor:
"""Test DocumentCompressor functionality"""
def test_rerank_compression(self):
"""Test re-ranking compression strategy"""
compressor = DocumentCompressor()
docs = [
{"text": f"Document {i} with some content here" * 20, "title": f"Doc {i}"}
for i in range(10)
]
compressed, metadata = compressor.compress(
docs, target_tokens=500, query="Document 5", strategy="rerank"
)
assert len(compressed) < len(docs)
assert metadata["original_docs"] == 10
assert metadata["strategy"] == "rerank"
def test_excerpt_extraction(self):
"""Test excerpt extraction strategy"""
compressor = DocumentCompressor()
docs = [
{
"text": "This is a long document. " * 100
+ "Python is great. "
+ "More text here. " * 100,
"title": "Python Guide",
}
]
compressed, metadata = compressor.compress(
docs, target_tokens=300, query="Python", strategy="excerpt"
)
assert metadata["excerpts_created"] > 0
# Excerpt should contain the query term
assert "python" in compressed[0]["text"].lower()
class TestToolResultCompressor:
"""Test ToolResultCompressor functionality"""
def test_truncate_large_results(self):
"""Test truncation of large tool results"""
compressor = ToolResultCompressor()
tool_results = [
{
"tool_name": "search",
"result": "Very long result " * 1000,
"arguments": {},
}
]
compressed, metadata = compressor.compress(
tool_results, target_tokens=100, strategy="truncate"
)
assert metadata["results_truncated"] > 0
# Result should be shorter
compressed_result_len = len(str(compressed[0]["result"]))
original_result_len = len(tool_results[0]["result"])
assert compressed_result_len < original_result_len
def test_extract_json_fields(self):
"""Test extraction of key fields from JSON results"""
compressor = ToolResultCompressor()
tool_results = [
{
"tool_name": "api_call",
"result": {
"data": {"important": "value"},
"metadata": {"verbose": "information" * 100},
"debug": {"lots": "of data" * 100},
},
"arguments": {},
}
]
compressed, metadata = compressor.compress(
tool_results, target_tokens=100, strategy="extract"
)
# Should keep important fields, discard verbose ones
assert "data" in compressed[0]["result"]
class TestPromptOptimizer:
"""Test PromptOptimizer functionality"""
def test_compress_tool_descriptions(self):
"""Test compression of tool descriptions"""
optimizer = PromptOptimizer()
tools = [
{
"type": "function",
"function": {
"name": f"tool_{i}",
"description": "This is a very long description " * 50,
"parameters": {},
},
}
for i in range(10)
]
optimized, metadata = optimizer.optimize_tools(
tools, target_tokens=500, strategy="compress"
)
assert metadata["optimized_tokens"] < metadata["original_tokens"]
assert metadata["descriptions_compressed"] > 0
def test_lazy_load_tools(self):
"""Test lazy loading of tools based on query"""
optimizer = PromptOptimizer()
tools = [
{
"type": "function",
"function": {
"name": "search_tool",
"description": "Search for information",
"parameters": {},
},
},
{
"type": "function",
"function": {
"name": "calculate_tool",
"description": "Perform calculations",
"parameters": {},
},
},
{
"type": "function",
"function": {
"name": "other_tool",
"description": "Do something else",
"parameters": {},
},
},
]
optimized, metadata = optimizer.optimize_tools(
tools, target_tokens=200, query="I want to search for something", strategy="lazy_load"
)
# Should prefer search tool
assert len(optimized) < len(tools)
tool_names = [t["function"]["name"] for t in optimized]
# Search tool should be included due to query relevance
assert any("search" in name for name in tool_names)
def test_integration_compression_workflow():
"""Test complete compression workflow"""
# Simulate a scenario with large inputs
manager = TokenBudgetManager(model_id="gpt-4o")
history_compressor = HistoryCompressor()
doc_compressor = DocumentCompressor()
# Large chat history
history = [
{"prompt": f"Question {i}" * 50, "response": f"Answer {i}" * 50}
for i in range(50)
]
# Large documents
docs = [
{"text": f"Document {i} content" * 100, "title": f"Doc {i}"} for i in range(20)
]
# Check budget
budget, usage, recommendation = manager.check_and_recommend(
system_prompt="You are a helpful assistant.",
current_query="What is Python?",
chat_history=history,
retrieved_docs=docs,
)
# Should need compression
assert recommendation.needs_compression()
# Apply compression
if recommendation.compress_history:
compressed_history, hist_meta = history_compressor.compress(
history, recommendation.target_history_tokens or budget.chat_history
)
assert len(compressed_history) < len(history)
if recommendation.compress_docs:
compressed_docs, doc_meta = doc_compressor.compress(
docs,
recommendation.target_docs_tokens or budget.retrieved_docs,
query="Python",
)
assert len(compressed_docs) < len(docs)