feat: enhance MemoryTool and NotesTool with tool_id management and directory renaming tests (#2026)

This commit is contained in:
Alex
2025-10-06 21:45:47 +01:00
committed by GitHub
parent e012189672
commit 9d452e3b04
4 changed files with 298 additions and 29 deletions

View File

@@ -104,7 +104,7 @@ class MemoryTool(Tool):
"properties": {
"path": {
"type": "string",
"description": "Path to file or directory (e.g., /notes.txt or /project/)."
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
},
"view_range": {
"type": "array",
@@ -233,6 +233,9 @@ class MemoryTool(Tool):
# Remove any leading/trailing whitespace
path = path.strip()
# Preserve whether path ends with / (indicates directory)
is_directory = path.endswith("/")
# Ensure path starts with / for consistency
if not path.startswith("/"):
path = "/" + path
@@ -250,6 +253,10 @@ class MemoryTool(Tool):
if not normalized.startswith("/"):
return None
# Preserve trailing slash for directories
if is_directory and not normalized.endswith("/") and normalized != "/":
normalized = normalized + "/"
return normalized
except Exception:
return None
@@ -322,8 +329,8 @@ class MemoryTool(Tool):
return f"Error: Line range out of bounds. File has {len(lines)} lines."
selected_lines = lines[start_idx:end_idx]
# Add line numbers
numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)]
# Add line numbers (enumerate with 1-based start)
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
return "\n".join(numbered_lines)
return content
@@ -480,6 +487,10 @@ class MemoryTool(Tool):
# Check if renaming a directory
if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
# Find all files in the old directory
docs = list(self.collection.find({
"user_id": self.user_id,

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
@@ -16,13 +17,24 @@ class NotesTool(Tool):
"""Initialize the tool.
Args:
tool_config: Optional tool configuration (unused for now).
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated notes
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
@@ -117,7 +129,7 @@ class NotesTool(Tool):
# Internal helpers (single-note)
# -----------------------------
def _get_note(self) -> str:
doc = self.collection.find_one({"user_id": self.user_id})
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
return str(doc["note"])
@@ -127,7 +139,7 @@ class NotesTool(Tool):
if not content:
return "Note content required."
self.collection.update_one(
{"user_id": self.user_id},
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing
)
@@ -137,7 +149,7 @@ class NotesTool(Tool):
if not old_str:
return "old_str is required."
doc = self.collection.find_one({"user_id": self.user_id})
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
@@ -152,7 +164,7 @@ class NotesTool(Tool):
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one(
{"user_id": self.user_id},
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Note updated."
@@ -161,7 +173,7 @@ class NotesTool(Tool):
if not text:
return "Text is required."
doc = self.collection.find_one({"user_id": self.user_id})
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
@@ -177,11 +189,11 @@ class NotesTool(Tool):
updated_note = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id},
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Text inserted."
def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id})
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."