Refactor agent creation and update logic to improve error handling and default values; enhance logging for better traceability

This commit is contained in:
Siddhant Rai
2025-10-01 13:56:31 +05:30
parent 82beafc086
commit ba49eea23d
5 changed files with 373 additions and 181 deletions

View File

@@ -1726,7 +1726,7 @@ class CreateAgent(Resource):
"key": key,
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "0"
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
@@ -1782,43 +1782,56 @@ class UpdateAgent(Resource):
@api.doc(description="Update an existing agent")
def put(self, agent_id):
if not (decoded_token := request.decoded_token):
return {"success": False}, 401
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
if request.content_type == "application/json":
data = request.get_json()
else:
data = request.form.to_dict()
if "tools" in data:
try:
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if "sources" in data:
try:
data["sources"] = json.loads(data["sources"])
except json.JSONDecodeError:
data["sources"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if not ObjectId.is_valid(agent_id):
return make_response(
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
)
oid = ObjectId(agent_id)
try:
if request.content_type and "application/json" in request.content_type:
data = request.get_json()
else:
data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"]
for field in json_fields:
if field in data and data[field]:
try:
data[field] = json.loads(data[field])
except json.JSONDecodeError:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid JSON format for field: {field}",
}
),
400,
)
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Invalid request data"}), 400
)
try:
existing_agent = agents_collection.find_one({"_id": oid, "user": user})
except Exception as err:
current_app.logger.error(
f"Error finding agent {agent_id}: {err}", exc_info=True
)
return make_response(
current_app.logger.error(
f"Error finding agent {agent_id}: {err}", exc_info=True
),
jsonify({"success": False, "message": "Database error finding agent"}),
500,
)
if not existing_agent:
return make_response(
jsonify(
@@ -1826,13 +1839,19 @@ class UpdateAgent(Resource):
),
404,
)
image_url, error = handle_image_upload(
request, existing_agent.get("image", ""), user, storage
)
if error:
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
current_app.logger.error(
f"Image upload error for agent {agent_id}: {error}"
)
return make_response(
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
400,
)
update_fields = {}
allowed_fields = [
"name",
@@ -1850,116 +1869,189 @@ class UpdateAgent(Resource):
]
for field in allowed_fields:
if field in data:
if field == "status":
new_status = data.get("status")
if new_status not in ["draft", "published"]:
return make_response(
jsonify(
{"success": False, "message": "Invalid status value"}
),
400,
)
update_fields[field] = new_status
elif field == "source":
source_id = data.get("source")
if source_id == "default":
# Handle special "default" source
if field not in data:
continue
update_fields[field] = "default"
elif source_id and ObjectId.is_valid(source_id):
update_fields[field] = DBRef("sources", ObjectId(source_id))
elif source_id:
return make_response(
jsonify(
{
"success": False,
"message": "Invalid source ID format provided",
}
),
400,
)
else:
update_fields[field] = ""
elif field == "sources":
sources_list = data.get("sources", [])
if sources_list and isinstance(sources_list, list):
valid_sources = []
for source_id in sources_list:
if source_id == "default":
valid_sources.append("default")
elif ObjectId.is_valid(source_id):
valid_sources.append(
DBRef("sources", ObjectId(source_id))
)
else:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid source ID format: {source_id}",
}
),
400,
)
update_fields[field] = valid_sources
else:
update_fields[field] = []
elif field == "chunks":
chunks_value = data.get("chunks")
if chunks_value == "":
update_fields[field] = "0"
else:
try:
if int(chunks_value) < 0:
return make_response(
jsonify(
{
"success": False,
"message": "Chunks value must be a positive integer",
}
),
400,
)
update_fields[field] = chunks_value
except ValueError:
if field == "status":
new_status = data.get("status")
if new_status not in ["draft", "published"]:
return make_response(
jsonify(
{
"success": False,
"message": "Invalid status value. Must be 'draft' or 'published'",
}
),
400,
)
update_fields[field] = new_status
elif field == "source":
source_id = data.get("source")
if source_id == "default":
update_fields[field] = "default"
elif source_id and ObjectId.is_valid(source_id):
update_fields[field] = DBRef("sources", ObjectId(source_id))
elif source_id:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid source ID format: {source_id}",
}
),
400,
)
else:
update_fields[field] = ""
elif field == "sources":
sources_list = data.get("sources", [])
if sources_list and isinstance(sources_list, list):
valid_sources = []
for source_id in sources_list:
if source_id == "default":
valid_sources.append("default")
elif ObjectId.is_valid(source_id):
valid_sources.append(DBRef("sources", ObjectId(source_id)))
else:
return make_response(
jsonify(
{
"success": False,
"message": "Invalid chunks value provided",
"message": f"Invalid source ID in list: {source_id}",
}
),
400,
)
update_fields[field] = valid_sources
else:
update_fields[field] = data[field]
update_fields[field] = []
elif field == "chunks":
chunks_value = data.get("chunks")
if chunks_value == "" or chunks_value is None:
update_fields[field] = "2"
else:
try:
chunks_int = int(chunks_value)
if chunks_int < 0:
return make_response(
jsonify(
{
"success": False,
"message": "Chunks value must be a non-negative integer",
}
),
400,
)
update_fields[field] = str(chunks_int)
except (ValueError, TypeError):
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid chunks value: {chunks_value}",
}
),
400,
)
elif field == "tools":
tools_list = data.get("tools", [])
if isinstance(tools_list, list):
update_fields[field] = tools_list
else:
return make_response(
jsonify(
{
"success": False,
"message": "Tools must be a list",
}
),
400,
)
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
if not value or not str(value).strip():
return make_response(
jsonify(
{
"success": False,
"message": f"Field '{field}' cannot be empty",
}
),
400,
)
update_fields[field] = value
if image_url:
update_fields["image"] = image_url
if not update_fields:
return make_response(
jsonify({"success": False, "message": "No update data provided"}), 400
jsonify(
{
"success": False,
"message": "No valid update data provided",
}
),
400,
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
if final_status == "published":
required_published_fields = [
"name",
"description",
"source",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field in required_published_fields:
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if req_field == "source" and final_value:
if not isinstance(final_value, DBRef):
missing_published_fields.append(req_field)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
@@ -1970,9 +2062,11 @@ class UpdateAgent(Resource):
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
try:
@@ -1985,20 +2079,22 @@ class UpdateAgent(Resource):
jsonify(
{
"success": False,
"message": "Agent not found or update failed unexpectedly",
"message": "Agent not found or update failed",
}
),
404,
)
if result.modified_count == 0 and result.matched_count == 1:
return make_response(
jsonify(
{
"success": True,
"message": "Agent found, but no changes were applied",
"message": "No changes detected",
"id": agent_id,
}
),
304,
200,
)
except Exception as err:
current_app.logger.error(
@@ -2008,6 +2104,7 @@ class UpdateAgent(Resource):
jsonify({"success": False, "message": "Database error during update"}),
500,
)
response_data = {
"success": True,
"id": agent_id,
@@ -2015,10 +2112,8 @@ class UpdateAgent(Resource):
}
if newly_generated_key:
response_data["key"] = newly_generated_key
return make_response(
jsonify(response_data),
200,
)
return make_response(jsonify(response_data), 200)
@user_ns.route("/api/delete_agent")

View File

@@ -36,6 +36,11 @@ class ClassicRAG(BaseRetriever):
self.chunks = 2
else:
self.chunks = chunks
user_identifier = user_api_key if user_api_key else "default"
logging.info(
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
self.gpt_model = gpt_model
self.token_limit = (
token_limit
@@ -92,17 +97,12 @@ class ClassicRAG(BaseRetriever):
or not self.vectorstores
):
return self.original_question
prompt = f"""Given the following conversation history:
{self.chat_history}
Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation:
"""
prompt = (
"Given the following conversation history:\n"
f"{self.chat_history}\n\n"
"Rephrase the following user question to be a standalone search query "
"that captures all relevant context from the conversation:\n"
)
messages = [
{"role": "system", "content": prompt},
@@ -120,10 +120,20 @@ class ClassicRAG(BaseRetriever):
def _get_data(self):
"""Retrieve relevant documents from configured vectorstores"""
if self.chunks == 0 or not self.vectorstores:
logging.info(
f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, "
f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}"
)
return []
all_docs = []
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
logging.info(
f"ClassicRAG._get_data: Starting retrieval with chunks={self.chunks}, "
f"vectorstores={self.vectorstores}, chunks_per_source={chunks_per_source}, "
f"query='{self.question[:50]}...'"
)
for vectorstore_id in self.vectorstores:
if vectorstore_id:
try:
@@ -172,6 +182,10 @@ class ClassicRAG(BaseRetriever):
exc_info=True,
)
continue
logging.info(
f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents "
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source})"
)
return all_docs
def search(self, query: str = ""):

View File

@@ -1,3 +1,4 @@
import logging
import os
from abc import ABC, abstractmethod
@@ -9,13 +10,27 @@ from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs
)
self.dimension = self.model.get_sentence_embedding_dimension()
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
try:
kwargs.setdefault("trust_remote_code", True)
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs,
)
if self.model is None or self.model._first_module() is None:
raise ValueError(
f"SentenceTransformer model failed to load properly for: {model_name}"
)
self.dimension = self.model.get_sentence_embedding_dimension()
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
except Exception as e:
logging.error(
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
exc_info=True,
)
raise
def embed_query(self, query: str):
return self.model.encode(query).tolist()
@@ -117,15 +132,29 @@ class BaseVectorStore(ABC):
embeddings_name, openai_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./models/all-mpnet-base-v2"):
possible_paths = [
"/app/models/all-mpnet-base-v2", # Docker absolute path
"./models/all-mpnet-base-v2", # Relative path
]
local_model_path = None
for path in possible_paths:
if os.path.exists(path):
local_model_path = path
logging.info(f"Found local model at path: {path}")
break
else:
logging.info(f"Path does not exist: {path}")
if local_model_path:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name="./models/all-mpnet-base-v2",
local_model_path,
)
else:
logging.warning(
f"Local model not found in any of the paths: {possible_paths}. Falling back to HuggingFace download."
)
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance

View File

@@ -46,11 +46,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
image: '',
source: '',
sources: [],
chunks: '',
retriever: '',
chunks: '2',
retriever: 'classic',
prompt_id: 'default',
tools: [],
agent_type: '',
agent_type: 'classic',
status: '',
json_schema: undefined,
});
@@ -122,7 +122,8 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
agent.name && agent.description && agent.prompt_id && agent.agent_type;
const isJsonSchemaValidOrEmpty =
jsonSchemaText.trim() === '' || jsonSchemaValid;
return hasRequiredFields && isJsonSchemaValidOrEmpty;
const hasSource = selectedSourceIds.size > 0;
return hasRequiredFields && isJsonSchemaValidOrEmpty && hasSource;
};
const isJsonSchemaInvalid = () => {
@@ -353,6 +354,26 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
getPrompts();
}, [token]);
// Auto-select default source if none selected
useEffect(() => {
if (sourceDocs && sourceDocs.length > 0 && selectedSourceIds.size === 0) {
const defaultSource = sourceDocs.find((s) => s.name === 'Default');
if (defaultSource) {
setSelectedSourceIds(
new Set([
defaultSource.id || defaultSource.retriever || defaultSource.name,
]),
);
} else {
setSelectedSourceIds(
new Set([
sourceDocs[0].id || sourceDocs[0].retriever || sourceDocs[0].name,
]),
);
}
}
}, [sourceDocs, selectedSourceIds.size]);
useEffect(() => {
if ((mode === 'edit' || mode === 'draft') && agentId) {
const getAgent = async () => {
@@ -650,7 +671,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
}
selectedIds={selectedSourceIds}
onSelectionChange={(newSelectedIds: Set<string | number>) => {
setSelectedSourceIds(newSelectedIds);
if (
newSelectedIds.size === 0 &&
sourceDocs &&
sourceDocs.length > 0
) {
const defaultSource = sourceDocs.find(
(s) => s.name === 'Default',
);
if (defaultSource) {
setSelectedSourceIds(
new Set([
defaultSource.id ||
defaultSource.retriever ||
defaultSource.name,
]),
);
} else {
setSelectedSourceIds(
new Set([
sourceDocs[0].id ||
sourceDocs[0].retriever ||
sourceDocs[0].name,
]),
);
}
} else {
setSelectedSourceIds(newSelectedIds);
}
}}
title="Select Sources"
searchPlaceholder="Search sources..."

View File

@@ -6,7 +6,6 @@ interface TableProps {
minWidth?: string;
}
interface TableContainerProps {
children: React.ReactNode;
className?: string;
@@ -34,45 +33,44 @@ interface TableCellProps {
align?: 'left' | 'right' | 'center';
}
const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(({
children,
className = '',
height = 'auto',
bordered = true
}, ref) => {
return (
<div className={`relative rounded-[6px] ${className}`}>
<div
ref={ref}
className={`w-full overflow-x-auto rounded-[6px] bg-transparent ${bordered ? 'border border-[#D7D7D7] dark:border-[#6A6A6A]' : ''}`}
style={{
maxHeight: height === 'auto' ? undefined : height,
overflowY: height === 'auto' ? 'hidden' : 'auto'
}}
>
{children}
const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(
({ children, className = '', height = 'auto', bordered = true }, ref) => {
return (
<div className={`relative rounded-[6px] ${className}`}>
<div
ref={ref}
className={`w-full overflow-x-auto rounded-[6px] bg-transparent ${bordered ? 'border border-[#D7D7D7] dark:border-[#6A6A6A]' : ''}`}
style={{
maxHeight: height === 'auto' ? undefined : height,
overflowY: height === 'auto' ? 'hidden' : 'auto',
}}
>
{children}
</div>
</div>
</div>
);
});;
);
},
);
TableContainer.displayName = 'TableContainer';
const Table: React.FC<TableProps> = ({
children,
className = '',
minWidth = 'min-w-[600px]'
minWidth = 'min-w-[600px]',
}) => {
return (
<table className={`w-full table-auto border-collapse bg-transparent ${minWidth} ${className}`}>
<table
className={`w-full table-auto border-collapse bg-transparent ${minWidth} ${className}`}
>
{children}
</table>
);
};
const TableHead: React.FC<TableHeadProps> = ({ children, className = '' }) => {
return (
<thead className={`
sticky top-0 z-10
bg-gray-100 dark:bg-[#27282D]
${className}
`}>
<thead
className={`sticky top-0 z-10 bg-gray-100 dark:bg-[#27282D] ${className} `}
>
{children}
</thead>
);
@@ -86,12 +84,20 @@ const TableBody: React.FC<TableHeadProps> = ({ children, className = '' }) => {
);
};
const TableRow: React.FC<TableRowProps> = ({ children, className = '', onClick }) => {
const baseClasses = "border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]";
const cursorClass = onClick ? "cursor-pointer" : "";
const TableRow: React.FC<TableRowProps> = ({
children,
className = '',
onClick,
}) => {
const baseClasses =
'border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]';
const cursorClass = onClick ? 'cursor-pointer' : '';
return (
<tr className={`${baseClasses} ${cursorClass} ${className}`} onClick={onClick}>
<tr
className={`${baseClasses} ${cursorClass} ${className}`}
onClick={onClick}
>
{children}
</tr>
);
@@ -102,7 +108,7 @@ const TableHeader: React.FC<TableCellProps> = ({
className = '',
minWidth,
width,
align = 'left'
align = 'left',
}) => {
const getAlignmentClass = () => {
switch (align) {
@@ -133,7 +139,7 @@ const TableCell: React.FC<TableCellProps> = ({
className = '',
minWidth,
width,
align = 'left'
align = 'left',
}) => {
const getAlignmentClass = () => {
switch (align) {