From 896dcf1f9eac1e0801c165a619e2cfd32f353170 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 13 Aug 2025 13:29:51 +0530 Subject: [PATCH] feat: add support for structured output and JSON schema validation --- application/agents/base.py | 47 +++- application/api/answer/routes/answer.py | 24 +- application/api/answer/routes/base.py | 75 ++++-- .../api/answer/services/stream_processor.py | 4 + application/api/user/routes.py | 45 ++++ application/llm/base.py | 16 +- application/llm/google_ai.py | 181 +++++++++---- application/llm/openai.py | 244 ++++++++++++------ frontend/src/agents/NewAgent.tsx | 133 +++++++++- frontend/src/agents/agentPreviewSlice.ts | 19 ++ frontend/src/agents/types/index.ts | 1 + .../src/conversation/conversationModels.ts | 4 + .../src/conversation/conversationSlice.ts | 20 ++ 13 files changed, 660 insertions(+), 153 deletions(-) diff --git a/application/agents/base.py b/application/agents/base.py index 509afb89..33c86f7a 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -28,6 +28,7 @@ class BaseAgent(ABC): chat_history: Optional[List[Dict]] = None, decoded_token: Optional[Dict] = None, attachments: Optional[List[Dict]] = None, + json_schema: Optional[Dict] = None, ): self.endpoint = endpoint self.llm_name = llm_name @@ -51,6 +52,7 @@ class BaseAgent(ABC): llm_name if llm_name else "default" ) self.attachments = attachments or [] + self.json_schema = json_schema @log_activity() def gen( @@ -283,6 +285,21 @@ class BaseAgent(ABC): and self.tools ): gen_kwargs["tools"] = self.tools + + if ( + self.json_schema + and hasattr(self.llm, "_supports_structured_output") + and self.llm._supports_structured_output() + ): + structured_format = self.llm.prepare_structured_output_format( + self.json_schema + ) + if structured_format: + if self.llm_name == "openai": + gen_kwargs["response_format"] = structured_format + elif self.llm_name == "google": + gen_kwargs["response_schema"] = structured_format + resp = self.llm.gen_stream(**gen_kwargs) if log_context: @@ -307,11 +324,25 @@ class BaseAgent(ABC): return resp def _handle_response(self, response, tools_dict, messages, log_context): + is_structured_output = ( + self.json_schema is not None + and hasattr(self.llm, "_supports_structured_output") + and self.llm._supports_structured_output() + ) + if isinstance(response, str): - yield {"answer": response} + answer_data = {"answer": response} + if is_structured_output: + answer_data["structured"] = True + answer_data["schema"] = self.json_schema + yield answer_data return if hasattr(response, "message") and getattr(response.message, "content", None): - yield {"answer": response.message.content} + answer_data = {"answer": response.message.content} + if is_structured_output: + answer_data["structured"] = True + answer_data["schema"] = self.json_schema + yield answer_data return processed_response_gen = self._llm_handler( response, tools_dict, messages, log_context, self.attachments @@ -319,8 +350,16 @@ class BaseAgent(ABC): for event in processed_response_gen: if isinstance(event, str): - yield {"answer": event} + answer_data = {"answer": event} + if is_structured_output: + answer_data["structured"] = True + answer_data["schema"] = self.json_schema + yield answer_data elif hasattr(event, "message") and getattr(event.message, "content", None): - yield {"answer": event.message.content} + answer_data = {"answer": event.message.content} + if is_structured_output: + answer_data["structured"] = True + answer_data["schema"] = self.json_schema + yield answer_data elif isinstance(event, dict) and "type" in event: yield event diff --git a/application/api/answer/routes/answer.py b/application/api/answer/routes/answer.py index 1b374638..2c2d8f7b 100644 --- a/application/api/answer/routes/answer.py +++ b/application/api/answer/routes/answer.py @@ -83,9 +83,24 @@ class AnswerResource(Resource, BaseAnswerResource): index=None, should_save_conversation=data.get("save_conversation", True), ) - conversation_id, response, sources, tool_calls, thought, error = ( - self.process_response_stream(stream) - ) + stream_result = self.process_response_stream(stream) + + if len(stream_result) == 7: + ( + conversation_id, + response, + sources, + tool_calls, + thought, + error, + structured_info, + ) = stream_result + else: + conversation_id, response, sources, tool_calls, thought, error = ( + stream_result + ) + structured_info = None + if error: return make_response({"error": error}, 400) result = { @@ -95,6 +110,9 @@ class AnswerResource(Resource, BaseAnswerResource): "tool_calls": tool_calls, "thought": thought, } + + if structured_info: + result.update(structured_info) except Exception as e: logger.error( f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", diff --git a/application/api/answer/routes/base.py b/application/api/answer/routes/base.py index 682da1f0..6176c5a8 100644 --- a/application/api/answer/routes/base.py +++ b/application/api/answer/routes/base.py @@ -79,12 +79,20 @@ class BaseAnswerResource: """ try: response_full, thought, source_log_docs, tool_calls = "", "", [], [] + is_structured = False + schema_info = None + structured_chunks = [] for line in agent.gen(query=question, retriever=retriever): if "answer" in line: response_full += str(line["answer"]) - data = json.dumps({"type": "answer", "answer": line["answer"]}) - yield f"data: {data}\n\n" + if line.get("structured"): + is_structured = True + schema_info = line.get("schema") + structured_chunks.append(line["answer"]) + else: + data = json.dumps({"type": "answer", "answer": line["answer"]}) + yield f"data: {data}\n\n" elif "sources" in line: truncated_sources = [] source_log_docs = line["sources"] @@ -109,6 +117,17 @@ class BaseAnswerResource: elif "type" in line: data = json.dumps(line) yield f"data: {data}\n\n" + + if is_structured and structured_chunks: + structured_data = { + "type": "structured_answer", + "answer": response_full, + "structured": True, + "schema": schema_info, + } + data = json.dumps(structured_data) + yield f"data: {data}\n\n" + if isNoneDoc: for doc in source_log_docs: doc["source"] = "None" @@ -139,28 +158,28 @@ class BaseAnswerResource: ) else: conversation_id = None - # Send conversation ID - - data = json.dumps({"type": "id", "id": str(conversation_id)}) + id_data = {"type": "id", "id": str(conversation_id)} + data = json.dumps(id_data) yield f"data: {data}\n\n" - # Log the interaction - retriever_params = retriever.get_params() - self.user_logs_collection.insert_one( - { - "action": "stream_answer", - "level": "info", - "user": decoded_token.get("sub"), - "api_key": user_api_key, - "question": question, - "response": response_full, - "sources": source_log_docs, - "retriever_params": retriever_params, - "attachments": attachment_ids, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) + log_data = { + "action": "stream_answer", + "level": "info", + "user": decoded_token.get("sub"), + "api_key": user_api_key, + "question": question, + "response": response_full, + "sources": source_log_docs, + "retriever_params": retriever_params, + "attachments": attachment_ids, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + if is_structured: + log_data["structured_output"] = True + if schema_info: + log_data["schema"] = schema_info + self.user_logs_collection.insert_one(log_data) # End of stream @@ -185,6 +204,8 @@ class BaseAnswerResource: tool_calls = [] thought = "" stream_ended = False + is_structured = False + schema_info = None for line in stream: try: @@ -195,6 +216,10 @@ class BaseAnswerResource: conversation_id = event["id"] elif event["type"] == "answer": response_full += event["answer"] + elif event["type"] == "structured_answer": + response_full = event["answer"] + is_structured = True + schema_info = event.get("schema") elif event["type"] == "source": source_log_docs = event["source"] elif event["type"] == "tool_calls": @@ -212,7 +237,8 @@ class BaseAnswerResource: if not stream_ended: logger.error("Stream ended unexpectedly without an 'end' event.") return None, None, None, None, "Stream ended unexpectedly" - return ( + + result = ( conversation_id, response_full, source_log_docs, @@ -221,6 +247,11 @@ class BaseAnswerResource: None, ) + if is_structured: + result = result + ({"structured": True, "schema": schema_info},) + + return result + def error_stream_generate(self, err_response): data = json.dumps({"type": "error", "error": err_response}) yield f"data: {data}\n\n" diff --git a/application/api/answer/services/stream_processor.py b/application/api/answer/services/stream_processor.py index ac725898..dfcfcdd2 100644 --- a/application/api/answer/services/stream_processor.py +++ b/application/api/answer/services/stream_processor.py @@ -192,6 +192,7 @@ class StreamProcessor: "prompt_id": data_key.get("prompt_id", "default"), "agent_type": data_key.get("agent_type", settings.AGENT_NAME), "user_api_key": api_key, + "json_schema": data_key.get("json_schema"), } ) self.initial_user_id = data_key.get("user") @@ -203,6 +204,7 @@ class StreamProcessor: "prompt_id": data_key.get("prompt_id", "default"), "agent_type": data_key.get("agent_type", settings.AGENT_NAME), "user_api_key": self.agent_key, + "json_schema": data_key.get("json_schema"), } ) self.decoded_token = ( @@ -216,6 +218,7 @@ class StreamProcessor: "prompt_id": self.data.get("prompt_id", "default"), "agent_type": settings.AGENT_NAME, "user_api_key": None, + "json_schema": None, } ) @@ -243,6 +246,7 @@ class StreamProcessor: chat_history=self.history, decoded_token=self.decoded_token, attachments=self.attachments, + json_schema=self.agent_config.get("json_schema"), ) def create_retriever(self): diff --git a/application/api/user/routes.py b/application/api/user/routes.py index a6c0d55b..ae9ccb31 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1127,6 +1127,7 @@ class GetAgent(Resource): "tool_details": resolve_tool_details(agent.get("tools", [])), "agent_type": agent.get("agent_type", ""), "status": agent.get("status", ""), + "json_schema": agent.get("json_schema"), "created_at": agent.get("createdAt", ""), "updated_at": agent.get("updatedAt", ""), "last_used_at": agent.get("lastUsedAt", ""), @@ -1181,6 +1182,7 @@ class GetAgents(Resource): "tool_details": resolve_tool_details(agent.get("tools", [])), "agent_type": agent.get("agent_type", ""), "status": agent.get("status", ""), + "json_schema": agent.get("json_schema"), "created_at": agent.get("createdAt", ""), "updated_at": agent.get("updatedAt", ""), "last_used_at": agent.get("lastUsedAt", ""), @@ -1226,6 +1228,9 @@ class CreateAgent(Resource): "status": fields.String( required=True, description="Status of the agent (draft or published)" ), + "json_schema": fields.Raw( + required=False, description="JSON schema for enforcing structured output format" + ), }, ) @@ -1244,7 +1249,35 @@ class CreateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "json_schema" in data: + try: + data["json_schema"] = json.loads(data["json_schema"]) + except json.JSONDecodeError: + data["json_schema"] = None print(f"Received data: {data}") + + # Validate JSON schema if provided + if data.get("json_schema"): + try: + # Basic validation - ensure it's a valid JSON structure + json_schema = data.get("json_schema") + if not isinstance(json_schema, dict): + return make_response( + jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}), + 400 + ) + + # Validate that it has either a 'schema' property or is itself a schema + if "schema" not in json_schema and "type" not in json_schema: + return make_response( + jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}), + 400 + ) + except Exception as e: + return make_response( + jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}), + 400 + ) if data.get("status") not in ["draft", "published"]: return make_response( @@ -1302,6 +1335,7 @@ class CreateAgent(Resource): "tools": data.get("tools", []), "agent_type": data.get("agent_type", ""), "status": data.get("status"), + "json_schema": data.get("json_schema"), "createdAt": datetime.datetime.now(datetime.timezone.utc), "updatedAt": datetime.datetime.now(datetime.timezone.utc), "lastUsedAt": None, @@ -1342,6 +1376,9 @@ class UpdateAgent(Resource): "status": fields.String( required=True, description="Status of the agent (draft or published)" ), + "json_schema": fields.Raw( + required=False, description="JSON schema for enforcing structured output format" + ), }, ) @@ -1360,6 +1397,11 @@ class UpdateAgent(Resource): data["tools"] = json.loads(data["tools"]) except json.JSONDecodeError: data["tools"] = [] + if "json_schema" in data: + try: + data["json_schema"] = json.loads(data["json_schema"]) + except json.JSONDecodeError: + data["json_schema"] = None if not ObjectId.is_valid(agent_id): return make_response( @@ -1405,6 +1447,7 @@ class UpdateAgent(Resource): "tools", "agent_type", "status", + "json_schema", ] for field in allowed_fields: @@ -1797,6 +1840,7 @@ class SharedAgent(Resource): "tool_details": resolve_tool_details(shared_agent.get("tools", [])), "agent_type": shared_agent.get("agent_type", ""), "status": shared_agent.get("status", ""), + "json_schema": shared_agent.get("json_schema"), "created_at": shared_agent.get("createdAt", ""), "updated_at": shared_agent.get("updatedAt", ""), "shared": shared_agent.get("shared_publicly", False), @@ -1874,6 +1918,7 @@ class SharedAgents(Resource): "tool_details": resolve_tool_details(agent.get("tools", [])), "agent_type": agent.get("agent_type", ""), "status": agent.get("status", ""), + "json_schema": agent.get("json_schema"), "created_at": agent.get("createdAt", ""), "updated_at": agent.get("updatedAt", ""), "pinned": str(agent["_id"]) in pinned_ids, diff --git a/application/llm/base.py b/application/llm/base.py index bef3e11f..b7f3c262 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -120,6 +120,20 @@ class BaseLLM(ABC): def _supports_tools(self): raise NotImplementedError("Subclass must implement _supports_tools method") + def supports_structured_output(self): + """Check if the LLM supports structured output/JSON schema enforcement""" + return hasattr(self, "_supports_structured_output") and callable( + getattr(self, "_supports_structured_output") + ) + + def _supports_structured_output(self): + return False + + def prepare_structured_output_format(self, json_schema): + """Prepare structured output format specific to the LLM provider""" + _ = json_schema + return None + def get_supported_attachment_types(self): """ Return a list of MIME types supported by this LLM for file uploads. @@ -127,4 +141,4 @@ class BaseLLM(ABC): Returns: list: List of supported MIME types """ - return [] # Default: no attachments supported + return [] diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index b749431b..91065b74 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -1,11 +1,13 @@ +import json +import logging + from google import genai from google.genai import types -import logging -import json + +from application.core.settings import settings from application.llm.base import BaseLLM from application.storage.storage_creator import StorageCreator -from application.core.settings import settings class GoogleLLM(BaseLLM): @@ -24,12 +26,12 @@ class GoogleLLM(BaseLLM): list: List of supported MIME types """ return [ - 'application/pdf', - 'image/png', - 'image/jpeg', - 'image/jpg', - 'image/webp', - 'image/gif' + "application/pdf", + "image/png", + "image/jpeg", + "image/jpg", + "image/webp", + "image/gif", ] def prepare_messages_with_attachments(self, messages, attachments=None): @@ -70,26 +72,30 @@ class GoogleLLM(BaseLLM): files = [] for attachment in attachments: - mime_type = attachment.get('mime_type') + mime_type = attachment.get("mime_type") if mime_type in self.get_supported_attachment_types(): try: file_uri = self._upload_file_to_google(attachment) - logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}") + logging.info( + f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}" + ) files.append({"file_uri": file_uri, "mime_type": mime_type}) except Exception as e: - logging.error(f"GoogleLLM: Error uploading file: {e}", exc_info=True) - if 'content' in attachment: - prepared_messages[user_message_index]["content"].append({ - "type": "text", - "text": f"[File could not be processed: {attachment.get('path', 'unknown')}]" - }) + logging.error( + f"GoogleLLM: Error uploading file: {e}", exc_info=True + ) + if "content" in attachment: + prepared_messages[user_message_index]["content"].append( + { + "type": "text", + "text": f"[File could not be processed: {attachment.get('path', 'unknown')}]", + } + ) if files: logging.info(f"GoogleLLM: Adding {len(files)} files to message") - prepared_messages[user_message_index]["content"].append({ - "files": files - }) + prepared_messages[user_message_index]["content"].append({"files": files}) return prepared_messages @@ -103,10 +109,10 @@ class GoogleLLM(BaseLLM): Returns: str: Google AI file URI for the uploaded file. """ - if 'google_file_uri' in attachment: - return attachment['google_file_uri'] + if "google_file_uri" in attachment: + return attachment["google_file_uri"] - file_path = attachment.get('path') + file_path = attachment.get("path") if not file_path: raise ValueError("No file path provided in attachment") @@ -116,17 +122,19 @@ class GoogleLLM(BaseLLM): try: file_uri = self.storage.process_file( file_path, - lambda local_path, **kwargs: self.client.files.upload(file=local_path).uri + lambda local_path, **kwargs: self.client.files.upload( + file=local_path + ).uri, ) from application.core.mongo_db import MongoDB + mongo = MongoDB.get_client() db = mongo[settings.MONGO_DB_NAME] attachments_collection = db["attachments"] - if '_id' in attachment: + if "_id" in attachment: attachments_collection.update_one( - {"_id": attachment['_id']}, - {"$set": {"google_file_uri": file_uri}} + {"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}} ) return file_uri @@ -166,13 +174,13 @@ class GoogleLLM(BaseLLM): ) ) elif "files" in item: - for file_data in item["files"]: - parts.append( - types.Part.from_uri( - file_uri=file_data["file_uri"], - mime_type=file_data["mime_type"] - ) + for file_data in item["files"]: + parts.append( + types.Part.from_uri( + file_uri=file_data["file_uri"], + mime_type=file_data["mime_type"], ) + ) else: raise ValueError( f"Unexpected content dictionary format:{item}" @@ -231,6 +239,7 @@ class GoogleLLM(BaseLLM): stream=False, tools=None, formatting="openai", + response_schema=None, **kwargs, ): client = genai.Client(api_key=self.api_key) @@ -244,16 +253,21 @@ class GoogleLLM(BaseLLM): if tools: cleaned_tools = self._clean_tools_format(tools) config.tools = cleaned_tools - response = client.models.generate_content( - model=model, - contents=messages, - config=config, - ) + + # Add response schema for structured output if provided + if response_schema: + config.response_schema = response_schema + config.response_mime_type = "application/json" + + response = client.models.generate_content( + model=model, + contents=messages, + config=config, + ) + + if tools: return response else: - response = client.models.generate_content( - model=model, contents=messages, config=config - ) return response.text def _raw_gen_stream( @@ -264,6 +278,7 @@ class GoogleLLM(BaseLLM): stream=True, tools=None, formatting="openai", + response_schema=None, **kwargs, ): client = genai.Client(api_key=self.api_key) @@ -278,17 +293,24 @@ class GoogleLLM(BaseLLM): cleaned_tools = self._clean_tools_format(tools) config.tools = cleaned_tools + # Add response schema for structured output if provided + if response_schema: + config.response_schema = response_schema + config.response_mime_type = "application/json" + # Check if we have both tools and file attachments has_attachments = False for message in messages: for part in message.parts: - if hasattr(part, 'file_data') and part.file_data is not None: + if hasattr(part, "file_data") and part.file_data is not None: has_attachments = True break if has_attachments: break - logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}") + logging.info( + f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}" + ) response = client.models.generate_content_stream( model=model, @@ -296,7 +318,6 @@ class GoogleLLM(BaseLLM): config=config, ) - for chunk in response: if hasattr(chunk, "candidates") and chunk.candidates: for candidate in chunk.candidates: @@ -311,3 +332,75 @@ class GoogleLLM(BaseLLM): def _supports_tools(self): return True + + def _supports_structured_output(self): + return True + + def prepare_structured_output_format(self, json_schema): + if not json_schema: + return None + + type_map = { + "object": "OBJECT", + "array": "ARRAY", + "string": "STRING", + "integer": "INTEGER", + "number": "NUMBER", + "boolean": "BOOLEAN", + } + + def convert(schema): + if not isinstance(schema, dict): + return schema + + result = {} + schema_type = schema.get("type") + if schema_type: + result["type"] = type_map.get(schema_type.lower(), schema_type.upper()) + + for key in [ + "description", + "nullable", + "enum", + "minItems", + "maxItems", + "required", + "propertyOrdering", + ]: + if key in schema: + result[key] = schema[key] + + if "format" in schema: + format_value = schema["format"] + if schema_type == "string": + if format_value == "date": + result["format"] = "date-time" + elif format_value in ["enum", "date-time"]: + result["format"] = format_value + else: + result["format"] = format_value + + if "properties" in schema: + result["properties"] = { + k: convert(v) for k, v in schema["properties"].items() + } + if "propertyOrdering" not in result and result.get("type") == "OBJECT": + result["propertyOrdering"] = list(result["properties"].keys()) + + if "items" in schema: + result["items"] = convert(schema["items"]) + + for field in ["anyOf", "oneOf", "allOf"]: + if field in schema: + result[field] = [convert(s) for s in schema[field]] + + return result + + try: + return convert(json_schema) + except Exception as e: + logging.error( + f"Error preparing structured output format for Google: {e}", + exc_info=True, + ) + return None diff --git a/application/llm/openai.py b/application/llm/openai.py index e363130b..618aa238 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,5 +1,5 @@ -import json import base64 +import json import logging from application.core.settings import settings @@ -13,7 +13,10 @@ class OpenAILLM(BaseLLM): from openai import OpenAI super().__init__(*args, **kwargs) - if isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip(): + if ( + isinstance(settings.OPENAI_BASE_URL, str) + and settings.OPENAI_BASE_URL.strip() + ): self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL) else: DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1" @@ -73,14 +76,30 @@ class OpenAILLM(BaseLLM): elif isinstance(item, dict): content_parts = [] if "text" in item: - content_parts.append({"type": "text", "text": item["text"]}) - elif "type" in item and item["type"] == "text" and "text" in item: + content_parts.append( + {"type": "text", "text": item["text"]} + ) + elif ( + "type" in item + and item["type"] == "text" + and "text" in item + ): content_parts.append(item) - elif "type" in item and item["type"] == "file" and "file" in item: + elif ( + "type" in item + and item["type"] == "file" + and "file" in item + ): content_parts.append(item) - elif "type" in item and item["type"] == "image_url" and "image_url" in item: + elif ( + "type" in item + and item["type"] == "image_url" + and "image_url" in item + ): content_parts.append(item) - cleaned_messages.append({"role": role, "content": content_parts}) + cleaned_messages.append( + {"role": role, "content": content_parts} + ) else: raise ValueError( f"Unexpected content dictionary format: {item}" @@ -98,22 +117,29 @@ class OpenAILLM(BaseLLM): stream=False, tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, + response_format=None, **kwargs, ): messages = self._clean_messages_openai(messages) + + request_params = { + "model": model, + "messages": messages, + "stream": stream, + **kwargs, + } + + if tools: + request_params["tools"] = tools + + if response_format: + request_params["response_format"] = response_format + + response = self.client.chat.completions.create(**request_params) + if tools: - response = self.client.chat.completions.create( - model=model, - messages=messages, - stream=stream, - tools=tools, - **kwargs, - ) return response.choices[0] else: - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) return response.choices[0].message.content def _raw_gen_stream( @@ -124,24 +150,32 @@ class OpenAILLM(BaseLLM): stream=True, tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, + response_format=None, **kwargs, ): messages = self._clean_messages_openai(messages) + + request_params = { + "model": model, + "messages": messages, + "stream": stream, + **kwargs, + } + if tools: - response = self.client.chat.completions.create( - model=model, - messages=messages, - stream=stream, - tools=tools, - **kwargs, - ) - else: - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) + request_params["tools"] = tools + + if response_format: + request_params["response_format"] = response_format + + response = self.client.chat.completions.create(**request_params) for line in response: - if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0: + if ( + len(line.choices) > 0 + and line.choices[0].delta.content is not None + and len(line.choices[0].delta.content) > 0 + ): yield line.choices[0].delta.content elif len(line.choices) > 0: yield line.choices[0] @@ -149,6 +183,66 @@ class OpenAILLM(BaseLLM): def _supports_tools(self): return True + def _supports_structured_output(self): + return True + + def prepare_structured_output_format(self, json_schema): + if not json_schema: + return None + + try: + + def add_additional_properties_false(schema_obj): + if isinstance(schema_obj, dict): + schema_copy = schema_obj.copy() + + if schema_copy.get("type") == "object": + schema_copy["additionalProperties"] = False + # Ensure 'required' includes all properties for OpenAI strict mode + if "properties" in schema_copy: + schema_copy["required"] = list( + schema_copy["properties"].keys() + ) + + for key, value in schema_copy.items(): + if key == "properties" and isinstance(value, dict): + schema_copy[key] = { + prop_name: add_additional_properties_false(prop_schema) + for prop_name, prop_schema in value.items() + } + elif key == "items" and isinstance(value, dict): + schema_copy[key] = add_additional_properties_false(value) + elif key in ["anyOf", "oneOf", "allOf"] and isinstance( + value, list + ): + schema_copy[key] = [ + add_additional_properties_false(sub_schema) + for sub_schema in value + ] + + return schema_copy + return schema_obj + + processed_schema = add_additional_properties_false(json_schema) + + result = { + "type": "json_schema", + "json_schema": { + "name": processed_schema.get("name", "response"), + "description": processed_schema.get( + "description", "Structured response" + ), + "schema": processed_schema, + "strict": True, + }, + } + + return result + + except Exception as e: + logging.error(f"Error preparing structured output format: {e}") + return None + def get_supported_attachment_types(self): """ Return a list of MIME types supported by OpenAI for file uploads. @@ -157,12 +251,12 @@ class OpenAILLM(BaseLLM): list: List of supported MIME types """ return [ - 'application/pdf', - 'image/png', - 'image/jpeg', - 'image/jpg', - 'image/webp', - 'image/gif' + "application/pdf", + "image/png", + "image/jpeg", + "image/jpg", + "image/webp", + "image/gif", ] def prepare_messages_with_attachments(self, messages, attachments=None): @@ -202,39 +296,46 @@ class OpenAILLM(BaseLLM): prepared_messages[user_message_index]["content"] = [] for attachment in attachments: - mime_type = attachment.get('mime_type') + mime_type = attachment.get("mime_type") - if mime_type and mime_type.startswith('image/'): + if mime_type and mime_type.startswith("image/"): try: base64_image = self._get_base64_image(attachment) - prepared_messages[user_message_index]["content"].append({ - "type": "image_url", - "image_url": { - "url": f"data:{mime_type};base64,{base64_image}" + prepared_messages[user_message_index]["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_image}" + }, } - }) + ) except Exception as e: - logging.error(f"Error processing image attachment: {e}", exc_info=True) - if 'content' in attachment: - prepared_messages[user_message_index]["content"].append({ - "type": "text", - "text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]" - }) + logging.error( + f"Error processing image attachment: {e}", exc_info=True + ) + if "content" in attachment: + prepared_messages[user_message_index]["content"].append( + { + "type": "text", + "text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]", + } + ) # Handle PDFs using the file API - elif mime_type == 'application/pdf': + elif mime_type == "application/pdf": try: file_id = self._upload_file_to_openai(attachment) - prepared_messages[user_message_index]["content"].append({ - "type": "file", - "file": {"file_id": file_id} - }) + prepared_messages[user_message_index]["content"].append( + {"type": "file", "file": {"file_id": file_id}} + ) except Exception as e: logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True) - if 'content' in attachment: - prepared_messages[user_message_index]["content"].append({ - "type": "text", - "text": f"File content:\n\n{attachment['content']}" - }) + if "content" in attachment: + prepared_messages[user_message_index]["content"].append( + { + "type": "text", + "text": f"File content:\n\n{attachment['content']}", + } + ) return prepared_messages @@ -248,13 +349,13 @@ class OpenAILLM(BaseLLM): Returns: str: Base64-encoded image data. """ - file_path = attachment.get('path') + file_path = attachment.get("path") if not file_path: raise ValueError("No file path provided in attachment") try: with self.storage.get_file(file_path) as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') + return base64.b64encode(image_file.read()).decode("utf-8") except FileNotFoundError: raise FileNotFoundError(f"File not found: {file_path}") @@ -273,10 +374,10 @@ class OpenAILLM(BaseLLM): """ import logging - if 'openai_file_id' in attachment: - return attachment['openai_file_id'] + if "openai_file_id" in attachment: + return attachment["openai_file_id"] - file_path = attachment.get('path') + file_path = attachment.get("path") if not self.storage.file_exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") @@ -285,19 +386,18 @@ class OpenAILLM(BaseLLM): file_id = self.storage.process_file( file_path, lambda local_path, **kwargs: self.client.files.create( - file=open(local_path, 'rb'), - purpose="assistants" - ).id + file=open(local_path, "rb"), purpose="assistants" + ).id, ) from application.core.mongo_db import MongoDB + mongo = MongoDB.get_client() db = mongo[settings.MONGO_DB_NAME] attachments_collection = db["attachments"] - if '_id' in attachment: + if "_id" in attachment: attachments_collection.update_one( - {"_id": attachment['_id']}, - {"$set": {"openai_file_id": file_id}} + {"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}} ) return file_id @@ -308,9 +408,7 @@ class OpenAILLM(BaseLLM): class AzureOpenAILLM(OpenAILLM): - def __init__( - self, api_key, user_api_key, *args, **kwargs - ): + def __init__(self, api_key, user_api_key, *args, **kwargs): super().__init__(api_key) self.api_base = (settings.OPENAI_API_BASE,) @@ -321,5 +419,5 @@ class AzureOpenAILLM(OpenAILLM): self.client = AzureOpenAI( api_key=api_key, api_version=settings.OPENAI_API_VERSION, - azure_endpoint=settings.OPENAI_API_BASE + azure_endpoint=settings.OPENAI_API_BASE, ) diff --git a/frontend/src/agents/NewAgent.tsx b/frontend/src/agents/NewAgent.tsx index 1e5e117a..da8cef5d 100644 --- a/frontend/src/agents/NewAgent.tsx +++ b/frontend/src/agents/NewAgent.tsx @@ -51,6 +51,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { tools: [], agent_type: '', status: '', + json_schema: undefined, }); const [imageFile, setImageFile] = useState(null); const [prompts, setPrompts] = useState< @@ -72,6 +73,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { const [hasChanges, setHasChanges] = useState(false); const [draftLoading, setDraftLoading] = useState(false); const [publishLoading, setPublishLoading] = useState(false); + const [jsonSchemaText, setJsonSchemaText] = useState(''); + const [jsonSchemaValid, setJsonSchemaValid] = useState(true); + const [isJsonSchemaExpanded, setIsJsonSchemaExpanded] = useState(false); const initialAgentRef = useRef(null); const sourceAnchorButtonRef = useRef(null); @@ -113,9 +117,15 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { ]; const isPublishable = () => { - return ( - agent.name && agent.description && agent.prompt_id && agent.agent_type - ); + const hasRequiredFields = + agent.name && agent.description && agent.prompt_id && agent.agent_type; + const isJsonSchemaValidOrEmpty = + jsonSchemaText.trim() === '' || jsonSchemaValid; + return hasRequiredFields && isJsonSchemaValidOrEmpty; + }; + + const isJsonSchemaInvalid = () => { + return jsonSchemaText.trim() !== '' && !jsonSchemaValid; }; const handleUpload = useCallback((files: File[]) => { @@ -153,6 +163,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { formData.append('tools', JSON.stringify(agent.tools)); else formData.append('tools', '[]'); + if (agent.json_schema) { + formData.append('json_schema', JSON.stringify(agent.json_schema)); + } + try { setDraftLoading(true); const response = @@ -194,6 +208,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { formData.append('tools', JSON.stringify(agent.tools)); else formData.append('tools', '[]'); + if (agent.json_schema) { + formData.append('json_schema', JSON.stringify(agent.json_schema)); + } + try { setPublishLoading(true); const response = @@ -226,6 +244,22 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { } }; + const validateAndSetJsonSchema = (text: string) => { + setJsonSchemaText(text); + if (text.trim() === '') { + setAgent({ ...agent, json_schema: undefined }); + setJsonSchemaValid(true); + return; + } + try { + const parsed = JSON.parse(text); + setAgent({ ...agent, json_schema: parsed }); + setJsonSchemaValid(true); + } catch (error) { + setJsonSchemaValid(false); + } + }; + useEffect(() => { const getTools = async () => { const response = await userService.getUserTools(token); @@ -264,6 +298,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { setSelectedSourceIds(new Set([data.retriever])); if (data.tools) setSelectedToolIds(new Set(data.tools)); if (data.status === 'draft') setEffectiveMode('draft'); + if (data.json_schema) { + const jsonText = JSON.stringify(data.json_schema, null, 2); + setJsonSchemaText(jsonText); + setJsonSchemaValid(true); + } setAgent(data); initialAgentRef.current = data; }; @@ -317,10 +356,17 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { setHasChanges(false); return; } + + const initialJsonSchemaText = initialAgentRef.current.json_schema + ? JSON.stringify(initialAgentRef.current.json_schema, null, 2) + : ''; + const isChanged = - !isEqual(agent, initialAgentRef.current) || imageFile !== null; + !isEqual(agent, initialAgentRef.current) || + imageFile !== null || + jsonSchemaText !== initialJsonSchemaText; setHasChanges(isChanged); - }, [agent, dispatch, effectiveMode, imageFile]); + }, [agent, dispatch, effectiveMode, imageFile, jsonSchemaText]); return (
@@ -356,7 +402,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) { )} {modeConfig[effectiveMode].showSaveDraft && (
+
+ + {isJsonSchemaExpanded && ( +
+
+

JSON response schema

+

+ Define a JSON schema to enforce structured output format +

+
+