mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #1924 from siiddhantt/feat/agent-schema-response
feat: add support for structured output and JSON schema validation
This commit is contained in:
@@ -28,6 +28,7 @@ class BaseAgent(ABC):
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
json_schema: Optional[Dict] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
@@ -51,6 +52,7 @@ class BaseAgent(ABC):
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = json_schema
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
@@ -283,6 +285,21 @@ class BaseAgent(ABC):
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
|
||||
if (
|
||||
self.json_schema
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
):
|
||||
structured_format = self.llm.prepare_structured_output_format(
|
||||
self.json_schema
|
||||
)
|
||||
if structured_format:
|
||||
if self.llm_name == "openai":
|
||||
gen_kwargs["response_format"] = structured_format
|
||||
elif self.llm_name == "google":
|
||||
gen_kwargs["response_schema"] = structured_format
|
||||
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
@@ -307,11 +324,25 @@ class BaseAgent(ABC):
|
||||
return resp
|
||||
|
||||
def _handle_response(self, response, tools_dict, messages, log_context):
|
||||
is_structured_output = (
|
||||
self.json_schema is not None
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
)
|
||||
|
||||
if isinstance(response, str):
|
||||
yield {"answer": response}
|
||||
answer_data = {"answer": response}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
if hasattr(response, "message") and getattr(response.message, "content", None):
|
||||
yield {"answer": response.message.content}
|
||||
answer_data = {"answer": response.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
processed_response_gen = self._llm_handler(
|
||||
response, tools_dict, messages, log_context, self.attachments
|
||||
@@ -319,8 +350,16 @@ class BaseAgent(ABC):
|
||||
|
||||
for event in processed_response_gen:
|
||||
if isinstance(event, str):
|
||||
yield {"answer": event}
|
||||
answer_data = {"answer": event}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif hasattr(event, "message") and getattr(event.message, "content", None):
|
||||
yield {"answer": event.message.content}
|
||||
answer_data = {"answer": event.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif isinstance(event, dict) and "type" in event:
|
||||
yield event
|
||||
|
||||
@@ -83,9 +83,24 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
)
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
self.process_response_stream(stream)
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if len(stream_result) == 7:
|
||||
(
|
||||
conversation_id,
|
||||
response,
|
||||
sources,
|
||||
tool_calls,
|
||||
thought,
|
||||
error,
|
||||
structured_info,
|
||||
) = stream_result
|
||||
else:
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
stream_result
|
||||
)
|
||||
structured_info = None
|
||||
|
||||
if error:
|
||||
return make_response({"error": error}, 400)
|
||||
result = {
|
||||
@@ -95,6 +110,9 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
}
|
||||
|
||||
if structured_info:
|
||||
result.update(structured_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
|
||||
@@ -79,12 +79,20 @@ class BaseAnswerResource:
|
||||
"""
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
|
||||
for line in agent.gen(query=question, retriever=retriever):
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
@@ -109,6 +117,17 @@ class BaseAnswerResource:
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
@@ -139,15 +158,12 @@ class BaseAnswerResource:
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
# Send conversation ID
|
||||
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Log the interaction
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
log_entry = {
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
@@ -159,13 +175,17 @@ class BaseAnswerResource:
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
|
||||
if is_structured:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_entry.items():
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_entry[key] = value[:10000]
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_entry)
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
# End of stream
|
||||
|
||||
@@ -190,6 +210,8 @@ class BaseAnswerResource:
|
||||
tool_calls = []
|
||||
thought = ""
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
@@ -200,6 +222,10 @@ class BaseAnswerResource:
|
||||
conversation_id = event["id"]
|
||||
elif event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "structured_answer":
|
||||
response_full = event["answer"]
|
||||
is_structured = True
|
||||
schema_info = event.get("schema")
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
@@ -217,7 +243,8 @@ class BaseAnswerResource:
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
return (
|
||||
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
@@ -226,6 +253,11 @@ class BaseAnswerResource:
|
||||
None,
|
||||
)
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
@@ -192,6 +192,7 @@ class StreamProcessor:
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": api_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
}
|
||||
)
|
||||
self.initial_user_id = data_key.get("user")
|
||||
@@ -203,6 +204,7 @@ class StreamProcessor:
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": self.agent_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
}
|
||||
)
|
||||
self.decoded_token = (
|
||||
@@ -216,6 +218,7 @@ class StreamProcessor:
|
||||
"prompt_id": self.data.get("prompt_id", "default"),
|
||||
"agent_type": settings.AGENT_NAME,
|
||||
"user_api_key": None,
|
||||
"json_schema": None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -243,6 +246,7 @@ class StreamProcessor:
|
||||
chat_history=self.history,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
def create_retriever(self):
|
||||
|
||||
@@ -1127,6 +1127,7 @@ class GetAgent(Resource):
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -1181,6 +1182,7 @@ class GetAgents(Resource):
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
@@ -1226,6 +1228,9 @@ class CreateAgent(Resource):
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False, description="JSON schema for enforcing structured output format"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1244,7 +1249,35 @@ class CreateAgent(Resource):
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
if data.get("json_schema"):
|
||||
try:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "JSON schema must be a valid JSON object"}),
|
||||
400
|
||||
)
|
||||
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property"}),
|
||||
400
|
||||
)
|
||||
except Exception as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"Invalid JSON schema: {str(e)}"}),
|
||||
400
|
||||
)
|
||||
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
return make_response(
|
||||
@@ -1302,6 +1335,7 @@ class CreateAgent(Resource):
|
||||
"tools": data.get("tools", []),
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
@@ -1342,6 +1376,9 @@ class UpdateAgent(Resource):
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False, description="JSON schema for enforcing structured output format"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1360,6 +1397,11 @@ class UpdateAgent(Resource):
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
|
||||
if not ObjectId.is_valid(agent_id):
|
||||
return make_response(
|
||||
@@ -1405,6 +1447,7 @@ class UpdateAgent(Resource):
|
||||
"tools",
|
||||
"agent_type",
|
||||
"status",
|
||||
"json_schema",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -1797,6 +1840,7 @@ class SharedAgent(Resource):
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
@@ -1874,6 +1918,7 @@ class SharedAgents(Resource):
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
|
||||
@@ -120,6 +120,20 @@ class BaseLLM(ABC):
|
||||
def _supports_tools(self):
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
|
||||
def supports_structured_output(self):
|
||||
"""Check if the LLM supports structured output/JSON schema enforcement"""
|
||||
return hasattr(self, "_supports_structured_output") and callable(
|
||||
getattr(self, "_supports_structured_output")
|
||||
)
|
||||
|
||||
def _supports_structured_output(self):
|
||||
return False
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
"""Prepare structured output format specific to the LLM provider"""
|
||||
_ = json_schema
|
||||
return None
|
||||
|
||||
def get_supported_attachment_types(self):
|
||||
"""
|
||||
Return a list of MIME types supported by this LLM for file uploads.
|
||||
@@ -127,4 +141,4 @@ class BaseLLM(ABC):
|
||||
Returns:
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
return [] # Default: no attachments supported
|
||||
return []
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
import logging
|
||||
import json
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
@@ -24,12 +26,12 @@ class GoogleLLM(BaseLLM):
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
return [
|
||||
'application/pdf',
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/jpg',
|
||||
'image/webp',
|
||||
'image/gif'
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -70,26 +72,30 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get('mime_type')
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type in self.get_supported_attachment_types():
|
||||
try:
|
||||
file_uri = self._upload_file_to_google(attachment)
|
||||
logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}")
|
||||
logging.info(
|
||||
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
||||
)
|
||||
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
||||
except Exception as e:
|
||||
logging.error(f"GoogleLLM: Error uploading file: {e}", exc_info=True)
|
||||
if 'content' in attachment:
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]"
|
||||
})
|
||||
logging.error(
|
||||
f"GoogleLLM: Error uploading file: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"files": files
|
||||
})
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
|
||||
return prepared_messages
|
||||
|
||||
@@ -103,10 +109,10 @@ class GoogleLLM(BaseLLM):
|
||||
Returns:
|
||||
str: Google AI file URI for the uploaded file.
|
||||
"""
|
||||
if 'google_file_uri' in attachment:
|
||||
return attachment['google_file_uri']
|
||||
if "google_file_uri" in attachment:
|
||||
return attachment["google_file_uri"]
|
||||
|
||||
file_path = attachment.get('path')
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
@@ -116,17 +122,19 @@ class GoogleLLM(BaseLLM):
|
||||
try:
|
||||
file_uri = self.storage.process_file(
|
||||
file_path,
|
||||
lambda local_path, **kwargs: self.client.files.upload(file=local_path).uri
|
||||
lambda local_path, **kwargs: self.client.files.upload(
|
||||
file=local_path
|
||||
).uri,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
if '_id' in attachment:
|
||||
if "_id" in attachment:
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment['_id']},
|
||||
{"$set": {"google_file_uri": file_uri}}
|
||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||
)
|
||||
|
||||
return file_uri
|
||||
@@ -166,13 +174,13 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
)
|
||||
elif "files" in item:
|
||||
for file_data in item["files"]:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"]
|
||||
)
|
||||
for file_data in item["files"]:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
@@ -231,6 +239,7 @@ class GoogleLLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
@@ -244,16 +253,21 @@ class GoogleLLM(BaseLLM):
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if tools:
|
||||
return response
|
||||
else:
|
||||
response = client.models.generate_content(
|
||||
model=model, contents=messages, config=config
|
||||
)
|
||||
return response.text
|
||||
|
||||
def _raw_gen_stream(
|
||||
@@ -264,6 +278,7 @@ class GoogleLLM(BaseLLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
@@ -278,17 +293,24 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
# Check if we have both tools and file attachments
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
if hasattr(part, 'file_data') and part.file_data is not None:
|
||||
if hasattr(part, "file_data") and part.file_data is not None:
|
||||
has_attachments = True
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
|
||||
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
|
||||
logging.info(
|
||||
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
@@ -296,7 +318,6 @@ class GoogleLLM(BaseLLM):
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
@@ -311,3 +332,75 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
type_map = {
|
||||
"object": "OBJECT",
|
||||
"array": "ARRAY",
|
||||
"string": "STRING",
|
||||
"integer": "INTEGER",
|
||||
"number": "NUMBER",
|
||||
"boolean": "BOOLEAN",
|
||||
}
|
||||
|
||||
def convert(schema):
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
schema_type = schema.get("type")
|
||||
if schema_type:
|
||||
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
|
||||
|
||||
for key in [
|
||||
"description",
|
||||
"nullable",
|
||||
"enum",
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"required",
|
||||
"propertyOrdering",
|
||||
]:
|
||||
if key in schema:
|
||||
result[key] = schema[key]
|
||||
|
||||
if "format" in schema:
|
||||
format_value = schema["format"]
|
||||
if schema_type == "string":
|
||||
if format_value == "date":
|
||||
result["format"] = "date-time"
|
||||
elif format_value in ["enum", "date-time"]:
|
||||
result["format"] = format_value
|
||||
else:
|
||||
result["format"] = format_value
|
||||
|
||||
if "properties" in schema:
|
||||
result["properties"] = {
|
||||
k: convert(v) for k, v in schema["properties"].items()
|
||||
}
|
||||
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
|
||||
result["propertyOrdering"] = list(result["properties"].keys())
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert(schema["items"])
|
||||
|
||||
for field in ["anyOf", "oneOf", "allOf"]:
|
||||
if field in schema:
|
||||
result[field] = [convert(s) for s in schema[field]]
|
||||
|
||||
return result
|
||||
|
||||
try:
|
||||
return convert(json_schema)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error preparing structured output format for Google: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from application.core.settings import settings
|
||||
@@ -13,7 +13,10 @@ class OpenAILLM(BaseLLM):
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip():
|
||||
if (
|
||||
isinstance(settings.OPENAI_BASE_URL, str)
|
||||
and settings.OPENAI_BASE_URL.strip()
|
||||
):
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
else:
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
@@ -73,14 +76,30 @@ class OpenAILLM(BaseLLM):
|
||||
elif isinstance(item, dict):
|
||||
content_parts = []
|
||||
if "text" in item:
|
||||
content_parts.append({"type": "text", "text": item["text"]})
|
||||
elif "type" in item and item["type"] == "text" and "text" in item:
|
||||
content_parts.append(
|
||||
{"type": "text", "text": item["text"]}
|
||||
)
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "text"
|
||||
and "text" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
elif "type" in item and item["type"] == "file" and "file" in item:
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "file"
|
||||
and "file" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "image_url"
|
||||
and "image_url" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
cleaned_messages.append({"role": role, "content": content_parts})
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": content_parts}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
@@ -98,22 +117,29 @@ class OpenAILLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
@@ -124,24 +150,32 @@ class OpenAILLM(BaseLLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
for line in response:
|
||||
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
@@ -149,6 +183,66 @@ class OpenAILLM(BaseLLM):
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
def add_additional_properties_false(schema_obj):
|
||||
if isinstance(schema_obj, dict):
|
||||
schema_copy = schema_obj.copy()
|
||||
|
||||
if schema_copy.get("type") == "object":
|
||||
schema_copy["additionalProperties"] = False
|
||||
# Ensure 'required' includes all properties for OpenAI strict mode
|
||||
if "properties" in schema_copy:
|
||||
schema_copy["required"] = list(
|
||||
schema_copy["properties"].keys()
|
||||
)
|
||||
|
||||
for key, value in schema_copy.items():
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
schema_copy[key] = {
|
||||
prop_name: add_additional_properties_false(prop_schema)
|
||||
for prop_name, prop_schema in value.items()
|
||||
}
|
||||
elif key == "items" and isinstance(value, dict):
|
||||
schema_copy[key] = add_additional_properties_false(value)
|
||||
elif key in ["anyOf", "oneOf", "allOf"] and isinstance(
|
||||
value, list
|
||||
):
|
||||
schema_copy[key] = [
|
||||
add_additional_properties_false(sub_schema)
|
||||
for sub_schema in value
|
||||
]
|
||||
|
||||
return schema_copy
|
||||
return schema_obj
|
||||
|
||||
processed_schema = add_additional_properties_false(json_schema)
|
||||
|
||||
result = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": processed_schema.get("name", "response"),
|
||||
"description": processed_schema.get(
|
||||
"description", "Structured response"
|
||||
),
|
||||
"schema": processed_schema,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error preparing structured output format: {e}")
|
||||
return None
|
||||
|
||||
def get_supported_attachment_types(self):
|
||||
"""
|
||||
Return a list of MIME types supported by OpenAI for file uploads.
|
||||
@@ -157,12 +251,12 @@ class OpenAILLM(BaseLLM):
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
return [
|
||||
'application/pdf',
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/jpg',
|
||||
'image/webp',
|
||||
'image/gif'
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -202,39 +296,46 @@ class OpenAILLM(BaseLLM):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get('mime_type')
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type and mime_type.startswith('image/'):
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
try:
|
||||
base64_image = self._get_base64_image(attachment)
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{base64_image}"
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{base64_image}"
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing image attachment: {e}", exc_info=True)
|
||||
if 'content' in attachment:
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "text",
|
||||
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]"
|
||||
})
|
||||
logging.error(
|
||||
f"Error processing image attachment: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
# Handle PDFs using the file API
|
||||
elif mime_type == 'application/pdf':
|
||||
elif mime_type == "application/pdf":
|
||||
try:
|
||||
file_id = self._upload_file_to_openai(attachment)
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "file",
|
||||
"file": {"file_id": file_id}
|
||||
})
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{"type": "file", "file": {"file_id": file_id}}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
|
||||
if 'content' in attachment:
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "text",
|
||||
"text": f"File content:\n\n{attachment['content']}"
|
||||
})
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"File content:\n\n{attachment['content']}",
|
||||
}
|
||||
)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
@@ -248,13 +349,13 @@ class OpenAILLM(BaseLLM):
|
||||
Returns:
|
||||
str: Base64-encoded image data.
|
||||
"""
|
||||
file_path = attachment.get('path')
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
try:
|
||||
with self.storage.get_file(file_path) as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
@@ -273,10 +374,10 @@ class OpenAILLM(BaseLLM):
|
||||
"""
|
||||
import logging
|
||||
|
||||
if 'openai_file_id' in attachment:
|
||||
return attachment['openai_file_id']
|
||||
if "openai_file_id" in attachment:
|
||||
return attachment["openai_file_id"]
|
||||
|
||||
file_path = attachment.get('path')
|
||||
file_path = attachment.get("path")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
@@ -285,19 +386,18 @@ class OpenAILLM(BaseLLM):
|
||||
file_id = self.storage.process_file(
|
||||
file_path,
|
||||
lambda local_path, **kwargs: self.client.files.create(
|
||||
file=open(local_path, 'rb'),
|
||||
purpose="assistants"
|
||||
).id
|
||||
file=open(local_path, "rb"), purpose="assistants"
|
||||
).id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
if '_id' in attachment:
|
||||
if "_id" in attachment:
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment['_id']},
|
||||
{"$set": {"openai_file_id": file_id}}
|
||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||
)
|
||||
|
||||
return file_id
|
||||
@@ -308,9 +408,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
|
||||
def __init__(
|
||||
self, api_key, user_api_key, *args, **kwargs
|
||||
):
|
||||
def __init__(self, api_key, user_api_key, *args, **kwargs):
|
||||
|
||||
super().__init__(api_key)
|
||||
self.api_base = (settings.OPENAI_API_BASE,)
|
||||
@@ -321,5 +419,5 @@ class AzureOpenAILLM(OpenAILLM):
|
||||
self.client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
api_version=settings.OPENAI_API_VERSION,
|
||||
azure_endpoint=settings.OPENAI_API_BASE
|
||||
azure_endpoint=settings.OPENAI_API_BASE,
|
||||
)
|
||||
|
||||
@@ -51,6 +51,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
tools: [],
|
||||
agent_type: '',
|
||||
status: '',
|
||||
json_schema: undefined,
|
||||
});
|
||||
const [imageFile, setImageFile] = useState<File | null>(null);
|
||||
const [prompts, setPrompts] = useState<
|
||||
@@ -72,6 +73,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const [hasChanges, setHasChanges] = useState(false);
|
||||
const [draftLoading, setDraftLoading] = useState(false);
|
||||
const [publishLoading, setPublishLoading] = useState(false);
|
||||
const [jsonSchemaText, setJsonSchemaText] = useState('');
|
||||
const [jsonSchemaValid, setJsonSchemaValid] = useState(true);
|
||||
const [isJsonSchemaExpanded, setIsJsonSchemaExpanded] = useState(false);
|
||||
|
||||
const initialAgentRef = useRef<Agent | null>(null);
|
||||
const sourceAnchorButtonRef = useRef<HTMLButtonElement>(null);
|
||||
@@ -113,9 +117,15 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
];
|
||||
|
||||
const isPublishable = () => {
|
||||
return (
|
||||
agent.name && agent.description && agent.prompt_id && agent.agent_type
|
||||
);
|
||||
const hasRequiredFields =
|
||||
agent.name && agent.description && agent.prompt_id && agent.agent_type;
|
||||
const isJsonSchemaValidOrEmpty =
|
||||
jsonSchemaText.trim() === '' || jsonSchemaValid;
|
||||
return hasRequiredFields && isJsonSchemaValidOrEmpty;
|
||||
};
|
||||
|
||||
const isJsonSchemaInvalid = () => {
|
||||
return jsonSchemaText.trim() !== '' && !jsonSchemaValid;
|
||||
};
|
||||
|
||||
const handleUpload = useCallback((files: File[]) => {
|
||||
@@ -153,6 +163,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('tools', JSON.stringify(agent.tools));
|
||||
else formData.append('tools', '[]');
|
||||
|
||||
if (agent.json_schema) {
|
||||
formData.append('json_schema', JSON.stringify(agent.json_schema));
|
||||
}
|
||||
|
||||
try {
|
||||
setDraftLoading(true);
|
||||
const response =
|
||||
@@ -194,6 +208,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('tools', JSON.stringify(agent.tools));
|
||||
else formData.append('tools', '[]');
|
||||
|
||||
if (agent.json_schema) {
|
||||
formData.append('json_schema', JSON.stringify(agent.json_schema));
|
||||
}
|
||||
|
||||
try {
|
||||
setPublishLoading(true);
|
||||
const response =
|
||||
@@ -226,6 +244,22 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
}
|
||||
};
|
||||
|
||||
const validateAndSetJsonSchema = (text: string) => {
|
||||
setJsonSchemaText(text);
|
||||
if (text.trim() === '') {
|
||||
setAgent({ ...agent, json_schema: undefined });
|
||||
setJsonSchemaValid(true);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(text);
|
||||
setAgent({ ...agent, json_schema: parsed });
|
||||
setJsonSchemaValid(true);
|
||||
} catch (error) {
|
||||
setJsonSchemaValid(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const getTools = async () => {
|
||||
const response = await userService.getUserTools(token);
|
||||
@@ -264,6 +298,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
setSelectedSourceIds(new Set([data.retriever]));
|
||||
if (data.tools) setSelectedToolIds(new Set(data.tools));
|
||||
if (data.status === 'draft') setEffectiveMode('draft');
|
||||
if (data.json_schema) {
|
||||
const jsonText = JSON.stringify(data.json_schema, null, 2);
|
||||
setJsonSchemaText(jsonText);
|
||||
setJsonSchemaValid(true);
|
||||
}
|
||||
setAgent(data);
|
||||
initialAgentRef.current = data;
|
||||
};
|
||||
@@ -317,10 +356,17 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
setHasChanges(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const initialJsonSchemaText = initialAgentRef.current.json_schema
|
||||
? JSON.stringify(initialAgentRef.current.json_schema, null, 2)
|
||||
: '';
|
||||
|
||||
const isChanged =
|
||||
!isEqual(agent, initialAgentRef.current) || imageFile !== null;
|
||||
!isEqual(agent, initialAgentRef.current) ||
|
||||
imageFile !== null ||
|
||||
jsonSchemaText !== initialJsonSchemaText;
|
||||
setHasChanges(isChanged);
|
||||
}, [agent, dispatch, effectiveMode, imageFile]);
|
||||
}, [agent, dispatch, effectiveMode, imageFile, jsonSchemaText]);
|
||||
return (
|
||||
<div className="p-4 md:p-12">
|
||||
<div className="flex items-center gap-3 px-4">
|
||||
@@ -356,7 +402,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
)}
|
||||
{modeConfig[effectiveMode].showSaveDraft && (
|
||||
<button
|
||||
className="hover:bg-vi</button>olets-are-blue border-violets-are-blue text-violets-are-blue hover:bg-violets-are-blue w-28 rounded-3xl border border-solid py-2 text-sm font-medium transition-colors hover:text-white"
|
||||
disabled={isJsonSchemaInvalid()}
|
||||
className={`border-violets-are-blue text-violets-are-blue hover:bg-violets-are-blue w-28 rounded-3xl border border-solid py-2 text-sm font-medium transition-colors hover:text-white ${
|
||||
isJsonSchemaInvalid() ? 'cursor-not-allowed opacity-30' : ''
|
||||
}`}
|
||||
onClick={handleSaveDraft}
|
||||
>
|
||||
<span className="flex items-center justify-center transition-all duration-200">
|
||||
@@ -602,6 +651,78 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<button
|
||||
onClick={() => setIsJsonSchemaExpanded(!isJsonSchemaExpanded)}
|
||||
className="flex w-full items-center justify-between text-left focus:outline-none"
|
||||
>
|
||||
<div>
|
||||
<h2 className="text-lg font-semibold">Advanced</h2>
|
||||
</div>
|
||||
<div className="ml-4 flex items-center">
|
||||
<svg
|
||||
className={`h-5 w-5 transform transition-transform duration-200 ${
|
||||
isJsonSchemaExpanded ? 'rotate-180' : ''
|
||||
}`}
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</button>
|
||||
{isJsonSchemaExpanded && (
|
||||
<div className="mt-3">
|
||||
<div>
|
||||
<h2 className="text-sm font-medium">JSON response schema</h2>
|
||||
<p className="mt-1 text-xs text-gray-600 dark:text-gray-400">
|
||||
Define a JSON schema to enforce structured output format
|
||||
</p>
|
||||
</div>
|
||||
<textarea
|
||||
value={jsonSchemaText}
|
||||
onChange={(e) => validateAndSetJsonSchema(e.target.value)}
|
||||
placeholder={`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string"}
|
||||
},
|
||||
"required": ["name", "email"],
|
||||
"additionalProperties": false
|
||||
}`}
|
||||
rows={9}
|
||||
className={`border-silver text-jet dark:bg-raisin-black dark:text-bright-gray mt-2 w-full rounded-2xl border bg-white px-4 py-3 font-mono text-sm outline-hidden dark:border-[#7E7E7E]`}
|
||||
/>
|
||||
{jsonSchemaText.trim() !== '' && (
|
||||
<div
|
||||
className={`mt-2 flex items-center gap-2 text-sm ${
|
||||
jsonSchemaValid
|
||||
? 'text-green-600 dark:text-green-400'
|
||||
: 'text-red-600 dark:text-red-400'
|
||||
}`}
|
||||
>
|
||||
<span
|
||||
className={`h-4 w-4 bg-contain bg-center bg-no-repeat ${
|
||||
jsonSchemaValid
|
||||
? "bg-[url('/src/assets/circle-check.svg')]"
|
||||
: "bg-[url('/src/assets/circle-x.svg')]"
|
||||
}`}
|
||||
/>
|
||||
{jsonSchemaValid
|
||||
? 'Valid JSON'
|
||||
: 'Invalid JSON - fix to enable saving'}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="col-span-3 flex flex-col gap-3 rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
|
||||
<h2 className="text-lg font-semibold">Preview</h2>
|
||||
|
||||
@@ -96,6 +96,17 @@ export const fetchPreviewAnswer = createAsyncThunk<
|
||||
message: data.error,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'structured_answer') {
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
index: targetIndex,
|
||||
query: {
|
||||
response: data.answer,
|
||||
structured: data.structured,
|
||||
schema: data.schema,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
@@ -201,6 +212,14 @@ export const agentPreviewSlice = createSlice({
|
||||
state.queries[index].response =
|
||||
(state.queries[index].response || '') + query.response;
|
||||
}
|
||||
|
||||
if (query.structured !== undefined) {
|
||||
state.queries[index].structured = query.structured;
|
||||
}
|
||||
|
||||
if (query.schema !== undefined) {
|
||||
state.queries[index].schema = query.schema;
|
||||
}
|
||||
},
|
||||
updateThought(
|
||||
state,
|
||||
|
||||
@@ -26,4 +26,5 @@ export type Agent = {
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
last_used_at?: string;
|
||||
json_schema?: object;
|
||||
};
|
||||
|
||||
@@ -33,6 +33,8 @@ export interface Answer {
|
||||
thought: string;
|
||||
sources: { title: string; text: string; source: string }[];
|
||||
tool_calls: ToolCallsType[];
|
||||
structured?: boolean;
|
||||
schema?: object;
|
||||
}
|
||||
|
||||
export interface Query {
|
||||
@@ -46,6 +48,8 @@ export interface Query {
|
||||
tool_calls?: ToolCallsType[];
|
||||
error?: string;
|
||||
attachments?: { id: string; fileName: string }[];
|
||||
structured?: boolean;
|
||||
schema?: object;
|
||||
}
|
||||
|
||||
export interface RetrievalPayload {
|
||||
|
||||
@@ -130,6 +130,18 @@ export const fetchAnswer = createAsyncThunk<
|
||||
message: data.error,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'structured_answer') {
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
conversationId: currentConversationId,
|
||||
index: targetIndex,
|
||||
query: {
|
||||
response: data.answer,
|
||||
structured: data.structured,
|
||||
schema: data.schema,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateStreamingQuery({
|
||||
@@ -250,6 +262,14 @@ export const conversationSlice = createSlice({
|
||||
state.queries[index].response =
|
||||
(state.queries[index].response || '') + query.response;
|
||||
}
|
||||
|
||||
if (query.structured !== undefined) {
|
||||
state.queries[index].structured = query.structured;
|
||||
}
|
||||
|
||||
if (query.schema !== undefined) {
|
||||
state.queries[index].schema = query.schema;
|
||||
}
|
||||
},
|
||||
updateConversationId(
|
||||
state,
|
||||
|
||||
Reference in New Issue
Block a user