mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
feat: implement MemoryTool with CRUD operations and integrate with MongoDB
This commit is contained in:
@@ -213,18 +213,24 @@ class BaseAgent(ABC):
|
||||
):
|
||||
target_dict[param] = value
|
||||
tm = ToolManager(config={})
|
||||
|
||||
# Prepare tool_config and add tool_id for memory tools
|
||||
if tool_data["name"] == "api_tool":
|
||||
tool_config = {
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
||||
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
tool_config=tool_config,
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
|
||||
535
application/agents/tools/memory.py
Normal file
535
application/agents/tools/memory.py
Normal file
@@ -0,0 +1,535 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
"""Memory
|
||||
|
||||
Stores and retrieves information across conversations through a memory file directory.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated memories
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
kwargs.get("view_range")
|
||||
)
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("file_text", "")
|
||||
)
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("old_str", ""),
|
||||
kwargs.get("new_str", "")
|
||||
)
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("insert_line", 1),
|
||||
kwargs.get("insert_text", "")
|
||||
)
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("path", ""))
|
||||
|
||||
if action_name == "rename":
|
||||
return self._rename(
|
||||
kwargs.get("old_path", ""),
|
||||
kwargs.get("new_path", "")
|
||||
)
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to file or directory (e.g., /notes.txt or /project/)."
|
||||
},
|
||||
"view_range": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create or overwrite a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||
},
|
||||
"file_text": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "file_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "String to find."
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "String to replace with."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_str", "new_str"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at a specific line in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"insert_line": {
|
||||
"type": "integer",
|
||||
"description": "Line number to insert at (1-indexed)."
|
||||
},
|
||||
"insert_text": {
|
||||
"type": "string",
|
||||
"description": "Text to insert."
|
||||
}
|
||||
},
|
||||
"required": ["path", "insert_line", "insert_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "rename",
|
||||
"description": "Rename or move a file/directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_path": {
|
||||
"type": "string",
|
||||
"description": "Current path (e.g., /old.txt)."
|
||||
},
|
||||
"new_path": {
|
||||
"type": "string",
|
||||
"description": "New path (e.g., /new.txt)."
|
||||
}
|
||||
},
|
||||
"required": ["old_path", "new_path"]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Path validation
|
||||
# -----------------------------
|
||||
def _validate_path(self, path: str) -> Optional[str]:
|
||||
"""Validate and normalize path.
|
||||
|
||||
Args:
|
||||
path: User-provided path.
|
||||
|
||||
Returns:
|
||||
Normalized path or None if invalid.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
path = path.strip()
|
||||
|
||||
# Ensure path starts with / for consistency
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for directory traversal patterns
|
||||
if ".." in path or path.count("//") > 0:
|
||||
return None
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
# Convert to Path object and resolve to canonical form
|
||||
normalized = str(Path(path).as_posix())
|
||||
|
||||
# Ensure it still starts with /
|
||||
if not normalized.startswith("/"):
|
||||
return None
|
||||
|
||||
return normalized
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View directory contents or file contents."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
# Check if viewing directory (ends with / or is root)
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return self._view_directory(validated_path)
|
||||
|
||||
# Otherwise view file
|
||||
return self._view_file(validated_path, view_range)
|
||||
|
||||
def _view_directory(self, path: str) -> str:
|
||||
"""List files in a directory."""
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
|
||||
# Extract filenames relative to the directory
|
||||
files = []
|
||||
for doc in docs:
|
||||
file_path = doc["path"]
|
||||
# Remove the directory prefix
|
||||
if file_path.startswith(search_path):
|
||||
relative = file_path[len(search_path):]
|
||||
if relative:
|
||||
files.append(relative)
|
||||
|
||||
files.sort()
|
||||
file_list = "\n".join(f"- {f}" for f in files)
|
||||
return f"Directory: {path}\n{file_list}"
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = str(doc["content"])
|
||||
|
||||
# Apply view_range if specified
|
||||
if view_range and len(view_range) == 2:
|
||||
lines = content.split("\n")
|
||||
start, end = view_range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start - 1)
|
||||
end_idx = min(len(lines), end)
|
||||
|
||||
if start_idx >= len(lines):
|
||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
# Add line numbers
|
||||
numbered_lines = [f"{i+start}: {line}" for i, line in enumerate(selected_lines, start=start_idx)]
|
||||
return "\n".join(numbered_lines)
|
||||
|
||||
return content
|
||||
|
||||
def _create(self, path: str, file_text: str) -> str:
|
||||
"""Create or overwrite a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||
"""Replace text in a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||
"""Insert text at a specific line."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
def _delete(self, path: str) -> str:
|
||||
"""Delete a file or directory."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
search_path = validated_path + "/"
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
def _rename(self, old_path: str, new_path: str) -> str:
|
||||
"""Rename or move a file/directory."""
|
||||
validated_old = self._validate_path(old_path)
|
||||
validated_new = self._validate_path(new_path)
|
||||
|
||||
if not validated_old or not validated_new:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Check if renaming a directory
|
||||
if validated_old.endswith("/"):
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
@@ -28,7 +28,7 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name in {"mcp_tool","notes"} and user_id:
|
||||
if tool_name in {"mcp_tool", "notes", "memory"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
@@ -36,7 +36,7 @@ class ToolManager:
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
if tool_name in {"mcp_tool", "memory"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
3
frontend/public/toolIcons/tool_memory.svg
Normal file
3
frontend/public/toolIcons/tool_memory.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3">
|
||||
<path d="M240-80q-33 0-56.5-23.5T160-160v-480q0-33 23.5-56.5T240-720h80v-80q0-17 11.5-28.5T360-840q17 0 28.5 11.5T400-800v80h40v-80q0-17 11.5-28.5T480-840q17 0 28.5 11.5T520-800v80h40v-80q0-17 11.5-28.5T600-840q17 0 28.5 11.5T640-800v80h80q33 0 56.5 23.5T800-640v480q0 33-23.5 56.5T720-80H240Zm0-80h480v-480H240v480Zm120-320v-80h240v80H360Zm0 120v-80h240v80H360Zm0 120v-80h160v80H360ZM240-160v-480 480Z"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 523 B |
609
tests/test_memory_tool.py
Normal file
609
tests/test_memory_tool.py
Normal file
@@ -0,0 +1,609 @@
|
||||
import pytest
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_tool(monkeypatch) -> MemoryTool:
|
||||
"""Provide a MemoryTool with a fake Mongo collection and fixed user_id."""
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {} # path -> document
|
||||
|
||||
def insert_one(self, doc):
|
||||
user_id = doc.get("user_id")
|
||||
tool_id = doc.get("tool_id")
|
||||
path = doc.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
self.docs[key] = doc
|
||||
return type("res", (), {"inserted_id": key})
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
|
||||
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
def find(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
results = []
|
||||
|
||||
# Handle regex queries for directory listing
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
# Remove regex escape characters and ^ anchor for simple matching
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
doc_path = doc.get("path", "")
|
||||
if doc_path.startswith(pattern):
|
||||
results.append(doc)
|
||||
else:
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def delete_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key in self.docs:
|
||||
del self.docs[key]
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
def delete_many(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
deleted = 0
|
||||
|
||||
# Handle regex queries for directory deletion
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
|
||||
keys_to_delete = []
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
doc_path = doc.get("path", "")
|
||||
if doc_path.startswith(pattern):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
else:
|
||||
# Delete all for user and tool
|
||||
keys_to_delete = [
|
||||
key for key, doc in self.docs.items()
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
|
||||
return type("res", (), {"deleted_count": deleted})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Return tool with a fixed tool_id for consistency in tests
|
||||
return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||
|
||||
|
||||
def test_init_without_user_id():
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
memory_tool = MemoryTool(tool_config={})
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "user_id" in result.lower()
|
||||
|
||||
|
||||
def test_view_empty_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Should show empty directory when no files exist."""
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "empty" in result.lower()
|
||||
|
||||
|
||||
def test_create_and_view_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test creating a file and viewing it."""
|
||||
# Create a file
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/notes.txt",
|
||||
file_text="Hello world"
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View the file
|
||||
result = memory_tool.execute_action("view", path="/notes.txt")
|
||||
assert "Hello world" in result
|
||||
|
||||
|
||||
def test_create_overwrite_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that create overwrites existing files."""
|
||||
# Create initial file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Original content"
|
||||
)
|
||||
|
||||
# Overwrite
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="New content"
|
||||
)
|
||||
|
||||
# Verify overwrite
|
||||
result = memory_tool.execute_action("view", path="/test.txt")
|
||||
assert "New content" in result
|
||||
assert "Original content" not in result
|
||||
|
||||
|
||||
def test_view_directory_with_files(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing directory contents."""
|
||||
# Create multiple files
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file3.txt",
|
||||
file_text="Content 3"
|
||||
)
|
||||
|
||||
# View directory
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "file1.txt" in result
|
||||
assert "file2.txt" in result
|
||||
assert "subdir/file3.txt" in result
|
||||
|
||||
|
||||
def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing specific lines from a file."""
|
||||
# Create a multiline file
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/multiline.txt",
|
||||
file_text=content
|
||||
)
|
||||
|
||||
# View lines 2-4
|
||||
result = memory_tool.execute_action(
|
||||
"view",
|
||||
path="/multiline.txt",
|
||||
view_range=[2, 4]
|
||||
)
|
||||
assert "Line 2" in result
|
||||
assert "Line 3" in result
|
||||
assert "Line 4" in result
|
||||
assert "Line 1" not in result
|
||||
assert "Line 5" not in result
|
||||
|
||||
|
||||
def test_str_replace(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement in a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/replace.txt",
|
||||
file_text="Hello world, hello universe"
|
||||
)
|
||||
|
||||
# Replace text
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/replace.txt",
|
||||
old_str="hello",
|
||||
new_str="hi"
|
||||
)
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
content = memory_tool.execute_action("view", path="/replace.txt")
|
||||
assert "hi world, hi universe" in content
|
||||
|
||||
|
||||
def test_str_replace_not_found(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement when string not found."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Hello world"
|
||||
)
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/test.txt",
|
||||
old_str="goodbye",
|
||||
new_str="hi"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_insert_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting text at a line number."""
|
||||
# Create a multiline file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/insert.txt",
|
||||
file_text="Line 1\nLine 2\nLine 3"
|
||||
)
|
||||
|
||||
# Insert at line 2
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/insert.txt",
|
||||
insert_line=2,
|
||||
insert_text="Inserted line"
|
||||
)
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
content = memory_tool.execute_action("view", path="/insert.txt")
|
||||
lines = content.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
def test_insert_invalid_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting at an invalid line number."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Line 1\nLine 2"
|
||||
)
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/test.txt",
|
||||
insert_line=100,
|
||||
insert_text="Text"
|
||||
)
|
||||
assert "invalid" in result.lower()
|
||||
|
||||
|
||||
def test_delete_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/delete_me.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
# Delete it
|
||||
result = memory_tool.execute_action("delete", path="/delete_me.txt")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify it's gone
|
||||
result = memory_tool.execute_action("view", path="/delete_me.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file that doesn't exist."""
|
||||
result = memory_tool.execute_action("delete", path="/nonexistent.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_delete_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a directory with files."""
|
||||
# Create files in a directory
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Delete the directory
|
||||
result = memory_tool.execute_action("delete", path="/subdir/")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify files are gone
|
||||
result = memory_tool.execute_action("view", path="/subdir/file1.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_rename_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/old_name.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
# Rename it
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/old_name.txt",
|
||||
new_path="/new_name.txt"
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
|
||||
# Verify old path doesn't exist
|
||||
result = memory_tool.execute_action("view", path="/old_name.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
# Verify new path exists
|
||||
result = memory_tool.execute_action("view", path="/new_name.txt")
|
||||
assert "Content" in result
|
||||
|
||||
|
||||
def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file that doesn't exist."""
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/nonexistent.txt",
|
||||
new_path="/new.txt"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_rename_to_existing_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming to a path that already exists."""
|
||||
# Create two files
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Try to rename file1 to file2
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/file1.txt",
|
||||
new_path="/file2.txt"
|
||||
)
|
||||
assert "already exists" in result.lower()
|
||||
|
||||
|
||||
def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
|
||||
"""Test that directory traversal attacks are prevented."""
|
||||
# Try various path traversal attempts
|
||||
invalid_paths = [
|
||||
"/../secrets.txt",
|
||||
"/../../etc/passwd",
|
||||
"..//file.txt",
|
||||
"/subdir/../../outside.txt",
|
||||
]
|
||||
|
||||
for path in invalid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="malicious content"
|
||||
)
|
||||
assert "invalid path" in result.lower()
|
||||
|
||||
|
||||
def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
|
||||
"""Test that paths work with or without leading slash (auto-normalized)."""
|
||||
# These paths should all work now (auto-prepended with /)
|
||||
valid_paths = [
|
||||
"etc/passwd", # Auto-prepended with /
|
||||
"home/user/file.txt", # Auto-prepended with /
|
||||
"file.txt", # Auto-prepended with /
|
||||
]
|
||||
|
||||
for path in valid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="content"
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# Verify the file can be accessed with or without leading slash
|
||||
view_result = memory_tool.execute_action("view", path=path)
|
||||
assert "content" in view_result
|
||||
|
||||
|
||||
def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that you cannot create a file at a directory path."""
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/",
|
||||
file_text="content"
|
||||
)
|
||||
assert "cannot create a file at directory path" in result.lower()
|
||||
|
||||
|
||||
def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
|
||||
"""Test that action metadata is properly defined."""
|
||||
metadata = memory_tool.get_actions_metadata()
|
||||
|
||||
# Check that all expected actions are defined
|
||||
action_names = [action["name"] for action in metadata]
|
||||
assert "view" in action_names
|
||||
assert "create" in action_names
|
||||
assert "str_replace" in action_names
|
||||
assert "insert" in action_names
|
||||
assert "delete" in action_names
|
||||
assert "rename" in action_names
|
||||
|
||||
# Check that each action has required fields
|
||||
for action in metadata:
|
||||
assert "name" in action
|
||||
assert "description" in action
|
||||
assert "parameters" in action
|
||||
|
||||
|
||||
def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
"""Test that different memory tool instances have isolated memories."""
|
||||
# Create fake collection
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def insert_one(self, doc):
|
||||
user_id = doc.get("user_id")
|
||||
tool_id = doc.get("tool_id")
|
||||
path = doc.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
self.docs[key] = doc
|
||||
return type("res", (), {"inserted_id": key})
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": ""}
|
||||
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two memory tools with different tool_ids for the same user
|
||||
tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||
tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||
|
||||
# Create a file in tool1
|
||||
tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1")
|
||||
|
||||
# Create a file with the same path in tool2
|
||||
tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2")
|
||||
|
||||
# Verify that each tool sees only its own content
|
||||
result1 = tool1.execute_action("view", path="/file.txt")
|
||||
result2 = tool2.execute_action("view", path="/file.txt")
|
||||
|
||||
assert "Content from tool 1" in result1
|
||||
assert "Content from tool 2" not in result1
|
||||
|
||||
assert "Content from tool 2" in result2
|
||||
assert "Content from tool 1" not in result2
|
||||
|
||||
|
||||
def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two tools without providing tool_id for the same user
|
||||
tool1 = MemoryTool({}, user_id="test_user")
|
||||
tool2 = MemoryTool({}, user_id="test_user")
|
||||
|
||||
# Both should have the same default tool_id for persistence
|
||||
assert tool1.tool_id == "default_test_user"
|
||||
assert tool2.tool_id == "default_test_user"
|
||||
assert tool1.tool_id == tool2.tool_id
|
||||
|
||||
# Different users should have different tool_ids
|
||||
tool3 = MemoryTool({}, user_id="another_user")
|
||||
assert tool3.tool_id == "default_another_user"
|
||||
assert tool3.tool_id != tool1.tool_id
|
||||
|
||||
|
||||
def test_paths_without_leading_slash(memory_tool) -> None:
|
||||
"""Test that paths without leading slash work correctly."""
|
||||
# Create file without leading slash
|
||||
result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung")
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View file without leading slash
|
||||
view_result = memory_tool.execute_action("view", path="cat_breeds.txt")
|
||||
assert "Korat" in view_result
|
||||
assert "Chartreux" in view_result
|
||||
|
||||
# View same file with leading slash (should work the same)
|
||||
view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt")
|
||||
assert "Korat" in view_result2
|
||||
|
||||
# Test str_replace without leading slash
|
||||
replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon")
|
||||
assert "updated" in replace_result.lower()
|
||||
|
||||
# Test nested path without leading slash
|
||||
nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2")
|
||||
assert "created" in nested_result.lower()
|
||||
|
||||
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
|
||||
assert "Task 1" in view_nested
|
||||
Reference in New Issue
Block a user