From 9d452e3b045125eadcb73f29743f919d3324c324 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 6 Oct 2025 21:45:47 +0100 Subject: [PATCH] feat: enhance MemoryTool and NotesTool with tool_id management and directory renaming tests (#2026) --- application/agents/tools/memory.py | 17 ++- application/agents/tools/notes.py | 34 ++++-- tests/test_memory_tool.py | 160 ++++++++++++++++++++++++++++- tests/test_notes_tool.py | 116 ++++++++++++++++++--- 4 files changed, 298 insertions(+), 29 deletions(-) diff --git a/application/agents/tools/memory.py b/application/agents/tools/memory.py index 51c38f37..ebee4ab0 100644 --- a/application/agents/tools/memory.py +++ b/application/agents/tools/memory.py @@ -104,7 +104,7 @@ class MemoryTool(Tool): "properties": { "path": { "type": "string", - "description": "Path to file or directory (e.g., /notes.txt or /project/)." + "description": "Path to file or directory (e.g., /notes.txt or /project/ or /)." }, "view_range": { "type": "array", @@ -233,6 +233,9 @@ class MemoryTool(Tool): # Remove any leading/trailing whitespace path = path.strip() + # Preserve whether path ends with / (indicates directory) + is_directory = path.endswith("/") + # Ensure path starts with / for consistency if not path.startswith("/"): path = "/" + path @@ -250,6 +253,10 @@ class MemoryTool(Tool): if not normalized.startswith("/"): return None + # Preserve trailing slash for directories + if is_directory and not normalized.endswith("/") and normalized != "/": + normalized = normalized + "/" + return normalized except Exception: return None @@ -322,8 +329,8 @@ class MemoryTool(Tool): return f"Error: Line range out of bounds. File has {len(lines)} lines." selected_lines = lines[start_idx:end_idx] - # Add line numbers - numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)] + # Add line numbers (enumerate with 1-based start) + numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)] return "\n".join(numbered_lines) return content @@ -480,6 +487,10 @@ class MemoryTool(Tool): # Check if renaming a directory if validated_old.endswith("/"): + # Ensure validated_new also ends with / for proper path replacement + if not validated_new.endswith("/"): + validated_new = validated_new + "/" + # Find all files in the old directory docs = list(self.collection.find({ "user_id": self.user_id, diff --git a/application/agents/tools/notes.py b/application/agents/tools/notes.py index 25bcbde7..3d7ced85 100644 --- a/application/agents/tools/notes.py +++ b/application/agents/tools/notes.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Any, Dict, List, Optional +import uuid from .base import Tool from application.core.mongo_db import MongoDB @@ -16,13 +17,24 @@ class NotesTool(Tool): """Initialize the tool. Args: - tool_config: Optional tool configuration (unused for now). + tool_config: Optional tool configuration. Should include: + - tool_id: Unique identifier for this notes tool instance (from user_tools._id) + This ensures each user's tool configuration has isolated notes user_id: The authenticated user's id (should come from decoded_token["sub"]). - """ - - self.user_id: Optional[str] = user_id + + # Get tool_id from configuration (passed from user_tools._id in production) + # In production, tool_id is the MongoDB ObjectId string from user_tools collection + if tool_config and "tool_id" in tool_config: + self.tool_id = tool_config["tool_id"] + elif user_id: + # Fallback for backward compatibility or testing + self.tool_id = f"default_{user_id}" + else: + # Last resort fallback (shouldn't happen in normal use) + self.tool_id = str(uuid.uuid4()) + db = MongoDB.get_client()[settings.MONGO_DB_NAME] self.collection = db["notes"] @@ -117,7 +129,7 @@ class NotesTool(Tool): # Internal helpers (single-note) # ----------------------------- def _get_note(self) -> str: - doc = self.collection.find_one({"user_id": self.user_id}) + doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id}) if not doc or not doc.get("note"): return "No note found." return str(doc["note"]) @@ -127,7 +139,7 @@ class NotesTool(Tool): if not content: return "Note content required." self.collection.update_one( - {"user_id": self.user_id}, + {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": content, "updated_at": datetime.utcnow()}}, upsert=True, # ✅ create if missing ) @@ -137,7 +149,7 @@ class NotesTool(Tool): if not old_str: return "old_str is required." - doc = self.collection.find_one({"user_id": self.user_id}) + doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id}) if not doc or not doc.get("note"): return "No note found." @@ -152,7 +164,7 @@ class NotesTool(Tool): updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE) self.collection.update_one( - {"user_id": self.user_id}, + {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, ) return "Note updated." @@ -161,7 +173,7 @@ class NotesTool(Tool): if not text: return "Text is required." - doc = self.collection.find_one({"user_id": self.user_id}) + doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id}) if not doc or not doc.get("note"): return "No note found." @@ -177,11 +189,11 @@ class NotesTool(Tool): updated_note = "\n".join(lines) self.collection.update_one( - {"user_id": self.user_id}, + {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, ) return "Text inserted." def _delete_note(self) -> str: - res = self.collection.delete_one({"user_id": self.user_id}) + res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id}) return "Note deleted." if res.deleted_count else "No note found to delete." diff --git a/tests/test_memory_tool.py b/tests/test_memory_tool.py index a2ac5f1e..2b041c6e 100644 --- a/tests/test_memory_tool.py +++ b/tests/test_memory_tool.py @@ -16,10 +16,40 @@ def memory_tool(monkeypatch) -> MemoryTool: tool_id = doc.get("tool_id") path = doc.get("path") key = f"{user_id}:{tool_id}:{path}" + # Add _id to document if not present + if "_id" not in doc: + doc["_id"] = key self.docs[key] = doc return type("res", (), {"inserted_id": key}) def update_one(self, q, u, upsert=False): + # Handle query by _id + if "_id" in q: + doc_id = q["_id"] + if doc_id not in self.docs: + return type("res", (), {"modified_count": 0}) + + if "$set" in u: + old_doc = self.docs[doc_id].copy() + old_doc.update(u["$set"]) + + # If path changed, update the dictionary key + if "path" in u["$set"]: + new_path = u["$set"]["path"] + user_id = old_doc.get("user_id") + tool_id = old_doc.get("tool_id") + new_key = f"{user_id}:{tool_id}:{new_path}" + + # Remove old key and add with new key + del self.docs[doc_id] + old_doc["_id"] = new_key + self.docs[new_key] = old_doc + else: + self.docs[doc_id] = old_doc + + return type("res", (), {"modified_count": 1}) + + # Handle query by user_id, tool_id, path user_id = q.get("user_id") tool_id = q.get("tool_id") path = q.get("path") @@ -29,7 +59,7 @@ def memory_tool(monkeypatch) -> MemoryTool: return type("res", (), {"modified_count": 0}) if key not in self.docs and upsert: - self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""} + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key} if "$set" in u: self.docs[key].update(u["$set"]) @@ -498,6 +528,33 @@ def test_memory_tool_isolation(monkeypatch) -> None: return type("res", (), {"inserted_id": key}) def update_one(self, q, u, upsert=False): + # Handle query by _id + if "_id" in q: + doc_id = q["_id"] + if doc_id not in self.docs: + return type("res", (), {"modified_count": 0}) + + if "$set" in u: + old_doc = self.docs[doc_id].copy() + old_doc.update(u["$set"]) + + # If path changed, update the dictionary key + if "path" in u["$set"]: + new_path = u["$set"]["path"] + user_id = old_doc.get("user_id") + tool_id = old_doc.get("tool_id") + new_key = f"{user_id}:{tool_id}:{new_path}" + + # Remove old key and add with new key + del self.docs[doc_id] + old_doc["_id"] = new_key + self.docs[new_key] = old_doc + else: + self.docs[doc_id] = old_doc + + return type("res", (), {"modified_count": 1}) + + # Handle query by user_id, tool_id, path user_id = q.get("user_id") tool_id = q.get("tool_id") path = q.get("path") @@ -507,7 +564,7 @@ def test_memory_tool_isolation(monkeypatch) -> None: return type("res", (), {"modified_count": 0}) if key not in self.docs and upsert: - self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""} + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key} if "$set" in u: self.docs[key].update(u["$set"]) @@ -607,3 +664,102 @@ def test_paths_without_leading_slash(memory_tool) -> None: view_nested = memory_tool.execute_action("view", path="projects/tasks.txt") assert "Task 1" in view_nested + + +def test_rename_directory(memory_tool: MemoryTool) -> None: + """Test renaming a directory with files.""" + # Create files in a directory + memory_tool.execute_action( + "create", + path="/docs/file1.txt", + file_text="Content 1" + ) + memory_tool.execute_action( + "create", + path="/docs/sub/file2.txt", + file_text="Content 2" + ) + + # Rename directory (with trailing slash) + result = memory_tool.execute_action( + "rename", + old_path="/docs/", + new_path="/archive/" + ) + assert "renamed" in result.lower() + assert "2 files" in result.lower() + + # Verify old paths don't exist + result = memory_tool.execute_action("view", path="/docs/file1.txt") + assert "not found" in result.lower() + + # Verify new paths exist + result = memory_tool.execute_action("view", path="/archive/file1.txt") + assert "Content 1" in result + + result = memory_tool.execute_action("view", path="/archive/sub/file2.txt") + assert "Content 2" in result + + +def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None: + """Test renaming a directory when new path is missing trailing slash.""" + # Create files in a directory + memory_tool.execute_action( + "create", + path="/docs/file1.txt", + file_text="Content 1" + ) + memory_tool.execute_action( + "create", + path="/docs/sub/file2.txt", + file_text="Content 2" + ) + + # Rename directory - old path has slash, new path doesn't + result = memory_tool.execute_action( + "rename", + old_path="/docs/", + new_path="/archive" # Missing trailing slash + ) + assert "renamed" in result.lower() + + # Verify paths are correct (not corrupted like "/archivesub/file2.txt") + result = memory_tool.execute_action("view", path="/archive/file1.txt") + assert "Content 1" in result + + result = memory_tool.execute_action("view", path="/archive/sub/file2.txt") + assert "Content 2" in result + + # Verify corrupted path doesn't exist + result = memory_tool.execute_action("view", path="/archivesub/file2.txt") + assert "not found" in result.lower() + + +def test_view_file_line_numbers(memory_tool: MemoryTool) -> None: + """Test that view_range displays correct line numbers.""" + # Create a multiline file + content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + memory_tool.execute_action( + "create", + path="/numbered.txt", + file_text=content + ) + + # View lines 2-4 + result = memory_tool.execute_action( + "view", + path="/numbered.txt", + view_range=[2, 4] + ) + + # Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5) + assert "2: Line 2" in result + assert "3: Line 3" in result + assert "4: Line 4" in result + assert "1: Line 1" not in result + assert "5: Line 5" not in result + + # Verify no off-by-one error + assert "3: Line 2" not in result # Wrong line number + assert "4: Line 3" not in result # Wrong line number + assert "5: Line 4" not in result # Wrong line number diff --git a/tests/test_notes_tool.py b/tests/test_notes_tool.py index d272c9cc..f57c5435 100644 --- a/tests/test_notes_tool.py +++ b/tests/test_notes_tool.py @@ -9,26 +9,34 @@ def notes_tool(monkeypatch) -> NotesTool: """Provide a NotesTool with a fake Mongo collection and fixed user_id.""" class FakeCollection: def __init__(self) -> None: - self.doc = None # single note per user + self.docs = {} # key: user_id:tool_id -> doc def update_one(self, q, u, upsert=False): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + # emulate single-note storage with optional upsert - if self.doc is None and not upsert: + if key not in self.docs and not upsert: return type("res", (), {"modified_count": 0}) - if self.doc is None and upsert: - self.doc = {"user_id": q["user_id"], "note": ""} + if key not in self.docs and upsert: + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""} if "$set" in u and "note" in u["$set"]: - self.doc["note"] = u["$set"]["note"] + self.docs[key]["note"] = u["$set"]["note"] return type("res", (), {"modified_count": 1}) def find_one(self, q): - if self.doc and self.doc.get("user_id") == q.get("user_id"): - return self.doc - return None + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + return self.docs.get(key) def delete_one(self, q): - if self.doc and self.doc.get("user_id") == q.get("user_id"): - self.doc = None + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + if key in self.docs: + del self.docs[key] return type("res", (), {"deleted_count": 1}) return type("res", (), {"deleted_count": 0}) @@ -39,14 +47,14 @@ def notes_tool(monkeypatch) -> NotesTool: # Patch MongoDB client globally for the tool monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) - # ToolManager will pass user_id in production; in tests we pass it directly - return NotesTool({}, user_id="test_user") + # Return tool with a fixed tool_id for consistency in tests + return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user") def test_view(notes_tool: NotesTool) -> None: # Manually insert a note to test retrieval notes_tool.collection.update_one( - {"user_id": "test_user"}, + {"user_id": "test_user", "tool_id": "test_tool_id"}, {"$set": {"note": "hello"}}, upsert=True ) @@ -131,3 +139,85 @@ def test_delete_nonexistent_note(monkeypatch): notes_tool = NotesTool(tool_config={}, user_id="user123") result = notes_tool.execute_action("delete") assert "no note found" in result.lower() + + +def test_notes_tool_isolation(monkeypatch) -> None: + """Test that different notes tool instances have isolated notes.""" + class FakeCollection: + def __init__(self) -> None: + self.docs = {} + + def update_one(self, q, u, upsert=False): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + + if key not in self.docs and not upsert: + return type("res", (), {"modified_count": 0}) + if key not in self.docs and upsert: + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""} + if "$set" in u and "note" in u["$set"]: + self.docs[key]["note"] = u["$set"]["note"] + return type("res", (), {"modified_count": 1}) + + def find_one(self, q): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + return self.docs.get(key) + + fake_collection = FakeCollection() + fake_db = {"notes": fake_collection} + fake_client = {settings.MONGO_DB_NAME: fake_db} + + monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + + # Create two notes tools with different tool_ids for the same user + tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user") + tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user") + + # Create a note in tool1 + tool1.execute_action("overwrite", text="Content from tool 1") + + # Create a note in tool2 + tool2.execute_action("overwrite", text="Content from tool 2") + + # Verify that each tool sees only its own content + result1 = tool1.execute_action("view") + result2 = tool2.execute_action("view") + + assert "Content from tool 1" in result1 + assert "Content from tool 2" not in result1 + + assert "Content from tool 2" in result2 + assert "Content from tool 1" not in result2 + + +def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None: + """Test that tool_id defaults to 'default_{user_id}' for persistence.""" + class FakeCollection: + def __init__(self) -> None: + self.docs = {} + + def update_one(self, q, u, upsert=False): + return type("res", (), {"modified_count": 1}) + + fake_collection = FakeCollection() + fake_db = {"notes": fake_collection} + fake_client = {settings.MONGO_DB_NAME: fake_db} + + monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) + + # Create two tools without providing tool_id for the same user + tool1 = NotesTool({}, user_id="test_user") + tool2 = NotesTool({}, user_id="test_user") + + # Both should have the same default tool_id for persistence + assert tool1.tool_id == "default_test_user" + assert tool2.tool_id == "default_test_user" + assert tool1.tool_id == tool2.tool_id + + # Different users should have different tool_ids + tool3 = NotesTool({}, user_id="another_user") + assert tool3.tool_id == "default_another_user" + assert tool3.tool_id != tool1.tool_id