feat: context compression

This commit is contained in:
Alex
2025-11-23 18:35:51 +00:00
parent 9e58eb02b3
commit 3737beb2ba
28 changed files with 5393 additions and 93 deletions

View File

@@ -1,4 +1,5 @@
import logging
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
@@ -16,6 +17,7 @@ class ToolCall:
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
thought_signature: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
@@ -178,6 +180,406 @@ class LLMHandler(ABC):
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
"""
Build a minimal context: system prompt + latest user message only.
Drops all tool/function messages to shrink context aggressively.
"""
system_message = next((m for m in messages if m.get("role") == "system"), None)
if not system_message:
logger.warning("Cannot prune messages minimally: missing system message.")
return None
last_non_system = None
for m in reversed(messages):
if m.get("role") == "user":
last_non_system = m
break
if not last_non_system and m.get("role") not in ("system", None):
last_non_system = m
if not last_non_system:
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
return None
logger.info("Pruning context to system + latest user/assistant message to proceed.")
return [system_message, last_non_system]
def _extract_text_from_content(self, content: Any) -> str:
"""
Convert message content (str or list of parts) to plain text for compression.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts_text = []
for item in content:
if isinstance(item, dict):
if "text" in item and item["text"] is not None:
parts_text.append(str(item["text"]))
elif "function_call" in item or "function_response" in item:
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
parts_text.append(str(item))
return "\n".join(parts_text)
return ""
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
"""
Build a conversation-like dict from current messages so we can compress
even when the conversation isn't persisted yet. Includes tool calls/results.
"""
queries = []
current_prompt = None
current_tool_calls = {}
def _commit_query(response_text: str):
nonlocal current_prompt, current_tool_calls
if current_prompt is None and not response_text:
return
tool_calls_list = list(current_tool_calls.values())
queries.append(
{
"prompt": current_prompt or "",
"response": response_text,
"tool_calls": tool_calls_list,
}
)
current_prompt = None
current_tool_calls = {}
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "user":
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
if isinstance(content, list):
for item in content:
if "function_call" in item:
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fc.get("name"),
"arguments": fc.get("args"),
"result": None,
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
"action_name": "unknown_action",
"arguments": {},
"result": tool_text,
"status": "completed",
}
)
# If there's an unfinished prompt with tool_calls but no response yet, commit it
if current_prompt is not None or current_tool_calls:
_commit_query(response_text="")
if not queries:
return None
return {
"queries": queries,
"compression_metadata": {
"is_compressed": False,
"compression_points": [],
},
}
def _rebuild_messages_after_compression(
self,
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Delegates to MessageBuilder for the actual reconstruction.
"""
from application.api.answer.services.compression.message_builder import (
MessageBuilder,
)
return MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=compressed_summary,
recent_queries=recent_queries,
include_current_execution=include_current_execution,
include_tool_calls=include_tool_calls,
)
def _perform_mid_execution_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Perform compression during tool execution and rebuild messages.
Uses the new orchestrator for simplified compression.
Args:
agent: The agent instance
messages: Current conversation messages
Returns:
(success: bool, rebuilt_messages: Optional[List[Dict]])
"""
try:
from application.api.answer.services.compression import (
CompressionOrchestrator,
)
from application.api.answer.services.conversation_service import (
ConversationService,
)
conversation_service = ConversationService()
orchestrator = CompressionOrchestrator(conversation_service)
# Get conversation from database (may be None for new sessions)
conversation = conversation_service.get_conversation(
agent.conversation_id, agent.initial_user_id
)
if conversation:
# Merge current in-flight messages (including tool calls)
conversation_from_msgs = self._build_conversation_from_messages(messages)
if conversation_from_msgs:
conversation = conversation_from_msgs
else:
logger.warning(
"Could not load conversation for compression; attempting in-memory compression"
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
)
if not result.success:
logger.warning(f"Mid-execution compression failed: {result.error}")
# Try minimal pruning as fallback
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
if not result.compression_performed:
logger.warning("Compression not performed")
return False, None
# Check if compression actually reduced tokens
if result.metadata:
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
logger.warning(
"Compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
logger.info(
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
)
# Also store the compression summary as a visible message
if result.metadata:
conversation_service.append_compression_message(
agent.conversation_id, result.metadata.to_dict()
)
# Update agent's compressed summary for downstream persistence
agent.compressed_summary = result.compressed_summary
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
agent.compression_saved = False
# Reset the context limit flag so tools can continue
agent.context_limit_reached = False
agent.current_token_count = 0
# Rebuild messages
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
result.compressed_summary,
result.recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing mid-execution compression: {str(e)}", exc_info=True
)
return False, None
def _perform_in_memory_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Fallback compression path when the conversation is not yet persisted.
Uses CompressionService directly without DB persistence.
"""
try:
from application.api.answer.services.compression.service import (
CompressionService,
)
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
conversation = self._build_conversation_from_messages(messages)
if not conversation:
logger.warning(
"Cannot perform in-memory compression: no user/assistant turns found"
)
return False, None
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
api_key,
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
)
# Create service without DB persistence capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=None, # No DB updates for in-memory
)
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0 or queries_count == 0:
logger.warning("Not enough queries to compress in-memory context")
return False, None
metadata = compression_service.compress_conversation(
conversation,
compress_up_to_index=compress_up_to,
)
# If compression doesn't reduce tokens, fall back to minimal pruning
if (
metadata.compressed_token_count
>= metadata.original_token_count
):
logger.warning(
"In-memory compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
# Attach metadata to synthetic conversation
conversation["compression_metadata"] = {
"is_compressed": True,
"compression_points": [metadata.to_dict()],
}
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
agent.compressed_summary = compressed_summary
agent.compression_metadata = metadata.to_dict()
agent.compression_saved = False
agent.context_limit_reached = False
agent.current_token_count = 0
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
compressed_summary,
recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
logger.info(
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing in-memory compression: {str(e)}", exc_info=True
)
return False, None
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> Generator:
@@ -195,7 +597,110 @@ class LLMHandler(ABC):
"""
updated_messages = messages.copy()
for call in tool_calls:
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
# Context limit reached - attempt mid-execution compression
compression_attempted = False
compression_successful = False
try:
from application.core.settings import settings
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
except Exception:
compression_enabled = False
if compression_enabled:
compression_attempted = True
try:
logger.info(
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
f"Attempting mid-execution compression..."
)
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
agent, updated_messages
)
if compression_successful and rebuilt_messages is not None:
# Update the messages list with rebuilt compressed version
updated_messages = rebuilt_messages
# Yield compression success message
yield {
"type": "info",
"data": {
"message": "Context window limit reached. Compressed conversation history to continue processing."
}
}
logger.info(
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
)
# Proceed to execute the current tool call with the reduced context
else:
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
except Exception as e:
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
compression_attempted = True
compression_successful = False
# If compression wasn't attempted or failed, skip remaining tools
if not compression_successful:
if i == 0:
# Special case: limit reached before executing any tools
# This can happen when previous tool responses pushed context over limit
if compression_attempted:
logger.warning(
f"Context limit reached before executing any tools. "
f"Compression attempted but failed. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data."
)
else:
logger.warning(
f"Context limit reached before executing any tools. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data. "
f"Consider enabling compression or using a model with larger context window."
)
else:
# Normal case: executed some tools, now stopping
tool_word = "tool call" if i == 1 else "tool calls"
remaining = len(tool_calls) - i
remaining_word = "tool call" if remaining == 1 else "tool calls"
if compression_attempted:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Compression attempted but failed. "
f"Skipping remaining {remaining} {remaining_word}."
)
else:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Skipping remaining {remaining} {remaining_word}. "
f"Consider enabling compression or using a model with larger context window."
)
# Mark remaining tools as skipped
for remaining_call in tool_calls[i:]:
skip_message = {
"type": "tool_call",
"data": {
"tool_name": "system",
"call_id": remaining_call.id,
"action_name": remaining_call.name,
"arguments": {},
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
"status": "skipped"
}
}
yield skip_message
# Set flag on agent
agent.context_limit_reached = True
break
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -205,21 +710,26 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
"content": [function_call_content],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
@@ -324,6 +834,9 @@ class LLMHandler(ABC):
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
existing.thought_signature = call.thought_signature
if parsed.finish_reason == "tool_calls":
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
@@ -336,8 +849,21 @@ class LLMHandler(ABC):
break
tool_calls = {}
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit
messages.append({
"role": "system",
"content": (
"WARNING: Context window limit has been reached. "
"Please provide a final response to the user without making additional tool calls. "
"Summarize the work completed so far."
)
})
logger.info("Context limit reached - instructing agent to wrap up")
response = agent.llm.gen_stream(
model=agent.model_id, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
)
self.llm_calls.append(build_stack_data(agent.llm))