feat: add support for structured output and JSON schema validation

This commit is contained in:
Siddhant Rai
2025-08-13 13:29:51 +05:30
parent 56831fbcf2
commit 896dcf1f9e
13 changed files with 660 additions and 153 deletions

View File

@@ -120,6 +120,20 @@ class BaseLLM(ABC):
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")
def supports_structured_output(self):
"""Check if the LLM supports structured output/JSON schema enforcement"""
return hasattr(self, "_supports_structured_output") and callable(
getattr(self, "_supports_structured_output")
)
def _supports_structured_output(self):
return False
def prepare_structured_output_format(self, json_schema):
"""Prepare structured output format specific to the LLM provider"""
_ = json_schema
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by this LLM for file uploads.
@@ -127,4 +141,4 @@ class BaseLLM(ABC):
Returns:
list: List of supported MIME types
"""
return [] # Default: no attachments supported
return []

View File

@@ -1,11 +1,13 @@
import json
import logging
from google import genai
from google.genai import types
import logging
import json
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
from application.core.settings import settings
class GoogleLLM(BaseLLM):
@@ -24,12 +26,12 @@ class GoogleLLM(BaseLLM):
list: List of supported MIME types
"""
return [
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -70,26 +72,30 @@ class GoogleLLM(BaseLLM):
files = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
mime_type = attachment.get("mime_type")
if mime_type in self.get_supported_attachment_types():
try:
file_uri = self._upload_file_to_google(attachment)
logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}")
logging.info(
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
)
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(f"GoogleLLM: Error uploading file: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]"
})
logging.error(
f"GoogleLLM: Error uploading file: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({
"files": files
})
prepared_messages[user_message_index]["content"].append({"files": files})
return prepared_messages
@@ -103,10 +109,10 @@ class GoogleLLM(BaseLLM):
Returns:
str: Google AI file URI for the uploaded file.
"""
if 'google_file_uri' in attachment:
return attachment['google_file_uri']
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
file_path = attachment.get('path')
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
@@ -116,17 +122,19 @@ class GoogleLLM(BaseLLM):
try:
file_uri = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.upload(file=local_path).uri
lambda local_path, **kwargs: self.client.files.upload(
file=local_path
).uri,
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if '_id' in attachment:
if "_id" in attachment:
attachments_collection.update_one(
{"_id": attachment['_id']},
{"$set": {"google_file_uri": file_uri}}
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
)
return file_uri
@@ -166,13 +174,13 @@ class GoogleLLM(BaseLLM):
)
)
elif "files" in item:
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"]
)
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
@@ -231,6 +239,7 @@ class GoogleLLM(BaseLLM):
stream=False,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
client = genai.Client(api_key=self.api_key)
@@ -244,16 +253,21 @@ class GoogleLLM(BaseLLM):
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
if tools:
return response
else:
response = client.models.generate_content(
model=model, contents=messages, config=config
)
return response.text
def _raw_gen_stream(
@@ -264,6 +278,7 @@ class GoogleLLM(BaseLLM):
stream=True,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
client = genai.Client(api_key=self.api_key)
@@ -278,17 +293,24 @@ class GoogleLLM(BaseLLM):
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, 'file_data') and part.file_data is not None:
if hasattr(part, "file_data") and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
)
response = client.models.generate_content_stream(
model=model,
@@ -296,7 +318,6 @@ class GoogleLLM(BaseLLM):
config=config,
)
for chunk in response:
if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates:
@@ -311,3 +332,75 @@ class GoogleLLM(BaseLLM):
def _supports_tools(self):
return True
def _supports_structured_output(self):
return True
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
type_map = {
"object": "OBJECT",
"array": "ARRAY",
"string": "STRING",
"integer": "INTEGER",
"number": "NUMBER",
"boolean": "BOOLEAN",
}
def convert(schema):
if not isinstance(schema, dict):
return schema
result = {}
schema_type = schema.get("type")
if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
for key in [
"description",
"nullable",
"enum",
"minItems",
"maxItems",
"required",
"propertyOrdering",
]:
if key in schema:
result[key] = schema[key]
if "format" in schema:
format_value = schema["format"]
if schema_type == "string":
if format_value == "date":
result["format"] = "date-time"
elif format_value in ["enum", "date-time"]:
result["format"] = format_value
else:
result["format"] = format_value
if "properties" in schema:
result["properties"] = {
k: convert(v) for k, v in schema["properties"].items()
}
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys())
if "items" in schema:
result["items"] = convert(schema["items"])
for field in ["anyOf", "oneOf", "allOf"]:
if field in schema:
result[field] = [convert(s) for s in schema[field]]
return result
try:
return convert(json_schema)
except Exception as e:
logging.error(
f"Error preparing structured output format for Google: {e}",
exc_info=True,
)
return None

View File

@@ -1,5 +1,5 @@
import json
import base64
import json
import logging
from application.core.settings import settings
@@ -13,7 +13,10 @@ class OpenAILLM(BaseLLM):
from openai import OpenAI
super().__init__(*args, **kwargs)
if isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip():
if (
isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip()
):
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
else:
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
@@ -73,14 +76,30 @@ class OpenAILLM(BaseLLM):
elif isinstance(item, dict):
content_parts = []
if "text" in item:
content_parts.append({"type": "text", "text": item["text"]})
elif "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(
{"type": "text", "text": item["text"]}
)
elif (
"type" in item
and item["type"] == "text"
and "text" in item
):
content_parts.append(item)
elif "type" in item and item["type"] == "file" and "file" in item:
elif (
"type" in item
and item["type"] == "file"
and "file" in item
):
content_parts.append(item)
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
elif (
"type" in item
and item["type"] == "image_url"
and "image_url" in item
):
content_parts.append(item)
cleaned_messages.append({"role": role, "content": content_parts})
cleaned_messages.append(
{"role": role, "content": content_parts}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
@@ -98,22 +117,29 @@ class OpenAILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
if tools:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
@@ -124,24 +150,32 @@ class OpenAILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
for line in response:
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
@@ -149,6 +183,66 @@ class OpenAILLM(BaseLLM):
def _supports_tools(self):
return True
def _supports_structured_output(self):
return True
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
try:
def add_additional_properties_false(schema_obj):
if isinstance(schema_obj, dict):
schema_copy = schema_obj.copy()
if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode
if "properties" in schema_copy:
schema_copy["required"] = list(
schema_copy["properties"].keys()
)
for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict):
schema_copy[key] = {
prop_name: add_additional_properties_false(prop_schema)
for prop_name, prop_schema in value.items()
}
elif key == "items" and isinstance(value, dict):
schema_copy[key] = add_additional_properties_false(value)
elif key in ["anyOf", "oneOf", "allOf"] and isinstance(
value, list
):
schema_copy[key] = [
add_additional_properties_false(sub_schema)
for sub_schema in value
]
return schema_copy
return schema_obj
processed_schema = add_additional_properties_false(json_schema)
result = {
"type": "json_schema",
"json_schema": {
"name": processed_schema.get("name", "response"),
"description": processed_schema.get(
"description", "Structured response"
),
"schema": processed_schema,
"strict": True,
},
}
return result
except Exception as e:
logging.error(f"Error preparing structured output format: {e}")
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by OpenAI for file uploads.
@@ -157,12 +251,12 @@ class OpenAILLM(BaseLLM):
list: List of supported MIME types
"""
return [
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -202,39 +296,46 @@ class OpenAILLM(BaseLLM):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
mime_type = attachment.get("mime_type")
if mime_type and mime_type.startswith('image/'):
if mime_type and mime_type.startswith("image/"):
try:
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
prepared_messages[user_message_index]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
},
}
})
)
except Exception as e:
logging.error(f"Error processing image attachment: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]"
})
logging.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
# Handle PDFs using the file API
elif mime_type == 'application/pdf':
elif mime_type == "application/pdf":
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append({
"type": "file",
"file": {"file_id": file_id}
})
prepared_messages[user_message_index]["content"].append(
{"type": "file", "file": {"file_id": file_id}}
)
except Exception as e:
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"File content:\n\n{attachment['content']}"
})
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"File content:\n\n{attachment['content']}",
}
)
return prepared_messages
@@ -248,13 +349,13 @@ class OpenAILLM(BaseLLM):
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get('path')
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")
@@ -273,10 +374,10 @@ class OpenAILLM(BaseLLM):
"""
import logging
if 'openai_file_id' in attachment:
return attachment['openai_file_id']
if "openai_file_id" in attachment:
return attachment["openai_file_id"]
file_path = attachment.get('path')
file_path = attachment.get("path")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
@@ -285,19 +386,18 @@ class OpenAILLM(BaseLLM):
file_id = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.create(
file=open(local_path, 'rb'),
purpose="assistants"
).id
file=open(local_path, "rb"), purpose="assistants"
).id,
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if '_id' in attachment:
if "_id" in attachment:
attachments_collection.update_one(
{"_id": attachment['_id']},
{"$set": {"openai_file_id": file_id}}
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
)
return file_id
@@ -308,9 +408,7 @@ class OpenAILLM(BaseLLM):
class AzureOpenAILLM(OpenAILLM):
def __init__(
self, api_key, user_api_key, *args, **kwargs
):
def __init__(self, api_key, user_api_key, *args, **kwargs):
super().__init__(api_key)
self.api_base = (settings.OPENAI_API_BASE,)
@@ -321,5 +419,5 @@ class AzureOpenAILLM(OpenAILLM):
self.client = AzureOpenAI(
api_key=api_key,
api_version=settings.OPENAI_API_VERSION,
azure_endpoint=settings.OPENAI_API_BASE
azure_endpoint=settings.OPENAI_API_BASE,
)