mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Add MongoDB-backed NotesTool with CRUD actions and unit tests (#1982)
* Add MongoDB-backed NotesTool with CRUD actions and unit tests * NotesTool: enforce single note per user, require decoded_token['sub'] user_id, fix tests * chore: remove accidentally committed results.txt and ignore it * fix: remove results.txt, enforce single note per user, add tests * refactor: update NotesTool actions and tests for clarity and consistency * refactor: update NotesTool docstring for clarity * refactor: simplify MCPTool docstring and remove redundant import in test_notes_tool * lint: fix test * refactor: remove unused import from test_notes_tool.py --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -37,7 +37,7 @@ _mcp_clients_cache = {}
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
|
||||
187
application/agents/tools/notes.py
Normal file
187
application/agents/tools/notes.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
"""Notepad
|
||||
|
||||
Single note. Supports viewing, overwriting, string replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration (unused for now).
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
|
||||
"""
|
||||
|
||||
|
||||
self.user_id: Optional[str] = user_id
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
if action_name == "overwrite":
|
||||
return self._overwrite_note(kwargs.get("text", ""))
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete_note()
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Retrieve the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "overwrite",
|
||||
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "New note content."}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_str": {"type": "string", "description": "String to find."},
|
||||
"new_str": {"type": "string", "description": "String to replace with."}
|
||||
},
|
||||
"required": ["old_str", "new_str"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at the specified line number (1-indexed).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||
"text": {"type": "string", "description": "Text to insert."}
|
||||
},
|
||||
"required": ["line_number", "text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _get_note(self) -> str:
|
||||
doc = self.collection.find_one({"user_id": self.user_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True, # ✅ create if missing
|
||||
)
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
return f"String '{old_str}' not found in note."
|
||||
|
||||
# Case-insensitive replacement
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
index = line_number - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Invalid line number. Note has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
res = self.collection.delete_one({"user_id": self.user_id})
|
||||
return "Note deleted." if res.deleted_count else "No note found to delete."
|
||||
@@ -28,7 +28,7 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
if tool_name in {"mcp_tool","notes"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
|
||||
1
frontend/public/toolIcons/tool_notes.svg
Normal file
1
frontend/public/toolIcons/tool_notes.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3"><path d="M320-240h320v-80H320v80Zm0-160h320v-80H320v80ZM240-80q-33 0-56.5-23.5T160-160v-640q0-33 23.5-56.5T240-880h320l240 240v480q0 33-23.5 56.5T720-80H240Zm280-520v-200H240v640h480v-440H520ZM240-800v200-200 640-640Z"/></svg>
|
||||
|
After Width: | Height: | Size: 334 B |
133
tests/test_notes_tool.py
Normal file
133
tests/test_notes_tool.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import pytest
|
||||
from application.agents.tools.notes import NotesTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notes_tool(monkeypatch) -> NotesTool:
|
||||
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.doc = None # single note per user
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
# emulate single-note storage with optional upsert
|
||||
if self.doc is None and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
if self.doc is None and upsert:
|
||||
self.doc = {"user_id": q["user_id"], "note": ""}
|
||||
if "$set" in u and "note" in u["$set"]:
|
||||
self.doc["note"] = u["$set"]["note"]
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q):
|
||||
if self.doc and self.doc.get("user_id") == q.get("user_id"):
|
||||
return self.doc
|
||||
return None
|
||||
|
||||
def delete_one(self, q):
|
||||
if self.doc and self.doc.get("user_id") == q.get("user_id"):
|
||||
self.doc = None
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
# Patch MongoDB client globally for the tool
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# ToolManager will pass user_id in production; in tests we pass it directly
|
||||
return NotesTool({}, user_id="test_user")
|
||||
|
||||
|
||||
def test_view(notes_tool: NotesTool) -> None:
|
||||
# Manually insert a note to test retrieval
|
||||
notes_tool.collection.update_one(
|
||||
{"user_id": "test_user"},
|
||||
{"$set": {"note": "hello"}},
|
||||
upsert=True
|
||||
)
|
||||
assert "hello" in notes_tool.execute_action("view")
|
||||
|
||||
|
||||
def test_overwrite_and_delete(notes_tool: NotesTool) -> None:
|
||||
# Overwrite creates a new note
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="first").lower()
|
||||
assert "first" in notes_tool.execute_action("view")
|
||||
|
||||
# Overwrite replaces existing note
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="second").lower()
|
||||
assert "second" in notes_tool.execute_action("view")
|
||||
|
||||
assert "deleted" in notes_tool.execute_action("delete").lower()
|
||||
assert "no note" in notes_tool.execute_action("view").lower()
|
||||
|
||||
def test_init_without_user_id(monkeypatch):
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
notes_tool = NotesTool(tool_config={})
|
||||
result = notes_tool.execute_action("view")
|
||||
assert "user_id" in str(result).lower()
|
||||
|
||||
|
||||
def test_view_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Should return 'No note found.' when no note exists"""
|
||||
result = notes_tool.execute_action("view")
|
||||
assert "no note found" in result.lower()
|
||||
|
||||
|
||||
def test_str_replace(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement in note"""
|
||||
# Create a note
|
||||
notes_tool.execute_action("overwrite", text="Hello world, hello universe")
|
||||
|
||||
# Replace text
|
||||
result = notes_tool.execute_action("str_replace", old_str="hello", new_str="hi")
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
note = notes_tool.execute_action("view")
|
||||
assert "hi world, hi universe" in note.lower()
|
||||
|
||||
|
||||
def test_str_replace_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement when string not found"""
|
||||
notes_tool.execute_action("overwrite", text="Hello world")
|
||||
result = notes_tool.execute_action("str_replace", old_str="goodbye", new_str="hi")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_insert_line(notes_tool: NotesTool) -> None:
|
||||
"""Test inserting text at a line number"""
|
||||
# Create a multiline note
|
||||
notes_tool.execute_action("overwrite", text="Line 1\nLine 2\nLine 3")
|
||||
|
||||
# Insert at line 2
|
||||
result = notes_tool.execute_action("insert", line_number=2, text="Inserted line")
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
note = notes_tool.execute_action("view")
|
||||
lines = note.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
def test_delete_nonexistent_note(monkeypatch):
|
||||
class FakeResult:
|
||||
deleted_count = 0
|
||||
|
||||
class FakeCollection:
|
||||
def delete_one(self, *args, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client",
|
||||
lambda: {"docsgpt": {"notes": FakeCollection()}}
|
||||
)
|
||||
|
||||
notes_tool = NotesTool(tool_config={}, user_id="user123")
|
||||
result = notes_tool.execute_action("delete")
|
||||
assert "no note found" in result.lower()
|
||||
Reference in New Issue
Block a user