feat: implement MemoryTool with CRUD operations and integrate with MongoDB

This commit is contained in:
Alex
2025-10-06 21:09:22 +01:00
parent 4c31e9a8b1
commit e012189672
5 changed files with 1165 additions and 12 deletions

View File

@@ -213,18 +213,24 @@ class BaseAgent(ABC):
):
target_dict[param] = value
tm = ToolManager(config={})
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
tool_config = {
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
tool = tm.load_tool(
tool_data["name"],
tool_config=(
{
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
tool_config=tool_config,
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":

View File

@@ -0,0 +1,535 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import re
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class MemoryTool(Tool):
"""Memory
Stores and retrieves information across conversations through a memory file directory.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated memories
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["memories"]
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of view, create, str_replace, insert, delete, rename.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: MemoryTool requires a valid user_id."
if action_name == "view":
return self._view(
kwargs.get("path", "/"),
kwargs.get("view_range")
)
if action_name == "create":
return self._create(
kwargs.get("path", ""),
kwargs.get("file_text", "")
)
if action_name == "str_replace":
return self._str_replace(
kwargs.get("path", ""),
kwargs.get("old_str", ""),
kwargs.get("new_str", "")
)
if action_name == "insert":
return self._insert(
kwargs.get("path", ""),
kwargs.get("insert_line", 1),
kwargs.get("insert_text", "")
)
if action_name == "delete":
return self._delete(kwargs.get("path", ""))
if action_name == "rename":
return self._rename(
kwargs.get("old_path", ""),
kwargs.get("new_path", "")
)
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "view",
"description": "Shows directory contents or file contents with optional line ranges.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to file or directory (e.g., /notes.txt or /project/)."
},
"view_range": {
"type": "array",
"items": {"type": "integer"},
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
}
},
"required": ["path"]
},
},
{
"name": "create",
"description": "Create or overwrite a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
},
"file_text": {
"type": "string",
"description": "Content to write to the file."
}
},
"required": ["path", "file_text"]
},
},
{
"name": "str_replace",
"description": "Replace text in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"old_str": {
"type": "string",
"description": "String to find."
},
"new_str": {
"type": "string",
"description": "String to replace with."
}
},
"required": ["path", "old_str", "new_str"]
},
},
{
"name": "insert",
"description": "Insert text at a specific line in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"insert_line": {
"type": "integer",
"description": "Line number to insert at (1-indexed)."
},
"insert_text": {
"type": "string",
"description": "Text to insert."
}
},
"required": ["path", "insert_line", "insert_text"]
},
},
{
"name": "delete",
"description": "Delete a file or directory.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to delete (e.g., /notes.txt or /project/)."
}
},
"required": ["path"]
},
},
{
"name": "rename",
"description": "Rename or move a file/directory.",
"parameters": {
"type": "object",
"properties": {
"old_path": {
"type": "string",
"description": "Current path (e.g., /old.txt)."
},
"new_path": {
"type": "string",
"description": "New path (e.g., /new.txt)."
}
},
"required": ["old_path", "new_path"]
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements."""
return {}
# -----------------------------
# Path validation
# -----------------------------
def _validate_path(self, path: str) -> Optional[str]:
"""Validate and normalize path.
Args:
path: User-provided path.
Returns:
Normalized path or None if invalid.
"""
if not path:
return None
# Remove any leading/trailing whitespace
path = path.strip()
# Ensure path starts with / for consistency
if not path.startswith("/"):
path = "/" + path
# Check for directory traversal patterns
if ".." in path or path.count("//") > 0:
return None
# Normalize the path
try:
# Convert to Path object and resolve to canonical form
normalized = str(Path(path).as_posix())
# Ensure it still starts with /
if not normalized.startswith("/"):
return None
return normalized
except Exception:
return None
# -----------------------------
# Internal helpers
# -----------------------------
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View directory contents or file contents."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
# Check if viewing directory (ends with / or is root)
if validated_path == "/" or validated_path.endswith("/"):
return self._view_directory(validated_path)
# Otherwise view file
return self._view_file(validated_path, view_range)
def _view_directory(self, path: str) -> str:
"""List files in a directory."""
# Ensure path ends with / for proper prefix matching
search_path = path if path.endswith("/") else path + "/"
# Find all files that start with this directory path
query = {
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
}
docs = list(self.collection.find(query, {"path": 1}))
if not docs:
return f"Directory: {path}\n(empty)"
# Extract filenames relative to the directory
files = []
for doc in docs:
file_path = doc["path"]
# Remove the directory prefix
if file_path.startswith(search_path):
relative = file_path[len(search_path):]
if relative:
files.append(relative)
files.sort()
file_list = "\n".join(f"- {f}" for f in files)
return f"Directory: {path}\n{file_list}"
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View file contents with optional line range."""
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
if not doc or not doc.get("content"):
return f"Error: File not found: {path}"
content = str(doc["content"])
# Apply view_range if specified
if view_range and len(view_range) == 2:
lines = content.split("\n")
start, end = view_range
# Convert to 0-indexed
start_idx = max(0, start - 1)
end_idx = min(len(lines), end)
if start_idx >= len(lines):
return f"Error: Line range out of bounds. File has {len(lines)} lines."
selected_lines = lines[start_idx:end_idx]
# Add line numbers
numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)]
return "\n".join(numbered_lines)
return content
def _create(self, path: str, file_text: str) -> str:
"""Create or overwrite a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/" or validated_path.endswith("/"):
return "Error: Cannot create a file at directory path."
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": file_text,
"updated_at": datetime.now()
}
},
upsert=True
)
return f"File created: {validated_path}"
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
"""Replace text in a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not old_str:
return "Error: old_str is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Replace the string (case-insensitive)
import re as regex_module
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"File updated: {validated_path}"
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
"""Insert text at a specific line."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not insert_text:
return "Error: insert_text is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
lines = current_content.split("\n")
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"Text inserted at line {insert_line} in {validated_path}"
def _delete(self, path: str) -> str:
"""Delete a file or directory."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/":
# Delete all files for this user and tool
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
return f"Deleted {result.deleted_count} file(s) from memory."
# Check if it's a directory (ends with /)
if validated_path.endswith("/"):
# Delete all files in directory
result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_path)}"}
})
return f"Deleted directory and {result.deleted_count} file(s)."
# Try to delete as directory first (without trailing slash)
# Check if any files start with this path + /
search_path = validated_path + "/"
directory_result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
})
if directory_result.deleted_count > 0:
return f"Deleted directory and {directory_result.deleted_count} file(s)."
# Delete single file
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_path
})
if result.deleted_count:
return f"Deleted: {validated_path}"
return f"Error: File not found: {validated_path}"
def _rename(self, old_path: str, new_path: str) -> str:
"""Rename or move a file/directory."""
validated_old = self._validate_path(old_path)
validated_new = self._validate_path(new_path)
if not validated_old or not validated_new:
return "Error: Invalid path."
if validated_old == "/" or validated_new == "/":
return "Error: Cannot rename root directory."
# Check if renaming a directory
if validated_old.endswith("/"):
# Find all files in the old directory
docs = list(self.collection.find({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_old)}"}
}))
if not docs:
return f"Error: Directory not found: {validated_old}"
# Update paths for all files
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
self.collection.update_one(
{"_id": doc["_id"]},
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
)
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
# Rename single file
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_old
})
if not doc:
return f"Error: File not found: {validated_old}"
# Check if new path already exists
existing = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new
})
if existing:
return f"Error: File already exists at {validated_new}"
# Delete the old document and create a new one with the new path
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
self.collection.insert_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new,
"content": doc.get("content", ""),
"updated_at": datetime.now()
})
return f"Renamed: {validated_old} -> {validated_new}"

View File

@@ -28,7 +28,7 @@ class ToolManager:
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if tool_name in {"mcp_tool","notes"} and user_id:
if tool_name in {"mcp_tool", "notes", "memory"} and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
@@ -36,7 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name == "mcp_tool" and user_id:
if tool_name in {"mcp_tool", "memory"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3">
<path d="M240-80q-33 0-56.5-23.5T160-160v-480q0-33 23.5-56.5T240-720h80v-80q0-17 11.5-28.5T360-840q17 0 28.5 11.5T400-800v80h40v-80q0-17 11.5-28.5T480-840q17 0 28.5 11.5T520-800v80h40v-80q0-17 11.5-28.5T600-840q17 0 28.5 11.5T640-800v80h80q33 0 56.5 23.5T800-640v480q0 33-23.5 56.5T720-80H240Zm0-80h480v-480H240v480Zm120-320v-80h240v80H360Zm0 120v-80h240v80H360Zm0 120v-80h160v80H360ZM240-160v-480 480Z"/>
</svg>

After

Width:  |  Height:  |  Size: 523 B

609
tests/test_memory_tool.py Normal file
View File

@@ -0,0 +1,609 @@
import pytest
from application.agents.tools.memory import MemoryTool
from application.core.settings import settings
@pytest.fixture
def memory_tool(monkeypatch) -> MemoryTool:
"""Provide a MemoryTool with a fake Mongo collection and fixed user_id."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {} # path -> document
def insert_one(self, doc):
user_id = doc.get("user_id")
tool_id = doc.get("tool_id")
path = doc.get("path")
key = f"{user_id}:{tool_id}:{path}"
self.docs[key] = doc
return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
key = f"{user_id}:{tool_id}:{path}"
if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
if "$set" in u:
self.docs[key].update(u["$set"])
return type("res", (), {"modified_count": 1})
def find_one(self, q, projection=None):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
if path:
key = f"{user_id}:{tool_id}:{path}"
return self.docs.get(key)
return None
def find(self, q, projection=None):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
results = []
# Handle regex queries for directory listing
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
regex_pattern = q["path"]["$regex"]
# Remove regex escape characters and ^ anchor for simple matching
pattern = regex_pattern.replace("\\", "").lstrip("^")
for key, doc in self.docs.items():
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
doc_path = doc.get("path", "")
if doc_path.startswith(pattern):
results.append(doc)
else:
for key, doc in self.docs.items():
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
results.append(doc)
return results
def delete_one(self, q):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
key = f"{user_id}:{tool_id}:{path}"
if key in self.docs:
del self.docs[key]
return type("res", (), {"deleted_count": 1})
return type("res", (), {"deleted_count": 0})
def delete_many(self, q):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
deleted = 0
# Handle regex queries for directory deletion
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
regex_pattern = q["path"]["$regex"]
pattern = regex_pattern.replace("\\", "").lstrip("^")
keys_to_delete = []
for key, doc in self.docs.items():
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
doc_path = doc.get("path", "")
if doc_path.startswith(pattern):
keys_to_delete.append(key)
for key in keys_to_delete:
del self.docs[key]
deleted += 1
else:
# Delete all for user and tool
keys_to_delete = [
key for key, doc in self.docs.items()
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id
]
for key in keys_to_delete:
del self.docs[key]
deleted += 1
return type("res", (), {"deleted_count": deleted})
fake_collection = FakeCollection()
fake_db = {"memories": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Return tool with a fixed tool_id for consistency in tests
return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user")
def test_init_without_user_id():
"""Should fail gracefully if no user_id is provided."""
memory_tool = MemoryTool(tool_config={})
result = memory_tool.execute_action("view", path="/")
assert "user_id" in result.lower()
def test_view_empty_directory(memory_tool: MemoryTool) -> None:
"""Should show empty directory when no files exist."""
result = memory_tool.execute_action("view", path="/")
assert "empty" in result.lower()
def test_create_and_view_file(memory_tool: MemoryTool) -> None:
"""Test creating a file and viewing it."""
# Create a file
result = memory_tool.execute_action(
"create",
path="/notes.txt",
file_text="Hello world"
)
assert "created" in result.lower()
# View the file
result = memory_tool.execute_action("view", path="/notes.txt")
assert "Hello world" in result
def test_create_overwrite_file(memory_tool: MemoryTool) -> None:
"""Test that create overwrites existing files."""
# Create initial file
memory_tool.execute_action(
"create",
path="/test.txt",
file_text="Original content"
)
# Overwrite
memory_tool.execute_action(
"create",
path="/test.txt",
file_text="New content"
)
# Verify overwrite
result = memory_tool.execute_action("view", path="/test.txt")
assert "New content" in result
assert "Original content" not in result
def test_view_directory_with_files(memory_tool: MemoryTool) -> None:
"""Test viewing directory contents."""
# Create multiple files
memory_tool.execute_action(
"create",
path="/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/file2.txt",
file_text="Content 2"
)
memory_tool.execute_action(
"create",
path="/subdir/file3.txt",
file_text="Content 3"
)
# View directory
result = memory_tool.execute_action("view", path="/")
assert "file1.txt" in result
assert "file2.txt" in result
assert "subdir/file3.txt" in result
def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
"""Test viewing specific lines from a file."""
# Create a multiline file
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
memory_tool.execute_action(
"create",
path="/multiline.txt",
file_text=content
)
# View lines 2-4
result = memory_tool.execute_action(
"view",
path="/multiline.txt",
view_range=[2, 4]
)
assert "Line 2" in result
assert "Line 3" in result
assert "Line 4" in result
assert "Line 1" not in result
assert "Line 5" not in result
def test_str_replace(memory_tool: MemoryTool) -> None:
"""Test string replacement in a file."""
# Create a file
memory_tool.execute_action(
"create",
path="/replace.txt",
file_text="Hello world, hello universe"
)
# Replace text
result = memory_tool.execute_action(
"str_replace",
path="/replace.txt",
old_str="hello",
new_str="hi"
)
assert "updated" in result.lower()
# Verify replacement
content = memory_tool.execute_action("view", path="/replace.txt")
assert "hi world, hi universe" in content
def test_str_replace_not_found(memory_tool: MemoryTool) -> None:
"""Test string replacement when string not found."""
memory_tool.execute_action(
"create",
path="/test.txt",
file_text="Hello world"
)
result = memory_tool.execute_action(
"str_replace",
path="/test.txt",
old_str="goodbye",
new_str="hi"
)
assert "not found" in result.lower()
def test_insert_line(memory_tool: MemoryTool) -> None:
"""Test inserting text at a line number."""
# Create a multiline file
memory_tool.execute_action(
"create",
path="/insert.txt",
file_text="Line 1\nLine 2\nLine 3"
)
# Insert at line 2
result = memory_tool.execute_action(
"insert",
path="/insert.txt",
insert_line=2,
insert_text="Inserted line"
)
assert "inserted" in result.lower()
# Verify insertion
content = memory_tool.execute_action("view", path="/insert.txt")
lines = content.split("\n")
assert lines[1] == "Inserted line"
assert lines[2] == "Line 2"
def test_insert_invalid_line(memory_tool: MemoryTool) -> None:
"""Test inserting at an invalid line number."""
memory_tool.execute_action(
"create",
path="/test.txt",
file_text="Line 1\nLine 2"
)
result = memory_tool.execute_action(
"insert",
path="/test.txt",
insert_line=100,
insert_text="Text"
)
assert "invalid" in result.lower()
def test_delete_file(memory_tool: MemoryTool) -> None:
"""Test deleting a file."""
# Create a file
memory_tool.execute_action(
"create",
path="/delete_me.txt",
file_text="Content"
)
# Delete it
result = memory_tool.execute_action("delete", path="/delete_me.txt")
assert "deleted" in result.lower()
# Verify it's gone
result = memory_tool.execute_action("view", path="/delete_me.txt")
assert "not found" in result.lower()
def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None:
"""Test deleting a file that doesn't exist."""
result = memory_tool.execute_action("delete", path="/nonexistent.txt")
assert "not found" in result.lower()
def test_delete_directory(memory_tool: MemoryTool) -> None:
"""Test deleting a directory with files."""
# Create files in a directory
memory_tool.execute_action(
"create",
path="/subdir/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/subdir/file2.txt",
file_text="Content 2"
)
# Delete the directory
result = memory_tool.execute_action("delete", path="/subdir/")
assert "deleted" in result.lower()
# Verify files are gone
result = memory_tool.execute_action("view", path="/subdir/file1.txt")
assert "not found" in result.lower()
def test_rename_file(memory_tool: MemoryTool) -> None:
"""Test renaming a file."""
# Create a file
memory_tool.execute_action(
"create",
path="/old_name.txt",
file_text="Content"
)
# Rename it
result = memory_tool.execute_action(
"rename",
old_path="/old_name.txt",
new_path="/new_name.txt"
)
assert "renamed" in result.lower()
# Verify old path doesn't exist
result = memory_tool.execute_action("view", path="/old_name.txt")
assert "not found" in result.lower()
# Verify new path exists
result = memory_tool.execute_action("view", path="/new_name.txt")
assert "Content" in result
def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None:
"""Test renaming a file that doesn't exist."""
result = memory_tool.execute_action(
"rename",
old_path="/nonexistent.txt",
new_path="/new.txt"
)
assert "not found" in result.lower()
def test_rename_to_existing_file(memory_tool: MemoryTool) -> None:
"""Test renaming to a path that already exists."""
# Create two files
memory_tool.execute_action(
"create",
path="/file1.txt",
file_text="Content 1"
)
memory_tool.execute_action(
"create",
path="/file2.txt",
file_text="Content 2"
)
# Try to rename file1 to file2
result = memory_tool.execute_action(
"rename",
old_path="/file1.txt",
new_path="/file2.txt"
)
assert "already exists" in result.lower()
def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
"""Test that directory traversal attacks are prevented."""
# Try various path traversal attempts
invalid_paths = [
"/../secrets.txt",
"/../../etc/passwd",
"..//file.txt",
"/subdir/../../outside.txt",
]
for path in invalid_paths:
result = memory_tool.execute_action(
"create",
path=path,
file_text="malicious content"
)
assert "invalid path" in result.lower()
def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
"""Test that paths work with or without leading slash (auto-normalized)."""
# These paths should all work now (auto-prepended with /)
valid_paths = [
"etc/passwd", # Auto-prepended with /
"home/user/file.txt", # Auto-prepended with /
"file.txt", # Auto-prepended with /
]
for path in valid_paths:
result = memory_tool.execute_action(
"create",
path=path,
file_text="content"
)
assert "created" in result.lower()
# Verify the file can be accessed with or without leading slash
view_result = memory_tool.execute_action("view", path=path)
assert "content" in view_result
def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None:
"""Test that you cannot create a file at a directory path."""
result = memory_tool.execute_action(
"create",
path="/",
file_text="content"
)
assert "cannot create a file at directory path" in result.lower()
def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
"""Test that action metadata is properly defined."""
metadata = memory_tool.get_actions_metadata()
# Check that all expected actions are defined
action_names = [action["name"] for action in metadata]
assert "view" in action_names
assert "create" in action_names
assert "str_replace" in action_names
assert "insert" in action_names
assert "delete" in action_names
assert "rename" in action_names
# Check that each action has required fields
for action in metadata:
assert "name" in action
assert "description" in action
assert "parameters" in action
def test_memory_tool_isolation(monkeypatch) -> None:
"""Test that different memory tool instances have isolated memories."""
# Create fake collection
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def insert_one(self, doc):
user_id = doc.get("user_id")
tool_id = doc.get("tool_id")
path = doc.get("path")
key = f"{user_id}:{tool_id}:{path}"
self.docs[key] = doc
return type("res", (), {"inserted_id": key})
def update_one(self, q, u, upsert=False):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
key = f"{user_id}:{tool_id}:{path}"
if key not in self.docs and not upsert:
return type("res", (), {"modified_count": 0})
if key not in self.docs and upsert:
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
if "$set" in u:
self.docs[key].update(u["$set"])
return type("res", (), {"modified_count": 1})
def find_one(self, q, projection=None):
user_id = q.get("user_id")
tool_id = q.get("tool_id")
path = q.get("path")
if path:
key = f"{user_id}:{tool_id}:{path}"
return self.docs.get(key)
return None
fake_collection = FakeCollection()
fake_db = {"memories": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two memory tools with different tool_ids for the same user
tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user")
tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user")
# Create a file in tool1
tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1")
# Create a file with the same path in tool2
tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2")
# Verify that each tool sees only its own content
result1 = tool1.execute_action("view", path="/file.txt")
result2 = tool2.execute_action("view", path="/file.txt")
assert "Content from tool 1" in result1
assert "Content from tool 2" not in result1
assert "Content from tool 2" in result2
assert "Content from tool 1" not in result2
def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
class FakeCollection:
def __init__(self) -> None:
self.docs = {}
def update_one(self, q, u, upsert=False):
return type("res", (), {"modified_count": 1})
fake_collection = FakeCollection()
fake_db = {"memories": fake_collection}
fake_client = {settings.MONGO_DB_NAME: fake_db}
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
# Create two tools without providing tool_id for the same user
tool1 = MemoryTool({}, user_id="test_user")
tool2 = MemoryTool({}, user_id="test_user")
# Both should have the same default tool_id for persistence
assert tool1.tool_id == "default_test_user"
assert tool2.tool_id == "default_test_user"
assert tool1.tool_id == tool2.tool_id
# Different users should have different tool_ids
tool3 = MemoryTool({}, user_id="another_user")
assert tool3.tool_id == "default_another_user"
assert tool3.tool_id != tool1.tool_id
def test_paths_without_leading_slash(memory_tool) -> None:
"""Test that paths without leading slash work correctly."""
# Create file without leading slash
result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung")
assert "created" in result.lower()
# View file without leading slash
view_result = memory_tool.execute_action("view", path="cat_breeds.txt")
assert "Korat" in view_result
assert "Chartreux" in view_result
# View same file with leading slash (should work the same)
view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt")
assert "Korat" in view_result2
# Test str_replace without leading slash
replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon")
assert "updated" in replace_result.lower()
# Test nested path without leading slash
nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2")
assert "created" in nested_result.lower()
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
assert "Task 1" in view_nested