mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-02 10:03:15 +00:00
feat: context compression
This commit is contained in:
@@ -91,7 +91,7 @@ def test_clean_messages_google_basic():
|
||||
{"function_call": {"name": "fn", "args": {"a": 1}}},
|
||||
]},
|
||||
]
|
||||
cleaned = llm._clean_messages_google(msgs)
|
||||
cleaned, system_instruction = llm._clean_messages_google(msgs)
|
||||
|
||||
assert all(hasattr(c, "role") and hasattr(c, "parts") for c in cleaned)
|
||||
assert any(c.role == "model" for c in cleaned)
|
||||
|
||||
325
tests/test_agent_token_tracking.py
Normal file
325
tests/test_agent_token_tracking.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.llm.handlers.base import LLMHandler, ToolCall
|
||||
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
"""Mock agent for testing"""
|
||||
|
||||
def _gen_inner(self, query, log_context=None):
|
||||
yield {"answer": "test"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent for testing"""
|
||||
agent = MockAgent(
|
||||
endpoint="test",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4o",
|
||||
api_key="test-key",
|
||||
)
|
||||
agent.llm = Mock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_handler():
|
||||
"""Create a mock LLM handler"""
|
||||
handler = Mock(spec=LLMHandler)
|
||||
handler.tool_calls = []
|
||||
return handler
|
||||
|
||||
|
||||
class TestAgentTokenTracking:
|
||||
"""Test suite for agent token tracking during execution"""
|
||||
|
||||
def test_calculate_current_context_tokens(self, mock_agent):
|
||||
"""Test token calculation for current context"""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
tokens = mock_agent._calculate_current_context_tokens(messages)
|
||||
|
||||
# Should count tokens from all messages
|
||||
assert tokens > 0
|
||||
# Rough estimate: ~20-40 tokens for this conversation
|
||||
assert 15 < tokens < 60
|
||||
|
||||
def test_calculate_tokens_with_tool_calls(self, mock_agent):
|
||||
"""Test token calculation includes tool call content"""
|
||||
messages = [
|
||||
{"role": "system", "content": "Test"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "search_tool",
|
||||
"args": {"query": "test"},
|
||||
"call_id": "123",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "search_tool",
|
||||
"response": {"result": "Found 10 results"},
|
||||
"call_id": "123",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tokens = mock_agent._calculate_current_context_tokens(messages)
|
||||
|
||||
# Should include tool call tokens
|
||||
assert tokens > 0
|
||||
|
||||
@patch("application.core.model_utils.get_token_limit")
|
||||
@patch("application.core.settings.settings")
|
||||
def test_check_context_limit_below_threshold(
|
||||
self, mock_settings, mock_get_token_limit, mock_agent
|
||||
):
|
||||
"""Test context limit check when below threshold"""
|
||||
mock_get_token_limit.return_value = 128000
|
||||
mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Short message"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
|
||||
# Should return False for small conversation
|
||||
result = mock_agent._check_context_limit(messages)
|
||||
assert result is False
|
||||
|
||||
# Should track current token count
|
||||
assert mock_agent.current_token_count > 0
|
||||
assert mock_agent.current_token_count < 128000 * 0.8
|
||||
|
||||
@patch("application.core.model_utils.get_token_limit")
|
||||
@patch("application.core.settings.settings")
|
||||
def test_check_context_limit_above_threshold(
|
||||
self, mock_settings, mock_get_token_limit, mock_agent
|
||||
):
|
||||
"""Test context limit check when above threshold"""
|
||||
mock_get_token_limit.return_value = 100 # Very small limit for testing
|
||||
mock_settings.COMPRESSION_THRESHOLD_PERCENTAGE = 0.8
|
||||
|
||||
# Create messages that will exceed 80 tokens (80% of 100)
|
||||
messages = [
|
||||
{"role": "system", "content": "a " * 50}, # ~50 tokens
|
||||
{"role": "user", "content": "b " * 50}, # ~50 tokens
|
||||
]
|
||||
|
||||
# Should return True when exceeding threshold
|
||||
result = mock_agent._check_context_limit(messages)
|
||||
assert result is True
|
||||
|
||||
@patch("application.agents.base.logger")
|
||||
def test_check_context_limit_error_handling(self, mock_logger, mock_agent):
|
||||
"""Test error handling in context limit check"""
|
||||
# Force an error by making get_token_limit fail
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", side_effect=Exception("Test error")
|
||||
):
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
|
||||
result = mock_agent._check_context_limit(messages)
|
||||
|
||||
# Should return False on error (safe default)
|
||||
assert result is False
|
||||
# Should log the error
|
||||
assert mock_logger.error.called
|
||||
|
||||
def test_context_limit_flag_initialization(self, mock_agent):
|
||||
"""Test that context limit flag is initialized"""
|
||||
assert hasattr(mock_agent, "context_limit_reached")
|
||||
assert mock_agent.context_limit_reached is False
|
||||
|
||||
assert hasattr(mock_agent, "current_token_count")
|
||||
assert mock_agent.current_token_count == 0
|
||||
|
||||
|
||||
class TestLLMHandlerTokenTracking:
|
||||
"""Test suite for LLM handler token tracking"""
|
||||
|
||||
@patch("application.llm.handlers.base.logger")
|
||||
def test_handle_tool_calls_stops_at_limit(self, mock_logger):
|
||||
"""Test that tool execution stops when context limit is reached"""
|
||||
from application.llm.handlers.base import LLMHandler
|
||||
|
||||
# Create a concrete handler for testing
|
||||
class TestHandler(LLMHandler):
|
||||
def parse_response(self, response):
|
||||
pass
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {"role": "tool", "content": str(result)}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
yield ""
|
||||
|
||||
handler = TestHandler()
|
||||
|
||||
# Create mock agent that hits limit on second tool
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def check_limit_side_effect(messages):
|
||||
call_count[0] += 1
|
||||
# Return True on second call (second tool)
|
||||
return call_count[0] >= 2
|
||||
|
||||
mock_agent._check_context_limit = Mock(side_effect=check_limit_side_effect)
|
||||
mock_agent._execute_tool_action = Mock(
|
||||
return_value=iter([{"type": "tool_call", "data": {}}])
|
||||
)
|
||||
|
||||
# Create multiple tool calls
|
||||
tool_calls = [
|
||||
ToolCall(id="1", name="tool1", arguments={}),
|
||||
ToolCall(id="2", name="tool2", arguments={}),
|
||||
ToolCall(id="3", name="tool3", arguments={}),
|
||||
]
|
||||
|
||||
messages = []
|
||||
tools_dict = {}
|
||||
|
||||
# Execute tool calls
|
||||
results = list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages))
|
||||
|
||||
# First tool should execute
|
||||
assert mock_agent._execute_tool_action.call_count == 1
|
||||
|
||||
# Should have yielded skip messages for tools 2 and 3
|
||||
skip_messages = [r for r in results if r.get("type") == "tool_call" and r.get("data", {}).get("status") == "skipped"]
|
||||
assert len(skip_messages) == 2
|
||||
|
||||
# Should have set the flag
|
||||
assert mock_agent.context_limit_reached is True
|
||||
|
||||
# Should have logged warning
|
||||
assert mock_logger.warning.called
|
||||
|
||||
def test_handle_tool_calls_all_execute_when_no_limit(self):
|
||||
"""Test that all tools execute when under limit"""
|
||||
from application.llm.handlers.base import LLMHandler
|
||||
|
||||
class TestHandler(LLMHandler):
|
||||
def parse_response(self, response):
|
||||
pass
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {"role": "tool", "content": str(result)}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
yield ""
|
||||
|
||||
handler = TestHandler()
|
||||
|
||||
# Create mock agent that never hits limit
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = False
|
||||
mock_agent._check_context_limit = Mock(return_value=False)
|
||||
mock_agent._execute_tool_action = Mock(
|
||||
return_value=iter([{"type": "tool_call", "data": {}}])
|
||||
)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(id="1", name="tool1", arguments={}),
|
||||
ToolCall(id="2", name="tool2", arguments={}),
|
||||
ToolCall(id="3", name="tool3", arguments={}),
|
||||
]
|
||||
|
||||
messages = []
|
||||
tools_dict = {}
|
||||
|
||||
# Execute tool calls
|
||||
list(handler.handle_tool_calls(mock_agent, tool_calls, tools_dict, messages))
|
||||
|
||||
# All 3 tools should execute
|
||||
assert mock_agent._execute_tool_action.call_count == 3
|
||||
|
||||
# Should not have set the flag
|
||||
assert mock_agent.context_limit_reached is False
|
||||
|
||||
@patch("application.llm.handlers.base.logger")
|
||||
def test_handle_streaming_adds_warning_message(self, mock_logger):
|
||||
"""Test that streaming handler adds warning when limit reached"""
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
class TestHandler(LLMHandler):
|
||||
def parse_response(self, response):
|
||||
if isinstance(response, dict) and response.get("type") == "tool_call":
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="1", name="test", arguments={}, index=0)],
|
||||
finish_reason="tool_calls",
|
||||
raw_response=None,
|
||||
)
|
||||
else:
|
||||
return LLMResponse(
|
||||
content="Done",
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call, result):
|
||||
return {"role": "tool", "content": str(result)}
|
||||
|
||||
def _iterate_stream(self, response):
|
||||
if response == "first":
|
||||
yield {"type": "tool_call"} # Object to be parsed, not string
|
||||
else:
|
||||
yield {"type": "stop"} # Object to be parsed, not string
|
||||
|
||||
handler = TestHandler()
|
||||
|
||||
# Create mock agent with limit reached
|
||||
mock_agent = Mock()
|
||||
mock_agent.context_limit_reached = True
|
||||
mock_agent.model_id = "gpt-4o"
|
||||
mock_agent.tools = []
|
||||
mock_agent.llm = Mock()
|
||||
mock_agent.llm.gen_stream = Mock(return_value="second")
|
||||
|
||||
def tool_handler_gen(*args):
|
||||
yield {"type": "tool", "data": {}}
|
||||
return []
|
||||
|
||||
# Mock handle_tool_calls to return messages and set flag
|
||||
with patch.object(
|
||||
handler, "handle_tool_calls", return_value=tool_handler_gen()
|
||||
):
|
||||
messages = []
|
||||
tools_dict = {}
|
||||
|
||||
# Execute streaming
|
||||
results = list(handler.handle_streaming(mock_agent, "first", tools_dict, messages))
|
||||
|
||||
# Should have called gen_stream with tools=None (disabled)
|
||||
mock_agent.llm.gen_stream.assert_called()
|
||||
call_kwargs = mock_agent.llm.gen_stream.call_args.kwargs
|
||||
assert call_kwargs.get("tools") is None
|
||||
|
||||
# Should have logged the warning
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
1082
tests/test_compression_service.py
Normal file
1082
tests/test_compression_service.py
Normal file
File diff suppressed because it is too large
Load Diff
1287
tests/test_integration.py
Executable file
1287
tests/test_integration.py
Executable file
File diff suppressed because it is too large
Load Diff
106
tests/test_model_validation.py
Normal file
106
tests/test_model_validation.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Tests for model validation and base_url functionality
|
||||
"""
|
||||
import pytest
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelRegistry,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_base_url_for_model,
|
||||
validate_model_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_with_base_url():
|
||||
"""Test that AvailableModel can store and retrieve base_url"""
|
||||
model = AvailableModel(
|
||||
id="test-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Test Model",
|
||||
description="Test model with custom base URL",
|
||||
base_url="https://custom-endpoint.com/v1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=8192,
|
||||
),
|
||||
)
|
||||
|
||||
assert model.base_url == "https://custom-endpoint.com/v1"
|
||||
assert model.id == "test-model"
|
||||
assert model.provider == ModelProvider.OPENAI
|
||||
|
||||
# Test to_dict includes base_url
|
||||
model_dict = model.to_dict()
|
||||
assert "base_url" in model_dict
|
||||
assert model_dict["base_url"] == "https://custom-endpoint.com/v1"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_without_base_url():
|
||||
"""Test that models without base_url still work"""
|
||||
model = AvailableModel(
|
||||
id="test-model-no-url",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Test Model",
|
||||
description="Test model without base URL",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=8192,
|
||||
),
|
||||
)
|
||||
|
||||
assert model.base_url is None
|
||||
|
||||
# Test to_dict doesn't include base_url when None
|
||||
model_dict = model.to_dict()
|
||||
assert "base_url" not in model_dict
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_model_id():
|
||||
"""Test model_id validation"""
|
||||
# Get the registry instance to check what models are available
|
||||
registry = ModelRegistry.get_instance()
|
||||
|
||||
# Test with a model that should exist (docsgpt-local is always added)
|
||||
assert validate_model_id("docsgpt-local") is True
|
||||
|
||||
# Test with invalid model_id
|
||||
assert validate_model_id("invalid-model-xyz-123") is False
|
||||
|
||||
# Test with None
|
||||
assert validate_model_id(None) is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_base_url_for_model():
|
||||
"""Test retrieving base_url for a model"""
|
||||
# Test with a model that doesn't have base_url
|
||||
result = get_base_url_for_model("docsgpt-local")
|
||||
assert result is None # docsgpt-local doesn't have custom base_url
|
||||
|
||||
# Test with invalid model
|
||||
result = get_base_url_for_model("invalid-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_validation_error_message():
|
||||
"""Test that validation provides helpful error messages"""
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
# Create processor with invalid model_id
|
||||
data = {"model_id": "invalid-model-xyz"}
|
||||
processor = StreamProcessor(data, None)
|
||||
|
||||
# Should raise ValueError with helpful message
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
processor._validate_and_set_model()
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Invalid model_id 'invalid-model-xyz'" in error_msg
|
||||
assert "Available models:" in error_msg
|
||||
313
tests/test_token_management.py
Normal file
313
tests/test_token_management.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Tests for token management and compression features.
|
||||
|
||||
NOTE: These tests are for future planned features that are not yet implemented.
|
||||
They are skipped until the following modules are created:
|
||||
- application.compression (DocumentCompressor, HistoryCompressor, etc.)
|
||||
- application.core.token_budget (TokenBudgetManager)
|
||||
"""
|
||||
import pytest
|
||||
|
||||
pytest.skip(
|
||||
"Token management features not yet implemented - planned for future release",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenBudgetManager:
|
||||
"""Test TokenBudgetManager functionality"""
|
||||
|
||||
def test_calculate_budget(self):
|
||||
"""Test budget calculation"""
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
budget = manager.calculate_budget()
|
||||
|
||||
assert budget.total_budget > 0
|
||||
assert budget.system_prompt > 0
|
||||
assert budget.chat_history > 0
|
||||
assert budget.retrieved_docs > 0
|
||||
|
||||
def test_measure_usage(self):
|
||||
"""Test token usage measurement"""
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
|
||||
usage = manager.measure_usage(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
current_query="What is Python?",
|
||||
chat_history=[
|
||||
{"prompt": "Hello", "response": "Hi there!"},
|
||||
{"prompt": "How are you?", "response": "I'm doing well, thanks!"},
|
||||
],
|
||||
)
|
||||
|
||||
assert usage.total > 0
|
||||
assert usage.system_prompt > 0
|
||||
assert usage.current_query > 0
|
||||
assert usage.chat_history > 0
|
||||
|
||||
def test_compression_recommendation(self):
|
||||
"""Test compression recommendation generation"""
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
|
||||
# Create scenario with excessive history
|
||||
large_history = [
|
||||
{"prompt": f"Question {i}" * 100, "response": f"Answer {i}" * 100}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
budget, usage, recommendation = manager.check_and_recommend(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
current_query="What is Python?",
|
||||
chat_history=large_history,
|
||||
)
|
||||
|
||||
# Should recommend compression
|
||||
assert recommendation.needs_compression()
|
||||
assert recommendation.compress_history
|
||||
|
||||
|
||||
class TestHistoryCompressor:
|
||||
"""Test HistoryCompressor functionality"""
|
||||
|
||||
def test_sliding_window_compression(self):
|
||||
"""Test sliding window compression strategy"""
|
||||
compressor = HistoryCompressor()
|
||||
|
||||
history = [
|
||||
{"prompt": f"Question {i}", "response": f"Answer {i}"} for i in range(20)
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
history, target_tokens=500, strategy="sliding_window"
|
||||
)
|
||||
|
||||
assert len(compressed) < len(history)
|
||||
assert metadata["original_messages"] == 20
|
||||
assert metadata["compressed_messages"] < 20
|
||||
assert metadata["strategy"] == "sliding_window"
|
||||
|
||||
def test_preserve_tool_calls(self):
|
||||
"""Test that tool calls are preserved during compression"""
|
||||
compressor = HistoryCompressor()
|
||||
|
||||
history = [
|
||||
{"prompt": "Question 1", "response": "Answer 1"},
|
||||
{
|
||||
"prompt": "Use a tool",
|
||||
"response": "Tool used",
|
||||
"tool_calls": [{"tool_name": "search", "result": "Found something"}],
|
||||
},
|
||||
{"prompt": "Question 3", "response": "Answer 3"},
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
history, target_tokens=200, strategy="sliding_window", preserve_tool_calls=True
|
||||
)
|
||||
|
||||
# Tool call message should be preserved
|
||||
has_tool_calls = any("tool_calls" in msg for msg in compressed)
|
||||
assert has_tool_calls
|
||||
|
||||
|
||||
class TestDocumentCompressor:
|
||||
"""Test DocumentCompressor functionality"""
|
||||
|
||||
def test_rerank_compression(self):
|
||||
"""Test re-ranking compression strategy"""
|
||||
compressor = DocumentCompressor()
|
||||
|
||||
docs = [
|
||||
{"text": f"Document {i} with some content here" * 20, "title": f"Doc {i}"}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
docs, target_tokens=500, query="Document 5", strategy="rerank"
|
||||
)
|
||||
|
||||
assert len(compressed) < len(docs)
|
||||
assert metadata["original_docs"] == 10
|
||||
assert metadata["strategy"] == "rerank"
|
||||
|
||||
def test_excerpt_extraction(self):
|
||||
"""Test excerpt extraction strategy"""
|
||||
compressor = DocumentCompressor()
|
||||
|
||||
docs = [
|
||||
{
|
||||
"text": "This is a long document. " * 100
|
||||
+ "Python is great. "
|
||||
+ "More text here. " * 100,
|
||||
"title": "Python Guide",
|
||||
}
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
docs, target_tokens=300, query="Python", strategy="excerpt"
|
||||
)
|
||||
|
||||
assert metadata["excerpts_created"] > 0
|
||||
# Excerpt should contain the query term
|
||||
assert "python" in compressed[0]["text"].lower()
|
||||
|
||||
|
||||
class TestToolResultCompressor:
|
||||
"""Test ToolResultCompressor functionality"""
|
||||
|
||||
def test_truncate_large_results(self):
|
||||
"""Test truncation of large tool results"""
|
||||
compressor = ToolResultCompressor()
|
||||
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": "search",
|
||||
"result": "Very long result " * 1000,
|
||||
"arguments": {},
|
||||
}
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
tool_results, target_tokens=100, strategy="truncate"
|
||||
)
|
||||
|
||||
assert metadata["results_truncated"] > 0
|
||||
# Result should be shorter
|
||||
compressed_result_len = len(str(compressed[0]["result"]))
|
||||
original_result_len = len(tool_results[0]["result"])
|
||||
assert compressed_result_len < original_result_len
|
||||
|
||||
def test_extract_json_fields(self):
|
||||
"""Test extraction of key fields from JSON results"""
|
||||
compressor = ToolResultCompressor()
|
||||
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": "api_call",
|
||||
"result": {
|
||||
"data": {"important": "value"},
|
||||
"metadata": {"verbose": "information" * 100},
|
||||
"debug": {"lots": "of data" * 100},
|
||||
},
|
||||
"arguments": {},
|
||||
}
|
||||
]
|
||||
|
||||
compressed, metadata = compressor.compress(
|
||||
tool_results, target_tokens=100, strategy="extract"
|
||||
)
|
||||
|
||||
# Should keep important fields, discard verbose ones
|
||||
assert "data" in compressed[0]["result"]
|
||||
|
||||
|
||||
class TestPromptOptimizer:
|
||||
"""Test PromptOptimizer functionality"""
|
||||
|
||||
def test_compress_tool_descriptions(self):
|
||||
"""Test compression of tool descriptions"""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"tool_{i}",
|
||||
"description": "This is a very long description " * 50,
|
||||
"parameters": {},
|
||||
},
|
||||
}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
optimized, metadata = optimizer.optimize_tools(
|
||||
tools, target_tokens=500, strategy="compress"
|
||||
)
|
||||
|
||||
assert metadata["optimized_tokens"] < metadata["original_tokens"]
|
||||
assert metadata["descriptions_compressed"] > 0
|
||||
|
||||
def test_lazy_load_tools(self):
|
||||
"""Test lazy loading of tools based on query"""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_tool",
|
||||
"description": "Search for information",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate_tool",
|
||||
"description": "Perform calculations",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "other_tool",
|
||||
"description": "Do something else",
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
optimized, metadata = optimizer.optimize_tools(
|
||||
tools, target_tokens=200, query="I want to search for something", strategy="lazy_load"
|
||||
)
|
||||
|
||||
# Should prefer search tool
|
||||
assert len(optimized) < len(tools)
|
||||
tool_names = [t["function"]["name"] for t in optimized]
|
||||
# Search tool should be included due to query relevance
|
||||
assert any("search" in name for name in tool_names)
|
||||
|
||||
|
||||
def test_integration_compression_workflow():
|
||||
"""Test complete compression workflow"""
|
||||
# Simulate a scenario with large inputs
|
||||
manager = TokenBudgetManager(model_id="gpt-4o")
|
||||
history_compressor = HistoryCompressor()
|
||||
doc_compressor = DocumentCompressor()
|
||||
|
||||
# Large chat history
|
||||
history = [
|
||||
{"prompt": f"Question {i}" * 50, "response": f"Answer {i}" * 50}
|
||||
for i in range(50)
|
||||
]
|
||||
|
||||
# Large documents
|
||||
docs = [
|
||||
{"text": f"Document {i} content" * 100, "title": f"Doc {i}"} for i in range(20)
|
||||
]
|
||||
|
||||
# Check budget
|
||||
budget, usage, recommendation = manager.check_and_recommend(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
current_query="What is Python?",
|
||||
chat_history=history,
|
||||
retrieved_docs=docs,
|
||||
)
|
||||
|
||||
# Should need compression
|
||||
assert recommendation.needs_compression()
|
||||
|
||||
# Apply compression
|
||||
if recommendation.compress_history:
|
||||
compressed_history, hist_meta = history_compressor.compress(
|
||||
history, recommendation.target_history_tokens or budget.chat_history
|
||||
)
|
||||
assert len(compressed_history) < len(history)
|
||||
|
||||
if recommendation.compress_docs:
|
||||
compressed_docs, doc_meta = doc_compressor.compress(
|
||||
docs,
|
||||
recommendation.target_docs_tokens or budget.retrieved_docs,
|
||||
query="Python",
|
||||
)
|
||||
assert len(compressed_docs) < len(docs)
|
||||
Reference in New Issue
Block a user