Refactor agent creation and update logic to improve error handling and default values; enhance logging for better traceability

This commit is contained in:
Siddhant Rai
2025-10-01 13:56:31 +05:30
parent 82beafc086
commit ba49eea23d
5 changed files with 373 additions and 181 deletions

View File

@@ -1726,7 +1726,7 @@ class CreateAgent(Resource):
"key": key, "key": key,
} }
if new_agent["chunks"] == "": if new_agent["chunks"] == "":
new_agent["chunks"] = "0" new_agent["chunks"] = "2"
if ( if (
new_agent["source"] == "" new_agent["source"] == ""
and new_agent["retriever"] == "" and new_agent["retriever"] == ""
@@ -1782,43 +1782,56 @@ class UpdateAgent(Resource):
@api.doc(description="Update an existing agent") @api.doc(description="Update an existing agent")
def put(self, agent_id): def put(self, agent_id):
if not (decoded_token := request.decoded_token): if not (decoded_token := request.decoded_token):
return {"success": False}, 401 return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub") user = decoded_token.get("sub")
if request.content_type == "application/json":
data = request.get_json()
else:
data = request.form.to_dict()
if "tools" in data:
try:
data["tools"] = json.loads(data["tools"])
except json.JSONDecodeError:
data["tools"] = []
if "sources" in data:
try:
data["sources"] = json.loads(data["sources"])
except json.JSONDecodeError:
data["sources"] = []
if "json_schema" in data:
try:
data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError:
data["json_schema"] = None
if not ObjectId.is_valid(agent_id): if not ObjectId.is_valid(agent_id):
return make_response( return make_response(
jsonify({"success": False, "message": "Invalid agent ID format"}), 400 jsonify({"success": False, "message": "Invalid agent ID format"}), 400
) )
oid = ObjectId(agent_id) oid = ObjectId(agent_id)
try:
if request.content_type and "application/json" in request.content_type:
data = request.get_json()
else:
data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"]
for field in json_fields:
if field in data and data[field]:
try:
data[field] = json.loads(data[field])
except json.JSONDecodeError:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid JSON format for field: {field}",
}
),
400,
)
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Invalid request data"}), 400
)
try: try:
existing_agent = agents_collection.find_one({"_id": oid, "user": user}) existing_agent = agents_collection.find_one({"_id": oid, "user": user})
except Exception as err: except Exception as err:
return make_response(
current_app.logger.error( current_app.logger.error(
f"Error finding agent {agent_id}: {err}", exc_info=True f"Error finding agent {agent_id}: {err}", exc_info=True
), )
return make_response(
jsonify({"success": False, "message": "Database error finding agent"}), jsonify({"success": False, "message": "Database error finding agent"}),
500, 500,
) )
if not existing_agent: if not existing_agent:
return make_response( return make_response(
jsonify( jsonify(
@@ -1826,13 +1839,19 @@ class UpdateAgent(Resource):
), ),
404, 404,
) )
image_url, error = handle_image_upload( image_url, error = handle_image_upload(
request, existing_agent.get("image", ""), user, storage request, existing_agent.get("image", ""), user, storage
) )
if error: if error:
return make_response( current_app.logger.error(
jsonify({"success": False, "message": "Image upload failed"}), 400 f"Image upload error for agent {agent_id}: {error}"
) )
return make_response(
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
400,
)
update_fields = {} update_fields = {}
allowed_fields = [ allowed_fields = [
"name", "name",
@@ -1850,49 +1869,30 @@ class UpdateAgent(Resource):
] ]
for field in allowed_fields: for field in allowed_fields:
if field in data: if field not in data:
continue
if field == "status": if field == "status":
new_status = data.get("status") new_status = data.get("status")
if new_status not in ["draft", "published"]: if new_status not in ["draft", "published"]:
return make_response(
jsonify(
{"success": False, "message": "Invalid status value"}
),
400,
)
update_fields[field] = new_status
elif field == "source":
source_id = data.get("source")
if source_id == "default":
# Handle special "default" source
update_fields[field] = "default"
elif source_id and ObjectId.is_valid(source_id):
update_fields[field] = DBRef("sources", ObjectId(source_id))
elif source_id:
return make_response( return make_response(
jsonify( jsonify(
{ {
"success": False, "success": False,
"message": "Invalid source ID format provided", "message": "Invalid status value. Must be 'draft' or 'published'",
} }
), ),
400, 400,
) )
else: update_fields[field] = new_status
update_fields[field] = ""
elif field == "sources": elif field == "source":
sources_list = data.get("sources", []) source_id = data.get("source")
if sources_list and isinstance(sources_list, list):
valid_sources = []
for source_id in sources_list:
if source_id == "default": if source_id == "default":
valid_sources.append("default") update_fields[field] = "default"
elif ObjectId.is_valid(source_id): elif source_id and ObjectId.is_valid(source_id):
valid_sources.append( update_fields[field] = DBRef("sources", ObjectId(source_id))
DBRef("sources", ObjectId(source_id)) elif source_id:
)
else:
return make_response( return make_response(
jsonify( jsonify(
{ {
@@ -1902,64 +1902,156 @@ class UpdateAgent(Resource):
), ),
400, 400,
) )
else:
update_fields[field] = ""
elif field == "sources":
sources_list = data.get("sources", [])
if sources_list and isinstance(sources_list, list):
valid_sources = []
for source_id in sources_list:
if source_id == "default":
valid_sources.append("default")
elif ObjectId.is_valid(source_id):
valid_sources.append(DBRef("sources", ObjectId(source_id)))
else:
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid source ID in list: {source_id}",
}
),
400,
)
update_fields[field] = valid_sources update_fields[field] = valid_sources
else: else:
update_fields[field] = [] update_fields[field] = []
elif field == "chunks": elif field == "chunks":
chunks_value = data.get("chunks") chunks_value = data.get("chunks")
if chunks_value == "": if chunks_value == "" or chunks_value is None:
update_fields[field] = "0" update_fields[field] = "2"
else: else:
try: try:
if int(chunks_value) < 0: chunks_int = int(chunks_value)
if chunks_int < 0:
return make_response( return make_response(
jsonify( jsonify(
{ {
"success": False, "success": False,
"message": "Chunks value must be a positive integer", "message": "Chunks value must be a non-negative integer",
} }
), ),
400, 400,
) )
update_fields[field] = chunks_value update_fields[field] = str(chunks_int)
except ValueError: except (ValueError, TypeError):
return make_response( return make_response(
jsonify( jsonify(
{ {
"success": False, "success": False,
"message": "Invalid chunks value provided", "message": f"Invalid chunks value: {chunks_value}",
} }
), ),
400, 400,
) )
elif field == "tools":
tools_list = data.get("tools", [])
if isinstance(tools_list, list):
update_fields[field] = tools_list
else: else:
update_fields[field] = data[field] return make_response(
jsonify(
{
"success": False,
"message": "Tools must be a list",
}
),
400,
)
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
if not value or not str(value).strip():
return make_response(
jsonify(
{
"success": False,
"message": f"Field '{field}' cannot be empty",
}
),
400,
)
update_fields[field] = value
if image_url: if image_url:
update_fields["image"] = image_url update_fields["image"] = image_url
if not update_fields: if not update_fields:
return make_response( return make_response(
jsonify({"success": False, "message": "No update data provided"}), 400 jsonify(
{
"success": False,
"message": "No valid update data provided",
}
),
400,
) )
newly_generated_key = None newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status")) final_status = update_fields.get("status", existing_agent.get("status"))
if final_status == "published": if final_status == "published":
required_published_fields = [ required_published_fields = {
"name", "name": "Agent name",
"description", "description": "Agent description",
"source", "chunks": "Chunks count",
"chunks", "prompt_id": "Prompt",
"retriever", "agent_type": "Agent type",
"prompt_id", }
"agent_type",
]
missing_published_fields = [] missing_published_fields = []
for req_field in required_published_fields: for req_field, field_label in required_published_fields.items():
final_value = update_fields.get( final_value = update_fields.get(
req_field, existing_agent.get(req_field) req_field, existing_agent.get(req_field)
) )
if req_field == "source" and final_value: if not final_value:
if not isinstance(final_value, DBRef): missing_published_fields.append(field_label)
missing_published_fields.append(req_field)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields: if missing_published_fields:
return make_response( return make_response(
jsonify( jsonify(
@@ -1970,9 +2062,11 @@ class UpdateAgent(Resource):
), ),
400, 400,
) )
if not existing_agent.get("key"): if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4()) newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key update_fields["key"] = newly_generated_key
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc) update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
try: try:
@@ -1985,20 +2079,22 @@ class UpdateAgent(Resource):
jsonify( jsonify(
{ {
"success": False, "success": False,
"message": "Agent not found or update failed unexpectedly", "message": "Agent not found or update failed",
} }
), ),
404, 404,
) )
if result.modified_count == 0 and result.matched_count == 1: if result.modified_count == 0 and result.matched_count == 1:
return make_response( return make_response(
jsonify( jsonify(
{ {
"success": True, "success": True,
"message": "Agent found, but no changes were applied", "message": "No changes detected",
"id": agent_id,
} }
), ),
304, 200,
) )
except Exception as err: except Exception as err:
current_app.logger.error( current_app.logger.error(
@@ -2008,6 +2104,7 @@ class UpdateAgent(Resource):
jsonify({"success": False, "message": "Database error during update"}), jsonify({"success": False, "message": "Database error during update"}),
500, 500,
) )
response_data = { response_data = {
"success": True, "success": True,
"id": agent_id, "id": agent_id,
@@ -2015,10 +2112,8 @@ class UpdateAgent(Resource):
} }
if newly_generated_key: if newly_generated_key:
response_data["key"] = newly_generated_key response_data["key"] = newly_generated_key
return make_response(
jsonify(response_data), return make_response(jsonify(response_data), 200)
200,
)
@user_ns.route("/api/delete_agent") @user_ns.route("/api/delete_agent")

View File

@@ -36,6 +36,11 @@ class ClassicRAG(BaseRetriever):
self.chunks = 2 self.chunks = 2
else: else:
self.chunks = chunks self.chunks = chunks
user_identifier = user_api_key if user_api_key else "default"
logging.info(
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
self.gpt_model = gpt_model self.gpt_model = gpt_model
self.token_limit = ( self.token_limit = (
token_limit token_limit
@@ -92,17 +97,12 @@ class ClassicRAG(BaseRetriever):
or not self.vectorstores or not self.vectorstores
): ):
return self.original_question return self.original_question
prompt = f"""Given the following conversation history: prompt = (
"Given the following conversation history:\n"
{self.chat_history} f"{self.chat_history}\n\n"
"Rephrase the following user question to be a standalone search query "
"that captures all relevant context from the conversation:\n"
)
Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation:
"""
messages = [ messages = [
{"role": "system", "content": prompt}, {"role": "system", "content": prompt},
@@ -120,10 +120,20 @@ class ClassicRAG(BaseRetriever):
def _get_data(self): def _get_data(self):
"""Retrieve relevant documents from configured vectorstores""" """Retrieve relevant documents from configured vectorstores"""
if self.chunks == 0 or not self.vectorstores: if self.chunks == 0 or not self.vectorstores:
logging.info(
f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, "
f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}"
)
return [] return []
all_docs = [] all_docs = []
chunks_per_source = max(1, self.chunks // len(self.vectorstores)) chunks_per_source = max(1, self.chunks // len(self.vectorstores))
logging.info(
f"ClassicRAG._get_data: Starting retrieval with chunks={self.chunks}, "
f"vectorstores={self.vectorstores}, chunks_per_source={chunks_per_source}, "
f"query='{self.question[:50]}...'"
)
for vectorstore_id in self.vectorstores: for vectorstore_id in self.vectorstores:
if vectorstore_id: if vectorstore_id:
try: try:
@@ -172,6 +182,10 @@ class ClassicRAG(BaseRetriever):
exc_info=True, exc_info=True,
) )
continue continue
logging.info(
f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents "
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source})"
)
return all_docs return all_docs
def search(self, query: str = ""): def search(self, query: str = ""):

View File

@@ -1,3 +1,4 @@
import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -9,13 +10,27 @@ from application.core.settings import settings
class EmbeddingsWrapper: class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs): def __init__(self, model_name, *args, **kwargs):
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
try:
kwargs.setdefault("trust_remote_code", True)
self.model = SentenceTransformer( self.model = SentenceTransformer(
model_name, model_name,
config_kwargs={"allow_dangerous_deserialization": True}, config_kwargs={"allow_dangerous_deserialization": True},
*args, *args,
**kwargs **kwargs,
)
if self.model is None or self.model._first_module() is None:
raise ValueError(
f"SentenceTransformer model failed to load properly for: {model_name}"
) )
self.dimension = self.model.get_sentence_embedding_dimension() self.dimension = self.model.get_sentence_embedding_dimension()
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
except Exception as e:
logging.error(
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
exc_info=True,
)
raise
def embed_query(self, query: str): def embed_query(self, query: str):
return self.model.encode(query).tolist() return self.model.encode(query).tolist()
@@ -117,15 +132,29 @@ class BaseVectorStore(ABC):
embeddings_name, openai_api_key=embeddings_key embeddings_name, openai_api_key=embeddings_key
) )
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2": elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./models/all-mpnet-base-v2"): possible_paths = [
"/app/models/all-mpnet-base-v2", # Docker absolute path
"./models/all-mpnet-base-v2", # Relative path
]
local_model_path = None
for path in possible_paths:
if os.path.exists(path):
local_model_path = path
logging.info(f"Found local model at path: {path}")
break
else:
logging.info(f"Path does not exist: {path}")
if local_model_path:
embedding_instance = EmbeddingsSingleton.get_instance( embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name="./models/all-mpnet-base-v2", local_model_path,
) )
else: else:
logging.warning(
f"Local model not found in any of the paths: {possible_paths}. Falling back to HuggingFace download."
)
embedding_instance = EmbeddingsSingleton.get_instance( embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name, embeddings_name,
) )
else: else:
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name) embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance return embedding_instance

View File

@@ -46,11 +46,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
image: '', image: '',
source: '', source: '',
sources: [], sources: [],
chunks: '', chunks: '2',
retriever: '', retriever: 'classic',
prompt_id: 'default', prompt_id: 'default',
tools: [], tools: [],
agent_type: '', agent_type: 'classic',
status: '', status: '',
json_schema: undefined, json_schema: undefined,
}); });
@@ -122,7 +122,8 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
agent.name && agent.description && agent.prompt_id && agent.agent_type; agent.name && agent.description && agent.prompt_id && agent.agent_type;
const isJsonSchemaValidOrEmpty = const isJsonSchemaValidOrEmpty =
jsonSchemaText.trim() === '' || jsonSchemaValid; jsonSchemaText.trim() === '' || jsonSchemaValid;
return hasRequiredFields && isJsonSchemaValidOrEmpty; const hasSource = selectedSourceIds.size > 0;
return hasRequiredFields && isJsonSchemaValidOrEmpty && hasSource;
}; };
const isJsonSchemaInvalid = () => { const isJsonSchemaInvalid = () => {
@@ -353,6 +354,26 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
getPrompts(); getPrompts();
}, [token]); }, [token]);
// Auto-select default source if none selected
useEffect(() => {
if (sourceDocs && sourceDocs.length > 0 && selectedSourceIds.size === 0) {
const defaultSource = sourceDocs.find((s) => s.name === 'Default');
if (defaultSource) {
setSelectedSourceIds(
new Set([
defaultSource.id || defaultSource.retriever || defaultSource.name,
]),
);
} else {
setSelectedSourceIds(
new Set([
sourceDocs[0].id || sourceDocs[0].retriever || sourceDocs[0].name,
]),
);
}
}
}, [sourceDocs, selectedSourceIds.size]);
useEffect(() => { useEffect(() => {
if ((mode === 'edit' || mode === 'draft') && agentId) { if ((mode === 'edit' || mode === 'draft') && agentId) {
const getAgent = async () => { const getAgent = async () => {
@@ -650,7 +671,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
} }
selectedIds={selectedSourceIds} selectedIds={selectedSourceIds}
onSelectionChange={(newSelectedIds: Set<string | number>) => { onSelectionChange={(newSelectedIds: Set<string | number>) => {
if (
newSelectedIds.size === 0 &&
sourceDocs &&
sourceDocs.length > 0
) {
const defaultSource = sourceDocs.find(
(s) => s.name === 'Default',
);
if (defaultSource) {
setSelectedSourceIds(
new Set([
defaultSource.id ||
defaultSource.retriever ||
defaultSource.name,
]),
);
} else {
setSelectedSourceIds(
new Set([
sourceDocs[0].id ||
sourceDocs[0].retriever ||
sourceDocs[0].name,
]),
);
}
} else {
setSelectedSourceIds(newSelectedIds); setSelectedSourceIds(newSelectedIds);
}
}} }}
title="Select Sources" title="Select Sources"
searchPlaceholder="Search sources..." searchPlaceholder="Search sources..."

View File

@@ -6,7 +6,6 @@ interface TableProps {
minWidth?: string; minWidth?: string;
} }
interface TableContainerProps { interface TableContainerProps {
children: React.ReactNode; children: React.ReactNode;
className?: string; className?: string;
@@ -34,12 +33,8 @@ interface TableCellProps {
align?: 'left' | 'right' | 'center'; align?: 'left' | 'right' | 'center';
} }
const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(({ const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(
children, ({ children, className = '', height = 'auto', bordered = true }, ref) => {
className = '',
height = 'auto',
bordered = true
}, ref) => {
return ( return (
<div className={`relative rounded-[6px] ${className}`}> <div className={`relative rounded-[6px] ${className}`}>
<div <div
@@ -47,32 +42,35 @@ const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(({
className={`w-full overflow-x-auto rounded-[6px] bg-transparent ${bordered ? 'border border-[#D7D7D7] dark:border-[#6A6A6A]' : ''}`} className={`w-full overflow-x-auto rounded-[6px] bg-transparent ${bordered ? 'border border-[#D7D7D7] dark:border-[#6A6A6A]' : ''}`}
style={{ style={{
maxHeight: height === 'auto' ? undefined : height, maxHeight: height === 'auto' ? undefined : height,
overflowY: height === 'auto' ? 'hidden' : 'auto' overflowY: height === 'auto' ? 'hidden' : 'auto',
}} }}
> >
{children} {children}
</div> </div>
</div> </div>
); );
});; },
);
TableContainer.displayName = 'TableContainer';
const Table: React.FC<TableProps> = ({ const Table: React.FC<TableProps> = ({
children, children,
className = '', className = '',
minWidth = 'min-w-[600px]' minWidth = 'min-w-[600px]',
}) => { }) => {
return ( return (
<table className={`w-full table-auto border-collapse bg-transparent ${minWidth} ${className}`}> <table
className={`w-full table-auto border-collapse bg-transparent ${minWidth} ${className}`}
>
{children} {children}
</table> </table>
); );
}; };
const TableHead: React.FC<TableHeadProps> = ({ children, className = '' }) => { const TableHead: React.FC<TableHeadProps> = ({ children, className = '' }) => {
return ( return (
<thead className={` <thead
sticky top-0 z-10 className={`sticky top-0 z-10 bg-gray-100 dark:bg-[#27282D] ${className} `}
bg-gray-100 dark:bg-[#27282D] >
${className}
`}>
{children} {children}
</thead> </thead>
); );
@@ -86,12 +84,20 @@ const TableBody: React.FC<TableHeadProps> = ({ children, className = '' }) => {
); );
}; };
const TableRow: React.FC<TableRowProps> = ({ children, className = '', onClick }) => { const TableRow: React.FC<TableRowProps> = ({
const baseClasses = "border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"; children,
const cursorClass = onClick ? "cursor-pointer" : ""; className = '',
onClick,
}) => {
const baseClasses =
'border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]';
const cursorClass = onClick ? 'cursor-pointer' : '';
return ( return (
<tr className={`${baseClasses} ${cursorClass} ${className}`} onClick={onClick}> <tr
className={`${baseClasses} ${cursorClass} ${className}`}
onClick={onClick}
>
{children} {children}
</tr> </tr>
); );
@@ -102,7 +108,7 @@ const TableHeader: React.FC<TableCellProps> = ({
className = '', className = '',
minWidth, minWidth,
width, width,
align = 'left' align = 'left',
}) => { }) => {
const getAlignmentClass = () => { const getAlignmentClass = () => {
switch (align) { switch (align) {
@@ -133,7 +139,7 @@ const TableCell: React.FC<TableCellProps> = ({
className = '', className = '',
minWidth, minWidth,
width, width,
align = 'left' align = 'left',
}) => { }) => {
const getAlignmentClass = () => { const getAlignmentClass = () => {
switch (align) { switch (align) {