diff --git a/application/agents/base.py b/application/agents/base.py index e75fbb7d..bdf7b100 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -256,6 +256,8 @@ class BaseAgent(ABC): # Use MongoDB _id if available, otherwise fall back to enumerated tool_id tool_config["tool_id"] = str(tool_data.get("_id", tool_id)) + if hasattr(self, "conversation_id") and self.conversation_id: + tool_config["conversation_id"] = self.conversation_id tool = tm.load_tool( tool_data["name"], tool_config=tool_config, @@ -269,6 +271,27 @@ class BaseAgent(ABC): else: logger.debug(f"Executing tool: {action_name} with args: {call_args}") result = tool.execute_action(action_name, **parameters) + + get_artifact_id = ( + getattr(tool, "get_artifact_id", None) + if tool_data["name"] != "api_tool" + else None + ) + + artifact_id = None + if callable(get_artifact_id): + try: + artifact_id = get_artifact_id(action_name, **parameters) + except Exception: + logger.exception( + "Failed to extract artifact_id from tool %s for action %s", + tool_data["name"], + action_name, + ) + + artifact_id = str(artifact_id).strip() if artifact_id is not None else "" + if artifact_id: + tool_call_data["artifact_id"] = artifact_id tool_call_data["result"] = ( f"{str(result)[:50]}..." if len(str(result)) > 50 else result ) diff --git a/application/agents/tools/notes.py b/application/agents/tools/notes.py index 3d7ced85..8afdd071 100644 --- a/application/agents/tools/notes.py +++ b/application/agents/tools/notes.py @@ -38,6 +38,8 @@ class NotesTool(Tool): db = MongoDB.get_client()[settings.MONGO_DB_NAME] self.collection = db["notes"] + self._last_artifact_id: Optional[str] = None + # ----------------------------- # Action implementations # ----------------------------- @@ -54,6 +56,8 @@ class NotesTool(Tool): if not self.user_id: return "Error: NotesTool requires a valid user_id." + self._last_artifact_id = None + if action_name == "view": return self._get_note() @@ -125,6 +129,9 @@ class NotesTool(Tool): """Return configuration requirements (none for now).""" return {} + def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]: + return self._last_artifact_id + # ----------------------------- # Internal helpers (single-note) # ----------------------------- @@ -132,17 +139,22 @@ class NotesTool(Tool): doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id}) if not doc or not doc.get("note"): return "No note found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) return str(doc["note"]) def _overwrite_note(self, content: str) -> str: content = (content or "").strip() if not content: return "Note content required." - self.collection.update_one( + result = self.collection.find_one_and_update( {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": content, "updated_at": datetime.utcnow()}}, - upsert=True, # ✅ create if missing + upsert=True, + return_document=True, ) + if result and result.get("_id") is not None: + self._last_artifact_id = str(result.get("_id")) return "Note saved." def _str_replace(self, old_str: str, new_str: str) -> str: @@ -163,10 +175,13 @@ class NotesTool(Tool): import re updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE) - self.collection.update_one( + result = self.collection.find_one_and_update( {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, + return_document=True, ) + if result and result.get("_id") is not None: + self._last_artifact_id = str(result.get("_id")) return "Note updated." def _insert(self, line_number: int, text: str) -> str: @@ -188,12 +203,21 @@ class NotesTool(Tool): lines.insert(index, text) updated_note = "\n".join(lines) - self.collection.update_one( + result = self.collection.find_one_and_update( {"user_id": self.user_id, "tool_id": self.tool_id}, {"$set": {"note": updated_note, "updated_at": datetime.utcnow()}}, + return_document=True, ) + if result and result.get("_id") is not None: + self._last_artifact_id = str(result.get("_id")) return "Text inserted." def _delete_note(self) -> str: - res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id}) - return "Note deleted." if res.deleted_count else "No note found to delete." + doc = self.collection.find_one_and_delete( + {"user_id": self.user_id, "tool_id": self.tool_id} + ) + if not doc: + return "No note found to delete." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return "Note deleted." diff --git a/application/agents/tools/todo_list.py b/application/agents/tools/todo_list.py index 87a3e969..b515ad56 100644 --- a/application/agents/tools/todo_list.py +++ b/application/agents/tools/todo_list.py @@ -38,6 +38,8 @@ class TodoListTool(Tool): db = MongoDB.get_client()[settings.MONGO_DB_NAME] self.collection = db["todos"] + self._last_artifact_id: Optional[str] = None + # ----------------------------- # Action implementations # ----------------------------- @@ -54,6 +56,8 @@ class TodoListTool(Tool): if not self.user_id: return "Error: TodoListTool requires a valid user_id." + self._last_artifact_id = None + if action_name == "list": return self._list() @@ -165,6 +169,9 @@ class TodoListTool(Tool): """Return configuration requirements.""" return {} + def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]: + return self._last_artifact_id + # ----------------------------- # Internal helpers # ----------------------------- @@ -190,11 +197,8 @@ class TodoListTool(Tool): Returns a simple integer (1, 2, 3, ...) scoped to this user/tool. With 5-10 todos max, scanning is negligible. """ - # Find all todos for this user/tool and get their IDs - todos = list(self.collection.find( - {"user_id": self.user_id, "tool_id": self.tool_id}, - {"todo_id": 1} - )) + query = {"user_id": self.user_id, "tool_id": self.tool_id} + todos = list(self.collection.find(query, {"todo_id": 1})) # Find the maximum todo_id max_id = 0 @@ -207,8 +211,8 @@ class TodoListTool(Tool): def _list(self) -> str: """List all todos for the user.""" - cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id}) - todos = list(cursor) + query = {"user_id": self.user_id, "tool_id": self.tool_id} + todos = list(self.collection.find(query)) if not todos: return "No todos found." @@ -242,7 +246,10 @@ class TodoListTool(Tool): "created_at": now, "updated_at": now, } - self.collection.insert_one(doc) + insert_result = self.collection.insert_one(doc) + inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id") + if inserted_id is not None: + self._last_artifact_id = str(inserted_id) return f"Todo created with ID {todo_id}: {title}" def _get(self, todo_id: Optional[Any]) -> str: @@ -251,15 +258,15 @@ class TodoListTool(Tool): if parsed_todo_id is None: return "Error: todo_id must be a positive integer." - doc = self.collection.find_one({ - "user_id": self.user_id, - "tool_id": self.tool_id, - "todo_id": parsed_todo_id - }) + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one(query) if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + title = doc.get("title", "Untitled") status = doc.get("status", "open") @@ -277,14 +284,17 @@ class TodoListTool(Tool): if not title: return "Error: Title is required." - result = self.collection.update_one( - {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}, - {"$set": {"title": title, "updated_at": datetime.now()}} + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one_and_update( + query, + {"$set": {"title": title, "updated_at": datetime.now()}}, ) - - if result.matched_count == 0: + if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return f"Todo {parsed_todo_id} updated to: {title}" def _complete(self, todo_id: Optional[Any]) -> str: @@ -293,14 +303,17 @@ class TodoListTool(Tool): if parsed_todo_id is None: return "Error: todo_id must be a positive integer." - result = self.collection.update_one( - {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}, - {"$set": {"status": "completed", "updated_at": datetime.now()}} + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one_and_update( + query, + {"$set": {"status": "completed", "updated_at": datetime.now()}}, ) - - if result.matched_count == 0: + if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return f"Todo {parsed_todo_id} marked as completed." def _delete(self, todo_id: Optional[Any]) -> str: @@ -309,13 +322,12 @@ class TodoListTool(Tool): if parsed_todo_id is None: return "Error: todo_id must be a positive integer." - result = self.collection.delete_one({ - "user_id": self.user_id, - "tool_id": self.tool_id, - "todo_id": parsed_todo_id - }) - - if result.deleted_count == 0: + query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id} + doc = self.collection.find_one_and_delete(query) + if not doc: return f"Error: Todo with ID {parsed_todo_id} not found." + if doc.get("_id") is not None: + self._last_artifact_id = str(doc.get("_id")) + return f"Todo {parsed_todo_id} deleted." diff --git a/application/api/user/tools/routes.py b/application/api/user/tools/routes.py index 384e8b8a..760f0120 100644 --- a/application/api/user/tools/routes.py +++ b/application/api/user/tools/routes.py @@ -467,3 +467,84 @@ class ParseSpec(Resource): except Exception as err: current_app.logger.error(f"Error parsing spec: {err}", exc_info=True) return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500) + + +@tools_ns.route("/artifact/") +class GetArtifact(Resource): + @api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.") + def get(self, artifact_id: str): + decoded_token = request.decoded_token + if not decoded_token: + return make_response(jsonify({"success": False}), 401) + user_id = decoded_token.get("sub") + + try: + obj_id = ObjectId(artifact_id) + except Exception: + return make_response( + jsonify({"success": False, "message": "Invalid artifact ID"}), 400 + ) + + from application.core.mongo_db import MongoDB + from application.core.settings import settings + + db = MongoDB.get_client()[settings.MONGO_DB_NAME] + + note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id}) + if note_doc: + content = note_doc.get("note", "") + line_count = len(content.split("\n")) if content else 0 + artifact = { + "artifact_type": "note", + "data": { + "content": content, + "line_count": line_count, + "updated_at": ( + note_doc["updated_at"].isoformat() + if note_doc.get("updated_at") + else None + ), + }, + } + return make_response(jsonify({"success": True, "artifact": artifact}), 200) + + todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id}) + if todo_doc: + tool_id = todo_doc.get("tool_id") + # Return all todos for the tool + query = {"user_id": user_id, "tool_id": tool_id} + all_todos = list(db["todos"].find(query)) + items = [] + open_count = 0 + completed_count = 0 + for t in all_todos: + status = t.get("status", "open") + if status == "open": + open_count += 1 + elif status == "completed": + completed_count += 1 + items.append({ + "todo_id": t.get("todo_id"), + "title": t.get("title", ""), + "status": status, + "created_at": ( + t["created_at"].isoformat() if t.get("created_at") else None + ), + "updated_at": ( + t["updated_at"].isoformat() if t.get("updated_at") else None + ), + }) + artifact = { + "artifact_type": "todo_list", + "data": { + "items": items, + "total_count": len(items), + "open_count": open_count, + "completed_count": completed_count, + }, + } + return make_response(jsonify({"success": True, "artifact": artifact}), 200) + + return make_response( + jsonify({"success": False, "message": "Artifact not found"}), 404 + ) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 929cb014..649ef13c 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -135,6 +135,7 @@ "integrity": "sha512-e7jT4DxYvIDLk1ZHmU/m/mB19rex9sv0c2ftBtjSBv+kVM/902eh0fINUzD7UwLLNR+jU585GxUJ8/EBfAM5fw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/code-frame": "^7.27.1", "@babel/generator": "^7.28.5", @@ -2623,6 +2624,7 @@ "integrity": "sha512-8QqtOQT5ACVlmsvKOJNEaWmRPmcojMOzCz4Hs2BGG/toAp/K38LcsMRyLp349glq5AzJbCEeimEoxaX6v/fLrA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/core": "^7.21.3", "@svgr/babel-preset": "8.1.0", @@ -3408,6 +3410,7 @@ "integrity": "sha512-6mDvHUFSjyT2B2yeNx2nUgMxh9LtOWvkhIU3uePn2I2oyNymUAX1NIsdgviM4CH+JSrp2D2hsMvJOkxY+0wNRA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "csstype": "^3.0.2" } @@ -4071,6 +4074,7 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -4421,6 +4425,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.8.19", "caniuse-lite": "^1.0.30001751", @@ -4601,6 +4606,7 @@ "resolved": "https://registry.npmjs.org/chart.js/-/chart.js-4.5.1.tgz", "integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==", "license": "MIT", + "peer": true, "dependencies": { "@kurkle/color": "^0.3.0" }, @@ -4613,6 +4619,7 @@ "resolved": "https://registry.npmjs.org/chevrotain/-/chevrotain-11.0.3.tgz", "integrity": "sha512-ci2iJH6LeIkvP9eJW6gpueU8cnZhv85ELY8w8WiFtNjMHA5ad6pQLaJo9mEly/9qUyCpvqX8/POVUTf18/HFdw==", "license": "Apache-2.0", + "peer": true, "dependencies": { "@chevrotain/cst-dts-gen": "11.0.3", "@chevrotain/gast": "11.0.3", @@ -4858,6 +4865,7 @@ "resolved": "https://registry.npmjs.org/cytoscape/-/cytoscape-3.33.1.tgz", "integrity": "sha512-iJc4TwyANnOGR1OmWhsS9ayRS3s+XQ185FmuHObThD+5AeJCakAAbWv8KimMTt08xCCLNgneQwFp+JRJOr9qGQ==", "license": "MIT", + "peer": true, "engines": { "node": ">=0.10" } @@ -5267,6 +5275,7 @@ "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", "license": "ISC", + "peer": true, "engines": { "node": ">=12" } @@ -5887,6 +5896,7 @@ "integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -5963,6 +5973,7 @@ "integrity": "sha512-82GZUjRS0p/jganf6q1rEO25VSoHH0hKPCTrgillPjdI/3bgBhAE1QzHrHTizjpRvy6pGAvKjDJtk2pF9NDq8w==", "dev": true, "license": "MIT", + "peer": true, "bin": { "eslint-config-prettier": "bin/cli.js" }, @@ -7360,6 +7371,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "@babel/runtime": "^7.27.6" }, @@ -10313,6 +10325,7 @@ "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", "dev": true, "license": "MIT", + "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -10497,6 +10510,7 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.2.0.tgz", "integrity": "sha512-tmbWg6W31tQLeB5cdIBOicJDJRR2KzXsV7uSK9iNfLWQ5bIZfxuPEHp7M8wiHyHnn0DD1i7w3Zmin0FtkrwoCQ==", "license": "MIT", + "peer": true, "engines": { "node": ">=0.10.0" } @@ -10516,6 +10530,7 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.0.tgz", "integrity": "sha512-UlbRu4cAiGaIewkPyiRGJk0imDN2T3JjieT6spoL2UeSf5od4n5LB/mQ4ejmxhCFT1tYe8IvaFulzynWovsEFQ==", "license": "MIT", + "peer": true, "dependencies": { "scheduler": "^0.27.0" }, @@ -10615,6 +10630,7 @@ "resolved": "https://registry.npmjs.org/react-redux/-/react-redux-9.2.0.tgz", "integrity": "sha512-ROY9fvHhwOD9ySfrF0wmvu//bKCQ6AeZZq1nJNtbDC+kk5DuSuNX/n6YWYF/SYy7bSba4D4FSz8DJeKY/S/r+g==", "license": "MIT", + "peer": true, "dependencies": { "@types/use-sync-external-store": "^0.0.6", "use-sync-external-store": "^1.4.0" @@ -10789,7 +10805,8 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz", "integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==", - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/redux-thunk": { "version": "3.1.0", @@ -11927,6 +11944,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -12157,6 +12175,7 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -12459,6 +12478,7 @@ "integrity": "sha512-C/Naxf8H0pBx1PA4BdpT+c/5wdqI9ILMdwjSMILw7tVIh3JsxzZqdeTLmmdaoh5MYUEOyBnM9K3o0DzoZ/fe+w==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", @@ -12567,6 +12587,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index f9933109..eec77508 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -68,6 +68,7 @@ const endpoints = { AGENT_FOLDERS: '/api/agents/folders/', AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`, MOVE_AGENT_TO_FOLDER: '/api/agents/folders/move_agent', + GET_ARTIFACT: (artifactId: string) => `/api/artifact/${artifactId}`, WORKFLOWS: '/api/workflows', WORKFLOW: (id: string) => `/api/workflows/${id}`, }, diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 5d5da6d5..06382435 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -160,6 +160,8 @@ const userService = { token: string | null, ): Promise => apiClient.post(endpoints.USER.MOVE_AGENT_TO_FOLDER, data, token), + getArtifact: (artifactId: string, token: string | null): Promise => + apiClient.get(endpoints.USER.GET_ARTIFACT(artifactId), token), getWorkflow: (id: string, token: string | null): Promise => apiClient.get(endpoints.USER.WORKFLOW(id), token), createWorkflow: (data: any, token: string | null): Promise => diff --git a/frontend/src/components/ActionButtons.tsx b/frontend/src/components/ActionButtons.tsx index 5053e925..22d4f286 100644 --- a/frontend/src/components/ActionButtons.tsx +++ b/frontend/src/components/ActionButtons.tsx @@ -16,6 +16,7 @@ interface ActionButtonsProps { className?: string; showNewChat?: boolean; showShare?: boolean; + isArtifactOpen?: boolean; } import { useNavigate } from 'react-router-dom'; @@ -24,6 +25,7 @@ export default function ActionButtons({ className = '', showNewChat = true, showShare = true, + isArtifactOpen = false, }: ActionButtonsProps) { const { t } = useTranslation(); const dispatch = useDispatch(); @@ -41,7 +43,11 @@ export default function ActionButtons({ navigate('/'); }; return ( -
+
{showNewChat && (
+ ); +} + +export default function ArtifactSidebar({ + isOpen, + onClose, + artifactId, + toolName, + conversationId, + variant = 'overlay', +}: ArtifactSidebarProps) { + const sidebarRef = React.useRef(null); + const lastSuccessfulTodoArtifactIdRef = React.useRef(null); + const currentFetchIdRef = React.useRef(null); + const token = useSelector(selectToken); + const [artifact, setArtifact] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [effectiveArtifactId, setEffectiveArtifactId] = useState( + artifactId, + ); + + const title = getArtifactTitle(artifact, toolName); + + // Reset last successful todo artifact ID when conversation changes + useEffect(() => { + lastSuccessfulTodoArtifactIdRef.current = null; + }, [conversationId]); + + // Reset effectiveArtifactId when artifactId changes + useEffect(() => { + if (!isOpen) { + setEffectiveArtifactId(null); + return; + } + setEffectiveArtifactId(artifactId); + }, [isOpen, artifactId]); + + // Fetch artifact when effectiveArtifactId changes + useEffect(() => { + if (!isOpen || !effectiveArtifactId) { + setArtifact(null); + setError(null); + setLoading(false); + currentFetchIdRef.current = null; + return; + } + + // Generate a unique ID for this fetch + const fetchId = `${effectiveArtifactId}-${Date.now()}`; + currentFetchIdRef.current = fetchId; + + setLoading(true); + setError(null); + + // Note: For todo artifacts, the endpoint always returns all todos for the tool; will be coversation scoped later + userService + .getArtifact(effectiveArtifactId, token) + .then(async (res: any) => { + // Ignore if this is not the current fetch + if (currentFetchIdRef.current !== fetchId) return; + + const isResponseLike = res && typeof res.json === 'function'; + const status = isResponseLike ? res.status : undefined; + const ok = isResponseLike ? Boolean(res.ok) : true; + + let data: any = res; + if (isResponseLike) { + try { + data = await res.json(); + } catch { + data = null; + } + } + + // Check again after async operation + if (currentFetchIdRef.current !== fetchId) return; + + if (ok && data?.success && data?.artifact) { + setArtifact(data.artifact); + // Remember the last successful todo artifact id so we can fallback if a newer id 404s. + if (data.artifact?.artifact_type === 'todo_list') { + lastSuccessfulTodoArtifactIdRef.current = effectiveArtifactId; + } + setLoading(false); + return; + } + + const isTodoTool = (toolName ?? '').toLowerCase().includes('todo'); + + // If the latest todo artifact id is missing (404), fall back to the last known good one + // so the backend can still resolve `tool_id` for the todo list. + if ( + status === 404 && + isTodoTool && + lastSuccessfulTodoArtifactIdRef.current && + lastSuccessfulTodoArtifactIdRef.current !== effectiveArtifactId + ) { + // Update effectiveArtifactId to trigger a new fetch with the fallback id + setEffectiveArtifactId(lastSuccessfulTodoArtifactIdRef.current); + setLoading(false); + return; + } + + // Ensure we show a visible error state instead of rendering nothing. + const message = + data?.message || + (status === 404 ? 'Artifact not found' : null) || + 'Failed to load artifact'; + setError(message); + setLoading(false); + }) + .catch((err) => { + // Ignore if this is not the current fetch + if (currentFetchIdRef.current !== fetchId) return; + setError('Failed to fetch artifact'); + setLoading(false); + }); + }, [isOpen, effectiveArtifactId, token, toolName, conversationId]); + + const handleClickOutside = (event: MouseEvent) => { + if ( + sidebarRef.current && + !sidebarRef.current.contains(event.target as Node) + ) { + onClose(); + } + }; + + useEffect(() => { + if (variant === 'overlay' && isOpen) { + document.addEventListener('mousedown', handleClickOutside); + } + return () => document.removeEventListener('mousedown', handleClickOutside); + }, [isOpen, variant]); + + const renderContent = () => { + if (loading) { + return ( +
+ +
+ ); + } + if (error) { + return ( +
+

{error}

+
+ ); + } + // Avoid rendering an empty panel if the artifact couldn't be loaded for any reason. + if (!artifact) { + return ( +
+

+ Artifact not found +

+
+ ); + } + switch (artifact.artifact_type) { + case 'todo_list': + return ; + case 'note': + return ; + default: + return ( +
+            {JSON.stringify(artifact, null, 2)}
+          
+ ); + } + }; + + if (variant === 'split') { + if (!isOpen) return null; + + return ( +
+ {/* Space for top bar / actions */} +
+ {/* Artifact panel */} +
+
+ + {title} + + +
+
{renderContent()}
+
+
+ ); + } + + return ( +
+
+
+ + {title} + + +
+
{renderContent()}
+
+
+ ); +} diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index c36d6a7c..cba4a479 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -3,6 +3,7 @@ import { useTranslation } from 'react-i18next'; import { useDispatch, useSelector } from 'react-redux'; import SharedAgentCard from '../agents/SharedAgentCard'; +import ArtifactSidebar from '../components/ArtifactSidebar'; import MessageInput from '../components/MessageInput'; import { useMediaQuery } from '../hooks'; import { @@ -14,20 +15,16 @@ import { AppDispatch } from '../store'; import { handleSendFeedback } from './conversationHandlers'; import ConversationMessages from './ConversationMessages'; import { FEEDBACK, Query } from './conversationModels'; +import { ToolCallsType } from './types'; import { addQuery, fetchAnswer, resendQuery, selectQueries, selectStatus, - setConversation, - updateConversationId, updateQuery, } from './conversationSlice'; -import { - selectCompletedAttachments, - clearAttachments, -} from '../upload/uploadSlice'; +import { selectCompletedAttachments } from '../upload/uploadSlice'; export default function Conversation() { const { t } = useTranslation(); @@ -43,13 +40,33 @@ export default function Conversation() { const [lastQueryReturnedErr, setLastQueryReturnedErr] = useState(false); - const [isShareModalOpen, setShareModalState] = useState(false); - const fetchStream = useRef(null); + const lastAutoOpenedArtifactId = useRef(null); + const didInitArtifactAutoOpen = useRef(false); + const prevConversationId = useRef(conversationId); + + const [openArtifact, setOpenArtifact] = useState<{ + id: string; + toolName: string; + } | null>(null); + + useEffect(() => { + const prevId = prevConversationId.current; + // Don't reset when the backend assigns the conversation id mid-stream (null -> id) + const isServerAssignedId = + prevId === null && conversationId !== null && status === 'loading'; + + if (!isServerAssignedId && prevId !== conversationId) { + setOpenArtifact(null); + lastAutoOpenedArtifactId.current = null; + } + + prevConversationId.current = conversationId; + }, [conversationId, status]); const handleFetchAnswer = useCallback( ({ question, index }: { question: string; index?: number }) => { - fetchStream.current = dispatch(fetchAnswer({ question, indx: index })); + dispatch(fetchAnswer({ question, indx: index })); }, [dispatch, selectedAgent], ); @@ -143,61 +160,138 @@ export default function Conversation() { } }; - const resetConversation = () => { - dispatch(setConversation([])); - dispatch( - updateConversationId({ - query: { conversationId: null }, - }), - ); - dispatch(clearAttachments()); - }; + useEffect(() => { + if (queries.length) { + const last = queries[queries.length - 1]; + if (last.error) setLastQueryReturnedErr(true); + if (last.response) setLastQueryReturnedErr(false); + } + }, [queries]); useEffect(() => { - if (queries.length === 0) { - setLastQueryReturnedErr(false); + // Avoid auto-opening an artifact from existing conversation history on first mount. + if (!didInitArtifactAutoOpen.current) { + didInitArtifactAutoOpen.current = true; return; } - const lastQuery = queries[queries.length - 1]; - setLastQueryReturnedErr(!!lastQuery.error && !lastQuery.response); + const isNotesOrTodoTool = (toolName?: string) => { + const t = (toolName ?? '').toLowerCase(); + return t === 'notes' || t === 'todo_list' || t === 'todo'; + }; + + const findLatestCompletedArtifactCall = ( + items: Query[], + ): ToolCallsType | null => { + for (let i = items.length - 1; i >= 0; i -= 1) { + const calls = items[i].tool_calls ?? []; + for (let j = calls.length - 1; j >= 0; j -= 1) { + const call = calls[j]; + if (call.artifact_id && call.status === 'completed') return call; + } + } + return null; + }; + + const latest = findLatestCompletedArtifactCall(queries); + if (!latest?.artifact_id) return; + if (!isNotesOrTodoTool(latest.tool_name)) return; + if (latest.artifact_id === lastAutoOpenedArtifactId.current) return; + + lastAutoOpenedArtifactId.current = latest.artifact_id; + setOpenArtifact({ + id: latest.artifact_id, + toolName: latest.tool_name, + }); }, [queries]); - return ( -
- - -
- ) : undefined - } - /> + const handleOpenArtifact = useCallback( + (artifact: { id: string; toolName: string }) => { + lastAutoOpenedArtifactId.current = artifact.id; + setOpenArtifact(artifact); + }, + [], + ); -
-
- { - handleQuestionSubmission(text); - }} - loading={status === 'loading'} - showSourceButton={selectedAgent ? false : true} - showToolButton={selectedAgent ? false : true} + const handleCloseArtifact = useCallback(() => setOpenArtifact(null), []); + + const isSplitArtifactOpen = !isMobile && openArtifact !== null; + + return ( +
+
+
+ + +
+ ) : undefined + } />
-

- {t('tagline')} -

+
+
+ { + handleQuestionSubmission(text); + }} + loading={status === 'loading'} + showSourceButton={selectedAgent ? false : true} + showToolButton={selectedAgent ? false : true} + /> +
+ +

+ {t('tagline')} +

+
+ + {isSplitArtifactOpen && ( +
+ +
+ )} + + {isMobile && ( + + )}
); } diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 481d3387..bf058581 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -62,6 +62,7 @@ const ConversationBubble = forwardRef< index?: number, ) => void; filesAttached?: { id: string; fileName: string }[]; + onOpenArtifact?: (artifact: { id: string; toolName: string }) => void; } >(function ConversationBubble( { @@ -78,6 +79,7 @@ const ConversationBubble = forwardRef< isStreaming, handleUpdatedQuestionSubmission, filesAttached, + onOpenArtifact, }, ref, ) { @@ -96,6 +98,21 @@ const ConversationBubble = forwardRef< const editableQueryRef = useRef(null); const [isQuestionCollapsed, setIsQuestionCollapsed] = useState(true); + const completedArtifactCalls = (toolCalls ?? []).filter( + (toolCall) => toolCall.artifact_id && toolCall.status === 'completed', + ); + const primaryArtifactCall = + completedArtifactCalls[completedArtifactCalls.length - 1] ?? null; + const artifactCount = completedArtifactCalls.length; + + const formatToolName = (toolName: string | undefined): string => { + if (!toolName) return ''; + return toolName + .split('_') + .map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase()) + .join(' '); + }; + useOutsideAlerter(editableQueryRef, () => setIsEditClicked(false), [], true); useEffect(() => { @@ -379,6 +396,45 @@ const ConversationBubble = forwardRef< {toolCalls && toolCalls.length > 0 && ( )} + {!message && primaryArtifactCall?.artifact_id && onOpenArtifact && ( +
+ +
+ )} {thought && ( )} @@ -548,6 +604,46 @@ const ConversationBubble = forwardRef<
) : ( <> + {primaryArtifactCall?.artifact_id && onOpenArtifact && ( +
+ +
+ )} {!isStreaming && ( <>
@@ -692,106 +788,107 @@ export default ConversationBubble; function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) { const [isToolCallsOpen, setIsToolCallsOpen] = useState(false); + return ( -
-
- - } - /> - -
- {isToolCallsOpen && ( -
-
- {toolCalls.map((toolCall, index) => ( - -
-
-

- - Arguments - {' '} - -

-

- - {JSON.stringify(toolCall.arguments, null, 2)} - -

-
-
-

- - Response - {' '} - -

- {toolCall.status === 'pending' && ( - - - - )} - {toolCall.status === 'completed' && ( -

+ +

+ {isToolCallsOpen && ( +
+
+ {toolCalls.map((toolCall, index) => ( + +
+
+

+ + Arguments + {' '} + +

+

- {JSON.stringify(toolCall.result, null, 2)} + {JSON.stringify(toolCall.arguments, null, 2)}

- )} - {toolCall.status === 'error' && ( -

- - {toolCall.error} - +

+
+

+ + Response + {' '} +

- )} + {toolCall.status === 'pending' && ( + + + + )} + {toolCall.status === 'completed' && ( +

+ + {JSON.stringify(toolCall.result, null, 2)} + +

+ )} + {toolCall.status === 'error' && ( +

+ + {toolCall.error} + +

+ )} +
-
- - ))} + + ))} +
-
- )} -
+ )} +
); } diff --git a/frontend/src/conversation/ConversationMessages.tsx b/frontend/src/conversation/ConversationMessages.tsx index 02d5b831..b0f68ab7 100644 --- a/frontend/src/conversation/ConversationMessages.tsx +++ b/frontend/src/conversation/ConversationMessages.tsx @@ -36,6 +36,8 @@ type ConversationMessagesProps = { status: Status; showHeroOnEmpty?: boolean; headerContent?: ReactNode; + onOpenArtifact?: (artifact: { id: string; toolName: string }) => void; + isSplitView?: boolean; }; export default function ConversationMessages({ @@ -46,6 +48,8 @@ export default function ConversationMessages({ handleFeedback, showHeroOnEmpty = true, headerContent, + onOpenArtifact, + isSplitView = false, }: ConversationMessagesProps) { const [isDarkTheme] = useDarkTheme(); const { t } = useTranslation(); @@ -147,6 +151,7 @@ export default function ConversationMessages({ thought={query.thought} sources={query.sources} toolCalls={query.tool_calls} + onOpenArtifact={onOpenArtifact} feedback={query.feedback} isStreaming={isCurrentlyStreaming} handleFeedback={ @@ -213,7 +218,13 @@ export default function ConversationMessages({ )} -
+
{headerContent} {queries.length > 0 ? ( diff --git a/frontend/src/conversation/types/index.ts b/frontend/src/conversation/types/index.ts index d962e4bc..c416bde6 100644 --- a/frontend/src/conversation/types/index.ts +++ b/frontend/src/conversation/types/index.ts @@ -6,4 +6,5 @@ export type ToolCallsType = { result?: Record; error?: string; status?: 'pending' | 'completed' | 'error'; + artifact_id?: string; }; diff --git a/tests/agents/test_get_artifact.py b/tests/agents/test_get_artifact.py new file mode 100644 index 00000000..e6fb0f76 --- /dev/null +++ b/tests/agents/test_get_artifact.py @@ -0,0 +1,194 @@ +from datetime import datetime + +import pytest +from bson.objectid import ObjectId +from flask import request + + +@pytest.mark.unit +class TestGetArtifact: + def test_note_artifact_success(self, mock_mongo_db, flask_app, decoded_token): + from application.core.settings import settings + from application.api.user.tools.routes import GetArtifact + + db = mock_mongo_db[settings.MONGO_DB_NAME] + note_id = ObjectId() + db["notes"].insert_one( + { + "_id": note_id, + "user_id": decoded_token["sub"], + "tool_id": "tool1", + "note": "a\nb", + "updated_at": datetime(2025, 1, 1), + } + ) + + with flask_app.app_context(): + with flask_app.test_request_context(): + request.decoded_token = decoded_token + resource = GetArtifact() + resp = resource.get(str(note_id)) + + assert resp.status_code == 200 + assert resp.json["artifact"]["artifact_type"] == "note" + assert resp.json["artifact"]["data"]["content"] == "a\nb" + assert resp.json["artifact"]["data"]["line_count"] == 2 + + def test_todo_artifact_success(self, mock_mongo_db, flask_app, decoded_token): + from application.core.settings import settings + from application.api.user.tools.routes import GetArtifact + + db = mock_mongo_db[settings.MONGO_DB_NAME] + todo_id_1 = ObjectId() + todo_id_2 = ObjectId() + db["todos"].insert_many([ + { + "_id": todo_id_1, + "user_id": decoded_token["sub"], + "tool_id": "tool1", + "todo_id": 1, + "title": "First task", + "status": "open", + "created_at": datetime(2025, 1, 1), + "updated_at": datetime(2025, 1, 1), + }, + { + "_id": todo_id_2, + "user_id": decoded_token["sub"], + "tool_id": "tool1", + "todo_id": 2, + "title": "Second task", + "status": "completed", + "created_at": datetime(2025, 1, 1), + "updated_at": datetime(2025, 1, 2), + }, + ]) + + with flask_app.app_context(): + with flask_app.test_request_context(): + request.decoded_token = decoded_token + resource = GetArtifact() + resp = resource.get(str(todo_id_1)) + + assert resp.status_code == 200 + assert resp.json["artifact"]["artifact_type"] == "todo_list" + data = resp.json["artifact"]["data"] + assert data["total_count"] == 2 + assert data["open_count"] == 1 + assert data["completed_count"] == 1 + assert len(data["items"]) == 2 + # Verify both todos are returned + todo_ids = [item["todo_id"] for item in data["items"]] + assert 1 in todo_ids + assert 2 in todo_ids + + def test_todo_artifact_all_param(self, mock_mongo_db, flask_app, decoded_token): + """Test that all todos are returned regardless of the 'all' query parameter.""" + from application.core.settings import settings + from application.api.user.tools.routes import GetArtifact + + db = mock_mongo_db[settings.MONGO_DB_NAME] + todo_id_1 = ObjectId() + todo_id_2 = ObjectId() + db["todos"].insert_many([ + { + "_id": todo_id_1, + "user_id": decoded_token["sub"], + "tool_id": "tool1", + "todo_id": 1, + "title": "First task", + "status": "open", + "created_at": datetime(2025, 1, 1), + "updated_at": datetime(2025, 1, 1), + }, + { + "_id": todo_id_2, + "user_id": decoded_token["sub"], + "tool_id": "tool1", + "todo_id": 2, + "title": "Second task", + "status": "completed", + "created_at": datetime(2025, 1, 1), + "updated_at": datetime(2025, 1, 2), + }, + ]) + + # Test without query parameter - should return all todos + with flask_app.app_context(): + with flask_app.test_request_context(): + request.decoded_token = decoded_token + resource = GetArtifact() + resp = resource.get(str(todo_id_1)) + + assert resp.status_code == 200 + assert resp.json["artifact"]["artifact_type"] == "todo_list" + data = resp.json["artifact"]["data"] + assert data["total_count"] == 2 + assert data["open_count"] == 1 + assert data["completed_count"] == 1 + assert len(data["items"]) == 2 + + # Test with query parameter (should still return all todos, parameter is ignored) + with flask_app.app_context(): + with flask_app.test_request_context(query_string={"all": "true"}): + request.decoded_token = decoded_token + resource = GetArtifact() + resp = resource.get(str(todo_id_1)) + + assert resp.status_code == 200 + assert resp.json["artifact"]["artifact_type"] == "todo_list" + data = resp.json["artifact"]["data"] + assert data["total_count"] == 2 + assert data["open_count"] == 1 + assert data["completed_count"] == 1 + assert len(data["items"]) == 2 + + def test_invalid_artifact_id_returns_400(self, mock_mongo_db, flask_app, decoded_token): + from application.api.user.tools.routes import GetArtifact + + with flask_app.app_context(): + with flask_app.test_request_context(): + request.decoded_token = decoded_token + resource = GetArtifact() + resp = resource.get("not_an_object_id") + + assert resp.status_code == 400 + assert resp.json["message"] == "Invalid artifact ID" + + def test_artifact_not_found_returns_404(self, mock_mongo_db, flask_app, decoded_token): + from application.api.user.tools.routes import GetArtifact + + non_existent_id = ObjectId() + + with flask_app.app_context(): + with flask_app.test_request_context(): + request.decoded_token = decoded_token + resource = GetArtifact() + resp = resource.get(str(non_existent_id)) + + assert resp.status_code == 404 + assert resp.json["message"] == "Artifact not found" + + def test_other_user_artifact_returns_404(self, mock_mongo_db, flask_app, decoded_token): + from application.core.settings import settings + from application.api.user.tools.routes import GetArtifact + + db = mock_mongo_db[settings.MONGO_DB_NAME] + note_id = ObjectId() + db["notes"].insert_one( + { + "_id": note_id, + "user_id": "other_user", + "tool_id": "tool1", + "note": "secret", + "updated_at": datetime(2025, 1, 1), + } + ) + + with flask_app.app_context(): + with flask_app.test_request_context(): + request.decoded_token = decoded_token + resource = GetArtifact() + resp = resource.get(str(note_id)) + + assert resp.status_code == 404 diff --git a/tests/conftest.py b/tests/conftest.py index 325ed406..2355808e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -191,3 +191,11 @@ def mock_tool_manager(mock_tool, monkeypatch): "application.agents.base.ToolManager", Mock(return_value=manager) ) return manager + + +@pytest.fixture +def flask_app(): + from flask import Flask + + app = Flask(__name__) + return app diff --git a/tests/test_memory_tool.py b/tests/test_memory_tool.py index be00dbd4..dc4d97d5 100644 --- a/tests/test_memory_tool.py +++ b/tests/test_memory_tool.py @@ -759,4 +759,4 @@ def test_view_file_line_numbers(memory_tool: MemoryTool) -> None: assert "3: Line 2" not in result # Wrong line number assert "4: Line 3" not in result # Wrong line number - assert "5: Line 4" not in result # Wrong line number + assert "5: Line 4" not in result # Wrong line number \ No newline at end of file diff --git a/tests/test_notes_tool.py b/tests/test_notes_tool.py index d6bd63f8..4c915506 100644 --- a/tests/test_notes_tool.py +++ b/tests/test_notes_tool.py @@ -10,20 +10,23 @@ def notes_tool(monkeypatch) -> NotesTool: class FakeCollection: def __init__(self) -> None: self.docs = {} # key: user_id:tool_id -> doc + self._id_counter = 0 + + def _generate_id(self): + self._id_counter += 1 + return f"fake_id_{self._id_counter}" def update_one(self, q, u, upsert=False): user_id = q.get("user_id") tool_id = q.get("tool_id") key = f"{user_id}:{tool_id}" - # emulate single-note storage with optional upsert - if key not in self.docs and not upsert: return type("res", (), {"modified_count": 0}) if key not in self.docs and upsert: - self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""} - if "$set" in u and "note" in u["$set"]: - self.docs[key]["note"] = u["$set"]["note"] + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": "", "_id": self._generate_id()} + if "$set" in u: + self.docs[key].update(u["$set"]) return type("res", (), {"modified_count": 1}) def find_one(self, q): @@ -32,6 +35,28 @@ def notes_tool(monkeypatch) -> NotesTool: key = f"{user_id}:{tool_id}" return self.docs.get(key) + def find_one_and_update(self, q, u, upsert=False, return_document=None): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + + if key not in self.docs and not upsert: + return None + if key not in self.docs and upsert: + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": "", "_id": self._generate_id()} + if "$set" in u: + self.docs[key].update(u["$set"]) + return self.docs[key] + + def find_one_and_delete(self, q): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + if key in self.docs: + doc = self.docs.pop(key) + return doc + return None + def delete_one(self, q): user_id = q.get("user_id") tool_id = q.get("tool_id") @@ -147,12 +172,9 @@ def test_insert_line(notes_tool: NotesTool) -> None: @pytest.mark.unit def test_delete_nonexistent_note(monkeypatch): - class FakeResult: - deleted_count = 0 - class FakeCollection: - def delete_one(self, *args, **kwargs): - return FakeResult() + def find_one_and_delete(self, q): + return None monkeypatch.setattr( "application.core.mongo_db.MongoDB.get_client", @@ -171,6 +193,11 @@ def test_notes_tool_isolation(monkeypatch) -> None: class FakeCollection: def __init__(self) -> None: self.docs = {} + self._id_counter = 0 + + def _generate_id(self): + self._id_counter += 1 + return f"fake_id_{self._id_counter}" def update_one(self, q, u, upsert=False): user_id = q.get("user_id") @@ -180,9 +207,9 @@ def test_notes_tool_isolation(monkeypatch) -> None: if key not in self.docs and not upsert: return type("res", (), {"modified_count": 0}) if key not in self.docs and upsert: - self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""} - if "$set" in u and "note" in u["$set"]: - self.docs[key]["note"] = u["$set"]["note"] + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": "", "_id": self._generate_id()} + if "$set" in u: + self.docs[key].update(u["$set"]) return type("res", (), {"modified_count": 1}) def find_one(self, q): @@ -191,6 +218,19 @@ def test_notes_tool_isolation(monkeypatch) -> None: key = f"{user_id}:{tool_id}" return self.docs.get(key) + def find_one_and_update(self, q, u, upsert=False, return_document=None): + user_id = q.get("user_id") + tool_id = q.get("tool_id") + key = f"{user_id}:{tool_id}" + + if key not in self.docs and not upsert: + return None + if key not in self.docs and upsert: + self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": "", "_id": self._generate_id()} + if "$set" in u: + self.docs[key].update(u["$set"]) + return self.docs[key] + fake_collection = FakeCollection() fake_db = {"notes": fake_collection} fake_client = {settings.MONGO_DB_NAME: fake_db} diff --git a/tests/test_todo_tool.py b/tests/test_todo_tool.py index 5fa2b242..fc751edf 100644 --- a/tests/test_todo_tool.py +++ b/tests/test_todo_tool.py @@ -24,14 +24,21 @@ class FakeCursor(list): class FakeCollection: def __init__(self): self.docs = {} + self._id_counter = 0 + + def _generate_id(self): + self._id_counter += 1 + return f"fake_id_{self._id_counter}" def create_index(self, *args, **kwargs): pass def insert_one(self, doc): key = (doc["user_id"], doc["tool_id"], doc["todo_id"]) + if "_id" not in doc: + doc["_id"] = self._generate_id() self.docs[key] = doc - return type("res", (), {"inserted_id": key}) + return type("res", (), {"inserted_id": doc["_id"]}) def find_one(self, query): key = (query.get("user_id"), query.get("tool_id"), query.get("todo_id")) @@ -52,12 +59,25 @@ class FakeCollection: self.docs[key].update(update.get("$set", {})) return type("res", (), {"matched_count": 1}) elif upsert: - new_doc = {**query, **update.get("$set", {})} + new_doc = {**query, **update.get("$set", {}), "_id": self._generate_id()} self.docs[key] = new_doc return type("res", (), {"matched_count": 1}) else: return type("res", (), {"matched_count": 0}) + def find_one_and_update(self, query, update): + key = (query.get("user_id"), query.get("tool_id"), query.get("todo_id")) + if key in self.docs: + self.docs[key].update(update.get("$set", {})) + return self.docs[key] + return None + + def find_one_and_delete(self, query): + key = (query.get("user_id"), query.get("tool_id"), query.get("todo_id")) + if key in self.docs: + return self.docs.pop(key) + return None + def delete_one(self, query): key = (query.get("user_id"), query.get("tool_id"), query.get("todo_id")) if key in self.docs: