feat: enhance MemoryTool and NotesTool with tool_id management and directory renaming tests (#2026)

This commit is contained in:
Alex
2025-10-06 21:45:47 +01:00
committed by GitHub
parent e012189672
commit 9d452e3b04
4 changed files with 298 additions and 29 deletions

View File

@@ -16,10 +16,40 @@ def memory_tool(monkeypatch) -> MemoryTool:
tool_id = doc.get("tool_id")
path = doc.get("path")
key = f"{user_id}:{tool_id}:{path}"
# Add _id to document if not present
if "_id" not in doc:
doc["_id"] = key
self.docs[key] = doc
return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False):
# Handle query by _id
if "_id" in q:
doc_id = q["_id"]
if doc_id not in self.docs:
return type("res", (), {"modified_count": 0})
if "$set" in u:
old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"])
# If path changed, update the dictionary key
if "path" in u["$set"]:
new_path = u["$set"]["path"]
user_id = old_doc.get("user_id")
tool_id = old_doc.get("tool_id")
new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key
del self.docs[doc_id]
old_doc["_id"] = new_key
self.docs[new_key] = old_doc
else:
self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
@@ -29,7 +59,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
if "$set" in u:
self.docs[key].update(u["$set"])
@@ -498,6 +528,33 @@ def test_memory_tool_isolation(monkeypatch) -> None:
return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False):
# Handle query by _id
if "_id" in q:
doc_id = q["_id"]
if doc_id not in self.docs:
return type("res", (), {"modified_count": 0})
if "$set" in u:
old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"])
# If path changed, update the dictionary key
if "path" in u["$set"]:
new_path = u["$set"]["path"]
user_id = old_doc.get("user_id")
tool_id = old_doc.get("tool_id")
new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key
del self.docs[doc_id]
old_doc["_id"] = new_key
self.docs[new_key] = old_doc
else:
self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
@@ -507,7 +564,7 @@ def test_memory_tool_isolation(monkeypatch) -> None:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
if "$set" in u:
self.docs[key].update(u["$set"])
@@ -607,3 +664,102 @@ def test_paths_without_leading_slash(memory_tool) -> None:
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
assert "Task 1" in view_nested
def test_rename_directory(memory_tool: MemoryTool) -> None:
"""Test renaming a directory with files."""
# Create files in a directory
memory_tool.execute_action(
"create",
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
)
# Rename directory (with trailing slash)
result = memory_tool.execute_action(
"rename",
old_path="/docs/",
new_path="/archive/"
)
assert "renamed" in result.lower()
assert "2 files" in result.lower()
# Verify old paths don't exist
result = memory_tool.execute_action("view", path="/docs/file1.txt")
assert "not found" in result.lower()
# Verify new paths exist
result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
assert "Content 2" in result
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
"""Test renaming a directory when new path is missing trailing slash."""
# Create files in a directory
memory_tool.execute_action(
"create",
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
)
# Rename directory - old path has slash, new path doesn't
result = memory_tool.execute_action(
"rename",
old_path="/docs/",
new_path="/archive" # Missing trailing slash
)
assert "renamed" in result.lower()
# Verify paths are correct (not corrupted like "/archivesub/file2.txt")
result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
assert "Content 2" in result
# Verify corrupted path doesn't exist
result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
assert "not found" in result.lower()
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
"""Test that view_range displays correct line numbers."""
# Create a multiline file
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
memory_tool.execute_action(
"create",
path="/numbered.txt",
file_text=content
)
# View lines 2-4
result = memory_tool.execute_action(
"view",
path="/numbered.txt",
view_range=[2, 4]
)
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
assert "2: Line 2" in result
assert "3: Line 3" in result
assert "4: Line 4" in result
assert "1: Line 1" not in result
assert "5: Line 5" not in result
# Verify no off-by-one error
assert "3: Line 2" not in result # Wrong line number
assert "4: Line 3" not in result # Wrong line number
assert "5: Line 4" not in result # Wrong line number

View File

@@ -9,26 +9,34 @@ def notes_tool(monkeypatch) -> NotesTool:
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
class FakeCollection:
def __init__(self) -> None:
self.doc = None # single note per user
self.docs = {} # key: user_id:tool_id -> doc
def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
# emulate single-note storage with optional upsert
if self.doc is None and not upsert:
if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0})
if self.doc is None and upsert:
self.doc = {"user_id": q["user_id"], "note": ""}
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
if "$set" in u and "note" in u["$set"]:
self.doc["note"] = u["$set"]["note"]
self.docs[key]["note"] = u["$set"]["note"]
return type("res", (), {"modified_count": 1})
def find_one(self, q):
if self.doc and self.doc.get("user_id") == q.get("user_id"):
return self.doc
return None
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
return self.docs.get(key)
def delete_one(self, q):
if self.doc and self.doc.get("user_id") == q.get("user_id"):
self.doc = None
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
if key in self.docs:
del self.docs[key]
return type("res", (), {"deleted_count": 1})
return type("res", (), {"deleted_count": 0})
@@ -39,14 +47,14 @@ def notes_tool(monkeypatch) -> NotesTool:
# Patch MongoDB client globally for the tool
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# ToolManager will pass user_id in production; in tests we pass it directly
return NotesTool({}, user_id="test_user")
# Return tool with a fixed tool_id for consistency in tests
return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
def test_view(notes_tool: NotesTool) -> None:
# Manually insert a note to test retrieval
notes_tool.collection.update_one(
{"user_id": "test_user"},
{"user_id": "test_user", "tool_id": "test_tool_id"},
{"$set": {"note": "hello"}},
upsert=True
)
@@ -131,3 +139,85 @@ def test_delete_nonexistent_note(monkeypatch):
notes_tool = NotesTool(tool_config={}, user_id="user123")
result = notes_tool.execute_action("delete")
assert "no note found" in result.lower()
def test_notes_tool_isolation(monkeypatch) -> None:
"""Test that different notes tool instances have isolated notes."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
if "$set" in u and "note" in u["$set"]:
self.docs[key]["note"] = u["$set"]["note"]
return type("res", (), {"modified_count": 1})
def find_one(self, q):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
return self.docs.get(key)
fake_collection = FakeCollection()
fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two notes tools with different tool_ids for the same user
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
# Create a note in tool1
tool1.execute_action("overwrite", text="Content from tool 1")
# Create a note in tool2
tool2.execute_action("overwrite", text="Content from tool 2")
# Verify that each tool sees only its own content
result1 = tool1.execute_action("view")
result2 = tool2.execute_action("view")
assert "Content from tool 1" in result1
assert "Content from tool 2" not in result1
assert "Content from tool 2" in result2
assert "Content from tool 1" not in result2
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def update_one(self, q, u, upsert=False):
return type("res", (), {"modified_count": 1})
fake_collection = FakeCollection()
fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two tools without providing tool_id for the same user
tool1 = NotesTool({}, user_id="test_user")
tool2 = NotesTool({}, user_id="test_user")
# Both should have the same default tool_id for persistence
assert tool1.tool_id == "default_test_user"
assert tool2.tool_id == "default_test_user"
assert tool1.tool_id == tool2.tool_id
# Different users should have different tool_ids
tool3 = NotesTool({}, user_id="another_user")
assert tool3.tool_id == "default_another_user"
assert tool3.tool_id != tool1.tool_id