Compare commits

...

1 Commits

4 changed files with 298 additions and 29 deletions

View File

@@ -104,7 +104,7 @@ class MemoryTool(Tool):
"properties": {
"path": {
"type": "string",
"description": "Path to file or directory (e.g., /notes.txt or /project/)."
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
},
"view_range": {
"type": "array",
@@ -233,6 +233,9 @@ class MemoryTool(Tool):
# Remove any leading/trailing whitespace
path = path.strip()
# Preserve whether path ends with / (indicates directory)
is_directory = path.endswith("/")
# Ensure path starts with / for consistency
if not path.startswith("/"):
path = "/" + path
@@ -250,6 +253,10 @@ class MemoryTool(Tool):
if not normalized.startswith("/"):
return None
# Preserve trailing slash for directories
if is_directory and not normalized.endswith("/") and normalized != "/":
normalized = normalized + "/"
return normalized
except Exception:
return None
@@ -322,8 +329,8 @@ class MemoryTool(Tool):
return f"Error: Line range out of bounds. File has {len(lines)} lines."
selected_lines = lines[start_idx:end_idx]
# Add line numbers
numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)]
# Add line numbers (enumerate with 1-based start)
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
return "\n".join(numbered_lines)
return content
@@ -480,6 +487,10 @@ class MemoryTool(Tool):
# Check if renaming a directory
if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
# Find all files in the old directory
docs = list(self.collection.find({
"user_id": self.user_id,

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
@@ -16,13 +17,24 @@ class NotesTool(Tool):
"""Initialize the tool.
Args:
tool_config: Optional tool configuration (unused for now).
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated notes
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
@@ -117,7 +129,7 @@ class NotesTool(Tool):
# Internal helpers (single-note)
# -----------------------------
def _get_note(self) -> str:
doc = self.collection.find_one({"user_id": self.user_id})
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
return str(doc["note"])
@@ -127,7 +139,7 @@ class NotesTool(Tool):
if not content:
return "Note content required."
self.collection.update_one(
{"user_id": self.user_id},
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing
)
@@ -137,7 +149,7 @@ class NotesTool(Tool):
if not old_str:
return "old_str is required."
doc = self.collection.find_one({"user_id": self.user_id})
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
@@ -152,7 +164,7 @@ class NotesTool(Tool):
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one(
{"user_id": self.user_id},
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Note updated."
@@ -161,7 +173,7 @@ class NotesTool(Tool):
if not text:
return "Text is required."
doc = self.collection.find_one({"user_id": self.user_id})
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
@@ -177,11 +189,11 @@ class NotesTool(Tool):
updated_note = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id},
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Text inserted."
def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id})
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."

View File

@@ -16,10 +16,40 @@ def memory_tool(monkeypatch) -> MemoryTool:
tool_id = doc.get("tool_id")
path = doc.get("path")
key = f"{user_id}:{tool_id}:{path}"
# Add _id to document if not present
if "_id" not in doc:
doc["_id"] = key
self.docs[key] = doc
return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False):
# Handle query by _id
if "_id" in q:
doc_id = q["_id"]
if doc_id not in self.docs:
return type("res", (), {"modified_count": 0})
if "$set" in u:
old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"])
# If path changed, update the dictionary key
if "path" in u["$set"]:
new_path = u["$set"]["path"]
user_id = old_doc.get("user_id")
tool_id = old_doc.get("tool_id")
new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key
del self.docs[doc_id]
old_doc["_id"] = new_key
self.docs[new_key] = old_doc
else:
self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
@@ -29,7 +59,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
if "$set" in u:
self.docs[key].update(u["$set"])
@@ -498,6 +528,33 @@ def test_memory_tool_isolation(monkeypatch) -> None:
return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False):
# Handle query by _id
if "_id" in q:
doc_id = q["_id"]
if doc_id not in self.docs:
return type("res", (), {"modified_count": 0})
if "$set" in u:
old_doc = self.docs[doc_id].copy()
old_doc.update(u["$set"])
# If path changed, update the dictionary key
if "path" in u["$set"]:
new_path = u["$set"]["path"]
user_id = old_doc.get("user_id")
tool_id = old_doc.get("tool_id")
new_key = f"{user_id}:{tool_id}:{new_path}"
# Remove old key and add with new key
del self.docs[doc_id]
old_doc["_id"] = new_key
self.docs[new_key] = old_doc
else:
self.docs[doc_id] = old_doc
return type("res", (), {"modified_count": 1})
# Handle query by user_id, tool_id, path
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
@@ -507,7 +564,7 @@ def test_memory_tool_isolation(monkeypatch) -> None:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
if "$set" in u:
self.docs[key].update(u["$set"])
@@ -607,3 +664,102 @@ def test_paths_without_leading_slash(memory_tool) -> None:
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
assert "Task 1" in view_nested
def test_rename_directory(memory_tool: MemoryTool) -> None:
"""Test renaming a directory with files."""
# Create files in a directory
memory_tool.execute_action(
"create",
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
)
# Rename directory (with trailing slash)
result = memory_tool.execute_action(
"rename",
old_path="/docs/",
new_path="/archive/"
)
assert "renamed" in result.lower()
assert "2 files" in result.lower()
# Verify old paths don't exist
result = memory_tool.execute_action("view", path="/docs/file1.txt")
assert "not found" in result.lower()
# Verify new paths exist
result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
assert "Content 2" in result
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
"""Test renaming a directory when new path is missing trailing slash."""
# Create files in a directory
memory_tool.execute_action(
"create",
path="/docs/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/docs/sub/file2.txt",
file_text="Content 2"
)
# Rename directory - old path has slash, new path doesn't
result = memory_tool.execute_action(
"rename",
old_path="/docs/",
new_path="/archive" # Missing trailing slash
)
assert "renamed" in result.lower()
# Verify paths are correct (not corrupted like "/archivesub/file2.txt")
result = memory_tool.execute_action("view", path="/archive/file1.txt")
assert "Content 1" in result
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
assert "Content 2" in result
# Verify corrupted path doesn't exist
result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
assert "not found" in result.lower()
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
"""Test that view_range displays correct line numbers."""
# Create a multiline file
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
memory_tool.execute_action(
"create",
path="/numbered.txt",
file_text=content
)
# View lines 2-4
result = memory_tool.execute_action(
"view",
path="/numbered.txt",
view_range=[2, 4]
)
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
assert "2: Line 2" in result
assert "3: Line 3" in result
assert "4: Line 4" in result
assert "1: Line 1" not in result
assert "5: Line 5" not in result
# Verify no off-by-one error
assert "3: Line 2" not in result # Wrong line number
assert "4: Line 3" not in result # Wrong line number
assert "5: Line 4" not in result # Wrong line number

View File

@@ -9,26 +9,34 @@ def notes_tool(monkeypatch) -> NotesTool:
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
class FakeCollection:
def __init__(self) -> None:
self.doc = None # single note per user
self.docs = {} # key: user_id:tool_id -> doc
def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
# emulate single-note storage with optional upsert
if self.doc is None and not upsert:
if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0})
if self.doc is None and upsert:
self.doc = {"user_id": q["user_id"], "note": ""}
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
if "$set" in u and "note" in u["$set"]:
self.doc["note"] = u["$set"]["note"]
self.docs[key]["note"] = u["$set"]["note"]
return type("res", (), {"modified_count": 1})
def find_one(self, q):
if self.doc and self.doc.get("user_id") == q.get("user_id"):
return self.doc
return None
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
return self.docs.get(key)
def delete_one(self, q):
if self.doc and self.doc.get("user_id") == q.get("user_id"):
self.doc = None
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
if key in self.docs:
del self.docs[key]
return type("res", (), {"deleted_count": 1})
return type("res", (), {"deleted_count": 0})
@@ -39,14 +47,14 @@ def notes_tool(monkeypatch) -> NotesTool:
# Patch MongoDB client globally for the tool
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# ToolManager will pass user_id in production; in tests we pass it directly
return NotesTool({}, user_id="test_user")
# Return tool with a fixed tool_id for consistency in tests
return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
def test_view(notes_tool: NotesTool) -> None:
# Manually insert a note to test retrieval
notes_tool.collection.update_one(
{"user_id": "test_user"},
{"user_id": "test_user", "tool_id": "test_tool_id"},
{"$set": {"note": "hello"}},
upsert=True
)
@@ -131,3 +139,85 @@ def test_delete_nonexistent_note(monkeypatch):
notes_tool = NotesTool(tool_config={}, user_id="user123")
result = notes_tool.execute_action("delete")
assert "no note found" in result.lower()
def test_notes_tool_isolation(monkeypatch) -> None:
"""Test that different notes tool instances have isolated notes."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
if "$set" in u and "note" in u["$set"]:
self.docs[key]["note"] = u["$set"]["note"]
return type("res", (), {"modified_count": 1})
def find_one(self, q):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
key = f"{user_id}:{tool_id}"
return self.docs.get(key)
fake_collection = FakeCollection()
fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two notes tools with different tool_ids for the same user
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
# Create a note in tool1
tool1.execute_action("overwrite", text="Content from tool 1")
# Create a note in tool2
tool2.execute_action("overwrite", text="Content from tool 2")
# Verify that each tool sees only its own content
result1 = tool1.execute_action("view")
result2 = tool2.execute_action("view")
assert "Content from tool 1" in result1
assert "Content from tool 2" not in result1
assert "Content from tool 2" in result2
assert "Content from tool 1" not in result2
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def update_one(self, q, u, upsert=False):
return type("res", (), {"modified_count": 1})
fake_collection = FakeCollection()
fake_db = {"notes": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two tools without providing tool_id for the same user
tool1 = NotesTool({}, user_id="test_user")
tool2 = NotesTool({}, user_id="test_user")
# Both should have the same default tool_id for persistence
assert tool1.tool_id == "default_test_user"
assert tool2.tool_id == "default_test_user"
assert tool1.tool_id == tool2.tool_id
# Different users should have different tool_ids
tool3 = NotesTool({}, user_id="another_user")
assert tool3.tool_id == "default_another_user"
assert tool3.tool_id != tool1.tool_id