mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-01 15:51:10 +00:00
feat: context compression (#2173)
* feat: context compression * fix: ruff
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from google import genai
|
||||
@@ -11,11 +10,13 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
@@ -33,6 +34,12 @@ class GoogleLLM(BaseLLM):
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -135,12 +142,38 @@ class GoogleLLM(BaseLLM):
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""Convert OpenAI format messages to Google AI format."""
|
||||
"""
|
||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||
|
||||
Returns:
|
||||
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
|
||||
combined system instruction.
|
||||
"""
|
||||
cleaned_messages = []
|
||||
system_instructions = []
|
||||
|
||||
def _extract_system_text(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item and item["text"] is not None:
|
||||
parts.append(item["text"])
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
# Gemini only accepts user/model in the contents list.
|
||||
if role == "system":
|
||||
sys_text = _extract_system_text(content)
|
||||
if sys_text:
|
||||
system_instructions.append(sys_text)
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
@@ -159,12 +192,27 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
# Create function call part with thought_signature if present
|
||||
# For Gemini 3 models, we need to include thought_signature
|
||||
if "thought_signature" in item:
|
||||
# Use Part constructor with functionCall and thoughtSignature
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=item["thought_signature"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use helper method when no thought_signature
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
@@ -188,7 +236,8 @@ class GoogleLLM(BaseLLM):
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
return cleaned_messages
|
||||
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
|
||||
return cleaned_messages, system_instruction
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
"""
|
||||
@@ -274,6 +323,61 @@ class GoogleLLM(BaseLLM):
|
||||
genai_tools.append(genai_tool)
|
||||
return genai_tools
|
||||
|
||||
def _extract_preview_from_message(self, message):
|
||||
"""Get a short, human-readable preview from the last message."""
|
||||
try:
|
||||
if hasattr(message, "parts"):
|
||||
for part in reversed(message.parts):
|
||||
if getattr(part, "text", None):
|
||||
return part.text
|
||||
function_call = getattr(part, "function_call", None)
|
||||
if function_call:
|
||||
name = getattr(function_call, "name", "") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
function_response = getattr(part, "function_response", None)
|
||||
if function_response:
|
||||
name = getattr(function_response, "name", "") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for item in reversed(content):
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, dict):
|
||||
if item.get("text"):
|
||||
return item["text"]
|
||||
if item.get("function_call"):
|
||||
fn = item["function_call"]
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name") or "function_call"
|
||||
return f"function_call:{name}"
|
||||
return "function_call"
|
||||
if item.get("function_response"):
|
||||
resp = item["function_response"]
|
||||
if isinstance(resp, dict):
|
||||
name = resp.get("name") or "function_response"
|
||||
return f"function_response:{name}"
|
||||
return "function_response"
|
||||
if "text" in message and isinstance(message["text"], str):
|
||||
return message["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return str(message)
|
||||
|
||||
def _summarize_messages_for_log(self, messages, preview_chars=20):
|
||||
"""Return a compact summary for logging to avoid huge payloads."""
|
||||
message_count = len(messages) if messages else 0
|
||||
last_preview = ""
|
||||
if messages:
|
||||
last_preview = self._extract_preview_from_message(messages[-1]) or ""
|
||||
last_preview = str(last_preview).replace("\n", " ")
|
||||
if len(last_preview) > preview_chars:
|
||||
last_preview = f"{last_preview[:preview_chars]}..."
|
||||
return f"count={message_count}, last='{last_preview}'"
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -287,12 +391,12 @@ class GoogleLLM(BaseLLM):
|
||||
):
|
||||
"""Generate content using Google AI API without streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
@@ -325,12 +429,12 @@ class GoogleLLM(BaseLLM):
|
||||
):
|
||||
"""Generate content using Google AI API with streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
system_instruction = None
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
messages, system_instruction = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
if system_instruction:
|
||||
config.system_instruction = system_instruction
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
@@ -349,8 +453,12 @@ class GoogleLLM(BaseLLM):
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
messages_summary = self._summarize_messages_for_log(messages)
|
||||
logging.info(
|
||||
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
||||
model,
|
||||
messages_summary,
|
||||
has_attachments,
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
@@ -16,6 +17,7 @@ class ToolCall:
|
||||
name: str
|
||||
arguments: Union[str, Dict]
|
||||
index: Optional[int] = None
|
||||
thought_signature: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "ToolCall":
|
||||
@@ -178,6 +180,406 @@ class LLMHandler(ABC):
|
||||
system_msg["content"] += f"\n\n{combined_text}"
|
||||
return prepared_messages
|
||||
|
||||
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Build a minimal context: system prompt + latest user message only.
|
||||
Drops all tool/function messages to shrink context aggressively.
|
||||
"""
|
||||
system_message = next((m for m in messages if m.get("role") == "system"), None)
|
||||
if not system_message:
|
||||
logger.warning("Cannot prune messages minimally: missing system message.")
|
||||
return None
|
||||
last_non_system = None
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
last_non_system = m
|
||||
break
|
||||
if not last_non_system and m.get("role") not in ("system", None):
|
||||
last_non_system = m
|
||||
if not last_non_system:
|
||||
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
|
||||
return None
|
||||
logger.info("Pruning context to system + latest user/assistant message to proceed.")
|
||||
return [system_message, last_non_system]
|
||||
|
||||
def _extract_text_from_content(self, content: Any) -> str:
|
||||
"""
|
||||
Convert message content (str or list of parts) to plain text for compression.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts_text = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if "text" in item and item["text"] is not None:
|
||||
parts_text.append(str(item["text"]))
|
||||
elif "function_call" in item or "function_response" in item:
|
||||
# Keep serialized function calls/responses so the compressor sees actions
|
||||
parts_text.append(str(item))
|
||||
elif "files" in item:
|
||||
parts_text.append(str(item))
|
||||
return "\n".join(parts_text)
|
||||
return ""
|
||||
|
||||
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
|
||||
"""
|
||||
Build a conversation-like dict from current messages so we can compress
|
||||
even when the conversation isn't persisted yet. Includes tool calls/results.
|
||||
"""
|
||||
queries = []
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
def _commit_query(response_text: str):
|
||||
nonlocal current_prompt, current_tool_calls
|
||||
if current_prompt is None and not response_text:
|
||||
return
|
||||
tool_calls_list = list(current_tool_calls.values())
|
||||
queries.append(
|
||||
{
|
||||
"prompt": current_prompt or "",
|
||||
"response": response_text,
|
||||
"tool_calls": tool_calls_list,
|
||||
}
|
||||
)
|
||||
current_prompt = None
|
||||
current_tool_calls = {}
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "user":
|
||||
current_prompt = self._extract_text_from_content(content)
|
||||
|
||||
elif role in {"assistant", "model"}:
|
||||
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_call" in item:
|
||||
fc = item["function_call"]
|
||||
call_id = fc.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fc.get("name"),
|
||||
"arguments": fc.get("args"),
|
||||
"result": None,
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
elif "function_response" in item:
|
||||
fr = item["function_response"]
|
||||
call_id = fr.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fr.get("name"),
|
||||
"arguments": None,
|
||||
"result": fr.get("response", {}).get("result"),
|
||||
"status": "completed",
|
||||
"call_id": call_id,
|
||||
}
|
||||
# No direct assistant text here; continue to next message
|
||||
continue
|
||||
|
||||
response_text = self._extract_text_from_content(content)
|
||||
_commit_query(response_text)
|
||||
|
||||
elif role == "tool":
|
||||
# Attach tool outputs to the latest pending tool call if possible
|
||||
tool_text = self._extract_text_from_content(content)
|
||||
# Attempt to parse function_response style
|
||||
call_id = None
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item and item["function_response"].get("call_id"):
|
||||
call_id = item["function_response"]["call_id"]
|
||||
break
|
||||
if call_id and call_id in current_tool_calls:
|
||||
current_tool_calls[call_id]["result"] = tool_text
|
||||
current_tool_calls[call_id]["status"] = "completed"
|
||||
elif queries:
|
||||
queries[-1].setdefault("tool_calls", []).append(
|
||||
{
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": "unknown_action",
|
||||
"arguments": {},
|
||||
"result": tool_text,
|
||||
"status": "completed",
|
||||
}
|
||||
)
|
||||
|
||||
# If there's an unfinished prompt with tool_calls but no response yet, commit it
|
||||
if current_prompt is not None or current_tool_calls:
|
||||
_commit_query(response_text="")
|
||||
|
||||
if not queries:
|
||||
return None
|
||||
|
||||
return {
|
||||
"queries": queries,
|
||||
"compression_metadata": {
|
||||
"is_compressed": False,
|
||||
"compression_points": [],
|
||||
},
|
||||
}
|
||||
|
||||
def _rebuild_messages_after_compression(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Delegates to MessageBuilder for the actual reconstruction.
|
||||
"""
|
||||
from application.api.answer.services.compression.message_builder import (
|
||||
MessageBuilder,
|
||||
)
|
||||
|
||||
return MessageBuilder.rebuild_messages_after_compression(
|
||||
messages=messages,
|
||||
compressed_summary=compressed_summary,
|
||||
recent_queries=recent_queries,
|
||||
include_current_execution=include_current_execution,
|
||||
include_tool_calls=include_tool_calls,
|
||||
)
|
||||
|
||||
def _perform_mid_execution_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Perform compression during tool execution and rebuild messages.
|
||||
|
||||
Uses the new orchestrator for simplified compression.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current conversation messages
|
||||
|
||||
Returns:
|
||||
(success: bool, rebuilt_messages: Optional[List[Dict]])
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
|
||||
conversation_service = ConversationService()
|
||||
orchestrator = CompressionOrchestrator(conversation_service)
|
||||
|
||||
# Get conversation from database (may be None for new sessions)
|
||||
conversation = conversation_service.get_conversation(
|
||||
agent.conversation_id, agent.initial_user_id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
# Merge current in-flight messages (including tool calls)
|
||||
conversation_from_msgs = self._build_conversation_from_messages(messages)
|
||||
if conversation_from_msgs:
|
||||
conversation = conversation_from_msgs
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not load conversation for compression; attempting in-memory compression"
|
||||
)
|
||||
return self._perform_in_memory_compression(agent, messages)
|
||||
|
||||
# Use orchestrator to perform compression
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id=agent.conversation_id,
|
||||
user_id=agent.initial_user_id,
|
||||
model_id=agent.model_id,
|
||||
decoded_token=getattr(agent, "decoded_token", {}),
|
||||
current_conversation=conversation,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.warning(f"Mid-execution compression failed: {result.error}")
|
||||
# Try minimal pruning as fallback
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
if not result.compression_performed:
|
||||
logger.warning("Compression not performed")
|
||||
return False, None
|
||||
|
||||
# Check if compression actually reduced tokens
|
||||
if result.metadata:
|
||||
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
|
||||
logger.warning(
|
||||
"Compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
|
||||
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Also store the compression summary as a visible message
|
||||
if result.metadata:
|
||||
conversation_service.append_compression_message(
|
||||
agent.conversation_id, result.metadata.to_dict()
|
||||
)
|
||||
|
||||
# Update agent's compressed summary for downstream persistence
|
||||
agent.compressed_summary = result.compressed_summary
|
||||
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
|
||||
agent.compression_saved = False
|
||||
|
||||
# Reset the context limit flag so tools can continue
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
# Rebuild messages
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
result.compressed_summary,
|
||||
result.recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _perform_in_memory_compression(
|
||||
self, agent, messages: List[Dict]
|
||||
) -> tuple[bool, Optional[List[Dict]]]:
|
||||
"""
|
||||
Fallback compression path when the conversation is not yet persisted.
|
||||
|
||||
Uses CompressionService directly without DB persistence.
|
||||
"""
|
||||
try:
|
||||
from application.api.answer.services.compression.service import (
|
||||
CompressionService,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
conversation = self._build_conversation_from_messages(messages)
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
"Cannot perform in-memory compression: no user/assistant turns found"
|
||||
)
|
||||
return False, None
|
||||
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else agent.model_id
|
||||
)
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key,
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=None, # No DB updates for in-memory
|
||||
)
|
||||
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0 or queries_count == 0:
|
||||
logger.warning("Not enough queries to compress in-memory context")
|
||||
return False, None
|
||||
|
||||
metadata = compression_service.compress_conversation(
|
||||
conversation,
|
||||
compress_up_to_index=compress_up_to,
|
||||
)
|
||||
|
||||
# If compression doesn't reduce tokens, fall back to minimal pruning
|
||||
if (
|
||||
metadata.compressed_token_count
|
||||
>= metadata.original_token_count
|
||||
):
|
||||
logger.warning(
|
||||
"In-memory compression did not reduce token count; falling back to minimal pruning"
|
||||
)
|
||||
pruned = self._prune_messages_minimal(messages)
|
||||
if pruned:
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
return True, pruned
|
||||
return False, None
|
||||
|
||||
# Attach metadata to synthetic conversation
|
||||
conversation["compression_metadata"] = {
|
||||
"is_compressed": True,
|
||||
"compression_points": [metadata.to_dict()],
|
||||
}
|
||||
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
agent.compressed_summary = compressed_summary
|
||||
agent.compression_metadata = metadata.to_dict()
|
||||
agent.compression_saved = False
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
rebuilt_messages = self._rebuild_messages_after_compression(
|
||||
messages,
|
||||
compressed_summary,
|
||||
recent_queries,
|
||||
include_current_execution=False,
|
||||
include_tool_calls=False,
|
||||
)
|
||||
if rebuilt_messages is None:
|
||||
return False, None
|
||||
|
||||
logger.info(
|
||||
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
return True, rebuilt_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error performing in-memory compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return False, None
|
||||
|
||||
def handle_tool_calls(
|
||||
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||
) -> Generator:
|
||||
@@ -195,7 +597,110 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
|
||||
for call in tool_calls:
|
||||
for i, call in enumerate(tool_calls):
|
||||
# Check context limit before executing tool call
|
||||
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
|
||||
# Context limit reached - attempt mid-execution compression
|
||||
compression_attempted = False
|
||||
compression_successful = False
|
||||
|
||||
try:
|
||||
from application.core.settings import settings
|
||||
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
|
||||
except Exception:
|
||||
compression_enabled = False
|
||||
|
||||
if compression_enabled:
|
||||
compression_attempted = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
|
||||
f"Attempting mid-execution compression..."
|
||||
)
|
||||
|
||||
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
|
||||
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
|
||||
agent, updated_messages
|
||||
)
|
||||
|
||||
if compression_successful and rebuilt_messages is not None:
|
||||
# Update the messages list with rebuilt compressed version
|
||||
updated_messages = rebuilt_messages
|
||||
|
||||
# Yield compression success message
|
||||
yield {
|
||||
"type": "info",
|
||||
"data": {
|
||||
"message": "Context window limit reached. Compressed conversation history to continue processing."
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
|
||||
)
|
||||
# Proceed to execute the current tool call with the reduced context
|
||||
else:
|
||||
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
|
||||
compression_attempted = True
|
||||
compression_successful = False
|
||||
|
||||
# If compression wasn't attempted or failed, skip remaining tools
|
||||
if not compression_successful:
|
||||
if i == 0:
|
||||
# Special case: limit reached before executing any tools
|
||||
# This can happen when previous tool responses pushed context over limit
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached before executing any tools. "
|
||||
f"Skipping all {len(tool_calls)} pending tool call(s). "
|
||||
f"This typically occurs when previous tool responses contained large amounts of data. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
else:
|
||||
# Normal case: executed some tools, now stopping
|
||||
tool_word = "tool call" if i == 1 else "tool calls"
|
||||
remaining = len(tool_calls) - i
|
||||
remaining_word = "tool call" if remaining == 1 else "tool calls"
|
||||
if compression_attempted:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Compression attempted but failed. "
|
||||
f"Skipping remaining {remaining} {remaining_word}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Context limit reached after executing {i} {tool_word}. "
|
||||
f"Skipping remaining {remaining} {remaining_word}. "
|
||||
f"Consider enabling compression or using a model with larger context window."
|
||||
)
|
||||
|
||||
# Mark remaining tools as skipped
|
||||
for remaining_call in tool_calls[i:]:
|
||||
skip_message = {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": "system",
|
||||
"call_id": remaining_call.id,
|
||||
"action_name": remaining_call.name,
|
||||
"arguments": {},
|
||||
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
|
||||
"status": "skipped"
|
||||
}
|
||||
}
|
||||
yield skip_message
|
||||
|
||||
# Set flag on agent
|
||||
agent.context_limit_reached = True
|
||||
break
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
@@ -205,21 +710,26 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
# Include thought_signature for Google Gemini 3 models
|
||||
# It should be at the same level as function_call, not inside it
|
||||
if call.thought_signature:
|
||||
function_call_content["thought_signature"] = call.thought_signature
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": [function_call_content],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
@@ -324,6 +834,9 @@ class LLMHandler(ABC):
|
||||
existing.name = call.name
|
||||
if call.arguments:
|
||||
existing.arguments += call.arguments
|
||||
# Preserve thought_signature for Google Gemini 3 models
|
||||
if call.thought_signature:
|
||||
existing.thought_signature = call.thought_signature
|
||||
if parsed.finish_reason == "tool_calls":
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, list(tool_calls.values()), tools_dict, messages
|
||||
@@ -336,8 +849,21 @@ class LLMHandler(ABC):
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": (
|
||||
"WARNING: Context window limit has been reached. "
|
||||
"Please provide a final response to the user without making additional tool calls. "
|
||||
"Summarize the work completed so far."
|
||||
)
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
@@ -19,15 +19,20 @@ class GoogleLLMHandler(LLMHandler):
|
||||
)
|
||||
if hasattr(response, "candidates"):
|
||||
parts = response.candidates[0].content.parts if response.candidates else []
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
)
|
||||
for part in parts
|
||||
if hasattr(part, "function_call") and part.function_call is not None
|
||||
]
|
||||
tool_calls = []
|
||||
for idx, part in enumerate(parts):
|
||||
if hasattr(part, "function_call") and part.function_call is not None:
|
||||
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
|
||||
thought_sig = part.thought_signature if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
index=idx,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
|
||||
content = " ".join(
|
||||
part.text
|
||||
@@ -41,13 +46,17 @@ class GoogleLLMHandler(LLMHandler):
|
||||
raw_response=response,
|
||||
)
|
||||
else:
|
||||
# This branch handles individual Part objects from streaming responses
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call"):
|
||||
if hasattr(response, "function_call") and response.function_call is not None:
|
||||
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
|
||||
thought_sig = response.thought_signature if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=response.function_call.name,
|
||||
arguments=response.function_call.args,
|
||||
thought_signature=thought_sig,
|
||||
)
|
||||
)
|
||||
return LLMResponse(
|
||||
|
||||
@@ -128,6 +128,10 @@ class OpenAILLM(BaseLLM):
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -159,6 +163,10 @@ class OpenAILLM(BaseLLM):
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
|
||||
Reference in New Issue
Block a user