mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-02 01:53:14 +00:00
feat: add support for structured output and JSON schema validation
This commit is contained in:
@@ -83,9 +83,24 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
)
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
self.process_response_stream(stream)
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if len(stream_result) == 7:
|
||||
(
|
||||
conversation_id,
|
||||
response,
|
||||
sources,
|
||||
tool_calls,
|
||||
thought,
|
||||
error,
|
||||
structured_info,
|
||||
) = stream_result
|
||||
else:
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
stream_result
|
||||
)
|
||||
structured_info = None
|
||||
|
||||
if error:
|
||||
return make_response({"error": error}, 400)
|
||||
result = {
|
||||
@@ -95,6 +110,9 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
}
|
||||
|
||||
if structured_info:
|
||||
result.update(structured_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
|
||||
@@ -79,12 +79,20 @@ class BaseAnswerResource:
|
||||
"""
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
|
||||
for line in agent.gen(query=question, retriever=retriever):
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
@@ -109,6 +117,17 @@ class BaseAnswerResource:
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
@@ -139,28 +158,28 @@ class BaseAnswerResource:
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
# Send conversation ID
|
||||
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Log the interaction
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
self.user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
if is_structured:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
# End of stream
|
||||
|
||||
@@ -185,6 +204,8 @@ class BaseAnswerResource:
|
||||
tool_calls = []
|
||||
thought = ""
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
@@ -195,6 +216,10 @@ class BaseAnswerResource:
|
||||
conversation_id = event["id"]
|
||||
elif event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "structured_answer":
|
||||
response_full = event["answer"]
|
||||
is_structured = True
|
||||
schema_info = event.get("schema")
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
@@ -212,7 +237,8 @@ class BaseAnswerResource:
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
return (
|
||||
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
@@ -221,6 +247,11 @@ class BaseAnswerResource:
|
||||
None,
|
||||
)
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
@@ -192,6 +192,7 @@ class StreamProcessor:
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": api_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
}
|
||||
)
|
||||
self.initial_user_id = data_key.get("user")
|
||||
@@ -203,6 +204,7 @@ class StreamProcessor:
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": self.agent_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
}
|
||||
)
|
||||
self.decoded_token = (
|
||||
@@ -216,6 +218,7 @@ class StreamProcessor:
|
||||
"prompt_id": self.data.get("prompt_id", "default"),
|
||||
"agent_type": settings.AGENT_NAME,
|
||||
"user_api_key": None,
|
||||
"json_schema": None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -243,6 +246,7 @@ class StreamProcessor:
|
||||
chat_history=self.history,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
def create_retriever(self):
|
||||
|
||||
@@ -1127,6 +1127,7 @@ class GetAgent(Resource):
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -1181,6 +1182,7 @@ class GetAgents(Resource):
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -1226,6 +1228,9 @@ class CreateAgent(Resource):
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False, description="JSON schema for enforcing structured output format"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1244,7 +1249,35 @@ class CreateAgent(Resource):
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
if data.get("json_schema"):
|
||||
try:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}),
|
||||
400
|
||||
)
|
||||
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}),
|
||||
400
|
||||
)
|
||||
except Exception as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}),
|
||||
400
|
||||
)
|
||||
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
return make_response(
|
||||
@@ -1302,6 +1335,7 @@ class CreateAgent(Resource):
|
||||
"tools": data.get("tools", []),
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
@@ -1342,6 +1376,9 @@ class UpdateAgent(Resource):
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False, description="JSON schema for enforcing structured output format"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1360,6 +1397,11 @@ class UpdateAgent(Resource):
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
|
||||
if not ObjectId.is_valid(agent_id):
|
||||
return make_response(
|
||||
@@ -1405,6 +1447,7 @@ class UpdateAgent(Resource):
|
||||
"tools",
|
||||
"agent_type",
|
||||
"status",
|
||||
"json_schema",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -1797,6 +1840,7 @@ class SharedAgent(Resource):
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
@@ -1874,6 +1918,7 @@ class SharedAgents(Resource):
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
|
||||
Reference in New Issue
Block a user