mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-16 19:20:53 +00:00
feat(agent): implement context validation and message truncation (#2249)
This commit is contained in:
@@ -346,12 +346,81 @@ class BaseAgent(ABC):
|
||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _validate_context_size(self, messages: List[Dict]) -> None:
|
||||
"""
|
||||
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
|
||||
|
||||
Args:
|
||||
messages: Messages to be sent to LLM
|
||||
"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
|
||||
percentage = (current_tokens / context_limit) * 100
|
||||
|
||||
# Log based on usage level
|
||||
if current_tokens >= context_limit:
|
||||
logger.warning(
|
||||
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%). Model: {self.model_id}"
|
||||
)
|
||||
elif current_tokens >= int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE):
|
||||
logger.info(
|
||||
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%)"
|
||||
)
|
||||
|
||||
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
|
||||
"""
|
||||
Truncate text by removing content from the middle, preserving start and end.
|
||||
|
||||
Args:
|
||||
text: Text to truncate
|
||||
max_tokens: Maximum tokens allowed
|
||||
|
||||
Returns:
|
||||
Truncated text with middle removed if needed
|
||||
"""
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
current_tokens = num_tokens_from_string(text)
|
||||
if current_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
# Estimate chars per token (roughly 4 chars per token for English)
|
||||
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
|
||||
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
|
||||
|
||||
if target_chars <= 0:
|
||||
return ""
|
||||
|
||||
# Split: keep 40% from start, 40% from end, remove middle
|
||||
start_chars = int(target_chars * 0.4)
|
||||
end_chars = int(target_chars * 0.4)
|
||||
|
||||
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
|
||||
|
||||
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
|
||||
|
||||
logger.info(
|
||||
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
|
||||
f"(removed middle section)"
|
||||
)
|
||||
|
||||
return truncated
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
) -> List[Dict]:
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
# Append compression summary to system prompt if present
|
||||
if self.compressed_summary:
|
||||
compression_context = (
|
||||
@@ -363,9 +432,34 @@ class BaseAgent(ABC):
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
system_tokens = num_tokens_from_string(system_prompt)
|
||||
|
||||
# Reserve 10% for response/tools
|
||||
safety_buffer = int(context_limit * 0.1)
|
||||
available_after_system = context_limit - system_tokens - safety_buffer
|
||||
|
||||
# Max tokens for query: 80% of available space (leave room for history)
|
||||
max_query_tokens = int(available_after_system * 0.8)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
# Truncate query from middle if it exceeds 80% of available context
|
||||
if query_tokens > max_query_tokens:
|
||||
query = self._truncate_text_middle(query, max_query_tokens)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
# Calculate remaining budget for chat history
|
||||
available_for_history = max(available_after_system - query_tokens, 0)
|
||||
|
||||
# Truncate chat history to fit within available budget
|
||||
working_history = self._truncate_history_to_fit(
|
||||
self.chat_history,
|
||||
available_for_history,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i in self.chat_history:
|
||||
for i in working_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages.append({"role": "user", "content": i["prompt"]})
|
||||
messages.append({"role": "assistant", "content": i["response"]})
|
||||
@@ -397,7 +491,65 @@ class BaseAgent(ABC):
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
|
||||
def _truncate_history_to_fit(
|
||||
self,
|
||||
history: List[Dict],
|
||||
max_tokens: int,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Truncate chat history to fit within token budget, keeping most recent messages.
|
||||
|
||||
Args:
|
||||
history: Full chat history
|
||||
max_tokens: Maximum tokens allowed for history
|
||||
|
||||
Returns:
|
||||
Truncated history (most recent messages that fit)
|
||||
"""
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
if not history or max_tokens <= 0:
|
||||
return []
|
||||
|
||||
truncated = []
|
||||
current_tokens = 0
|
||||
|
||||
# Iterate from newest to oldest
|
||||
for message in reversed(history):
|
||||
message_tokens = 0
|
||||
|
||||
if "prompt" in message and "response" in message:
|
||||
message_tokens += num_tokens_from_string(message["prompt"])
|
||||
message_tokens += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_str = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
message_tokens += num_tokens_from_string(tool_str)
|
||||
|
||||
if current_tokens + message_tokens <= max_tokens:
|
||||
current_tokens += message_tokens
|
||||
truncated.insert(0, message) # Maintain chronological order
|
||||
else:
|
||||
break
|
||||
|
||||
if len(truncated) < len(history):
|
||||
logger.info(
|
||||
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
|
||||
f"to fit within {max_tokens:,} token budget"
|
||||
)
|
||||
|
||||
return truncated
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
# Pre-flight context validation - fail fast if over limit
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user