From 09b9576eef0a344a310d5052fe374199554418bb Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 11 Sep 2025 17:54:46 +0530 Subject: [PATCH] feat: enhance message and schema cleaning for Google AI integration --- application/llm/google_ai.py | 81 ++++++++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 8 deletions(-) diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 91065b74..54567f6f 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM): raise def _clean_messages_google(self, messages): + """Convert OpenAI format messages to Google AI format.""" cleaned_messages = [] for message in messages: role = message.get("role") @@ -150,6 +151,15 @@ class GoogleLLM(BaseLLM): if role == "assistant": role = "model" + elif role == "system": + continue + elif role == "tool": + continue + elif role not in ["user", "model"]: + logging.warning( + f"GoogleLLM: Converting unsupported role '{role}' to 'user'" + ) + role = "user" parts = [] if role and content is not None: @@ -188,11 +198,63 @@ class GoogleLLM(BaseLLM): else: raise ValueError(f"Unexpected content type: {type(content)}") - cleaned_messages.append(types.Content(role=role, parts=parts)) + if parts: + cleaned_messages.append(types.Content(role=role, parts=parts)) return cleaned_messages + def _clean_schema(self, schema_obj): + """ + Recursively remove unsupported fields from schema objects + and validate required properties. + """ + if not isinstance(schema_obj, dict): + return schema_obj + allowed_fields = { + "type", + "description", + "items", + "properties", + "required", + "enum", + "pattern", + "minimum", + "maximum", + "nullable", + "default", + } + + cleaned = {} + for key, value in schema_obj.items(): + if key not in allowed_fields: + continue + elif key == "type" and isinstance(value, str): + cleaned[key] = value.upper() + elif isinstance(value, dict): + cleaned[key] = self._clean_schema(value) + elif isinstance(value, list): + cleaned[key] = [self._clean_schema(item) for item in value] + else: + cleaned[key] = value + + # Validate that required properties actually exist in properties + if "required" in cleaned and "properties" in cleaned: + valid_required = [] + properties_keys = set(cleaned["properties"].keys()) + for required_prop in cleaned["required"]: + if required_prop in properties_keys: + valid_required.append(required_prop) + if valid_required: + cleaned["required"] = valid_required + else: + cleaned.pop("required", None) + elif "required" in cleaned and "properties" not in cleaned: + cleaned.pop("required", None) + + return cleaned + def _clean_tools_format(self, tools_list): + """Convert OpenAI format tools to Google AI format.""" genai_tools = [] for tool_data in tools_list: if tool_data["type"] == "function": @@ -201,18 +263,16 @@ class GoogleLLM(BaseLLM): properties = parameters.get("properties", {}) if properties: + cleaned_properties = {} + for k, v in properties.items(): + cleaned_properties[k] = self._clean_schema(v) + genai_function = dict( name=function["name"], description=function["description"], parameters={ "type": "OBJECT", - "properties": { - k: { - **v, - "type": v["type"].upper() if v["type"] else None, - } - for k, v in properties.items() - }, + "properties": cleaned_properties, "required": ( parameters["required"] if "required" in parameters @@ -242,6 +302,7 @@ class GoogleLLM(BaseLLM): response_schema=None, **kwargs, ): + """Generate content using Google AI API without streaming.""" client = genai.Client(api_key=self.api_key) if formatting == "openai": messages = self._clean_messages_google(messages) @@ -281,6 +342,7 @@ class GoogleLLM(BaseLLM): response_schema=None, **kwargs, ): + """Generate content using Google AI API with streaming.""" client = genai.Client(api_key=self.api_key) if formatting == "openai": messages = self._clean_messages_google(messages) @@ -331,12 +393,15 @@ class GoogleLLM(BaseLLM): yield chunk.text def _supports_tools(self): + """Return whether this LLM supports function calling.""" return True def _supports_structured_output(self): + """Return whether this LLM supports structured JSON output.""" return True def prepare_structured_output_format(self, json_schema): + """Convert JSON schema to Google AI structured output format.""" if not json_schema: return None