diff --git a/application/agents/base.py b/application/agents/base.py index 134de1c3..f975abad 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -213,18 +213,24 @@ class BaseAgent(ABC): ): target_dict[param] = value tm = ToolManager(config={}) + + # Prepare tool_config and add tool_id for memory tools + if tool_data["name"] == "api_tool": + tool_config = { + "url": tool_data["config"]["actions"][action_name]["url"], + "method": tool_data["config"]["actions"][action_name]["method"], + "headers": headers, + "query_params": query_params, + } + else: + tool_config = tool_data["config"].copy() if tool_data["config"] else {} + # Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool) + # Use MongoDB _id if available, otherwise fall back to enumerated tool_id + tool_config["tool_id"] = str(tool_data.get("_id", tool_id)) + tool = tm.load_tool( tool_data["name"], - tool_config=( - { - "url": tool_data["config"]["actions"][action_name]["url"], - "method": tool_data["config"]["actions"][action_name]["method"], - "headers": headers, - "query_params": query_params, - } - if tool_data["name"] == "api_tool" - else tool_data["config"] - ), + tool_config=tool_config, user_id=self.user, # Pass user ID for MCP tools credential decryption ) if tool_data["name"] == "api_tool": diff --git a/application/agents/tools/memory.py b/application/agents/tools/memory.py new file mode 100644 index 00000000..51c38f37 --- /dev/null +++ b/application/agents/tools/memory.py @@ -0,0 +1,535 @@ +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional +import re +import uuid + +from .base import Tool +from application.core.mongo_db import MongoDB +from application.core.settings import settings + + +class MemoryTool(Tool): + """Memory + + Stores and retrieves information across conversations through a memory file directory. + """ + + def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None: + """Initialize the tool. + + Args: + tool_config: Optional tool configuration. Should include: + - tool_id: Unique identifier for this memory tool instance (from user_tools._id) + This ensures each user's tool configuration has isolated memories + user_id: The authenticated user's id (should come from decoded_token["sub"]). + """ + self.user_id: Optional[str] = user_id + + # Get tool_id from configuration (passed from user_tools._id in production) + # In production, tool_id is the MongoDB ObjectId string from user_tools collection + if tool_config and "tool_id" in tool_config: + self.tool_id = tool_config["tool_id"] + elif user_id: + # Fallback for backward compatibility or testing + self.tool_id = f"default_{user_id}" + else: + # Last resort fallback (shouldn't happen in normal use) + self.tool_id = str(uuid.uuid4()) + + db = MongoDB.get_client()[settings.MONGO_DB_NAME] + self.collection = db["memories"] + + # ----------------------------- + # Action implementations + # ----------------------------- + def execute_action(self, action_name: str, **kwargs: Any) -> str: + """Execute an action by name. + + Args: + action_name: One of view, create, str_replace, insert, delete, rename. + **kwargs: Parameters for the action. + + Returns: + A human-readable string result. + """ + if not self.user_id: + return "Error: MemoryTool requires a valid user_id." + + if action_name == "view": + return self._view( + kwargs.get("path", "/"), + kwargs.get("view_range") + ) + + if action_name == "create": + return self._create( + kwargs.get("path", ""), + kwargs.get("file_text", "") + ) + + if action_name == "str_replace": + return self._str_replace( + kwargs.get("path", ""), + kwargs.get("old_str", ""), + kwargs.get("new_str", "") + ) + + if action_name == "insert": + return self._insert( + kwargs.get("path", ""), + kwargs.get("insert_line", 1), + kwargs.get("insert_text", "") + ) + + if action_name == "delete": + return self._delete(kwargs.get("path", "")) + + if action_name == "rename": + return self._rename( + kwargs.get("old_path", ""), + kwargs.get("new_path", "") + ) + + return f"Unknown action: {action_name}" + + def get_actions_metadata(self) -> List[Dict[str, Any]]: + """Return JSON metadata describing supported actions for tool schemas.""" + return [ + { + "name": "view", + "description": "Shows directory contents or file contents with optional line ranges.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to file or directory (e.g., /notes.txt or /project/)." + }, + "view_range": { + "type": "array", + "items": {"type": "integer"}, + "description": "Optional [start_line, end_line] to view specific lines (1-indexed)." + } + }, + "required": ["path"] + }, + }, + { + "name": "create", + "description": "Create or overwrite a file.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to create (e.g., /notes.txt or /project/task.txt)." + }, + "file_text": { + "type": "string", + "description": "Content to write to the file." + } + }, + "required": ["path", "file_text"] + }, + }, + { + "name": "str_replace", + "description": "Replace text in a file.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path (e.g., /notes.txt)." + }, + "old_str": { + "type": "string", + "description": "String to find." + }, + "new_str": { + "type": "string", + "description": "String to replace with." + } + }, + "required": ["path", "old_str", "new_str"] + }, + }, + { + "name": "insert", + "description": "Insert text at a specific line in a file.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path (e.g., /notes.txt)." + }, + "insert_line": { + "type": "integer", + "description": "Line number to insert at (1-indexed)." + }, + "insert_text": { + "type": "string", + "description": "Text to insert." + } + }, + "required": ["path", "insert_line", "insert_text"] + }, + }, + { + "name": "delete", + "description": "Delete a file or directory.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to delete (e.g., /notes.txt or /project/)." + } + }, + "required": ["path"] + }, + }, + { + "name": "rename", + "description": "Rename or move a file/directory.", + "parameters": { + "type": "object", + "properties": { + "old_path": { + "type": "string", + "description": "Current path (e.g., /old.txt)." + }, + "new_path": { + "type": "string", + "description": "New path (e.g., /new.txt)." + } + }, + "required": ["old_path", "new_path"] + }, + }, + ] + + def get_config_requirements(self) -> Dict[str, Any]: + """Return configuration requirements.""" + return {} + + # ----------------------------- + # Path validation + # ----------------------------- + def _validate_path(self, path: str) -> Optional[str]: + """Validate and normalize path. + + Args: + path: User-provided path. + + Returns: + Normalized path or None if invalid. + """ + if not path: + return None + + # Remove any leading/trailing whitespace + path = path.strip() + + # Ensure path starts with / for consistency + if not path.startswith("/"): + path = "/" + path + + # Check for directory traversal patterns + if ".." in path or path.count("//") > 0: + return None + + # Normalize the path + try: + # Convert to Path object and resolve to canonical form + normalized = str(Path(path).as_posix()) + + # Ensure it still starts with / + if not normalized.startswith("/"): + return None + + return normalized + except Exception: + return None + + # ----------------------------- + # Internal helpers + # ----------------------------- + def _view(self, path: str, view_range: Optional[List[int]] = None) -> str: + """View directory contents or file contents.""" + validated_path = self._validate_path(path) + if not validated_path: + return "Error: Invalid path." + + # Check if viewing directory (ends with / or is root) + if validated_path == "/" or validated_path.endswith("/"): + return self._view_directory(validated_path) + + # Otherwise view file + return self._view_file(validated_path, view_range) + + def _view_directory(self, path: str) -> str: + """List files in a directory.""" + # Ensure path ends with / for proper prefix matching + search_path = path if path.endswith("/") else path + "/" + + # Find all files that start with this directory path + query = { + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": {"$regex": f"^{re.escape(search_path)}"} + } + + docs = list(self.collection.find(query, {"path": 1})) + + if not docs: + return f"Directory: {path}\n(empty)" + + # Extract filenames relative to the directory + files = [] + for doc in docs: + file_path = doc["path"] + # Remove the directory prefix + if file_path.startswith(search_path): + relative = file_path[len(search_path):] + if relative: + files.append(relative) + + files.sort() + file_list = "\n".join(f"- {f}" for f in files) + return f"Directory: {path}\n{file_list}" + + def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str: + """View file contents with optional line range.""" + doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path}) + + if not doc or not doc.get("content"): + return f"Error: File not found: {path}" + + content = str(doc["content"]) + + # Apply view_range if specified + if view_range and len(view_range) == 2: + lines = content.split("\n") + start, end = view_range + # Convert to 0-indexed + start_idx = max(0, start - 1) + end_idx = min(len(lines), end) + + if start_idx >= len(lines): + return f"Error: Line range out of bounds. File has {len(lines)} lines." + + selected_lines = lines[start_idx:end_idx] + # Add line numbers + numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)] + return "\n".join(numbered_lines) + + return content + + def _create(self, path: str, file_text: str) -> str: + """Create or overwrite a file.""" + validated_path = self._validate_path(path) + if not validated_path: + return "Error: Invalid path." + + if validated_path == "/" or validated_path.endswith("/"): + return "Error: Cannot create a file at directory path." + + self.collection.update_one( + {"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path}, + { + "$set": { + "content": file_text, + "updated_at": datetime.now() + } + }, + upsert=True + ) + + return f"File created: {validated_path}" + + def _str_replace(self, path: str, old_str: str, new_str: str) -> str: + """Replace text in a file.""" + validated_path = self._validate_path(path) + if not validated_path: + return "Error: Invalid path." + + if not old_str: + return "Error: old_str is required." + + doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path}) + + if not doc or not doc.get("content"): + return f"Error: File not found: {validated_path}" + + current_content = str(doc["content"]) + + # Check if old_str exists (case-insensitive) + if old_str.lower() not in current_content.lower(): + return f"Error: String '{old_str}' not found in file." + + # Replace the string (case-insensitive) + import re as regex_module + updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE) + + self.collection.update_one( + {"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path}, + { + "$set": { + "content": updated_content, + "updated_at": datetime.now() + } + } + ) + + return f"File updated: {validated_path}" + + def _insert(self, path: str, insert_line: int, insert_text: str) -> str: + """Insert text at a specific line.""" + validated_path = self._validate_path(path) + if not validated_path: + return "Error: Invalid path." + + if not insert_text: + return "Error: insert_text is required." + + doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path}) + + if not doc or not doc.get("content"): + return f"Error: File not found: {validated_path}" + + current_content = str(doc["content"]) + lines = current_content.split("\n") + + # Convert to 0-indexed + index = insert_line - 1 + if index < 0 or index > len(lines): + return f"Error: Invalid line number. File has {len(lines)} lines." + + lines.insert(index, insert_text) + updated_content = "\n".join(lines) + + self.collection.update_one( + {"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path}, + { + "$set": { + "content": updated_content, + "updated_at": datetime.now() + } + } + ) + + return f"Text inserted at line {insert_line} in {validated_path}" + + def _delete(self, path: str) -> str: + """Delete a file or directory.""" + validated_path = self._validate_path(path) + if not validated_path: + return "Error: Invalid path." + + if validated_path == "/": + # Delete all files for this user and tool + result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id}) + return f"Deleted {result.deleted_count} file(s) from memory." + + # Check if it's a directory (ends with /) + if validated_path.endswith("/"): + # Delete all files in directory + result = self.collection.delete_many({ + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": {"$regex": f"^{re.escape(validated_path)}"} + }) + return f"Deleted directory and {result.deleted_count} file(s)." + + # Try to delete as directory first (without trailing slash) + # Check if any files start with this path + / + search_path = validated_path + "/" + directory_result = self.collection.delete_many({ + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": {"$regex": f"^{re.escape(search_path)}"} + }) + + if directory_result.deleted_count > 0: + return f"Deleted directory and {directory_result.deleted_count} file(s)." + + # Delete single file + result = self.collection.delete_one({ + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": validated_path + }) + + if result.deleted_count: + return f"Deleted: {validated_path}" + return f"Error: File not found: {validated_path}" + + def _rename(self, old_path: str, new_path: str) -> str: + """Rename or move a file/directory.""" + validated_old = self._validate_path(old_path) + validated_new = self._validate_path(new_path) + + if not validated_old or not validated_new: + return "Error: Invalid path." + + if validated_old == "/" or validated_new == "/": + return "Error: Cannot rename root directory." + + # Check if renaming a directory + if validated_old.endswith("/"): + # Find all files in the old directory + docs = list(self.collection.find({ + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": {"$regex": f"^{re.escape(validated_old)}"} + })) + + if not docs: + return f"Error: Directory not found: {validated_old}" + + # Update paths for all files + for doc in docs: + old_file_path = doc["path"] + new_file_path = old_file_path.replace(validated_old, validated_new, 1) + + self.collection.update_one( + {"_id": doc["_id"]}, + {"$set": {"path": new_file_path, "updated_at": datetime.now()}} + ) + + return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)" + + # Rename single file + doc = self.collection.find_one({ + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": validated_old + }) + + if not doc: + return f"Error: File not found: {validated_old}" + + # Check if new path already exists + existing = self.collection.find_one({ + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": validated_new + }) + + if existing: + return f"Error: File already exists at {validated_new}" + + # Delete the old document and create a new one with the new path + self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old}) + self.collection.insert_one({ + "user_id": self.user_id, + "tool_id": self.tool_id, + "path": validated_new, + "content": doc.get("content", ""), + "updated_at": datetime.now() + }) + + return f"Renamed: {validated_old} -> {validated_new}" diff --git a/application/agents/tools/tool_manager.py b/application/agents/tools/tool_manager.py index 40ce02ce..fb45b987 100644 --- a/application/agents/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -28,7 +28,7 @@ class ToolManager: module = importlib.import_module(f"application.agents.tools.{tool_name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: - if tool_name in {"mcp_tool","notes"} and user_id: + if tool_name in {"mcp_tool", "notes", "memory"} and user_id: return obj(tool_config, user_id) else: return obj(tool_config) @@ -36,7 +36,7 @@ class ToolManager: def execute_action(self, tool_name, action_name, user_id=None, **kwargs): if tool_name not in self.tools: raise ValueError(f"Tool '{tool_name}' not loaded") - if tool_name == "mcp_tool" and user_id: + if tool_name in {"mcp_tool", "memory"} and user_id: tool_config = self.config.get(tool_name, {}) tool = self.load_tool(tool_name, tool_config, user_id) return tool.execute_action(action_name, **kwargs) diff --git a/frontend/public/toolIcons/tool_memory.svg b/frontend/public/toolIcons/tool_memory.svg new file mode 100644 index 00000000..cd526397 --- /dev/null +++ b/frontend/public/toolIcons/tool_memory.svg @@ -0,0 +1,3 @@ + + + diff --git a/tests/test_memory_tool.py b/tests/test_memory_tool.py new file mode 100644 index 00000000..a2ac5f1e --- /dev/null +++ b/tests/test_memory_tool.py @@ -0,0 +1,609 @@ +import pytest +from application.agents.tools.memory import MemoryTool +from application.core.settings import settings + + +@pytest.fixture +def memory_tool(monkeypatch) -> MemoryTool: + """Provide a MemoryTool with a fake Mongo collection and fixed user_id.""" + + class FakeCollection: + def __init__(self) -> None: + self.docs = {} # path -> document + + def insert_one(self, doc): + user_id = doc.get("user_id") + tool_id = doc.get("tool_id") + path = doc.get("path") + key = f"{user_id}:{tool_id}:{path}" + self.docs[key] = doc + return type("res", (), {"inserted_id": key}) + + def update_one(self, q, u, upsert=False): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + path = q.get("path") + key = f"{user_id}:{tool_id}:{path}" + + if key not in self.docs and not upsert: + return type("res", (), {"modified_count": 0}) + + if key not in self.docs and upsert: + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""} + + if "$set" in u: + self.docs[key].update(u["$set"]) + + return type("res", (), {"modified_count": 1}) + + def find_one(self, q, projection=None): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + path = q.get("path") + + if path: + key = f"{user_id}:{tool_id}:{path}" + return self.docs.get(key) + + return None + + def find(self, q, projection=None): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + results = [] + + # Handle regex queries for directory listing + if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]: + regex_pattern = q["path"]["$regex"] + # Remove regex escape characters and ^ anchor for simple matching + pattern = regex_pattern.replace("\\", "").lstrip("^") + + for key, doc in self.docs.items(): + if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id: + doc_path = doc.get("path", "") + if doc_path.startswith(pattern): + results.append(doc) + else: + for key, doc in self.docs.items(): + if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id: + results.append(doc) + + return results + + def delete_one(self, q): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + path = q.get("path") + key = f"{user_id}:{tool_id}:{path}" + + if key in self.docs: + del self.docs[key] + return type("res", (), {"deleted_count": 1}) + + return type("res", (), {"deleted_count": 0}) + + def delete_many(self, q): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + deleted = 0 + + # Handle regex queries for directory deletion + if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]: + regex_pattern = q["path"]["$regex"] + pattern = regex_pattern.replace("\\", "").lstrip("^") + + keys_to_delete = [] + for key, doc in self.docs.items(): + if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id: + doc_path = doc.get("path", "") + if doc_path.startswith(pattern): + keys_to_delete.append(key) + + for key in keys_to_delete: + del self.docs[key] + deleted += 1 + else: + # Delete all for user and tool + keys_to_delete = [ + key for key, doc in self.docs.items() + if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id + ] + for key in keys_to_delete: + del self.docs[key] + deleted += 1 + + return type("res", (), {"deleted_count": deleted}) + + fake_collection = FakeCollection() + fake_db = {"memories": fake_collection} + fake_client = {settings.MONGO_DB_NAME: fake_db} + + monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + + # Return tool with a fixed tool_id for consistency in tests + return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user") + + +def test_init_without_user_id(): + """Should fail gracefully if no user_id is provided.""" + memory_tool = MemoryTool(tool_config={}) + result = memory_tool.execute_action("view", path="/") + assert "user_id" in result.lower() + + +def test_view_empty_directory(memory_tool: MemoryTool) -> None: + """Should show empty directory when no files exist.""" + result = memory_tool.execute_action("view", path="/") + assert "empty" in result.lower() + + +def test_create_and_view_file(memory_tool: MemoryTool) -> None: + """Test creating a file and viewing it.""" + # Create a file + result = memory_tool.execute_action( + "create", + path="/notes.txt", + file_text="Hello world" + ) + assert "created" in result.lower() + + # View the file + result = memory_tool.execute_action("view", path="/notes.txt") + assert "Hello world" in result + + +def test_create_overwrite_file(memory_tool: MemoryTool) -> None: + """Test that create overwrites existing files.""" + # Create initial file + memory_tool.execute_action( + "create", + path="/test.txt", + file_text="Original content" + ) + + # Overwrite + memory_tool.execute_action( + "create", + path="/test.txt", + file_text="New content" + ) + + # Verify overwrite + result = memory_tool.execute_action("view", path="/test.txt") + assert "New content" in result + assert "Original content" not in result + + +def test_view_directory_with_files(memory_tool: MemoryTool) -> None: + """Test viewing directory contents.""" + # Create multiple files + memory_tool.execute_action( + "create", + path="/file1.txt", + file_text="Content 1" + ) + memory_tool.execute_action( + "create", + path="/file2.txt", + file_text="Content 2" + ) + memory_tool.execute_action( + "create", + path="/subdir/file3.txt", + file_text="Content 3" + ) + + # View directory + result = memory_tool.execute_action("view", path="/") + assert "file1.txt" in result + assert "file2.txt" in result + assert "subdir/file3.txt" in result + + +def test_view_file_with_line_range(memory_tool: MemoryTool) -> None: + """Test viewing specific lines from a file.""" + # Create a multiline file + content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + memory_tool.execute_action( + "create", + path="/multiline.txt", + file_text=content + ) + + # View lines 2-4 + result = memory_tool.execute_action( + "view", + path="/multiline.txt", + view_range=[2, 4] + ) + assert "Line 2" in result + assert "Line 3" in result + assert "Line 4" in result + assert "Line 1" not in result + assert "Line 5" not in result + + +def test_str_replace(memory_tool: MemoryTool) -> None: + """Test string replacement in a file.""" + # Create a file + memory_tool.execute_action( + "create", + path="/replace.txt", + file_text="Hello world, hello universe" + ) + + # Replace text + result = memory_tool.execute_action( + "str_replace", + path="/replace.txt", + old_str="hello", + new_str="hi" + ) + assert "updated" in result.lower() + + # Verify replacement + content = memory_tool.execute_action("view", path="/replace.txt") + assert "hi world, hi universe" in content + + +def test_str_replace_not_found(memory_tool: MemoryTool) -> None: + """Test string replacement when string not found.""" + memory_tool.execute_action( + "create", + path="/test.txt", + file_text="Hello world" + ) + + result = memory_tool.execute_action( + "str_replace", + path="/test.txt", + old_str="goodbye", + new_str="hi" + ) + assert "not found" in result.lower() + + +def test_insert_line(memory_tool: MemoryTool) -> None: + """Test inserting text at a line number.""" + # Create a multiline file + memory_tool.execute_action( + "create", + path="/insert.txt", + file_text="Line 1\nLine 2\nLine 3" + ) + + # Insert at line 2 + result = memory_tool.execute_action( + "insert", + path="/insert.txt", + insert_line=2, + insert_text="Inserted line" + ) + assert "inserted" in result.lower() + + # Verify insertion + content = memory_tool.execute_action("view", path="/insert.txt") + lines = content.split("\n") + assert lines[1] == "Inserted line" + assert lines[2] == "Line 2" + + +def test_insert_invalid_line(memory_tool: MemoryTool) -> None: + """Test inserting at an invalid line number.""" + memory_tool.execute_action( + "create", + path="/test.txt", + file_text="Line 1\nLine 2" + ) + + result = memory_tool.execute_action( + "insert", + path="/test.txt", + insert_line=100, + insert_text="Text" + ) + assert "invalid" in result.lower() + + +def test_delete_file(memory_tool: MemoryTool) -> None: + """Test deleting a file.""" + # Create a file + memory_tool.execute_action( + "create", + path="/delete_me.txt", + file_text="Content" + ) + + # Delete it + result = memory_tool.execute_action("delete", path="/delete_me.txt") + assert "deleted" in result.lower() + + # Verify it's gone + result = memory_tool.execute_action("view", path="/delete_me.txt") + assert "not found" in result.lower() + + +def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None: + """Test deleting a file that doesn't exist.""" + result = memory_tool.execute_action("delete", path="/nonexistent.txt") + assert "not found" in result.lower() + + +def test_delete_directory(memory_tool: MemoryTool) -> None: + """Test deleting a directory with files.""" + # Create files in a directory + memory_tool.execute_action( + "create", + path="/subdir/file1.txt", + file_text="Content 1" + ) + memory_tool.execute_action( + "create", + path="/subdir/file2.txt", + file_text="Content 2" + ) + + # Delete the directory + result = memory_tool.execute_action("delete", path="/subdir/") + assert "deleted" in result.lower() + + # Verify files are gone + result = memory_tool.execute_action("view", path="/subdir/file1.txt") + assert "not found" in result.lower() + + +def test_rename_file(memory_tool: MemoryTool) -> None: + """Test renaming a file.""" + # Create a file + memory_tool.execute_action( + "create", + path="/old_name.txt", + file_text="Content" + ) + + # Rename it + result = memory_tool.execute_action( + "rename", + old_path="/old_name.txt", + new_path="/new_name.txt" + ) + assert "renamed" in result.lower() + + # Verify old path doesn't exist + result = memory_tool.execute_action("view", path="/old_name.txt") + assert "not found" in result.lower() + + # Verify new path exists + result = memory_tool.execute_action("view", path="/new_name.txt") + assert "Content" in result + + +def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None: + """Test renaming a file that doesn't exist.""" + result = memory_tool.execute_action( + "rename", + old_path="/nonexistent.txt", + new_path="/new.txt" + ) + assert "not found" in result.lower() + + +def test_rename_to_existing_file(memory_tool: MemoryTool) -> None: + """Test renaming to a path that already exists.""" + # Create two files + memory_tool.execute_action( + "create", + path="/file1.txt", + file_text="Content 1" + ) + memory_tool.execute_action( + "create", + path="/file2.txt", + file_text="Content 2" + ) + + # Try to rename file1 to file2 + result = memory_tool.execute_action( + "rename", + old_path="/file1.txt", + new_path="/file2.txt" + ) + assert "already exists" in result.lower() + + +def test_path_traversal_protection(memory_tool: MemoryTool) -> None: + """Test that directory traversal attacks are prevented.""" + # Try various path traversal attempts + invalid_paths = [ + "/../secrets.txt", + "/../../etc/passwd", + "..//file.txt", + "/subdir/../../outside.txt", + ] + + for path in invalid_paths: + result = memory_tool.execute_action( + "create", + path=path, + file_text="malicious content" + ) + assert "invalid path" in result.lower() + + +def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None: + """Test that paths work with or without leading slash (auto-normalized).""" + # These paths should all work now (auto-prepended with /) + valid_paths = [ + "etc/passwd", # Auto-prepended with / + "home/user/file.txt", # Auto-prepended with / + "file.txt", # Auto-prepended with / + ] + + for path in valid_paths: + result = memory_tool.execute_action( + "create", + path=path, + file_text="content" + ) + assert "created" in result.lower() + + # Verify the file can be accessed with or without leading slash + view_result = memory_tool.execute_action("view", path=path) + assert "content" in view_result + + +def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None: + """Test that you cannot create a file at a directory path.""" + result = memory_tool.execute_action( + "create", + path="/", + file_text="content" + ) + assert "cannot create a file at directory path" in result.lower() + + +def test_get_actions_metadata(memory_tool: MemoryTool) -> None: + """Test that action metadata is properly defined.""" + metadata = memory_tool.get_actions_metadata() + + # Check that all expected actions are defined + action_names = [action["name"] for action in metadata] + assert "view" in action_names + assert "create" in action_names + assert "str_replace" in action_names + assert "insert" in action_names + assert "delete" in action_names + assert "rename" in action_names + + # Check that each action has required fields + for action in metadata: + assert "name" in action + assert "description" in action + assert "parameters" in action + + +def test_memory_tool_isolation(monkeypatch) -> None: + """Test that different memory tool instances have isolated memories.""" + # Create fake collection + class FakeCollection: + def __init__(self) -> None: + self.docs = {} + + def insert_one(self, doc): + user_id = doc.get("user_id") + tool_id = doc.get("tool_id") + path = doc.get("path") + key = f"{user_id}:{tool_id}:{path}" + self.docs[key] = doc + return type("res", (), {"inserted_id": key}) + + def update_one(self, q, u, upsert=False): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + path = q.get("path") + key = f"{user_id}:{tool_id}:{path}" + + if key not in self.docs and not upsert: + return type("res", (), {"modified_count": 0}) + + if key not in self.docs and upsert: + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""} + + if "$set" in u: + self.docs[key].update(u["$set"]) + + return type("res", (), {"modified_count": 1}) + + def find_one(self, q, projection=None): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + path = q.get("path") + + if path: + key = f"{user_id}:{tool_id}:{path}" + return self.docs.get(key) + + return None + + fake_collection = FakeCollection() + fake_db = {"memories": fake_collection} + fake_client = {settings.MONGO_DB_NAME: fake_db} + + monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + + # Create two memory tools with different tool_ids for the same user + tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user") + tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user") + + # Create a file in tool1 + tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1") + + # Create a file with the same path in tool2 + tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2") + + # Verify that each tool sees only its own content + result1 = tool1.execute_action("view", path="/file.txt") + result2 = tool2.execute_action("view", path="/file.txt") + + assert "Content from tool 1" in result1 + assert "Content from tool 2" not in result1 + + assert "Content from tool 2" in result2 + assert "Content from tool 1" not in result2 + + +def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None: + """Test that tool_id defaults to 'default_{user_id}' for persistence.""" + class FakeCollection: + def __init__(self) -> None: + self.docs = {} + + def update_one(self, q, u, upsert=False): + return type("res", (), {"modified_count": 1}) + + fake_collection = FakeCollection() + fake_db = {"memories": fake_collection} + fake_client = {settings.MONGO_DB_NAME: fake_db} + + monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + + # Create two tools without providing tool_id for the same user + tool1 = MemoryTool({}, user_id="test_user") + tool2 = MemoryTool({}, user_id="test_user") + + # Both should have the same default tool_id for persistence + assert tool1.tool_id == "default_test_user" + assert tool2.tool_id == "default_test_user" + assert tool1.tool_id == tool2.tool_id + + # Different users should have different tool_ids + tool3 = MemoryTool({}, user_id="another_user") + assert tool3.tool_id == "default_another_user" + assert tool3.tool_id != tool1.tool_id + + +def test_paths_without_leading_slash(memory_tool) -> None: + """Test that paths without leading slash work correctly.""" + # Create file without leading slash + result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung") + assert "created" in result.lower() + + # View file without leading slash + view_result = memory_tool.execute_action("view", path="cat_breeds.txt") + assert "Korat" in view_result + assert "Chartreux" in view_result + + # View same file with leading slash (should work the same) + view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt") + assert "Korat" in view_result2 + + # Test str_replace without leading slash + replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon") + assert "updated" in replace_result.lower() + + # Test nested path without leading slash + nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2") + assert "created" in nested_result.lower() + + view_nested = memory_tool.execute_action("view", path="projects/tasks.txt") + assert "Task 1" in view_nested