mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: enhance message and schema cleaning for Google AI integration
This commit is contained in:
@@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _clean_messages_google(self, messages):
|
def _clean_messages_google(self, messages):
|
||||||
|
"""Convert OpenAI format messages to Google AI format."""
|
||||||
cleaned_messages = []
|
cleaned_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.get("role")
|
role = message.get("role")
|
||||||
@@ -150,6 +151,15 @@ class GoogleLLM(BaseLLM):
|
|||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
role = "model"
|
role = "model"
|
||||||
|
elif role == "system":
|
||||||
|
continue
|
||||||
|
elif role == "tool":
|
||||||
|
continue
|
||||||
|
elif role not in ["user", "model"]:
|
||||||
|
logging.warning(
|
||||||
|
f"GoogleLLM: Converting unsupported role '{role}' to 'user'"
|
||||||
|
)
|
||||||
|
role = "user"
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
if role and content is not None:
|
if role and content is not None:
|
||||||
@@ -188,11 +198,63 @@ class GoogleLLM(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||||
|
|
||||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
if parts:
|
||||||
|
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||||
|
|
||||||
return cleaned_messages
|
return cleaned_messages
|
||||||
|
|
||||||
|
def _clean_schema(self, schema_obj):
|
||||||
|
"""
|
||||||
|
Recursively remove unsupported fields from schema objects
|
||||||
|
and validate required properties.
|
||||||
|
"""
|
||||||
|
if not isinstance(schema_obj, dict):
|
||||||
|
return schema_obj
|
||||||
|
allowed_fields = {
|
||||||
|
"type",
|
||||||
|
"description",
|
||||||
|
"items",
|
||||||
|
"properties",
|
||||||
|
"required",
|
||||||
|
"enum",
|
||||||
|
"pattern",
|
||||||
|
"minimum",
|
||||||
|
"maximum",
|
||||||
|
"nullable",
|
||||||
|
"default",
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned = {}
|
||||||
|
for key, value in schema_obj.items():
|
||||||
|
if key not in allowed_fields:
|
||||||
|
continue
|
||||||
|
elif key == "type" and isinstance(value, str):
|
||||||
|
cleaned[key] = value.upper()
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
cleaned[key] = self._clean_schema(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||||
|
else:
|
||||||
|
cleaned[key] = value
|
||||||
|
|
||||||
|
# Validate that required properties actually exist in properties
|
||||||
|
if "required" in cleaned and "properties" in cleaned:
|
||||||
|
valid_required = []
|
||||||
|
properties_keys = set(cleaned["properties"].keys())
|
||||||
|
for required_prop in cleaned["required"]:
|
||||||
|
if required_prop in properties_keys:
|
||||||
|
valid_required.append(required_prop)
|
||||||
|
if valid_required:
|
||||||
|
cleaned["required"] = valid_required
|
||||||
|
else:
|
||||||
|
cleaned.pop("required", None)
|
||||||
|
elif "required" in cleaned and "properties" not in cleaned:
|
||||||
|
cleaned.pop("required", None)
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
def _clean_tools_format(self, tools_list):
|
def _clean_tools_format(self, tools_list):
|
||||||
|
"""Convert OpenAI format tools to Google AI format."""
|
||||||
genai_tools = []
|
genai_tools = []
|
||||||
for tool_data in tools_list:
|
for tool_data in tools_list:
|
||||||
if tool_data["type"] == "function":
|
if tool_data["type"] == "function":
|
||||||
@@ -201,18 +263,16 @@ class GoogleLLM(BaseLLM):
|
|||||||
properties = parameters.get("properties", {})
|
properties = parameters.get("properties", {})
|
||||||
|
|
||||||
if properties:
|
if properties:
|
||||||
|
cleaned_properties = {}
|
||||||
|
for k, v in properties.items():
|
||||||
|
cleaned_properties[k] = self._clean_schema(v)
|
||||||
|
|
||||||
genai_function = dict(
|
genai_function = dict(
|
||||||
name=function["name"],
|
name=function["name"],
|
||||||
description=function["description"],
|
description=function["description"],
|
||||||
parameters={
|
parameters={
|
||||||
"type": "OBJECT",
|
"type": "OBJECT",
|
||||||
"properties": {
|
"properties": cleaned_properties,
|
||||||
k: {
|
|
||||||
**v,
|
|
||||||
"type": v["type"].upper() if v["type"] else None,
|
|
||||||
}
|
|
||||||
for k, v in properties.items()
|
|
||||||
},
|
|
||||||
"required": (
|
"required": (
|
||||||
parameters["required"]
|
parameters["required"]
|
||||||
if "required" in parameters
|
if "required" in parameters
|
||||||
@@ -242,6 +302,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
response_schema=None,
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Generate content using Google AI API without streaming."""
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -281,6 +342,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
response_schema=None,
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Generate content using Google AI API with streaming."""
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -331,12 +393,15 @@ class GoogleLLM(BaseLLM):
|
|||||||
yield chunk.text
|
yield chunk.text
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
|
"""Return whether this LLM supports function calling."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _supports_structured_output(self):
|
def _supports_structured_output(self):
|
||||||
|
"""Return whether this LLM supports structured JSON output."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def prepare_structured_output_format(self, json_schema):
|
def prepare_structured_output_format(self, json_schema):
|
||||||
|
"""Convert JSON schema to Google AI structured output format."""
|
||||||
if not json_schema:
|
if not json_schema:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user