mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Compare commits
33 Commits
hacktoberf
...
memory-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b43f9f3cb2 | ||
|
|
e012189672 | ||
|
|
4c31e9a8b1 | ||
|
|
7cfc230316 | ||
|
|
9605e85f1c | ||
|
|
498e2b772c | ||
|
|
dad897da51 | ||
|
|
02ad5f062e | ||
|
|
4eb9471b4f | ||
|
|
b505d207d7 | ||
|
|
3c954bd07f | ||
|
|
c00b6459dc | ||
|
|
eb4d776784 | ||
|
|
5d7a890533 | ||
|
|
9c6aefef1e | ||
|
|
e4554d6c09 | ||
|
|
c184b63df8 | ||
|
|
6bb4195393 | ||
|
|
7827a4d40d | ||
|
|
f09fa8231a | ||
|
|
96ff10000d | ||
|
|
9460636867 | ||
|
|
6c43245295 | ||
|
|
266b6cf638 | ||
|
|
70183e234a | ||
|
|
17b9c359ca | ||
|
|
045630b8a5 | ||
|
|
55ff7dd640 | ||
|
|
5b2738aec9 | ||
|
|
892312fc08 | ||
|
|
444b1a0b65 | ||
|
|
814ea1c016 | ||
|
|
7c15a4c7ff |
@@ -35,4 +35,4 @@ Non-Code Contributions:
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
We will publish a t-shirt desing later into the October.
|
||||
We will publish a t-shirt design later into the October.
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://twitter.com/docsgptai"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
@@ -67,7 +67,7 @@
|
||||
- [x] Json Responses (August 2025)
|
||||
- [x] MCP support (August 2025)
|
||||
- [x] Google Drive integration (September 2025)
|
||||
- [ ] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||
- [ ] Sharepoint integration (October 2025)
|
||||
- [ ] Deep Agents (October 2025)
|
||||
- [ ] Agent scheduling
|
||||
|
||||
@@ -213,18 +213,24 @@ class BaseAgent(ABC):
|
||||
):
|
||||
target_dict[param] = value
|
||||
tm = ToolManager(config={})
|
||||
|
||||
# Prepare tool_config and add tool_id for memory tools
|
||||
if tool_data["name"] == "api_tool":
|
||||
tool_config = {
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
||||
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
tool_config=tool_config,
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
|
||||
@@ -37,7 +37,7 @@ _mcp_clients_cache = {}
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
|
||||
546
application/agents/tools/memory.py
Normal file
546
application/agents/tools/memory.py
Normal file
@@ -0,0 +1,546 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
"""Memory
|
||||
|
||||
Stores and retrieves information across conversations through a memory file directory.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated memories
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
kwargs.get("view_range")
|
||||
)
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("file_text", "")
|
||||
)
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("old_str", ""),
|
||||
kwargs.get("new_str", "")
|
||||
)
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("insert_line", 1),
|
||||
kwargs.get("insert_text", "")
|
||||
)
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("path", ""))
|
||||
|
||||
if action_name == "rename":
|
||||
return self._rename(
|
||||
kwargs.get("old_path", ""),
|
||||
kwargs.get("new_path", "")
|
||||
)
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||
},
|
||||
"view_range": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create or overwrite a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||
},
|
||||
"file_text": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "file_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "String to find."
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "String to replace with."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_str", "new_str"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at a specific line in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"insert_line": {
|
||||
"type": "integer",
|
||||
"description": "Line number to insert at (1-indexed)."
|
||||
},
|
||||
"insert_text": {
|
||||
"type": "string",
|
||||
"description": "Text to insert."
|
||||
}
|
||||
},
|
||||
"required": ["path", "insert_line", "insert_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "rename",
|
||||
"description": "Rename or move a file/directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_path": {
|
||||
"type": "string",
|
||||
"description": "Current path (e.g., /old.txt)."
|
||||
},
|
||||
"new_path": {
|
||||
"type": "string",
|
||||
"description": "New path (e.g., /new.txt)."
|
||||
}
|
||||
},
|
||||
"required": ["old_path", "new_path"]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Path validation
|
||||
# -----------------------------
|
||||
def _validate_path(self, path: str) -> Optional[str]:
|
||||
"""Validate and normalize path.
|
||||
|
||||
Args:
|
||||
path: User-provided path.
|
||||
|
||||
Returns:
|
||||
Normalized path or None if invalid.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
path = path.strip()
|
||||
|
||||
# Preserve whether path ends with / (indicates directory)
|
||||
is_directory = path.endswith("/")
|
||||
|
||||
# Ensure path starts with / for consistency
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for directory traversal patterns
|
||||
if ".." in path or path.count("//") > 0:
|
||||
return None
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
# Convert to Path object and resolve to canonical form
|
||||
normalized = str(Path(path).as_posix())
|
||||
|
||||
# Ensure it still starts with /
|
||||
if not normalized.startswith("/"):
|
||||
return None
|
||||
|
||||
# Preserve trailing slash for directories
|
||||
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||
normalized = normalized + "/"
|
||||
|
||||
return normalized
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View directory contents or file contents."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
# Check if viewing directory (ends with / or is root)
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return self._view_directory(validated_path)
|
||||
|
||||
# Otherwise view file
|
||||
return self._view_file(validated_path, view_range)
|
||||
|
||||
def _view_directory(self, path: str) -> str:
|
||||
"""List files in a directory."""
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
|
||||
# Extract filenames relative to the directory
|
||||
files = []
|
||||
for doc in docs:
|
||||
file_path = doc["path"]
|
||||
# Remove the directory prefix
|
||||
if file_path.startswith(search_path):
|
||||
relative = file_path[len(search_path):]
|
||||
if relative:
|
||||
files.append(relative)
|
||||
|
||||
files.sort()
|
||||
file_list = "\n".join(f"- {f}" for f in files)
|
||||
return f"Directory: {path}\n{file_list}"
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = str(doc["content"])
|
||||
|
||||
# Apply view_range if specified
|
||||
if view_range and len(view_range) == 2:
|
||||
lines = content.split("\n")
|
||||
start, end = view_range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start - 1)
|
||||
end_idx = min(len(lines), end)
|
||||
|
||||
if start_idx >= len(lines):
|
||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
# Add line numbers (enumerate with 1-based start)
|
||||
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||
return "\n".join(numbered_lines)
|
||||
|
||||
return content
|
||||
|
||||
def _create(self, path: str, file_text: str) -> str:
|
||||
"""Create or overwrite a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||
"""Replace text in a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||
"""Insert text at a specific line."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
def _delete(self, path: str) -> str:
|
||||
"""Delete a file or directory."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
search_path = validated_path + "/"
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
def _rename(self, old_path: str, new_path: str) -> str:
|
||||
"""Rename or move a file/directory."""
|
||||
validated_old = self._validate_path(old_path)
|
||||
validated_new = self._validate_path(new_path)
|
||||
|
||||
if not validated_old or not validated_new:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Check if renaming a directory
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
199
application/agents/tools/notes.py
Normal file
199
application/agents/tools/notes.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
"""Notepad
|
||||
|
||||
Single note. Supports viewing, overwriting, string replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated notes
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
if action_name == "overwrite":
|
||||
return self._overwrite_note(kwargs.get("text", ""))
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete_note()
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Retrieve the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "overwrite",
|
||||
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "New note content."}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_str": {"type": "string", "description": "String to find."},
|
||||
"new_str": {"type": "string", "description": "String to replace with."}
|
||||
},
|
||||
"required": ["old_str", "new_str"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at the specified line number (1-indexed).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||
"text": {"type": "string", "description": "Text to insert."}
|
||||
},
|
||||
"required": ["line_number", "text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _get_note(self) -> str:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True, # ✅ create if missing
|
||||
)
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
return f"String '{old_str}' not found in note."
|
||||
|
||||
# Case-insensitive replacement
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
index = line_number - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Invalid line number. Note has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return "Note deleted." if res.deleted_count else "No note found to delete."
|
||||
@@ -28,7 +28,7 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
if tool_name in {"mcp_tool", "notes", "memory"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
@@ -36,7 +36,7 @@ class ToolManager:
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
if tool_name in {"mcp_tool", "memory"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""User API module - provides all user-related API endpoints"""
|
||||
|
||||
from .routes import user
|
||||
|
||||
__all__ = ["user"]
|
||||
|
||||
7
application/api/user/agents/__init__.py
Normal file
7
application/api/user/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Agents module."""
|
||||
|
||||
from .routes import agents_ns
|
||||
from .sharing import agents_sharing_ns
|
||||
from .webhooks import agents_webhooks_ns
|
||||
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
|
||||
910
application/api/user/agents/routes.py
Normal file
910
application/api/user/agents/routes.py
Normal file
@@ -0,0 +1,910 @@
|
||||
"""Agent management routes."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
handle_image_upload,
|
||||
resolve_tool_details,
|
||||
storage,
|
||||
users_collection,
|
||||
)
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
generate_image_url,
|
||||
validate_required_fields,
|
||||
)
|
||||
|
||||
|
||||
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
|
||||
|
||||
|
||||
@agents_ns.route("/get_agent")
|
||||
class GetAgent(Resource):
|
||||
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
|
||||
def get(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
if not (agent_id := request.args.get("id")):
|
||||
return {"success": False, "message": "ID required"}, 400
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
if not agent:
|
||||
return {"status": "Not found"}, 404
|
||||
data = {
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent["name"],
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"sources": [
|
||||
(
|
||||
str(db.dereference(source_ref)["_id"])
|
||||
if isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
else source_ref
|
||||
)
|
||||
for source_ref in agent.get("sources", [])
|
||||
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
|
||||
or source_ref == "default"
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
"key": (
|
||||
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
||||
if "key" in agent
|
||||
else ""
|
||||
),
|
||||
"pinned": agent.get("pinned", False),
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
|
||||
return {"success": False}, 400
|
||||
|
||||
|
||||
@agents_ns.route("/get_agents")
|
||||
class GetAgents(Resource):
|
||||
@api.doc(description="Retrieve agents for the user")
|
||||
def get(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
user_doc = ensure_user_doc(user)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
agents = agents_collection.find({"user": user})
|
||||
list_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent["name"],
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(agent.get("source")))
|
||||
else (
|
||||
agent.get("source", "")
|
||||
if agent.get("source") == "default"
|
||||
else ""
|
||||
)
|
||||
),
|
||||
"sources": [
|
||||
(
|
||||
source_ref
|
||||
if source_ref == "default"
|
||||
else str(db.dereference(source_ref)["_id"])
|
||||
)
|
||||
for source_ref in agent.get("sources", [])
|
||||
if source_ref == "default"
|
||||
or (
|
||||
isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
)
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
"key": (
|
||||
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
||||
if "key" in agent
|
||||
else ""
|
||||
),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_agents), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/create_agent")
|
||||
class CreateAgent(Resource):
|
||||
create_agent_model = api.model(
|
||||
"CreateAgentModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="Name of the agent"),
|
||||
"description": fields.String(
|
||||
required=True, description="Description of the agent"
|
||||
),
|
||||
"image": fields.Raw(
|
||||
required=False, description="Image file upload", type="file"
|
||||
),
|
||||
"source": fields.String(
|
||||
required=False, description="Source ID (legacy single source)"
|
||||
),
|
||||
"sources": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_agent_model)
|
||||
@api.doc(description="Create a new agent")
|
||||
def post(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
user = decoded_token.get("sub")
|
||||
if request.content_type == "application/json":
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
if "tools" in data:
|
||||
try:
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "sources" in data:
|
||||
try:
|
||||
data["sources"] = json.loads(data["sources"])
|
||||
except json.JSONDecodeError:
|
||||
data["sources"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
|
||||
if data.get("json_schema"):
|
||||
try:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid JSON object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Status must be either 'draft' or 'published'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("status") == "published":
|
||||
required_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"agent_type",
|
||||
]
|
||||
# Require either source or sources (but not both)
|
||||
|
||||
if not data.get("source") and not data.get("sources"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Either 'source' or 'sources' field is required for published agents",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
validate_fields = ["name", "description", "prompt_id", "agent_type"]
|
||||
else:
|
||||
required_fields = ["name"]
|
||||
validate_fields = []
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
invalid_fields = validate_required_fields(data, validate_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if invalid_fields:
|
||||
return invalid_fields
|
||||
image_url, error = handle_image_upload(request, "", user, storage)
|
||||
if error:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}), 400
|
||||
)
|
||||
try:
|
||||
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
|
||||
|
||||
sources_list = []
|
||||
if data.get("sources") and len(data.get("sources", [])) > 0:
|
||||
for source_id in data.get("sources", []):
|
||||
if source_id == "default":
|
||||
sources_list.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
sources_list.append(DBRef("sources", ObjectId(source_id)))
|
||||
source_field = ""
|
||||
else:
|
||||
source_value = data.get("source", "")
|
||||
if source_value == "default":
|
||||
source_field = "default"
|
||||
elif ObjectId.is_valid(source_value):
|
||||
source_field = DBRef("sources", ObjectId(source_value))
|
||||
else:
|
||||
source_field = ""
|
||||
new_agent = {
|
||||
"user": user,
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description", ""),
|
||||
"image": image_url,
|
||||
"source": source_field,
|
||||
"sources": sources_list,
|
||||
"chunks": data.get("chunks", ""),
|
||||
"retriever": data.get("retriever", ""),
|
||||
"prompt_id": data.get("prompt_id", ""),
|
||||
"tools": data.get("tools", []),
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
if (
|
||||
new_agent["source"] == ""
|
||||
and new_agent["retriever"] == ""
|
||||
and not new_agent["sources"]
|
||||
):
|
||||
new_agent["retriever"] = "classic"
|
||||
resp = agents_collection.insert_one(new_agent)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id, "key": key}), 201)
|
||||
|
||||
|
||||
@agents_ns.route("/update_agent/<string:agent_id>")
|
||||
class UpdateAgent(Resource):
|
||||
update_agent_model = api.model(
|
||||
"UpdateAgentModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="New name of the agent"),
|
||||
"description": fields.String(
|
||||
required=True, description="New description of the agent"
|
||||
),
|
||||
"image": fields.String(
|
||||
required=False, description="New image URL or identifier"
|
||||
),
|
||||
"source": fields.String(
|
||||
required=False, description="Source ID (legacy single source)"
|
||||
),
|
||||
"sources": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_agent_model)
|
||||
@api.doc(description="Update an existing agent")
|
||||
def put(self, agent_id):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
if not ObjectId.is_valid(agent_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
|
||||
)
|
||||
oid = ObjectId(agent_id)
|
||||
|
||||
try:
|
||||
if request.content_type and "application/json" in request.content_type:
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
json_fields = ["tools", "sources", "json_schema"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
try:
|
||||
data[field] = json.loads(data[field])
|
||||
except json.JSONDecodeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid JSON format for field: {field}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error parsing request data: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid request data"}), 400
|
||||
)
|
||||
try:
|
||||
existing_agent = agents_collection.find_one({"_id": oid, "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error finding agent {agent_id}: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error finding agent"}),
|
||||
500,
|
||||
)
|
||||
if not existing_agent:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Agent not found or not authorized"}
|
||||
),
|
||||
404,
|
||||
)
|
||||
image_url, error = handle_image_upload(
|
||||
request, existing_agent.get("image", ""), user, storage
|
||||
)
|
||||
if error:
|
||||
current_app.logger.error(
|
||||
f"Image upload error for agent {agent_id}: {error}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
|
||||
400,
|
||||
)
|
||||
update_fields = {}
|
||||
allowed_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"image",
|
||||
"source",
|
||||
"sources",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"tools",
|
||||
"agent_type",
|
||||
"status",
|
||||
"json_schema",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
if field not in data:
|
||||
continue
|
||||
if field == "status":
|
||||
new_status = data.get("status")
|
||||
if new_status not in ["draft", "published"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid status value. Must be 'draft' or 'published'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = new_status
|
||||
elif field == "source":
|
||||
source_id = data.get("source")
|
||||
if source_id == "default":
|
||||
update_fields[field] = "default"
|
||||
elif source_id and ObjectId.is_valid(source_id):
|
||||
update_fields[field] = DBRef("sources", ObjectId(source_id))
|
||||
elif source_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid source ID format: {source_id}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
else:
|
||||
update_fields[field] = ""
|
||||
elif field == "sources":
|
||||
sources_list = data.get("sources", [])
|
||||
if sources_list and isinstance(sources_list, list):
|
||||
valid_sources = []
|
||||
for source_id in sources_list:
|
||||
if source_id == "default":
|
||||
valid_sources.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
valid_sources.append(DBRef("sources", ObjectId(source_id)))
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid source ID in list: {source_id}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = valid_sources
|
||||
else:
|
||||
update_fields[field] = []
|
||||
elif field == "chunks":
|
||||
chunks_value = data.get("chunks")
|
||||
if chunks_value == "" or chunks_value is None:
|
||||
update_fields[field] = "2"
|
||||
else:
|
||||
try:
|
||||
chunks_int = int(chunks_value)
|
||||
if chunks_int < 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Chunks value must be a non-negative integer",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = str(chunks_int)
|
||||
except (ValueError, TypeError):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid chunks value: {chunks_value}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
elif field == "tools":
|
||||
tools_list = data.get("tools", [])
|
||||
if isinstance(tools_list, list):
|
||||
update_fields[field] = tools_list
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Tools must be a list",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
elif field == "json_schema":
|
||||
json_schema = data.get("json_schema")
|
||||
if json_schema is not None:
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = json_schema
|
||||
else:
|
||||
update_fields[field] = None
|
||||
else:
|
||||
value = data[field]
|
||||
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||
if not value or not str(value).strip():
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Field '{field}' cannot be empty",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = value
|
||||
if image_url:
|
||||
update_fields["image"] = image_url
|
||||
if not update_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "No valid update data provided",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
newly_generated_key = None
|
||||
final_status = update_fields.get("status", existing_agent.get("status"))
|
||||
|
||||
if final_status == "published":
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
"description": "Agent description",
|
||||
"chunks": "Chunks count",
|
||||
"prompt_id": "Prompt",
|
||||
"agent_type": "Agent type",
|
||||
}
|
||||
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
source_val = update_fields.get("source", existing_agent.get("source"))
|
||||
sources_val = update_fields.get(
|
||||
"sources", existing_agent.get("sources", [])
|
||||
)
|
||||
|
||||
has_valid_source = (
|
||||
isinstance(source_val, DBRef)
|
||||
or source_val == "default"
|
||||
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
||||
)
|
||||
|
||||
if not has_valid_source:
|
||||
missing_published_fields.append("Source")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if not existing_agent.get("key"):
|
||||
newly_generated_key = str(uuid.uuid4())
|
||||
update_fields["key"] = newly_generated_key
|
||||
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
try:
|
||||
result = agents_collection.update_one(
|
||||
{"_id": oid, "user": user}, {"$set": update_fields}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Agent not found or update failed",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
if result.modified_count == 0 and result.matched_count == 1:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": "No changes detected",
|
||||
"id": agent_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating agent {agent_id}: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error during update"}),
|
||||
500,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": agent_id,
|
||||
"message": "Agent updated successfully",
|
||||
}
|
||||
if newly_generated_key:
|
||||
response_data["key"] = newly_generated_key
|
||||
return make_response(jsonify(response_data), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/delete_agent")
|
||||
class DeleteAgent(Resource):
|
||||
@api.doc(params={"id": "ID of the agent"}, description="Delete an agent by ID")
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
deleted_agent = agents_collection.find_one_and_delete(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not deleted_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
deleted_id = str(deleted_agent["_id"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": deleted_id}), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/pinned_agents")
|
||||
class PinnedAgents(Resource):
|
||||
@api.doc(description="Get pinned agents for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
try:
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
pinned_ids = user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||
|
||||
if not pinned_ids:
|
||||
return make_response(jsonify([]), 200)
|
||||
pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids]
|
||||
|
||||
pinned_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": pinned_object_ids}}
|
||||
)
|
||||
pinned_agents = list(pinned_agents_cursor)
|
||||
existing_ids = {str(agent["_id"]) for agent in pinned_agents}
|
||||
|
||||
# Clean up any stale pinned IDs
|
||||
|
||||
stale_ids = [
|
||||
agent_id for agent_id in pinned_ids if agent_id not in existing_ids
|
||||
]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
||||
)
|
||||
list_pinned_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"source": (
|
||||
str(db.dereference(agent["source"])["_id"])
|
||||
if "source" in agent
|
||||
and agent["source"]
|
||||
and isinstance(agent["source"], DBRef)
|
||||
and db.dereference(agent["source"]) is not None
|
||||
else ""
|
||||
),
|
||||
"chunks": agent.get("chunks", ""),
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
"key": (
|
||||
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
||||
if "key" in agent
|
||||
else ""
|
||||
),
|
||||
"pinned": True,
|
||||
}
|
||||
for agent in pinned_agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving pinned agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_pinned_agents), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/pin_agent")
|
||||
class PinAgent(Resource):
|
||||
@api.doc(params={"id": "ID of the agent"}, description="Pin or unpin an agent")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
pinned_list = user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||
|
||||
if agent_id in pinned_list:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pull": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
action = "unpinned"
|
||||
else:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
action = "pinned"
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Server error"}), 500
|
||||
)
|
||||
return make_response(jsonify({"success": True, "action": action}), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/remove_shared_agent")
|
||||
class RemoveSharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the shared agent"},
|
||||
description="Remove a shared agent from the current user's shared list",
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "shared_publicly": True}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{
|
||||
"$pull": {
|
||||
"agent_preferences.shared_with_me": agent_id,
|
||||
"agent_preferences.pinned": agent_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error removing shared agent: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Server error"}), 500
|
||||
)
|
||||
254
application/api/user/agents/sharing.py
Normal file
254
application/api/user/agents/sharing.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Agent management sharing functionality."""
|
||||
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from bson import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
resolve_tool_details,
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={
|
||||
"token": "Shared token of the agent",
|
||||
},
|
||||
description="Get a shared agent by token or ID",
|
||||
)
|
||||
def get(self):
|
||||
shared_token = request.args.get("token")
|
||||
|
||||
if not shared_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
shared_agent = agents_collection.find_one(query)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
"user": shared_agent.get("user", ""),
|
||||
"name": shared_agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(shared_agent["image"])
|
||||
if shared_agent.get("image")
|
||||
else ""
|
||||
),
|
||||
"description": shared_agent.get("description", ""),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(shared_agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"chunks": shared_agent.get("chunks", "0"),
|
||||
"retriever": shared_agent.get("retriever", "classic"),
|
||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||
"tools": shared_agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
"shared_token": shared_agent.get("shared_token", ""),
|
||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||
}
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for tool in data["tools"]:
|
||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user")
|
||||
|
||||
if user_id != owner_id:
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agents")
|
||||
class SharedAgents(Resource):
|
||||
@api.doc(description="Get shared agents explicitly shared with the user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||
"shared_with_me", []
|
||||
)
|
||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||
|
||||
shared_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||
)
|
||||
shared_agents = list(shared_agents_cursor)
|
||||
|
||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
}
|
||||
for agent in shared_agents
|
||||
]
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/share_agent")
|
||||
class ShareAgent(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ShareAgentModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="ID of the agent"),
|
||||
"shared": fields.Boolean(
|
||||
required=True, description="Share or unshare the agent"
|
||||
),
|
||||
"username": fields.String(
|
||||
required=False, description="Name of the user"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Share or unshare an agent")
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing JSON body"}), 400
|
||||
)
|
||||
agent_id = data.get("id")
|
||||
shared = data.get("shared")
|
||||
username = data.get("username", "")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
if shared is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Shared parameter is required and must be true or false",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{
|
||||
"$set": {
|
||||
"shared_publicly": shared,
|
||||
"shared_metadata": shared_metadata,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
119
application/api/user/agents/webhooks.py
Normal file
119
application/api/user/agents/webhooks.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the agent"},
|
||||
description="Generate webhook URL for the agent",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id), "user": user},
|
||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error generating webhook URL: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error generating webhook URL"}),
|
||||
400,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
if payload is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid or missing JSON data in request body",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
5
application/api/user/analytics/__init__.py
Normal file
5
application/api/user/analytics/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Analytics module."""
|
||||
|
||||
from .routes import analytics_ns
|
||||
|
||||
__all__ = ["analytics_ns"]
|
||||
540
application/api/user/analytics/routes.py
Normal file
540
application/api/user/analytics/routes.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""Analytics and reporting routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
conversations_collection,
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
token_usage_collection,
|
||||
user_logs_collection,
|
||||
)
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
"GetMessageAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_message_analytics_model)
|
||||
@api.doc(description="Get message analytics based on filter option")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in message_data:
|
||||
daily_messages[entry["_id"]] = entry["count"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "messages": daily_messages}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_token_analytics")
|
||||
class GetTokenAnalytics(Resource):
|
||||
get_token_analytics_model = api.model(
|
||||
"GetTokenAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_token_analytics_model)
|
||||
@api.doc(description="Get token analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"minute": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"hour": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"day": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
match_stage,
|
||||
group_stage,
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in token_usage_data:
|
||||
if filter_option == "last_hour":
|
||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||
elif filter_option == "last_24_hour":
|
||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||
else:
|
||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "token_usage": daily_token_usage}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_feedback_analytics")
|
||||
class GetFeedbackAnalytics(Resource):
|
||||
get_feedback_analytics_model = api.model(
|
||||
"GetFeedbackAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_feedback_analytics_model)
|
||||
@api.doc(description="Get feedback analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.time",
|
||||
"positive": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"negative": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
feedback_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
|
||||
for entry in feedback_data:
|
||||
daily_feedback[entry["_id"]] = {
|
||||
"positive": entry["positive"],
|
||||
"negative": entry["negative"],
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting feedback analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "feedback": daily_feedback}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_user_logs")
|
||||
class GetUserLogs(Resource):
|
||||
get_user_logs_model = api.model(
|
||||
"GetUserLogsModel",
|
||||
{
|
||||
"page": fields.Integer(
|
||||
required=False,
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
),
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"page_size": fields.Integer(
|
||||
required=False,
|
||||
description="Number of logs per page",
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_user_logs_model)
|
||||
@api.doc(description="Get user logs with pagination")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
items_cursor = (
|
||||
user_logs_collection.find(query)
|
||||
.sort("timestamp", -1)
|
||||
.skip(skip)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
items = list(items_cursor)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("_id")),
|
||||
"action": item.get("action"),
|
||||
"level": item.get("level"),
|
||||
"user": item.get("user"),
|
||||
"question": item.get("question"),
|
||||
"sources": item.get("sources"),
|
||||
"retriever_params": item.get("retriever_params"),
|
||||
"timestamp": item.get("timestamp"),
|
||||
}
|
||||
for item in items[:page_size]
|
||||
]
|
||||
|
||||
has_more = len(items) > page_size
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
5
application/api/user/attachments/__init__.py
Normal file
5
application/api/user/attachments/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Attachments module."""
|
||||
|
||||
from .routes import attachments_ns
|
||||
|
||||
__all__ = ["attachments_ns"]
|
||||
150
application/api/user/attachments/routes.py
Normal file
150
application/api/user/attachments/routes.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""File attachments and media routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, storage
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.core.settings import settings
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
attachments_ns = Namespace(
|
||||
"attachments", description="File attachments and media operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/store_attachment")
|
||||
class StoreAttachment(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores a single attachment without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
file = request.files.get("file")
|
||||
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
user = None
|
||||
if decoded_token:
|
||||
user = safe_filename(decoded_token.get("sub"))
|
||||
elif api_key:
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
user = safe_filename(agent.get("user"))
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task.id,
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@attachments_ns.route("/images/<path:image_path>")
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
try:
|
||||
file_obj = storage.get_file(image_path)
|
||||
extension = image_path.split(".")[-1].lower()
|
||||
content_type = f"image/{extension}"
|
||||
if extension == "jpg":
|
||||
content_type = "image/jpeg"
|
||||
response = make_response(file_obj.read())
|
||||
response.headers.set("Content-Type", content_type)
|
||||
response.headers.set("Cache-Control", "max-age=86400")
|
||||
|
||||
return response
|
||||
except FileNotFoundError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error retrieving image"}), 500
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/tts")
|
||||
class TextToSpeech(Resource):
|
||||
tts_model = api.model(
|
||||
"TextToSpeechModel",
|
||||
{
|
||||
"text": fields.String(
|
||||
required=True, description="Text to be synthesized as audio"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = GoogleTTS()
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"audio_base64": audio_base64,
|
||||
"lang": detected_language,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
222
application/api/user/base.py
Normal file
222
application/api/user/base.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Shared utilities, database connections, and helper functions for user API routes.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
|
||||
|
||||
try:
|
||||
agents_collection.create_index(
|
||||
[("shared", 1)],
|
||||
name="shared_index",
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
def generate_minute_range(start_date, end_date):
|
||||
"""Generate a dictionary with minute-level time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_hourly_range(start_date, end_date):
|
||||
"""Generate a dictionary with hourly time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_date_range(start_date, end_date):
|
||||
"""Generate a dictionary with daily date ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
|
||||
for i in range((end_date - start_date).days + 1)
|
||||
}
|
||||
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure user document exists with proper agent preferences structure.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document
|
||||
"""
|
||||
default_prefs = {
|
||||
"pinned": [],
|
||||
"shared_with_me": [],
|
||||
}
|
||||
|
||||
user_doc = users_collection.find_one_and_update(
|
||||
{"user_id": user_id},
|
||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
prefs = user_doc.get("agent_preferences", {})
|
||||
updates = {}
|
||||
if "pinned" not in prefs:
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their details.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs
|
||||
|
||||
Returns:
|
||||
List of tool details with id, name, and display_name
|
||||
"""
|
||||
tools = user_tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("name", ""),
|
||||
"display_name": tool.get("displayName", tool.get("name", "")),
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store for a given source ID.
|
||||
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
|
||||
Returns:
|
||||
Vector store instance
|
||||
"""
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
def handle_image_upload(
|
||||
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
|
||||
) -> Tuple[str, Optional[Response]]:
|
||||
"""
|
||||
Handle image file upload from request.
|
||||
|
||||
Args:
|
||||
request: Flask request object
|
||||
existing_url: Existing image URL (fallback)
|
||||
user: User ID
|
||||
storage: Storage instance
|
||||
base_path: Base path for upload
|
||||
|
||||
Returns:
|
||||
Tuple of (image_url, error_response)
|
||||
"""
|
||||
image_url = existing_url
|
||||
|
||||
if "image" in request.files:
|
||||
file = request.files["image"]
|
||||
if file.filename != "":
|
||||
filename = secure_filename(file.filename)
|
||||
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
|
||||
try:
|
||||
storage.save_file(file, upload_path, storage_class="STANDARD")
|
||||
image_url = upload_path
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error uploading image: {e}")
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}),
|
||||
400,
|
||||
)
|
||||
return image_url, None
|
||||
|
||||
|
||||
def require_agent(func):
|
||||
"""
|
||||
Decorator to require valid agent webhook token.
|
||||
|
||||
Args:
|
||||
func: Function to decorate
|
||||
|
||||
Returns:
|
||||
Wrapped function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one(
|
||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||
)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["_id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
5
application/api/user/conversations/__init__.py
Normal file
5
application/api/user/conversations/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Conversation management module."""
|
||||
|
||||
from .routes import conversations_ns
|
||||
|
||||
__all__ = ["conversations_ns"]
|
||||
280
application/api/user/conversations/routes.py
Normal file
280
application/api/user/conversations/routes.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Conversation management routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
"conversations", description="Conversation management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a conversation by ID",
|
||||
params={"id": "The ID of the conversation to delete"},
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_all_conversations")
|
||||
class DeleteAllConversations(Resource):
|
||||
@api.doc(
|
||||
description="Deletes all conversations for a specific user",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_conversations")
|
||||
class GetConversations(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["_id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a single conversation by ID",
|
||||
params={"id": "The conversation ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
# Process queries to include attachment names
|
||||
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/update_conversation_name")
|
||||
class UpdateConversationName(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateConversationModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Conversation ID"),
|
||||
"name": fields.String(
|
||||
required=True, description="New name of the conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates the name of a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index": fields.Integer(
|
||||
required=True,
|
||||
description="The question number in that particular conversation",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="id of the particular conversation"
|
||||
),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Submit feedback for a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["feedback"] is None:
|
||||
# Remove feedback and feedback_timestamp if feedback is null
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Set feedback and feedback_timestamp if feedback has a value
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
5
application/api/user/prompts/__init__.py
Normal file
5
application/api/user/prompts/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Prompts module."""
|
||||
|
||||
from .routes import prompts_ns
|
||||
|
||||
__all__ = ["prompts_ns"]
|
||||
191
application/api/user/prompts/routes.py
Normal file
191
application/api/user/prompts/routes.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Prompt management routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
"prompts", description="Prompt management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@prompts_ns.route("/create_prompt")
|
||||
class CreatePrompt(Resource):
|
||||
create_prompt_model = api.model(
|
||||
"CreatePromptModel",
|
||||
{
|
||||
"content": fields.String(
|
||||
required=True, description="Content of the prompt"
|
||||
),
|
||||
"name": fields.String(required=True, description="Name of the prompt"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_prompt_model)
|
||||
@api.doc(description="Create a new prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["content", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"content": data["content"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_prompts")
|
||||
class GetPrompts(Resource):
|
||||
@api.doc(description="Get all prompts for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = [
|
||||
{"id": "default", "name": "default", "type": "public"},
|
||||
{"id": "creative", "name": "creative", "type": "public"},
|
||||
{"id": "strict", "name": "strict", "type": "public"},
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
list_prompts.append(
|
||||
{
|
||||
"id": str(prompt["_id"]),
|
||||
"name": prompt["name"],
|
||||
"type": "private",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_prompts), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_single_prompt")
|
||||
class GetSinglePrompt(Resource):
|
||||
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
prompt_id = request.args.get("id")
|
||||
if not prompt_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
if prompt_id == "default":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_combine_template = f.read()
|
||||
return make_response(jsonify({"content": chat_combine_template}), 200)
|
||||
elif prompt_id == "creative":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_reduce_creative = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_creative}), 200)
|
||||
elif prompt_id == "strict":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
|
||||
) as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"content": prompt["content"]}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/delete_prompt")
|
||||
class DeletePrompt(Resource):
|
||||
delete_prompt_model = api.model(
|
||||
"DeletePromptModel",
|
||||
{"id": fields.String(required=True, description="Prompt ID to delete")},
|
||||
)
|
||||
|
||||
@api.expect(delete_prompt_model)
|
||||
@api.doc(description="Delete a prompt by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/update_prompt")
|
||||
class UpdatePrompt(Resource):
|
||||
update_prompt_model = api.model(
|
||||
"UpdatePromptModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Prompt ID to update"),
|
||||
"name": fields.String(required=True, description="New name of the prompt"),
|
||||
"content": fields.String(
|
||||
required=True, description="New content of the prompt"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_prompt_model)
|
||||
@api.doc(description="Update an existing prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name", "content"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
File diff suppressed because it is too large
Load Diff
5
application/api/user/sharing/__init__.py
Normal file
5
application/api/user/sharing/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Sharing module."""
|
||||
|
||||
from .routes import sharing_ns
|
||||
|
||||
__all__ = ["sharing_ns"]
|
||||
301
application/api/user/sharing/routes.py
Normal file
301
application/api/user/sharing/routes.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Conversation sharing routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, inputs, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
db,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sharing_ns.route("/share")
|
||||
class ShareConversation(Resource):
|
||||
share_conversation_model = api.model(
|
||||
"ShareConversationModel",
|
||||
{
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="Conversation ID"
|
||||
),
|
||||
"user": fields.String(description="User ID (optional)"),
|
||||
"prompt_id": fields.String(description="Prompt ID (optional)"),
|
||||
"chunks": fields.Integer(description="Chunks count (optional)"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(share_conversation_model)
|
||||
@api.doc(description="Share a conversation")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
|
||||
if is_promptable is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "isPromptable is required"}), 400
|
||||
)
|
||||
conversation_id = data["conversation_id"]
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
current_n_queries = len(conversation["queries"])
|
||||
explicit_binary = Binary.from_uuid(
|
||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
||||
)
|
||||
|
||||
if is_promptable:
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = data.get("chunks", "2")
|
||||
|
||||
name = conversation["name"] + "(shared)"
|
||||
new_api_key_data = {
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
||||
if pre_existing_api_document:
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
new_api_key_data["key"] = api_uuid
|
||||
new_api_key_data["name"] = name
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
agents_collection.insert_one(new_api_key_data)
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error sharing conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sharing_ns.route("/shared_conversation/<string:identifier>")
|
||||
class GetPubliclySharedConversations(Resource):
|
||||
@api.doc(description="Get publicly shared conversations by identifier")
|
||||
def get(self, identifier: str):
|
||||
try:
|
||||
query_uuid = Binary.from_uuid(
|
||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
||||
)
|
||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
||||
conversation_queries = []
|
||||
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
and isinstance(shared["conversation_id"], DBRef)
|
||||
):
|
||||
conversation_ref = shared["conversation_id"]
|
||||
conversation = db.dereference(conversation_ref)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
conversation_queries = conversation["queries"][
|
||||
: (shared["first_n_queries"])
|
||||
]
|
||||
|
||||
for query in conversation_queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
date = conversation["_id"].generation_time.isoformat()
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation["name"],
|
||||
"timestamp": date,
|
||||
}
|
||||
if shared["isPromptable"] and "api_key" in shared:
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting shared conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
7
application/api/user/sources/__init__.py
Normal file
7
application/api/user/sources/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Sources module."""
|
||||
|
||||
from .chunks import sources_chunks_ns
|
||||
from .routes import sources_ns
|
||||
from .upload import sources_upload_ns
|
||||
|
||||
__all__ = ["sources_ns", "sources_chunks_ns", "sources_upload_ns"]
|
||||
278
application/api/user/sources/chunks.py
Normal file
278
application/api/user/sources/chunks.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Source document management chunk management."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import get_vector_store, sources_collection
|
||||
from application.utils import check_required_fields, num_tokens_from_string
|
||||
|
||||
sources_chunks_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/get_chunks")
|
||||
class GetChunks(Resource):
|
||||
@api.doc(
|
||||
description="Retrieves chunks from a document, optionally filtered by file path and search term",
|
||||
params={
|
||||
"id": "The document ID",
|
||||
"page": "Page number for pagination",
|
||||
"per_page": "Number of chunks per page",
|
||||
"path": "Optional: Filter chunks by relative file path",
|
||||
"search": "Optional: Search term to filter chunks by title or content",
|
||||
},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
page = int(request.args.get("page", 1))
|
||||
per_page = int(request.args.get("per_page", 10))
|
||||
path = request.args.get("path")
|
||||
search_term = request.args.get("search", "").strip().lower()
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
metadata = chunk.get("metadata", {})
|
||||
|
||||
# Filter by path if provided
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
# Check if the chunk's source matches the requested path
|
||||
|
||||
if not chunk_source or not chunk_source.endswith(path):
|
||||
continue
|
||||
# Filter by search term if provided
|
||||
|
||||
if search_term:
|
||||
text_match = search_term in chunk.get("text", "").lower()
|
||||
title_match = search_term in metadata.get("title", "").lower()
|
||||
|
||||
if not (text_match or title_match):
|
||||
continue
|
||||
filtered_chunks.append(chunk)
|
||||
chunks = filtered_chunks
|
||||
|
||||
total_chunks = len(chunks)
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
paginated_chunks = chunks[start:end]
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"total": total_chunks,
|
||||
"chunks": paginated_chunks,
|
||||
"path": path if path else None,
|
||||
"search": search_term if search_term else None,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/add_chunk")
|
||||
class AddChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AddChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"text": fields.String(required=True, description="Text of the chunk"),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Adds a new chunk to the document",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "text"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata", {})
|
||||
token_count = num_tokens_from_string(text)
|
||||
metadata["token_count"] = token_count
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunk_id = store.add_chunk(text, metadata)
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error adding chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/delete_chunk")
|
||||
class DeleteChunk(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a specific chunk from the document.",
|
||||
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
chunk_id = request.args.get("chunk_id")
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if deleted:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk deleted successfully"}), 200
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk not found or could not be deleted"}),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error deleting chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/update_chunk")
|
||||
class UpdateChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"chunk_id": fields.String(
|
||||
required=True, description="Chunk ID to update"
|
||||
),
|
||||
"text": fields.String(
|
||||
required=False, description="New text of the chunk"
|
||||
),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Updated metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates an existing chunk in the document.",
|
||||
)
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "chunk_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
chunk_id = data.get("chunk_id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata")
|
||||
|
||||
if text is not None:
|
||||
token_count = num_tokens_from_string(text)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["token_count"] = token_count
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
|
||||
chunks = store.get_chunks()
|
||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||
if not existing_chunk:
|
||||
return make_response(jsonify({"error": "Chunk not found"}), 404)
|
||||
new_text = text if text is not None else existing_chunk["text"]
|
||||
|
||||
if metadata is not None:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
new_metadata.update(metadata)
|
||||
else:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
if text is not None:
|
||||
new_metadata["token_count"] = num_tokens_from_string(new_text)
|
||||
try:
|
||||
new_chunk_id = store.add_chunk(new_text, new_metadata)
|
||||
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if not deleted:
|
||||
current_app.logger.warning(
|
||||
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"message": "Chunk updated successfully",
|
||||
"chunk_id": new_chunk_id,
|
||||
"original_chunk_id": chunk_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as add_error:
|
||||
current_app.logger.error(f"Failed to add updated chunk: {add_error}")
|
||||
return make_response(
|
||||
jsonify({"error": "Failed to update chunk - addition failed"}), 500
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error updating chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
350
application/api/user/sources/routes.py
Normal file
350
application/api/user/sources/routes.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Source document management routes."""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
sources_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_ns.route("/sources")
|
||||
class CombinedJson(Resource):
|
||||
@api.doc(description="Provide JSON file with combined available indexes")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = [
|
||||
{
|
||||
"name": "Default",
|
||||
"date": "default",
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "remote",
|
||||
"tokens": "",
|
||||
"retriever": "classic",
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||
data.append(
|
||||
{
|
||||
"id": str(index["_id"]),
|
||||
"name": index.get("name"),
|
||||
"date": index.get("date"),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": index.get("tokens", ""),
|
||||
"retriever": index.get("retriever", "classic"),
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
), # Add type field with default "file"
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sources/paginated")
|
||||
class PaginatedSources(Resource):
|
||||
@api.doc(description="Get document with pagination, sorting and filtering")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||
# add .strip() to remove leading and trailing whitespaces
|
||||
|
||||
search_term = request.args.get(
|
||||
"search", ""
|
||||
).strip() # add search for filter documents
|
||||
|
||||
# Prepare query for filtering
|
||||
|
||||
query = {"user": user}
|
||||
if search_term:
|
||||
query["name"] = {
|
||||
"$regex": search_term,
|
||||
"$options": "i", # using case-insensitive search
|
||||
}
|
||||
total_documents = sources_collection.count_documents(query)
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
page = min(
|
||||
max(1, page), total_pages
|
||||
) # add this to make sure page inbound is within the range
|
||||
sort_order = 1 if sort_order == "asc" else -1
|
||||
skip = (page - 1) * rows_per_page
|
||||
|
||||
try:
|
||||
documents = (
|
||||
sources_collection.find(query)
|
||||
.sort(sort_field, sort_order)
|
||||
.skip(skip)
|
||||
.limit(rows_per_page)
|
||||
)
|
||||
|
||||
paginated_docs = []
|
||||
for doc in documents:
|
||||
doc_data = {
|
||||
"id": str(doc["_id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
response = {
|
||||
"total": total_documents,
|
||||
"totalPages": total_pages,
|
||||
"currentPage": page,
|
||||
"paginated": paginated_docs,
|
||||
}
|
||||
return make_response(jsonify(response), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving paginated sources: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/docs_check")
|
||||
class CheckDocs(Resource):
|
||||
check_docs_model = api.model(
|
||||
"CheckDocsModel",
|
||||
{"docs": fields.String(required=True, description="Document name")},
|
||||
)
|
||||
|
||||
@api.expect(check_docs_model)
|
||||
@api.doc(description="Check if document exists")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["docs"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
vectorstore = "vectors/" + secure_filename(data["docs"])
|
||||
if os.path.exists(vectorstore) or data["docs"] == "default":
|
||||
return {"status": "exists"}, 200
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error checking document: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_by_ids")
|
||||
class DeleteByIds(Resource):
|
||||
@api.doc(
|
||||
description="Deletes documents from the vector store by IDs",
|
||||
params={"path": "Comma-separated list of IDs"},
|
||||
)
|
||||
def get(self):
|
||||
ids = request.args.get("path")
|
||||
if not ids:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_index(ids=ids)
|
||||
if result:
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_old")
|
||||
class DeleteOldIndexes(Resource):
|
||||
@api.doc(
|
||||
description="Deletes old indexes and associated files",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
index_path = f"indexes/{str(doc['_id'])}"
|
||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||
storage.delete_file(f"{index_path}/index.faiss")
|
||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||
storage.delete_file(f"{index_path}/index.pkl")
|
||||
else:
|
||||
vectorstore = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
file_path = doc["file_path"]
|
||||
if storage.is_directory(file_path):
|
||||
files = storage.list_files(file_path)
|
||||
for f in files:
|
||||
storage.delete_file(f)
|
||||
else:
|
||||
storage.delete_file(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/combine")
|
||||
class RedirectToSources(Resource):
|
||||
@api.doc(
|
||||
description="Redirects /api/combine to /api/sources for backward compatibility"
|
||||
)
|
||||
def get(self):
|
||||
return redirect("/api/sources", code=301)
|
||||
|
||||
|
||||
@sources_ns.route("/manage_sync")
|
||||
class ManageSync(Resource):
|
||||
manage_sync_model = api.model(
|
||||
"ManageSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID"),
|
||||
"sync_frequency": fields.String(
|
||||
required=True,
|
||||
description="Sync frequency (never, daily, weekly, monthly)",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(manage_sync_model)
|
||||
@api.doc(description="Manage sync frequency for sources")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["source_id", "sync_frequency"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
sync_frequency = data["sync_frequency"]
|
||||
|
||||
if sync_frequency not in ["never", "daily", "weekly", "monthly"]:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||
)
|
||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
||||
try:
|
||||
sources_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(source_id),
|
||||
"user": user,
|
||||
},
|
||||
update_data,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating sync frequency: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
description="Get the directory structure for a document",
|
||||
params={"id": "The document ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
try:
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
directory_structure = doc.get("directory_structure", {})
|
||||
base_path = doc.get("file_path", "")
|
||||
|
||||
provider = None
|
||||
remote_data = doc.get("remote_data")
|
||||
try:
|
||||
if isinstance(remote_data, str) and remote_data:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
provider = remote_data_obj.get("provider")
|
||||
except Exception as e:
|
||||
current_app.logger.warning(
|
||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"directory_structure": directory_structure,
|
||||
"base_path": base_path,
|
||||
"provider": provider,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
572
application/api/user/sources/upload.py
Normal file
572
application/api/user/sources/upload.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""Source document management upload functionality."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields, safe_filename
|
||||
|
||||
|
||||
sources_upload_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/upload")
|
||||
class UploadFile(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Missing required fields or files",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
|
||||
safe_user = safe_filename(user)
|
||||
dir_name = safe_filename(job_name)
|
||||
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
for file in files:
|
||||
original_filename = file.filename
|
||||
safe_file = safe_filename(original_filename)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||
file.save(temp_file_path)
|
||||
|
||||
if zipfile.is_zipfile(temp_file_path):
|
||||
try:
|
||||
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(path=temp_dir)
|
||||
|
||||
# Walk through extracted files and upload them
|
||||
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for extracted_file in files:
|
||||
if (
|
||||
os.path.join(root, extracted_file)
|
||||
== temp_file_path
|
||||
):
|
||||
continue
|
||||
rel_path = os.path.relpath(
|
||||
os.path.join(root, extracted_file), temp_dir
|
||||
)
|
||||
storage_path = f"{base_path}/{rel_path}"
|
||||
|
||||
with open(
|
||||
os.path.join(root, extracted_file), "rb"
|
||||
) as f:
|
||||
storage.save_file(f, storage_path)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error extracting zip: {e}", exc_info=True
|
||||
)
|
||||
# If zip extraction fails, save the original zip file
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
else:
|
||||
# For non-zip files, save directly
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
[
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/remote")
|
||||
class UploadRemote(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RemoteUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source of the data"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Data to process"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads remote source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Process file_ids
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
# Process folder_ids
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [
|
||||
id.strip() for id in folder_ids.split(",") if id.strip()
|
||||
]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading remote source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/manage_source_files")
|
||||
class ManageSourceFiles(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ManageSourceFilesModel",
|
||||
{
|
||||
"source_id": fields.String(
|
||||
required=True, description="Source ID to modify"
|
||||
),
|
||||
"operation": fields.String(
|
||||
required=True,
|
||||
description="Operation: 'add', 'remove', or 'remove_directory'",
|
||||
),
|
||||
"file_paths": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="File paths to remove (for remove operation)",
|
||||
),
|
||||
"directory_path": fields.String(
|
||||
required=False,
|
||||
description="Directory path to remove (for remove_directory operation)",
|
||||
),
|
||||
"file": fields.Raw(
|
||||
required=False, description="Files to add (for add operation)"
|
||||
),
|
||||
"parent_dir": fields.String(
|
||||
required=False,
|
||||
description="Parent directory path relative to source root",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Add files, remove files, or remove directories from an existing source",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
|
||||
if not source_id or not operation:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "source_id and operation are required",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation not in ["add", "remove", "remove_directory"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "operation must be 'add', 'remove', or 'remove_directory'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
ObjectId(source_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||
)
|
||||
try:
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Source not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
|
||||
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid parent directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation == "add":
|
||||
files = request.files.getlist("file")
|
||||
if not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "No files provided for add operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
added_files = []
|
||||
|
||||
target_dir = source_file_path
|
||||
if parent_dir:
|
||||
target_dir = f"{source_file_path}/{parent_dir}"
|
||||
for file in files:
|
||||
if file.filename:
|
||||
safe_filename_str = safe_filename(file.filename)
|
||||
file_path = f"{target_dir}/{safe_filename_str}"
|
||||
|
||||
# Save file to storage
|
||||
|
||||
storage.save_file(file, file_path)
|
||||
added_files.append(safe_filename_str)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Added {len(added_files)} files",
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths")
|
||||
if not file_paths_str:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "file_paths required for remove operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
file_paths = (
|
||||
json.loads(file_paths_str)
|
||||
if isinstance(file_paths_str, str)
|
||||
else file_paths_str
|
||||
)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid file_paths format"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
|
||||
if storage.file_exists(full_path):
|
||||
storage.delete_file(full_path)
|
||||
removed_files.append(file_path)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path")
|
||||
if not directory_path:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "directory_path required for remove_directory operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate directory path (prevent path traversal)
|
||||
|
||||
if directory_path.startswith("/") or ".." in directory_path:
|
||||
current_app.logger.warning(
|
||||
f"Invalid directory path attempted for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
full_directory_path = (
|
||||
f"{source_file_path}/{directory_path}"
|
||||
if directory_path
|
||||
else source_file_path
|
||||
)
|
||||
|
||||
if not storage.is_directory(full_directory_path):
|
||||
current_app.logger.warning(
|
||||
f"Directory not found or is not a directory for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Directory not found or is not a directory",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
current_app.logger.error(
|
||||
f"Failed to remove directory from storage. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Successfully removed directory. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path", "")
|
||||
error_context += f", directory_path={directory_path}"
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths", "")
|
||||
error_context += f", file_paths={file_paths_str}"
|
||||
elif operation == "add":
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
error_context += f", parent_dir={parent_dir}"
|
||||
current_app.logger.error(
|
||||
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Operation failed"}), 500
|
||||
)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/task_status")
|
||||
class TaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"TaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get celery job status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta) # Convert to a string representation
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
6
application/api/user/tools/__init__.py
Normal file
6
application/api/user/tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Tools module."""
|
||||
|
||||
from .mcp import tools_mcp_ns
|
||||
from .routes import tools_ns
|
||||
|
||||
__all__ = ["tools_ns", "tools_mcp_ns"]
|
||||
333
application/api/user/tools/mcp.py
Normal file
333
application/api/user/tools/mcp.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from email.quoprimime import unquote
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.cache import get_redis_instance
|
||||
from application.security.encryption import encrypt_credentials
|
||||
from application.utils import check_required_fields
|
||||
|
||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerTestModel",
|
||||
{
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration to test"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Test MCP server connection with provided configuration")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
|
||||
if auth_type == "api_key" and "api_key" in config:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer" and "bearer_token" in config:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config:
|
||||
auth_credentials["password"] = config["password"]
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
|
||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/save")
|
||||
class MCPServerSave(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerSaveModel",
|
||||
{
|
||||
"id": fields.String(
|
||||
required=False, description="Tool ID for updates (optional)"
|
||||
),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the MCP server"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=False, default=True, description="Tool status"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create or update MCP server with automatic tool discovery")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["displayName", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "api_key":
|
||||
if "api_key" in config and config["api_key"]:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer":
|
||||
if "bearer_token" in config and config["bearer_token"]:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config and config["username"]:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config and config["password"]:
|
||||
auth_credentials["password"] = config["password"]
|
||||
mcp_config = config.copy()
|
||||
mcp_config["auth_credentials"] = auth_credentials
|
||||
|
||||
if auth_type == "oauth":
|
||||
if not config.get("oauth_task_id"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Connection not authorized. Please complete the OAuth authorization first.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "OAuth failed or not completed. Please try authorizing again.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
actions_metadata = result.get("tools", [])
|
||||
elif auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(config=mcp_config, user_id=user)
|
||||
mcp_tool.discover_tools()
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
else:
|
||||
raise Exception(
|
||||
"No valid credentials provided for the selected authentication type"
|
||||
)
|
||||
storage_config = config.copy()
|
||||
if auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
transformed_actions = []
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"displayName": data["displayName"],
|
||||
"customName": data["displayName"],
|
||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": data.get("status", True),
|
||||
"user": user,
|
||||
}
|
||||
|
||||
tool_id = data.get("id")
|
||||
if tool_id:
|
||||
result = user_tools_collection.update_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
result = user_tools_collection.insert_one(tool_data)
|
||||
tool_id = str(result.inserted_id)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/callback")
|
||||
class MCPOAuthCallback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCallbackModel",
|
||||
{
|
||||
"code": fields.String(required=True, description="Authorization code"),
|
||||
"state": fields.String(required=True, description="State parameter"),
|
||||
"error": fields.String(
|
||||
required=False, description="Error message (if any)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Handle OAuth callback by providing the authorization code and state"
|
||||
)
|
||||
def get(self):
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
if not code or not state:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
if not redis_client:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||
)
|
||||
code = unquote(code)
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
success = manager.handle_oauth_callback(state, code, error)
|
||||
if success:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
|
||||
)
|
||||
else:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||
)
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
"""
|
||||
Get current status of OAuth flow.
|
||||
Frontend should poll this endpoint periodically.
|
||||
"""
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Task not found or expired",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
||||
)
|
||||
415
application/api/user/tools/routes.py
Normal file
415
application/api/user/tools/routes.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""Tool management routes."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
|
||||
@tools_ns.route("/available_tools")
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
doc = tool_instance.__doc__.strip()
|
||||
lines = doc.split("\n", 1)
|
||||
name = lines[0].strip()
|
||||
description = lines[1].strip() if len(lines) > 1 else ""
|
||||
tools_metadata.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"displayName": name,
|
||||
"description": description,
|
||||
"configRequirements": tool_instance.get_config_requirements(),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting available tools: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/get_tools")
|
||||
class GetTools(Resource):
|
||||
@api.doc(description="Get tools created by a user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
user_tools = []
|
||||
for tool in tools:
|
||||
tool["id"] = str(tool["_id"])
|
||||
tool.pop("_id")
|
||||
user_tools.append(tool)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/create_tool")
|
||||
class CreateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateToolModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="Name of the tool"),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the tool"
|
||||
),
|
||||
"description": fields.String(
|
||||
required=True, description="Tool description"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
"customName": fields.String(
|
||||
required=False, description="Custom name for the tool"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create a new tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = [
|
||||
"name",
|
||||
"displayName",
|
||||
"description",
|
||||
"config",
|
||||
"status",
|
||||
]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
actions_metadata = tool_instance.get_actions_metadata()
|
||||
transformed_actions = []
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
try:
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"displayName": data["displayName"],
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": data["config"],
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool")
|
||||
class UpdateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"name": fields.String(description="Name of the tool"),
|
||||
"displayName": fields.String(description="Display name for the tool"),
|
||||
"customName": fields.String(description="Custom name for the tool"),
|
||||
"description": fields.String(description="Tool description"),
|
||||
"config": fields.Raw(description="Configuration of the tool"),
|
||||
"actions": fields.List(
|
||||
fields.Raw, description="Actions the tool can perform"
|
||||
),
|
||||
"status": fields.Boolean(description="Status of the tool"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
update_data = {}
|
||||
if "name" in data:
|
||||
update_data["name"] = data["name"]
|
||||
if "displayName" in data:
|
||||
update_data["displayName"] = data["displayName"]
|
||||
if "customName" in data:
|
||||
update_data["customName"] = data["customName"]
|
||||
if "description" in data:
|
||||
update_data["description"] = data["description"]
|
||||
if "actions" in data:
|
||||
update_data["actions"] = data["actions"]
|
||||
if "config" in data:
|
||||
if "actions" in data["config"]:
|
||||
for action_name in list(data["config"]["actions"].keys()):
|
||||
if not validate_function_name(action_name):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
|
||||
"param": "tools[].function.name",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if tool_doc and tool_doc.get("name") == "mcp_tool":
|
||||
config = data["config"]
|
||||
existing_config = tool_doc.get("config", {})
|
||||
storage_config = existing_config.copy()
|
||||
|
||||
storage_config.update(config)
|
||||
existing_credentials = {}
|
||||
if "encrypted_credentials" in existing_config:
|
||||
existing_credentials = decrypt_credentials(
|
||||
existing_config["encrypted_credentials"], user
|
||||
)
|
||||
auth_credentials = existing_credentials.copy()
|
||||
auth_type = storage_config.get("auth_type", "none")
|
||||
if auth_type == "api_key":
|
||||
if "api_key" in config and config["api_key"]:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config[
|
||||
"api_key_header"
|
||||
]
|
||||
elif auth_type == "bearer":
|
||||
if "bearer_token" in config and config["bearer_token"]:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif "encrypted_token" in config and config["encrypted_token"]:
|
||||
auth_credentials["bearer_token"] = config["encrypted_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config and config["username"]:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config and config["password"]:
|
||||
auth_credentials["password"] = config["password"]
|
||||
if auth_type != "none" and auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = (
|
||||
encrypted_credentials_string
|
||||
)
|
||||
elif auth_type == "none":
|
||||
storage_config.pop("encrypted_credentials", None)
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"encrypted_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
update_data["config"] = storage_config
|
||||
else:
|
||||
update_data["config"] = data["config"]
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": update_data},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_config")
|
||||
class UpdateToolConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolConfigModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the configuration of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"config": data["config"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool config: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_actions")
|
||||
class UpdateToolActions(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolActionsModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"actions": fields.List(
|
||||
fields.Raw,
|
||||
required=True,
|
||||
description="Actions the tool can perform",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the actions of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "actions"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_status")
|
||||
class UpdateToolStatus(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolStatusModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the status of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "status"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool status: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/delete_tool")
|
||||
class DeleteTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DeleteToolModel",
|
||||
{"id": fields.String(required=True, description="Tool ID")},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Delete a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return {"success": False, "message": "Tool not found"}, 404
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||
return {"success": False}, 400
|
||||
return {"success": True}, 200
|
||||
@@ -41,10 +41,15 @@ class Settings(BaseSettings):
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
|
||||
GOOGLE_CLIENT_ID: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client ID
|
||||
)
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client secret
|
||||
)
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
@@ -26,7 +26,7 @@ class LocalStorage(BaseStorage):
|
||||
return path
|
||||
return os.path.join(self.base_dir, path)
|
||||
|
||||
def save_file(self, file_data: BinaryIO, path: str) -> dict:
|
||||
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
||||
"""Save a file to local storage."""
|
||||
full_path = self._get_full_path(path)
|
||||
|
||||
|
||||
17
frontend/package-lock.json
generated
17
frontend/package-lock.json
generated
@@ -12,7 +12,7 @@
|
||||
"chart.js": "^4.4.4",
|
||||
"clsx": "^2.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"i18next": "^24.2.0",
|
||||
"i18next": "^25.5.3",
|
||||
"i18next-browser-languagedetector": "^8.0.2",
|
||||
"lodash": "^4.17.21",
|
||||
"mermaid": "^11.6.0",
|
||||
@@ -321,9 +321,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/runtime": {
|
||||
"version": "7.27.3",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.3.tgz",
|
||||
"integrity": "sha512-7EYtGezsdiDMyY80+65EzwiGmcJqpmcZCojSXaRgdrBaGtWTgDZKq69cPIVped6MkIM78cTQ2GOiEYjwOlG4xw==",
|
||||
"version": "7.28.4",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz",
|
||||
"integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
@@ -6217,9 +6217,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/i18next": {
|
||||
"version": "24.2.0",
|
||||
"resolved": "https://registry.npmjs.org/i18next/-/i18next-24.2.0.tgz",
|
||||
"integrity": "sha512-ArJJTS1lV6lgKH7yEf4EpgNZ7+THl7bsGxxougPYiXRTJ/Fe1j08/TBpV9QsXCIYVfdE/HWG/xLezJ5DOlfBOA==",
|
||||
"version": "25.5.3",
|
||||
"resolved": "https://registry.npmjs.org/i18next/-/i18next-25.5.3.tgz",
|
||||
"integrity": "sha512-joFqorDeQ6YpIXni944upwnuHBf5IoPMuqAchGVeQLdWC2JOjxgM9V8UGLhNIIH/Q8QleRxIi0BSRQehSrDLcg==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
@@ -6234,8 +6234,9 @@
|
||||
"url": "https://www.i18next.com/how-to/faq#i18next-is-awesome.-how-can-i-support-the-project"
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.23.2"
|
||||
"@babel/runtime": "^7.27.6"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"typescript": "^5"
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
"chart.js": "^4.4.4",
|
||||
"clsx": "^2.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"i18next": "^24.2.0",
|
||||
"i18next": "^25.5.3",
|
||||
"i18next-browser-languagedetector": "^8.0.2",
|
||||
"lodash": "^4.17.21",
|
||||
"mermaid": "^11.6.0",
|
||||
|
||||
3
frontend/public/toolIcons/tool_memory.svg
Normal file
3
frontend/public/toolIcons/tool_memory.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3">
|
||||
<path d="M240-80q-33 0-56.5-23.5T160-160v-480q0-33 23.5-56.5T240-720h80v-80q0-17 11.5-28.5T360-840q17 0 28.5 11.5T400-800v80h40v-80q0-17 11.5-28.5T480-840q17 0 28.5 11.5T520-800v80h40v-80q0-17 11.5-28.5T600-840q17 0 28.5 11.5T640-800v80h80q33 0 56.5 23.5T800-640v480q0 33-23.5 56.5T720-80H240Zm0-80h480v-480H240v480Zm120-320v-80h240v80H360Zm0 120v-80h240v80H360Zm0 120v-80h160v80H360ZM240-160v-480 480Z"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 523 B |
1
frontend/public/toolIcons/tool_notes.svg
Normal file
1
frontend/public/toolIcons/tool_notes.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3"><path d="M320-240h320v-80H320v80Zm0-160h320v-80H320v80ZM240-80q-33 0-56.5-23.5T160-160v-640q0-33 23.5-56.5T240-880h320l240 240v480q0 33-23.5 56.5T720-80H240Zm280-520v-200H240v640h480v-440H520ZM240-800v200-200 640-640Z"/></svg>
|
||||
|
After Width: | Height: | Size: 334 B |
@@ -7,6 +7,7 @@ import Agents from './agents';
|
||||
import SharedAgentGate from './agents/SharedAgentGate';
|
||||
import ActionButtons from './components/ActionButtons';
|
||||
import Spinner from './components/Spinner';
|
||||
import UploadToast from './components/UploadToast';
|
||||
import Conversation from './conversation/Conversation';
|
||||
import { SharedConversation } from './conversation/SharedConversation';
|
||||
import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
@@ -37,14 +38,15 @@ function MainLayout() {
|
||||
<Navigation navOpen={navOpen} setNavOpen={setNavOpen} />
|
||||
<ActionButtons showNewChat={true} showShare={true} />
|
||||
<div
|
||||
className={`h-[calc(100dvh-64px)] overflow-auto lg:h-screen ${
|
||||
className={`h-[calc(100dvh-64px)] overflow-auto transition-all duration-300 ease-in-out lg:h-screen ${
|
||||
!(isMobile || isTablet)
|
||||
? `ml-0 ${!navOpen ? 'lg:mx-auto' : 'lg:ml-72'}`
|
||||
? `${navOpen ? 'lg:ml-72' : 'lg:ml-0'}`
|
||||
: 'ml-0 lg:ml-16'
|
||||
}`}
|
||||
>
|
||||
<Outlet />
|
||||
</div>
|
||||
<UploadToast />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import Github from './assets/git_nav.svg';
|
||||
import Hamburger from './assets/hamburger.svg';
|
||||
import openNewChat from './assets/openNewChat.svg';
|
||||
import Pin from './assets/pin.svg';
|
||||
import Robot from './assets/robot.svg';
|
||||
import AgentImage from './components/AgentImage';
|
||||
import SettingGear from './assets/settingGear.svg';
|
||||
import Spark from './assets/spark.svg';
|
||||
import SpinnerDark from './assets/spinner-dark.svg';
|
||||
@@ -292,20 +292,26 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
useDefaultDocument();
|
||||
return (
|
||||
<>
|
||||
{!navOpen && (
|
||||
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-25 lg:block">
|
||||
{(isMobile || isTablet) && navOpen && (
|
||||
<div
|
||||
className="fixed inset-0 z-10 bg-black opacity-50 transition-opacity duration-300"
|
||||
onClick={() => setNavOpen(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{
|
||||
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-300 ease-in-out lg:block">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
onClick={() => {
|
||||
setNavOpen(!navOpen);
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
className={`${
|
||||
!navOpen ? 'rotate-180' : 'rotate-0'
|
||||
} m-auto transition-all duration-200`}
|
||||
className="m-auto transition-all duration-300 ease-in-out"
|
||||
/>
|
||||
</button>
|
||||
{queries?.length > 0 && (
|
||||
@@ -313,6 +319,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
onClick={() => {
|
||||
newChat();
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={openNewChat}
|
||||
@@ -326,12 +333,12 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
}
|
||||
<div
|
||||
ref={navRef}
|
||||
className={`${
|
||||
!navOpen && '-ml-96 md:-ml-72'
|
||||
} bg-lotion dark:border-r-purple-taupe dark:bg-chinese-black fixed top-0 z-20 flex h-full w-72 flex-col border-r border-b-0 transition-all duration-20 dark:text-white`}
|
||||
} bg-lotion dark:border-r-purple-taupe dark:bg-chinese-black fixed top-0 z-20 flex h-full w-72 flex-col border-r border-b-0 transition-all duration-300 ease-in-out dark:text-white`}
|
||||
>
|
||||
<div
|
||||
className={'visible mt-2 flex h-[6vh] w-full justify-between md:h-12'}
|
||||
@@ -345,7 +352,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
}}
|
||||
>
|
||||
<a href="/" className="flex gap-1.5">
|
||||
<img className="mb-2 h-10" src={DocsGPT3} alt="DocsGPT Logo" />
|
||||
<img className="h-10" src={DocsGPT3} alt="DocsGPT Logo" />
|
||||
<p className="my-auto text-2xl font-semibold">DocsGPT</p>
|
||||
</a>
|
||||
</div>
|
||||
@@ -358,9 +365,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<img
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
className={`${
|
||||
!navOpen ? 'rotate-180' : 'rotate-0'
|
||||
} m-auto transition-all duration-200`}
|
||||
className="m-auto transition-all duration-300 ease-in-out hover:scale-110"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
@@ -419,12 +424,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex w-6 justify-center">
|
||||
<img
|
||||
src={
|
||||
agent.image && agent.image.trim() !== ''
|
||||
? agent.image
|
||||
: Robot
|
||||
}
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
alt="agent-logo"
|
||||
className="h-6 w-6 rounded-full object-contain"
|
||||
/>
|
||||
@@ -576,7 +577,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
</NavLink>
|
||||
<NavLink
|
||||
target="_blank"
|
||||
to={'https://twitter.com/docsgptai'}
|
||||
to={'https://x.com/docsgptai'}
|
||||
className={
|
||||
'rounded-full hover:bg-gray-100 dark:hover:bg-[#28292E]'
|
||||
}
|
||||
@@ -585,7 +586,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
src={Twitter}
|
||||
width={20}
|
||||
height={20}
|
||||
alt="Follow us on Twitter"
|
||||
alt="Follow us on X"
|
||||
className="m-2 self-center filter dark:invert"
|
||||
/>
|
||||
</NavLink>
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import Robot from '../assets/robot.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import ThreeDots from '../assets/three-dots.svg';
|
||||
import ContextMenu, { MenuOption } from '../components/ContextMenu';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
@@ -82,8 +82,8 @@ export default function AgentCard({
|
||||
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center gap-1 px-1">
|
||||
<img
|
||||
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
alt={`${agent.name}`}
|
||||
className="h-7 w-7 rounded-full object-contain"
|
||||
/>
|
||||
|
||||
@@ -49,7 +49,7 @@ export default function AgentLogs() {
|
||||
</p>
|
||||
</div>
|
||||
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
|
||||
<h1 className="text-eerie-black m-0 text-[40px] font-bold dark:text-white">
|
||||
<h1 className="text-eerie-black m-0 text-[32px] font-bold md:text-[40px] dark:text-white">
|
||||
Agent Logs
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
@@ -527,7 +527,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</p>
|
||||
</div>
|
||||
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
|
||||
<h1 className="text-eerie-black m-0 text-[40px] font-bold dark:text-white">
|
||||
<h1 className="text-eerie-black m-0 text-[32px] font-bold lg:text-[40px] dark:text-white">
|
||||
{modeConfig[effectiveMode].heading}
|
||||
</h1>
|
||||
<div className="flex flex-wrap items-center gap-1">
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useParams } from 'react-router-dom';
|
||||
import userService from '../api/services/userService';
|
||||
import NoFilesDarkIcon from '../assets/no-files-dark.svg';
|
||||
import NoFilesIcon from '../assets/no-files.svg';
|
||||
import Robot from '../assets/robot.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import MessageInput from '../components/MessageInput';
|
||||
import Spinner from '../components/Spinner';
|
||||
import ConversationMessages from '../conversation/ConversationMessages';
|
||||
@@ -152,12 +152,8 @@ export default function SharedAgent() {
|
||||
return (
|
||||
<div className="relative h-full w-full">
|
||||
<div className="absolute top-5 left-4 hidden items-center gap-3 sm:flex">
|
||||
<img
|
||||
src={
|
||||
sharedAgent.image && sharedAgent.image.trim() !== ''
|
||||
? sharedAgent.image
|
||||
: Robot
|
||||
}
|
||||
<AgentImage
|
||||
src={sharedAgent.image}
|
||||
alt="agent-logo"
|
||||
className="h-6 w-6 rounded-full object-contain"
|
||||
/>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import Robot from '../assets/robot.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import { Agent } from './types';
|
||||
|
||||
export default function SharedAgentCard({ agent }: { agent: Agent }) {
|
||||
@@ -6,8 +6,8 @@ export default function SharedAgentCard({ agent }: { agent: Agent }) {
|
||||
<div className="border-dark-gray dark:border-grey flex w-full max-w-[720px] flex-col rounded-3xl border p-6 shadow-xs sm:w-fit sm:min-w-[480px]">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-12 w-12 items-center justify-center overflow-hidden rounded-full p-1">
|
||||
<img
|
||||
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
className="h-full w-full rounded-full object-contain"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -8,7 +8,7 @@ import Link from '../assets/link-gray.svg';
|
||||
import Monitoring from '../assets/monitoring.svg';
|
||||
import Pin from '../assets/pin.svg';
|
||||
import Trash from '../assets/red-trash.svg';
|
||||
import Robot from '../assets/robot.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import ThreeDots from '../assets/three-dots.svg';
|
||||
import UnPin from '../assets/unpin.svg';
|
||||
import ContextMenu, { MenuOption } from '../components/ContextMenu';
|
||||
@@ -111,7 +111,7 @@ function AgentsList() {
|
||||
}, [token]);
|
||||
return (
|
||||
<div className="p-4 md:p-12">
|
||||
<h1 className="text-eerie-black mb-0 text-[40px] font-bold dark:text-[#E0E0E0]">
|
||||
<h1 className="text-eerie-black mb-0 text-[32px] font-bold lg:text-[40px] dark:text-[#E0E0E0]">
|
||||
Agents
|
||||
</h1>
|
||||
<p className="dark:text-gray-4000 mt-5 text-[15px] text-[#71717A]">
|
||||
@@ -138,11 +138,7 @@ function AgentsList() {
|
||||
</button>
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center px-1">
|
||||
<img
|
||||
src={Robot}
|
||||
alt="agent-logo"
|
||||
className="h-7 w-7 rounded-full"
|
||||
/>
|
||||
<AgentImage className="h-7 w-7 rounded-full" />
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<p
|
||||
@@ -436,8 +432,8 @@ function AgentCard({
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center gap-1 px-1">
|
||||
<img
|
||||
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
alt={`${agent.name}`}
|
||||
className="h-7 w-7 rounded-full object-contain"
|
||||
/>
|
||||
|
||||
3
frontend/src/assets/check-circle-filled.svg
Normal file
3
frontend/src/assets/check-circle-filled.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="24" height="25" viewBox="0 0 24 25" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 2.5C17.523 2.5 22 6.977 22 12.5C22 18.023 17.523 22.5 12 22.5C6.477 22.5 2 18.023 2 12.5C2 6.977 6.477 2.5 12 2.5ZM15.22 9.47L10.75 13.94L8.78 11.97C8.63783 11.8375 8.44978 11.7654 8.25548 11.7688C8.06118 11.7723 7.87579 11.851 7.73838 11.9884C7.60097 12.1258 7.52225 12.3112 7.51883 12.5055C7.5154 12.6998 7.58752 12.8878 7.72 13.03L10.22 15.53C10.3606 15.6705 10.5512 15.7493 10.75 15.7493C10.9488 15.7493 11.1394 15.6705 11.28 15.53L16.28 10.53C16.3537 10.4613 16.4128 10.3785 16.4538 10.2865C16.4948 10.1945 16.5168 10.0952 16.5186 9.99452C16.5204 9.89382 16.5018 9.79379 16.4641 9.7004C16.4264 9.60701 16.3703 9.52218 16.299 9.45096C16.2278 9.37974 16.143 9.3236 16.0496 9.28588C15.9562 9.24816 15.8562 9.22963 15.7555 9.23141C15.6548 9.23318 15.5555 9.25523 15.4635 9.29622C15.3715 9.33721 15.2887 9.39631 15.22 9.47Z" fill="#0C9D35"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 958 B |
3
frontend/src/assets/warn.svg
Normal file
3
frontend/src/assets/warn.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="20" height="21" viewBox="0 0 20 21" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M15 1.83989C16.5202 2.71758 17.7826 3.97997 18.6603 5.50017C19.538 7.02038 20 8.74483 20 10.5002C20 12.2556 19.5379 13.98 18.6602 15.5002C17.7825 17.0204 16.5201 18.2828 14.9999 19.1605C13.4797 20.0381 11.7552 20.5002 9.99984 20.5001C8.24446 20.5001 6.52002 20.038 4.99984 19.1603C3.47965 18.2826 2.21729 17.0202 1.33963 15.5C0.46198 13.9797 -4.45897e-05 12.2553 3.22765e-09 10.4999L0.00500012 10.1759C0.0610032 8.44888 0.563548 6.76585 1.46364 5.29089C2.36373 3.81592 3.63065 2.59934 5.14089 1.75977C6.65113 0.920205 8.35315 0.486289 10.081 0.50033C11.8089 0.514371 13.5036 0.97589 15 1.83989ZM10 13.4999C9.73478 13.4999 9.48043 13.6052 9.29289 13.7928C9.10536 13.9803 9 14.2347 9 14.4999V14.5099C9 14.7751 9.10536 15.0295 9.29289 15.217C9.48043 15.4045 9.73478 15.5099 10 15.5099C10.2652 15.5099 10.5196 15.4045 10.7071 15.217C10.8946 15.0295 11 14.7751 11 14.5099V14.4999C11 14.2347 10.8946 13.9803 10.7071 13.7928C10.5196 13.6052 10.2652 13.4999 10 13.4999ZM10 6.49989C9.73478 6.49989 9.48043 6.60525 9.29289 6.79279C9.10536 6.98032 9 7.23468 9 7.49989V11.4999C9 11.7651 9.10536 12.0195 9.29289 12.207C9.48043 12.3945 9.73478 12.4999 10 12.4999C10.2652 12.4999 10.5196 12.3945 10.7071 12.207C10.8946 12.0195 11 11.7651 11 11.4999V7.49989C11 7.23468 10.8946 6.98032 10.7071 6.79279C10.5196 6.60525 10.2652 6.49989 10 6.49989Z" fill="#EA4335"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
40
frontend/src/components/AgentImage.tsx
Normal file
40
frontend/src/components/AgentImage.tsx
Normal file
@@ -0,0 +1,40 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import Robot from '../assets/robot.svg';
|
||||
|
||||
type AgentImageProps = {
|
||||
src?: string | null;
|
||||
alt?: string;
|
||||
className?: string;
|
||||
fallbackSrc?: string;
|
||||
};
|
||||
|
||||
export default function AgentImage({
|
||||
src,
|
||||
alt = 'agent',
|
||||
className = '',
|
||||
fallbackSrc = Robot,
|
||||
}: AgentImageProps) {
|
||||
const [currentSrc, setCurrentSrc] = useState(
|
||||
src && src.trim() !== '' ? src : fallbackSrc,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const newSrc = src && src.trim() !== '' ? src : fallbackSrc;
|
||||
if (newSrc !== currentSrc) {
|
||||
setCurrentSrc(newSrc);
|
||||
}
|
||||
}, [src, fallbackSrc]);
|
||||
|
||||
return (
|
||||
<img
|
||||
src={currentSrc}
|
||||
alt={alt}
|
||||
className={className}
|
||||
referrerPolicy="no-referrer"
|
||||
crossOrigin="anonymous"
|
||||
onError={() => {
|
||||
if (currentSrc !== fallbackSrc) setCurrentSrc(fallbackSrc);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import React, { useRef } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDarkTheme } from '../hooks';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
|
||||
@@ -24,6 +25,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
onDisconnect,
|
||||
errorMessage,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const token = useSelector(selectToken);
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const completedRef = useRef(false);
|
||||
@@ -47,12 +49,16 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
cleanup();
|
||||
onSuccess({
|
||||
session_token: event.data.session_token,
|
||||
user_email: event.data.user_email || 'Connected User',
|
||||
user_email:
|
||||
event.data.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
});
|
||||
} else if (errorProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
onError(event.data.error || 'Authentication failed');
|
||||
onError(
|
||||
event.data.error || t('modals.uploadDoc.connectors.auth.authFailed'),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -71,13 +77,15 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
|
||||
if (!authResponse.ok) {
|
||||
throw new Error(
|
||||
`Failed to get authorization URL: ${authResponse.status}`,
|
||||
`${t('modals.uploadDoc.connectors.auth.authUrlFailed')}: ${authResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
const authData = await authResponse.json();
|
||||
if (!authData.success || !authData.authorization_url) {
|
||||
throw new Error(authData.error || 'Failed to get authorization URL');
|
||||
throw new Error(
|
||||
authData.error || t('modals.uploadDoc.connectors.auth.authUrlFailed'),
|
||||
);
|
||||
}
|
||||
|
||||
const authWindow = window.open(
|
||||
@@ -86,9 +94,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
'width=500,height=600,scrollbars=yes,resizable=yes',
|
||||
);
|
||||
if (!authWindow) {
|
||||
throw new Error(
|
||||
'Failed to open authentication window. Please allow popups.',
|
||||
);
|
||||
throw new Error(t('modals.uploadDoc.connectors.auth.popupBlocked'));
|
||||
}
|
||||
|
||||
window.addEventListener('message', handleAuthMessage as any);
|
||||
@@ -98,13 +104,17 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
clearInterval(checkClosed);
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
if (!completedRef.current) {
|
||||
onError('Authentication was cancelled');
|
||||
onError(t('modals.uploadDoc.connectors.auth.authCancelled'));
|
||||
}
|
||||
}
|
||||
}, 1000);
|
||||
intervalRef.current = checkClosed;
|
||||
} catch (error) {
|
||||
onError(error instanceof Error ? error.message : 'Authentication failed');
|
||||
onError(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: t('modals.uploadDoc.connectors.auth.authFailed'),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -147,14 +157,18 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41z"
|
||||
/>
|
||||
</svg>
|
||||
<span>Connected as {userEmail}</span>
|
||||
<span>
|
||||
{t('modals.uploadDoc.connectors.auth.connectedAs', {
|
||||
email: userEmail,
|
||||
})}
|
||||
</span>
|
||||
</div>
|
||||
{onDisconnect && (
|
||||
<button
|
||||
onClick={onDisconnect}
|
||||
className="text-xs font-medium text-[#212121] underline hover:text-gray-700"
|
||||
>
|
||||
Disconnect
|
||||
{t('modals.uploadDoc.connectors.auth.disconnect')}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -8,10 +8,6 @@ import CopyIcon from '../assets/copy.svg?react';
|
||||
|
||||
type CopyButtonProps = {
|
||||
textToCopy: string;
|
||||
bgColorLight?: string;
|
||||
bgColorDark?: string;
|
||||
hoverBgColorLight?: string;
|
||||
hoverBgColorDark?: string;
|
||||
iconSize?: string;
|
||||
padding?: string;
|
||||
showText?: boolean;
|
||||
@@ -27,14 +23,11 @@ const DEFAULT_COPIED_DURATION = 2000;
|
||||
const DEFAULT_BG_LIGHT = '#FFFFFF';
|
||||
const DEFAULT_BG_DARK = 'transparent';
|
||||
const DEFAULT_HOVER_BG_LIGHT = '#EEEEEE';
|
||||
const DEFAULT_HOVER_BG_DARK = '#4A4A4A';
|
||||
const DEFAULT_HOVER_BG_DARK = '#464152';
|
||||
|
||||
export default function CopyButton({
|
||||
textToCopy,
|
||||
bgColorLight = DEFAULT_BG_LIGHT,
|
||||
bgColorDark = DEFAULT_BG_DARK,
|
||||
hoverBgColorLight = DEFAULT_HOVER_BG_LIGHT,
|
||||
hoverBgColorDark = DEFAULT_HOVER_BG_DARK,
|
||||
|
||||
iconSize = DEFAULT_ICON_SIZE,
|
||||
padding = DEFAULT_PADDING,
|
||||
showText = false,
|
||||
@@ -50,9 +43,10 @@ export default function CopyButton({
|
||||
const iconWrapperClasses = clsx(
|
||||
'flex items-center justify-center rounded-full transition-colors duration-150 ease-in-out',
|
||||
padding,
|
||||
`bg-[${bgColorLight}] dark:bg-[${bgColorDark}]`,
|
||||
`hover:bg-[${hoverBgColorLight}] dark:hover:bg-[${hoverBgColorDark}]`,
|
||||
`bg-[${DEFAULT_BG_LIGHT}] dark:bg-[${DEFAULT_BG_DARK}]`,
|
||||
{
|
||||
[`hover:bg-[${DEFAULT_HOVER_BG_LIGHT}] dark:hover:bg-[${DEFAULT_HOVER_BG_DARK}]`]:
|
||||
!isCopied,
|
||||
'bg-green-100 dark:bg-green-900 hover:bg-green-100 dark:hover:bg-green-900':
|
||||
isCopied,
|
||||
},
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import useDrivePicker from 'react-google-drive-picker';
|
||||
|
||||
import ConnectorAuth from './ConnectorAuth';
|
||||
@@ -26,6 +27,7 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
token,
|
||||
onSelectionChange,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const [selectedFiles, setSelectedFiles] = useState<PickerFile[]>([]);
|
||||
const [selectedFolders, setSelectedFolders] = useState<PickerFile[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
@@ -66,14 +68,19 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
|
||||
if (!validateResponse.ok) {
|
||||
setIsConnected(false);
|
||||
setAuthError('Session expired. Please reconnect to Google Drive.');
|
||||
setAuthError(
|
||||
t('modals.uploadDoc.connectors.googleDrive.sessionExpired'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
const validateData = await validateResponse.json();
|
||||
if (validateData.success) {
|
||||
setUserEmail(validateData.user_email || 'Connected User');
|
||||
setUserEmail(
|
||||
validateData.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
setAccessToken(validateData.access_token || null);
|
||||
@@ -83,14 +90,14 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
setIsConnected(false);
|
||||
setAuthError(
|
||||
validateData.error ||
|
||||
'Session expired. Please reconnect your account.',
|
||||
t('modals.uploadDoc.connectors.googleDrive.sessionExpiredGeneric'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error validating session:', error);
|
||||
setAuthError('Failed to validate session. Please reconnect.');
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.validateFailed'));
|
||||
setIsConnected(false);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
@@ -103,15 +110,13 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
const sessionToken = getSessionToken('google_drive');
|
||||
|
||||
if (!sessionToken) {
|
||||
setAuthError('No valid session found. Please reconnect to Google Drive.');
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.noSession'));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!accessToken) {
|
||||
setAuthError(
|
||||
'No access token available. Please reconnect to Google Drive.',
|
||||
);
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.noAccessToken'));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
@@ -193,7 +198,7 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error opening picker:', error);
|
||||
setAuthError('Failed to open file picker. Please try again.');
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.pickerFailed'));
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
@@ -264,9 +269,12 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
<>
|
||||
<ConnectorAuth
|
||||
provider="google_drive"
|
||||
label="Connect to Google Drive"
|
||||
label={t('modals.uploadDoc.connectors.googleDrive.connect')}
|
||||
onSuccess={(data) => {
|
||||
setUserEmail(data.user_email || 'Connected User');
|
||||
setUserEmail(
|
||||
data.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
|
||||
@@ -289,26 +297,34 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h3 className="text-sm font-medium">Selected Files</h3>
|
||||
<h3 className="text-sm font-medium">
|
||||
{t('modals.uploadDoc.connectors.googleDrive.selectedFiles')}
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => handleOpenPicker()}
|
||||
className="rounded-md bg-[#A076F6] px-3 py-1 text-sm text-white hover:bg-[#8A5FD4]"
|
||||
disabled={isLoading}
|
||||
>
|
||||
{isLoading ? 'Loading...' : 'Select Files'}
|
||||
{isLoading
|
||||
? t('modals.uploadDoc.connectors.googleDrive.loading')
|
||||
: t(
|
||||
'modals.uploadDoc.connectors.googleDrive.selectFiles',
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{selectedFiles.length === 0 && selectedFolders.length === 0 ? (
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||
No files or folders selected
|
||||
{t(
|
||||
'modals.uploadDoc.connectors.googleDrive.noFilesSelected',
|
||||
)}
|
||||
</p>
|
||||
) : (
|
||||
<div className="max-h-60 overflow-y-auto">
|
||||
{selectedFolders.length > 0 && (
|
||||
<div className="mb-2">
|
||||
<h4 className="mb-1 text-xs font-medium text-gray-500">
|
||||
Folders
|
||||
{t('modals.uploadDoc.connectors.googleDrive.folders')}
|
||||
</h4>
|
||||
{selectedFolders.map((folder) => (
|
||||
<div
|
||||
@@ -317,7 +333,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
>
|
||||
<img
|
||||
src={folder.iconUrl}
|
||||
alt="Folder"
|
||||
alt={t(
|
||||
'modals.uploadDoc.connectors.googleDrive.folderAlt',
|
||||
)}
|
||||
className="mr-2 h-5 w-5"
|
||||
/>
|
||||
<span className="flex-1 truncate text-sm">
|
||||
@@ -337,7 +355,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
}}
|
||||
className="ml-2 text-sm text-red-500 hover:text-red-700"
|
||||
>
|
||||
Remove
|
||||
{t(
|
||||
'modals.uploadDoc.connectors.googleDrive.remove',
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
@@ -347,7 +367,7 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
{selectedFiles.length > 0 && (
|
||||
<div>
|
||||
<h4 className="mb-1 text-xs font-medium text-gray-500">
|
||||
Files
|
||||
{t('modals.uploadDoc.connectors.googleDrive.files')}
|
||||
</h4>
|
||||
{selectedFiles.map((file) => (
|
||||
<div
|
||||
@@ -356,7 +376,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
>
|
||||
<img
|
||||
src={file.iconUrl}
|
||||
alt="File"
|
||||
alt={t(
|
||||
'modals.uploadDoc.connectors.googleDrive.fileAlt',
|
||||
)}
|
||||
className="mr-2 h-5 w-5"
|
||||
/>
|
||||
<span className="flex-1 truncate text-sm">
|
||||
@@ -375,7 +397,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
}}
|
||||
className="ml-2 text-sm text-red-500 hover:text-red-700"
|
||||
>
|
||||
Remove
|
||||
{t(
|
||||
'modals.uploadDoc.connectors.googleDrive.remove',
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
229
frontend/src/components/UploadToast.tsx
Normal file
229
frontend/src/components/UploadToast.tsx
Normal file
@@ -0,0 +1,229 @@
|
||||
import { useState } from 'react';
|
||||
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { selectUploadTasks, dismissUploadTask } from '../upload/uploadSlice';
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import CheckCircleFilled from '../assets/check-circle-filled.svg';
|
||||
import WarnIcon from '../assets/warn.svg';
|
||||
|
||||
const PROGRESS_RADIUS = 10;
|
||||
const PROGRESS_CIRCUMFERENCE = 2 * Math.PI * PROGRESS_RADIUS;
|
||||
|
||||
export default function UploadToast() {
|
||||
const [collapsedTasks, setCollapsedTasks] = useState<Record<string, boolean>>(
|
||||
{},
|
||||
);
|
||||
|
||||
const toggleTaskCollapse = (taskId: string) => {
|
||||
setCollapsedTasks((prev) => ({
|
||||
...prev,
|
||||
[taskId]: !prev[taskId],
|
||||
}));
|
||||
};
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const uploadTasks = useSelector(selectUploadTasks);
|
||||
|
||||
const getStatusHeading = (status: string) => {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('attachments.uploadFailed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2">
|
||||
{uploadTasks
|
||||
.filter((task) => !task.dismissed)
|
||||
.map((task) => {
|
||||
const shouldShowProgress = [
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
].includes(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset =
|
||||
PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
const isCollapsed = collapsedTasks[task.id] ?? false;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={task.id}
|
||||
className={`w-[271px] overflow-hidden rounded-2xl border border-[#00000021] shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300 ${
|
||||
task.status === 'completed'
|
||||
? 'bg-[#FBFBFB] dark:bg-[#26272E]'
|
||||
: task.status === 'failed'
|
||||
? 'bg-[#FBFBFB] dark:bg-[#26272E]'
|
||||
: 'bg-[#FBFBFB] dark:bg-[#26272E]'
|
||||
}`}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
task.status !== 'failed'
|
||||
? 'bg-[#FBF2FE] dark:bg-transparent'
|
||||
: ''
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter text-[14px] leading-[16.5px] font-medium text-black dark:text-[#DCDCDC]">
|
||||
{getStatusHeading(task.status)}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleTaskCollapse(task.id)}
|
||||
aria-label={
|
||||
isCollapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
isCollapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => dispatch(dismissUploadTask(task.id))}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: isCollapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
isCollapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black dark:text-[#B7BAB8]"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{shouldShowProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 flex-shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t(
|
||||
'modals.uploadDoc.progress.uploadProgress',
|
||||
{
|
||||
progress: formattedProgress,
|
||||
},
|
||||
)}
|
||||
>
|
||||
<circle
|
||||
className="text-gray-300 dark:text-gray-700"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 flex-shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 flex-shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' && task.errorMessage && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -188,7 +188,7 @@ const ConversationBubble = forwardRef<
|
||||
setIsEditClicked(true);
|
||||
setEditInputBox(message ?? '');
|
||||
}}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 dark:hover:bg-[#35363B] ${isQuestionHovered || isEditClicked ? 'visible' : 'invisible'}`}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isQuestionHovered || isEditClicked ? 'visible' : 'invisible'}`}
|
||||
>
|
||||
<img src={Edit} alt="Edit" className="cursor-pointer" />
|
||||
</button>
|
||||
@@ -407,7 +407,7 @@ const ConversationBubble = forwardRef<
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
className={`fade-in-bubble bg-gray-1000 dark:bg-gun-metal mr-5 flex max-w-full rounded-[28px] px-7 py-[18px] ${
|
||||
className={`fade-in-bubble bg-gray-1000 dark:bg-gun-metal mr-5 flex max-w-full rounded-[18px] px-6 py-4.5 ${
|
||||
type === 'ERROR'
|
||||
? 'text-red-3000 dark:border-red-2000 relative flex-row items-center rounded-full border border-transparent bg-[#FFE7E7] p-2 py-5 text-sm font-normal dark:text-white'
|
||||
: 'flex-col rounded-3xl'
|
||||
|
||||
@@ -229,6 +229,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "Upload new document",
|
||||
"select": "Choose how to upload your document to DocsGPT",
|
||||
"selectSource": "Select the way to add your source",
|
||||
"selectedFiles": "Selected Files",
|
||||
"noFilesSelected": "No files selected",
|
||||
"file": "Upload from device",
|
||||
"back": "Back",
|
||||
"wait": "Please wait ...",
|
||||
@@ -257,13 +260,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Upload is in progress",
|
||||
"training": "Training is in progress",
|
||||
"completed": "Training completed",
|
||||
"training": "Upload is in progress",
|
||||
"completed": "Upload completed",
|
||||
"wait": "This may take several minutes",
|
||||
"tokenLimit": "Over the token limit, please consider uploading smaller document"
|
||||
"preparing": "Preparing upload",
|
||||
"tokenLimit": "Over the token limit, please consider uploading smaller document",
|
||||
"expandDetails": "Expand upload details",
|
||||
"collapseDetails": "Collapse upload details",
|
||||
"dismiss": "Dismiss upload toast",
|
||||
"uploadProgress": "Upload progress {{progress}}%",
|
||||
"clear": "Clear"
|
||||
},
|
||||
"showAdvanced": "Show advanced options",
|
||||
"hideAdvanced": "Hide advanced options"
|
||||
"hideAdvanced": "Hide advanced options",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "Upload File",
|
||||
"heading": "Upload new document"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "Crawler",
|
||||
"heading": "Add content with Web Crawler"
|
||||
},
|
||||
"url": {
|
||||
"label": "Link",
|
||||
"heading": "Add content from URL"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "Add content from GitHub"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Add content from Reddit"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Upload from Google Drive"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "Connected User",
|
||||
"authFailed": "Authentication failed",
|
||||
"authUrlFailed": "Failed to get authorization URL",
|
||||
"popupBlocked": "Failed to open authentication window. Please allow popups.",
|
||||
"authCancelled": "Authentication was cancelled",
|
||||
"connectedAs": "Connected as {{email}}",
|
||||
"disconnect": "Disconnect"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Connect to Google Drive",
|
||||
"sessionExpired": "Session expired. Please reconnect to Google Drive.",
|
||||
"sessionExpiredGeneric": "Session expired. Please reconnect your account.",
|
||||
"validateFailed": "Failed to validate session. Please reconnect.",
|
||||
"noSession": "No valid session found. Please reconnect to Google Drive.",
|
||||
"noAccessToken": "No access token available. Please reconnect to Google Drive.",
|
||||
"pickerFailed": "Failed to open file picker. Please try again.",
|
||||
"selectedFiles": "Selected Files",
|
||||
"selectFiles": "Select Files",
|
||||
"loading": "Loading...",
|
||||
"noFilesSelected": "No files or folders selected",
|
||||
"folders": "Folders",
|
||||
"files": "Files",
|
||||
"remove": "Remove",
|
||||
"folderAlt": "Folder",
|
||||
"fileAlt": "File"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "Create New API Key",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "Subir nuevo documento",
|
||||
"select": "Elige cómo cargar tu documento en DocsGPT",
|
||||
"selectSource": "Selecciona la forma de agregar tu fuente",
|
||||
"selectedFiles": "Archivos Seleccionados",
|
||||
"noFilesSelected": "No hay archivos seleccionados",
|
||||
"file": "Subir desde el dispositivo",
|
||||
"back": "Atrás",
|
||||
"wait": "Por favor espera ...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Subida en progreso",
|
||||
"training": "Entrenamiento en progreso",
|
||||
"completed": "Entrenamiento completado",
|
||||
"training": "Subida en progreso",
|
||||
"completed": "Subida completada",
|
||||
"wait": "Esto puede tardar varios minutos",
|
||||
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño"
|
||||
"preparing": "Preparando subida",
|
||||
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño",
|
||||
"expandDetails": "Expandir detalles de subida",
|
||||
"collapseDetails": "Contraer detalles de subida",
|
||||
"dismiss": "Descartar notificación de subida",
|
||||
"uploadProgress": "Progreso de subida {{progress}}%",
|
||||
"clear": "Limpiar"
|
||||
},
|
||||
"showAdvanced": "Mostrar opciones avanzadas",
|
||||
"hideAdvanced": "Ocultar opciones avanzadas"
|
||||
"hideAdvanced": "Ocultar opciones avanzadas",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "Subir archivo",
|
||||
"heading": "Subir nuevo documento"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "Rastreador",
|
||||
"heading": "Agregar contenido con rastreador web"
|
||||
},
|
||||
"url": {
|
||||
"label": "Enlace",
|
||||
"heading": "Agregar contenido desde URL"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "Agregar contenido desde GitHub"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Agregar contenido desde Reddit"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Subir desde Google Drive"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "Usuario Conectado",
|
||||
"authFailed": "Autenticación fallida",
|
||||
"authUrlFailed": "Error al obtener la URL de autorización",
|
||||
"popupBlocked": "Error al abrir la ventana de autenticación. Por favor, permita ventanas emergentes.",
|
||||
"authCancelled": "Autenticación cancelada",
|
||||
"connectedAs": "Conectado como {{email}}",
|
||||
"disconnect": "Desconectar"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Conectar a Google Drive",
|
||||
"sessionExpired": "Sesión expirada. Por favor, reconecte a Google Drive.",
|
||||
"sessionExpiredGeneric": "Sesión expirada. Por favor, reconecte su cuenta.",
|
||||
"validateFailed": "Error al validar la sesión. Por favor, reconecte.",
|
||||
"noSession": "No se encontró una sesión válida. Por favor, reconecte a Google Drive.",
|
||||
"noAccessToken": "No hay token de acceso disponible. Por favor, reconecte a Google Drive.",
|
||||
"pickerFailed": "Error al abrir el selector de archivos. Por favor, inténtelo de nuevo.",
|
||||
"selectedFiles": "Archivos Seleccionados",
|
||||
"selectFiles": "Seleccionar Archivos",
|
||||
"loading": "Cargando...",
|
||||
"noFilesSelected": "No hay archivos o carpetas seleccionados",
|
||||
"folders": "Carpetas",
|
||||
"files": "Archivos",
|
||||
"remove": "Eliminar",
|
||||
"folderAlt": "Carpeta",
|
||||
"fileAlt": "Archivo"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "Crear Nueva Clave de API",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "新しい文書をアップロードする",
|
||||
"select": "ドキュメントを DocsGPT にアップロードする方法を選択します",
|
||||
"selectSource": "ソースを追加する方法を選択してください",
|
||||
"selectedFiles": "選択されたファイル",
|
||||
"noFilesSelected": "ファイルが選択されていません",
|
||||
"file": "デバイスからアップロード",
|
||||
"back": "戻る",
|
||||
"wait": "お待ちください ...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "アップロード中",
|
||||
"training": "トレーニング中",
|
||||
"completed": "トレーニング完了",
|
||||
"training": "アップロード中",
|
||||
"completed": "アップロード完了",
|
||||
"wait": "数分かかる場合があります",
|
||||
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください"
|
||||
"preparing": "アップロードを準備中",
|
||||
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください",
|
||||
"expandDetails": "アップロードの詳細を展開",
|
||||
"collapseDetails": "アップロードの詳細を折りたたむ",
|
||||
"dismiss": "アップロード通知を閉じる",
|
||||
"uploadProgress": "アップロード進行状況 {{progress}}%",
|
||||
"clear": "クリア"
|
||||
},
|
||||
"showAdvanced": "詳細オプションを表示",
|
||||
"hideAdvanced": "詳細オプションを非表示"
|
||||
"hideAdvanced": "詳細オプションを非表示",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "ファイルをアップロード",
|
||||
"heading": "新しいドキュメントをアップロード"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "クローラー",
|
||||
"heading": "Webクローラーでコンテンツを追加"
|
||||
},
|
||||
"url": {
|
||||
"label": "リンク",
|
||||
"heading": "URLからコンテンツを追加"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "GitHubからコンテンツを追加"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Redditからコンテンツを追加"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Google Driveからアップロード"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "接続されたユーザー",
|
||||
"authFailed": "認証に失敗しました",
|
||||
"authUrlFailed": "認証URLの取得に失敗しました",
|
||||
"popupBlocked": "認証ウィンドウを開けませんでした。ポップアップを許可してください。",
|
||||
"authCancelled": "認証がキャンセルされました",
|
||||
"connectedAs": "{{email}}として接続",
|
||||
"disconnect": "切断"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Google Driveに接続",
|
||||
"sessionExpired": "セッションが期限切れです。Google Driveに再接続してください。",
|
||||
"sessionExpiredGeneric": "セッションが期限切れです。アカウントに再接続してください。",
|
||||
"validateFailed": "セッションの検証に失敗しました。再接続してください。",
|
||||
"noSession": "有効なセッションが見つかりません。Google Driveに再接続してください。",
|
||||
"noAccessToken": "アクセストークンが利用できません。Google Driveに再接続してください。",
|
||||
"pickerFailed": "ファイルピッカーを開けませんでした。もう一度お試しください。",
|
||||
"selectedFiles": "選択されたファイル",
|
||||
"selectFiles": "ファイルを選択",
|
||||
"loading": "読み込み中...",
|
||||
"noFilesSelected": "ファイルまたはフォルダが選択されていません",
|
||||
"folders": "フォルダ",
|
||||
"files": "ファイル",
|
||||
"remove": "削除",
|
||||
"folderAlt": "フォルダ",
|
||||
"fileAlt": "ファイル"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "新しいAPIキーを作成",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "Загрузить новый документ",
|
||||
"select": "Выберите способ загрузки документа в DocsGPT",
|
||||
"selectSource": "Выберите способ добавления источника",
|
||||
"selectedFiles": "Выбранные файлы",
|
||||
"noFilesSelected": "Файлы не выбраны",
|
||||
"file": "Загрузить с устройства",
|
||||
"back": "Назад",
|
||||
"wait": "Пожалуйста, подождите...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Идет загрузка",
|
||||
"training": "Идет обучение",
|
||||
"completed": "Обучение завершено",
|
||||
"training": "Идет загрузка",
|
||||
"completed": "Загрузка завершена",
|
||||
"wait": "Это может занять несколько минут",
|
||||
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера"
|
||||
"preparing": "Подготовка загрузки",
|
||||
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера",
|
||||
"expandDetails": "Развернуть детали загрузки",
|
||||
"collapseDetails": "Свернуть детали загрузки",
|
||||
"dismiss": "Закрыть уведомление о загрузке",
|
||||
"uploadProgress": "Прогресс загрузки {{progress}}%",
|
||||
"clear": "Очистить"
|
||||
},
|
||||
"showAdvanced": "Показать расширенные настройки",
|
||||
"hideAdvanced": "Скрыть расширенные настройки"
|
||||
"hideAdvanced": "Скрыть расширенные настройки",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "Загрузить файл",
|
||||
"heading": "Загрузить новый документ"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "Краулер",
|
||||
"heading": "Добавить контент с помощью веб-краулера"
|
||||
},
|
||||
"url": {
|
||||
"label": "Ссылка",
|
||||
"heading": "Добавить контент из URL"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "Добавить контент из GitHub"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Добавить контент из Reddit"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Загрузить из Google Drive"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "Подключенный пользователь",
|
||||
"authFailed": "Ошибка аутентификации",
|
||||
"authUrlFailed": "Не удалось получить URL авторизации",
|
||||
"popupBlocked": "Не удалось открыть окно аутентификации. Пожалуйста, разрешите всплывающие окна.",
|
||||
"authCancelled": "Аутентификация отменена",
|
||||
"connectedAs": "Подключен как {{email}}",
|
||||
"disconnect": "Отключить"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Подключиться к Google Drive",
|
||||
"sessionExpired": "Сеанс истек. Пожалуйста, переподключитесь к Google Drive.",
|
||||
"sessionExpiredGeneric": "Сеанс истек. Пожалуйста, переподключите свою учетную запись.",
|
||||
"validateFailed": "Не удалось проверить сеанс. Пожалуйста, переподключитесь.",
|
||||
"noSession": "Действительный сеанс не найден. Пожалуйста, переподключитесь к Google Drive.",
|
||||
"noAccessToken": "Токен доступа недоступен. Пожалуйста, переподключитесь к Google Drive.",
|
||||
"pickerFailed": "Не удалось открыть средство выбора файлов. Пожалуйста, попробуйте еще раз.",
|
||||
"selectedFiles": "Выбранные файлы",
|
||||
"selectFiles": "Выбрать файлы",
|
||||
"loading": "Загрузка...",
|
||||
"noFilesSelected": "Файлы или папки не выбраны",
|
||||
"folders": "Папки",
|
||||
"files": "Файлы",
|
||||
"remove": "Удалить",
|
||||
"folderAlt": "Папка",
|
||||
"fileAlt": "Файл"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "Создать новый API ключ",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "上傳新文件",
|
||||
"select": "選擇如何將文件上傳到 DocsGPT",
|
||||
"selectSource": "選擇新增來源的方式",
|
||||
"selectedFiles": "已選擇的檔案",
|
||||
"noFilesSelected": "未選擇檔案",
|
||||
"file": "從檔案",
|
||||
"remote": "遠端",
|
||||
"back": "返回",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "正在上傳",
|
||||
"training": "正在訓練",
|
||||
"completed": "訓練完成",
|
||||
"training": "正在上傳",
|
||||
"completed": "上傳完成",
|
||||
"wait": "這可能需要幾分鐘",
|
||||
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔"
|
||||
"preparing": "準備上傳",
|
||||
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔",
|
||||
"expandDetails": "展開上傳詳情",
|
||||
"collapseDetails": "摺疊上傳詳情",
|
||||
"dismiss": "關閉上傳通知",
|
||||
"uploadProgress": "上傳進度 {{progress}}%",
|
||||
"clear": "清除"
|
||||
},
|
||||
"showAdvanced": "顯示進階選項",
|
||||
"hideAdvanced": "隱藏進階選項"
|
||||
"hideAdvanced": "隱藏進階選項",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "上傳檔案",
|
||||
"heading": "上傳新文檔"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "爬蟲",
|
||||
"heading": "使用網路爬蟲新增內容"
|
||||
},
|
||||
"url": {
|
||||
"label": "連結",
|
||||
"heading": "從URL新增內容"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "從GitHub新增內容"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "從Reddit新增內容"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "從Google Drive上傳"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "已連接使用者",
|
||||
"authFailed": "驗證失敗",
|
||||
"authUrlFailed": "取得授權URL失敗",
|
||||
"popupBlocked": "無法開啟驗證視窗。請允許彈出視窗。",
|
||||
"authCancelled": "驗證已取消",
|
||||
"connectedAs": "已連接為 {{email}}",
|
||||
"disconnect": "中斷連接"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "連接到 Google Drive",
|
||||
"sessionExpired": "工作階段已過期。請重新連接到 Google Drive。",
|
||||
"sessionExpiredGeneric": "工作階段已過期。請重新連接您的帳戶。",
|
||||
"validateFailed": "驗證工作階段失敗。請重新連接。",
|
||||
"noSession": "未找到有效工作階段。請重新連接到 Google Drive。",
|
||||
"noAccessToken": "存取權杖不可用。請重新連接到 Google Drive。",
|
||||
"pickerFailed": "無法開啟檔案選擇器。請重試。",
|
||||
"selectedFiles": "已選擇的檔案",
|
||||
"selectFiles": "選擇檔案",
|
||||
"loading": "載入中...",
|
||||
"noFilesSelected": "未選擇檔案或資料夾",
|
||||
"folders": "資料夾",
|
||||
"files": "檔案",
|
||||
"remove": "移除",
|
||||
"folderAlt": "資料夾",
|
||||
"fileAlt": "檔案"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "建立新的 API 金鑰",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "上传新文档",
|
||||
"select": "选择如何将文档上传到 DocsGPT",
|
||||
"selectSource": "选择添加源的方式",
|
||||
"selectedFiles": "已选择的文件",
|
||||
"noFilesSelected": "未选择文件",
|
||||
"file": "从设备上传",
|
||||
"back": "后退",
|
||||
"wait": "请稍等 ...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "正在上传",
|
||||
"training": "正在训练",
|
||||
"completed": "训练完成",
|
||||
"training": "正在上传",
|
||||
"completed": "上传完成",
|
||||
"wait": "这可能需要几分钟",
|
||||
"tokenLimit": "超出令牌限制,请考虑上传较小的文档"
|
||||
"preparing": "准备上传",
|
||||
"tokenLimit": "超出令牌限制,请考虑上传较小的文档",
|
||||
"expandDetails": "展开上传详情",
|
||||
"collapseDetails": "折叠上传详情",
|
||||
"dismiss": "关闭上传通知",
|
||||
"uploadProgress": "上传进度 {{progress}}%",
|
||||
"clear": "清除"
|
||||
},
|
||||
"showAdvanced": "显示高级选项",
|
||||
"hideAdvanced": "隐藏高级选项"
|
||||
"hideAdvanced": "隐藏高级选项",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "上传文件",
|
||||
"heading": "上传新文档"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "爬虫",
|
||||
"heading": "使用网络爬虫添加内容"
|
||||
},
|
||||
"url": {
|
||||
"label": "链接",
|
||||
"heading": "从URL添加内容"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "从GitHub添加内容"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "从Reddit添加内容"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "从Google Drive上传"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "已连接用户",
|
||||
"authFailed": "身份验证失败",
|
||||
"authUrlFailed": "获取授权URL失败",
|
||||
"popupBlocked": "无法打开身份验证窗口。请允许弹出窗口。",
|
||||
"authCancelled": "身份验证已取消",
|
||||
"connectedAs": "已连接为 {{email}}",
|
||||
"disconnect": "断开连接"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "连接到 Google Drive",
|
||||
"sessionExpired": "会话已过期。请重新连接到 Google Drive。",
|
||||
"sessionExpiredGeneric": "会话已过期。请重新连接您的账户。",
|
||||
"validateFailed": "验证会话失败。请重新连接。",
|
||||
"noSession": "未找到有效会话。请重新连接到 Google Drive。",
|
||||
"noAccessToken": "访问令牌不可用。请重新连接到 Google Drive。",
|
||||
"pickerFailed": "无法打开文件选择器。请重试。",
|
||||
"selectedFiles": "已选择的文件",
|
||||
"selectFiles": "选择文件",
|
||||
"loading": "加载中...",
|
||||
"noFilesSelected": "未选择文件或文件夹",
|
||||
"folders": "文件夹",
|
||||
"files": "文件",
|
||||
"remove": "删除",
|
||||
"folderAlt": "文件夹",
|
||||
"fileAlt": "文件"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "创建新的 API 密钥",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { nanoid } from '@reduxjs/toolkit';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
@@ -24,6 +25,8 @@ import {
|
||||
getIngestorSchema,
|
||||
IngestorOption,
|
||||
} from '../upload/types/ingestor';
|
||||
import { addUploadTask, updateUploadTask } from './uploadSlice';
|
||||
|
||||
import { FormField, IngestorConfig, IngestorType } from './types/ingestor';
|
||||
|
||||
import { FilePicker } from '../components/FilePicker';
|
||||
@@ -190,12 +193,12 @@ function Upload({
|
||||
<div className="mb-3" {...getRootProps()}>
|
||||
<span className="text-purple-30 dark:text-silver inline-block rounded-3xl border border-[#7F7F82] bg-transparent px-4 py-2 font-medium hover:cursor-pointer">
|
||||
<input type="button" {...getInputProps()} />
|
||||
Choose Files
|
||||
{t('modals.uploadDoc.choose')}
|
||||
</span>
|
||||
</div>
|
||||
<div className="mt-4 max-w-full">
|
||||
<p className="text-eerie-black dark:text-light-gray mb-[14px] text-[14px] font-medium">
|
||||
Selected Files
|
||||
{t('modals.uploadDoc.selectedFiles')}
|
||||
</p>
|
||||
<div className="max-w-full overflow-hidden">
|
||||
{files.map((file) => (
|
||||
@@ -209,7 +212,7 @@ function Upload({
|
||||
))}
|
||||
{files.length === 0 && (
|
||||
<p className="text-gray-6000 dark:text-light-gray text-[14px]">
|
||||
No files selected
|
||||
{t('modals.uploadDoc.noFilesSelected')}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
@@ -259,15 +262,8 @@ function Upload({
|
||||
config: {},
|
||||
}));
|
||||
|
||||
const [progress, setProgress] = useState<{
|
||||
type: 'UPLOAD' | 'TRAINING';
|
||||
percentage: number;
|
||||
taskId?: string;
|
||||
failed?: boolean;
|
||||
}>();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const setTimeoutRef = useRef<number | null>(null);
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const ingestorOptions: IngestorOption[] = IngestorFormSchemas.filter(
|
||||
(schema) => (schema.validate ? schema.validate() : true),
|
||||
@@ -279,188 +275,120 @@ function Upload({
|
||||
}));
|
||||
|
||||
const sourceDocs = useSelector(selectSourceDocs);
|
||||
useEffect(() => {
|
||||
if (setTimeoutRef.current) {
|
||||
clearTimeout(setTimeoutRef.current);
|
||||
}
|
||||
|
||||
const resetUploaderState = useCallback(() => {
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setSelectedFiles([]);
|
||||
setSelectedFolders([]);
|
||||
setShowAdvancedOptions(false);
|
||||
}, []);
|
||||
|
||||
function ProgressBar({ progressPercent }: { progressPercent: number }) {
|
||||
return (
|
||||
<div className="my-8 flex h-full w-full items-center justify-center">
|
||||
<div className="relative h-32 w-32 rounded-full">
|
||||
<div className="absolute inset-0 rounded-full shadow-[0_0_10px_2px_rgba(0,0,0,0.3)_inset] dark:shadow-[0_0_10px_2px_rgba(0,0,0,0.3)_inset]"></div>
|
||||
<div
|
||||
className={`absolute inset-0 rounded-full ${progressPercent === 100 ? 'bg-linear-to-r from-white to-gray-400 shadow-xl shadow-lime-300/50 dark:bg-linear-to-br dark:from-gray-500 dark:to-gray-300 dark:shadow-lime-300/50' : 'shadow-[0_4px_0_#7D54D1] dark:shadow-[0_4px_0_#7D54D1]'}`}
|
||||
style={{
|
||||
animation: `${progressPercent === 100 ? 'none' : 'rotate 2s linear infinite'}`,
|
||||
}}
|
||||
></div>
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<span className="text-2xl font-bold">{progressPercent}%</span>
|
||||
</div>
|
||||
<style>
|
||||
{`@keyframes rotate {
|
||||
0% { transform: rotate(0deg); }
|
||||
100%{ transform: rotate(360deg); }
|
||||
}`}
|
||||
</style>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const handleTaskFailure = useCallback(
|
||||
(clientTaskId: string, errorMessage?: string) => {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
errorMessage: errorMessage || t('attachments.uploadFailed'),
|
||||
},
|
||||
}),
|
||||
);
|
||||
},
|
||||
[dispatch, t],
|
||||
);
|
||||
|
||||
function Progress({
|
||||
title,
|
||||
isCancellable = false,
|
||||
isFailed = false,
|
||||
isTraining = false,
|
||||
}: {
|
||||
title: string;
|
||||
isCancellable?: boolean;
|
||||
isFailed?: boolean;
|
||||
isTraining?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div className="text-gray-2000 dark:text-bright-gray mt-5 flex flex-col items-center gap-2">
|
||||
<p className="text-gra text-xl tracking-[0.15px]">
|
||||
{isTraining &&
|
||||
(progress?.percentage === 100
|
||||
? t('modals.uploadDoc.progress.completed')
|
||||
: title)}
|
||||
{!isTraining && title}
|
||||
</p>
|
||||
<p className="text-sm">{t('modals.uploadDoc.progress.wait')}</p>
|
||||
<p className={`ml-5 text-xl text-red-400 ${isFailed ? '' : 'hidden'}`}>
|
||||
{t('modals.uploadDoc.progress.tokenLimit')}
|
||||
</p>
|
||||
{/* <p className="mt-10 text-2xl">{progress?.percentage || 0}%</p> */}
|
||||
<ProgressBar progressPercent={progress?.percentage || 0} />
|
||||
{isTraining &&
|
||||
(progress?.percentage === 100 ? (
|
||||
<button
|
||||
onClick={() => {
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setProgress(undefined);
|
||||
setModalState('INACTIVE');
|
||||
}}
|
||||
className="h-[42px] cursor-pointer rounded-3xl bg-[#7D54D1] px-[28px] py-[6px] text-sm text-white shadow-lg hover:bg-[#6F3FD1]"
|
||||
>
|
||||
{t('modals.uploadDoc.start')}
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="ml-2 h-[42px] cursor-pointer rounded-3xl bg-[#7D54D14D] px-[28px] py-[6px] text-sm text-white shadow-lg"
|
||||
disabled
|
||||
>
|
||||
{t('modals.uploadDoc.wait')}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const trackTraining = useCallback(
|
||||
(backendTaskId: string, clientTaskId: string) => {
|
||||
let timeoutId: number | null = null;
|
||||
|
||||
function UploadProgress() {
|
||||
return <Progress title={t('modals.uploadDoc.progress.upload')}></Progress>;
|
||||
}
|
||||
|
||||
function TrainingProgress() {
|
||||
const dispatch = useDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
let timeoutID: number | undefined;
|
||||
|
||||
if ((progress?.percentage ?? 0) < 100) {
|
||||
timeoutID = setTimeout(() => {
|
||||
userService
|
||||
.getTaskStatus(progress?.taskId as string, null)
|
||||
.then((data) => data.json())
|
||||
.then((data) => {
|
||||
if (data.status == 'SUCCESS') {
|
||||
if (data.result.limited === true) {
|
||||
getDocs(token).then((data) => {
|
||||
dispatch(setSourceDocs(data));
|
||||
dispatch(
|
||||
setSelectedDocs(
|
||||
Array.isArray(data) &&
|
||||
data?.find(
|
||||
(d: Doc) => d.type?.toLowerCase() === 'local',
|
||||
),
|
||||
),
|
||||
);
|
||||
});
|
||||
setProgress(
|
||||
(progress) =>
|
||||
progress && {
|
||||
...progress,
|
||||
percentage: 100,
|
||||
failed: true,
|
||||
},
|
||||
);
|
||||
} else {
|
||||
getDocs(token).then((data) => {
|
||||
dispatch(setSourceDocs(data));
|
||||
const docIds = new Set(
|
||||
(Array.isArray(sourceDocs) &&
|
||||
sourceDocs?.map((doc: Doc) =>
|
||||
doc.id ? doc.id : null,
|
||||
)) ||
|
||||
[],
|
||||
);
|
||||
if (data && Array.isArray(data)) {
|
||||
data.map((updatedDoc: Doc) => {
|
||||
if (updatedDoc.id && !docIds.has(updatedDoc.id)) {
|
||||
// Select the doc not present in the intersection of current Docs and fetched data
|
||||
dispatch(setSelectedDocs(updatedDoc));
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
setProgress(
|
||||
(progress) =>
|
||||
progress && {
|
||||
...progress,
|
||||
percentage: 100,
|
||||
failed: false,
|
||||
},
|
||||
);
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setProgress(undefined);
|
||||
setModalState('INACTIVE');
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} else if (data.status == 'PROGRESS') {
|
||||
setProgress(
|
||||
(progress) =>
|
||||
progress && {
|
||||
...progress,
|
||||
percentage: data.result.current,
|
||||
},
|
||||
);
|
||||
const poll = () => {
|
||||
userService
|
||||
.getTaskStatus(backendTaskId, null)
|
||||
.then((response) => response.json())
|
||||
.then(async (data) => {
|
||||
if (data.status === 'SUCCESS') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
});
|
||||
}, 5000);
|
||||
}
|
||||
|
||||
// cleanup
|
||||
return () => {
|
||||
if (timeoutID !== undefined) {
|
||||
clearTimeout(timeoutID);
|
||||
}
|
||||
const docs = await getDocs(token);
|
||||
dispatch(setSourceDocs(docs));
|
||||
|
||||
if (Array.isArray(docs)) {
|
||||
const existingDocIds = new Set(
|
||||
(Array.isArray(sourceDocs) ? sourceDocs : [])
|
||||
.map((doc: Doc) => doc?.id)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
const newDoc = docs.find(
|
||||
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
|
||||
);
|
||||
if (newDoc) {
|
||||
dispatch(setSelectedDocs([newDoc]));
|
||||
}
|
||||
}
|
||||
|
||||
if (data.result?.limited) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
progress: 100,
|
||||
errorMessage: t('modals.uploadDoc.progress.tokenLimit'),
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
errorMessage: undefined,
|
||||
},
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} else if (data.status === 'FAILURE') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, data.result?.message);
|
||||
} else if (data.status === 'PROGRESS') {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'training',
|
||||
progress: Math.min(100, data.result?.current ?? 0),
|
||||
},
|
||||
}),
|
||||
);
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
} else {
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId);
|
||||
});
|
||||
};
|
||||
}, [progress, dispatch]);
|
||||
return (
|
||||
<Progress
|
||||
title={t('modals.uploadDoc.progress.training')}
|
||||
isCancellable={progress?.percentage === 100}
|
||||
isFailed={progress?.failed === true}
|
||||
isTraining={true}
|
||||
></Progress>
|
||||
);
|
||||
}
|
||||
|
||||
timeoutId = window.setTimeout(poll, 3000);
|
||||
},
|
||||
[dispatch, handleTaskFailure, onSuccessfulUpload, sourceDocs, t, token],
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: File[]) => {
|
||||
@@ -483,7 +411,7 @@ function Upload({
|
||||
|
||||
const doNothing = () => undefined;
|
||||
|
||||
const uploadFile = () => {
|
||||
const uploadFile = (clientTaskId: string) => {
|
||||
const formData = new FormData();
|
||||
files.forEach((file) => {
|
||||
formData.append('file', file);
|
||||
@@ -491,34 +419,89 @@ function Upload({
|
||||
|
||||
formData.append('name', ingestor.name);
|
||||
formData.append('user', 'local');
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'uploading', progress: 0 },
|
||||
}),
|
||||
);
|
||||
|
||||
xhr.upload.addEventListener('progress', (event) => {
|
||||
const progress = +((event.loaded / event.total) * 100).toFixed(2);
|
||||
setProgress({ type: 'UPLOAD', percentage: progress });
|
||||
if (!event.lengthComputable) return;
|
||||
const progressPercentage = Number(
|
||||
((event.loaded / event.total) * 100).toFixed(2),
|
||||
);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { progress: progressPercentage },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
const { task_id } = JSON.parse(xhr.responseText);
|
||||
setTimeoutRef.current = setTimeout(() => {
|
||||
setProgress({ type: 'TRAINING', percentage: 0, taskId: task_id });
|
||||
}, 3000);
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const parsed = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
if (parsed.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: parsed.task_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(parsed.task_id, clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'completed', progress: 100 },
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} catch (error) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
}
|
||||
} else {
|
||||
handleTaskFailure(clientTaskId, xhr.statusText || undefined);
|
||||
}
|
||||
};
|
||||
xhr.open('POST', `${apiHost + '/api/upload'}`);
|
||||
|
||||
xhr.onerror = () => {
|
||||
handleTaskFailure(clientTaskId);
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}/api/upload`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
};
|
||||
|
||||
const uploadRemote = () => {
|
||||
if (!ingestor.type) return;
|
||||
const uploadRemote = (clientTaskId: string) => {
|
||||
if (!ingestor.type) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
return;
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('name', ingestor.name);
|
||||
formData.append('user', 'local');
|
||||
formData.append('source', ingestor.type as string);
|
||||
|
||||
let configData: any = {};
|
||||
|
||||
const ingestorSchema = getIngestorSchema(ingestor.type as IngestorType);
|
||||
if (!ingestorSchema) return;
|
||||
if (!ingestorSchema) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
return;
|
||||
}
|
||||
|
||||
const schema: FormField[] = ingestorSchema.fields;
|
||||
const hasLocalFilePicker = schema.some(
|
||||
(field: FormField) => field.type === 'local_file_picker',
|
||||
@@ -530,11 +513,12 @@ function Upload({
|
||||
(field: FormField) => field.type === 'google_drive_picker',
|
||||
);
|
||||
|
||||
let configData: Record<string, unknown> = { ...ingestor.config };
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
files.forEach((file) => {
|
||||
formData.append('file', file);
|
||||
});
|
||||
configData = { ...ingestor.config };
|
||||
} else if (hasRemoteFilePicker || hasGoogleDrivePicker) {
|
||||
const sessionToken = getSessionToken(ingestor.type as string);
|
||||
configData = {
|
||||
@@ -543,44 +527,122 @@ function Upload({
|
||||
file_ids: selectedFiles,
|
||||
folder_ids: selectedFolders,
|
||||
};
|
||||
} else {
|
||||
configData = { ...ingestor.config };
|
||||
}
|
||||
|
||||
formData.append('data', JSON.stringify(configData));
|
||||
|
||||
const apiHost: string = import.meta.env.VITE_API_HOST;
|
||||
const xhr = new XMLHttpRequest();
|
||||
xhr.upload.addEventListener('progress', (event: ProgressEvent) => {
|
||||
if (event.lengthComputable) {
|
||||
const progressPercentage = +(
|
||||
(event.loaded / event.total) *
|
||||
100
|
||||
).toFixed(2);
|
||||
setProgress({ type: 'UPLOAD', percentage: progressPercentage });
|
||||
}
|
||||
});
|
||||
xhr.onload = () => {
|
||||
const response = JSON.parse(xhr.responseText) as { task_id: string };
|
||||
setTimeoutRef.current = window.setTimeout(() => {
|
||||
setProgress({
|
||||
type: 'TRAINING',
|
||||
percentage: 0,
|
||||
taskId: response.task_id,
|
||||
});
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
const endpoint =
|
||||
ingestor.type === 'local_file'
|
||||
? `${apiHost}/api/upload`
|
||||
: `${apiHost}/api/remote`;
|
||||
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'uploading', progress: 0 },
|
||||
}),
|
||||
);
|
||||
|
||||
xhr.upload.addEventListener('progress', (event: ProgressEvent) => {
|
||||
if (!event.lengthComputable) return;
|
||||
const progressPercentage = Number(
|
||||
((event.loaded / event.total) * 100).toFixed(2),
|
||||
);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { progress: progressPercentage },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(response.task_id, clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'completed', progress: 100 },
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} catch (error) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
}
|
||||
} else {
|
||||
handleTaskFailure(clientTaskId, xhr.statusText || undefined);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
handleTaskFailure(clientTaskId);
|
||||
};
|
||||
|
||||
xhr.open('POST', endpoint);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
};
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
resetUploaderState();
|
||||
setModalState('INACTIVE');
|
||||
close();
|
||||
}, [close, resetUploaderState, setModalState]);
|
||||
|
||||
const handleUpload = () => {
|
||||
if (!ingestor.type) return;
|
||||
|
||||
const ingestorSchemaForUpload = getIngestorSchema(
|
||||
ingestor.type as IngestorType,
|
||||
);
|
||||
if (!ingestorSchemaForUpload) return;
|
||||
|
||||
const schema: FormField[] = ingestorSchemaForUpload.fields;
|
||||
const hasLocalFilePicker = schema.some(
|
||||
(field: FormField) => field.type === 'local_file_picker',
|
||||
);
|
||||
|
||||
const displayName =
|
||||
ingestor.name?.trim() || files[0]?.name || t('modals.uploadDoc.label');
|
||||
|
||||
const clientTaskId = nanoid();
|
||||
|
||||
dispatch(
|
||||
addUploadTask({
|
||||
id: clientTaskId,
|
||||
fileName: displayName,
|
||||
progress: 0,
|
||||
status: 'preparing',
|
||||
}),
|
||||
);
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
uploadFile(clientTaskId);
|
||||
} else {
|
||||
uploadRemote(clientTaskId);
|
||||
}
|
||||
|
||||
handleClose();
|
||||
};
|
||||
|
||||
const { getRootProps, getInputProps } = useDropzone({
|
||||
onDrop,
|
||||
multiple: true,
|
||||
@@ -733,7 +795,7 @@ function Upload({
|
||||
/>
|
||||
</div>
|
||||
<p className="font-inter self-start text-[13px] leading-[18px] font-semibold">
|
||||
{option.label}
|
||||
{t(`modals.uploadDoc.ingestors.${option.value}.label`)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -741,18 +803,16 @@ function Upload({
|
||||
</div>
|
||||
);
|
||||
};
|
||||
let view;
|
||||
|
||||
if (progress?.type === 'UPLOAD') {
|
||||
view = <UploadProgress></UploadProgress>;
|
||||
} else if (progress?.type === 'TRAINING') {
|
||||
view = <TrainingProgress></TrainingProgress>;
|
||||
} else {
|
||||
view = (
|
||||
return (
|
||||
<WrapperModal
|
||||
close={handleClose}
|
||||
className="max-h-[90vh] w-11/12 sm:max-h-none sm:w-auto sm:min-w-[600px] md:min-w-[700px]"
|
||||
contentClassName="max-h-[80vh] sm:max-h-none"
|
||||
>
|
||||
<div className="flex w-full flex-col gap-6">
|
||||
{!ingestor.type && (
|
||||
<p className="font-inter text-left text-[20px] leading-[28px] font-semibold tracking-[0.15px] text-[#18181B] dark:text-[#ECECF1]">
|
||||
Select the way to add your source
|
||||
{t('modals.uploadDoc.selectSource')}
|
||||
</p>
|
||||
)}
|
||||
|
||||
@@ -770,12 +830,12 @@ function Upload({
|
||||
alt="back"
|
||||
className="h-3 w-3 rotate-180 transform"
|
||||
/>
|
||||
<span>Back</span>
|
||||
<span>{t('modals.uploadDoc.back')}</span>
|
||||
</button>
|
||||
|
||||
<h2 className="font-inter text-[22px] leading-[28px] font-semibold tracking-[0.15px] text-black dark:text-[#E0E0E0]">
|
||||
{ingestor.type &&
|
||||
getIngestorSchema(ingestor.type as IngestorType)?.heading}
|
||||
t(`modals.uploadDoc.ingestors.${ingestor.type}.heading`)}
|
||||
</h2>
|
||||
|
||||
<Input
|
||||
@@ -789,7 +849,7 @@ function Upload({
|
||||
}));
|
||||
}}
|
||||
borderVariant="thin"
|
||||
placeholder="Name"
|
||||
placeholder={t('modals.uploadDoc.name')}
|
||||
required={true}
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
className="w-full"
|
||||
@@ -816,23 +876,7 @@ function Upload({
|
||||
<div className="flex justify-end gap-4">
|
||||
{activeTab && ingestor.type && (
|
||||
<button
|
||||
onClick={() => {
|
||||
if (!ingestor.type) return;
|
||||
const ingestorSchemaForUpload = getIngestorSchema(
|
||||
ingestor.type as IngestorType,
|
||||
);
|
||||
if (!ingestorSchemaForUpload) return;
|
||||
const schema: FormField[] = ingestorSchemaForUpload.fields;
|
||||
const hasLocalFilePicker = schema.some(
|
||||
(field: FormField) => field.type === 'local_file_picker',
|
||||
);
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
uploadFile();
|
||||
} else {
|
||||
uploadRemote();
|
||||
}
|
||||
}}
|
||||
onClick={handleUpload}
|
||||
disabled={isUploadDisabled()}
|
||||
className={`rounded-3xl px-4 py-2 text-[14px] font-medium ${
|
||||
isUploadDisabled()
|
||||
@@ -845,22 +889,6 @@ function Upload({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<WrapperModal
|
||||
isPerformingTask={progress !== undefined && progress.percentage < 100}
|
||||
close={() => {
|
||||
close();
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setModalState('INACTIVE');
|
||||
}}
|
||||
className="max-h-[90vh] w-11/12 sm:max-h-none sm:w-auto sm:min-w-[600px] md:min-w-[700px]"
|
||||
contentClassName="max-h-[80vh] sm:max-h-none"
|
||||
>
|
||||
{view}
|
||||
</WrapperModal>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -10,12 +10,31 @@ export interface Attachment {
|
||||
token_count?: number;
|
||||
}
|
||||
|
||||
export type UploadTaskStatus =
|
||||
| 'preparing'
|
||||
| 'uploading'
|
||||
| 'training'
|
||||
| 'completed'
|
||||
| 'failed';
|
||||
|
||||
export interface UploadTask {
|
||||
id: string;
|
||||
fileName: string;
|
||||
progress: number;
|
||||
status: UploadTaskStatus;
|
||||
taskId?: string;
|
||||
errorMessage?: string;
|
||||
dismissed?: boolean;
|
||||
}
|
||||
|
||||
interface UploadState {
|
||||
attachments: Attachment[];
|
||||
tasks: UploadTask[];
|
||||
}
|
||||
|
||||
const initialState: UploadState = {
|
||||
attachments: [],
|
||||
tasks: [],
|
||||
};
|
||||
|
||||
export const uploadSlice = createSlice({
|
||||
@@ -52,6 +71,49 @@ export const uploadSlice = createSlice({
|
||||
(att) => att.status === 'uploading' || att.status === 'processing',
|
||||
);
|
||||
},
|
||||
addUploadTask: (state, action: PayloadAction<UploadTask>) => {
|
||||
state.tasks.unshift(action.payload);
|
||||
},
|
||||
updateUploadTask: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
updates: Partial<UploadTask>;
|
||||
}>,
|
||||
) => {
|
||||
const index = state.tasks.findIndex(
|
||||
(task) => task.id === action.payload.id,
|
||||
);
|
||||
if (index !== -1) {
|
||||
const updates = action.payload.updates;
|
||||
|
||||
// When task completes or fails, set dismissed to false to notify user
|
||||
if (updates.status === 'completed' || updates.status === 'failed') {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
...updates,
|
||||
dismissed: false,
|
||||
};
|
||||
} else {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
...updates,
|
||||
};
|
||||
}
|
||||
}
|
||||
},
|
||||
dismissUploadTask: (state, action: PayloadAction<string>) => {
|
||||
const index = state.tasks.findIndex((task) => task.id === action.payload);
|
||||
if (index !== -1) {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
dismissed: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
removeUploadTask: (state, action: PayloadAction<string>) => {
|
||||
state.tasks = state.tasks.filter((task) => task.id !== action.payload);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -60,10 +122,15 @@ export const {
|
||||
updateAttachment,
|
||||
removeAttachment,
|
||||
clearAttachments,
|
||||
addUploadTask,
|
||||
updateUploadTask,
|
||||
dismissUploadTask,
|
||||
removeUploadTask,
|
||||
} = uploadSlice.actions;
|
||||
|
||||
export const selectAttachments = (state: RootState) => state.upload.attachments;
|
||||
export const selectCompletedAttachments = (state: RootState) =>
|
||||
state.upload.attachments.filter((att) => att.status === 'completed');
|
||||
export const selectUploadTasks = (state: RootState) => state.upload.tasks;
|
||||
|
||||
export default uploadSlice.reducer;
|
||||
|
||||
765
tests/test_memory_tool.py
Normal file
765
tests/test_memory_tool.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import pytest
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_tool(monkeypatch) -> MemoryTool:
|
||||
"""Provide a MemoryTool with a fake Mongo collection and fixed user_id."""
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {} # path -> document
|
||||
|
||||
def insert_one(self, doc):
|
||||
user_id = doc.get("user_id")
|
||||
tool_id = doc.get("tool_id")
|
||||
path = doc.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
# Add _id to document if not present
|
||||
if "_id" not in doc:
|
||||
doc["_id"] = key
|
||||
self.docs[key] = doc
|
||||
return type("res", (), {"inserted_id": key})
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
# Handle query by _id
|
||||
if "_id" in q:
|
||||
doc_id = q["_id"]
|
||||
if doc_id not in self.docs:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if "$set" in u:
|
||||
old_doc = self.docs[doc_id].copy()
|
||||
old_doc.update(u["$set"])
|
||||
|
||||
# If path changed, update the dictionary key
|
||||
if "path" in u["$set"]:
|
||||
new_path = u["$set"]["path"]
|
||||
user_id = old_doc.get("user_id")
|
||||
tool_id = old_doc.get("tool_id")
|
||||
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||
|
||||
# Remove old key and add with new key
|
||||
del self.docs[doc_id]
|
||||
old_doc["_id"] = new_key
|
||||
self.docs[new_key] = old_doc
|
||||
else:
|
||||
self.docs[doc_id] = old_doc
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
# Handle query by user_id, tool_id, path
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
def find(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
results = []
|
||||
|
||||
# Handle regex queries for directory listing
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
# Remove regex escape characters and ^ anchor for simple matching
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
doc_path = doc.get("path", "")
|
||||
if doc_path.startswith(pattern):
|
||||
results.append(doc)
|
||||
else:
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def delete_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key in self.docs:
|
||||
del self.docs[key]
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
def delete_many(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
deleted = 0
|
||||
|
||||
# Handle regex queries for directory deletion
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
|
||||
keys_to_delete = []
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
doc_path = doc.get("path", "")
|
||||
if doc_path.startswith(pattern):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
else:
|
||||
# Delete all for user and tool
|
||||
keys_to_delete = [
|
||||
key for key, doc in self.docs.items()
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
|
||||
return type("res", (), {"deleted_count": deleted})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Return tool with a fixed tool_id for consistency in tests
|
||||
return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||
|
||||
|
||||
def test_init_without_user_id():
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
memory_tool = MemoryTool(tool_config={})
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "user_id" in result.lower()
|
||||
|
||||
|
||||
def test_view_empty_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Should show empty directory when no files exist."""
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "empty" in result.lower()
|
||||
|
||||
|
||||
def test_create_and_view_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test creating a file and viewing it."""
|
||||
# Create a file
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/notes.txt",
|
||||
file_text="Hello world"
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View the file
|
||||
result = memory_tool.execute_action("view", path="/notes.txt")
|
||||
assert "Hello world" in result
|
||||
|
||||
|
||||
def test_create_overwrite_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that create overwrites existing files."""
|
||||
# Create initial file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Original content"
|
||||
)
|
||||
|
||||
# Overwrite
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="New content"
|
||||
)
|
||||
|
||||
# Verify overwrite
|
||||
result = memory_tool.execute_action("view", path="/test.txt")
|
||||
assert "New content" in result
|
||||
assert "Original content" not in result
|
||||
|
||||
|
||||
def test_view_directory_with_files(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing directory contents."""
|
||||
# Create multiple files
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file3.txt",
|
||||
file_text="Content 3"
|
||||
)
|
||||
|
||||
# View directory
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "file1.txt" in result
|
||||
assert "file2.txt" in result
|
||||
assert "subdir/file3.txt" in result
|
||||
|
||||
|
||||
def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing specific lines from a file."""
|
||||
# Create a multiline file
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/multiline.txt",
|
||||
file_text=content
|
||||
)
|
||||
|
||||
# View lines 2-4
|
||||
result = memory_tool.execute_action(
|
||||
"view",
|
||||
path="/multiline.txt",
|
||||
view_range=[2, 4]
|
||||
)
|
||||
assert "Line 2" in result
|
||||
assert "Line 3" in result
|
||||
assert "Line 4" in result
|
||||
assert "Line 1" not in result
|
||||
assert "Line 5" not in result
|
||||
|
||||
|
||||
def test_str_replace(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement in a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/replace.txt",
|
||||
file_text="Hello world, hello universe"
|
||||
)
|
||||
|
||||
# Replace text
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/replace.txt",
|
||||
old_str="hello",
|
||||
new_str="hi"
|
||||
)
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
content = memory_tool.execute_action("view", path="/replace.txt")
|
||||
assert "hi world, hi universe" in content
|
||||
|
||||
|
||||
def test_str_replace_not_found(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement when string not found."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Hello world"
|
||||
)
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/test.txt",
|
||||
old_str="goodbye",
|
||||
new_str="hi"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_insert_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting text at a line number."""
|
||||
# Create a multiline file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/insert.txt",
|
||||
file_text="Line 1\nLine 2\nLine 3"
|
||||
)
|
||||
|
||||
# Insert at line 2
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/insert.txt",
|
||||
insert_line=2,
|
||||
insert_text="Inserted line"
|
||||
)
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
content = memory_tool.execute_action("view", path="/insert.txt")
|
||||
lines = content.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
def test_insert_invalid_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting at an invalid line number."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Line 1\nLine 2"
|
||||
)
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/test.txt",
|
||||
insert_line=100,
|
||||
insert_text="Text"
|
||||
)
|
||||
assert "invalid" in result.lower()
|
||||
|
||||
|
||||
def test_delete_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/delete_me.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
# Delete it
|
||||
result = memory_tool.execute_action("delete", path="/delete_me.txt")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify it's gone
|
||||
result = memory_tool.execute_action("view", path="/delete_me.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file that doesn't exist."""
|
||||
result = memory_tool.execute_action("delete", path="/nonexistent.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_delete_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a directory with files."""
|
||||
# Create files in a directory
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Delete the directory
|
||||
result = memory_tool.execute_action("delete", path="/subdir/")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify files are gone
|
||||
result = memory_tool.execute_action("view", path="/subdir/file1.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_rename_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/old_name.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
# Rename it
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/old_name.txt",
|
||||
new_path="/new_name.txt"
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
|
||||
# Verify old path doesn't exist
|
||||
result = memory_tool.execute_action("view", path="/old_name.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
# Verify new path exists
|
||||
result = memory_tool.execute_action("view", path="/new_name.txt")
|
||||
assert "Content" in result
|
||||
|
||||
|
||||
def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file that doesn't exist."""
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/nonexistent.txt",
|
||||
new_path="/new.txt"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_rename_to_existing_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming to a path that already exists."""
|
||||
# Create two files
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Try to rename file1 to file2
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/file1.txt",
|
||||
new_path="/file2.txt"
|
||||
)
|
||||
assert "already exists" in result.lower()
|
||||
|
||||
|
||||
def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
|
||||
"""Test that directory traversal attacks are prevented."""
|
||||
# Try various path traversal attempts
|
||||
invalid_paths = [
|
||||
"/../secrets.txt",
|
||||
"/../../etc/passwd",
|
||||
"..//file.txt",
|
||||
"/subdir/../../outside.txt",
|
||||
]
|
||||
|
||||
for path in invalid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="malicious content"
|
||||
)
|
||||
assert "invalid path" in result.lower()
|
||||
|
||||
|
||||
def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
|
||||
"""Test that paths work with or without leading slash (auto-normalized)."""
|
||||
# These paths should all work now (auto-prepended with /)
|
||||
valid_paths = [
|
||||
"etc/passwd", # Auto-prepended with /
|
||||
"home/user/file.txt", # Auto-prepended with /
|
||||
"file.txt", # Auto-prepended with /
|
||||
]
|
||||
|
||||
for path in valid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="content"
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# Verify the file can be accessed with or without leading slash
|
||||
view_result = memory_tool.execute_action("view", path=path)
|
||||
assert "content" in view_result
|
||||
|
||||
|
||||
def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that you cannot create a file at a directory path."""
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/",
|
||||
file_text="content"
|
||||
)
|
||||
assert "cannot create a file at directory path" in result.lower()
|
||||
|
||||
|
||||
def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
|
||||
"""Test that action metadata is properly defined."""
|
||||
metadata = memory_tool.get_actions_metadata()
|
||||
|
||||
# Check that all expected actions are defined
|
||||
action_names = [action["name"] for action in metadata]
|
||||
assert "view" in action_names
|
||||
assert "create" in action_names
|
||||
assert "str_replace" in action_names
|
||||
assert "insert" in action_names
|
||||
assert "delete" in action_names
|
||||
assert "rename" in action_names
|
||||
|
||||
# Check that each action has required fields
|
||||
for action in metadata:
|
||||
assert "name" in action
|
||||
assert "description" in action
|
||||
assert "parameters" in action
|
||||
|
||||
|
||||
def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
"""Test that different memory tool instances have isolated memories."""
|
||||
# Create fake collection
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def insert_one(self, doc):
|
||||
user_id = doc.get("user_id")
|
||||
tool_id = doc.get("tool_id")
|
||||
path = doc.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
self.docs[key] = doc
|
||||
return type("res", (), {"inserted_id": key})
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
# Handle query by _id
|
||||
if "_id" in q:
|
||||
doc_id = q["_id"]
|
||||
if doc_id not in self.docs:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if "$set" in u:
|
||||
old_doc = self.docs[doc_id].copy()
|
||||
old_doc.update(u["$set"])
|
||||
|
||||
# If path changed, update the dictionary key
|
||||
if "path" in u["$set"]:
|
||||
new_path = u["$set"]["path"]
|
||||
user_id = old_doc.get("user_id")
|
||||
tool_id = old_doc.get("tool_id")
|
||||
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||
|
||||
# Remove old key and add with new key
|
||||
del self.docs[doc_id]
|
||||
old_doc["_id"] = new_key
|
||||
self.docs[new_key] = old_doc
|
||||
else:
|
||||
self.docs[doc_id] = old_doc
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
# Handle query by user_id, tool_id, path
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two memory tools with different tool_ids for the same user
|
||||
tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||
tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||
|
||||
# Create a file in tool1
|
||||
tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1")
|
||||
|
||||
# Create a file with the same path in tool2
|
||||
tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2")
|
||||
|
||||
# Verify that each tool sees only its own content
|
||||
result1 = tool1.execute_action("view", path="/file.txt")
|
||||
result2 = tool2.execute_action("view", path="/file.txt")
|
||||
|
||||
assert "Content from tool 1" in result1
|
||||
assert "Content from tool 2" not in result1
|
||||
|
||||
assert "Content from tool 2" in result2
|
||||
assert "Content from tool 1" not in result2
|
||||
|
||||
|
||||
def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two tools without providing tool_id for the same user
|
||||
tool1 = MemoryTool({}, user_id="test_user")
|
||||
tool2 = MemoryTool({}, user_id="test_user")
|
||||
|
||||
# Both should have the same default tool_id for persistence
|
||||
assert tool1.tool_id == "default_test_user"
|
||||
assert tool2.tool_id == "default_test_user"
|
||||
assert tool1.tool_id == tool2.tool_id
|
||||
|
||||
# Different users should have different tool_ids
|
||||
tool3 = MemoryTool({}, user_id="another_user")
|
||||
assert tool3.tool_id == "default_another_user"
|
||||
assert tool3.tool_id != tool1.tool_id
|
||||
|
||||
|
||||
def test_paths_without_leading_slash(memory_tool) -> None:
|
||||
"""Test that paths without leading slash work correctly."""
|
||||
# Create file without leading slash
|
||||
result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung")
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View file without leading slash
|
||||
view_result = memory_tool.execute_action("view", path="cat_breeds.txt")
|
||||
assert "Korat" in view_result
|
||||
assert "Chartreux" in view_result
|
||||
|
||||
# View same file with leading slash (should work the same)
|
||||
view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt")
|
||||
assert "Korat" in view_result2
|
||||
|
||||
# Test str_replace without leading slash
|
||||
replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon")
|
||||
assert "updated" in replace_result.lower()
|
||||
|
||||
# Test nested path without leading slash
|
||||
nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2")
|
||||
assert "created" in nested_result.lower()
|
||||
|
||||
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
|
||||
assert "Task 1" in view_nested
|
||||
|
||||
|
||||
def test_rename_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a directory with files."""
|
||||
# Create files in a directory
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/sub/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Rename directory (with trailing slash)
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/docs/",
|
||||
new_path="/archive/"
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
assert "2 files" in result.lower()
|
||||
|
||||
# Verify old paths don't exist
|
||||
result = memory_tool.execute_action("view", path="/docs/file1.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
# Verify new paths exist
|
||||
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||
assert "Content 1" in result
|
||||
|
||||
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
|
||||
assert "Content 2" in result
|
||||
|
||||
|
||||
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a directory when new path is missing trailing slash."""
|
||||
# Create files in a directory
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/sub/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Rename directory - old path has slash, new path doesn't
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/docs/",
|
||||
new_path="/archive" # Missing trailing slash
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
|
||||
# Verify paths are correct (not corrupted like "/archivesub/file2.txt")
|
||||
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||
assert "Content 1" in result
|
||||
|
||||
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
|
||||
assert "Content 2" in result
|
||||
|
||||
# Verify corrupted path doesn't exist
|
||||
result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
|
||||
"""Test that view_range displays correct line numbers."""
|
||||
# Create a multiline file
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/numbered.txt",
|
||||
file_text=content
|
||||
)
|
||||
|
||||
# View lines 2-4
|
||||
result = memory_tool.execute_action(
|
||||
"view",
|
||||
path="/numbered.txt",
|
||||
view_range=[2, 4]
|
||||
)
|
||||
|
||||
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
|
||||
assert "2: Line 2" in result
|
||||
assert "3: Line 3" in result
|
||||
assert "4: Line 4" in result
|
||||
assert "1: Line 1" not in result
|
||||
assert "5: Line 5" not in result
|
||||
|
||||
# Verify no off-by-one error
|
||||
assert "3: Line 2" not in result # Wrong line number
|
||||
assert "4: Line 3" not in result # Wrong line number
|
||||
assert "5: Line 4" not in result # Wrong line number
|
||||
223
tests/test_notes_tool.py
Normal file
223
tests/test_notes_tool.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import pytest
|
||||
from application.agents.tools.notes import NotesTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notes_tool(monkeypatch) -> NotesTool:
|
||||
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {} # key: user_id:tool_id -> doc
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
|
||||
# emulate single-note storage with optional upsert
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
|
||||
if "$set" in u and "note" in u["$set"]:
|
||||
self.docs[key]["note"] = u["$set"]["note"]
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
return self.docs.get(key)
|
||||
|
||||
def delete_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
if key in self.docs:
|
||||
del self.docs[key]
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
# Patch MongoDB client globally for the tool
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Return tool with a fixed tool_id for consistency in tests
|
||||
return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||
|
||||
|
||||
def test_view(notes_tool: NotesTool) -> None:
|
||||
# Manually insert a note to test retrieval
|
||||
notes_tool.collection.update_one(
|
||||
{"user_id": "test_user", "tool_id": "test_tool_id"},
|
||||
{"$set": {"note": "hello"}},
|
||||
upsert=True
|
||||
)
|
||||
assert "hello" in notes_tool.execute_action("view")
|
||||
|
||||
|
||||
def test_overwrite_and_delete(notes_tool: NotesTool) -> None:
|
||||
# Overwrite creates a new note
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="first").lower()
|
||||
assert "first" in notes_tool.execute_action("view")
|
||||
|
||||
# Overwrite replaces existing note
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="second").lower()
|
||||
assert "second" in notes_tool.execute_action("view")
|
||||
|
||||
assert "deleted" in notes_tool.execute_action("delete").lower()
|
||||
assert "no note" in notes_tool.execute_action("view").lower()
|
||||
|
||||
def test_init_without_user_id(monkeypatch):
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
notes_tool = NotesTool(tool_config={})
|
||||
result = notes_tool.execute_action("view")
|
||||
assert "user_id" in str(result).lower()
|
||||
|
||||
|
||||
def test_view_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Should return 'No note found.' when no note exists"""
|
||||
result = notes_tool.execute_action("view")
|
||||
assert "no note found" in result.lower()
|
||||
|
||||
|
||||
def test_str_replace(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement in note"""
|
||||
# Create a note
|
||||
notes_tool.execute_action("overwrite", text="Hello world, hello universe")
|
||||
|
||||
# Replace text
|
||||
result = notes_tool.execute_action("str_replace", old_str="hello", new_str="hi")
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
note = notes_tool.execute_action("view")
|
||||
assert "hi world, hi universe" in note.lower()
|
||||
|
||||
|
||||
def test_str_replace_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement when string not found"""
|
||||
notes_tool.execute_action("overwrite", text="Hello world")
|
||||
result = notes_tool.execute_action("str_replace", old_str="goodbye", new_str="hi")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_insert_line(notes_tool: NotesTool) -> None:
|
||||
"""Test inserting text at a line number"""
|
||||
# Create a multiline note
|
||||
notes_tool.execute_action("overwrite", text="Line 1\nLine 2\nLine 3")
|
||||
|
||||
# Insert at line 2
|
||||
result = notes_tool.execute_action("insert", line_number=2, text="Inserted line")
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
note = notes_tool.execute_action("view")
|
||||
lines = note.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
def test_delete_nonexistent_note(monkeypatch):
|
||||
class FakeResult:
|
||||
deleted_count = 0
|
||||
|
||||
class FakeCollection:
|
||||
def delete_one(self, *args, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client",
|
||||
lambda: {"docsgpt": {"notes": FakeCollection()}}
|
||||
)
|
||||
|
||||
notes_tool = NotesTool(tool_config={}, user_id="user123")
|
||||
result = notes_tool.execute_action("delete")
|
||||
assert "no note found" in result.lower()
|
||||
|
||||
|
||||
def test_notes_tool_isolation(monkeypatch) -> None:
|
||||
"""Test that different notes tool instances have isolated notes."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
|
||||
if "$set" in u and "note" in u["$set"]:
|
||||
self.docs[key]["note"] = u["$set"]["note"]
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
return self.docs.get(key)
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two notes tools with different tool_ids for the same user
|
||||
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||
|
||||
# Create a note in tool1
|
||||
tool1.execute_action("overwrite", text="Content from tool 1")
|
||||
|
||||
# Create a note in tool2
|
||||
tool2.execute_action("overwrite", text="Content from tool 2")
|
||||
|
||||
# Verify that each tool sees only its own content
|
||||
result1 = tool1.execute_action("view")
|
||||
result2 = tool2.execute_action("view")
|
||||
|
||||
assert "Content from tool 1" in result1
|
||||
assert "Content from tool 2" not in result1
|
||||
|
||||
assert "Content from tool 2" in result2
|
||||
assert "Content from tool 1" not in result2
|
||||
|
||||
|
||||
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two tools without providing tool_id for the same user
|
||||
tool1 = NotesTool({}, user_id="test_user")
|
||||
tool2 = NotesTool({}, user_id="test_user")
|
||||
|
||||
# Both should have the same default tool_id for persistence
|
||||
assert tool1.tool_id == "default_test_user"
|
||||
assert tool2.tool_id == "default_test_user"
|
||||
assert tool1.tool_id == tool2.tool_id
|
||||
|
||||
# Different users should have different tool_ids
|
||||
tool3 = NotesTool({}, user_id="another_user")
|
||||
assert tool3.tool_id == "default_another_user"
|
||||
assert tool3.tool_id != tool1.tool_id
|
||||
Reference in New Issue
Block a user