feat: add support for structured output and JSON schema validation

This commit is contained in:
Siddhant Rai
2025-08-13 13:29:51 +05:30
parent 56831fbcf2
commit 896dcf1f9e
13 changed files with 660 additions and 153 deletions

View File

@@ -83,9 +83,24 @@ class AnswerResource(Resource, BaseAnswerResource):
index=None,
should_save_conversation=data.get("save_conversation", True),
)
conversation_id, response, sources, tool_calls, thought, error = (
self.process_response_stream(stream)
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if error:
return make_response({"error": error}, 400)
result = {
@@ -95,6 +110,9 @@ class AnswerResource(Resource, BaseAnswerResource):
"tool_calls": tool_calls,
"thought": thought,
}
if structured_info:
result.update(structured_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",

View File

@@ -79,12 +79,20 @@ class BaseAnswerResource:
"""
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
for line in agent.gen(query=question, retriever=retriever):
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
@@ -109,6 +117,17 @@ class BaseAnswerResource:
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
@@ -139,28 +158,28 @@ class BaseAnswerResource:
)
else:
conversation_id = None
# Send conversation ID
data = json.dumps({"type": "id", "id": str(conversation_id)})
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
# Log the interaction
retriever_params = retriever.get_params()
self.user_logs_collection.insert_one(
{
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if is_structured:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
self.user_logs_collection.insert_one(log_data)
# End of stream
@@ -185,6 +204,8 @@ class BaseAnswerResource:
tool_calls = []
thought = ""
stream_ended = False
is_structured = False
schema_info = None
for line in stream:
try:
@@ -195,6 +216,10 @@ class BaseAnswerResource:
conversation_id = event["id"]
elif event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "structured_answer":
response_full = event["answer"]
is_structured = True
schema_info = event.get("schema")
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
@@ -212,7 +237,8 @@ class BaseAnswerResource:
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly"
return (
result = (
conversation_id,
response_full,
source_log_docs,
@@ -221,6 +247,11 @@ class BaseAnswerResource:
None,
)
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"

View File

@@ -192,6 +192,7 @@ class StreamProcessor:
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
}
)
self.initial_user_id = data_key.get("user")
@@ -203,6 +204,7 @@ class StreamProcessor:
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
}
)
self.decoded_token = (
@@ -216,6 +218,7 @@ class StreamProcessor:
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"user_api_key": None,
"json_schema": None,
}
)
@@ -243,6 +246,7 @@ class StreamProcessor:
chat_history=self.history,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
)
def create_retriever(self):

View File

@@ -1127,6 +1127,7 @@ class GetAgent(Resource):
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"last_used_at": agent.get("lastUsedAt", ""),
@@ -1181,6 +1182,7 @@ class GetAgents(Resource):
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"last_used_at": agent.get("lastUsedAt", ""),
@@ -1226,6 +1228,9 @@ class CreateAgent(Resource):
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False, description="JSON schema for enforcing structured output format"
),
},
)
@@ -1244,7 +1249,35 @@ class CreateAgent(Resource):
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}),
400
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}),
400
)
except Exception as e:
return make_response(
jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}),
400
)
if data.get("status") not in ["draft", "published"]:
return make_response(
@@ -1302,6 +1335,7 @@ class CreateAgent(Resource):
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
@@ -1342,6 +1376,9 @@ class UpdateAgent(Resource):
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"json_schema": fields.Raw(
required=False, description="JSON schema for enforcing structured output format"
),
},
)
@@ -1360,6 +1397,11 @@ class UpdateAgent(Resource):
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if not ObjectId.is_valid(agent_id):
return make_response(
@@ -1405,6 +1447,7 @@ class UpdateAgent(Resource):
"tools",
"agent_type",
"status",
"json_schema",
]
for field in allowed_fields:
@@ -1797,6 +1840,7 @@ class SharedAgent(Resource):
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
"agent_type": shared_agent.get("agent_type", ""),
"status": shared_agent.get("status", ""),
"json_schema": shared_agent.get("json_schema"),
"created_at": shared_agent.get("createdAt", ""),
"updated_at": shared_agent.get("updatedAt", ""),
"shared": shared_agent.get("shared_publicly", False),
@@ -1874,6 +1918,7 @@ class SharedAgents(Resource):
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"pinned": str(agent["_id"]) in pinned_ids,