mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: enhance MemoryTool and NotesTool with tool_id management and directory renaming tests (#2026)
This commit is contained in:
@@ -104,7 +104,7 @@ class MemoryTool(Tool):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"path": {
|
"path": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Path to file or directory (e.g., /notes.txt or /project/)."
|
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||||
},
|
},
|
||||||
"view_range": {
|
"view_range": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
@@ -233,6 +233,9 @@ class MemoryTool(Tool):
|
|||||||
# Remove any leading/trailing whitespace
|
# Remove any leading/trailing whitespace
|
||||||
path = path.strip()
|
path = path.strip()
|
||||||
|
|
||||||
|
# Preserve whether path ends with / (indicates directory)
|
||||||
|
is_directory = path.endswith("/")
|
||||||
|
|
||||||
# Ensure path starts with / for consistency
|
# Ensure path starts with / for consistency
|
||||||
if not path.startswith("/"):
|
if not path.startswith("/"):
|
||||||
path = "/" + path
|
path = "/" + path
|
||||||
@@ -250,6 +253,10 @@ class MemoryTool(Tool):
|
|||||||
if not normalized.startswith("/"):
|
if not normalized.startswith("/"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Preserve trailing slash for directories
|
||||||
|
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||||
|
normalized = normalized + "/"
|
||||||
|
|
||||||
return normalized
|
return normalized
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -322,8 +329,8 @@ class MemoryTool(Tool):
|
|||||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||||
|
|
||||||
selected_lines = lines[start_idx:end_idx]
|
selected_lines = lines[start_idx:end_idx]
|
||||||
# Add line numbers
|
# Add line numbers (enumerate with 1-based start)
|
||||||
numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)]
|
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||||
return "\n".join(numbered_lines)
|
return "\n".join(numbered_lines)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
@@ -480,6 +487,10 @@ class MemoryTool(Tool):
|
|||||||
|
|
||||||
# Check if renaming a directory
|
# Check if renaming a directory
|
||||||
if validated_old.endswith("/"):
|
if validated_old.endswith("/"):
|
||||||
|
# Ensure validated_new also ends with / for proper path replacement
|
||||||
|
if not validated_new.endswith("/"):
|
||||||
|
validated_new = validated_new + "/"
|
||||||
|
|
||||||
# Find all files in the old directory
|
# Find all files in the old directory
|
||||||
docs = list(self.collection.find({
|
docs = list(self.collection.find({
|
||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
from .base import Tool
|
from .base import Tool
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
@@ -16,13 +17,24 @@ class NotesTool(Tool):
|
|||||||
"""Initialize the tool.
|
"""Initialize the tool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_config: Optional tool configuration (unused for now).
|
tool_config: Optional tool configuration. Should include:
|
||||||
|
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||||
|
This ensures each user's tool configuration has isolated notes
|
||||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
|
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||||
|
if tool_config and "tool_id" in tool_config:
|
||||||
|
self.tool_id = tool_config["tool_id"]
|
||||||
|
elif user_id:
|
||||||
|
# Fallback for backward compatibility or testing
|
||||||
|
self.tool_id = f"default_{user_id}"
|
||||||
|
else:
|
||||||
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||||
self.collection = db["notes"]
|
self.collection = db["notes"]
|
||||||
|
|
||||||
@@ -117,7 +129,7 @@ class NotesTool(Tool):
|
|||||||
# Internal helpers (single-note)
|
# Internal helpers (single-note)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
def _get_note(self) -> str:
|
def _get_note(self) -> str:
|
||||||
doc = self.collection.find_one({"user_id": self.user_id})
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
if not doc or not doc.get("note"):
|
if not doc or not doc.get("note"):
|
||||||
return "No note found."
|
return "No note found."
|
||||||
return str(doc["note"])
|
return str(doc["note"])
|
||||||
@@ -127,7 +139,7 @@ class NotesTool(Tool):
|
|||||||
if not content:
|
if not content:
|
||||||
return "Note content required."
|
return "Note content required."
|
||||||
self.collection.update_one(
|
self.collection.update_one(
|
||||||
{"user_id": self.user_id},
|
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||||
upsert=True, # ✅ create if missing
|
upsert=True, # ✅ create if missing
|
||||||
)
|
)
|
||||||
@@ -137,7 +149,7 @@ class NotesTool(Tool):
|
|||||||
if not old_str:
|
if not old_str:
|
||||||
return "old_str is required."
|
return "old_str is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id})
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
if not doc or not doc.get("note"):
|
if not doc or not doc.get("note"):
|
||||||
return "No note found."
|
return "No note found."
|
||||||
|
|
||||||
@@ -152,7 +164,7 @@ class NotesTool(Tool):
|
|||||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||||
|
|
||||||
self.collection.update_one(
|
self.collection.update_one(
|
||||||
{"user_id": self.user_id},
|
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||||
)
|
)
|
||||||
return "Note updated."
|
return "Note updated."
|
||||||
@@ -161,7 +173,7 @@ class NotesTool(Tool):
|
|||||||
if not text:
|
if not text:
|
||||||
return "Text is required."
|
return "Text is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id})
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
if not doc or not doc.get("note"):
|
if not doc or not doc.get("note"):
|
||||||
return "No note found."
|
return "No note found."
|
||||||
|
|
||||||
@@ -177,11 +189,11 @@ class NotesTool(Tool):
|
|||||||
updated_note = "\n".join(lines)
|
updated_note = "\n".join(lines)
|
||||||
|
|
||||||
self.collection.update_one(
|
self.collection.update_one(
|
||||||
{"user_id": self.user_id},
|
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||||
)
|
)
|
||||||
return "Text inserted."
|
return "Text inserted."
|
||||||
|
|
||||||
def _delete_note(self) -> str:
|
def _delete_note(self) -> str:
|
||||||
res = self.collection.delete_one({"user_id": self.user_id})
|
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
return "Note deleted." if res.deleted_count else "No note found to delete."
|
return "Note deleted." if res.deleted_count else "No note found to delete."
|
||||||
|
|||||||
@@ -16,10 +16,40 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
|||||||
tool_id = doc.get("tool_id")
|
tool_id = doc.get("tool_id")
|
||||||
path = doc.get("path")
|
path = doc.get("path")
|
||||||
key = f"{user_id}:{tool_id}:{path}"
|
key = f"{user_id}:{tool_id}:{path}"
|
||||||
|
# Add _id to document if not present
|
||||||
|
if "_id" not in doc:
|
||||||
|
doc["_id"] = key
|
||||||
self.docs[key] = doc
|
self.docs[key] = doc
|
||||||
return type("res", (), {"inserted_id": key})
|
return type("res", (), {"inserted_id": key})
|
||||||
|
|
||||||
def update_one(self, q, u, upsert=False):
|
def update_one(self, q, u, upsert=False):
|
||||||
|
# Handle query by _id
|
||||||
|
if "_id" in q:
|
||||||
|
doc_id = q["_id"]
|
||||||
|
if doc_id not in self.docs:
|
||||||
|
return type("res", (), {"modified_count": 0})
|
||||||
|
|
||||||
|
if "$set" in u:
|
||||||
|
old_doc = self.docs[doc_id].copy()
|
||||||
|
old_doc.update(u["$set"])
|
||||||
|
|
||||||
|
# If path changed, update the dictionary key
|
||||||
|
if "path" in u["$set"]:
|
||||||
|
new_path = u["$set"]["path"]
|
||||||
|
user_id = old_doc.get("user_id")
|
||||||
|
tool_id = old_doc.get("tool_id")
|
||||||
|
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||||
|
|
||||||
|
# Remove old key and add with new key
|
||||||
|
del self.docs[doc_id]
|
||||||
|
old_doc["_id"] = new_key
|
||||||
|
self.docs[new_key] = old_doc
|
||||||
|
else:
|
||||||
|
self.docs[doc_id] = old_doc
|
||||||
|
|
||||||
|
return type("res", (), {"modified_count": 1})
|
||||||
|
|
||||||
|
# Handle query by user_id, tool_id, path
|
||||||
user_id = q.get("user_id")
|
user_id = q.get("user_id")
|
||||||
tool_id = q.get("tool_id")
|
tool_id = q.get("tool_id")
|
||||||
path = q.get("path")
|
path = q.get("path")
|
||||||
@@ -29,7 +59,7 @@ def memory_tool(monkeypatch) -> MemoryTool:
|
|||||||
return type("res", (), {"modified_count": 0})
|
return type("res", (), {"modified_count": 0})
|
||||||
|
|
||||||
if key not in self.docs and upsert:
|
if key not in self.docs and upsert:
|
||||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
|
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||||
|
|
||||||
if "$set" in u:
|
if "$set" in u:
|
||||||
self.docs[key].update(u["$set"])
|
self.docs[key].update(u["$set"])
|
||||||
@@ -498,6 +528,33 @@ def test_memory_tool_isolation(monkeypatch) -> None:
|
|||||||
return type("res", (), {"inserted_id": key})
|
return type("res", (), {"inserted_id": key})
|
||||||
|
|
||||||
def update_one(self, q, u, upsert=False):
|
def update_one(self, q, u, upsert=False):
|
||||||
|
# Handle query by _id
|
||||||
|
if "_id" in q:
|
||||||
|
doc_id = q["_id"]
|
||||||
|
if doc_id not in self.docs:
|
||||||
|
return type("res", (), {"modified_count": 0})
|
||||||
|
|
||||||
|
if "$set" in u:
|
||||||
|
old_doc = self.docs[doc_id].copy()
|
||||||
|
old_doc.update(u["$set"])
|
||||||
|
|
||||||
|
# If path changed, update the dictionary key
|
||||||
|
if "path" in u["$set"]:
|
||||||
|
new_path = u["$set"]["path"]
|
||||||
|
user_id = old_doc.get("user_id")
|
||||||
|
tool_id = old_doc.get("tool_id")
|
||||||
|
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||||
|
|
||||||
|
# Remove old key and add with new key
|
||||||
|
del self.docs[doc_id]
|
||||||
|
old_doc["_id"] = new_key
|
||||||
|
self.docs[new_key] = old_doc
|
||||||
|
else:
|
||||||
|
self.docs[doc_id] = old_doc
|
||||||
|
|
||||||
|
return type("res", (), {"modified_count": 1})
|
||||||
|
|
||||||
|
# Handle query by user_id, tool_id, path
|
||||||
user_id = q.get("user_id")
|
user_id = q.get("user_id")
|
||||||
tool_id = q.get("tool_id")
|
tool_id = q.get("tool_id")
|
||||||
path = q.get("path")
|
path = q.get("path")
|
||||||
@@ -507,7 +564,7 @@ def test_memory_tool_isolation(monkeypatch) -> None:
|
|||||||
return type("res", (), {"modified_count": 0})
|
return type("res", (), {"modified_count": 0})
|
||||||
|
|
||||||
if key not in self.docs and upsert:
|
if key not in self.docs and upsert:
|
||||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
|
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||||
|
|
||||||
if "$set" in u:
|
if "$set" in u:
|
||||||
self.docs[key].update(u["$set"])
|
self.docs[key].update(u["$set"])
|
||||||
@@ -607,3 +664,102 @@ def test_paths_without_leading_slash(memory_tool) -> None:
|
|||||||
|
|
||||||
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
|
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
|
||||||
assert "Task 1" in view_nested
|
assert "Task 1" in view_nested
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_directory(memory_tool: MemoryTool) -> None:
|
||||||
|
"""Test renaming a directory with files."""
|
||||||
|
# Create files in a directory
|
||||||
|
memory_tool.execute_action(
|
||||||
|
"create",
|
||||||
|
path="/docs/file1.txt",
|
||||||
|
file_text="Content 1"
|
||||||
|
)
|
||||||
|
memory_tool.execute_action(
|
||||||
|
"create",
|
||||||
|
path="/docs/sub/file2.txt",
|
||||||
|
file_text="Content 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rename directory (with trailing slash)
|
||||||
|
result = memory_tool.execute_action(
|
||||||
|
"rename",
|
||||||
|
old_path="/docs/",
|
||||||
|
new_path="/archive/"
|
||||||
|
)
|
||||||
|
assert "renamed" in result.lower()
|
||||||
|
assert "2 files" in result.lower()
|
||||||
|
|
||||||
|
# Verify old paths don't exist
|
||||||
|
result = memory_tool.execute_action("view", path="/docs/file1.txt")
|
||||||
|
assert "not found" in result.lower()
|
||||||
|
|
||||||
|
# Verify new paths exist
|
||||||
|
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||||
|
assert "Content 1" in result
|
||||||
|
|
||||||
|
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
|
||||||
|
assert "Content 2" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
|
||||||
|
"""Test renaming a directory when new path is missing trailing slash."""
|
||||||
|
# Create files in a directory
|
||||||
|
memory_tool.execute_action(
|
||||||
|
"create",
|
||||||
|
path="/docs/file1.txt",
|
||||||
|
file_text="Content 1"
|
||||||
|
)
|
||||||
|
memory_tool.execute_action(
|
||||||
|
"create",
|
||||||
|
path="/docs/sub/file2.txt",
|
||||||
|
file_text="Content 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rename directory - old path has slash, new path doesn't
|
||||||
|
result = memory_tool.execute_action(
|
||||||
|
"rename",
|
||||||
|
old_path="/docs/",
|
||||||
|
new_path="/archive" # Missing trailing slash
|
||||||
|
)
|
||||||
|
assert "renamed" in result.lower()
|
||||||
|
|
||||||
|
# Verify paths are correct (not corrupted like "/archivesub/file2.txt")
|
||||||
|
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||||
|
assert "Content 1" in result
|
||||||
|
|
||||||
|
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
|
||||||
|
assert "Content 2" in result
|
||||||
|
|
||||||
|
# Verify corrupted path doesn't exist
|
||||||
|
result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
|
||||||
|
assert "not found" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
|
||||||
|
"""Test that view_range displays correct line numbers."""
|
||||||
|
# Create a multiline file
|
||||||
|
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||||
|
memory_tool.execute_action(
|
||||||
|
"create",
|
||||||
|
path="/numbered.txt",
|
||||||
|
file_text=content
|
||||||
|
)
|
||||||
|
|
||||||
|
# View lines 2-4
|
||||||
|
result = memory_tool.execute_action(
|
||||||
|
"view",
|
||||||
|
path="/numbered.txt",
|
||||||
|
view_range=[2, 4]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
|
||||||
|
assert "2: Line 2" in result
|
||||||
|
assert "3: Line 3" in result
|
||||||
|
assert "4: Line 4" in result
|
||||||
|
assert "1: Line 1" not in result
|
||||||
|
assert "5: Line 5" not in result
|
||||||
|
|
||||||
|
# Verify no off-by-one error
|
||||||
|
assert "3: Line 2" not in result # Wrong line number
|
||||||
|
assert "4: Line 3" not in result # Wrong line number
|
||||||
|
assert "5: Line 4" not in result # Wrong line number
|
||||||
|
|||||||
@@ -9,26 +9,34 @@ def notes_tool(monkeypatch) -> NotesTool:
|
|||||||
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
|
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
|
||||||
class FakeCollection:
|
class FakeCollection:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.doc = None # single note per user
|
self.docs = {} # key: user_id:tool_id -> doc
|
||||||
|
|
||||||
def update_one(self, q, u, upsert=False):
|
def update_one(self, q, u, upsert=False):
|
||||||
|
user_id = q.get("user_id")
|
||||||
|
tool_id = q.get("tool_id")
|
||||||
|
key = f"{user_id}:{tool_id}"
|
||||||
|
|
||||||
# emulate single-note storage with optional upsert
|
# emulate single-note storage with optional upsert
|
||||||
if self.doc is None and not upsert:
|
if key not in self.docs and not upsert:
|
||||||
return type("res", (), {"modified_count": 0})
|
return type("res", (), {"modified_count": 0})
|
||||||
if self.doc is None and upsert:
|
if key not in self.docs and upsert:
|
||||||
self.doc = {"user_id": q["user_id"], "note": ""}
|
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
|
||||||
if "$set" in u and "note" in u["$set"]:
|
if "$set" in u and "note" in u["$set"]:
|
||||||
self.doc["note"] = u["$set"]["note"]
|
self.docs[key]["note"] = u["$set"]["note"]
|
||||||
return type("res", (), {"modified_count": 1})
|
return type("res", (), {"modified_count": 1})
|
||||||
|
|
||||||
def find_one(self, q):
|
def find_one(self, q):
|
||||||
if self.doc and self.doc.get("user_id") == q.get("user_id"):
|
user_id = q.get("user_id")
|
||||||
return self.doc
|
tool_id = q.get("tool_id")
|
||||||
return None
|
key = f"{user_id}:{tool_id}"
|
||||||
|
return self.docs.get(key)
|
||||||
|
|
||||||
def delete_one(self, q):
|
def delete_one(self, q):
|
||||||
if self.doc and self.doc.get("user_id") == q.get("user_id"):
|
user_id = q.get("user_id")
|
||||||
self.doc = None
|
tool_id = q.get("tool_id")
|
||||||
|
key = f"{user_id}:{tool_id}"
|
||||||
|
if key in self.docs:
|
||||||
|
del self.docs[key]
|
||||||
return type("res", (), {"deleted_count": 1})
|
return type("res", (), {"deleted_count": 1})
|
||||||
return type("res", (), {"deleted_count": 0})
|
return type("res", (), {"deleted_count": 0})
|
||||||
|
|
||||||
@@ -39,14 +47,14 @@ def notes_tool(monkeypatch) -> NotesTool:
|
|||||||
# Patch MongoDB client globally for the tool
|
# Patch MongoDB client globally for the tool
|
||||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||||
|
|
||||||
# ToolManager will pass user_id in production; in tests we pass it directly
|
# Return tool with a fixed tool_id for consistency in tests
|
||||||
return NotesTool({}, user_id="test_user")
|
return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||||
|
|
||||||
|
|
||||||
def test_view(notes_tool: NotesTool) -> None:
|
def test_view(notes_tool: NotesTool) -> None:
|
||||||
# Manually insert a note to test retrieval
|
# Manually insert a note to test retrieval
|
||||||
notes_tool.collection.update_one(
|
notes_tool.collection.update_one(
|
||||||
{"user_id": "test_user"},
|
{"user_id": "test_user", "tool_id": "test_tool_id"},
|
||||||
{"$set": {"note": "hello"}},
|
{"$set": {"note": "hello"}},
|
||||||
upsert=True
|
upsert=True
|
||||||
)
|
)
|
||||||
@@ -131,3 +139,85 @@ def test_delete_nonexistent_note(monkeypatch):
|
|||||||
notes_tool = NotesTool(tool_config={}, user_id="user123")
|
notes_tool = NotesTool(tool_config={}, user_id="user123")
|
||||||
result = notes_tool.execute_action("delete")
|
result = notes_tool.execute_action("delete")
|
||||||
assert "no note found" in result.lower()
|
assert "no note found" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_notes_tool_isolation(monkeypatch) -> None:
|
||||||
|
"""Test that different notes tool instances have isolated notes."""
|
||||||
|
class FakeCollection:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.docs = {}
|
||||||
|
|
||||||
|
def update_one(self, q, u, upsert=False):
|
||||||
|
user_id = q.get("user_id")
|
||||||
|
tool_id = q.get("tool_id")
|
||||||
|
key = f"{user_id}:{tool_id}"
|
||||||
|
|
||||||
|
if key not in self.docs and not upsert:
|
||||||
|
return type("res", (), {"modified_count": 0})
|
||||||
|
if key not in self.docs and upsert:
|
||||||
|
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
|
||||||
|
if "$set" in u and "note" in u["$set"]:
|
||||||
|
self.docs[key]["note"] = u["$set"]["note"]
|
||||||
|
return type("res", (), {"modified_count": 1})
|
||||||
|
|
||||||
|
def find_one(self, q):
|
||||||
|
user_id = q.get("user_id")
|
||||||
|
tool_id = q.get("tool_id")
|
||||||
|
key = f"{user_id}:{tool_id}"
|
||||||
|
return self.docs.get(key)
|
||||||
|
|
||||||
|
fake_collection = FakeCollection()
|
||||||
|
fake_db = {"notes": fake_collection}
|
||||||
|
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||||
|
|
||||||
|
# Create two notes tools with different tool_ids for the same user
|
||||||
|
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||||
|
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||||
|
|
||||||
|
# Create a note in tool1
|
||||||
|
tool1.execute_action("overwrite", text="Content from tool 1")
|
||||||
|
|
||||||
|
# Create a note in tool2
|
||||||
|
tool2.execute_action("overwrite", text="Content from tool 2")
|
||||||
|
|
||||||
|
# Verify that each tool sees only its own content
|
||||||
|
result1 = tool1.execute_action("view")
|
||||||
|
result2 = tool2.execute_action("view")
|
||||||
|
|
||||||
|
assert "Content from tool 1" in result1
|
||||||
|
assert "Content from tool 2" not in result1
|
||||||
|
|
||||||
|
assert "Content from tool 2" in result2
|
||||||
|
assert "Content from tool 1" not in result2
|
||||||
|
|
||||||
|
|
||||||
|
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||||
|
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||||
|
class FakeCollection:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.docs = {}
|
||||||
|
|
||||||
|
def update_one(self, q, u, upsert=False):
|
||||||
|
return type("res", (), {"modified_count": 1})
|
||||||
|
|
||||||
|
fake_collection = FakeCollection()
|
||||||
|
fake_db = {"notes": fake_collection}
|
||||||
|
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||||
|
|
||||||
|
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||||
|
|
||||||
|
# Create two tools without providing tool_id for the same user
|
||||||
|
tool1 = NotesTool({}, user_id="test_user")
|
||||||
|
tool2 = NotesTool({}, user_id="test_user")
|
||||||
|
|
||||||
|
# Both should have the same default tool_id for persistence
|
||||||
|
assert tool1.tool_id == "default_test_user"
|
||||||
|
assert tool2.tool_id == "default_test_user"
|
||||||
|
assert tool1.tool_id == tool2.tool_id
|
||||||
|
|
||||||
|
# Different users should have different tool_ids
|
||||||
|
tool3 = NotesTool({}, user_id="another_user")
|
||||||
|
assert tool3.tool_id == "default_another_user"
|
||||||
|
assert tool3.tool_id != tool1.tool_id
|
||||||
|
|||||||
Reference in New Issue
Block a user