feat: context compression

This commit is contained in:
Alex
2025-11-23 18:35:51 +00:00
parent 9e58eb02b3
commit 3737beb2ba
28 changed files with 5393 additions and 93 deletions

View File

@@ -1,4 +1,3 @@
import json
import logging
from google import genai
@@ -11,11 +10,13 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -33,6 +34,12 @@ class GoogleLLM(BaseLLM):
"image/jpg",
"image/webp",
"image/gif",
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -135,12 +142,38 @@ class GoogleLLM(BaseLLM):
raise
def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
Returns:
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
combined system instruction.
"""
cleaned_messages = []
system_instructions = []
def _extract_system_text(content):
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and "text" in item and item["text"] is not None:
parts.append(item["text"])
return "\n".join(parts)
return ""
for message in messages:
role = message.get("role")
content = message.get("content")
# Gemini only accepts user/model in the contents list.
if role == "system":
sys_text = _extract_system_text(content)
if sys_text:
system_instructions.append(sys_text)
continue
if role == "assistant":
role = "model"
elif role == "tool":
@@ -159,12 +192,27 @@ class GoogleLLM(BaseLLM):
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=item["function_call"]["name"],
args=cleaned_args,
),
thoughtSignature=item["thought_signature"],
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
)
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
@@ -188,7 +236,8 @@ class GoogleLLM(BaseLLM):
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
return cleaned_messages, system_instruction
def _clean_schema(self, schema_obj):
"""
@@ -274,6 +323,61 @@ class GoogleLLM(BaseLLM):
genai_tools.append(genai_tool)
return genai_tools
def _extract_preview_from_message(self, message):
"""Get a short, human-readable preview from the last message."""
try:
if hasattr(message, "parts"):
for part in reversed(message.parts):
if getattr(part, "text", None):
return part.text
function_call = getattr(part, "function_call", None)
if function_call:
name = getattr(function_call, "name", "") or "function_call"
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = getattr(function_response, "name", "") or "function_response"
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
for item in reversed(content):
if isinstance(item, str):
return item
if isinstance(item, dict):
if item.get("text"):
return item["text"]
if item.get("function_call"):
fn = item["function_call"]
if isinstance(fn, dict):
name = fn.get("name") or "function_call"
return f"function_call:{name}"
return "function_call"
if item.get("function_response"):
resp = item["function_response"]
if isinstance(resp, dict):
name = resp.get("name") or "function_response"
return f"function_response:{name}"
return "function_response"
if "text" in message and isinstance(message["text"], str):
return message["text"]
except Exception:
pass
return str(message)
def _summarize_messages_for_log(self, messages, preview_chars=20):
"""Return a compact summary for logging to avoid huge payloads."""
message_count = len(messages) if messages else 0
last_preview = ""
if messages:
last_preview = self._extract_preview_from_message(messages[-1]) or ""
last_preview = str(last_preview).replace("\n", " ")
if len(last_preview) > preview_chars:
last_preview = f"{last_preview[:preview_chars]}..."
return f"count={message_count}, last='{last_preview}'"
def _raw_gen(
self,
baseself,
@@ -287,12 +391,12 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
@@ -325,12 +429,12 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
@@ -349,8 +453,12 @@ class GoogleLLM(BaseLLM):
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(