mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Refactor agent creation and update logic to improve error handling and default values; enhance logging for better traceability
This commit is contained in:
@@ -1726,7 +1726,7 @@ class CreateAgent(Resource):
|
|||||||
"key": key,
|
"key": key,
|
||||||
}
|
}
|
||||||
if new_agent["chunks"] == "":
|
if new_agent["chunks"] == "":
|
||||||
new_agent["chunks"] = "0"
|
new_agent["chunks"] = "2"
|
||||||
if (
|
if (
|
||||||
new_agent["source"] == ""
|
new_agent["source"] == ""
|
||||||
and new_agent["retriever"] == ""
|
and new_agent["retriever"] == ""
|
||||||
@@ -1782,43 +1782,56 @@ class UpdateAgent(Resource):
|
|||||||
@api.doc(description="Update an existing agent")
|
@api.doc(description="Update an existing agent")
|
||||||
def put(self, agent_id):
|
def put(self, agent_id):
|
||||||
if not (decoded_token := request.decoded_token):
|
if not (decoded_token := request.decoded_token):
|
||||||
return {"success": False}, 401
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||||
|
)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
if request.content_type == "application/json":
|
|
||||||
data = request.get_json()
|
|
||||||
else:
|
|
||||||
data = request.form.to_dict()
|
|
||||||
if "tools" in data:
|
|
||||||
try:
|
|
||||||
data["tools"] = json.loads(data["tools"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
data["tools"] = []
|
|
||||||
if "sources" in data:
|
|
||||||
try:
|
|
||||||
data["sources"] = json.loads(data["sources"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
data["sources"] = []
|
|
||||||
if "json_schema" in data:
|
|
||||||
try:
|
|
||||||
data["json_schema"] = json.loads(data["json_schema"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
data["json_schema"] = None
|
|
||||||
if not ObjectId.is_valid(agent_id):
|
if not ObjectId.is_valid(agent_id):
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
|
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
|
||||||
)
|
)
|
||||||
oid = ObjectId(agent_id)
|
oid = ObjectId(agent_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if request.content_type and "application/json" in request.content_type:
|
||||||
|
data = request.get_json()
|
||||||
|
else:
|
||||||
|
data = request.form.to_dict()
|
||||||
|
json_fields = ["tools", "sources", "json_schema"]
|
||||||
|
for field in json_fields:
|
||||||
|
if field in data and data[field]:
|
||||||
|
try:
|
||||||
|
data[field] = json.loads(data[field])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"Invalid JSON format for field: {field}",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error parsing request data: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid request data"}), 400
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
existing_agent = agents_collection.find_one({"_id": oid, "user": user})
|
existing_agent = agents_collection.find_one({"_id": oid, "user": user})
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return make_response(
|
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error finding agent {agent_id}: {err}", exc_info=True
|
f"Error finding agent {agent_id}: {err}", exc_info=True
|
||||||
),
|
)
|
||||||
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Database error finding agent"}),
|
jsonify({"success": False, "message": "Database error finding agent"}),
|
||||||
500,
|
500,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not existing_agent:
|
if not existing_agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -1826,13 +1839,19 @@ class UpdateAgent(Resource):
|
|||||||
),
|
),
|
||||||
404,
|
404,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_url, error = handle_image_upload(
|
image_url, error = handle_image_upload(
|
||||||
request, existing_agent.get("image", ""), user, storage
|
request, existing_agent.get("image", ""), user, storage
|
||||||
)
|
)
|
||||||
if error:
|
if error:
|
||||||
return make_response(
|
current_app.logger.error(
|
||||||
jsonify({"success": False, "message": "Image upload failed"}), 400
|
f"Image upload error for agent {agent_id}: {error}"
|
||||||
)
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
update_fields = {}
|
update_fields = {}
|
||||||
allowed_fields = [
|
allowed_fields = [
|
||||||
"name",
|
"name",
|
||||||
@@ -1850,49 +1869,30 @@ class UpdateAgent(Resource):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for field in allowed_fields:
|
for field in allowed_fields:
|
||||||
if field in data:
|
if field not in data:
|
||||||
|
continue
|
||||||
|
|
||||||
if field == "status":
|
if field == "status":
|
||||||
new_status = data.get("status")
|
new_status = data.get("status")
|
||||||
if new_status not in ["draft", "published"]:
|
if new_status not in ["draft", "published"]:
|
||||||
return make_response(
|
|
||||||
jsonify(
|
|
||||||
{"success": False, "message": "Invalid status value"}
|
|
||||||
),
|
|
||||||
400,
|
|
||||||
)
|
|
||||||
update_fields[field] = new_status
|
|
||||||
elif field == "source":
|
|
||||||
source_id = data.get("source")
|
|
||||||
if source_id == "default":
|
|
||||||
# Handle special "default" source
|
|
||||||
|
|
||||||
update_fields[field] = "default"
|
|
||||||
elif source_id and ObjectId.is_valid(source_id):
|
|
||||||
update_fields[field] = DBRef("sources", ObjectId(source_id))
|
|
||||||
elif source_id:
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": "Invalid source ID format provided",
|
"message": "Invalid status value. Must be 'draft' or 'published'",
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
else:
|
update_fields[field] = new_status
|
||||||
update_fields[field] = ""
|
|
||||||
elif field == "sources":
|
elif field == "source":
|
||||||
sources_list = data.get("sources", [])
|
source_id = data.get("source")
|
||||||
if sources_list and isinstance(sources_list, list):
|
|
||||||
valid_sources = []
|
|
||||||
for source_id in sources_list:
|
|
||||||
if source_id == "default":
|
if source_id == "default":
|
||||||
valid_sources.append("default")
|
update_fields[field] = "default"
|
||||||
elif ObjectId.is_valid(source_id):
|
elif source_id and ObjectId.is_valid(source_id):
|
||||||
valid_sources.append(
|
update_fields[field] = DBRef("sources", ObjectId(source_id))
|
||||||
DBRef("sources", ObjectId(source_id))
|
elif source_id:
|
||||||
)
|
|
||||||
else:
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
@@ -1902,64 +1902,156 @@ class UpdateAgent(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
update_fields[field] = ""
|
||||||
|
|
||||||
|
elif field == "sources":
|
||||||
|
sources_list = data.get("sources", [])
|
||||||
|
if sources_list and isinstance(sources_list, list):
|
||||||
|
valid_sources = []
|
||||||
|
for source_id in sources_list:
|
||||||
|
if source_id == "default":
|
||||||
|
valid_sources.append("default")
|
||||||
|
elif ObjectId.is_valid(source_id):
|
||||||
|
valid_sources.append(DBRef("sources", ObjectId(source_id)))
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"Invalid source ID in list: {source_id}",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
update_fields[field] = valid_sources
|
update_fields[field] = valid_sources
|
||||||
else:
|
else:
|
||||||
update_fields[field] = []
|
update_fields[field] = []
|
||||||
|
|
||||||
elif field == "chunks":
|
elif field == "chunks":
|
||||||
chunks_value = data.get("chunks")
|
chunks_value = data.get("chunks")
|
||||||
if chunks_value == "":
|
if chunks_value == "" or chunks_value is None:
|
||||||
update_fields[field] = "0"
|
update_fields[field] = "2"
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if int(chunks_value) < 0:
|
chunks_int = int(chunks_value)
|
||||||
|
if chunks_int < 0:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": "Chunks value must be a positive integer",
|
"message": "Chunks value must be a non-negative integer",
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
update_fields[field] = chunks_value
|
update_fields[field] = str(chunks_int)
|
||||||
except ValueError:
|
except (ValueError, TypeError):
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": "Invalid chunks value provided",
|
"message": f"Invalid chunks value: {chunks_value}",
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif field == "tools":
|
||||||
|
tools_list = data.get("tools", [])
|
||||||
|
if isinstance(tools_list, list):
|
||||||
|
update_fields[field] = tools_list
|
||||||
else:
|
else:
|
||||||
update_fields[field] = data[field]
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Tools must be a list",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif field == "json_schema":
|
||||||
|
json_schema = data.get("json_schema")
|
||||||
|
if json_schema is not None:
|
||||||
|
if not isinstance(json_schema, dict):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "JSON schema must be a valid object",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
update_fields[field] = json_schema
|
||||||
|
else:
|
||||||
|
update_fields[field] = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
value = data[field]
|
||||||
|
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||||
|
if not value or not str(value).strip():
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"Field '{field}' cannot be empty",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
update_fields[field] = value
|
||||||
|
|
||||||
if image_url:
|
if image_url:
|
||||||
update_fields["image"] = image_url
|
update_fields["image"] = image_url
|
||||||
|
|
||||||
if not update_fields:
|
if not update_fields:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "No update data provided"}), 400
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "No valid update data provided",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
newly_generated_key = None
|
newly_generated_key = None
|
||||||
final_status = update_fields.get("status", existing_agent.get("status"))
|
final_status = update_fields.get("status", existing_agent.get("status"))
|
||||||
|
|
||||||
if final_status == "published":
|
if final_status == "published":
|
||||||
required_published_fields = [
|
required_published_fields = {
|
||||||
"name",
|
"name": "Agent name",
|
||||||
"description",
|
"description": "Agent description",
|
||||||
"source",
|
"chunks": "Chunks count",
|
||||||
"chunks",
|
"prompt_id": "Prompt",
|
||||||
"retriever",
|
"agent_type": "Agent type",
|
||||||
"prompt_id",
|
}
|
||||||
"agent_type",
|
|
||||||
]
|
|
||||||
missing_published_fields = []
|
missing_published_fields = []
|
||||||
for req_field in required_published_fields:
|
for req_field, field_label in required_published_fields.items():
|
||||||
final_value = update_fields.get(
|
final_value = update_fields.get(
|
||||||
req_field, existing_agent.get(req_field)
|
req_field, existing_agent.get(req_field)
|
||||||
)
|
)
|
||||||
if req_field == "source" and final_value:
|
if not final_value:
|
||||||
if not isinstance(final_value, DBRef):
|
missing_published_fields.append(field_label)
|
||||||
missing_published_fields.append(req_field)
|
|
||||||
|
source_val = update_fields.get("source", existing_agent.get("source"))
|
||||||
|
sources_val = update_fields.get(
|
||||||
|
"sources", existing_agent.get("sources", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
has_valid_source = (
|
||||||
|
isinstance(source_val, DBRef)
|
||||||
|
or source_val == "default"
|
||||||
|
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_valid_source:
|
||||||
|
missing_published_fields.append("Source")
|
||||||
|
|
||||||
if missing_published_fields:
|
if missing_published_fields:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -1970,9 +2062,11 @@ class UpdateAgent(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not existing_agent.get("key"):
|
if not existing_agent.get("key"):
|
||||||
newly_generated_key = str(uuid.uuid4())
|
newly_generated_key = str(uuid.uuid4())
|
||||||
update_fields["key"] = newly_generated_key
|
update_fields["key"] = newly_generated_key
|
||||||
|
|
||||||
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1985,20 +2079,22 @@ class UpdateAgent(Resource):
|
|||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": "Agent not found or update failed unexpectedly",
|
"message": "Agent not found or update failed",
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
404,
|
404,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.modified_count == 0 and result.matched_count == 1:
|
if result.modified_count == 0 and result.matched_count == 1:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "Agent found, but no changes were applied",
|
"message": "No changes detected",
|
||||||
|
"id": agent_id,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
304,
|
200,
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
@@ -2008,6 +2104,7 @@ class UpdateAgent(Resource):
|
|||||||
jsonify({"success": False, "message": "Database error during update"}),
|
jsonify({"success": False, "message": "Database error during update"}),
|
||||||
500,
|
500,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"id": agent_id,
|
"id": agent_id,
|
||||||
@@ -2015,10 +2112,8 @@ class UpdateAgent(Resource):
|
|||||||
}
|
}
|
||||||
if newly_generated_key:
|
if newly_generated_key:
|
||||||
response_data["key"] = newly_generated_key
|
response_data["key"] = newly_generated_key
|
||||||
return make_response(
|
|
||||||
jsonify(response_data),
|
return make_response(jsonify(response_data), 200)
|
||||||
200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@user_ns.route("/api/delete_agent")
|
@user_ns.route("/api/delete_agent")
|
||||||
|
|||||||
@@ -36,6 +36,11 @@ class ClassicRAG(BaseRetriever):
|
|||||||
self.chunks = 2
|
self.chunks = 2
|
||||||
else:
|
else:
|
||||||
self.chunks = chunks
|
self.chunks = chunks
|
||||||
|
user_identifier = user_api_key if user_api_key else "default"
|
||||||
|
logging.info(
|
||||||
|
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
|
||||||
|
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||||
|
)
|
||||||
self.gpt_model = gpt_model
|
self.gpt_model = gpt_model
|
||||||
self.token_limit = (
|
self.token_limit = (
|
||||||
token_limit
|
token_limit
|
||||||
@@ -92,17 +97,12 @@ class ClassicRAG(BaseRetriever):
|
|||||||
or not self.vectorstores
|
or not self.vectorstores
|
||||||
):
|
):
|
||||||
return self.original_question
|
return self.original_question
|
||||||
prompt = f"""Given the following conversation history:
|
prompt = (
|
||||||
|
"Given the following conversation history:\n"
|
||||||
{self.chat_history}
|
f"{self.chat_history}\n\n"
|
||||||
|
"Rephrase the following user question to be a standalone search query "
|
||||||
|
"that captures all relevant context from the conversation:\n"
|
||||||
|
)
|
||||||
Rephrase the following user question to be a standalone search query
|
|
||||||
|
|
||||||
that captures all relevant context from the conversation:
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": prompt},
|
{"role": "system", "content": prompt},
|
||||||
@@ -120,10 +120,20 @@ class ClassicRAG(BaseRetriever):
|
|||||||
def _get_data(self):
|
def _get_data(self):
|
||||||
"""Retrieve relevant documents from configured vectorstores"""
|
"""Retrieve relevant documents from configured vectorstores"""
|
||||||
if self.chunks == 0 or not self.vectorstores:
|
if self.chunks == 0 or not self.vectorstores:
|
||||||
|
logging.info(
|
||||||
|
f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, "
|
||||||
|
f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
all_docs = []
|
all_docs = []
|
||||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"ClassicRAG._get_data: Starting retrieval with chunks={self.chunks}, "
|
||||||
|
f"vectorstores={self.vectorstores}, chunks_per_source={chunks_per_source}, "
|
||||||
|
f"query='{self.question[:50]}...'"
|
||||||
|
)
|
||||||
|
|
||||||
for vectorstore_id in self.vectorstores:
|
for vectorstore_id in self.vectorstores:
|
||||||
if vectorstore_id:
|
if vectorstore_id:
|
||||||
try:
|
try:
|
||||||
@@ -172,6 +182,10 @@ class ClassicRAG(BaseRetriever):
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
logging.info(
|
||||||
|
f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents "
|
||||||
|
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source})"
|
||||||
|
)
|
||||||
return all_docs
|
return all_docs
|
||||||
|
|
||||||
def search(self, query: str = ""):
|
def search(self, query: str = ""):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
@@ -9,13 +10,27 @@ from application.core.settings import settings
|
|||||||
|
|
||||||
class EmbeddingsWrapper:
|
class EmbeddingsWrapper:
|
||||||
def __init__(self, model_name, *args, **kwargs):
|
def __init__(self, model_name, *args, **kwargs):
|
||||||
|
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
|
||||||
|
try:
|
||||||
|
kwargs.setdefault("trust_remote_code", True)
|
||||||
self.model = SentenceTransformer(
|
self.model = SentenceTransformer(
|
||||||
model_name,
|
model_name,
|
||||||
config_kwargs={"allow_dangerous_deserialization": True},
|
config_kwargs={"allow_dangerous_deserialization": True},
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
|
)
|
||||||
|
if self.model is None or self.model._first_module() is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"SentenceTransformer model failed to load properly for: {model_name}"
|
||||||
)
|
)
|
||||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||||
|
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def embed_query(self, query: str):
|
def embed_query(self, query: str):
|
||||||
return self.model.encode(query).tolist()
|
return self.model.encode(query).tolist()
|
||||||
@@ -117,15 +132,29 @@ class BaseVectorStore(ABC):
|
|||||||
embeddings_name, openai_api_key=embeddings_key
|
embeddings_name, openai_api_key=embeddings_key
|
||||||
)
|
)
|
||||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||||
if os.path.exists("./models/all-mpnet-base-v2"):
|
possible_paths = [
|
||||||
|
"/app/models/all-mpnet-base-v2", # Docker absolute path
|
||||||
|
"./models/all-mpnet-base-v2", # Relative path
|
||||||
|
]
|
||||||
|
local_model_path = None
|
||||||
|
for path in possible_paths:
|
||||||
|
if os.path.exists(path):
|
||||||
|
local_model_path = path
|
||||||
|
logging.info(f"Found local model at path: {path}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logging.info(f"Path does not exist: {path}")
|
||||||
|
if local_model_path:
|
||||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||||
embeddings_name="./models/all-mpnet-base-v2",
|
local_model_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Local model not found in any of the paths: {possible_paths}. Falling back to HuggingFace download."
|
||||||
|
)
|
||||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||||
embeddings_name,
|
embeddings_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||||
|
|
||||||
return embedding_instance
|
return embedding_instance
|
||||||
|
|||||||
@@ -46,11 +46,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
image: '',
|
image: '',
|
||||||
source: '',
|
source: '',
|
||||||
sources: [],
|
sources: [],
|
||||||
chunks: '',
|
chunks: '2',
|
||||||
retriever: '',
|
retriever: 'classic',
|
||||||
prompt_id: 'default',
|
prompt_id: 'default',
|
||||||
tools: [],
|
tools: [],
|
||||||
agent_type: '',
|
agent_type: 'classic',
|
||||||
status: '',
|
status: '',
|
||||||
json_schema: undefined,
|
json_schema: undefined,
|
||||||
});
|
});
|
||||||
@@ -122,7 +122,8 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
agent.name && agent.description && agent.prompt_id && agent.agent_type;
|
agent.name && agent.description && agent.prompt_id && agent.agent_type;
|
||||||
const isJsonSchemaValidOrEmpty =
|
const isJsonSchemaValidOrEmpty =
|
||||||
jsonSchemaText.trim() === '' || jsonSchemaValid;
|
jsonSchemaText.trim() === '' || jsonSchemaValid;
|
||||||
return hasRequiredFields && isJsonSchemaValidOrEmpty;
|
const hasSource = selectedSourceIds.size > 0;
|
||||||
|
return hasRequiredFields && isJsonSchemaValidOrEmpty && hasSource;
|
||||||
};
|
};
|
||||||
|
|
||||||
const isJsonSchemaInvalid = () => {
|
const isJsonSchemaInvalid = () => {
|
||||||
@@ -353,6 +354,26 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
getPrompts();
|
getPrompts();
|
||||||
}, [token]);
|
}, [token]);
|
||||||
|
|
||||||
|
// Auto-select default source if none selected
|
||||||
|
useEffect(() => {
|
||||||
|
if (sourceDocs && sourceDocs.length > 0 && selectedSourceIds.size === 0) {
|
||||||
|
const defaultSource = sourceDocs.find((s) => s.name === 'Default');
|
||||||
|
if (defaultSource) {
|
||||||
|
setSelectedSourceIds(
|
||||||
|
new Set([
|
||||||
|
defaultSource.id || defaultSource.retriever || defaultSource.name,
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setSelectedSourceIds(
|
||||||
|
new Set([
|
||||||
|
sourceDocs[0].id || sourceDocs[0].retriever || sourceDocs[0].name,
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [sourceDocs, selectedSourceIds.size]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if ((mode === 'edit' || mode === 'draft') && agentId) {
|
if ((mode === 'edit' || mode === 'draft') && agentId) {
|
||||||
const getAgent = async () => {
|
const getAgent = async () => {
|
||||||
@@ -650,7 +671,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
}
|
}
|
||||||
selectedIds={selectedSourceIds}
|
selectedIds={selectedSourceIds}
|
||||||
onSelectionChange={(newSelectedIds: Set<string | number>) => {
|
onSelectionChange={(newSelectedIds: Set<string | number>) => {
|
||||||
|
if (
|
||||||
|
newSelectedIds.size === 0 &&
|
||||||
|
sourceDocs &&
|
||||||
|
sourceDocs.length > 0
|
||||||
|
) {
|
||||||
|
const defaultSource = sourceDocs.find(
|
||||||
|
(s) => s.name === 'Default',
|
||||||
|
);
|
||||||
|
if (defaultSource) {
|
||||||
|
setSelectedSourceIds(
|
||||||
|
new Set([
|
||||||
|
defaultSource.id ||
|
||||||
|
defaultSource.retriever ||
|
||||||
|
defaultSource.name,
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setSelectedSourceIds(
|
||||||
|
new Set([
|
||||||
|
sourceDocs[0].id ||
|
||||||
|
sourceDocs[0].retriever ||
|
||||||
|
sourceDocs[0].name,
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
setSelectedSourceIds(newSelectedIds);
|
setSelectedSourceIds(newSelectedIds);
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
title="Select Sources"
|
title="Select Sources"
|
||||||
searchPlaceholder="Search sources..."
|
searchPlaceholder="Search sources..."
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ interface TableProps {
|
|||||||
minWidth?: string;
|
minWidth?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
interface TableContainerProps {
|
interface TableContainerProps {
|
||||||
children: React.ReactNode;
|
children: React.ReactNode;
|
||||||
className?: string;
|
className?: string;
|
||||||
@@ -34,12 +33,8 @@ interface TableCellProps {
|
|||||||
align?: 'left' | 'right' | 'center';
|
align?: 'left' | 'right' | 'center';
|
||||||
}
|
}
|
||||||
|
|
||||||
const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(({
|
const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(
|
||||||
children,
|
({ children, className = '', height = 'auto', bordered = true }, ref) => {
|
||||||
className = '',
|
|
||||||
height = 'auto',
|
|
||||||
bordered = true
|
|
||||||
}, ref) => {
|
|
||||||
return (
|
return (
|
||||||
<div className={`relative rounded-[6px] ${className}`}>
|
<div className={`relative rounded-[6px] ${className}`}>
|
||||||
<div
|
<div
|
||||||
@@ -47,32 +42,35 @@ const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(({
|
|||||||
className={`w-full overflow-x-auto rounded-[6px] bg-transparent ${bordered ? 'border border-[#D7D7D7] dark:border-[#6A6A6A]' : ''}`}
|
className={`w-full overflow-x-auto rounded-[6px] bg-transparent ${bordered ? 'border border-[#D7D7D7] dark:border-[#6A6A6A]' : ''}`}
|
||||||
style={{
|
style={{
|
||||||
maxHeight: height === 'auto' ? undefined : height,
|
maxHeight: height === 'auto' ? undefined : height,
|
||||||
overflowY: height === 'auto' ? 'hidden' : 'auto'
|
overflowY: height === 'auto' ? 'hidden' : 'auto',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
});;
|
},
|
||||||
|
);
|
||||||
|
TableContainer.displayName = 'TableContainer';
|
||||||
|
|
||||||
const Table: React.FC<TableProps> = ({
|
const Table: React.FC<TableProps> = ({
|
||||||
children,
|
children,
|
||||||
className = '',
|
className = '',
|
||||||
minWidth = 'min-w-[600px]'
|
minWidth = 'min-w-[600px]',
|
||||||
}) => {
|
}) => {
|
||||||
return (
|
return (
|
||||||
<table className={`w-full table-auto border-collapse bg-transparent ${minWidth} ${className}`}>
|
<table
|
||||||
|
className={`w-full table-auto border-collapse bg-transparent ${minWidth} ${className}`}
|
||||||
|
>
|
||||||
{children}
|
{children}
|
||||||
</table>
|
</table>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
const TableHead: React.FC<TableHeadProps> = ({ children, className = '' }) => {
|
const TableHead: React.FC<TableHeadProps> = ({ children, className = '' }) => {
|
||||||
return (
|
return (
|
||||||
<thead className={`
|
<thead
|
||||||
sticky top-0 z-10
|
className={`sticky top-0 z-10 bg-gray-100 dark:bg-[#27282D] ${className} `}
|
||||||
bg-gray-100 dark:bg-[#27282D]
|
>
|
||||||
${className}
|
|
||||||
`}>
|
|
||||||
{children}
|
{children}
|
||||||
</thead>
|
</thead>
|
||||||
);
|
);
|
||||||
@@ -86,12 +84,20 @@ const TableBody: React.FC<TableHeadProps> = ({ children, className = '' }) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const TableRow: React.FC<TableRowProps> = ({ children, className = '', onClick }) => {
|
const TableRow: React.FC<TableRowProps> = ({
|
||||||
const baseClasses = "border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]";
|
children,
|
||||||
const cursorClass = onClick ? "cursor-pointer" : "";
|
className = '',
|
||||||
|
onClick,
|
||||||
|
}) => {
|
||||||
|
const baseClasses =
|
||||||
|
'border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]';
|
||||||
|
const cursorClass = onClick ? 'cursor-pointer' : '';
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<tr className={`${baseClasses} ${cursorClass} ${className}`} onClick={onClick}>
|
<tr
|
||||||
|
className={`${baseClasses} ${cursorClass} ${className}`}
|
||||||
|
onClick={onClick}
|
||||||
|
>
|
||||||
{children}
|
{children}
|
||||||
</tr>
|
</tr>
|
||||||
);
|
);
|
||||||
@@ -102,7 +108,7 @@ const TableHeader: React.FC<TableCellProps> = ({
|
|||||||
className = '',
|
className = '',
|
||||||
minWidth,
|
minWidth,
|
||||||
width,
|
width,
|
||||||
align = 'left'
|
align = 'left',
|
||||||
}) => {
|
}) => {
|
||||||
const getAlignmentClass = () => {
|
const getAlignmentClass = () => {
|
||||||
switch (align) {
|
switch (align) {
|
||||||
@@ -133,7 +139,7 @@ const TableCell: React.FC<TableCellProps> = ({
|
|||||||
className = '',
|
className = '',
|
||||||
minWidth,
|
minWidth,
|
||||||
width,
|
width,
|
||||||
align = 'left'
|
align = 'left',
|
||||||
}) => {
|
}) => {
|
||||||
const getAlignmentClass = () => {
|
const getAlignmentClass = () => {
|
||||||
switch (align) {
|
switch (align) {
|
||||||
|
|||||||
Reference in New Issue
Block a user