feat: enhance MemoryTool and NotesTool with tool_id management and directory renaming tests (#2026)

This commit is contained in:
Alex
2025-10-06 21:45:47 +01:00
committed by GitHub
parent e012189672
commit 9d452e3b04
4 changed files with 298 additions and 29 deletions

View File

@@ -104,7 +104,7 @@ class MemoryTool(Tool):
"properties": { "properties": {
"path": { "path": {
"type": "string", "type": "string",
"description": "Path to file or directory (e.g., /notes.txt or /project/)." "description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
}, },
"view_range": { "view_range": {
"type": "array", "type": "array",
@@ -233,6 +233,9 @@ class MemoryTool(Tool):
# Remove any leading/trailing whitespace # Remove any leading/trailing whitespace
path = path.strip() path = path.strip()
# Preserve whether path ends with / (indicates directory)
is_directory = path.endswith("/")
# Ensure path starts with / for consistency # Ensure path starts with / for consistency
if not path.startswith("/"): if not path.startswith("/"):
path = "/" + path path = "/" + path
@@ -250,6 +253,10 @@ class MemoryTool(Tool):
if not normalized.startswith("/"): if not normalized.startswith("/"):
return None return None
# Preserve trailing slash for directories
if is_directory and not normalized.endswith("/") and normalized != "/":
normalized = normalized + "/"
return normalized return normalized
except Exception: except Exception:
return None return None
@@ -322,8 +329,8 @@ class MemoryTool(Tool):
return f"Error: Line range out of bounds. File has {len(lines)} lines." return f"Error: Line range out of bounds. File has {len(lines)} lines."
selected_lines = lines[start_idx:end_idx] selected_lines = lines[start_idx:end_idx]
# Add line numbers # Add line numbers (enumerate with 1-based start)
numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)] numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
return "\n".join(numbered_lines) return "\n".join(numbered_lines)
return content return content
@@ -480,6 +487,10 @@ class MemoryTool(Tool):
# Check if renaming a directory # Check if renaming a directory
if validated_old.endswith("/"): if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
# Find all files in the old directory # Find all files in the old directory
docs = list(self.collection.find({ docs = list(self.collection.find({
"user_id": self.user_id, "user_id": self.user_id,

View File

@@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import uuid
from .base import Tool from .base import Tool
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
@@ -16,13 +17,24 @@ class NotesTool(Tool):
"""Initialize the tool. """Initialize the tool.
Args: Args:
tool_config: Optional tool configuration (unused for now). tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated notes
user_id: The authenticated user's id (should come from decoded_token["sub"]). user_id: The authenticated user's id (should come from decoded_token["sub"]).
""" """
self.user_id: Optional[str] = user_id self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME] db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"] self.collection = db["notes"]
@@ -117,7 +129,7 @@ class NotesTool(Tool):
# Internal helpers (single-note) # Internal helpers (single-note)
# ----------------------------- # -----------------------------
def _get_note(self) -> str: def _get_note(self) -> str:
doc = self.collection.find_one({"user_id": self.user_id}) doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"): if not doc or not doc.get("note"):
return "No note found." return "No note found."
return str(doc["note"]) return str(doc["note"])
@@ -127,7 +139,7 @@ class NotesTool(Tool):
if not content: if not content:
return "Note content required." return "Note content required."
self.collection.update_one( self.collection.update_one(
{"user_id": self.user_id}, {"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}}, {"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing upsert=True, # ✅ create if missing
) )
@@ -137,7 +149,7 @@ class NotesTool(Tool):
if not old_str: if not old_str:
return "old_str is required." return "old_str is required."
doc = self.collection.find_one({"user_id": self.user_id}) doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"): if not doc or not doc.get("note"):
return "No note found." return "No note found."
@@ -152,7 +164,7 @@ class NotesTool(Tool):
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE) updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one( self.collection.update_one(
{"user_id": self.user_id}, {"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
) )
return "Note updated." return "Note updated."
@@ -161,7 +173,7 @@ class NotesTool(Tool):
if not text: if not text:
return "Text is required." return "Text is required."
doc = self.collection.find_one({"user_id": self.user_id}) doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"): if not doc or not doc.get("note"):
return "No note found." return "No note found."
@@ -177,11 +189,11 @@ class NotesTool(Tool):
updated_note = "\n".join(lines) updated_note = "\n".join(lines)
self.collection.update_one( self.collection.update_one(
{"user_id": self.user_id}, {"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
) )
return "Text inserted." return "Text inserted."
def _delete_note(self) -> str: def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id}) res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete." return "Note deleted." if res.deleted_count else "No note found to delete."

View File

@@ -16,10 +16,40 @@ def memory_tool(monkeypatch) -> MemoryTool:
tool_id = doc.get("tool_id") tool_id = doc.get("tool_id")
path = doc.get("path") path = doc.get("path")
key = f"{user_id}:{tool_id}:{path}" key = f"{user_id}:{tool_id}:{path}"
# Add _id to document if not present
if "_id" not in doc:
doc["_id"] = key
self.docs[key] = doc self.docs[key] = doc
return type("res", (), {"inserted_id": key}) return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False): def update_one(self, q, u, upsert=False):
# Handle query by _id
if "_id" in q:
doc_id = q["_id"]
if doc_id not in self.docs:
return type("res", (), {"modified_count": 0})
if "$set" in u:
old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"])
# If path changed, update the dictionary key
if "path" in u["$set"]:
new_path = u["$set"]["path"]
user_id = old_doc.get("user_id")
tool_id = old_doc.get("tool_id")
new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key
del self.docs[doc_id]
old_doc["_id"] = new_key
self.docs[new_key] = old_doc
else:
self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path
user_id = q.get("user_id") user_id = q.get("user_id")
tool_id = q.get("tool_id") tool_id = q.get("tool_id")
path = q.get("path") path = q.get("path")
@@ -29,7 +59,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert: if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""} self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
if "$set" in u: if "$set" in u:
self.docs[key].update(u["$set"]) self.docs[key].update(u["$set"])
@@ -498,6 +528,33 @@ def test_memory_tool_isolation(monkeypatch) -> None:
return type("res", (), {"inserted_id": key}) return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False): def update_one(self, q, u, upsert=False):
# Handle query by _id
if "_id" in q:
doc_id = q["_id"]
if doc_id not in self.docs:
return type("res", (), {"modified_count": 0})
if "$set" in u:
old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"])
# If path changed, update the dictionary key
if "path" in u["$set"]:
new_path = u["$set"]["path"]
user_id = old_doc.get("user_id")
tool_id = old_doc.get("tool_id")
new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key
del self.docs[doc_id]
old_doc["_id"] = new_key
self.docs[new_key] = old_doc
else:
self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path
user_id = q.get("user_id") user_id = q.get("user_id")
tool_id = q.get("tool_id") tool_id = q.get("tool_id")
path = q.get("path") path = q.get("path")
@@ -507,7 +564,7 @@ def test_memory_tool_isolation(monkeypatch) -> None:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert: if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""} self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
if "$set" in u: if "$set" in u:
self.docs[key].update(u["$set"]) self.docs[key].update(u["$set"])
@@ -607,3 +664,102 @@ def test_paths_without_leading_slash(memory_tool) -> None:
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt") view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
assert "Task 1" in view_nested assert "Task 1" in view_nested
def test_rename_directory(memory_tool: MemoryTool) -> None:
"""Test renaming a directory with files."""
# Create files in a directory
memory_tool.execute_action(
"create",
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
)
# Rename directory (with trailing slash)
result = memory_tool.execute_action(
"rename",
old_path="/docs/",
new_path="/archive/"
)
assert "renamed" in result.lower()
assert "2 files" in result.lower()
# Verify old paths don't exist
result = memory_tool.execute_action("view", path="/docs/file1.txt")
assert "not found" in result.lower()
# Verify new paths exist
result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
assert "Content 2" in result
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
"""Test renaming a directory when new path is missing trailing slash."""
# Create files in a directory
memory_tool.execute_action(
"create",
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
)
# Rename directory - old path has slash, new path doesn't
result = memory_tool.execute_action(
"rename",
old_path="/docs/",
new_path="/archive" # Missing trailing slash
)
assert "renamed" in result.lower()
# Verify paths are correct (not corrupted like "/archivesub/file2.txt")
result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
assert "Content 2" in result
# Verify corrupted path doesn't exist
result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
assert "not found" in result.lower()
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
"""Test that view_range displays correct line numbers."""
# Create a multiline file
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
memory_tool.execute_action(
"create",
path="/numbered.txt",
file_text=content
)
# View lines 2-4
result = memory_tool.execute_action(
"view",
path="/numbered.txt",
view_range=[2, 4]
)
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
assert "2: Line 2" in result
assert "3: Line 3" in result
assert "4: Line 4" in result
assert "1: Line 1" not in result
assert "5: Line 5" not in result
# Verify no off-by-one error
assert "3: Line 2" not in result # Wrong line number
assert "4: Line 3" not in result # Wrong line number
assert "5: Line 4" not in result # Wrong line number

View File

@@ -9,26 +9,34 @@ def notes_tool(monkeypatch) -> NotesTool:
"""Provide a NotesTool with a fake Mongo collection and fixed user_id.""" """Provide a NotesTool with a fake Mongo collection and fixed user_id."""
class FakeCollection: class FakeCollection:
def __init__(self) -> None: def __init__(self) -> None:
self.doc = None # single note per user self.docs = {} # key: user_id:tool_id -> doc
def update_one(self, q, u, upsert=False): def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
# emulate single-note storage with optional upsert # emulate single-note storage with optional upsert
if self.doc is None and not upsert: if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0}) return type("res", (), {"modified_count": 0})
if self.doc is None and upsert: if key not in self.docs and upsert:
self.doc = {"user_id": q["user_id"], "note": ""} self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
if "$set" in u and "note" in u["$set"]: if "$set" in u and "note" in u["$set"]:
self.doc["note"] = u["$set"]["note"] self.docs[key]["note"] = u["$set"]["note"]
return type("res", (), {"modified_count": 1}) return type("res", (), {"modified_count": 1})
def find_one(self, q): def find_one(self, q):
if self.doc and self.doc.get("user_id") == q.get("user_id"): user_id = q.get("user_id")
return self.doc tool_id = q.get("tool_id")
return None key = f"{user_id}:{tool_id}"
return self.docs.get(key)
def delete_one(self, q): def delete_one(self, q):
if self.doc and self.doc.get("user_id") == q.get("user_id"): user_id = q.get("user_id")
self.doc = None tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
if key in self.docs:
del self.docs[key]
return type("res", (), {"deleted_count": 1}) return type("res", (), {"deleted_count": 1})
return type("res", (), {"deleted_count": 0}) return type("res", (), {"deleted_count": 0})
@@ -39,14 +47,14 @@ def notes_tool(monkeypatch) -> NotesTool:
# Patch MongoDB client globally for the tool # Patch MongoDB client globally for the tool
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client) monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# ToolManager will pass user_id in production; in tests we pass it directly # Return tool with a fixed tool_id for consistency in tests
return NotesTool({}, user_id="test_user") return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
def test_view(notes_tool: NotesTool) -> None: def test_view(notes_tool: NotesTool) -> None:
# Manually insert a note to test retrieval # Manually insert a note to test retrieval
notes_tool.collection.update_one( notes_tool.collection.update_one(
{"user_id": "test_user"}, {"user_id": "test_user", "tool_id": "test_tool_id"},
{"$set": {"note": "hello"}}, {"$set": {"note": "hello"}},
upsert=True upsert=True
) )
@@ -131,3 +139,85 @@ def test_delete_nonexistent_note(monkeypatch):
notes_tool = NotesTool(tool_config={}, user_id="user123") notes_tool = NotesTool(tool_config={}, user_id="user123")
result = notes_tool.execute_action("delete") result = notes_tool.execute_action("delete")
assert "no note found" in result.lower() assert "no note found" in result.lower()
def test_notes_tool_isolation(monkeypatch) -> None:
"""Test that different notes tool instances have isolated notes."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
if "$set" in u and "note" in u["$set"]:
self.docs[key]["note"] = u["$set"]["note"]
return type("res", (), {"modified_count": 1})
def find_one(self, q):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
return self.docs.get(key)
fake_collection = FakeCollection()
fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two notes tools with different tool_ids for the same user
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
# Create a note in tool1
tool1.execute_action("overwrite", text="Content from tool 1")
# Create a note in tool2
tool2.execute_action("overwrite", text="Content from tool 2")
# Verify that each tool sees only its own content
result1 = tool1.execute_action("view")
result2 = tool2.execute_action("view")
assert "Content from tool 1" in result1
assert "Content from tool 2" not in result1
assert "Content from tool 2" in result2
assert "Content from tool 1" not in result2
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def update_one(self, q, u, upsert=False):
return type("res", (), {"modified_count": 1})
fake_collection = FakeCollection()
fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two tools without providing tool_id for the same user
tool1 = NotesTool({}, user_id="test_user")
tool2 = NotesTool({}, user_id="test_user")
# Both should have the same default tool_id for persistence
assert tool1.tool_id == "default_test_user"
assert tool2.tool_id == "default_test_user"
assert tool1.tool_id == tool2.tool_id
# Different users should have different tool_ids
tool3 = NotesTool({}, user_id="another_user")
assert tool3.tool_id == "default_another_user"
assert tool3.tool_id != tool1.tool_id