From 17698ce77405b8d72fbef36d94369dc0ad90d59e Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 24 Nov 2025 10:44:19 +0000 Subject: [PATCH] feat: context compression (#2173) * feat: context compression * fix: ruff --- .gitignore | 1 + application/agents/base.py | 69 + application/api/answer/routes/base.py | 39 + .../answer/services/compression/__init__.py | 20 + .../services/compression/message_builder.py | 234 +++ .../services/compression/orchestrator.py | 232 +++ .../services/compression/prompt_builder.py | 149 ++ .../answer/services/compression/service.py | 306 ++++ .../services/compression/threshold_checker.py | 103 ++ .../services/compression/token_counter.py | 103 ++ .../api/answer/services/compression/types.py | 83 ++ .../answer/services/conversation_service.py | 100 ++ .../api/answer/services/stream_processor.py | 80 +- application/core/model_configs.py | 58 +- application/core/settings.py | 7 + application/llm/google_ai.py | 146 +- application/llm/handlers/base.py | 548 ++++++- application/llm/handlers/google.py | 29 +- application/llm/openai.py | 8 + application/prompts/compression/v1.0.txt | 35 + application/requirements.txt | 2 +- application/utils.py | 18 + tests/llm/test_google_llm.py | 2 +- tests/test_agent_token_tracking.py | 325 +++++ tests/test_compression_service.py | 1082 ++++++++++++++ tests/test_integration.py | 1287 +++++++++++++++++ tests/test_model_validation.py | 106 ++ tests/test_token_management.py | 314 ++++ 28 files changed, 5393 insertions(+), 93 deletions(-) create mode 100644 application/api/answer/services/compression/__init__.py create mode 100644 application/api/answer/services/compression/message_builder.py create mode 100644 application/api/answer/services/compression/orchestrator.py create mode 100644 application/api/answer/services/compression/prompt_builder.py create mode 100644 application/api/answer/services/compression/service.py create mode 100644 application/api/answer/services/compression/threshold_checker.py create mode 100644 application/api/answer/services/compression/token_counter.py create mode 100644 application/api/answer/services/compression/types.py create mode 100644 application/prompts/compression/v1.0.txt create mode 100644 tests/test_agent_token_tracking.py create mode 100644 tests/test_compression_service.py create mode 100755 tests/test_integration.py create mode 100644 tests/test_model_validation.py create mode 100644 tests/test_token_management.py diff --git a/.gitignore b/.gitignore index 91abeca1..e0e6280a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +experiments/ experiments # C extensions diff --git a/application/agents/base.py b/application/agents/base.py index dbf15a1f..b2d79c03 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -34,6 +34,7 @@ class BaseAgent(ABC): token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"], limited_request_mode: Optional[bool] = False, request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"], + compressed_summary: Optional[str] = None, ): self.endpoint = endpoint self.llm_name = llm_name @@ -64,6 +65,9 @@ class BaseAgent(ABC): self.token_limit = token_limit self.limited_request_mode = limited_request_mode self.request_limit = request_limit + self.compressed_summary = compressed_summary + self.current_token_count = 0 + self.context_limit_reached = False @log_activity() def gen( @@ -276,12 +280,77 @@ class BaseAgent(ABC): for tool_call in self.tool_calls ] + def _calculate_current_context_tokens(self, messages: List[Dict]) -> int: + """ + Calculate total tokens in current context (messages). + + Args: + messages: List of message dicts + + Returns: + Total token count + """ + from application.api.answer.services.compression.token_counter import ( + TokenCounter, + ) + + return TokenCounter.count_message_tokens(messages) + + def _check_context_limit(self, messages: List[Dict]) -> bool: + """ + Check if we're approaching context limit (80%). + + Args: + messages: Current message list + + Returns: + True if at or above 80% of context limit + """ + from application.core.model_utils import get_token_limit + from application.core.settings import settings + + try: + # Calculate current tokens + current_tokens = self._calculate_current_context_tokens(messages) + self.current_token_count = current_tokens + + # Get context limit for model + context_limit = get_token_limit(self.model_id) + + # Calculate threshold (80%) + threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE) + + # Check if we've reached the limit + if current_tokens >= threshold: + logger.warning( + f"Context limit approaching: {current_tokens}/{context_limit} tokens " + f"({(current_tokens/context_limit)*100:.1f}%)" + ) + return True + + return False + + except Exception as e: + logger.error(f"Error checking context limit: {str(e)}", exc_info=True) + return False + def _build_messages( self, system_prompt: str, query: str, ) -> List[Dict]: """Build messages using pre-rendered system prompt""" + # Append compression summary to system prompt if present + if self.compressed_summary: + compression_context = ( + "\n\n---\n\n" + "This session is being continued from a previous conversation that " + "has been compressed to fit within context limits. " + "The conversation is summarized below:\n\n" + f"{self.compressed_summary}" + ) + system_prompt = system_prompt + compression_context + messages = [{"role": "system", "content": system_prompt}] for i in self.chat_history: diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index aefb66c6..be112729 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -266,6 +266,26 @@ class BaseAnswerResource: shared_token=shared_token, attachment_ids=attachment_ids, ) + # Persist compression metadata/summary if it exists and wasn't saved mid-execution + compression_meta = getattr(agent, "compression_metadata", None) + compression_saved = getattr(agent, "compression_saved", False) + if conversation_id and compression_meta and not compression_saved: + try: + self.conversation_service.update_compression_metadata( + conversation_id, compression_meta + ) + self.conversation_service.append_compression_message( + conversation_id, compression_meta + ) + agent.compression_saved = True + logger.info( + f"Persisted compression metadata for conversation {conversation_id}" + ) + except Exception as e: + logger.error( + f"Failed to persist compression metadata: {str(e)}", + exc_info=True, + ) else: conversation_id = None id_data = {"type": "id", "id": str(conversation_id)} @@ -328,6 +348,25 @@ class BaseAnswerResource: shared_token=shared_token, attachment_ids=attachment_ids, ) + compression_meta = getattr(agent, "compression_metadata", None) + compression_saved = getattr(agent, "compression_saved", False) + if conversation_id and compression_meta and not compression_saved: + try: + self.conversation_service.update_compression_metadata( + conversation_id, compression_meta + ) + self.conversation_service.append_compression_message( + conversation_id, compression_meta + ) + agent.compression_saved = True + logger.info( + f"Persisted compression metadata for conversation {conversation_id} (partial stream)" + ) + except Exception as e: + logger.error( + f"Failed to persist compression metadata (partial stream): {str(e)}", + exc_info=True, + ) except Exception as e: logger.error( f"Error saving partial response: {str(e)}", exc_info=True diff --git a/application/api/answer/services/compression/__init__.py b/application/api/answer/services/compression/__init__.py new file mode 100644 index 00000000..4cbdb910 --- /dev/null +++ b/application/api/answer/services/compression/__init__.py @@ -0,0 +1,20 @@ +""" +Compression module for managing conversation context compression. + +""" + +from application.api.answer.services.compression.orchestrator import ( + CompressionOrchestrator, +) +from application.api.answer.services.compression.service import CompressionService +from application.api.answer.services.compression.types import ( + CompressionResult, + CompressionMetadata, +) + +__all__ = [ + "CompressionOrchestrator", + "CompressionService", + "CompressionResult", + "CompressionMetadata", +] diff --git a/application/api/answer/services/compression/message_builder.py b/application/api/answer/services/compression/message_builder.py new file mode 100644 index 00000000..93772fe5 --- /dev/null +++ b/application/api/answer/services/compression/message_builder.py @@ -0,0 +1,234 @@ +"""Message reconstruction utilities for compression.""" + +import logging +import uuid +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class MessageBuilder: + """Builds message arrays from compressed context.""" + + @staticmethod + def build_from_compressed_context( + system_prompt: str, + compressed_summary: Optional[str], + recent_queries: List[Dict], + include_tool_calls: bool = False, + context_type: str = "pre_request", + ) -> List[Dict]: + """ + Build messages from compressed context. + + Args: + system_prompt: Original system prompt + compressed_summary: Compressed summary (if any) + recent_queries: Recent uncompressed queries + include_tool_calls: Whether to include tool calls from history + context_type: Type of context ('pre_request' or 'mid_execution') + + Returns: + List of message dicts ready for LLM + """ + # Append compression summary to system prompt if present + if compressed_summary: + system_prompt = MessageBuilder._append_compression_context( + system_prompt, compressed_summary, context_type + ) + + messages = [{"role": "system", "content": system_prompt}] + + # Add recent history + for query in recent_queries: + if "prompt" in query and "response" in query: + messages.append({"role": "user", "content": query["prompt"]}) + messages.append({"role": "assistant", "content": query["response"]}) + + # Add tool calls from history if present + if include_tool_calls and "tool_calls" in query: + for tool_call in query["tool_calls"]: + call_id = tool_call.get("call_id") or str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + messages.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages.append( + {"role": "tool", "content": [function_response_dict]} + ) + + # If no recent queries (everything was compressed), add a continuation user message + if len(recent_queries) == 0 and compressed_summary: + messages.append({ + "role": "user", + "content": "Please continue with the remaining tasks based on the context above." + }) + logger.info("Added continuation user message to maintain proper turn-taking after full compression") + + return messages + + @staticmethod + def _append_compression_context( + system_prompt: str, compressed_summary: str, context_type: str = "pre_request" + ) -> str: + """ + Append compression context to system prompt. + + Args: + system_prompt: Original system prompt + compressed_summary: Summary to append + context_type: Type of compression context + + Returns: + Updated system prompt + """ + # Remove existing compression context if present + if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt: + parts = system_prompt.split("\n\n---\n\n") + system_prompt = parts[0] + + # Build appropriate context message based on type + if context_type == "mid_execution": + context_message = ( + "\n\n---\n\n" + "Context window limit reached during execution. " + "Previous conversation has been compressed to fit within limits. " + "The conversation is summarized below:\n\n" + f"{compressed_summary}" + ) + else: # pre_request + context_message = ( + "\n\n---\n\n" + "This session is being continued from a previous conversation that " + "has been compressed to fit within context limits. " + "The conversation is summarized below:\n\n" + f"{compressed_summary}" + ) + + return system_prompt + context_message + + @staticmethod + def rebuild_messages_after_compression( + messages: List[Dict], + compressed_summary: Optional[str], + recent_queries: List[Dict], + include_current_execution: bool = False, + include_tool_calls: bool = False, + ) -> Optional[List[Dict]]: + """ + Rebuild the message list after compression so tool execution can continue. + + Args: + messages: Original message list + compressed_summary: Compressed summary + recent_queries: Recent uncompressed queries + include_current_execution: Whether to preserve current execution messages + include_tool_calls: Whether to include tool calls from history + + Returns: + Rebuilt message list or None if failed + """ + # Find the system message + system_message = next( + (msg for msg in messages if msg.get("role") == "system"), None + ) + if not system_message: + logger.warning("No system message found in messages list") + return None + + # Update system message with compressed summary + if compressed_summary: + content = system_message.get("content", "") + system_message["content"] = MessageBuilder._append_compression_context( + content, compressed_summary, "mid_execution" + ) + logger.info( + "Appended compression summary to system prompt (truncated): %s", + ( + compressed_summary[:500] + "..." + if len(compressed_summary) > 500 + else compressed_summary + ), + ) + + rebuilt_messages = [system_message] + + # Add recent history from compressed context + for query in recent_queries: + if "prompt" in query and "response" in query: + rebuilt_messages.append({"role": "user", "content": query["prompt"]}) + rebuilt_messages.append( + {"role": "assistant", "content": query["response"]} + ) + + # Add tool calls from history if present + if include_tool_calls and "tool_calls" in query: + for tool_call in query["tool_calls"]: + call_id = tool_call.get("call_id") or str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + rebuilt_messages.append( + {"role": "assistant", "content": [function_call_dict]} + ) + rebuilt_messages.append( + {"role": "tool", "content": [function_response_dict]} + ) + + # If no recent queries (everything was compressed), add a continuation user message + if len(recent_queries) == 0 and compressed_summary: + rebuilt_messages.append({ + "role": "user", + "content": "Please continue with the remaining tasks based on the context above." + }) + logger.info("Added continuation user message to maintain proper turn-taking after full compression") + + if include_current_execution: + # Preserve any messages that were added during the current execution cycle + recent_msg_count = 1 # system message + for query in recent_queries: + if "prompt" in query and "response" in query: + recent_msg_count += 2 + if "tool_calls" in query: + recent_msg_count += len(query["tool_calls"]) * 2 + + if len(messages) > recent_msg_count: + current_execution_messages = messages[recent_msg_count:] + rebuilt_messages.extend(current_execution_messages) + logger.info( + f"Preserved {len(current_execution_messages)} messages from current execution cycle" + ) + + logger.info( + f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. " + f"Ready to continue tool execution." + ) + return rebuilt_messages diff --git a/application/api/answer/services/compression/orchestrator.py b/application/api/answer/services/compression/orchestrator.py new file mode 100644 index 00000000..797a66d4 --- /dev/null +++ b/application/api/answer/services/compression/orchestrator.py @@ -0,0 +1,232 @@ +"""High-level compression orchestration.""" + +import logging +from typing import Any, Dict, Optional + +from application.api.answer.services.compression.service import CompressionService +from application.api.answer.services.compression.threshold_checker import ( + CompressionThresholdChecker, +) +from application.api.answer.services.compression.types import CompressionResult +from application.api.answer.services.conversation_service import ConversationService +from application.core.model_utils import ( + get_api_key_for_provider, + get_provider_from_model_id, +) +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator + +logger = logging.getLogger(__name__) + + +class CompressionOrchestrator: + """ + Facade for compression operations. + + Coordinates between all compression components and provides + a simple interface for callers. + """ + + def __init__( + self, + conversation_service: ConversationService, + threshold_checker: Optional[CompressionThresholdChecker] = None, + ): + """ + Initialize orchestrator. + + Args: + conversation_service: Service for DB operations + threshold_checker: Custom threshold checker (optional) + """ + self.conversation_service = conversation_service + self.threshold_checker = threshold_checker or CompressionThresholdChecker() + + def compress_if_needed( + self, + conversation_id: str, + user_id: str, + model_id: str, + decoded_token: Dict[str, Any], + current_query_tokens: int = 500, + ) -> CompressionResult: + """ + Check if compression is needed and perform it if so. + + This is the main entry point for compression operations. + + Args: + conversation_id: Conversation ID + user_id: User ID + model_id: Model being used for conversation + decoded_token: User's decoded JWT token + current_query_tokens: Estimated tokens for current query + + Returns: + CompressionResult with summary and recent queries + """ + try: + # Load conversation + conversation = self.conversation_service.get_conversation( + conversation_id, user_id + ) + + if not conversation: + logger.warning( + f"Conversation {conversation_id} not found for user {user_id}" + ) + return CompressionResult.failure("Conversation not found") + + # Check if compression is needed + if not self.threshold_checker.should_compress( + conversation, model_id, current_query_tokens + ): + # No compression needed, return full history + queries = conversation.get("queries", []) + return CompressionResult.success_no_compression(queries) + + # Perform compression + return self._perform_compression( + conversation_id, conversation, model_id, decoded_token + ) + + except Exception as e: + logger.error( + f"Error in compress_if_needed: {str(e)}", exc_info=True + ) + return CompressionResult.failure(str(e)) + + def _perform_compression( + self, + conversation_id: str, + conversation: Dict[str, Any], + model_id: str, + decoded_token: Dict[str, Any], + ) -> CompressionResult: + """ + Perform the actual compression operation. + + Args: + conversation_id: Conversation ID + conversation: Conversation document + model_id: Model ID for conversation + decoded_token: User token + + Returns: + CompressionResult + """ + try: + # Determine which model to use for compression + compression_model = ( + settings.COMPRESSION_MODEL_OVERRIDE + if settings.COMPRESSION_MODEL_OVERRIDE + else model_id + ) + + # Get provider and API key for compression model + provider = get_provider_from_model_id(compression_model) + api_key = get_api_key_for_provider(provider) + + # Create compression LLM + compression_llm = LLMCreator.create_llm( + provider, + api_key=api_key, + user_api_key=None, + decoded_token=decoded_token, + model_id=compression_model, + ) + + # Create compression service with DB update capability + compression_service = CompressionService( + llm=compression_llm, + model_id=compression_model, + conversation_service=self.conversation_service, + ) + + # Compress all queries up to the latest + queries_count = len(conversation.get("queries", [])) + compress_up_to = queries_count - 1 + + if compress_up_to < 0: + logger.warning("No queries to compress") + return CompressionResult.success_no_compression([]) + + logger.info( + f"Initiating compression for conversation {conversation_id}: " + f"compressing all {queries_count} queries (0-{compress_up_to})" + ) + + # Perform compression and save to DB + metadata = compression_service.compress_and_save( + conversation_id, conversation, compress_up_to + ) + + logger.info( + f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, " + f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens" + ) + + # Reload conversation with updated metadata + conversation = self.conversation_service.get_conversation( + conversation_id, user_id=decoded_token.get("sub") + ) + + # Get compressed context + compressed_summary, recent_queries = ( + compression_service.get_compressed_context(conversation) + ) + + return CompressionResult.success_with_compression( + compressed_summary, recent_queries, metadata + ) + + except Exception as e: + logger.error(f"Error performing compression: {str(e)}", exc_info=True) + return CompressionResult.failure(str(e)) + + def compress_mid_execution( + self, + conversation_id: str, + user_id: str, + model_id: str, + decoded_token: Dict[str, Any], + current_conversation: Optional[Dict[str, Any]] = None, + ) -> CompressionResult: + """ + Perform compression during tool execution. + + Args: + conversation_id: Conversation ID + user_id: User ID + model_id: Model ID + decoded_token: User token + current_conversation: Pre-loaded conversation (optional) + + Returns: + CompressionResult + """ + try: + # Load conversation if not provided + if current_conversation: + conversation = current_conversation + else: + conversation = self.conversation_service.get_conversation( + conversation_id, user_id + ) + + if not conversation: + logger.warning( + f"Could not load conversation {conversation_id} for mid-execution compression" + ) + return CompressionResult.failure("Conversation not found") + + # Perform compression + return self._perform_compression( + conversation_id, conversation, model_id, decoded_token + ) + + except Exception as e: + logger.error( + f"Error in mid-execution compression: {str(e)}", exc_info=True + ) + return CompressionResult.failure(str(e)) diff --git a/application/api/answer/services/compression/prompt_builder.py b/application/api/answer/services/compression/prompt_builder.py new file mode 100644 index 00000000..d5ce3183 --- /dev/null +++ b/application/api/answer/services/compression/prompt_builder.py @@ -0,0 +1,149 @@ +"""Compression prompt building logic.""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class CompressionPromptBuilder: + """Builds prompts for LLM compression calls.""" + + def __init__(self, version: str = "v1.0"): + """ + Initialize prompt builder. + + Args: + version: Prompt template version to use + """ + self.version = version + self.system_prompt = self._load_prompt(version) + + def _load_prompt(self, version: str) -> str: + """ + Load prompt template from file. + + Args: + version: Version string (e.g., 'v1.0') + + Returns: + Prompt template content + + Raises: + FileNotFoundError: If prompt template file doesn't exist + """ + current_dir = Path(__file__).resolve().parents[4] + prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt" + + try: + with open(prompt_path, "r") as f: + return f.read() + except FileNotFoundError: + logger.error(f"Compression prompt template not found: {prompt_path}") + raise FileNotFoundError( + f"Compression prompt template '{version}' not found at {prompt_path}. " + f"Please ensure the template file exists." + ) + + def build_prompt( + self, + queries: List[Dict[str, Any]], + existing_compressions: Optional[List[Dict[str, Any]]] = None, + ) -> List[Dict[str, str]]: + """ + Build messages for compression LLM call. + + Args: + queries: List of query objects to compress + existing_compressions: List of previous compression points + + Returns: + List of message dicts for LLM + """ + # Build conversation text + conversation_text = self._format_conversation(queries) + + # Add existing compression context if present + existing_compression_context = "" + if existing_compressions and len(existing_compressions) > 0: + existing_compression_context = ( + "\n\nIMPORTANT: This conversation has been compressed before. " + "Previous compression summaries:\n\n" + ) + for i, comp in enumerate(existing_compressions): + existing_compression_context += ( + f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n" + f"{comp.get('compressed_summary', '')}\n\n" + ) + existing_compression_context += ( + "Your task is to create a NEW summary that incorporates the context from " + "previous compressions AND the new messages below. The final summary should " + "be comprehensive and include all important information from both previous " + "compressions and new messages.\n\n" + ) + + user_prompt = ( + f"{existing_compression_context}" + f"Here is the conversation to summarize:\n\n" + f"{conversation_text}" + ) + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + return messages + + def _format_conversation(self, queries: List[Dict[str, Any]]) -> str: + """ + Format conversation queries into readable text for compression. + + Args: + queries: List of query objects + + Returns: + Formatted conversation text + """ + conversation_lines = [] + + for i, query in enumerate(queries): + conversation_lines.append(f"--- Message {i + 1} ---") + conversation_lines.append(f"User: {query.get('prompt', '')}") + + # Add tool calls if present + tool_calls = query.get("tool_calls", []) + if tool_calls: + conversation_lines.append("\nTool Calls:") + for tc in tool_calls: + tool_name = tc.get("tool_name", "unknown") + action_name = tc.get("action_name", "unknown") + arguments = tc.get("arguments", {}) + result = tc.get("result", "") + if result is None: + result = "" + status = tc.get("status", "unknown") + + # Include full tool result for complete compression context + conversation_lines.append( + f" - {tool_name}.{action_name}({arguments}) " + f"[{status}] → {result}" + ) + + # Add agent thought if present + thought = query.get("thought", "") + if thought: + conversation_lines.append(f"\nAgent Thought: {thought}") + + # Add assistant response + conversation_lines.append(f"\nAssistant: {query.get('response', '')}") + + # Add sources if present + sources = query.get("sources", []) + if sources: + conversation_lines.append(f"\nSources Used: {len(sources)} documents") + + conversation_lines.append("") # Empty line between messages + + return "\n".join(conversation_lines) diff --git a/application/api/answer/services/compression/service.py b/application/api/answer/services/compression/service.py new file mode 100644 index 00000000..ccf6f126 --- /dev/null +++ b/application/api/answer/services/compression/service.py @@ -0,0 +1,306 @@ +"""Core compression service with simplified responsibilities.""" + +import logging +import re +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from application.api.answer.services.compression.prompt_builder import ( + CompressionPromptBuilder, +) +from application.api.answer.services.compression.token_counter import TokenCounter +from application.api.answer.services.compression.types import ( + CompressionMetadata, +) +from application.core.settings import settings + +logger = logging.getLogger(__name__) + + +class CompressionService: + """ + Service for compressing conversation history. + + Handles DB updates. + """ + + def __init__( + self, + llm, + model_id: str, + conversation_service=None, + prompt_builder: Optional[CompressionPromptBuilder] = None, + ): + """ + Initialize compression service. + + Args: + llm: LLM instance to use for compression + model_id: Model ID for compression + conversation_service: Service for DB operations (optional, for DB updates) + prompt_builder: Custom prompt builder (optional) + """ + self.llm = llm + self.model_id = model_id + self.conversation_service = conversation_service + self.prompt_builder = prompt_builder or CompressionPromptBuilder( + version=settings.COMPRESSION_PROMPT_VERSION + ) + + def compress_conversation( + self, + conversation: Dict[str, Any], + compress_up_to_index: int, + ) -> CompressionMetadata: + """ + Compress conversation history up to specified index. + + Args: + conversation: Full conversation document + compress_up_to_index: Last query index to include in compression + + Returns: + CompressionMetadata with compression details + + Raises: + ValueError: If compress_up_to_index is invalid + """ + try: + queries = conversation.get("queries", []) + + if compress_up_to_index < 0 or compress_up_to_index >= len(queries): + raise ValueError( + f"Invalid compress_up_to_index: {compress_up_to_index} " + f"(conversation has {len(queries)} queries)" + ) + + # Get queries to compress + queries_to_compress = queries[: compress_up_to_index + 1] + + # Check if there are existing compressions + existing_compressions = conversation.get("compression_metadata", {}).get( + "compression_points", [] + ) + + if existing_compressions: + logger.info( + f"Found {len(existing_compressions)} previous compression(s) - " + f"will incorporate into new summary" + ) + + # Calculate original token count + original_tokens = TokenCounter.count_query_tokens(queries_to_compress) + + # Log tool call stats + self._log_tool_call_stats(queries_to_compress) + + # Build compression prompt + messages = self.prompt_builder.build_prompt( + queries_to_compress, existing_compressions + ) + + # Call LLM to generate compression + logger.info( + f"Starting compression: {len(queries_to_compress)} queries " + f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) " + f"using model {self.model_id}" + ) + + response = self.llm.gen( + model=self.model_id, messages=messages, max_tokens=4000 + ) + + # Extract summary from response + compressed_summary = self._extract_summary(response) + + # Calculate compressed token count + compressed_tokens = TokenCounter.count_message_tokens( + [{"content": compressed_summary}] + ) + + # Calculate compression ratio + compression_ratio = ( + original_tokens / compressed_tokens if compressed_tokens > 0 else 0 + ) + + logger.info( + f"Compression complete: {original_tokens} → {compressed_tokens} tokens " + f"({compression_ratio:.1f}x compression)" + ) + + # Build compression metadata + compression_metadata = CompressionMetadata( + timestamp=datetime.now(timezone.utc), + query_index=compress_up_to_index, + compressed_summary=compressed_summary, + original_token_count=original_tokens, + compressed_token_count=compressed_tokens, + compression_ratio=compression_ratio, + model_used=self.model_id, + compression_prompt_version=self.prompt_builder.version, + ) + + return compression_metadata + + except Exception as e: + logger.error(f"Error compressing conversation: {str(e)}", exc_info=True) + raise + + def compress_and_save( + self, + conversation_id: str, + conversation: Dict[str, Any], + compress_up_to_index: int, + ) -> CompressionMetadata: + """ + Compress conversation and save to database. + + Args: + conversation_id: Conversation ID + conversation: Full conversation document + compress_up_to_index: Last query index to include + + Returns: + CompressionMetadata + + Raises: + ValueError: If conversation_service not provided or invalid index + """ + if not self.conversation_service: + raise ValueError( + "conversation_service required for compress_and_save operation" + ) + + # Perform compression + metadata = self.compress_conversation(conversation, compress_up_to_index) + + # Save to database + self.conversation_service.update_compression_metadata( + conversation_id, metadata.to_dict() + ) + + logger.info(f"Compression metadata saved to database for {conversation_id}") + + return metadata + + def get_compressed_context( + self, conversation: Dict[str, Any] + ) -> tuple[Optional[str], List[Dict[str, Any]]]: + """ + Get compressed summary + recent uncompressed messages. + + Args: + conversation: Full conversation document + + Returns: + (compressed_summary, recent_messages) + """ + try: + compression_metadata = conversation.get("compression_metadata", {}) + + if not compression_metadata.get("is_compressed"): + logger.debug("No compression metadata found - using full history") + queries = conversation.get("queries", []) + if queries is None: + logger.error("Conversation queries is None - returning empty list") + return None, [] + return None, queries + + compression_points = compression_metadata.get("compression_points", []) + + if not compression_points: + logger.debug("No compression points found - using full history") + queries = conversation.get("queries", []) + if queries is None: + logger.error("Conversation queries is None - returning empty list") + return None, [] + return None, queries + + # Get the most recent compression point + latest_compression = compression_points[-1] + compressed_summary = latest_compression.get("compressed_summary") + last_compressed_index = latest_compression.get("query_index") + compressed_tokens = latest_compression.get("compressed_token_count", 0) + original_tokens = latest_compression.get("original_token_count", 0) + + # Get only messages after compression point + queries = conversation.get("queries", []) + total_queries = len(queries) + recent_queries = queries[last_compressed_index + 1 :] + + logger.info( + f"Using compressed context: summary ({compressed_tokens} tokens, " + f"compressed from {original_tokens}) + {len(recent_queries)} recent messages " + f"(messages {last_compressed_index + 1}-{total_queries - 1})" + ) + + return compressed_summary, recent_queries + + except Exception as e: + logger.error( + f"Error getting compressed context: {str(e)}", exc_info=True + ) + queries = conversation.get("queries", []) + if queries is None: + return None, [] + return None, queries + + def _extract_summary(self, llm_response: str) -> str: + """ + Extract clean summary from LLM response. + + Args: + llm_response: Raw LLM response + + Returns: + Cleaned summary text + """ + try: + # Try to extract content within tags + summary_match = re.search( + r"(.*?)", llm_response, re.DOTALL + ) + + if summary_match: + summary = summary_match.group(1).strip() + else: + # If no summary tags, remove analysis tags and use the rest + summary = re.sub( + r".*?", "", llm_response, flags=re.DOTALL + ).strip() + + return summary + + except Exception as e: + logger.warning(f"Error extracting summary: {str(e)}, using full response") + return llm_response + + def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None: + """Log statistics about tool calls in queries.""" + total_tool_calls = 0 + total_tool_result_chars = 0 + tool_call_breakdown = {} + + for q in queries: + for tc in q.get("tool_calls", []): + total_tool_calls += 1 + tool_name = tc.get("tool_name", "unknown") + action_name = tc.get("action_name", "unknown") + key = f"{tool_name}.{action_name}" + tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1 + + # Track total tool result size + result = tc.get("result", "") + if result: + total_tool_result_chars += len(str(result)) + + if total_tool_calls > 0: + tool_breakdown_str = ", ".join( + f"{tool}({count})" + for tool, count in sorted(tool_call_breakdown.items()) + ) + tool_result_kb = total_tool_result_chars / 1024 + logger.info( + f"Tool call breakdown: {tool_breakdown_str} " + f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)" + ) diff --git a/application/api/answer/services/compression/threshold_checker.py b/application/api/answer/services/compression/threshold_checker.py new file mode 100644 index 00000000..15397018 --- /dev/null +++ b/application/api/answer/services/compression/threshold_checker.py @@ -0,0 +1,103 @@ +"""Compression threshold checking logic.""" + +import logging +from typing import Any, Dict + +from application.core.model_utils import get_token_limit +from application.core.settings import settings +from application.api.answer.services.compression.token_counter import TokenCounter + +logger = logging.getLogger(__name__) + + +class CompressionThresholdChecker: + """Determines if compression is needed based on token thresholds.""" + + def __init__(self, threshold_percentage: float = None): + """ + Initialize threshold checker. + + Args: + threshold_percentage: Percentage of context to use as threshold + (defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE) + """ + self.threshold_percentage = ( + threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE + ) + + def should_compress( + self, + conversation: Dict[str, Any], + model_id: str, + current_query_tokens: int = 500, + ) -> bool: + """ + Determine if compression is needed. + + Args: + conversation: Full conversation document + model_id: Target model for this request + current_query_tokens: Estimated tokens for current query + + Returns: + True if tokens >= threshold% of context window + """ + try: + # Calculate total tokens in conversation + total_tokens = TokenCounter.count_conversation_tokens(conversation) + total_tokens += current_query_tokens + + # Get context window limit for model + context_limit = get_token_limit(model_id) + + # Calculate threshold + threshold = int(context_limit * self.threshold_percentage) + + compression_needed = total_tokens >= threshold + percentage_used = (total_tokens / context_limit) * 100 + + if compression_needed: + logger.warning( + f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit " + f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)" + ) + else: + logger.info( + f"Compression check: {total_tokens}/{context_limit} tokens " + f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed" + ) + + return compression_needed + + except Exception as e: + logger.error(f"Error checking compression need: {str(e)}", exc_info=True) + return False + + def check_message_tokens(self, messages: list, model_id: str) -> bool: + """ + Check if message list exceeds threshold. + + Args: + messages: List of message dicts + model_id: Target model + + Returns: + True if at or above threshold + """ + try: + current_tokens = TokenCounter.count_message_tokens(messages) + context_limit = get_token_limit(model_id) + threshold = int(context_limit * self.threshold_percentage) + + if current_tokens >= threshold: + logger.warning( + f"Message context limit approaching: {current_tokens}/{context_limit} tokens " + f"({(current_tokens/context_limit)*100:.1f}%)" + ) + return True + + return False + + except Exception as e: + logger.error(f"Error checking message tokens: {str(e)}", exc_info=True) + return False diff --git a/application/api/answer/services/compression/token_counter.py b/application/api/answer/services/compression/token_counter.py new file mode 100644 index 00000000..ac676cf0 --- /dev/null +++ b/application/api/answer/services/compression/token_counter.py @@ -0,0 +1,103 @@ +"""Token counting utilities for compression.""" + +import logging +from typing import Any, Dict, List + +from application.utils import num_tokens_from_string +from application.core.settings import settings + +logger = logging.getLogger(__name__) + + +class TokenCounter: + """Centralized token counting for conversations and messages.""" + + @staticmethod + def count_message_tokens(messages: List[Dict]) -> int: + """ + Calculate total tokens in a list of messages. + + Args: + messages: List of message dicts with 'content' field + + Returns: + Total token count + """ + total_tokens = 0 + for message in messages: + content = message.get("content", "") + if isinstance(content, str): + total_tokens += num_tokens_from_string(content) + elif isinstance(content, list): + # Handle structured content (tool calls, etc.) + for item in content: + if isinstance(item, dict): + total_tokens += num_tokens_from_string(str(item)) + return total_tokens + + @staticmethod + def count_query_tokens( + queries: List[Dict[str, Any]], include_tool_calls: bool = True + ) -> int: + """ + Count tokens across multiple query objects. + + Args: + queries: List of query objects from conversation + include_tool_calls: Whether to count tool call tokens + + Returns: + Total token count + """ + total_tokens = 0 + + for query in queries: + # Count prompt and response tokens + if "prompt" in query: + total_tokens += num_tokens_from_string(query["prompt"]) + if "response" in query: + total_tokens += num_tokens_from_string(query["response"]) + if "thought" in query: + total_tokens += num_tokens_from_string(query.get("thought", "")) + + # Count tool call tokens + if include_tool_calls and "tool_calls" in query: + for tool_call in query["tool_calls"]: + tool_call_string = ( + f"Tool: {tool_call.get('tool_name')} | " + f"Action: {tool_call.get('action_name')} | " + f"Args: {tool_call.get('arguments')} | " + f"Response: {tool_call.get('result')}" + ) + total_tokens += num_tokens_from_string(tool_call_string) + + return total_tokens + + @staticmethod + def count_conversation_tokens( + conversation: Dict[str, Any], include_system_prompt: bool = False + ) -> int: + """ + Calculate total tokens in a conversation. + + Args: + conversation: Conversation document + include_system_prompt: Whether to include system prompt in count + + Returns: + Total token count + """ + try: + queries = conversation.get("queries", []) + total_tokens = TokenCounter.count_query_tokens(queries) + + # Add system prompt tokens if requested + if include_system_prompt: + # Rough estimate for system prompt + total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500) + + return total_tokens + + except Exception as e: + logger.error(f"Error calculating conversation tokens: {str(e)}") + return 0 diff --git a/application/api/answer/services/compression/types.py b/application/api/answer/services/compression/types.py new file mode 100644 index 00000000..b71ab9ee --- /dev/null +++ b/application/api/answer/services/compression/types.py @@ -0,0 +1,83 @@ +"""Type definitions for compression module.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional + + +@dataclass +class CompressionMetadata: + """Metadata about a compression operation.""" + + timestamp: datetime + query_index: int + compressed_summary: str + original_token_count: int + compressed_token_count: int + compression_ratio: float + model_used: str + compression_prompt_version: str + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for DB storage.""" + return { + "timestamp": self.timestamp, + "query_index": self.query_index, + "compressed_summary": self.compressed_summary, + "original_token_count": self.original_token_count, + "compressed_token_count": self.compressed_token_count, + "compression_ratio": self.compression_ratio, + "model_used": self.model_used, + "compression_prompt_version": self.compression_prompt_version, + } + + +@dataclass +class CompressionResult: + """Result of a compression operation.""" + + success: bool + compressed_summary: Optional[str] = None + recent_queries: List[Dict[str, Any]] = field(default_factory=list) + metadata: Optional[CompressionMetadata] = None + error: Optional[str] = None + compression_performed: bool = False + + @classmethod + def success_with_compression( + cls, summary: str, queries: List[Dict], metadata: CompressionMetadata + ) -> "CompressionResult": + """Create a successful result with compression.""" + return cls( + success=True, + compressed_summary=summary, + recent_queries=queries, + metadata=metadata, + compression_performed=True, + ) + + @classmethod + def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult": + """Create a successful result without compression needed.""" + return cls( + success=True, + recent_queries=queries, + compression_performed=False, + ) + + @classmethod + def failure(cls, error: str) -> "CompressionResult": + """Create a failure result.""" + return cls(success=False, error=error, compression_performed=False) + + def as_history(self) -> List[Dict[str, str]]: + """ + Convert recent queries to history format. + + Returns: + List of prompt/response dicts + """ + return [ + {"prompt": q["prompt"], "response": q["response"]} + for q in self.recent_queries + ] diff --git a/application/api/answer/services/conversation_service.py b/application/api/answer/services/conversation_service.py index 0e98983e..bf55801c 100644 --- a/application/api/answer/services/conversation_service.py +++ b/application/api/answer/services/conversation_service.py @@ -180,3 +180,103 @@ class ConversationService: conversation_data["api_key"] = agent["key"] result = self.conversations_collection.insert_one(conversation_data) return str(result.inserted_id) + + def update_compression_metadata( + self, conversation_id: str, compression_metadata: Dict[str, Any] + ) -> None: + """ + Update conversation with compression metadata. + + Uses $push with $slice to keep only the most recent compression points, + preventing unbounded array growth. Since each compression incorporates + previous compressions, older points become redundant. + + Args: + conversation_id: Conversation ID + compression_metadata: Compression point data + """ + try: + self.conversations_collection.update_one( + {"_id": ObjectId(conversation_id)}, + { + "$set": { + "compression_metadata.is_compressed": True, + "compression_metadata.last_compression_at": compression_metadata.get( + "timestamp" + ), + }, + "$push": { + "compression_metadata.compression_points": { + "$each": [compression_metadata], + "$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS, + } + }, + }, + ) + logger.info( + f"Updated compression metadata for conversation {conversation_id}" + ) + except Exception as e: + logger.error( + f"Error updating compression metadata: {str(e)}", exc_info=True + ) + raise + + def append_compression_message( + self, conversation_id: str, compression_metadata: Dict[str, Any] + ) -> None: + """ + Append a synthetic compression summary entry into the conversation history. + This makes the summary visible in the DB alongside normal queries. + """ + try: + summary = compression_metadata.get("compressed_summary", "") + if not summary: + return + timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc)) + + self.conversations_collection.update_one( + {"_id": ObjectId(conversation_id)}, + { + "$push": { + "queries": { + "prompt": "[Context Compression Summary]", + "response": summary, + "thought": "", + "sources": [], + "tool_calls": [], + "timestamp": timestamp, + "attachments": [], + "model_id": compression_metadata.get("model_used"), + } + } + }, + ) + logger.info(f"Appended compression summary to conversation {conversation_id}") + except Exception as e: + logger.error( + f"Error appending compression summary: {str(e)}", exc_info=True + ) + + def get_compression_metadata( + self, conversation_id: str + ) -> Optional[Dict[str, Any]]: + """ + Get compression metadata for a conversation. + + Args: + conversation_id: Conversation ID + + Returns: + Compression metadata dict or None + """ + try: + conversation = self.conversations_collection.find_one( + {"_id": ObjectId(conversation_id)}, {"compression_metadata": 1} + ) + return conversation.get("compression_metadata") if conversation else None + except Exception as e: + logger.error( + f"Error getting compression metadata: {str(e)}", exc_info=True + ) + return None diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index 586e7696..5d97fbe8 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -10,6 +10,8 @@ from bson.dbref import DBRef from bson.objectid import ObjectId from application.agents.agent_creator import AgentCreator +from application.api.answer.services.compression import CompressionOrchestrator +from application.api.answer.services.compression.token_counter import TokenCounter from application.api.answer.services.conversation_service import ConversationService from application.api.answer.services.prompt_renderer import PromptRenderer from application.core.model_utils import ( @@ -90,9 +92,14 @@ class StreamProcessor: self.shared_token = None self.model_id: Optional[str] = None self.conversation_service = ConversationService() + self.compression_orchestrator = CompressionOrchestrator( + self.conversation_service + ) self.prompt_renderer = PromptRenderer() self._prompt_content: Optional[str] = None self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None + self.compressed_summary: Optional[str] = None + self.compressed_summary_tokens: int = 0 def initialize(self): """Initialize all required components for processing""" @@ -112,15 +119,72 @@ class StreamProcessor: ) if not conversation: raise ValueError("Conversation not found or unauthorized") - self.history = [ - {"prompt": query["prompt"], "response": query["response"]} - for query in conversation.get("queries", []) - ] + + # Check if compression is enabled and needed + if settings.ENABLE_CONVERSATION_COMPRESSION: + self._handle_compression(conversation) + else: + # Original behavior - load all history + self.history = [ + {"prompt": query["prompt"], "response": query["response"]} + for query in conversation.get("queries", []) + ] else: self.history = limit_chat_history( json.loads(self.data.get("history", "[]")), model_id=self.model_id ) + def _handle_compression(self, conversation: Dict[str, Any]): + """ + Handle conversation compression logic using orchestrator. + + Args: + conversation: Full conversation document + """ + try: + # Use orchestrator to handle all compression logic + result = self.compression_orchestrator.compress_if_needed( + conversation_id=self.conversation_id, + user_id=self.initial_user_id, + model_id=self.model_id, + decoded_token=self.decoded_token, + ) + + if not result.success: + logger.error( + f"Compression failed: {result.error}, using full history" + ) + self.history = [ + {"prompt": query["prompt"], "response": query["response"]} + for query in conversation.get("queries", []) + ] + return + + # Set compressed summary if compression was performed + if result.compression_performed and result.compressed_summary: + self.compressed_summary = result.compressed_summary + self.compressed_summary_tokens = TokenCounter.count_message_tokens( + [{"content": result.compressed_summary}] + ) + logger.info( + f"Using compressed summary ({self.compressed_summary_tokens} tokens) " + f"+ {len(result.recent_queries)} recent messages" + ) + + # Build history from recent queries + self.history = result.as_history() + + except Exception as e: + logger.error( + f"Error handling compression, falling back to standard history: {str(e)}", + exc_info=True, + ) + # Fallback to original behavior + self.history = [ + {"prompt": query["prompt"], "response": query["response"]} + for query in conversation.get("queries", []) + ] + def _process_attachments(self): """Process any attachments in the request""" attachment_ids = self.data.get("attachments", []) @@ -658,7 +722,7 @@ class StreamProcessor: ) system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER) - return AgentCreator.create_agent( + agent = AgentCreator.create_agent( self.agent_config["agent_type"], endpoint="stream", llm_name=provider or settings.LLM_PROVIDER, @@ -671,4 +735,10 @@ class StreamProcessor: decoded_token=self.decoded_token, attachments=self.attachments, json_schema=self.agent_config.get("json_schema"), + compressed_summary=self.compressed_summary, ) + + agent.conversation_id = self.conversation_id + agent.initial_user_id = self.initial_user_id + + return agent diff --git a/application/core/model_configs.py b/application/core/model_configs.py index b802ee27..5f75bc83 100644 --- a/application/core/model_configs.py +++ b/application/core/model_configs.py @@ -29,63 +29,29 @@ GOOGLE_ATTACHMENTS = [ OPENAI_MODELS = [ AvailableModel( - id="gpt-4o", + id="gpt-5.1", provider=ModelProvider.OPENAI, - display_name="GPT-4 Omni", - description="Latest and most capable model", + display_name="GPT-5.1", + description="Flagship model with enhanced reasoning, coding, and agentic capabilities", capabilities=ModelCapabilities( supports_tools=True, supports_structured_output=True, supported_attachment_types=OPENAI_ATTACHMENTS, - context_window=128000, + context_window=400000, ), ), AvailableModel( - id="gpt-4o-mini", + id="gpt-5-mini", provider=ModelProvider.OPENAI, - display_name="GPT-4 Omni Mini", - description="Fast and efficient", + display_name="GPT-5 Mini", + description="Faster, cost-effective variant of GPT-5.1", capabilities=ModelCapabilities( supports_tools=True, supports_structured_output=True, supported_attachment_types=OPENAI_ATTACHMENTS, - context_window=128000, + context_window=400000, ), - ), - AvailableModel( - id="gpt-4-turbo", - provider=ModelProvider.OPENAI, - display_name="GPT-4 Turbo", - description="Fast GPT-4 with 128k context", - capabilities=ModelCapabilities( - supports_tools=True, - supports_structured_output=True, - supported_attachment_types=OPENAI_ATTACHMENTS, - context_window=128000, - ), - ), - AvailableModel( - id="gpt-4", - provider=ModelProvider.OPENAI, - display_name="GPT-4", - description="Most capable model", - capabilities=ModelCapabilities( - supports_tools=True, - supports_structured_output=True, - supported_attachment_types=OPENAI_ATTACHMENTS, - context_window=8192, - ), - ), - AvailableModel( - id="gpt-3.5-turbo", - provider=ModelProvider.OPENAI, - display_name="GPT-3.5 Turbo", - description="Fast and cost-effective", - capabilities=ModelCapabilities( - supports_tools=True, - context_window=4096, - ), - ), + ) ] @@ -159,15 +125,15 @@ GOOGLE_MODELS = [ ), ), AvailableModel( - id="gemini-2.5-pro", + id="gemini-3-pro-preview", provider=ModelProvider.GOOGLE, - display_name="Gemini 2.5 Pro", + display_name="Gemini 3 Pro", description="Most capable Gemini model", capabilities=ModelCapabilities( supports_tools=True, supports_structured_output=True, supported_attachment_types=GOOGLE_ATTACHMENTS, - context_window=2000000, + context_window=20000, # Set low for testing compression ), ), ] diff --git a/application/core/settings.py b/application/core/settings.py index ee7ffa05..cecbb333 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -144,6 +144,13 @@ class Settings(BaseSettings): # Tool pre-fetch settings ENABLE_TOOL_PREFETCH: bool = True + # Conversation Compression Settings + ENABLE_CONVERSATION_COMPRESSION: bool = True + COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context + COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression + COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations + COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat + path = Path(__file__).parent.parent.absolute() settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 9c58a3e1..00c609ec 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -1,4 +1,3 @@ -import json import logging from google import genai @@ -11,11 +10,13 @@ from application.storage.storage_creator import StorageCreator class GoogleLLM(BaseLLM): - def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): + def __init__( + self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs + ): super().__init__(*args, **kwargs) self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY self.user_api_key = user_api_key - + self.client = genai.Client(api_key=self.api_key) self.storage = StorageCreator.get_storage() @@ -33,6 +34,12 @@ class GoogleLLM(BaseLLM): "image/jpg", "image/webp", "image/gif", + "application/pdf", + "image/png", + "image/jpeg", + "image/jpg", + "image/webp", + "image/gif", ] def prepare_messages_with_attachments(self, messages, attachments=None): @@ -135,12 +142,38 @@ class GoogleLLM(BaseLLM): raise def _clean_messages_google(self, messages): - """Convert OpenAI format messages to Google AI format.""" + """ + Convert OpenAI format messages to Google AI format and collect system prompts. + + Returns: + tuple[list[types.Content], Optional[str]]: cleaned messages and optional + combined system instruction. + """ cleaned_messages = [] + system_instructions = [] + + def _extract_system_text(content): + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict) and "text" in item and item["text"] is not None: + parts.append(item["text"]) + return "\n".join(parts) + return "" + for message in messages: role = message.get("role") content = message.get("content") + # Gemini only accepts user/model in the contents list. + if role == "system": + sys_text = _extract_system_text(content) + if sys_text: + system_instructions.append(sys_text) + continue + if role == "assistant": role = "model" elif role == "tool": @@ -159,12 +192,27 @@ class GoogleLLM(BaseLLM): cleaned_args = self._remove_null_values( item["function_call"]["args"] ) - parts.append( - types.Part.from_function_call( - name=item["function_call"]["name"], - args=cleaned_args, + # Create function call part with thought_signature if present + # For Gemini 3 models, we need to include thought_signature + if "thought_signature" in item: + # Use Part constructor with functionCall and thoughtSignature + parts.append( + types.Part( + functionCall=types.FunctionCall( + name=item["function_call"]["name"], + args=cleaned_args, + ), + thoughtSignature=item["thought_signature"], + ) + ) + else: + # Use helper method when no thought_signature + parts.append( + types.Part.from_function_call( + name=item["function_call"]["name"], + args=cleaned_args, + ) ) - ) elif "function_response" in item: parts.append( types.Part.from_function_response( @@ -188,7 +236,8 @@ class GoogleLLM(BaseLLM): raise ValueError(f"Unexpected content type: {type(content)}") if parts: cleaned_messages.append(types.Content(role=role, parts=parts)) - return cleaned_messages + system_instruction = "\n\n".join(system_instructions) if system_instructions else None + return cleaned_messages, system_instruction def _clean_schema(self, schema_obj): """ @@ -274,6 +323,61 @@ class GoogleLLM(BaseLLM): genai_tools.append(genai_tool) return genai_tools + def _extract_preview_from_message(self, message): + """Get a short, human-readable preview from the last message.""" + try: + if hasattr(message, "parts"): + for part in reversed(message.parts): + if getattr(part, "text", None): + return part.text + function_call = getattr(part, "function_call", None) + if function_call: + name = getattr(function_call, "name", "") or "function_call" + return f"function_call:{name}" + function_response = getattr(part, "function_response", None) + if function_response: + name = getattr(function_response, "name", "") or "function_response" + return f"function_response:{name}" + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + for item in reversed(content): + if isinstance(item, str): + return item + if isinstance(item, dict): + if item.get("text"): + return item["text"] + if item.get("function_call"): + fn = item["function_call"] + if isinstance(fn, dict): + name = fn.get("name") or "function_call" + return f"function_call:{name}" + return "function_call" + if item.get("function_response"): + resp = item["function_response"] + if isinstance(resp, dict): + name = resp.get("name") or "function_response" + return f"function_response:{name}" + return "function_response" + if "text" in message and isinstance(message["text"], str): + return message["text"] + except Exception: + pass + return str(message) + + def _summarize_messages_for_log(self, messages, preview_chars=20): + """Return a compact summary for logging to avoid huge payloads.""" + message_count = len(messages) if messages else 0 + last_preview = "" + if messages: + last_preview = self._extract_preview_from_message(messages[-1]) or "" + last_preview = str(last_preview).replace("\n", " ") + if len(last_preview) > preview_chars: + last_preview = f"{last_preview[:preview_chars]}..." + return f"count={message_count}, last='{last_preview}'" + def _raw_gen( self, baseself, @@ -287,12 +391,12 @@ class GoogleLLM(BaseLLM): ): """Generate content using Google AI API without streaming.""" client = genai.Client(api_key=self.api_key) + system_instruction = None if formatting == "openai": - messages = self._clean_messages_google(messages) + messages, system_instruction = self._clean_messages_google(messages) config = types.GenerateContentConfig() - if messages[0].role == "system": - config.system_instruction = messages[0].parts[0].text - messages = messages[1:] + if system_instruction: + config.system_instruction = system_instruction if tools: cleaned_tools = self._clean_tools_format(tools) config.tools = cleaned_tools @@ -325,12 +429,12 @@ class GoogleLLM(BaseLLM): ): """Generate content using Google AI API with streaming.""" client = genai.Client(api_key=self.api_key) + system_instruction = None if formatting == "openai": - messages = self._clean_messages_google(messages) + messages, system_instruction = self._clean_messages_google(messages) config = types.GenerateContentConfig() - if messages[0].role == "system": - config.system_instruction = messages[0].parts[0].text - messages = messages[1:] + if system_instruction: + config.system_instruction = system_instruction if tools: cleaned_tools = self._clean_tools_format(tools) config.tools = cleaned_tools @@ -349,8 +453,12 @@ class GoogleLLM(BaseLLM): break if has_attachments: break + messages_summary = self._summarize_messages_for_log(messages) logging.info( - f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}" + "GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s", + model, + messages_summary, + has_attachments, ) response = client.models.generate_content_stream( diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index 920caf65..b11654c5 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -1,4 +1,5 @@ import logging +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, Generator, List, Optional, Union @@ -16,6 +17,7 @@ class ToolCall: name: str arguments: Union[str, Dict] index: Optional[int] = None + thought_signature: Optional[str] = None @classmethod def from_dict(cls, data: Dict) -> "ToolCall": @@ -178,6 +180,406 @@ class LLMHandler(ABC): system_msg["content"] += f"\n\n{combined_text}" return prepared_messages + def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]: + """ + Build a minimal context: system prompt + latest user message only. + Drops all tool/function messages to shrink context aggressively. + """ + system_message = next((m for m in messages if m.get("role") == "system"), None) + if not system_message: + logger.warning("Cannot prune messages minimally: missing system message.") + return None + last_non_system = None + for m in reversed(messages): + if m.get("role") == "user": + last_non_system = m + break + if not last_non_system and m.get("role") not in ("system", None): + last_non_system = m + if not last_non_system: + logger.warning("Cannot prune messages minimally: missing user/assistant messages.") + return None + logger.info("Pruning context to system + latest user/assistant message to proceed.") + return [system_message, last_non_system] + + def _extract_text_from_content(self, content: Any) -> str: + """ + Convert message content (str or list of parts) to plain text for compression. + """ + if isinstance(content, str): + return content + if isinstance(content, list): + parts_text = [] + for item in content: + if isinstance(item, dict): + if "text" in item and item["text"] is not None: + parts_text.append(str(item["text"])) + elif "function_call" in item or "function_response" in item: + # Keep serialized function calls/responses so the compressor sees actions + parts_text.append(str(item)) + elif "files" in item: + parts_text.append(str(item)) + return "\n".join(parts_text) + return "" + + def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]: + """ + Build a conversation-like dict from current messages so we can compress + even when the conversation isn't persisted yet. Includes tool calls/results. + """ + queries = [] + current_prompt = None + current_tool_calls = {} + + def _commit_query(response_text: str): + nonlocal current_prompt, current_tool_calls + if current_prompt is None and not response_text: + return + tool_calls_list = list(current_tool_calls.values()) + queries.append( + { + "prompt": current_prompt or "", + "response": response_text, + "tool_calls": tool_calls_list, + } + ) + current_prompt = None + current_tool_calls = {} + + for message in messages: + role = message.get("role") + content = message.get("content") + + if role == "user": + current_prompt = self._extract_text_from_content(content) + + elif role in {"assistant", "model"}: + # If this assistant turn contains tool calls, collect them; otherwise commit a response. + if isinstance(content, list): + for item in content: + if "function_call" in item: + fc = item["function_call"] + call_id = fc.get("call_id") or str(uuid.uuid4()) + current_tool_calls[call_id] = { + "tool_name": "unknown_tool", + "action_name": fc.get("name"), + "arguments": fc.get("args"), + "result": None, + "status": "called", + "call_id": call_id, + } + elif "function_response" in item: + fr = item["function_response"] + call_id = fr.get("call_id") or str(uuid.uuid4()) + current_tool_calls[call_id] = { + "tool_name": "unknown_tool", + "action_name": fr.get("name"), + "arguments": None, + "result": fr.get("response", {}).get("result"), + "status": "completed", + "call_id": call_id, + } + # No direct assistant text here; continue to next message + continue + + response_text = self._extract_text_from_content(content) + _commit_query(response_text) + + elif role == "tool": + # Attach tool outputs to the latest pending tool call if possible + tool_text = self._extract_text_from_content(content) + # Attempt to parse function_response style + call_id = None + if isinstance(content, list): + for item in content: + if "function_response" in item and item["function_response"].get("call_id"): + call_id = item["function_response"]["call_id"] + break + if call_id and call_id in current_tool_calls: + current_tool_calls[call_id]["result"] = tool_text + current_tool_calls[call_id]["status"] = "completed" + elif queries: + queries[-1].setdefault("tool_calls", []).append( + { + "tool_name": "unknown_tool", + "action_name": "unknown_action", + "arguments": {}, + "result": tool_text, + "status": "completed", + } + ) + + # If there's an unfinished prompt with tool_calls but no response yet, commit it + if current_prompt is not None or current_tool_calls: + _commit_query(response_text="") + + if not queries: + return None + + return { + "queries": queries, + "compression_metadata": { + "is_compressed": False, + "compression_points": [], + }, + } + + def _rebuild_messages_after_compression( + self, + messages: List[Dict], + compressed_summary: Optional[str], + recent_queries: List[Dict], + include_current_execution: bool = False, + include_tool_calls: bool = False, + ) -> Optional[List[Dict]]: + """ + Rebuild the message list after compression so tool execution can continue. + + Delegates to MessageBuilder for the actual reconstruction. + """ + from application.api.answer.services.compression.message_builder import ( + MessageBuilder, + ) + + return MessageBuilder.rebuild_messages_after_compression( + messages=messages, + compressed_summary=compressed_summary, + recent_queries=recent_queries, + include_current_execution=include_current_execution, + include_tool_calls=include_tool_calls, + ) + + def _perform_mid_execution_compression( + self, agent, messages: List[Dict] + ) -> tuple[bool, Optional[List[Dict]]]: + """ + Perform compression during tool execution and rebuild messages. + + Uses the new orchestrator for simplified compression. + + Args: + agent: The agent instance + messages: Current conversation messages + + Returns: + (success: bool, rebuilt_messages: Optional[List[Dict]]) + """ + try: + from application.api.answer.services.compression import ( + CompressionOrchestrator, + ) + from application.api.answer.services.conversation_service import ( + ConversationService, + ) + + conversation_service = ConversationService() + orchestrator = CompressionOrchestrator(conversation_service) + + # Get conversation from database (may be None for new sessions) + conversation = conversation_service.get_conversation( + agent.conversation_id, agent.initial_user_id + ) + + if conversation: + # Merge current in-flight messages (including tool calls) + conversation_from_msgs = self._build_conversation_from_messages(messages) + if conversation_from_msgs: + conversation = conversation_from_msgs + else: + logger.warning( + "Could not load conversation for compression; attempting in-memory compression" + ) + return self._perform_in_memory_compression(agent, messages) + + # Use orchestrator to perform compression + result = orchestrator.compress_mid_execution( + conversation_id=agent.conversation_id, + user_id=agent.initial_user_id, + model_id=agent.model_id, + decoded_token=getattr(agent, "decoded_token", {}), + current_conversation=conversation, + ) + + if not result.success: + logger.warning(f"Mid-execution compression failed: {result.error}") + # Try minimal pruning as fallback + pruned = self._prune_messages_minimal(messages) + if pruned: + agent.context_limit_reached = False + agent.current_token_count = 0 + return True, pruned + return False, None + + if not result.compression_performed: + logger.warning("Compression not performed") + return False, None + + # Check if compression actually reduced tokens + if result.metadata: + if result.metadata.compressed_token_count >= result.metadata.original_token_count: + logger.warning( + "Compression did not reduce token count; falling back to minimal pruning" + ) + pruned = self._prune_messages_minimal(messages) + if pruned: + agent.context_limit_reached = False + agent.current_token_count = 0 + return True, pruned + return False, None + + logger.info( + f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, " + f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens" + ) + + # Also store the compression summary as a visible message + if result.metadata: + conversation_service.append_compression_message( + agent.conversation_id, result.metadata.to_dict() + ) + + # Update agent's compressed summary for downstream persistence + agent.compressed_summary = result.compressed_summary + agent.compression_metadata = result.metadata.to_dict() if result.metadata else None + agent.compression_saved = False + + # Reset the context limit flag so tools can continue + agent.context_limit_reached = False + agent.current_token_count = 0 + + # Rebuild messages + rebuilt_messages = self._rebuild_messages_after_compression( + messages, + result.compressed_summary, + result.recent_queries, + include_current_execution=False, + include_tool_calls=False, + ) + + if rebuilt_messages is None: + return False, None + + return True, rebuilt_messages + + except Exception as e: + logger.error( + f"Error performing mid-execution compression: {str(e)}", exc_info=True + ) + return False, None + + def _perform_in_memory_compression( + self, agent, messages: List[Dict] + ) -> tuple[bool, Optional[List[Dict]]]: + """ + Fallback compression path when the conversation is not yet persisted. + + Uses CompressionService directly without DB persistence. + """ + try: + from application.api.answer.services.compression.service import ( + CompressionService, + ) + from application.core.model_utils import ( + get_api_key_for_provider, + get_provider_from_model_id, + ) + from application.core.settings import settings + from application.llm.llm_creator import LLMCreator + + conversation = self._build_conversation_from_messages(messages) + if not conversation: + logger.warning( + "Cannot perform in-memory compression: no user/assistant turns found" + ) + return False, None + + compression_model = ( + settings.COMPRESSION_MODEL_OVERRIDE + if settings.COMPRESSION_MODEL_OVERRIDE + else agent.model_id + ) + provider = get_provider_from_model_id(compression_model) + api_key = get_api_key_for_provider(provider) + compression_llm = LLMCreator.create_llm( + provider, + api_key, + getattr(agent, "user_api_key", None), + getattr(agent, "decoded_token", None), + model_id=compression_model, + ) + + # Create service without DB persistence capability + compression_service = CompressionService( + llm=compression_llm, + model_id=compression_model, + conversation_service=None, # No DB updates for in-memory + ) + + queries_count = len(conversation.get("queries", [])) + compress_up_to = queries_count - 1 + + if compress_up_to < 0 or queries_count == 0: + logger.warning("Not enough queries to compress in-memory context") + return False, None + + metadata = compression_service.compress_conversation( + conversation, + compress_up_to_index=compress_up_to, + ) + + # If compression doesn't reduce tokens, fall back to minimal pruning + if ( + metadata.compressed_token_count + >= metadata.original_token_count + ): + logger.warning( + "In-memory compression did not reduce token count; falling back to minimal pruning" + ) + pruned = self._prune_messages_minimal(messages) + if pruned: + agent.context_limit_reached = False + agent.current_token_count = 0 + return True, pruned + return False, None + + # Attach metadata to synthetic conversation + conversation["compression_metadata"] = { + "is_compressed": True, + "compression_points": [metadata.to_dict()], + } + + compressed_summary, recent_queries = ( + compression_service.get_compressed_context(conversation) + ) + + agent.compressed_summary = compressed_summary + agent.compression_metadata = metadata.to_dict() + agent.compression_saved = False + agent.context_limit_reached = False + agent.current_token_count = 0 + + rebuilt_messages = self._rebuild_messages_after_compression( + messages, + compressed_summary, + recent_queries, + include_current_execution=False, + include_tool_calls=False, + ) + if rebuilt_messages is None: + return False, None + + logger.info( + f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, " + f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens" + ) + return True, rebuilt_messages + + except Exception as e: + logger.error( + f"Error performing in-memory compression: {str(e)}", exc_info=True + ) + return False, None + def handle_tool_calls( self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict] ) -> Generator: @@ -195,7 +597,110 @@ class LLMHandler(ABC): """ updated_messages = messages.copy() - for call in tool_calls: + for i, call in enumerate(tool_calls): + # Check context limit before executing tool call + if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages): + # Context limit reached - attempt mid-execution compression + compression_attempted = False + compression_successful = False + + try: + from application.core.settings import settings + compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION + except Exception: + compression_enabled = False + + if compression_enabled: + compression_attempted = True + try: + logger.info( + f"Context limit reached with {len(tool_calls) - i} remaining tool calls. " + f"Attempting mid-execution compression..." + ) + + # Trigger mid-execution compression (DB-backed if available, otherwise in-memory) + compression_successful, rebuilt_messages = self._perform_mid_execution_compression( + agent, updated_messages + ) + + if compression_successful and rebuilt_messages is not None: + # Update the messages list with rebuilt compressed version + updated_messages = rebuilt_messages + + # Yield compression success message + yield { + "type": "info", + "data": { + "message": "Context window limit reached. Compressed conversation history to continue processing." + } + } + + logger.info( + f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls." + ) + # Proceed to execute the current tool call with the reduced context + else: + logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.") + except Exception as e: + logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True) + compression_attempted = True + compression_successful = False + + # If compression wasn't attempted or failed, skip remaining tools + if not compression_successful: + if i == 0: + # Special case: limit reached before executing any tools + # This can happen when previous tool responses pushed context over limit + if compression_attempted: + logger.warning( + f"Context limit reached before executing any tools. " + f"Compression attempted but failed. " + f"Skipping all {len(tool_calls)} pending tool call(s). " + f"This typically occurs when previous tool responses contained large amounts of data." + ) + else: + logger.warning( + f"Context limit reached before executing any tools. " + f"Skipping all {len(tool_calls)} pending tool call(s). " + f"This typically occurs when previous tool responses contained large amounts of data. " + f"Consider enabling compression or using a model with larger context window." + ) + else: + # Normal case: executed some tools, now stopping + tool_word = "tool call" if i == 1 else "tool calls" + remaining = len(tool_calls) - i + remaining_word = "tool call" if remaining == 1 else "tool calls" + if compression_attempted: + logger.warning( + f"Context limit reached after executing {i} {tool_word}. " + f"Compression attempted but failed. " + f"Skipping remaining {remaining} {remaining_word}." + ) + else: + logger.warning( + f"Context limit reached after executing {i} {tool_word}. " + f"Skipping remaining {remaining} {remaining_word}. " + f"Consider enabling compression or using a model with larger context window." + ) + + # Mark remaining tools as skipped + for remaining_call in tool_calls[i:]: + skip_message = { + "type": "tool_call", + "data": { + "tool_name": "system", + "call_id": remaining_call.id, + "action_name": remaining_call.name, + "arguments": {}, + "result": "Skipped: Context limit reached. Too many tool calls in conversation.", + "status": "skipped" + } + } + yield skip_message + + # Set flag on agent + agent.context_limit_reached = True + break try: self.tool_calls.append(call) tool_executor_gen = agent._execute_tool_action(tools_dict, call) @@ -205,21 +710,26 @@ class LLMHandler(ABC): except StopIteration as e: tool_response, call_id = e.value break + + function_call_content = { + "function_call": { + "name": call.name, + "args": call.arguments, + "call_id": call_id, + } + } + # Include thought_signature for Google Gemini 3 models + # It should be at the same level as function_call, not inside it + if call.thought_signature: + function_call_content["thought_signature"] = call.thought_signature updated_messages.append( { "role": "assistant", - "content": [ - { - "function_call": { - "name": call.name, - "args": call.arguments, - "call_id": call_id, - } - } - ], + "content": [function_call_content], } ) + updated_messages.append(self.create_tool_message(call, tool_response)) except Exception as e: logger.error(f"Error executing tool: {str(e)}", exc_info=True) @@ -324,6 +834,9 @@ class LLMHandler(ABC): existing.name = call.name if call.arguments: existing.arguments += call.arguments + # Preserve thought_signature for Google Gemini 3 models + if call.thought_signature: + existing.thought_signature = call.thought_signature if parsed.finish_reason == "tool_calls": tool_handler_gen = self.handle_tool_calls( agent, list(tool_calls.values()), tools_dict, messages @@ -336,8 +849,21 @@ class LLMHandler(ABC): break tool_calls = {} + # Check if context limit was reached during tool execution + if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached: + # Add system message warning about context limit + messages.append({ + "role": "system", + "content": ( + "WARNING: Context window limit has been reached. " + "Please provide a final response to the user without making additional tool calls. " + "Summarize the work completed so far." + ) + }) + logger.info("Context limit reached - instructing agent to wrap up") + response = agent.llm.gen_stream( - model=agent.model_id, messages=messages, tools=agent.tools + model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None ) self.llm_calls.append(build_stack_data(agent.llm)) diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py index 7fa44cb6..0142922a 100644 --- a/application/llm/handlers/google.py +++ b/application/llm/handlers/google.py @@ -19,15 +19,20 @@ class GoogleLLMHandler(LLMHandler): ) if hasattr(response, "candidates"): parts = response.candidates[0].content.parts if response.candidates else [] - tool_calls = [ - ToolCall( - id=str(uuid.uuid4()), - name=part.function_call.name, - arguments=part.function_call.args, - ) - for part in parts - if hasattr(part, "function_call") and part.function_call is not None - ] + tool_calls = [] + for idx, part in enumerate(parts): + if hasattr(part, "function_call") and part.function_call is not None: + has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None + thought_sig = part.thought_signature if has_sig else None + tool_calls.append( + ToolCall( + id=str(uuid.uuid4()), + name=part.function_call.name, + arguments=part.function_call.args, + index=idx, + thought_signature=thought_sig, + ) + ) content = " ".join( part.text @@ -41,13 +46,17 @@ class GoogleLLMHandler(LLMHandler): raw_response=response, ) else: + # This branch handles individual Part objects from streaming responses tool_calls = [] - if hasattr(response, "function_call"): + if hasattr(response, "function_call") and response.function_call is not None: + has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None + thought_sig = response.thought_signature if has_sig else None tool_calls.append( ToolCall( id=str(uuid.uuid4()), name=response.function_call.name, arguments=response.function_call.args, + thought_signature=thought_sig, ) ) return LLMResponse( diff --git a/application/llm/openai.py b/application/llm/openai.py index beab465b..3917cbf7 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -128,6 +128,10 @@ class OpenAILLM(BaseLLM): ): messages = self._clean_messages_openai(messages) + # Convert max_tokens to max_completion_tokens for newer models + if "max_tokens" in kwargs: + kwargs["max_completion_tokens"] = kwargs.pop("max_tokens") + request_params = { "model": model, "messages": messages, @@ -159,6 +163,10 @@ class OpenAILLM(BaseLLM): ): messages = self._clean_messages_openai(messages) + # Convert max_tokens to max_completion_tokens for newer models + if "max_tokens" in kwargs: + kwargs["max_completion_tokens"] = kwargs.pop("max_tokens") + request_params = { "model": model, "messages": messages, diff --git a/application/prompts/compression/v1.0.txt b/application/prompts/compression/v1.0.txt new file mode 100644 index 00000000..28e7550d --- /dev/null +++ b/application/prompts/compression/v1.0.txt @@ -0,0 +1,35 @@ +Your task is to create a detailed summary of the conversation so far, paying close attention to the user's explicit requests and your previous actions. + +This summary should be thorough in capturing technical details, code patterns, and architectural decisions that would be essential for continuing work without losing context. + +Before providing your final summary, wrap your analysis in tags to organize your thoughts and ensure you've covered all necessary points. In your analysis process: + +1. Chronologically analyze each message, tool call and section of the conversation. For each section thoroughly identify: + - The user's explicit requests and intents + - Your approach to addressing the user's requests + - Key decisions, concepts and patterns + - Specific details like if applicable: + - file names + - full code snippets + - function signatures + - file edits + - Errors that you ran into and how you fixed them + - Pay special attention to specific user feedback that you received, especially if the user told you to do something differently. + +2. Double-check for accuracy and completeness, addressing each required element thoroughly. + +Your summary should include the following sections: + +1. Primary Request and Intent: Capture all of the user's explicit requests and intents in detail +2. Key Concepts: List all important concepts discussed. +3. Files and Code Sections: Enumerate specific files and code sections examined, modified, or created. Pay special attention to the most recent messages and include full code snippets where applicable and include a summary of why this file read or edit is important. +4. Errors and fixes: List all errors that you ran into, and how you fixed them. Pay special attention to specific user feedback that you received, especially if the user told you to do something differently. +5. Problem Solving: Document problems solved and any ongoing troubleshooting efforts. +6. All user messages: List ALL user messages that are not tool results. These are critical for understanding the users' feedback and changing intent. +7. Tool Calls: List ALL tool calls made, including their inputs relevant parts of the outputs. +8. Pending Tasks: Outline any pending tasks that you have explicitly been asked to work on. +9. Current Work: Describe in detail precisely what was being worked on immediately before this summary request, paying special attention to the most recent messages from both user and assistant. Include file names and code snippets where applicable. +10. Optional Next Step: List the next step that you will take that is related to the most recent work you were doing. IMPORTANT: ensure that this step is DIRECTLY in line with the user's most recent explicit requests, and the task you were working on immediately before this summary request. If your last task was concluded, then only list next steps if they are explicitly in line with the users request. Do not start on tangential requests or really old requests that were already completed without confirming with the user first. +If there is a next step, include direct quotes from the most recent conversation showing exactly what task you were working on and where you left off. This should be verbatim to ensure there's no drift in task interpretation. + +Please provide your summary based on the conversation and tools used so far, following this structure and ensuring precision and thoroughness in your response. diff --git a/application/requirements.txt b/application/requirements.txt index 08d259b1..cb58247b 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -15,7 +15,7 @@ Flask==3.1.1 faiss-cpu==1.9.0.post1 fastmcp==2.11.0 flask-restx==1.3.0 -google-genai==1.3.0 +google-genai==1.49.0 google-api-python-client==2.179.0 google-auth-httplib2==0.2.0 google-auth-oauthlib==1.2.2 diff --git a/application/utils.py b/application/utils.py index 89b884f0..b25c4717 100644 --- a/application/utils.py +++ b/application/utils.py @@ -197,6 +197,24 @@ def generate_image_url(image_path): return f"{base_url}/api/images/{image_path}" +def calculate_compression_threshold( + model_id: str, threshold_percentage: float = 0.8 +) -> int: + """ + Calculate token threshold for triggering compression. + + Args: + model_id: Model identifier + threshold_percentage: Percentage of context window (default 80%) + + Returns: + Token count threshold + """ + total_context = get_token_limit(model_id) + threshold = int(total_context * threshold_percentage) + return threshold + + def clean_text_for_tts(text: str) -> str: """ clean text for Text-to-Speech processing. diff --git a/tests/llm/test_google_llm.py b/tests/llm/test_google_llm.py index 0862c727..80434e98 100644 --- a/tests/llm/test_google_llm.py +++ b/tests/llm/test_google_llm.py @@ -91,7 +91,7 @@ def test_clean_messages_google_basic(): {"function_call": {"name": "fn", "args": {"a": 1}}}, ]}, ] - cleaned = llm._clean_messages_google(msgs) + cleaned, system_instruction = llm._clean_messages_google(msgs) assert all(hasattr(c, "role") and hasattr(c, "parts") for c in cleaned) assert any(c.role == "model" for c in cleaned) diff --git a/tests/test_agent_token_tracking.py b/tests/test_agent_token_tracking.py new file mode 100644 index 00000000..e168567a --- /dev/null +++ b/tests/test_agent_token_tracking.py @@ -0,0 +1,325 @@ +import pytest +from unittest.mock import Mock, patch + +from application.agents.base import BaseAgent +from application.llm.handlers.base import LLMHandler, ToolCall + + +class MockAgent(BaseAgent): + """Mock agent for testing""" + + def _gen_inner(self, query, log_context=None): + yield {"answer": "test"} + + +@pytest.fixture +def mock_agent(): + """Create a mock agent for testing""" + agent = MockAgent( + endpoint="test", + llm_name="openai", + model_id="gpt-4o", + api_key="test-key", + ) + agent.llm = Mock() + return agent + + +@pytest.fixture +def mock_llm_handler(): + """Create a mock LLM handler""" + handler = Mock(spec=LLMHandler) + handler.tool_calls = [] + return handler + + +class TestAgentTokenTracking: + """Test suite for agent token tracking during execution""" + + def test_calculate_current_context_tokens(self, mock_agent): + """Test token calculation for current context""" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + + tokens = mock_agent._calculate_current_context_tokens(messages) + + # Should count tokens from all messages + assert tokens > 0 + # Rough estimate: ~20-40 tokens for this conversation + assert 15 < tokens < 60 + + def test_calculate_tokens_with_tool_calls(self, mock_agent): + """Test token calculation includes tool call content""" + messages = [ + {"role": "system", "content": "Test"}, + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": "search_tool", + "args": {"query": "test"}, + "call_id": "123", + } + } + ], + }, + { + "role": "tool", + "content": [ + { + "function_response": { + "name": "search_tool", + "response": {"result": "Found 10 results"}, + "call_id": "123", + } + } + ], + }, + ] + + tokens = mock_agent._calculate_current_context_tokens(messages) + + # Should include tool call tokens + assert tokens > 0 + + @patch("application.core.model_utils.get_token_limit") + @patch("application.core.settings.settings") + def test_check_context_limit_below_threshold( + self, mock_settings, mock_get_token_limit, mock_agent + ): + """Test context limit check when below threshold""" + mock_get_token_limit.return_value = 128000 + mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8 + + messages = [ + {"role": "system", "content": "Short message"}, + {"role": "user", "content": "Hello"}, + ] + + # Should return False for small conversation + result = mock_agent._check_context_limit(messages) + assert result is False + + # Should track current token count + assert mock_agent.current_token_count > 0 + assert mock_agent.current_token_count < 128000 * 0.8 + + @patch("application.core.model_utils.get_token_limit") + @patch("application.core.settings.settings") + def test_check_context_limit_above_threshold( + self, mock_settings, mock_get_token_limit, mock_agent + ): + """Test context limit check when above threshold""" + mock_get_token_limit.return_value = 100 # Very small limit for testing + mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8 + + # Create messages that will exceed 80 tokens (80% of 100) + messages = [ + {"role": "system", "content": "a " * 50}, # ~50 tokens + {"role": "user", "content": "b " * 50}, # ~50 tokens + ] + + # Should return True when exceeding threshold + result = mock_agent._check_context_limit(messages) + assert result is True + + @patch("application.agents.base.logger") + def test_check_context_limit_error_handling(self, mock_logger, mock_agent): + """Test error handling in context limit check""" + # Force an error by making get_token_limit fail + with patch( + "application.core.model_utils.get_token_limit", side_effect=Exception("Test error") + ): + messages = [{"role": "user", "content": "test"}] + + result = mock_agent._check_context_limit(messages) + + # Should return False on error (safe default) + assert result is False + # Should log the error + assert mock_logger.error.called + + def test_context_limit_flag_initialization(self, mock_agent): + """Test that context limit flag is initialized""" + assert hasattr(mock_agent, "context_limit_reached") + assert mock_agent.context_limit_reached is False + + assert hasattr(mock_agent, "current_token_count") + assert mock_agent.current_token_count == 0 + + +class TestLLMHandlerTokenTracking: + """Test suite for LLM handler token tracking""" + + @patch("application.llm.handlers.base.logger") + def test_handle_tool_calls_stops_at_limit(self, mock_logger): + """Test that tool execution stops when context limit is reached""" + from application.llm.handlers.base import LLMHandler + + # Create a concrete handler for testing + class TestHandler(LLMHandler): + def parse_response(self, response): + pass + + def create_tool_message(self, tool_call, result): + return {"role": "tool", "content": str(result)} + + def _iterate_stream(self, response): + yield "" + + handler = TestHandler() + + # Create mock agent that hits limit on second tool + mock_agent = Mock() + mock_agent.context_limit_reached = False + + call_count = [0] + + def check_limit_side_effect(messages): + call_count[0] += 1 + # Return True on second call (second tool) + return call_count[0] >= 2 + + mock_agent._check_context_limit = Mock(side_effect=check_limit_side_effect) + mock_agent._execute_tool_action = Mock( + return_value=iter([{"type": "tool_call", "data": {}}]) + ) + + # Create multiple tool calls + tool_calls = [ + ToolCall(id="1", name="tool1", arguments={}), + ToolCall(id="2", name="tool2", arguments={}), + ToolCall(id="3", name="tool3", arguments={}), + ] + + messages = [] + tools_dict = {} + + # Execute tool calls + results = list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages)) + + # First tool should execute + assert mock_agent._execute_tool_action.call_count == 1 + + # Should have yielded skip messages for tools 2 and 3 + skip_messages = [r for r in results if r.get("type") == "tool_call" and r.get("data", {}).get("status") == "skipped"] + assert len(skip_messages) == 2 + + # Should have set the flag + assert mock_agent.context_limit_reached is True + + # Should have logged warning + assert mock_logger.warning.called + + def test_handle_tool_calls_all_execute_when_no_limit(self): + """Test that all tools execute when under limit""" + from application.llm.handlers.base import LLMHandler + + class TestHandler(LLMHandler): + def parse_response(self, response): + pass + + def create_tool_message(self, tool_call, result): + return {"role": "tool", "content": str(result)} + + def _iterate_stream(self, response): + yield "" + + handler = TestHandler() + + # Create mock agent that never hits limit + mock_agent = Mock() + mock_agent.context_limit_reached = False + mock_agent._check_context_limit = Mock(return_value=False) + mock_agent._execute_tool_action = Mock( + return_value=iter([{"type": "tool_call", "data": {}}]) + ) + + tool_calls = [ + ToolCall(id="1", name="tool1", arguments={}), + ToolCall(id="2", name="tool2", arguments={}), + ToolCall(id="3", name="tool3", arguments={}), + ] + + messages = [] + tools_dict = {} + + # Execute tool calls + list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages)) + + # All 3 tools should execute + assert mock_agent._execute_tool_action.call_count == 3 + + # Should not have set the flag + assert mock_agent.context_limit_reached is False + + @patch("application.llm.handlers.base.logger") + def test_handle_streaming_adds_warning_message(self, mock_logger): + """Test that streaming handler adds warning when limit reached""" + from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + class TestHandler(LLMHandler): + def parse_response(self, response): + if isinstance(response, dict) and response.get("type") == "tool_call": + return LLMResponse( + content="", + tool_calls=[ToolCall(id="1", name="test", arguments={}, index=0)], + finish_reason="tool_calls", + raw_response=None, + ) + else: + return LLMResponse( + content="Done", + tool_calls=[], + finish_reason="stop", + raw_response=None, + ) + + def create_tool_message(self, tool_call, result): + return {"role": "tool", "content": str(result)} + + def _iterate_stream(self, response): + if response == "first": + yield {"type": "tool_call"} # Object to be parsed, not string + else: + yield {"type": "stop"} # Object to be parsed, not string + + handler = TestHandler() + + # Create mock agent with limit reached + mock_agent = Mock() + mock_agent.context_limit_reached = True + mock_agent.model_id = "gpt-4o" + mock_agent.tools = [] + mock_agent.llm = Mock() + mock_agent.llm.gen_stream = Mock(return_value="second") + + def tool_handler_gen(*args): + yield {"type": "tool", "data": {}} + return [] + + # Mock handle_tool_calls to return messages and set flag + with patch.object( + handler, "handle_tool_calls", return_value=tool_handler_gen() + ): + messages = [] + tools_dict = {} + + # Execute streaming + list(handler.handle_streaming(mock_agent, "first", tools_dict, messages)) + + # Should have called gen_stream with tools=None (disabled) + mock_agent.llm.gen_stream.assert_called() + call_kwargs = mock_agent.llm.gen_stream.call_args.kwargs + assert call_kwargs.get("tools") is None + + # Should have logged the warning + assert mock_logger.info.called + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_compression_service.py b/tests/test_compression_service.py new file mode 100644 index 00000000..c6700e3a --- /dev/null +++ b/tests/test_compression_service.py @@ -0,0 +1,1082 @@ +import pytest +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +from application.api.answer.services.compression import CompressionService +from application.api.answer.services.compression.threshold_checker import ( + CompressionThresholdChecker, +) +from application.api.answer.services.compression.token_counter import TokenCounter +from application.api.answer.services.compression.prompt_builder import ( + CompressionPromptBuilder, +) +from application.core.settings import settings + + +@pytest.fixture +def mock_llm(): + """Create a mock LLM for testing""" + llm = Mock() + llm.gen = Mock() + return llm + + +@pytest.fixture +def compression_service(mock_llm): + """Create a CompressionService instance with mock LLM""" + return CompressionService(llm=mock_llm, model_id="gpt-4o") + + +@pytest.fixture +def threshold_checker(): + """Create a ThresholdChecker instance""" + return CompressionThresholdChecker() + + +@pytest.fixture +def prompt_builder(): + """Create a PromptBuilder instance""" + return CompressionPromptBuilder() + + +@pytest.fixture +def sample_conversation(): + """Create a sample conversation for testing""" + return { + "_id": "test_conversation_id", + "user": "test_user", + "date": datetime.now(timezone.utc), + "name": "Test Conversation", + "queries": [ + { + "prompt": "What is Python?", + "response": "Python is a high-level programming language.", + "thought": "", + "sources": [], + "tool_calls": [], + "timestamp": datetime.now(timezone.utc), + }, + { + "prompt": "How do I install it?", + "response": "You can install Python from python.org", + "thought": "", + "sources": [], + "tool_calls": [], + "timestamp": datetime.now(timezone.utc), + }, + { + "prompt": "What are some popular libraries?", + "response": "Popular Python libraries include NumPy, Pandas, Django, Flask, etc.", + "thought": "", + "sources": [], + "tool_calls": [], + "timestamp": datetime.now(timezone.utc), + }, + ], + } + + +@pytest.fixture +def large_conversation(): + """Create a large conversation that exceeds threshold""" + queries = [] + for i in range(100): + queries.append( + { + "prompt": f"Question {i}: " + ("test " * 100), # ~400 tokens each + "response": f"Answer {i}: " + ("response " * 100), # ~400 tokens each + "thought": "", + "sources": [], + "tool_calls": [], + "timestamp": datetime.now(timezone.utc), + } + ) + + return { + "_id": "large_conversation_id", + "user": "test_user", + "date": datetime.now(timezone.utc), + "name": "Large Conversation", + "queries": queries, + } + + +class TestCompressionService: + """Test suite for CompressionService""" + + def test_initialization(self, mock_llm): + """Test CompressionService initialization""" + service = CompressionService(llm=mock_llm, model_id="gpt-4o") + + assert service.llm == mock_llm + assert service.model_id == "gpt-4o" + assert service.prompt_builder is not None + assert service.prompt_builder.version == settings.COMPRESSION_PROMPT_VERSION + + @patch("application.api.answer.services.compression.threshold_checker.get_token_limit") + def test_should_compress_below_threshold( + self, mock_get_token_limit, threshold_checker, sample_conversation + ): + """Test that compression is not triggered when below threshold""" + mock_get_token_limit.return_value = 128000 # GPT-4o limit + + # Small conversation should not trigger compression + result = threshold_checker.should_compress( + sample_conversation, model_id="gpt-4o" + ) + + assert result is False + + @patch("application.api.answer.services.compression.threshold_checker.get_token_limit") + def test_should_compress_above_threshold( + self, mock_get_token_limit, threshold_checker, large_conversation + ): + """Test that compression is triggered when above threshold""" + mock_get_token_limit.return_value = 10000 # Lower limit to ensure large conversation exceeds threshold + + # Large conversation should trigger compression (100 queries with repeated text) + # Threshold at 80% of 10k = 8k tokens, so large_conversation > 8k should trigger + result = threshold_checker.should_compress( + large_conversation, model_id="gpt-4o" + ) + + assert result is True + + @patch("application.api.answer.services.compression.threshold_checker.get_token_limit") + def test_should_compress_at_exact_threshold( + self, mock_get_token_limit, threshold_checker + ): + """Test compression trigger at exact 80% threshold""" + mock_get_token_limit.return_value = 1000 + + # Create conversation with exactly 800 tokens (80% of 1000) + conversation = { + "queries": [ + { + "prompt": "a " * 200, # ~200 tokens + "response": "b " * 200, # ~200 tokens + }, + { + "prompt": "c " * 200, # ~200 tokens + "response": "d " * 200, # ~200 tokens + }, + ] + } + + result = threshold_checker.should_compress(conversation, model_id="test-model") + + # Should trigger at or above 80% + assert result is True + + def test_compress_conversation_basic(self, compression_service, sample_conversation): + """Test basic conversation compression""" + # Mock LLM response + mock_summary = """ + + The conversation covers Python basics and installation. + + + + 1. Primary Request and Intent: + User asked about Python and how to install it. + + 2. Key Concepts: + - Python programming language + - Installation process + + 3. Files and Code Sections: + None + + 4. Errors and fixes: + None + + 5. Problem Solving: + Explained Python installation from python.org + + 6. All user messages: + - What is Python? + - How do I install it? + - What are some popular libraries? + + 7. Pending Tasks: + None + + 8. Current Work: + Provided information about popular Python libraries. + + 9. Optional Next Step: + None + + """ + compression_service.llm.gen.return_value = mock_summary + + # Compress first 2 queries + result = compression_service.compress_conversation( + conversation=sample_conversation, compress_up_to_index=1 + ) + + # Verify LLM was called + assert compression_service.llm.gen.called + + # Verify result is a CompressionMetadata object + assert hasattr(result, 'timestamp') + assert result.query_index == 1 + assert hasattr(result, 'compressed_summary') + assert result.original_token_count > 0 + assert result.compressed_token_count > 0 + assert result.compression_ratio > 0 + assert result.model_used == "gpt-4o" + assert result.compression_prompt_version == settings.COMPRESSION_PROMPT_VERSION + + # Verify summary was extracted correctly (without analysis tags) + assert "" not in result.compressed_summary + assert "Primary Request and Intent" in result.compressed_summary + + def test_compress_conversation_with_tool_calls(self, compression_service): + """Test compression of conversation with tool calls""" + conversation = { + "queries": [ + { + "prompt": "Search for Python tutorials", + "response": "I'll search for Python tutorials.", + "thought": "Need to use search tool", + "sources": [], + "tool_calls": [ + { + "tool_name": "search_tool", + "action_name": "search", + "arguments": {"query": "Python tutorials"}, + "result": "Found 100 tutorials", + "status": "completed", + } + ], + "timestamp": datetime.now(timezone.utc), + } + ] + } + + mock_summary = "Test summary with tools" + compression_service.llm.gen.return_value = mock_summary + + compression_service.compress_conversation( + conversation=conversation, compress_up_to_index=0 + ) + + # Verify tool calls are included in compression prompt + call_args = compression_service.llm.gen.call_args + messages = call_args[1]["messages"] + user_message = messages[1]["content"] + + assert "Tool Calls:" in user_message + assert "search_tool" in user_message + + def test_compress_conversation_invalid_index( + self, compression_service, sample_conversation + ): + """Test compression with invalid index raises error""" + with pytest.raises(ValueError, match="Invalid compress_up_to_index"): + compression_service.compress_conversation( + conversation=sample_conversation, + compress_up_to_index=100, # Invalid - conversation only has 3 queries + ) + + def test_get_compressed_context_no_compression( + self, compression_service, sample_conversation + ): + """Test getting context when no compression exists""" + summary, recent = compression_service.get_compressed_context( + sample_conversation + ) + + assert summary is None + assert len(recent) == 3 # All queries returned + + def test_get_compressed_context_with_compression(self, compression_service): + """Test getting context when compression exists""" + conversation = { + "queries": [ + {"prompt": "Q1", "response": "A1"}, + {"prompt": "Q2", "response": "A2"}, + {"prompt": "Q3", "response": "A3"}, + {"prompt": "Q4", "response": "A4"}, + {"prompt": "Q5", "response": "A5"}, + ], + "compression_metadata": { + "is_compressed": True, + "last_compression_at": datetime.now(timezone.utc), + "compression_points": [ + { + "timestamp": datetime.now(timezone.utc), + "query_index": 2, # Compressed up to Q3 + "compressed_summary": "Summary of Q1-Q3", + "original_token_count": 100, + "compressed_token_count": 20, + "compression_ratio": 5.0, + } + ], + }, + } + + summary, recent = compression_service.get_compressed_context( + conversation + ) + + assert summary == "Summary of Q1-Q3" + assert len(recent) == 2 # Q4 and Q5 (after compression point) + assert recent[0]["prompt"] == "Q4" + assert recent[1]["prompt"] == "Q5" + + def test_get_compressed_context_multiple_compressions(self, compression_service): + """Test getting context when multiple compressions exist""" + conversation = { + "queries": [ + {"prompt": f"Q{i}", "response": f"A{i}"} for i in range(1, 11) + ], + "compression_metadata": { + "is_compressed": True, + "last_compression_at": datetime.now(timezone.utc), + "compression_points": [ + { + "timestamp": datetime.now(timezone.utc), + "query_index": 4, # First compression + "compressed_summary": "First compression summary", + "original_token_count": 100, + "compressed_token_count": 20, + }, + { + "timestamp": datetime.now(timezone.utc), + "query_index": 7, # Second compression + "compressed_summary": "Second compression summary (includes first)", + "original_token_count": 150, + "compressed_token_count": 30, + }, + ], + }, + } + + summary, recent = compression_service.get_compressed_context( + conversation + ) + + # Should use the most recent compression + assert summary == "Second compression summary (includes first)" + assert len(recent) == 2 # Q9 and Q10 (after compression point at index 7) + assert recent[0]["prompt"] == "Q9" + assert recent[1]["prompt"] == "Q10" + + def test_extract_summary_with_tags(self, compression_service): + """Test summary extraction with analysis and summary tags""" + llm_response = """ + + This is my analysis of the conversation. + It has multiple lines. + + + + This is the actual summary. + It should be extracted. + + """ + + result = compression_service._extract_summary(llm_response) + + assert "" not in result + assert "This is the actual summary" in result + assert "my analysis" not in result + + def test_extract_summary_without_tags(self, compression_service): + """Test summary extraction when no tags present""" + llm_response = "This is a plain summary without tags." + + result = compression_service._extract_summary(llm_response) + + assert result == "This is a plain summary without tags." + + def test_count_tokens_in_queries(self, sample_conversation): + """Test token counting in queries""" + queries = sample_conversation["queries"] + + token_count = TokenCounter.count_query_tokens(queries) + + # Should count all prompts and responses + assert token_count > 0 + + def test_count_tokens_with_tool_calls(self): + """Test token counting includes tool calls""" + queries = [ + { + "prompt": "Test prompt", + "response": "Test response", + "tool_calls": [ + { + "tool_name": "test_tool", + "action_name": "test_action", + "arguments": {"arg": "value"}, + "result": "Tool result", + } + ], + } + ] + + token_count_with_tools = TokenCounter.count_query_tokens( + queries, include_tool_calls=True + ) + token_count_without_tools = TokenCounter.count_query_tokens( + queries, include_tool_calls=False + ) + + assert token_count_with_tools > token_count_without_tools + + def test_format_conversation_for_compression( + self, prompt_builder, sample_conversation + ): + """Test conversation formatting for compression prompt""" + queries = sample_conversation["queries"] + + formatted = prompt_builder._format_conversation(queries) + + # Verify formatting includes all messages + assert "Message 1" in formatted + assert "What is Python?" in formatted + assert "Python is a high-level programming language" in formatted + assert "Message 2" in formatted + assert "How do I install it?" in formatted + + def test_build_compression_prompt_basic(self, prompt_builder): + """Test compression prompt building""" + queries = [ + {"prompt": "Q1", "response": "A1", "tool_calls": [], "sources": []}, + {"prompt": "Q2", "response": "A2", "tool_calls": [], "sources": []}, + ] + + messages = prompt_builder.build_prompt(queries) + + assert len(messages) == 2 # System and user messages + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "user" + assert "conversation to summarize" in messages[1]["content"] + + def test_build_compression_prompt_with_existing_compressions( + self, prompt_builder + ): + """Test compression prompt building with existing compressions""" + queries = [ + {"prompt": "Q3", "response": "A3", "tool_calls": [], "sources": []}, + {"prompt": "Q4", "response": "A4", "tool_calls": [], "sources": []}, + ] + + existing_compressions = [ + { + "query_index": 1, + "compressed_summary": "Previous compression summary", + "timestamp": datetime.now(timezone.utc), + } + ] + + messages = prompt_builder.build_prompt( + queries, existing_compressions + ) + + user_content = messages[1]["content"] + + # Should mention existing compression + assert "compressed before" in user_content + assert "Previous compression summary" in user_content + assert "NEW summary" in user_content + + def test_calculate_conversation_tokens( + self, sample_conversation + ): + """Test conversation token calculation""" + token_count = TokenCounter.count_conversation_tokens( + sample_conversation, include_system_prompt=False + ) + + assert token_count > 0 + + # With system prompt should be higher + token_count_with_system = TokenCounter.count_conversation_tokens( + sample_conversation, include_system_prompt=True + ) + + assert token_count_with_system > token_count + + @patch("application.api.answer.services.compression.threshold_checker.logger") + def test_error_handling_in_should_compress( + self, mock_logger, threshold_checker, sample_conversation + ): + """Test error handling in should_compress""" + # Force an error by making get_token_limit raise an exception + with patch( + "application.api.answer.services.compression.threshold_checker.get_token_limit", + side_effect=Exception("Test error"), + ): + result = threshold_checker.should_compress( + sample_conversation, model_id="gpt-4o" + ) + + # Should return False on error + assert result is False + # Should log the error + assert mock_logger.error.called + + @patch("application.api.answer.services.compression.service.logger") + def test_error_handling_in_get_compressed_context( + self, mock_logger, compression_service + ): + """Test error handling in get_compressed_context""" + # Malformed conversation + malformed_conversation = {"queries": None} + + summary, recent = compression_service.get_compressed_context( + malformed_conversation + ) + + # Should return safe defaults + assert summary is None + assert recent == [] + # Should log the error + assert mock_logger.error.called + + + def test_compression_points_array_limiting(self, compression_service): + """Test that only the most recent compression points are kept""" + # Simulate a conversation with 3 previous compressions + conversation = { + "queries": [ + {"prompt": f"Q{i}", "response": f"A{i}"} for i in range(1, 11) + ], + "compression_metadata": { + "is_compressed": True, + "last_compression_at": datetime.now(timezone.utc), + "compression_points": [ + { + "timestamp": datetime.now(timezone.utc), + "query_index": 2, + "compressed_summary": "First compression summary", + "original_token_count": 100, + "compressed_token_count": 20, + }, + { + "timestamp": datetime.now(timezone.utc), + "query_index": 5, + "compressed_summary": "Second compression summary", + "original_token_count": 150, + "compressed_token_count": 30, + }, + { + "timestamp": datetime.now(timezone.utc), + "query_index": 7, + "compressed_summary": "Third compression summary", + "original_token_count": 200, + "compressed_token_count": 40, + }, + ], + }, + } + + # The service should use the most recent compression + summary, recent = compression_service.get_compressed_context( + conversation + ) + + # Should use the most recent (third) compression + assert summary == "Third compression summary" + assert len(recent) == 2 # Q9 and Q10 (after compression point at index 7) + assert recent[0]["prompt"] == "Q9" + assert recent[1]["prompt"] == "Q10" + + def test_compression_with_heavy_tool_usage(self, compression_service): + """Test compression when conversation has many tool calls with large responses + + Scenario: User asks agent to scrape all files in a GitHub repo, generating + dozens of tool calls with file contents as responses. This tests the system's + ability to compress tool-heavy conversations that hit token limits. + """ + # Simulate a conversation where agent scraped 50 files from DocsGPT repo + queries = [] + + # Initial user request + queries.append({ + "prompt": "Please analyze all Python files in the https://github.com/arc53/DocsGPT repository", + "response": "I'll scrape all the Python files from the DocsGPT repository and analyze them.", + "tool_calls": [] + }) + + # Simulate 50 file scraping tool calls with realistic file contents + file_paths = [ + "application/app.py", + "application/api/answer/routes.py", + "application/api/answer/services/conversation_service.py", + "application/api/answer/services/compression_service.py", + "application/api/answer/services/stream_processor.py", + "application/agents/base.py", + "application/agents/react.py", + "application/llm/handlers/base.py", + "application/llm/llm_creator.py", + "application/core/settings.py", + "application/core/model_configs.py", + "application/utils.py", + "application/vectorstore/base.py", + "application/parser/file_parser.py", + "tests/test_compression_service.py", + "tests/test_agent_token_tracking.py", + "frontend/src/App.tsx", + "frontend/src/store/index.ts", + "deployment/docker-compose.yaml", + "setup.py", + ] + + tool_calls = [] + for i, file_path in enumerate(file_paths[:20]): # First 20 files + # Each tool call with realistic file content (simulating ~500-1000 tokens per file) + file_content = f""" +# {file_path} + +import os +import sys +from typing import Dict, List, Optional, Any +from datetime import datetime + +class {file_path.split('/')[-1].replace('.py', '').title()}: + ''' + This is a module that handles various operations for the DocsGPT application. + It contains multiple classes and functions for processing data. + ''' + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.initialized = False + self.data_store = {{}} + + def process_data(self, input_data: List[str]) -> Dict[str, Any]: + '''Process input data and return results''' + results = {{}} + for item in input_data: + # Complex processing logic here + processed = self._transform_item(item) + results[item] = processed + return results + + def _transform_item(self, item: str) -> str: + '''Internal transformation logic''' + # Multiple lines of transformation code + transformed = item.upper().strip() + transformed = transformed.replace(' ', '_') + return transformed + + def validate_config(self) -> bool: + '''Validate configuration settings''' + required_keys = ['api_key', 'endpoint', 'model_id'] + return all(key in self.config for key in required_keys) + +# Additional helper functions +def utility_function_one(param: str) -> str: + return param.strip().lower() + +def utility_function_two(data: Dict) -> List: + return list(data.values()) + +def main(): + config = {{'api_key': 'test', 'endpoint': 'http://localhost', 'model_id': 'gpt-4'}} + instance = {file_path.split('/')[-1].replace('.py', '').title()}(config) + instance.process_data(['item1', 'item2', 'item3']) +""" * 2 # Double it to simulate ~1000-1500 tokens per response + + tool_calls.append({ + "call_id": f"call_{i}", + "tool_name": "github_file_scraper", + "action_name": "read_file", + "arguments": {"file_path": file_path}, + "result": {"content": file_content, "status": "success"}, + "status": "success" + }) + + # Add query with all tool calls + queries.append({ + "prompt": "[Agent continues processing]", + "response": "I've scraped 20 Python files. Let me analyze the patterns...", + "tool_calls": tool_calls + }) + + # Add analysis response + queries.append({ + "prompt": "[Agent continues analysis]", + "response": """Based on my analysis of the 20 Python files: + +1. Architecture: The codebase follows a modular architecture with clear separation between API, agents, LLM handlers, and utilities. + +2. Key patterns identified: + - Heavy use of type hints (typing module) + - Consistent error handling patterns + - Service-based architecture for API endpoints + - Factory pattern for LLM creation + - Abstract base classes for extensibility + +3. Core components: + - Agent system with tool integration + - LLM provider abstraction + - Compression service for context management + - Stream processing for real-time responses + +4. Code quality observations: + - Comprehensive docstrings + - Good test coverage + - Clear naming conventions + - Proper separation of concerns""", + "tool_calls": [] + }) + + conversation = {"queries": queries} + + # Mock LLM response for compression + mock_summary = """ + User requested analysis of all Python files in DocsGPT GitHub repository. + Agent scraped 20 files including app.py, API routes, services, agents, and tests. + Analysis revealed modular architecture with service-based design, type hints, + factory patterns, and agent system with tool integration. Code quality is high + with comprehensive docstrings and test coverage. + """ + compression_service.llm.gen.return_value = mock_summary + + # Compress the heavy tool usage + result = compression_service.compress_conversation( + conversation=conversation, + compress_up_to_index=1 # Compress first 2 queries (including all tool calls) + ) + + # Verify compression handled tool calls properly + assert result.query_index == 1 + assert result.compressed_summary is not None + + # Verify the compression prompt included tool call information + call_args = compression_service.llm.gen.call_args + messages = call_args[1]["messages"] + user_message = messages[1]["content"] + + # Should include tool calls section + assert "Tool Calls:" in user_message + assert "github_file_scraper" in user_message + + # Verify compression ratio (should be significant with all that tool data) + original_tokens = result.original_token_count + compressed_tokens = result.compressed_token_count + compression_ratio = result.compression_ratio + + # With 20 large tool responses, original should be substantial + assert original_tokens > 5000 # At least 5k tokens from tool responses + assert compressed_tokens < 500 # Summary should be much smaller + assert compression_ratio > 10 # Should achieve >10x compression + + def test_compression_with_needle_in_haystack(self, compression_service): + """Test compression preserves important information buried in long conversation + + Scenario: User has long conversation with verbose responses, hiding critical + information in the middle. Tests that compression correctly identifies and + preserves key details even when surrounded by lengthy content. + """ + # Create a long conversation with important info buried in the middle + queries = [] + + # Query 1: Long general discussion + queries.append({ + "prompt": "Tell me about Python programming best practices", + "response": """Python best practices encompass a wide range of principles and patterns. +Let me provide a comprehensive overview: + +1. Code Style and Formatting: + - Follow PEP 8 style guide for consistent formatting + - Use 4 spaces for indentation (not tabs) + - Limit lines to 79 characters for code, 72 for docstrings + - Use meaningful variable names that describe their purpose + - Add whitespace around operators and after commas + - Group imports: standard library, third-party, local + +2. Documentation: + - Write clear docstrings for all functions, classes, and modules + - Use type hints for better code clarity and IDE support + - Include examples in docstrings when helpful + - Keep comments up-to-date with code changes + +3. Error Handling: + - Use specific exceptions rather than bare except clauses + - Create custom exceptions for domain-specific errors + - Always clean up resources with context managers (with statement) + - Log errors appropriately for debugging + +4. Testing: + - Write unit tests for all critical functionality + - Aim for high test coverage (80%+) + - Use pytest for modern testing features + - Mock external dependencies in tests + +5. Code Organization: + - Keep functions small and focused on single tasks + - Use classes to group related functionality + - Avoid deep nesting (max 3-4 levels) + - Extract complex conditions into well-named variables + +6. Performance: + - Use list comprehensions for simple transformations + - Avoid premature optimization + - Profile code before optimizing + - Use generators for large datasets + +These practices help maintain readable, maintainable, and efficient code.""", + "tool_calls": [] + }) + + # Query 2: Another long response + queries.append({ + "prompt": "What about Python data structures?", + "response": """Python provides several built-in data structures, each optimized for different use cases: + +1. Lists: + - Ordered, mutable sequences + - Dynamic sizing with amortized O(1) append + - Access by index in O(1) + - Insertion/deletion in middle is O(n) + - Use cases: ordered collections, stacks, queues + - Methods: append(), extend(), insert(), remove(), pop(), sort() + +2. Tuples: + - Ordered, immutable sequences + - Slightly more memory efficient than lists + - Can be used as dictionary keys (if contents are hashable) + - Use cases: fixed collections, function return values, dictionary keys + +3. Dictionaries: + - Unordered (ordered in Python 3.7+) key-value mappings + - Average O(1) lookup, insertion, deletion + - Keys must be hashable + - Use cases: lookups, caching, counting, grouping + - Methods: get(), keys(), values(), items(), update(), pop() + +4. Sets: + - Unordered collections of unique elements + - Average O(1) membership testing + - Efficient for removing duplicates + - Support set operations: union, intersection, difference + - Use cases: membership testing, removing duplicates, set mathematics + +5. Collections module extensions: + - defaultdict: dict with default values for missing keys + - Counter: dict subclass for counting hashable objects + - deque: double-ended queue with O(1) append/pop from both ends + - OrderedDict: maintains insertion order (less relevant in Python 3.7+) + - namedtuple: tuple subclass with named fields + +6. Performance considerations: + - Lists for ordered data with frequent append operations + - Dictionaries for key-based lookups + - Sets for membership testing and uniqueness + - Deques for queue operations from both ends + - Tuples for immutable data + +Understanding these data structures is crucial for writing efficient Python code.""", + "tool_calls": [] + }) + + # Query 3: THE CRITICAL INFORMATION (needle in the haystack) + queries.append({ + "prompt": "I need to remember this important detail", + "response": """I'll make a note of that important detail. + +CRITICAL INFORMATION TO REMEMBER: +The production database password is stored in the environment variable DB_PASSWORD_PROD. +The backup schedule is set to run daily at 3:00 AM UTC. +The API rate limit for premium users is 10,000 requests per hour. +The encryption key rotation happens every 90 days. +The primary contact for incidents is: ops-team@example.com + +I've recorded this information for our conversation. These operational details are important for system administration and should be referenced when needed.""", + "tool_calls": [] + }) + + # Query 4: More long content after the important info + queries.append({ + "prompt": "Explain Python decorators in detail", + "response": """Python decorators are a powerful feature that allows you to modify or enhance functions and classes. Here's a comprehensive explanation: + +1. Basic Concept: + - Decorators are functions that take another function as input + - They return a modified version of that function + - Syntax: @decorator above function definition + - They implement the decorator design pattern + +2. Function Decorators: + ```python + def my_decorator(func): + def wrapper(*args, **kwargs): + # Code before function + result = func(*args, **kwargs) + # Code after function + return result + return wrapper + + @my_decorator + def my_function(): + pass + ``` + +3. Common Use Cases: + - Logging: Record function calls and results + - Timing: Measure execution time + - Authentication: Check permissions before execution + - Caching: Store and return cached results + - Validation: Check input parameters + - Rate limiting: Throttle function calls + +4. Decorators with Arguments: + ```python + def repeat(times): + def decorator(func): + def wrapper(*args, **kwargs): + for _ in range(times): + result = func(*args, **kwargs) + return result + return wrapper + return decorator + + @repeat(3) + def greet(): + print("Hello") + ``` + +5. Class Decorators: + - Can decorate entire classes + - Useful for adding methods or attributes + - Can enforce patterns like singleton + +6. Built-in Decorators: + - @property: Create managed attributes + - @staticmethod: Define static methods + - @classmethod: Define class methods + - @abstractmethod: Define abstract methods + +7. functools.wraps: + - Preserves original function metadata + - Should be used in decorator implementations + - Maintains __name__, __doc__, etc. + +8. Practical Examples: + - @login_required for web routes + - @cache for memoization + - @retry for resilient API calls + - @deprecated for marking old code + +Decorators are essential for writing clean, maintainable Python code with separation of concerns.""", + "tool_calls": [] + }) + + # Query 5: Final long response + queries.append({ + "prompt": "What about Python async programming?", + "response": """Asynchronous programming in Python allows for concurrent execution of I/O-bound operations: + +1. Core Concepts: + - Event loop: Manages and executes async tasks + - Coroutines: Functions defined with async def + - await: Pauses coroutine until awaitable completes + - Tasks: Wrapper for coroutines to run concurrently + +2. Basic Syntax: + ```python + import asyncio + + async def fetch_data(): + await asyncio.sleep(1) + return "data" + + async def main(): + result = await fetch_data() + print(result) + + asyncio.run(main()) + ``` + +3. When to Use Async: + - I/O-bound operations (network requests, file I/O, database queries) + - Multiple concurrent operations + - Real-time applications (websockets, streaming) + - NOT for CPU-bound tasks (use multiprocessing instead) + +4. Common Patterns: + - Gather: Run multiple coroutines concurrently + - create_task: Schedule coroutine execution + - Semaphore: Limit concurrent operations + - Queue: Producer-consumer patterns + +5. Async Libraries: + - aiohttp: Async HTTP client/server + - asyncpg: Async PostgreSQL driver + - motor: Async MongoDB driver + - aioredis: Async Redis client + +6. Error Handling: + - Use try/except in coroutines + - Tasks can be cancelled with task.cancel() + - Timeouts with asyncio.wait_for() + +Understanding async programming is crucial for building scalable Python applications.""", + "tool_calls": [] + }) + + conversation = {"queries": queries} + + # Mock LLM response that MUST preserve the critical information + mock_summary = """ + User asked about Python best practices, data structures, decorators, and async programming. + Discussed code style, testing, documentation standards, and various Python data structures. + + CRITICAL OPERATIONAL DETAILS PROVIDED: + - Production database password stored in DB_PASSWORD_PROD environment variable + - Backup schedule: daily at 3:00 AM UTC + - Premium API rate limit: 10,000 requests/hour + - Encryption key rotation: every 90 days + - Incident contact: ops-team@example.com + + Also covered decorators for code enhancement and async programming for I/O-bound operations. + """ + compression_service.llm.gen.return_value = mock_summary + + # Compress everything except the last query + result = compression_service.compress_conversation( + conversation=conversation, + compress_up_to_index=3 # Compress first 4 queries (includes the critical info) + ) + + # Verify compression happened + assert result.query_index == 3 + assert result.compressed_summary is not None + + # Get the compressed context + conversation["compression_metadata"] = { + "is_compressed": True, + "last_compression_at": datetime.now(timezone.utc), + "compression_points": [result.to_dict()] + } + + summary, recent = compression_service.get_compressed_context( + conversation + ) + + # Verify critical information is in the summary + assert summary is not None + assert "DB_PASSWORD_PROD" in summary or "database password" in summary.lower() + assert "3:00 AM UTC" in summary or "backup" in summary.lower() + assert "10,000" in summary or "rate limit" in summary.lower() + assert "ops-team@example.com" in summary or "incident contact" in summary.lower() + + # Verify only the last query is in recent + assert len(recent) == 1 + assert "async programming" in recent[0]["prompt"].lower() + + # The compression should be substantial (long responses compressed to summary) + assert result.original_token_count > 1300 # 4 long responses + assert result.compressed_token_count < 300 # Summary should be concise + assert result.compression_ratio > 4 # At least 4x compression + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100755 index 00000000..a3588c7e --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,1287 @@ +#!/usr/bin/env python3 +""" +Integration test script for DocsGPT API endpoints. + +Tests: +1. /stream endpoint without agent +2. /api/answer endpoint without agent +3. Create agent via API +4. /stream endpoint with agent +5. /api/answer endpoint with agent + +Usage: + python tests/test_integration.py # auto-generates JWT token from local secret when available + python tests/test_integration.py --base-url http://localhost:7091 + python tests/test_integration.py --token YOUR_JWT_TOKEN # override auto-generation +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path +from typing import Optional + +import requests + + +class Colors: + """ANSI color codes for terminal output""" + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + + +def generate_default_token() -> tuple[Optional[str], Optional[str]]: + """ + Try to generate a JWT token using the same logic as generate_test_token.py. + Returns a tuple of (token, error_message). Token is None on failure. + """ + secret = os.getenv("JWT_SECRET_KEY") + key_file = Path(".jwt_secret_key") + + if not secret: + try: + secret = key_file.read_text().strip() + except FileNotFoundError: + return None, f"Set JWT_SECRET_KEY or create {key_file} by running the backend once." + except OSError as exc: + return None, f"Could not read {key_file}: {exc}" + + if not secret: + return None, "JWT secret key is empty." + + try: + from jose import jwt # type: ignore + except ImportError: + return None, "python-jose is not installed (pip install 'python-jose' to auto-generate tokens)." + + try: + payload = {"sub": "test_integration_user"} + return jwt.encode(payload, secret, algorithm="HS256"), None + except Exception as exc: + return None, f"Failed to generate JWT token: {exc}" + + +class DocsGPTTester: + def __init__(self, base_url: str, token: Optional[str] = None, token_source: str = "provided"): + self.base_url = base_url.rstrip('/') + self.token = token + self.token_source = token_source + self.headers = {} + if token: + self.headers['Authorization'] = f'Bearer {token}' + self.agent_id = None + self.test_results = [] + + def print_header(self, message: str): + """Print a colored header""" + print(f"\n{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}") + print(f"{Colors.HEADER}{Colors.BOLD}{message}{Colors.ENDC}") + print(f"{Colors.HEADER}{Colors.BOLD}{'=' * 70}{Colors.ENDC}\n") + + def print_success(self, message: str): + """Print a success message""" + print(f"{Colors.OKGREEN}✓ {message}{Colors.ENDC}") + + def print_error(self, message: str): + """Print an error message""" + print(f"{Colors.FAIL}✗ {message}{Colors.ENDC}") + + def print_info(self, message: str): + """Print an info message""" + print(f"{Colors.OKCYAN}ℹ {message}{Colors.ENDC}") + + def print_warning(self, message: str): + """Print a warning message""" + print(f"{Colors.WARNING}⚠ {message}{Colors.ENDC}") + + def test_stream_endpoint(self, agent_id: Optional[str] = None) -> bool: + """Test the /stream endpoint""" + endpoint = f"{self.base_url}/stream" + test_name = f"Stream endpoint{'with agent ' + agent_id if agent_id else ' (no agent)'}" + + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "isNoneDoc": True, + } + + if agent_id: + payload["agent_id"] = agent_id + + try: + self.print_info(f"POST {endpoint}") + self.print_info(f"Payload: {json.dumps(payload, indent=2)}") + + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + stream=True, + timeout=30 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return False + + # Parse SSE stream + events = [] + full_response = "" + conversation_id = None + + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + data_str = line[6:] # Remove 'data: ' prefix + try: + data = json.loads(data_str) + events.append(data) + + # Handle different event types + if data.get('type') in ['stream', 'answer']: + # Both 'stream' and 'answer' types contain response text + full_response += data.get('message', '') or data.get('answer', '') + elif data.get('type') == 'id': + conversation_id = data.get('id') + elif data.get('type') == 'end': + break + except json.JSONDecodeError: + pass + + self.print_success(f"Received {len(events)} events") + self.print_info(f"Response preview: {full_response[:100]}...") + + if conversation_id: + self.print_success(f"Conversation ID: {conversation_id}") + + if not full_response: + self.print_warning("No response content received") + + self.test_results.append((test_name, True, "Success")) + self.print_success(f"{test_name} passed!") + return True + + except requests.exceptions.RequestException as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + def test_answer_endpoint(self, agent_id: Optional[str] = None) -> bool: + """Test the /api/answer endpoint""" + endpoint = f"{self.base_url}/api/answer" + test_name = f"Answer endpoint{' with agent ' + agent_id if agent_id else ' (no agent)'}" + + self.print_header(f"Testing {test_name}") + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "isNoneDoc": True, + } + + if agent_id: + payload["agent_id"] = agent_id + + try: + self.print_info(f"POST {endpoint}") + self.print_info(f"Payload: {json.dumps(payload, indent=2)}") + + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + timeout=30 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return False + + result = response.json() + + self.print_info(f"Response keys: {list(result.keys())}") + + if 'answer' in result: + answer = result['answer'] + self.print_success(f"Answer received: {answer[:100]}...") + else: + self.print_warning("No 'answer' field in response") + + if 'conversation_id' in result: + self.print_success(f"Conversation ID: {result['conversation_id']}") + + if 'sources' in result: + self.print_info(f"Sources: {len(result['sources'])} items") + + self.test_results.append((test_name, True, "Success")) + self.print_success(f"{test_name} passed!") + return True + + except requests.exceptions.RequestException as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + def upload_text_source(self) -> Optional[str]: + """Upload a simple text source for testing + + This creates a source without requiring crawler infrastructure. + """ + endpoint = f"{self.base_url}/api/upload" + test_name = "Upload Text Source" + + self.print_header(f"Testing {test_name}") + + if not self.token: + self.print_warning("No authentication token provided") + self.print_info("Source upload requires authentication") + self.test_results.append((test_name, True, "Skipped (auth required)")) + return None + + # Create a simple text file for upload + test_content = """# DocsGPT Test Documentation + +## Installation + +To install DocsGPT, follow these steps: + +1. Clone the repository +2. Run `docker compose up` +3. Access the application at http://localhost:5173 + +## Configuration + +DocsGPT can be configured using environment variables: +- API_KEY: Your OpenAI API key +- LLM_PROVIDER: Choose between openai, anthropic, or google +- ENABLE_CONVERSATION_COMPRESSION: Enable context compression + +## Features + +DocsGPT provides: +- Conversation compression for long chats +- Real-time token tracking +- Multiple LLM provider support +- Agent system with tools +""" + + try: + self.print_info(f"POST {endpoint}") + self.print_info("Uploading test documentation...") + + # Create a file-like object + files = { + 'file': ('test_docs.txt', test_content.encode(), 'text/plain') + } + data = { + 'user': 'test_user', + 'name': f'Test Docs {int(time.time())}', + } + + response = requests.post( + endpoint, + files=files, + data=data, + headers=self.headers, + timeout=30 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code == 200: + result = response.json() + task_id = result.get('task_id') + + if task_id: + self.print_success(f"Upload task started: {task_id}") + self.print_info("Waiting for processing (10 seconds)...") + time.sleep(10) + self.test_results.append((test_name, True, f"Task: {task_id}")) + return task_id + else: + self.print_warning("No task_id returned") + self.test_results.append((test_name, False, "No task_id")) + return None + else: + self.print_error(f"Expected 200, got {response.status_code}") + try: + error_data = response.json() + self.print_error(f"Error: {error_data}") + except Exception: + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return None + + except requests.exceptions.RequestException as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + + def upload_crawler_source(self) -> Optional[str]: + """Upload a crawler source for DocsGPT documentation""" + endpoint = f"{self.base_url}/api/remote" + test_name = "Upload Crawler Source" + + self.print_header(f"Testing {test_name}") + + if not self.token: + self.print_warning("No authentication token provided") + self.print_info("Source upload requires authentication") + self.print_info("Skipping source upload and agent tests...") + self.test_results.append((test_name, True, "Skipped (auth required)")) + return None + + payload = { + "user": "test_user", + "source": "crawler", + "name": f"DocsGPT Docs {int(time.time())}", + "data": json.dumps({"url": "https://docs.docsgpt.cloud/"}), + } + + try: + self.print_info(f"POST {endpoint}") + self.print_info("Crawling: https://docs.docsgpt.cloud/") + + response = requests.post( + endpoint, + data=payload, + headers=self.headers, + timeout=30 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code == 200: + result = response.json() + task_id = result.get('task_id') + + if task_id: + self.print_success(f"Crawler task started: {task_id}") + self.print_info("Waiting for crawler to complete (30 seconds)...") + time.sleep(30) # Wait for crawler to process + self.test_results.append((test_name, True, f"Task: {task_id}")) + return task_id + else: + self.print_warning("No task_id returned") + self.test_results.append((test_name, False, "No task_id")) + return None + else: + self.print_error(f"Expected 200, got {response.status_code}") + try: + error_data = response.json() + self.print_error(f"Error: {error_data}") + except Exception: + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return None + + except requests.exceptions.RequestException as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + + def get_source_id_from_task(self, task_id: str) -> Optional[str]: + """Check task status and get source ID""" + endpoint = f"{self.base_url}/api/task_status" + + try: + response = requests.get( + endpoint, + params={"task_id": task_id}, + headers=self.headers, + timeout=10 + ) + + if response.status_code == 200: + result = response.json() + if result.get('status') == 'SUCCESS': + # Task completed, now find the source + # Query sources collection to find the latest source + sources_response = requests.get( + f"{self.base_url}/api/sources", + headers=self.headers, + timeout=10 + ) + if sources_response.status_code == 200: + sources = sources_response.json() + # Filter out the "Default" source and get user sources only + user_sources = [s for s in sources if s.get('date') != 'default'] + if user_sources and len(user_sources) > 0: + # Get the most recent source (first one, as they're sorted by date desc) + latest_source = user_sources[0] + return latest_source.get('id') + return None + except Exception as e: + self.print_error(f"Error getting source ID: {str(e)}") + return None + + def create_agent(self, source_id: Optional[str] = None, published: bool = False) -> Optional[tuple]: + """Create an agent via API + + Args: + source_id: Optional source ID to attach to agent + published: If True, create published agent (requires source_id) + + Returns: + Tuple of (agent_id, api_key) if successful, None otherwise + """ + endpoint = f"{self.base_url}/api/create_agent" + + if published and source_id: + test_name = f"Create Published Agent with source {source_id[:8]}..." + elif published: + test_name = "Create Published Agent (skipped - no source)" + else: + test_name = "Create Draft Agent" + + self.print_header(f"Testing {test_name}") + + if not self.token: + self.print_warning("No authentication token provided") + self.print_info("Agent creation requires authentication") + self.print_info("To test agents, provide a JWT token with --token argument") + self.print_info("Skipping agent tests...") + # Mark as skipped rather than attempting without auth + self.test_results.append((test_name, True, "Skipped (auth required)")) + return None + + # Published agents require a source + if published and not source_id: + self.print_warning("Cannot create published agent without source") + self.test_results.append((test_name, True, "Skipped (no source)")) + return None + + # Create payload based on type + if published: + self.print_info(f"Creating published agent with source {source_id[:8]}...") + payload = { + "name": f"Test Agent (Published) {int(time.time())}", + "description": "Integration test agent with source", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": "classic", + "status": "published", + "source": source_id, + } + else: + self.print_info("Creating draft agent (for agent_id testing)") + payload = { + "name": f"Test Agent (Draft) {int(time.time())}", + "description": "Integration test draft agent", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": "classic", + "status": "draft", + } + + try: + self.print_info(f"POST {endpoint}") + self.print_info(f"Payload: {json.dumps(payload, indent=2)}") + + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + timeout=10 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code in [200, 201]: # Accept both 200 OK and 201 Created + result = response.json() + agent_id = result.get('id') + api_key = result.get('key', '') + + if agent_id: + self.agent_id = agent_id + self.print_success(f"Agent created with ID: {agent_id}") + if api_key: + self.print_success(f"Agent API key: {api_key[:20]}...") + self.test_results.append((test_name, True, f"ID: {agent_id}, API Key: Yes")) + return (agent_id, api_key) + else: + self.print_warning("Agent created but no API key (draft agent)") + self.test_results.append((test_name, True, f"ID: {agent_id}, API Key: No")) + return (agent_id, None) + else: + self.print_warning("Agent created but no ID returned") + self.test_results.append((test_name, False, "No ID returned")) + return None + elif response.status_code == 401: + self.print_warning("Authentication required for agent creation") + self.print_info("To test agents, provide a JWT token with --token argument") + self.print_info("Skipping agent tests...") + # Mark as "skipped" rather than "failed" + self.test_results.append((test_name, True, "Skipped (auth required)")) + return None + else: + self.print_error(f"Expected 200/201, got {response.status_code}") + try: + error_data = response.json() + self.print_error(f"Error: {error_data.get('message', response.text[:200])}") + except Exception: + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return None + + except requests.exceptions.RequestException as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + + def test_api_key_endpoint(self, api_key: str, endpoint_type: str = "stream") -> bool: + """Test endpoint with API key instead of agent_id""" + test_name = f"{endpoint_type.capitalize()} endpoint with API key" + + self.print_header(f"Testing {test_name}") + + if endpoint_type == "stream": + endpoint = f"{self.base_url}/stream" + else: + endpoint = f"{self.base_url}/api/answer" + + payload = { + "question": "What is DocsGPT?", + "history": "[]", + "api_key": api_key, # Use api_key instead of agent_id + } + + try: + self.print_info(f"POST {endpoint}") + self.print_info(f"Using API key: {api_key[:20]}...") + + if endpoint_type == "stream": + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + stream=True, + timeout=30 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return False + + # Parse SSE stream + events = [] + full_response = "" + + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + data_str = line[6:] + try: + data = json.loads(data_str) + events.append(data) + + if data.get('type') in ['stream', 'answer']: + full_response += data.get('message', '') or data.get('answer', '') + elif data.get('type') == 'end': + break + except json.JSONDecodeError: + pass + + self.print_success(f"Received {len(events)} events") + self.print_info(f"Response preview: {full_response[:100]}...") + self.test_results.append((test_name, True, "Success")) + return True + + else: # answer endpoint + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + timeout=30 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code != 200: + self.print_error(f"Expected 200, got {response.status_code}") + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return False + + result = response.json() + answer = result.get('answer', '') + self.print_success(f"Answer received: {answer[:100]}...") + self.test_results.append((test_name, True, "Success")) + return True + + except requests.exceptions.RequestException as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + def test_model_validation(self) -> bool: + """Test model_id validation""" + endpoint = f"{self.base_url}/stream" + test_name = "Model validation (invalid model_id)" + + self.print_header(f"Testing {test_name}") + + payload = { + "question": "Test question", + "history": "[]", + "isNoneDoc": True, + "model_id": "invalid-model-xyz-123", + } + + try: + self.print_info(f"POST {endpoint}") + self.print_info("Testing with invalid model_id: invalid-model-xyz-123") + + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + stream=True, + timeout=10 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code == 400: + # Read the error from SSE stream + error_message = None + error_field = None + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + data_str = line[6:] + try: + data = json.loads(data_str) + if data.get('type') == 'error': + # Try both 'message' and 'error' fields + error_message = data.get('message') or data.get('error', '') + error_field = 'message' if 'message' in data else 'error' + break + except json.JSONDecodeError: + pass + + # Consider it successful if we got a 400 with any error message + if error_message: + self.print_success("Invalid model_id rejected with 400 status") + self.print_info(f"Error ({error_field}): {error_message[:200]}") + + # Check if it's the detailed validation error or generic error + if 'Invalid model_id' in error_message or 'model' in error_message.lower(): + self.print_success("✓ Validation error contains model information") + self.test_results.append((test_name, True, "Validation works")) + else: + self.print_warning("Generic error message (validation may need improvement)") + self.test_results.append((test_name, True, "Generic validation")) + return True + else: + self.print_warning("No error message in response") + self.test_results.append((test_name, False, "No error message")) + return False + else: + self.print_warning(f"Expected 400, got {response.status_code}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return False + + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + def create_web_scraping_agent(self) -> Optional[tuple]: + """Create an agent with read_webpage tool enabled + + Returns: + Tuple of (agent_id, api_key) if successful, None otherwise + """ + endpoint = f"{self.base_url}/api/create_agent" + test_name = "Create Web Scraping Agent" + + self.print_header(f"Testing {test_name}") + + if not self.token: + self.print_warning("No authentication token provided") + self.test_results.append((test_name, True, "Skipped (auth required)")) + return None + + # Create agent with read_webpage tool + payload = { + "name": f"Web Scraping Agent {int(time.time())}", + "description": "Test agent with read_webpage tool for compression testing", + "prompt_id": "default", + "chunks": 2, + "retriever": "classic", + "agent_type": "react", # ReAct agent supports tools + "status": "draft", + "tools": ["read_webpage"], # Enable read_webpage tool + } + + try: + self.print_info(f"POST {endpoint}") + self.print_info("Creating agent with read_webpage tool...") + + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + timeout=10 + ) + + self.print_info(f"Status Code: {response.status_code}") + + if response.status_code in [200, 201]: + result = response.json() + agent_id = result.get('id') + api_key = result.get('key', '') + + if agent_id: + self.print_success(f"Web scraping agent created with ID: {agent_id}") + if api_key: + self.print_success(f"Agent API key: {api_key[:20]}...") + self.test_results.append((test_name, True, f"ID: {agent_id}, API Key: Yes")) + return (agent_id, api_key) + else: + self.print_warning("Agent created but no API key (draft agent)") + self.test_results.append((test_name, True, f"ID: {agent_id}, API Key: No")) + return (agent_id, None) + else: + self.print_warning("Agent created but no ID returned") + self.test_results.append((test_name, False, "No ID returned")) + return None + else: + self.print_error(f"Expected 200/201, got {response.status_code}") + try: + error_data = response.json() + self.print_error(f"Error: {error_data.get('message', response.text[:200])}") + except Exception: + self.print_error(f"Response: {response.text[:500]}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + return None + + except requests.exceptions.RequestException as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + except Exception as e: + self.print_error(f"Unexpected error: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return None + + def test_compression_heavy_tool_usage(self, agent_result: Optional[tuple] = None) -> bool: + """Test compression with heavy tool usage (real API calls) + + This simulates a scenario where an agent makes many tool calls + (including read_webpage for web scraping), generating large responses + that should trigger compression. + + Args: + agent_result: Optional tuple of (agent_id, api_key) from agent creation + """ + endpoint = f"{self.base_url}/api/answer" + test_name = "Compression - Heavy Tool Usage" + + self.print_header(f"Testing {test_name}") + + if not self.token: + self.print_warning("Authentication required for compression tests") + self.test_results.append((test_name, True, "Skipped (auth required)")) + return False + + # Use provided agent or create one + if not agent_result: + self.print_info("No web scraping agent provided, creating one...") + agent_result = self.create_web_scraping_agent() + + if not agent_result: + self.print_warning("Could not create web scraping agent, using isNoneDoc instead") + agent_id = None + api_key = None + else: + agent_id, api_key = agent_result + + # Define URLs to scrape for testing + urls_to_scrape = [ + "https://docs.docsgpt.cloud/", + "https://docs.docsgpt.cloud/getting-started/quickstart", + "https://docs.docsgpt.cloud/getting-started/installation", + "https://docs.docsgpt.cloud/extensions/extensions-intro", + "https://github.com/arc53/DocsGPT", + ] + + # Make requests with tool usage + self.print_info("Making 10 consecutive requests to build up conversation history...") + self.print_info("Some requests will use read_webpage tool for web scraping...") + + current_conv_id = None + + for i in range(10): + # Alternate between regular questions and web scraping + if i < 5 and agent_id: + # Use web scraping for first 5 requests + url = urls_to_scrape[i % len(urls_to_scrape)] + question = f"Please read and summarize the content from this webpage: {url}" + else: + # Use regular questions for remaining requests + question = f"Tell me about Python topic number {i+1}: data structures, decorators, async, testing, etc. Please provide a comprehensive explanation." + + payload = { + "question": question, + "history": "[]", + "model_id": "gemini-2.5-pro", + } + + # Use agent if available, otherwise isNoneDoc + if agent_id: + payload["agent_id"] = agent_id + elif api_key: + payload["api_key"] = api_key + else: + payload["isNoneDoc"] = True + + if current_conv_id: + payload["conversation_id"] = current_conv_id + + try: + response = requests.post( + endpoint, + json=payload, + headers=self.headers, + timeout=90 # Longer timeout for web scraping + ) + + if response.status_code == 200: + result = response.json() + current_conv_id = result.get('conversation_id', current_conv_id) + answer_preview = result.get('answer', '')[:80] + self.print_success(f"Request {i+1}/10 completed (conv_id: {current_conv_id})") + self.print_info(f" Answer preview: {answer_preview}...") + else: + self.print_error(f"Request {i+1}/10 failed with status {response.status_code}") + self.test_results.append((test_name, False, f"Request {i+1} failed")) + return False + + time.sleep(2) # Small delay between requests + + except Exception as e: + self.print_error(f"Request {i+1}/10 failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + # Check if conversation was compressed by examining metadata + if current_conv_id: + self.print_info(f"Checking compression status for conversation {current_conv_id}") + # Note: This would require a /api/conversation/{id} endpoint to verify + self.print_success("Heavy tool usage test completed") + tool_info = "with read_webpage" if agent_id else "without tools" + self.test_results.append((test_name, True, f"10 requests {tool_info}, conv_id: {current_conv_id}")) + return True + else: + self.print_warning("No conversation_id received") + self.test_results.append((test_name, False, "No conversation_id")) + return False + + def test_compression_needle_in_haystack(self) -> bool: + """Test that compression preserves critical information + + This sends a long conversation with important info in the middle, + then asks about that info to verify it was preserved through compression. + """ + endpoint = f"{self.base_url}/api/answer" + test_name = "Compression - Needle in Haystack" + + self.print_header(f"Testing {test_name}") + + if not self.token: + self.print_warning("Authentication required for compression tests") + self.test_results.append((test_name, True, "Skipped (auth required)")) + return False + + conversation_id = None + + # Step 1: Send general questions + self.print_info("Step 1: Sending general questions...") + for i, question in enumerate([ + "Tell me about Python best practices in detail", + "Explain Python data structures comprehensively", + ]): + payload = { + "question": question, + "history": "[]", + "isNoneDoc": True, + "model_id": "gemini-2.5-pro", + } + + if conversation_id: + payload["conversation_id"] = conversation_id + + try: + response = requests.post(endpoint, json=payload, headers=self.headers, timeout=60) + if response.status_code == 200: + result = response.json() + conversation_id = result.get('conversation_id', conversation_id) + self.print_success(f"General question {i+1}/2 completed") + else: + self.print_error(f"Request failed with status {response.status_code}") + self.test_results.append((test_name, False, "General questions failed")) + return False + time.sleep(2) + except Exception as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + # Step 2: Send CRITICAL information + self.print_info("Step 2: Sending CRITICAL information to remember...") + critical_payload = { + "question": "Please remember this critical information: The production database password is stored in DB_PASSWORD_PROD environment variable. The backup runs at 3:00 AM UTC daily. Premium users have 10,000 req/hour limit.", + "history": "[]", + "isNoneDoc": True, + "model_id": "gemini-2.5-pro", + "conversation_id": conversation_id, + } + + try: + response = requests.post(endpoint, json=critical_payload, headers=self.headers, timeout=60) + if response.status_code == 200: + result = response.json() + conversation_id = result.get('conversation_id', conversation_id) + self.print_success("Critical information sent") + else: + self.print_error("Critical info request failed") + self.test_results.append((test_name, False, "Critical info failed")) + return False + time.sleep(2) + except Exception as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + # Step 3: Send more general questions to bury the critical info + self.print_info("Step 3: Sending more questions to bury the critical info...") + for i, question in enumerate([ + "Explain Python decorators in great detail", + "Tell me about Python async programming comprehensively", + ]): + payload = { + "question": question, + "history": "[]", + "isNoneDoc": True, + "model_id": "gemini-2.5-pro", + "conversation_id": conversation_id, + } + + try: + response = requests.post(endpoint, json=payload, headers=self.headers, timeout=60) + if response.status_code == 200: + result = response.json() + conversation_id = result.get('conversation_id', conversation_id) + self.print_success(f"Burying question {i+1}/2 completed") + else: + self.print_error("Request failed") + self.test_results.append((test_name, False, "Burying questions failed")) + return False + time.sleep(2) + except Exception as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + # Step 4: Ask about the critical information + self.print_info("Step 4: Testing if critical info was preserved...") + recall_payload = { + "question": "What was the database password environment variable I mentioned earlier?", + "history": "[]", + "isNoneDoc": True, + "model_id": "gemini-2.5-pro", + "conversation_id": conversation_id, + } + + try: + response = requests.post(endpoint, json=recall_payload, headers=self.headers, timeout=60) + if response.status_code == 200: + result = response.json() + answer = result.get('answer', '').lower() + + # Check if the critical info was preserved + if 'db_password_prod' in answer or 'database password' in answer: + self.print_success("✓ Critical information preserved through compression!") + self.print_info(f"Answer: {answer[:150]}...") + self.test_results.append((test_name, True, "Info preserved")) + return True + else: + self.print_warning("Critical information may have been lost") + self.print_info(f"Answer: {answer[:150]}...") + self.test_results.append((test_name, False, "Info not preserved")) + return False + else: + self.print_error("Recall request failed") + self.test_results.append((test_name, False, "Recall failed")) + return False + except Exception as e: + self.print_error(f"Request failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + return False + + def print_summary(self): + """Print test results summary""" + self.print_header("Test Results Summary") + + passed = sum(1 for _, success, _ in self.test_results if success) + failed = len(self.test_results) - passed + + print(f"\n{Colors.BOLD}Total Tests: {len(self.test_results)}{Colors.ENDC}") + print(f"{Colors.OKGREEN}Passed: {passed}{Colors.ENDC}") + print(f"{Colors.FAIL}Failed: {failed}{Colors.ENDC}\n") + + print(f"{Colors.BOLD}Detailed Results:{Colors.ENDC}") + for test_name, success, message in self.test_results: + status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if success else f"{Colors.FAIL}FAIL{Colors.ENDC}" + print(f" {status} - {test_name}: {message}") + + print() + return failed == 0 + + def run_all_tests(self): + """Run all integration tests""" + self.print_header("DocsGPT Integration Tests") + self.print_info(f"Base URL: {self.base_url}") + if self.token: + self.print_info(f"Authentication: Yes ({self.token_source})") + else: + self.print_info("Authentication: No (agent-related tests will be skipped)") + + # Test 1: Stream endpoint without agent + self.test_stream_endpoint() + time.sleep(1) + + # Test 2: Answer endpoint without agent + self.test_answer_endpoint() + time.sleep(1) + + # Test 3: Model validation + self.test_model_validation() + time.sleep(1) + + # Test 4: Compression tests (requires token) + if self.token: + self.print_info("Running compression integration tests...") + time.sleep(1) + + # Test 4a: Heavy tool usage compression + self.test_compression_heavy_tool_usage() + time.sleep(2) + + # Test 4b: Needle in haystack compression + self.test_compression_needle_in_haystack() + time.sleep(1) + else: + self.print_info("Skipping compression tests (no authentication)") + + # Test 5: Upload text source (requires token) - faster than crawler + task_id = self.upload_text_source() + source_id = None + + if task_id: + # Test 6: Get source ID from completed task + source_id = self.get_source_id_from_task(task_id) + if source_id: + self.print_success(f"Source created with ID: {source_id}") + else: + self.print_warning("Could not retrieve source ID from task - trying crawler fallback") + # Fallback to crawler if text upload failed + crawler_task_id = self.upload_crawler_source() + if crawler_task_id: + source_id = self.get_source_id_from_task(crawler_task_id) + if source_id: + self.print_success(f"Source created with ID (crawler): {source_id}") + else: + self.print_warning("Could not retrieve source ID from crawler task either") + + # Test 7: Create published agent (for API key testing) - default behavior + # Published agents get an API key automatically + published_result = self.create_agent(source_id=source_id, published=True) + + if published_result: + agent_id, api_key = published_result + time.sleep(1) + + if api_key: + # Test 8 & 9: Test with API key (primary method) + self.test_api_key_endpoint(api_key, endpoint_type="stream") + time.sleep(1) + self.test_api_key_endpoint(api_key, endpoint_type="answer") + time.sleep(1) + + # Test 10: Also test with agent_id for completeness + self.test_stream_endpoint(agent_id=agent_id) + time.sleep(1) + self.test_answer_endpoint(agent_id=agent_id) + + # Test 11: If agent has a source, test source-specific questions + if source_id: + time.sleep(1) + self.print_info("Testing published agent with source-specific questions...") + + test_name = "Published agent with source (DocsGPT question)" + self.print_header(f"Testing {test_name}") + + payload = { + "question": "How do I install DocsGPT?", + "history": "[]", + "api_key": api_key, + } + + try: + response = requests.post( + f"{self.base_url}/api/answer", + json=payload, + headers=self.headers, + timeout=30 + ) + + if response.status_code == 200: + result = response.json() + answer = result.get('answer', '') + self.print_success(f"Answer received: {answer[:100]}...") + + if any(word in answer.lower() for word in ['install', 'docker', 'setup']): + self.print_success("Answer contains relevant information from source") + self.test_results.append((test_name, True, "Success")) + else: + self.print_warning("Answer may not be using source data") + self.test_results.append((test_name, True, "Answer unclear")) + else: + self.print_error(f"Status {response.status_code}") + self.test_results.append((test_name, False, f"Status {response.status_code}")) + + except Exception as e: + self.print_error(f"Test failed: {str(e)}") + self.test_results.append((test_name, False, str(e))) + else: + self.print_warning("Published agent created but no API key received") + self.print_info("Testing with agent_id instead...") + # Fallback to agent_id testing + self.test_stream_endpoint(agent_id=agent_id) + time.sleep(1) + self.test_answer_endpoint(agent_id=agent_id) + else: + if self.token: + self.print_warning("Published agent creation failed - some tests skipped") + else: + self.print_info("Skipping agent tests (no authentication token)") + + # Print summary + success = self.print_summary() + return 0 if success else 1 + + +def main(): + parser = argparse.ArgumentParser( + description='Integration test script for DocsGPT API endpoints', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Test local instance + python tests/test_integration.py # auto-generates JWT token from local secret if possible + + # Test remote instance + python tests/test_integration.py --base-url https://app.docsgpt.com + + # Test with authentication (required for agent creation) + python tests/test_integration.py --token YOUR_JWT_TOKEN + + # Test specific endpoint only + python tests/test_integration.py --base-url http://localhost:7091 --token YOUR_TOKEN + """ + ) + + parser.add_argument( + '--base-url', + default='http://localhost:7091', + help='Base URL of DocsGPT instance (default: http://localhost:7091)' + ) + + parser.add_argument( + '--token', + help='JWT authentication token (auto-generated from local secret when available)' + ) + + args = parser.parse_args() + + token = args.token + token_source = "provided via --token" if token else "auto-generated from local JWT secret" + + if not token: + token, token_error = generate_default_token() + if token: + print(f"{Colors.OKCYAN}ℹ Using auto-generated JWT token from local secret{Colors.ENDC}") + else: + token_source = "none" + if token_error: + print(f"{Colors.WARNING}⚠ Could not auto-generate JWT token: {token_error}{Colors.ENDC}") + print(f"{Colors.WARNING}⚠ Agent creation tests will be skipped unless you provide --token{Colors.ENDC}") + + try: + tester = DocsGPTTester(args.base_url, token, token_source=token_source) + exit_code = tester.run_all_tests() + sys.exit(exit_code) + except KeyboardInterrupt: + print(f"\n{Colors.WARNING}Tests interrupted by user{Colors.ENDC}") + sys.exit(1) + except Exception as e: + print(f"\n{Colors.FAIL}Fatal error: {str(e)}{Colors.ENDC}") + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/tests/test_model_validation.py b/tests/test_model_validation.py new file mode 100644 index 00000000..379ccbf6 --- /dev/null +++ b/tests/test_model_validation.py @@ -0,0 +1,106 @@ +""" +Tests for model validation and base_url functionality +""" +import pytest +from application.core.model_settings import ( + AvailableModel, + ModelCapabilities, + ModelProvider, + ModelRegistry, +) +from application.core.model_utils import ( + get_base_url_for_model, + validate_model_id, +) + + +@pytest.mark.unit +def test_model_with_base_url(): + """Test that AvailableModel can store and retrieve base_url""" + model = AvailableModel( + id="test-model", + provider=ModelProvider.OPENAI, + display_name="Test Model", + description="Test model with custom base URL", + base_url="https://custom-endpoint.com/v1", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=8192, + ), + ) + + assert model.base_url == "https://custom-endpoint.com/v1" + assert model.id == "test-model" + assert model.provider == ModelProvider.OPENAI + + # Test to_dict includes base_url + model_dict = model.to_dict() + assert "base_url" in model_dict + assert model_dict["base_url"] == "https://custom-endpoint.com/v1" + + +@pytest.mark.unit +def test_model_without_base_url(): + """Test that models without base_url still work""" + model = AvailableModel( + id="test-model-no-url", + provider=ModelProvider.OPENAI, + display_name="Test Model", + description="Test model without base URL", + capabilities=ModelCapabilities( + supports_tools=True, + context_window=8192, + ), + ) + + assert model.base_url is None + + # Test to_dict doesn't include base_url when None + model_dict = model.to_dict() + assert "base_url" not in model_dict + + +@pytest.mark.unit +def test_validate_model_id(): + """Test model_id validation""" + # Get the registry instance to check what models are available + ModelRegistry.get_instance() + + # Test with a model that should exist (docsgpt-local is always added) + assert validate_model_id("docsgpt-local") is True + + # Test with invalid model_id + assert validate_model_id("invalid-model-xyz-123") is False + + # Test with None + assert validate_model_id(None) is False + + +@pytest.mark.unit +def test_get_base_url_for_model(): + """Test retrieving base_url for a model""" + # Test with a model that doesn't have base_url + result = get_base_url_for_model("docsgpt-local") + assert result is None # docsgpt-local doesn't have custom base_url + + # Test with invalid model + result = get_base_url_for_model("invalid-model") + assert result is None + + +@pytest.mark.unit +def test_model_validation_error_message(): + """Test that validation provides helpful error messages""" + from application.api.answer.services.stream_processor import StreamProcessor + + # Create processor with invalid model_id + data = {"model_id": "invalid-model-xyz"} + processor = StreamProcessor(data, None) + + # Should raise ValueError with helpful message + with pytest.raises(ValueError) as exc_info: + processor._validate_and_set_model() + + error_msg = str(exc_info.value) + assert "Invalid model_id 'invalid-model-xyz'" in error_msg + assert "Available models:" in error_msg diff --git a/tests/test_token_management.py b/tests/test_token_management.py new file mode 100644 index 00000000..0b166953 --- /dev/null +++ b/tests/test_token_management.py @@ -0,0 +1,314 @@ +""" +Tests for token management and compression features. + +NOTE: These tests are for future planned features that are not yet implemented. +They are skipped until the following modules are created: +- application.compression (DocumentCompressor, HistoryCompressor, etc.) +- application.core.token_budget (TokenBudgetManager) +""" +# ruff: noqa: F821 +import pytest + +pytest.skip( + "Token management features not yet implemented - planned for future release", + allow_module_level=True, +) + + +class TestTokenBudgetManager: + """Test TokenBudgetManager functionality""" + + def test_calculate_budget(self): + """Test budget calculation""" + manager = TokenBudgetManager(model_id="gpt-4o") + budget = manager.calculate_budget() + + assert budget.total_budget > 0 + assert budget.system_prompt > 0 + assert budget.chat_history > 0 + assert budget.retrieved_docs > 0 + + def test_measure_usage(self): + """Test token usage measurement""" + manager = TokenBudgetManager(model_id="gpt-4o") + + usage = manager.measure_usage( + system_prompt="You are a helpful assistant.", + current_query="What is Python?", + chat_history=[ + {"prompt": "Hello", "response": "Hi there!"}, + {"prompt": "How are you?", "response": "I'm doing well, thanks!"}, + ], + ) + + assert usage.total > 0 + assert usage.system_prompt > 0 + assert usage.current_query > 0 + assert usage.chat_history > 0 + + def test_compression_recommendation(self): + """Test compression recommendation generation""" + manager = TokenBudgetManager(model_id="gpt-4o") + + # Create scenario with excessive history + large_history = [ + {"prompt": f"Question {i}" * 100, "response": f"Answer {i}" * 100} + for i in range(100) + ] + + budget, usage, recommendation = manager.check_and_recommend( + system_prompt="You are a helpful assistant.", + current_query="What is Python?", + chat_history=large_history, + ) + + # Should recommend compression + assert recommendation.needs_compression() + assert recommendation.compress_history + + +class TestHistoryCompressor: + """Test HistoryCompressor functionality""" + + def test_sliding_window_compression(self): + """Test sliding window compression strategy""" + compressor = HistoryCompressor() + + history = [ + {"prompt": f"Question {i}", "response": f"Answer {i}"} for i in range(20) + ] + + compressed, metadata = compressor.compress( + history, target_tokens=500, strategy="sliding_window" + ) + + assert len(compressed) < len(history) + assert metadata["original_messages"] == 20 + assert metadata["compressed_messages"] < 20 + assert metadata["strategy"] == "sliding_window" + + def test_preserve_tool_calls(self): + """Test that tool calls are preserved during compression""" + compressor = HistoryCompressor() + + history = [ + {"prompt": "Question 1", "response": "Answer 1"}, + { + "prompt": "Use a tool", + "response": "Tool used", + "tool_calls": [{"tool_name": "search", "result": "Found something"}], + }, + {"prompt": "Question 3", "response": "Answer 3"}, + ] + + compressed, metadata = compressor.compress( + history, target_tokens=200, strategy="sliding_window", preserve_tool_calls=True + ) + + # Tool call message should be preserved + has_tool_calls = any("tool_calls" in msg for msg in compressed) + assert has_tool_calls + + +class TestDocumentCompressor: + """Test DocumentCompressor functionality""" + + def test_rerank_compression(self): + """Test re-ranking compression strategy""" + compressor = DocumentCompressor() + + docs = [ + {"text": f"Document {i} with some content here" * 20, "title": f"Doc {i}"} + for i in range(10) + ] + + compressed, metadata = compressor.compress( + docs, target_tokens=500, query="Document 5", strategy="rerank" + ) + + assert len(compressed) < len(docs) + assert metadata["original_docs"] == 10 + assert metadata["strategy"] == "rerank" + + def test_excerpt_extraction(self): + """Test excerpt extraction strategy""" + compressor = DocumentCompressor() + + docs = [ + { + "text": "This is a long document. " * 100 + + "Python is great. " + + "More text here. " * 100, + "title": "Python Guide", + } + ] + + compressed, metadata = compressor.compress( + docs, target_tokens=300, query="Python", strategy="excerpt" + ) + + assert metadata["excerpts_created"] > 0 + # Excerpt should contain the query term + assert "python" in compressed[0]["text"].lower() + + +class TestToolResultCompressor: + """Test ToolResultCompressor functionality""" + + def test_truncate_large_results(self): + """Test truncation of large tool results""" + compressor = ToolResultCompressor() + + tool_results = [ + { + "tool_name": "search", + "result": "Very long result " * 1000, + "arguments": {}, + } + ] + + compressed, metadata = compressor.compress( + tool_results, target_tokens=100, strategy="truncate" + ) + + assert metadata["results_truncated"] > 0 + # Result should be shorter + compressed_result_len = len(str(compressed[0]["result"])) + original_result_len = len(tool_results[0]["result"]) + assert compressed_result_len < original_result_len + + def test_extract_json_fields(self): + """Test extraction of key fields from JSON results""" + compressor = ToolResultCompressor() + + tool_results = [ + { + "tool_name": "api_call", + "result": { + "data": {"important": "value"}, + "metadata": {"verbose": "information" * 100}, + "debug": {"lots": "of data" * 100}, + }, + "arguments": {}, + } + ] + + compressed, metadata = compressor.compress( + tool_results, target_tokens=100, strategy="extract" + ) + + # Should keep important fields, discard verbose ones + assert "data" in compressed[0]["result"] + + +class TestPromptOptimizer: + """Test PromptOptimizer functionality""" + + def test_compress_tool_descriptions(self): + """Test compression of tool descriptions""" + optimizer = PromptOptimizer() + + tools = [ + { + "type": "function", + "function": { + "name": f"tool_{i}", + "description": "This is a very long description " * 50, + "parameters": {}, + }, + } + for i in range(10) + ] + + optimized, metadata = optimizer.optimize_tools( + tools, target_tokens=500, strategy="compress" + ) + + assert metadata["optimized_tokens"] < metadata["original_tokens"] + assert metadata["descriptions_compressed"] > 0 + + def test_lazy_load_tools(self): + """Test lazy loading of tools based on query""" + optimizer = PromptOptimizer() + + tools = [ + { + "type": "function", + "function": { + "name": "search_tool", + "description": "Search for information", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "calculate_tool", + "description": "Perform calculations", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "other_tool", + "description": "Do something else", + "parameters": {}, + }, + }, + ] + + optimized, metadata = optimizer.optimize_tools( + tools, target_tokens=200, query="I want to search for something", strategy="lazy_load" + ) + + # Should prefer search tool + assert len(optimized) < len(tools) + tool_names = [t["function"]["name"] for t in optimized] + # Search tool should be included due to query relevance + assert any("search" in name for name in tool_names) + + +def test_integration_compression_workflow(): + """Test complete compression workflow""" + # Simulate a scenario with large inputs + manager = TokenBudgetManager(model_id="gpt-4o") + history_compressor = HistoryCompressor() + doc_compressor = DocumentCompressor() + + # Large chat history + history = [ + {"prompt": f"Question {i}" * 50, "response": f"Answer {i}" * 50} + for i in range(50) + ] + + # Large documents + docs = [ + {"text": f"Document {i} content" * 100, "title": f"Doc {i}"} for i in range(20) + ] + + # Check budget + budget, usage, recommendation = manager.check_and_recommend( + system_prompt="You are a helpful assistant.", + current_query="What is Python?", + chat_history=history, + retrieved_docs=docs, + ) + + # Should need compression + assert recommendation.needs_compression() + + # Apply compression + if recommendation.compress_history: + compressed_history, hist_meta = history_compressor.compress( + history, recommendation.target_history_tokens or budget.chat_history + ) + assert len(compressed_history) < len(history) + + if recommendation.compress_docs: + compressed_docs, doc_meta = doc_compressor.compress( + docs, + recommendation.target_docs_tokens or budget.retrieved_docs, + query="Python", + ) + assert len(compressed_docs) < len(docs)