mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Compare commits
37 Commits
hacktoberf
...
github-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2db33b2d82 | ||
|
|
da6317a242 | ||
|
|
8b8e616557 | ||
|
|
d260f1a1a6 | ||
|
|
9d452e3b04 | ||
|
|
e012189672 | ||
|
|
4c31e9a8b1 | ||
|
|
7cfc230316 | ||
|
|
9605e85f1c | ||
|
|
498e2b772c | ||
|
|
dad897da51 | ||
|
|
02ad5f062e | ||
|
|
4eb9471b4f | ||
|
|
b505d207d7 | ||
|
|
3c954bd07f | ||
|
|
c00b6459dc | ||
|
|
eb4d776784 | ||
|
|
5d7a890533 | ||
|
|
9c6aefef1e | ||
|
|
e4554d6c09 | ||
|
|
c184b63df8 | ||
|
|
6bb4195393 | ||
|
|
7827a4d40d | ||
|
|
f09fa8231a | ||
|
|
96ff10000d | ||
|
|
9460636867 | ||
|
|
6c43245295 | ||
|
|
266b6cf638 | ||
|
|
70183e234a | ||
|
|
17b9c359ca | ||
|
|
045630b8a5 | ||
|
|
55ff7dd640 | ||
|
|
5b2738aec9 | ||
|
|
892312fc08 | ||
|
|
444b1a0b65 | ||
|
|
814ea1c016 | ||
|
|
7c15a4c7ff |
@@ -35,4 +35,4 @@ Non-Code Contributions:
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
We will publish a t-shirt desing later into the October.
|
||||
We will publish a t-shirt design later into the October.
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://twitter.com/docsgptai"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
@@ -67,7 +67,7 @@
|
||||
- [x] Json Responses (August 2025)
|
||||
- [x] MCP support (August 2025)
|
||||
- [x] Google Drive integration (September 2025)
|
||||
- [ ] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||
- [ ] Sharepoint integration (October 2025)
|
||||
- [ ] Deep Agents (October 2025)
|
||||
- [ ] Agent scheduling
|
||||
@@ -118,7 +118,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
|
||||
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
||||
```
|
||||
|
||||
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
|
||||
**Navigate to http://localhost:5173/**
|
||||
|
||||
|
||||
@@ -213,18 +213,24 @@ class BaseAgent(ABC):
|
||||
):
|
||||
target_dict[param] = value
|
||||
tm = ToolManager(config={})
|
||||
|
||||
# Prepare tool_config and add tool_id for memory tools
|
||||
if tool_data["name"] == "api_tool":
|
||||
tool_config = {
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
||||
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
tool_config=tool_config,
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
|
||||
@@ -20,9 +20,10 @@ with open(
|
||||
"r",
|
||||
) as f:
|
||||
final_prompt_template = f.read()
|
||||
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -38,49 +39,69 @@ class ReActAgent(BaseAgent):
|
||||
collected_content = []
|
||||
if isinstance(resp, str):
|
||||
collected_content.append(resp)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
collected_content.append(resp.message.content)
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices") and resp.choices and
|
||||
hasattr(resp.choices[0], "message") and
|
||||
hasattr(resp.choices[0].message, "content") and
|
||||
resp.choices[0].message.content is not None
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices")
|
||||
and resp.choices
|
||||
and hasattr(resp.choices[0], "message")
|
||||
and hasattr(resp.choices[0].message, "content")
|
||||
and resp.choices[0].message.content is not None
|
||||
):
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
|
||||
hasattr(resp.content[0], "text")
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content")
|
||||
and isinstance(resp.content, list)
|
||||
and resp.content
|
||||
and hasattr(resp.content[0], "text")
|
||||
):
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
else:
|
||||
# Assume resp is a stream if not a recognized object
|
||||
chunk = None
|
||||
try:
|
||||
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
for (
|
||||
chunk
|
||||
) in (
|
||||
resp
|
||||
): # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
content_piece = ""
|
||||
# OpenAI-like stream
|
||||
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
|
||||
hasattr(chunk.choices[0], 'delta') and \
|
||||
hasattr(chunk.choices[0].delta, 'content') and \
|
||||
chunk.choices[0].delta.content is not None:
|
||||
if (
|
||||
hasattr(chunk, "choices")
|
||||
and len(chunk.choices) > 0
|
||||
and hasattr(chunk.choices[0], "delta")
|
||||
and hasattr(chunk.choices[0].delta, "content")
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
content_piece = chunk.choices[0].delta.content
|
||||
# Anthropic-like stream (ContentBlockDelta)
|
||||
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
|
||||
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
|
||||
elif (
|
||||
hasattr(chunk, "type")
|
||||
and chunk.type == "content_block_delta"
|
||||
and hasattr(chunk, "delta")
|
||||
and hasattr(chunk.delta, "text")
|
||||
):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
content_piece = chunk
|
||||
|
||||
if content_piece:
|
||||
collected_content.append(content_piece)
|
||||
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
|
||||
except (
|
||||
TypeError
|
||||
): # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(
|
||||
f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
|
||||
|
||||
logger.error(
|
||||
f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk) if chunk is not None else 'N/A'}"
|
||||
)
|
||||
|
||||
return "".join(collected_content)
|
||||
|
||||
@@ -112,8 +133,9 @@ class ReActAgent(BaseAgent):
|
||||
yield {"thought": line_chunk}
|
||||
self.plan = "".join(current_plan_parts)
|
||||
if self.plan:
|
||||
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
|
||||
|
||||
self.observations.append(
|
||||
f"Plan: {self.plan} Iteration: {iterating_reasoning}"
|
||||
)
|
||||
|
||||
max_obs_len = 20000
|
||||
obs_str = "\n".join(self.observations)
|
||||
@@ -125,34 +147,55 @@ class ReActAgent(BaseAgent):
|
||||
+ f"\n\nObservations:\n{obs_str}"
|
||||
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
|
||||
)
|
||||
|
||||
|
||||
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
|
||||
|
||||
resp_from_llm_gen = self._llm_gen(messages, log_context)
|
||||
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(
|
||||
resp_from_llm_gen
|
||||
)
|
||||
if initial_llm_thought_content:
|
||||
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
|
||||
self.observations.append(
|
||||
f"Initial thought/response: {initial_llm_thought_content}"
|
||||
)
|
||||
else:
|
||||
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
|
||||
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
|
||||
|
||||
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
|
||||
logger.info(
|
||||
"ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls)."
|
||||
)
|
||||
resp_after_handler = self._llm_handler(
|
||||
resp_from_llm_gen, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
for (
|
||||
tool_call_info
|
||||
) in (
|
||||
self.tool_calls
|
||||
): # Iterate over self.tool_calls populated by _llm_handler
|
||||
observation_string = (
|
||||
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
|
||||
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
|
||||
)
|
||||
self.observations.append(observation_string)
|
||||
|
||||
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
|
||||
content_after_handler = self._extract_content_from_llm_response(
|
||||
resp_after_handler
|
||||
)
|
||||
if content_after_handler:
|
||||
self.observations.append(f"Response after tool execution: {content_after_handler}")
|
||||
self.observations.append(
|
||||
f"Response after tool execution: {content_after_handler}"
|
||||
)
|
||||
else:
|
||||
logger.info("ReActAgent: LLM response after handler had no textual content.")
|
||||
logger.info(
|
||||
"ReActAgent: LLM response after handler had no textual content."
|
||||
)
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
{
|
||||
"component": "agent_tool_calls",
|
||||
"data": {"tool_calls": self.tool_calls.copy()},
|
||||
}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
@@ -165,13 +208,17 @@ class ReActAgent(BaseAgent):
|
||||
display_tool_calls.append(cleaned_tc)
|
||||
if display_tool_calls:
|
||||
yield {"tool_calls": display_tool_calls}
|
||||
|
||||
|
||||
if "SATISFIED" in content_after_handler:
|
||||
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
|
||||
logger.info(
|
||||
"ReActAgent: LLM satisfied with the plan and data. Stopping reasoning."
|
||||
)
|
||||
break
|
||||
|
||||
# 3. Create Final Answer based on all observations
|
||||
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
|
||||
final_answer_stream = self._create_final_answer(
|
||||
query, self.observations, log_context
|
||||
)
|
||||
for answer_chunk in final_answer_stream:
|
||||
yield {"answer": answer_chunk}
|
||||
logger.info("ReActAgent: Finished generating final answer.")
|
||||
@@ -184,12 +231,16 @@ class ReActAgent(BaseAgent):
|
||||
summaries = docs_data if docs_data else "No documents retrieved."
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
|
||||
plan_prompt_filled = plan_prompt_filled.replace(
|
||||
"{observations}", "\n".join(self.observations)
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": plan_prompt_filled}]
|
||||
|
||||
plan_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
|
||||
model=self.gpt_model,
|
||||
messages=messages,
|
||||
tools=getattr(self, "tools", None), # Use self.tools
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
@@ -206,8 +257,12 @@ class ReActAgent(BaseAgent):
|
||||
observation_string = "\n".join(observations)
|
||||
max_obs_len = 10000
|
||||
if len(observation_string) > max_obs_len:
|
||||
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
|
||||
observation_string = (
|
||||
observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
)
|
||||
logger.warning(
|
||||
"ReActAgent: Truncated observations for final answer prompt due to length."
|
||||
)
|
||||
|
||||
final_answer_prompt_filled = final_prompt_template.format(
|
||||
query=query, observations=observation_string
|
||||
@@ -226,4 +281,4 @@ class ReActAgent(BaseAgent):
|
||||
for chunk in final_answer_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
yield content_piece
|
||||
|
||||
@@ -37,7 +37,7 @@ _mcp_clients_cache = {}
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
|
||||
546
application/agents/tools/memory.py
Normal file
546
application/agents/tools/memory.py
Normal file
@@ -0,0 +1,546 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
"""Memory
|
||||
|
||||
Stores and retrieves information across conversations through a memory file directory.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated memories
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
kwargs.get("view_range")
|
||||
)
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("file_text", "")
|
||||
)
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("old_str", ""),
|
||||
kwargs.get("new_str", "")
|
||||
)
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("insert_line", 1),
|
||||
kwargs.get("insert_text", "")
|
||||
)
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("path", ""))
|
||||
|
||||
if action_name == "rename":
|
||||
return self._rename(
|
||||
kwargs.get("old_path", ""),
|
||||
kwargs.get("new_path", "")
|
||||
)
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||
},
|
||||
"view_range": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create or overwrite a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||
},
|
||||
"file_text": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "file_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "String to find."
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "String to replace with."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_str", "new_str"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at a specific line in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"insert_line": {
|
||||
"type": "integer",
|
||||
"description": "Line number to insert at (1-indexed)."
|
||||
},
|
||||
"insert_text": {
|
||||
"type": "string",
|
||||
"description": "Text to insert."
|
||||
}
|
||||
},
|
||||
"required": ["path", "insert_line", "insert_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "rename",
|
||||
"description": "Rename or move a file/directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_path": {
|
||||
"type": "string",
|
||||
"description": "Current path (e.g., /old.txt)."
|
||||
},
|
||||
"new_path": {
|
||||
"type": "string",
|
||||
"description": "New path (e.g., /new.txt)."
|
||||
}
|
||||
},
|
||||
"required": ["old_path", "new_path"]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Path validation
|
||||
# -----------------------------
|
||||
def _validate_path(self, path: str) -> Optional[str]:
|
||||
"""Validate and normalize path.
|
||||
|
||||
Args:
|
||||
path: User-provided path.
|
||||
|
||||
Returns:
|
||||
Normalized path or None if invalid.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
path = path.strip()
|
||||
|
||||
# Preserve whether path ends with / (indicates directory)
|
||||
is_directory = path.endswith("/")
|
||||
|
||||
# Ensure path starts with / for consistency
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for directory traversal patterns
|
||||
if ".." in path or path.count("//") > 0:
|
||||
return None
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
# Convert to Path object and resolve to canonical form
|
||||
normalized = str(Path(path).as_posix())
|
||||
|
||||
# Ensure it still starts with /
|
||||
if not normalized.startswith("/"):
|
||||
return None
|
||||
|
||||
# Preserve trailing slash for directories
|
||||
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||
normalized = normalized + "/"
|
||||
|
||||
return normalized
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View directory contents or file contents."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
# Check if viewing directory (ends with / or is root)
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return self._view_directory(validated_path)
|
||||
|
||||
# Otherwise view file
|
||||
return self._view_file(validated_path, view_range)
|
||||
|
||||
def _view_directory(self, path: str) -> str:
|
||||
"""List files in a directory."""
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
|
||||
# Extract filenames relative to the directory
|
||||
files = []
|
||||
for doc in docs:
|
||||
file_path = doc["path"]
|
||||
# Remove the directory prefix
|
||||
if file_path.startswith(search_path):
|
||||
relative = file_path[len(search_path):]
|
||||
if relative:
|
||||
files.append(relative)
|
||||
|
||||
files.sort()
|
||||
file_list = "\n".join(f"- {f}" for f in files)
|
||||
return f"Directory: {path}\n{file_list}"
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = str(doc["content"])
|
||||
|
||||
# Apply view_range if specified
|
||||
if view_range and len(view_range) == 2:
|
||||
lines = content.split("\n")
|
||||
start, end = view_range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start - 1)
|
||||
end_idx = min(len(lines), end)
|
||||
|
||||
if start_idx >= len(lines):
|
||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
# Add line numbers (enumerate with 1-based start)
|
||||
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||
return "\n".join(numbered_lines)
|
||||
|
||||
return content
|
||||
|
||||
def _create(self, path: str, file_text: str) -> str:
|
||||
"""Create or overwrite a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||
"""Replace text in a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||
"""Insert text at a specific line."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
def _delete(self, path: str) -> str:
|
||||
"""Delete a file or directory."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
search_path = validated_path + "/"
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
def _rename(self, old_path: str, new_path: str) -> str:
|
||||
"""Rename or move a file/directory."""
|
||||
validated_old = self._validate_path(old_path)
|
||||
validated_new = self._validate_path(new_path)
|
||||
|
||||
if not validated_old or not validated_new:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Check if renaming a directory
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
199
application/agents/tools/notes.py
Normal file
199
application/agents/tools/notes.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
"""Notepad
|
||||
|
||||
Single note. Supports viewing, overwriting, string replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated notes
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
if action_name == "overwrite":
|
||||
return self._overwrite_note(kwargs.get("text", ""))
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete_note()
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Retrieve the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "overwrite",
|
||||
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "New note content."}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_str": {"type": "string", "description": "String to find."},
|
||||
"new_str": {"type": "string", "description": "String to replace with."}
|
||||
},
|
||||
"required": ["old_str", "new_str"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at the specified line number (1-indexed).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||
"text": {"type": "string", "description": "Text to insert."}
|
||||
},
|
||||
"required": ["line_number", "text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _get_note(self) -> str:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True, # ✅ create if missing
|
||||
)
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
return f"String '{old_str}' not found in note."
|
||||
|
||||
# Case-insensitive replacement
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
index = line_number - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Invalid line number. Note has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
)
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return "Note deleted." if res.deleted_count else "No note found to delete."
|
||||
@@ -28,7 +28,7 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
if tool_name in {"mcp_tool", "notes", "memory"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
@@ -36,7 +36,7 @@ class ToolManager:
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
if tool_name in {"mcp_tool", "memory"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""User API module - provides all user-related API endpoints"""
|
||||
|
||||
from .routes import user
|
||||
|
||||
__all__ = ["user"]
|
||||
|
||||
7
application/api/user/agents/__init__.py
Normal file
7
application/api/user/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Agents module."""
|
||||
|
||||
from .routes import agents_ns
|
||||
from .sharing import agents_sharing_ns
|
||||
from .webhooks import agents_webhooks_ns
|
||||
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
|
||||
974
application/api/user/agents/routes.py
Normal file
974
application/api/user/agents/routes.py
Normal file
@@ -0,0 +1,974 @@
|
||||
"""Agent management routes."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
handle_image_upload,
|
||||
resolve_tool_details,
|
||||
storage,
|
||||
users_collection,
|
||||
)
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
generate_image_url,
|
||||
validate_required_fields,
|
||||
)
|
||||
|
||||
|
||||
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
|
||||
|
||||
|
||||
@agents_ns.route("/get_agent")
|
||||
class GetAgent(Resource):
|
||||
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
|
||||
def get(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
if not (agent_id := request.args.get("id")):
|
||||
return {"success": False, "message": "ID required"}, 400
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
if not agent:
|
||||
return {"status": "Not found"}, 404
|
||||
data = {
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent["name"],
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"sources": [
|
||||
(
|
||||
str(db.dereference(source_ref)["_id"])
|
||||
if isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
else source_ref
|
||||
)
|
||||
for source_ref in agent.get("sources", [])
|
||||
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
|
||||
or source_ref == "default"
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
"key": (
|
||||
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
||||
if "key" in agent
|
||||
else ""
|
||||
),
|
||||
"pinned": agent.get("pinned", False),
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Agent fetch error: {e}", exc_info=True)
|
||||
return {"success": False}, 400
|
||||
|
||||
|
||||
@agents_ns.route("/get_agents")
|
||||
class GetAgents(Resource):
|
||||
@api.doc(description="Retrieve agents for the user")
|
||||
def get(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
user_doc = ensure_user_doc(user)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
agents = agents_collection.find({"user": user})
|
||||
list_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent["name"],
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(agent.get("source")))
|
||||
else (
|
||||
agent.get("source", "")
|
||||
if agent.get("source") == "default"
|
||||
else ""
|
||||
)
|
||||
),
|
||||
"sources": [
|
||||
(
|
||||
source_ref
|
||||
if source_ref == "default"
|
||||
else str(db.dereference(source_ref)["_id"])
|
||||
)
|
||||
for source_ref in agent.get("sources", [])
|
||||
if source_ref == "default"
|
||||
or (
|
||||
isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
)
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
"key": (
|
||||
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
||||
if "key" in agent
|
||||
else ""
|
||||
),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_agents), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/create_agent")
|
||||
class CreateAgent(Resource):
|
||||
create_agent_model = api.model(
|
||||
"CreateAgentModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="Name of the agent"),
|
||||
"description": fields.String(
|
||||
required=True, description="Description of the agent"
|
||||
),
|
||||
"image": fields.Raw(
|
||||
required=False, description="Image file upload", type="file"
|
||||
),
|
||||
"source": fields.String(
|
||||
required=False, description="Source ID (legacy single source)"
|
||||
),
|
||||
"sources": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_agent_model)
|
||||
@api.doc(description="Create a new agent")
|
||||
def post(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return {"success": False}, 401
|
||||
user = decoded_token.get("sub")
|
||||
if request.content_type == "application/json":
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
if "tools" in data:
|
||||
try:
|
||||
data["tools"] = json.loads(data["tools"])
|
||||
except json.JSONDecodeError:
|
||||
data["tools"] = []
|
||||
if "sources" in data:
|
||||
try:
|
||||
data["sources"] = json.loads(data["sources"])
|
||||
except json.JSONDecodeError:
|
||||
data["sources"] = []
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
data["json_schema"] = json.loads(data["json_schema"])
|
||||
except json.JSONDecodeError:
|
||||
data["json_schema"] = None
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
|
||||
if data.get("json_schema"):
|
||||
try:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid JSON object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Status must be either 'draft' or 'published'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("status") == "published":
|
||||
required_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"agent_type",
|
||||
]
|
||||
# Require either source or sources (but not both)
|
||||
|
||||
if not data.get("source") and not data.get("sources"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Either 'source' or 'sources' field is required for published agents",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
validate_fields = ["name", "description", "prompt_id", "agent_type"]
|
||||
else:
|
||||
required_fields = ["name"]
|
||||
validate_fields = []
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
invalid_fields = validate_required_fields(data, validate_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if invalid_fields:
|
||||
return invalid_fields
|
||||
image_url, error = handle_image_upload(request, "", user, storage)
|
||||
if error:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}), 400
|
||||
)
|
||||
try:
|
||||
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
|
||||
|
||||
sources_list = []
|
||||
if data.get("sources") and len(data.get("sources", [])) > 0:
|
||||
for source_id in data.get("sources", []):
|
||||
if source_id == "default":
|
||||
sources_list.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
sources_list.append(DBRef("sources", ObjectId(source_id)))
|
||||
source_field = ""
|
||||
else:
|
||||
source_value = data.get("source", "")
|
||||
if source_value == "default":
|
||||
source_field = "default"
|
||||
elif ObjectId.is_valid(source_value):
|
||||
source_field = DBRef("sources", ObjectId(source_value))
|
||||
else:
|
||||
source_field = ""
|
||||
new_agent = {
|
||||
"user": user,
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description", ""),
|
||||
"image": image_url,
|
||||
"source": source_field,
|
||||
"sources": sources_list,
|
||||
"chunks": data.get("chunks", ""),
|
||||
"retriever": data.get("retriever", ""),
|
||||
"prompt_id": data.get("prompt_id", ""),
|
||||
"tools": data.get("tools", []),
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
if (
|
||||
new_agent["source"] == ""
|
||||
and new_agent["retriever"] == ""
|
||||
and not new_agent["sources"]
|
||||
):
|
||||
new_agent["retriever"] = "classic"
|
||||
resp = agents_collection.insert_one(new_agent)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id, "key": key}), 201)
|
||||
|
||||
|
||||
@agents_ns.route("/update_agent/<string:agent_id>")
|
||||
class UpdateAgent(Resource):
|
||||
update_agent_model = api.model(
|
||||
"UpdateAgentModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="New name of the agent"),
|
||||
"description": fields.String(
|
||||
required=True, description="New description of the agent"
|
||||
),
|
||||
"image": fields.String(
|
||||
required=False, description="New image URL or identifier"
|
||||
),
|
||||
"source": fields.String(
|
||||
required=False, description="Source ID (legacy single source)"
|
||||
),
|
||||
"sources": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_agent_model)
|
||||
@api.doc(description="Update an existing agent")
|
||||
def put(self, agent_id):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
if not ObjectId.is_valid(agent_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID format"}), 400
|
||||
)
|
||||
oid = ObjectId(agent_id)
|
||||
|
||||
try:
|
||||
if request.content_type and "application/json" in request.content_type:
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = request.form.to_dict()
|
||||
json_fields = ["tools", "sources", "json_schema"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
try:
|
||||
data[field] = json.loads(data[field])
|
||||
except json.JSONDecodeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid JSON format for field: {field}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error parsing request data: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid request data"}), 400
|
||||
)
|
||||
try:
|
||||
existing_agent = agents_collection.find_one({"_id": oid, "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error finding agent {agent_id}: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error finding agent"}),
|
||||
500,
|
||||
)
|
||||
if not existing_agent:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Agent not found or not authorized"}
|
||||
),
|
||||
404,
|
||||
)
|
||||
image_url, error = handle_image_upload(
|
||||
request, existing_agent.get("image", ""), user, storage
|
||||
)
|
||||
if error:
|
||||
current_app.logger.error(
|
||||
f"Image upload error for agent {agent_id}: {error}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
|
||||
400,
|
||||
)
|
||||
update_fields = {}
|
||||
allowed_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"image",
|
||||
"source",
|
||||
"sources",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"tools",
|
||||
"agent_type",
|
||||
"status",
|
||||
"json_schema",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
if field not in data:
|
||||
continue
|
||||
if field == "status":
|
||||
new_status = data.get("status")
|
||||
if new_status not in ["draft", "published"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid status value. Must be 'draft' or 'published'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = new_status
|
||||
elif field == "source":
|
||||
source_id = data.get("source")
|
||||
if source_id == "default":
|
||||
update_fields[field] = "default"
|
||||
elif source_id and ObjectId.is_valid(source_id):
|
||||
update_fields[field] = DBRef("sources", ObjectId(source_id))
|
||||
elif source_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid source ID format: {source_id}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
else:
|
||||
update_fields[field] = ""
|
||||
elif field == "sources":
|
||||
sources_list = data.get("sources", [])
|
||||
if sources_list and isinstance(sources_list, list):
|
||||
valid_sources = []
|
||||
for source_id in sources_list:
|
||||
if source_id == "default":
|
||||
valid_sources.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
valid_sources.append(DBRef("sources", ObjectId(source_id)))
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid source ID in list: {source_id}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = valid_sources
|
||||
else:
|
||||
update_fields[field] = []
|
||||
elif field == "chunks":
|
||||
chunks_value = data.get("chunks")
|
||||
if chunks_value == "" or chunks_value is None:
|
||||
update_fields[field] = "2"
|
||||
else:
|
||||
try:
|
||||
chunks_int = int(chunks_value)
|
||||
if chunks_int < 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Chunks value must be a non-negative integer",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = str(chunks_int)
|
||||
except (ValueError, TypeError):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid chunks value: {chunks_value}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
elif field == "tools":
|
||||
tools_list = data.get("tools", [])
|
||||
if isinstance(tools_list, list):
|
||||
update_fields[field] = tools_list
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Tools must be a list",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
elif field == "json_schema":
|
||||
json_schema = data.get("json_schema")
|
||||
if json_schema is not None:
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = json_schema
|
||||
else:
|
||||
update_fields[field] = None
|
||||
else:
|
||||
value = data[field]
|
||||
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||
if not value or not str(value).strip():
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Field '{field}' cannot be empty",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = value
|
||||
if image_url:
|
||||
update_fields["image"] = image_url
|
||||
if not update_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "No valid update data provided",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
newly_generated_key = None
|
||||
final_status = update_fields.get("status", existing_agent.get("status"))
|
||||
|
||||
if final_status == "published":
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
"description": "Agent description",
|
||||
"chunks": "Chunks count",
|
||||
"prompt_id": "Prompt",
|
||||
"agent_type": "Agent type",
|
||||
}
|
||||
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
source_val = update_fields.get("source", existing_agent.get("source"))
|
||||
sources_val = update_fields.get(
|
||||
"sources", existing_agent.get("sources", [])
|
||||
)
|
||||
|
||||
has_valid_source = (
|
||||
isinstance(source_val, DBRef)
|
||||
or source_val == "default"
|
||||
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
||||
)
|
||||
|
||||
if not has_valid_source:
|
||||
missing_published_fields.append("Source")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if not existing_agent.get("key"):
|
||||
newly_generated_key = str(uuid.uuid4())
|
||||
update_fields["key"] = newly_generated_key
|
||||
update_fields["updatedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
try:
|
||||
result = agents_collection.update_one(
|
||||
{"_id": oid, "user": user}, {"$set": update_fields}
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Agent not found or update failed",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
if result.modified_count == 0 and result.matched_count == 1:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": "No changes detected",
|
||||
"id": agent_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating agent {agent_id}: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error during update"}),
|
||||
500,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": agent_id,
|
||||
"message": "Agent updated successfully",
|
||||
}
|
||||
if newly_generated_key:
|
||||
response_data["key"] = newly_generated_key
|
||||
return make_response(jsonify(response_data), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/delete_agent")
|
||||
class DeleteAgent(Resource):
|
||||
@api.doc(params={"id": "ID of the agent"}, description="Delete an agent by ID")
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
deleted_agent = agents_collection.find_one_and_delete(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not deleted_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
deleted_id = str(deleted_agent["_id"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": deleted_id}), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/pinned_agents")
|
||||
class PinnedAgents(Resource):
|
||||
@api.doc(description="Get pinned agents for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
try:
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
pinned_ids = user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||
|
||||
if not pinned_ids:
|
||||
return make_response(jsonify([]), 200)
|
||||
pinned_object_ids = [ObjectId(agent_id) for agent_id in pinned_ids]
|
||||
|
||||
pinned_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": pinned_object_ids}}
|
||||
)
|
||||
pinned_agents = list(pinned_agents_cursor)
|
||||
existing_ids = {str(agent["_id"]) for agent in pinned_agents}
|
||||
|
||||
# Clean up any stale pinned IDs
|
||||
|
||||
stale_ids = [
|
||||
agent_id for agent_id in pinned_ids if agent_id not in existing_ids
|
||||
]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
||||
)
|
||||
list_pinned_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"source": (
|
||||
str(db.dereference(agent["source"])["_id"])
|
||||
if "source" in agent
|
||||
and agent["source"]
|
||||
and isinstance(agent["source"], DBRef)
|
||||
and db.dereference(agent["source"]) is not None
|
||||
else ""
|
||||
),
|
||||
"chunks": agent.get("chunks", ""),
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"last_used_at": agent.get("lastUsedAt", ""),
|
||||
"key": (
|
||||
f"{agent['key'][:4]}...{agent['key'][-4:]}"
|
||||
if "key" in agent
|
||||
else ""
|
||||
),
|
||||
"pinned": True,
|
||||
}
|
||||
for agent in pinned_agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving pinned agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_pinned_agents), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/template_agents")
|
||||
class GetTemplateAgents(Resource):
|
||||
@api.doc(description="Get template/premade agents")
|
||||
def get(self):
|
||||
try:
|
||||
template_agents = agents_collection.find({"user": "system"})
|
||||
template_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent["name"],
|
||||
"description": agent["description"],
|
||||
"image": agent.get("image", ""),
|
||||
}
|
||||
for agent in template_agents
|
||||
]
|
||||
return make_response(jsonify(template_agents), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Template agents fetch error: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_ns.route("/adopt_agent")
|
||||
class AdoptAgent(Resource):
|
||||
@api.doc(params={"id": "Agent ID"}, description="Adopt an agent by ID")
|
||||
def post(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
if not (agent_id := request.args.get("id")):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID required"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": "system"}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(jsonify({"status": "Not found"}), 404)
|
||||
|
||||
new_agent = agent.copy()
|
||||
new_agent.pop("_id", None)
|
||||
new_agent["user"] = decoded_token["sub"]
|
||||
new_agent["status"] = "published"
|
||||
new_agent["lastUsedAt"] = datetime.datetime.now(datetime.timezone.utc)
|
||||
new_agent["key"] = str(uuid.uuid4())
|
||||
insert_result = agents_collection.insert_one(new_agent)
|
||||
|
||||
response_agent = new_agent.copy()
|
||||
response_agent.pop("_id", None)
|
||||
response_agent["id"] = str(insert_result.inserted_id)
|
||||
response_agent["tool_details"] = resolve_tool_details(
|
||||
response_agent.get("tools", [])
|
||||
)
|
||||
if isinstance(response_agent.get("source"), DBRef):
|
||||
response_agent["source"] = str(response_agent["source"].id)
|
||||
return make_response(
|
||||
jsonify({"success": True, "agent": response_agent}), 200
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Agent adopt error: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_ns.route("/pin_agent")
|
||||
class PinAgent(Resource):
|
||||
@api.doc(params={"id": "ID of the agent"}, description="Pin or unpin an agent")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
pinned_list = user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||
|
||||
if agent_id in pinned_list:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pull": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
action = "unpinned"
|
||||
else:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
action = "pinned"
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Server error"}), 500
|
||||
)
|
||||
return make_response(jsonify({"success": True, "action": action}), 200)
|
||||
|
||||
|
||||
@agents_ns.route("/remove_shared_agent")
|
||||
class RemoveSharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the shared agent"},
|
||||
description="Remove a shared agent from the current user's shared list",
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "shared_publicly": True}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{
|
||||
"$pull": {
|
||||
"agent_preferences.shared_with_me": agent_id,
|
||||
"agent_preferences.pinned": agent_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error removing shared agent: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Server error"}), 500
|
||||
)
|
||||
254
application/api/user/agents/sharing.py
Normal file
254
application/api/user/agents/sharing.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Agent management sharing functionality."""
|
||||
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from bson import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
resolve_tool_details,
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={
|
||||
"token": "Shared token of the agent",
|
||||
},
|
||||
description="Get a shared agent by token or ID",
|
||||
)
|
||||
def get(self):
|
||||
shared_token = request.args.get("token")
|
||||
|
||||
if not shared_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
shared_agent = agents_collection.find_one(query)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
"user": shared_agent.get("user", ""),
|
||||
"name": shared_agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(shared_agent["image"])
|
||||
if shared_agent.get("image")
|
||||
else ""
|
||||
),
|
||||
"description": shared_agent.get("description", ""),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(shared_agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"chunks": shared_agent.get("chunks", "0"),
|
||||
"retriever": shared_agent.get("retriever", "classic"),
|
||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||
"tools": shared_agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
"shared_token": shared_agent.get("shared_token", ""),
|
||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||
}
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for tool in data["tools"]:
|
||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user")
|
||||
|
||||
if user_id != owner_id:
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agents")
|
||||
class SharedAgents(Resource):
|
||||
@api.doc(description="Get shared agents explicitly shared with the user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||
"shared_with_me", []
|
||||
)
|
||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||
|
||||
shared_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||
)
|
||||
shared_agents = list(shared_agents_cursor)
|
||||
|
||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
}
|
||||
for agent in shared_agents
|
||||
]
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/share_agent")
|
||||
class ShareAgent(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ShareAgentModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="ID of the agent"),
|
||||
"shared": fields.Boolean(
|
||||
required=True, description="Share or unshare the agent"
|
||||
),
|
||||
"username": fields.String(
|
||||
required=False, description="Name of the user"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Share or unshare an agent")
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing JSON body"}), 400
|
||||
)
|
||||
agent_id = data.get("id")
|
||||
shared = data.get("shared")
|
||||
username = data.get("username", "")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
if shared is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Shared parameter is required and must be true or false",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{
|
||||
"$set": {
|
||||
"shared_publicly": shared,
|
||||
"shared_metadata": shared_metadata,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
119
application/api/user/agents/webhooks.py
Normal file
119
application/api/user/agents/webhooks.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the agent"},
|
||||
description="Generate webhook URL for the agent",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id), "user": user},
|
||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error generating webhook URL: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error generating webhook URL"}),
|
||||
400,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
if payload is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid or missing JSON data in request body",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
5
application/api/user/analytics/__init__.py
Normal file
5
application/api/user/analytics/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Analytics module."""
|
||||
|
||||
from .routes import analytics_ns
|
||||
|
||||
__all__ = ["analytics_ns"]
|
||||
540
application/api/user/analytics/routes.py
Normal file
540
application/api/user/analytics/routes.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""Analytics and reporting routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
conversations_collection,
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
token_usage_collection,
|
||||
user_logs_collection,
|
||||
)
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
"GetMessageAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_message_analytics_model)
|
||||
@api.doc(description="Get message analytics based on filter option")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in message_data:
|
||||
daily_messages[entry["_id"]] = entry["count"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "messages": daily_messages}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_token_analytics")
|
||||
class GetTokenAnalytics(Resource):
|
||||
get_token_analytics_model = api.model(
|
||||
"GetTokenAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_token_analytics_model)
|
||||
@api.doc(description="Get token analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"minute": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"hour": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"day": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
match_stage,
|
||||
group_stage,
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in token_usage_data:
|
||||
if filter_option == "last_hour":
|
||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||
elif filter_option == "last_24_hour":
|
||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||
else:
|
||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "token_usage": daily_token_usage}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_feedback_analytics")
|
||||
class GetFeedbackAnalytics(Resource):
|
||||
get_feedback_analytics_model = api.model(
|
||||
"GetFeedbackAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_feedback_analytics_model)
|
||||
@api.doc(description="Get feedback analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.time",
|
||||
"positive": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"negative": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
feedback_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
|
||||
for entry in feedback_data:
|
||||
daily_feedback[entry["_id"]] = {
|
||||
"positive": entry["positive"],
|
||||
"negative": entry["negative"],
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting feedback analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "feedback": daily_feedback}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_user_logs")
|
||||
class GetUserLogs(Resource):
|
||||
get_user_logs_model = api.model(
|
||||
"GetUserLogsModel",
|
||||
{
|
||||
"page": fields.Integer(
|
||||
required=False,
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
),
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"page_size": fields.Integer(
|
||||
required=False,
|
||||
description="Number of logs per page",
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_user_logs_model)
|
||||
@api.doc(description="Get user logs with pagination")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
items_cursor = (
|
||||
user_logs_collection.find(query)
|
||||
.sort("timestamp", -1)
|
||||
.skip(skip)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
items = list(items_cursor)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("_id")),
|
||||
"action": item.get("action"),
|
||||
"level": item.get("level"),
|
||||
"user": item.get("user"),
|
||||
"question": item.get("question"),
|
||||
"sources": item.get("sources"),
|
||||
"retriever_params": item.get("retriever_params"),
|
||||
"timestamp": item.get("timestamp"),
|
||||
}
|
||||
for item in items[:page_size]
|
||||
]
|
||||
|
||||
has_more = len(items) > page_size
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
5
application/api/user/attachments/__init__.py
Normal file
5
application/api/user/attachments/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Attachments module."""
|
||||
|
||||
from .routes import attachments_ns
|
||||
|
||||
__all__ = ["attachments_ns"]
|
||||
150
application/api/user/attachments/routes.py
Normal file
150
application/api/user/attachments/routes.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""File attachments and media routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, storage
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.core.settings import settings
|
||||
from application.tts.google_tts import GoogleTTS
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
attachments_ns = Namespace(
|
||||
"attachments", description="File attachments and media operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/store_attachment")
|
||||
class StoreAttachment(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores a single attachment without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
file = request.files.get("file")
|
||||
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
user = None
|
||||
if decoded_token:
|
||||
user = safe_filename(decoded_token.get("sub"))
|
||||
elif api_key:
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
user = safe_filename(agent.get("user"))
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task.id,
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||
|
||||
|
||||
@attachments_ns.route("/images/<path:image_path>")
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
try:
|
||||
file_obj = storage.get_file(image_path)
|
||||
extension = image_path.split(".")[-1].lower()
|
||||
content_type = f"image/{extension}"
|
||||
if extension == "jpg":
|
||||
content_type = "image/jpeg"
|
||||
response = make_response(file_obj.read())
|
||||
response.headers.set("Content-Type", content_type)
|
||||
response.headers.set("Cache-Control", "max-age=86400")
|
||||
|
||||
return response
|
||||
except FileNotFoundError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error retrieving image"}), 500
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/tts")
|
||||
class TextToSpeech(Resource):
|
||||
tts_model = api.model(
|
||||
"TextToSpeechModel",
|
||||
{
|
||||
"text": fields.String(
|
||||
required=True, description="Text to be synthesized as audio"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = GoogleTTS()
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"audio_base64": audio_base64,
|
||||
"lang": detected_language,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
222
application/api/user/base.py
Normal file
222
application/api/user/base.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Shared utilities, database connections, and helper functions for user API routes.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
|
||||
|
||||
try:
|
||||
agents_collection.create_index(
|
||||
[("shared", 1)],
|
||||
name="shared_index",
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
def generate_minute_range(start_date, end_date):
|
||||
"""Generate a dictionary with minute-level time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_hourly_range(start_date, end_date):
|
||||
"""Generate a dictionary with hourly time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_date_range(start_date, end_date):
|
||||
"""Generate a dictionary with daily date ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
|
||||
for i in range((end_date - start_date).days + 1)
|
||||
}
|
||||
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure user document exists with proper agent preferences structure.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document
|
||||
"""
|
||||
default_prefs = {
|
||||
"pinned": [],
|
||||
"shared_with_me": [],
|
||||
}
|
||||
|
||||
user_doc = users_collection.find_one_and_update(
|
||||
{"user_id": user_id},
|
||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
prefs = user_doc.get("agent_preferences", {})
|
||||
updates = {}
|
||||
if "pinned" not in prefs:
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their details.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs
|
||||
|
||||
Returns:
|
||||
List of tool details with id, name, and display_name
|
||||
"""
|
||||
tools = user_tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("name", ""),
|
||||
"display_name": tool.get("displayName", tool.get("name", "")),
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store for a given source ID.
|
||||
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
|
||||
Returns:
|
||||
Vector store instance
|
||||
"""
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
def handle_image_upload(
|
||||
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
|
||||
) -> Tuple[str, Optional[Response]]:
|
||||
"""
|
||||
Handle image file upload from request.
|
||||
|
||||
Args:
|
||||
request: Flask request object
|
||||
existing_url: Existing image URL (fallback)
|
||||
user: User ID
|
||||
storage: Storage instance
|
||||
base_path: Base path for upload
|
||||
|
||||
Returns:
|
||||
Tuple of (image_url, error_response)
|
||||
"""
|
||||
image_url = existing_url
|
||||
|
||||
if "image" in request.files:
|
||||
file = request.files["image"]
|
||||
if file.filename != "":
|
||||
filename = secure_filename(file.filename)
|
||||
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
|
||||
try:
|
||||
storage.save_file(file, upload_path, storage_class="STANDARD")
|
||||
image_url = upload_path
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error uploading image: {e}")
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}),
|
||||
400,
|
||||
)
|
||||
return image_url, None
|
||||
|
||||
|
||||
def require_agent(func):
|
||||
"""
|
||||
Decorator to require valid agent webhook token.
|
||||
|
||||
Args:
|
||||
func: Function to decorate
|
||||
|
||||
Returns:
|
||||
Wrapped function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one(
|
||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||
)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["_id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
5
application/api/user/conversations/__init__.py
Normal file
5
application/api/user/conversations/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Conversation management module."""
|
||||
|
||||
from .routes import conversations_ns
|
||||
|
||||
__all__ = ["conversations_ns"]
|
||||
280
application/api/user/conversations/routes.py
Normal file
280
application/api/user/conversations/routes.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Conversation management routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
"conversations", description="Conversation management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a conversation by ID",
|
||||
params={"id": "The ID of the conversation to delete"},
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_all_conversations")
|
||||
class DeleteAllConversations(Resource):
|
||||
@api.doc(
|
||||
description="Deletes all conversations for a specific user",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_conversations")
|
||||
class GetConversations(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["_id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a single conversation by ID",
|
||||
params={"id": "The conversation ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
# Process queries to include attachment names
|
||||
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/update_conversation_name")
|
||||
class UpdateConversationName(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateConversationModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Conversation ID"),
|
||||
"name": fields.String(
|
||||
required=True, description="New name of the conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates the name of a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index": fields.Integer(
|
||||
required=True,
|
||||
description="The question number in that particular conversation",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="id of the particular conversation"
|
||||
),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Submit feedback for a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["feedback"] is None:
|
||||
# Remove feedback and feedback_timestamp if feedback is null
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Set feedback and feedback_timestamp if feedback has a value
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
5
application/api/user/prompts/__init__.py
Normal file
5
application/api/user/prompts/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Prompts module."""
|
||||
|
||||
from .routes import prompts_ns
|
||||
|
||||
__all__ = ["prompts_ns"]
|
||||
191
application/api/user/prompts/routes.py
Normal file
191
application/api/user/prompts/routes.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Prompt management routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
"prompts", description="Prompt management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@prompts_ns.route("/create_prompt")
|
||||
class CreatePrompt(Resource):
|
||||
create_prompt_model = api.model(
|
||||
"CreatePromptModel",
|
||||
{
|
||||
"content": fields.String(
|
||||
required=True, description="Content of the prompt"
|
||||
),
|
||||
"name": fields.String(required=True, description="Name of the prompt"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_prompt_model)
|
||||
@api.doc(description="Create a new prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["content", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"content": data["content"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_prompts")
|
||||
class GetPrompts(Resource):
|
||||
@api.doc(description="Get all prompts for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = [
|
||||
{"id": "default", "name": "default", "type": "public"},
|
||||
{"id": "creative", "name": "creative", "type": "public"},
|
||||
{"id": "strict", "name": "strict", "type": "public"},
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
list_prompts.append(
|
||||
{
|
||||
"id": str(prompt["_id"]),
|
||||
"name": prompt["name"],
|
||||
"type": "private",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_prompts), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_single_prompt")
|
||||
class GetSinglePrompt(Resource):
|
||||
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
prompt_id = request.args.get("id")
|
||||
if not prompt_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
if prompt_id == "default":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_combine_template = f.read()
|
||||
return make_response(jsonify({"content": chat_combine_template}), 200)
|
||||
elif prompt_id == "creative":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_reduce_creative = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_creative}), 200)
|
||||
elif prompt_id == "strict":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
|
||||
) as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"content": prompt["content"]}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/delete_prompt")
|
||||
class DeletePrompt(Resource):
|
||||
delete_prompt_model = api.model(
|
||||
"DeletePromptModel",
|
||||
{"id": fields.String(required=True, description="Prompt ID to delete")},
|
||||
)
|
||||
|
||||
@api.expect(delete_prompt_model)
|
||||
@api.doc(description="Delete a prompt by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/update_prompt")
|
||||
class UpdatePrompt(Resource):
|
||||
update_prompt_model = api.model(
|
||||
"UpdatePromptModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Prompt ID to update"),
|
||||
"name": fields.String(required=True, description="New name of the prompt"),
|
||||
"content": fields.String(
|
||||
required=True, description="New content of the prompt"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_prompt_model)
|
||||
@api.doc(description="Update an existing prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name", "content"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
File diff suppressed because it is too large
Load Diff
5
application/api/user/sharing/__init__.py
Normal file
5
application/api/user/sharing/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Sharing module."""
|
||||
|
||||
from .routes import sharing_ns
|
||||
|
||||
__all__ = ["sharing_ns"]
|
||||
301
application/api/user/sharing/routes.py
Normal file
301
application/api/user/sharing/routes.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Conversation sharing routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, inputs, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
db,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sharing_ns.route("/share")
|
||||
class ShareConversation(Resource):
|
||||
share_conversation_model = api.model(
|
||||
"ShareConversationModel",
|
||||
{
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="Conversation ID"
|
||||
),
|
||||
"user": fields.String(description="User ID (optional)"),
|
||||
"prompt_id": fields.String(description="Prompt ID (optional)"),
|
||||
"chunks": fields.Integer(description="Chunks count (optional)"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(share_conversation_model)
|
||||
@api.doc(description="Share a conversation")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
|
||||
if is_promptable is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "isPromptable is required"}), 400
|
||||
)
|
||||
conversation_id = data["conversation_id"]
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
current_n_queries = len(conversation["queries"])
|
||||
explicit_binary = Binary.from_uuid(
|
||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
||||
)
|
||||
|
||||
if is_promptable:
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = data.get("chunks", "2")
|
||||
|
||||
name = conversation["name"] + "(shared)"
|
||||
new_api_key_data = {
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
||||
if pre_existing_api_document:
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
new_api_key_data["key"] = api_uuid
|
||||
new_api_key_data["name"] = name
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
agents_collection.insert_one(new_api_key_data)
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId(conversation_id)
|
||||
),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": {
|
||||
"$ref": "conversations",
|
||||
"$id": ObjectId(conversation_id),
|
||||
},
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error sharing conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sharing_ns.route("/shared_conversation/<string:identifier>")
|
||||
class GetPubliclySharedConversations(Resource):
|
||||
@api.doc(description="Get publicly shared conversations by identifier")
|
||||
def get(self, identifier: str):
|
||||
try:
|
||||
query_uuid = Binary.from_uuid(
|
||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
||||
)
|
||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
||||
conversation_queries = []
|
||||
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
and isinstance(shared["conversation_id"], DBRef)
|
||||
):
|
||||
conversation_ref = shared["conversation_id"]
|
||||
conversation = db.dereference(conversation_ref)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
conversation_queries = conversation["queries"][
|
||||
: (shared["first_n_queries"])
|
||||
]
|
||||
|
||||
for query in conversation_queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
date = conversation["_id"].generation_time.isoformat()
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation["name"],
|
||||
"timestamp": date,
|
||||
}
|
||||
if shared["isPromptable"] and "api_key" in shared:
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting shared conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
7
application/api/user/sources/__init__.py
Normal file
7
application/api/user/sources/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Sources module."""
|
||||
|
||||
from .chunks import sources_chunks_ns
|
||||
from .routes import sources_ns
|
||||
from .upload import sources_upload_ns
|
||||
|
||||
__all__ = ["sources_ns", "sources_chunks_ns", "sources_upload_ns"]
|
||||
278
application/api/user/sources/chunks.py
Normal file
278
application/api/user/sources/chunks.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Source document management chunk management."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import get_vector_store, sources_collection
|
||||
from application.utils import check_required_fields, num_tokens_from_string
|
||||
|
||||
sources_chunks_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/get_chunks")
|
||||
class GetChunks(Resource):
|
||||
@api.doc(
|
||||
description="Retrieves chunks from a document, optionally filtered by file path and search term",
|
||||
params={
|
||||
"id": "The document ID",
|
||||
"page": "Page number for pagination",
|
||||
"per_page": "Number of chunks per page",
|
||||
"path": "Optional: Filter chunks by relative file path",
|
||||
"search": "Optional: Search term to filter chunks by title or content",
|
||||
},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
page = int(request.args.get("page", 1))
|
||||
per_page = int(request.args.get("per_page", 10))
|
||||
path = request.args.get("path")
|
||||
search_term = request.args.get("search", "").strip().lower()
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
metadata = chunk.get("metadata", {})
|
||||
|
||||
# Filter by path if provided
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
# Check if the chunk's source matches the requested path
|
||||
|
||||
if not chunk_source or not chunk_source.endswith(path):
|
||||
continue
|
||||
# Filter by search term if provided
|
||||
|
||||
if search_term:
|
||||
text_match = search_term in chunk.get("text", "").lower()
|
||||
title_match = search_term in metadata.get("title", "").lower()
|
||||
|
||||
if not (text_match or title_match):
|
||||
continue
|
||||
filtered_chunks.append(chunk)
|
||||
chunks = filtered_chunks
|
||||
|
||||
total_chunks = len(chunks)
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
paginated_chunks = chunks[start:end]
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"total": total_chunks,
|
||||
"chunks": paginated_chunks,
|
||||
"path": path if path else None,
|
||||
"search": search_term if search_term else None,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/add_chunk")
|
||||
class AddChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AddChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"text": fields.String(required=True, description="Text of the chunk"),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Adds a new chunk to the document",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "text"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata", {})
|
||||
token_count = num_tokens_from_string(text)
|
||||
metadata["token_count"] = token_count
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunk_id = store.add_chunk(text, metadata)
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error adding chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/delete_chunk")
|
||||
class DeleteChunk(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a specific chunk from the document.",
|
||||
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
chunk_id = request.args.get("chunk_id")
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if deleted:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk deleted successfully"}), 200
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk not found or could not be deleted"}),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error deleting chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/update_chunk")
|
||||
class UpdateChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"chunk_id": fields.String(
|
||||
required=True, description="Chunk ID to update"
|
||||
),
|
||||
"text": fields.String(
|
||||
required=False, description="New text of the chunk"
|
||||
),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Updated metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates an existing chunk in the document.",
|
||||
)
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "chunk_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
chunk_id = data.get("chunk_id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata")
|
||||
|
||||
if text is not None:
|
||||
token_count = num_tokens_from_string(text)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["token_count"] = token_count
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
|
||||
chunks = store.get_chunks()
|
||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||
if not existing_chunk:
|
||||
return make_response(jsonify({"error": "Chunk not found"}), 404)
|
||||
new_text = text if text is not None else existing_chunk["text"]
|
||||
|
||||
if metadata is not None:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
new_metadata.update(metadata)
|
||||
else:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
if text is not None:
|
||||
new_metadata["token_count"] = num_tokens_from_string(new_text)
|
||||
try:
|
||||
new_chunk_id = store.add_chunk(new_text, new_metadata)
|
||||
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if not deleted:
|
||||
current_app.logger.warning(
|
||||
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"message": "Chunk updated successfully",
|
||||
"chunk_id": new_chunk_id,
|
||||
"original_chunk_id": chunk_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as add_error:
|
||||
current_app.logger.error(f"Failed to add updated chunk: {add_error}")
|
||||
return make_response(
|
||||
jsonify({"error": "Failed to update chunk - addition failed"}), 500
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error updating chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
350
application/api/user/sources/routes.py
Normal file
350
application/api/user/sources/routes.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Source document management routes."""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
sources_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_ns.route("/sources")
|
||||
class CombinedJson(Resource):
|
||||
@api.doc(description="Provide JSON file with combined available indexes")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = [
|
||||
{
|
||||
"name": "Default",
|
||||
"date": "default",
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "remote",
|
||||
"tokens": "",
|
||||
"retriever": "classic",
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||
data.append(
|
||||
{
|
||||
"id": str(index["_id"]),
|
||||
"name": index.get("name"),
|
||||
"date": index.get("date"),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": index.get("tokens", ""),
|
||||
"retriever": index.get("retriever", "classic"),
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
), # Add type field with default "file"
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sources/paginated")
|
||||
class PaginatedSources(Resource):
|
||||
@api.doc(description="Get document with pagination, sorting and filtering")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||
# add .strip() to remove leading and trailing whitespaces
|
||||
|
||||
search_term = request.args.get(
|
||||
"search", ""
|
||||
).strip() # add search for filter documents
|
||||
|
||||
# Prepare query for filtering
|
||||
|
||||
query = {"user": user}
|
||||
if search_term:
|
||||
query["name"] = {
|
||||
"$regex": search_term,
|
||||
"$options": "i", # using case-insensitive search
|
||||
}
|
||||
total_documents = sources_collection.count_documents(query)
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
page = min(
|
||||
max(1, page), total_pages
|
||||
) # add this to make sure page inbound is within the range
|
||||
sort_order = 1 if sort_order == "asc" else -1
|
||||
skip = (page - 1) * rows_per_page
|
||||
|
||||
try:
|
||||
documents = (
|
||||
sources_collection.find(query)
|
||||
.sort(sort_field, sort_order)
|
||||
.skip(skip)
|
||||
.limit(rows_per_page)
|
||||
)
|
||||
|
||||
paginated_docs = []
|
||||
for doc in documents:
|
||||
doc_data = {
|
||||
"id": str(doc["_id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
response = {
|
||||
"total": total_documents,
|
||||
"totalPages": total_pages,
|
||||
"currentPage": page,
|
||||
"paginated": paginated_docs,
|
||||
}
|
||||
return make_response(jsonify(response), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving paginated sources: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/docs_check")
|
||||
class CheckDocs(Resource):
|
||||
check_docs_model = api.model(
|
||||
"CheckDocsModel",
|
||||
{"docs": fields.String(required=True, description="Document name")},
|
||||
)
|
||||
|
||||
@api.expect(check_docs_model)
|
||||
@api.doc(description="Check if document exists")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["docs"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
vectorstore = "vectors/" + secure_filename(data["docs"])
|
||||
if os.path.exists(vectorstore) or data["docs"] == "default":
|
||||
return {"status": "exists"}, 200
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error checking document: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_by_ids")
|
||||
class DeleteByIds(Resource):
|
||||
@api.doc(
|
||||
description="Deletes documents from the vector store by IDs",
|
||||
params={"path": "Comma-separated list of IDs"},
|
||||
)
|
||||
def get(self):
|
||||
ids = request.args.get("path")
|
||||
if not ids:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_index(ids=ids)
|
||||
if result:
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_old")
|
||||
class DeleteOldIndexes(Resource):
|
||||
@api.doc(
|
||||
description="Deletes old indexes and associated files",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
index_path = f"indexes/{str(doc['_id'])}"
|
||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||
storage.delete_file(f"{index_path}/index.faiss")
|
||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||
storage.delete_file(f"{index_path}/index.pkl")
|
||||
else:
|
||||
vectorstore = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
file_path = doc["file_path"]
|
||||
if storage.is_directory(file_path):
|
||||
files = storage.list_files(file_path)
|
||||
for f in files:
|
||||
storage.delete_file(f)
|
||||
else:
|
||||
storage.delete_file(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/combine")
|
||||
class RedirectToSources(Resource):
|
||||
@api.doc(
|
||||
description="Redirects /api/combine to /api/sources for backward compatibility"
|
||||
)
|
||||
def get(self):
|
||||
return redirect("/api/sources", code=301)
|
||||
|
||||
|
||||
@sources_ns.route("/manage_sync")
|
||||
class ManageSync(Resource):
|
||||
manage_sync_model = api.model(
|
||||
"ManageSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID"),
|
||||
"sync_frequency": fields.String(
|
||||
required=True,
|
||||
description="Sync frequency (never, daily, weekly, monthly)",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(manage_sync_model)
|
||||
@api.doc(description="Manage sync frequency for sources")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["source_id", "sync_frequency"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
sync_frequency = data["sync_frequency"]
|
||||
|
||||
if sync_frequency not in ["never", "daily", "weekly", "monthly"]:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||
)
|
||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
||||
try:
|
||||
sources_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(source_id),
|
||||
"user": user,
|
||||
},
|
||||
update_data,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating sync frequency: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
description="Get the directory structure for a document",
|
||||
params={"id": "The document ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
try:
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
directory_structure = doc.get("directory_structure", {})
|
||||
base_path = doc.get("file_path", "")
|
||||
|
||||
provider = None
|
||||
remote_data = doc.get("remote_data")
|
||||
try:
|
||||
if isinstance(remote_data, str) and remote_data:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
provider = remote_data_obj.get("provider")
|
||||
except Exception as e:
|
||||
current_app.logger.warning(
|
||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"directory_structure": directory_structure,
|
||||
"base_path": base_path,
|
||||
"provider": provider,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
572
application/api/user/sources/upload.py
Normal file
572
application/api/user/sources/upload.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""Source document management upload functionality."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields, safe_filename
|
||||
|
||||
|
||||
sources_upload_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/upload")
|
||||
class UploadFile(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Missing required fields or files",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
|
||||
safe_user = safe_filename(user)
|
||||
dir_name = safe_filename(job_name)
|
||||
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
for file in files:
|
||||
original_filename = file.filename
|
||||
safe_file = safe_filename(original_filename)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||
file.save(temp_file_path)
|
||||
|
||||
if zipfile.is_zipfile(temp_file_path):
|
||||
try:
|
||||
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(path=temp_dir)
|
||||
|
||||
# Walk through extracted files and upload them
|
||||
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for extracted_file in files:
|
||||
if (
|
||||
os.path.join(root, extracted_file)
|
||||
== temp_file_path
|
||||
):
|
||||
continue
|
||||
rel_path = os.path.relpath(
|
||||
os.path.join(root, extracted_file), temp_dir
|
||||
)
|
||||
storage_path = f"{base_path}/{rel_path}"
|
||||
|
||||
with open(
|
||||
os.path.join(root, extracted_file), "rb"
|
||||
) as f:
|
||||
storage.save_file(f, storage_path)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error extracting zip: {e}", exc_info=True
|
||||
)
|
||||
# If zip extraction fails, save the original zip file
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
else:
|
||||
# For non-zip files, save directly
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
[
|
||||
".rst",
|
||||
".md",
|
||||
".pdf",
|
||||
".txt",
|
||||
".docx",
|
||||
".csv",
|
||||
".epub",
|
||||
".html",
|
||||
".mdx",
|
||||
".json",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
],
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/remote")
|
||||
class UploadRemote(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RemoteUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source of the data"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Data to process"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads remote source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Process file_ids
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
# Process folder_ids
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [
|
||||
id.strip() for id in folder_ids.split(",") if id.strip()
|
||||
]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading remote source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/manage_source_files")
|
||||
class ManageSourceFiles(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ManageSourceFilesModel",
|
||||
{
|
||||
"source_id": fields.String(
|
||||
required=True, description="Source ID to modify"
|
||||
),
|
||||
"operation": fields.String(
|
||||
required=True,
|
||||
description="Operation: 'add', 'remove', or 'remove_directory'",
|
||||
),
|
||||
"file_paths": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="File paths to remove (for remove operation)",
|
||||
),
|
||||
"directory_path": fields.String(
|
||||
required=False,
|
||||
description="Directory path to remove (for remove_directory operation)",
|
||||
),
|
||||
"file": fields.Raw(
|
||||
required=False, description="Files to add (for add operation)"
|
||||
),
|
||||
"parent_dir": fields.String(
|
||||
required=False,
|
||||
description="Parent directory path relative to source root",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Add files, remove files, or remove directories from an existing source",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
|
||||
if not source_id or not operation:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "source_id and operation are required",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation not in ["add", "remove", "remove_directory"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "operation must be 'add', 'remove', or 'remove_directory'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
ObjectId(source_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||
)
|
||||
try:
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Source not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
|
||||
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid parent directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation == "add":
|
||||
files = request.files.getlist("file")
|
||||
if not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "No files provided for add operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
added_files = []
|
||||
|
||||
target_dir = source_file_path
|
||||
if parent_dir:
|
||||
target_dir = f"{source_file_path}/{parent_dir}"
|
||||
for file in files:
|
||||
if file.filename:
|
||||
safe_filename_str = safe_filename(file.filename)
|
||||
file_path = f"{target_dir}/{safe_filename_str}"
|
||||
|
||||
# Save file to storage
|
||||
|
||||
storage.save_file(file, file_path)
|
||||
added_files.append(safe_filename_str)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Added {len(added_files)} files",
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths")
|
||||
if not file_paths_str:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "file_paths required for remove operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
file_paths = (
|
||||
json.loads(file_paths_str)
|
||||
if isinstance(file_paths_str, str)
|
||||
else file_paths_str
|
||||
)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid file_paths format"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
|
||||
if storage.file_exists(full_path):
|
||||
storage.delete_file(full_path)
|
||||
removed_files.append(file_path)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path")
|
||||
if not directory_path:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "directory_path required for remove_directory operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate directory path (prevent path traversal)
|
||||
|
||||
if directory_path.startswith("/") or ".." in directory_path:
|
||||
current_app.logger.warning(
|
||||
f"Invalid directory path attempted for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
full_directory_path = (
|
||||
f"{source_file_path}/{directory_path}"
|
||||
if directory_path
|
||||
else source_file_path
|
||||
)
|
||||
|
||||
if not storage.is_directory(full_directory_path):
|
||||
current_app.logger.warning(
|
||||
f"Directory not found or is not a directory for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Directory not found or is not a directory",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
current_app.logger.error(
|
||||
f"Failed to remove directory from storage. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Successfully removed directory. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path", "")
|
||||
error_context += f", directory_path={directory_path}"
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths", "")
|
||||
error_context += f", file_paths={file_paths_str}"
|
||||
elif operation == "add":
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
error_context += f", parent_dir={parent_dir}"
|
||||
current_app.logger.error(
|
||||
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Operation failed"}), 500
|
||||
)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/task_status")
|
||||
class TaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"TaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get celery job status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta) # Convert to a string representation
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
6
application/api/user/tools/__init__.py
Normal file
6
application/api/user/tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Tools module."""
|
||||
|
||||
from .mcp import tools_mcp_ns
|
||||
from .routes import tools_ns
|
||||
|
||||
__all__ = ["tools_ns", "tools_mcp_ns"]
|
||||
333
application/api/user/tools/mcp.py
Normal file
333
application/api/user/tools/mcp.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from email.quoprimime import unquote
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.cache import get_redis_instance
|
||||
from application.security.encryption import encrypt_credentials
|
||||
from application.utils import check_required_fields
|
||||
|
||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerTestModel",
|
||||
{
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration to test"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Test MCP server connection with provided configuration")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
|
||||
if auth_type == "api_key" and "api_key" in config:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer" and "bearer_token" in config:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config:
|
||||
auth_credentials["password"] = config["password"]
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
|
||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/save")
|
||||
class MCPServerSave(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerSaveModel",
|
||||
{
|
||||
"id": fields.String(
|
||||
required=False, description="Tool ID for updates (optional)"
|
||||
),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the MCP server"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=False, default=True, description="Tool status"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create or update MCP server with automatic tool discovery")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["displayName", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "api_key":
|
||||
if "api_key" in config and config["api_key"]:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer":
|
||||
if "bearer_token" in config and config["bearer_token"]:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config and config["username"]:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config and config["password"]:
|
||||
auth_credentials["password"] = config["password"]
|
||||
mcp_config = config.copy()
|
||||
mcp_config["auth_credentials"] = auth_credentials
|
||||
|
||||
if auth_type == "oauth":
|
||||
if not config.get("oauth_task_id"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Connection not authorized. Please complete the OAuth authorization first.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "OAuth failed or not completed. Please try authorizing again.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
actions_metadata = result.get("tools", [])
|
||||
elif auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(config=mcp_config, user_id=user)
|
||||
mcp_tool.discover_tools()
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
else:
|
||||
raise Exception(
|
||||
"No valid credentials provided for the selected authentication type"
|
||||
)
|
||||
storage_config = config.copy()
|
||||
if auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
transformed_actions = []
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"displayName": data["displayName"],
|
||||
"customName": data["displayName"],
|
||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": data.get("status", True),
|
||||
"user": user,
|
||||
}
|
||||
|
||||
tool_id = data.get("id")
|
||||
if tool_id:
|
||||
result = user_tools_collection.update_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
result = user_tools_collection.insert_one(tool_data)
|
||||
tool_id = str(result.inserted_id)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/callback")
|
||||
class MCPOAuthCallback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCallbackModel",
|
||||
{
|
||||
"code": fields.String(required=True, description="Authorization code"),
|
||||
"state": fields.String(required=True, description="State parameter"),
|
||||
"error": fields.String(
|
||||
required=False, description="Error message (if any)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Handle OAuth callback by providing the authorization code and state"
|
||||
)
|
||||
def get(self):
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
if not code or not state:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
if not redis_client:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||
)
|
||||
code = unquote(code)
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
success = manager.handle_oauth_callback(state, code, error)
|
||||
if success:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
|
||||
)
|
||||
else:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||
)
|
||||
return redirect(
|
||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
"""
|
||||
Get current status of OAuth flow.
|
||||
Frontend should poll this endpoint periodically.
|
||||
"""
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Task not found or expired",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
||||
)
|
||||
415
application/api/user/tools/routes.py
Normal file
415
application/api/user/tools/routes.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""Tool management routes."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
|
||||
@tools_ns.route("/available_tools")
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
doc = tool_instance.__doc__.strip()
|
||||
lines = doc.split("\n", 1)
|
||||
name = lines[0].strip()
|
||||
description = lines[1].strip() if len(lines) > 1 else ""
|
||||
tools_metadata.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"displayName": name,
|
||||
"description": description,
|
||||
"configRequirements": tool_instance.get_config_requirements(),
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting available tools: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/get_tools")
|
||||
class GetTools(Resource):
|
||||
@api.doc(description="Get tools created by a user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
user_tools = []
|
||||
for tool in tools:
|
||||
tool["id"] = str(tool["_id"])
|
||||
tool.pop("_id")
|
||||
user_tools.append(tool)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/create_tool")
|
||||
class CreateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateToolModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="Name of the tool"),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the tool"
|
||||
),
|
||||
"description": fields.String(
|
||||
required=True, description="Tool description"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
"customName": fields.String(
|
||||
required=False, description="Custom name for the tool"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create a new tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = [
|
||||
"name",
|
||||
"displayName",
|
||||
"description",
|
||||
"config",
|
||||
"status",
|
||||
]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
actions_metadata = tool_instance.get_actions_metadata()
|
||||
transformed_actions = []
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
if "properties" in action["parameters"]:
|
||||
for param_name, param_details in action["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed_actions.append(action)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
try:
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"displayName": data["displayName"],
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": data["config"],
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool")
|
||||
class UpdateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"name": fields.String(description="Name of the tool"),
|
||||
"displayName": fields.String(description="Display name for the tool"),
|
||||
"customName": fields.String(description="Custom name for the tool"),
|
||||
"description": fields.String(description="Tool description"),
|
||||
"config": fields.Raw(description="Configuration of the tool"),
|
||||
"actions": fields.List(
|
||||
fields.Raw, description="Actions the tool can perform"
|
||||
),
|
||||
"status": fields.Boolean(description="Status of the tool"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
update_data = {}
|
||||
if "name" in data:
|
||||
update_data["name"] = data["name"]
|
||||
if "displayName" in data:
|
||||
update_data["displayName"] = data["displayName"]
|
||||
if "customName" in data:
|
||||
update_data["customName"] = data["customName"]
|
||||
if "description" in data:
|
||||
update_data["description"] = data["description"]
|
||||
if "actions" in data:
|
||||
update_data["actions"] = data["actions"]
|
||||
if "config" in data:
|
||||
if "actions" in data["config"]:
|
||||
for action_name in list(data["config"]["actions"].keys()):
|
||||
if not validate_function_name(action_name):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
|
||||
"param": "tools[].function.name",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if tool_doc and tool_doc.get("name") == "mcp_tool":
|
||||
config = data["config"]
|
||||
existing_config = tool_doc.get("config", {})
|
||||
storage_config = existing_config.copy()
|
||||
|
||||
storage_config.update(config)
|
||||
existing_credentials = {}
|
||||
if "encrypted_credentials" in existing_config:
|
||||
existing_credentials = decrypt_credentials(
|
||||
existing_config["encrypted_credentials"], user
|
||||
)
|
||||
auth_credentials = existing_credentials.copy()
|
||||
auth_type = storage_config.get("auth_type", "none")
|
||||
if auth_type == "api_key":
|
||||
if "api_key" in config and config["api_key"]:
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if "api_key_header" in config:
|
||||
auth_credentials["api_key_header"] = config[
|
||||
"api_key_header"
|
||||
]
|
||||
elif auth_type == "bearer":
|
||||
if "bearer_token" in config and config["bearer_token"]:
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif "encrypted_token" in config and config["encrypted_token"]:
|
||||
auth_credentials["bearer_token"] = config["encrypted_token"]
|
||||
elif auth_type == "basic":
|
||||
if "username" in config and config["username"]:
|
||||
auth_credentials["username"] = config["username"]
|
||||
if "password" in config and config["password"]:
|
||||
auth_credentials["password"] = config["password"]
|
||||
if auth_type != "none" and auth_credentials:
|
||||
encrypted_credentials_string = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
storage_config["encrypted_credentials"] = (
|
||||
encrypted_credentials_string
|
||||
)
|
||||
elif auth_type == "none":
|
||||
storage_config.pop("encrypted_credentials", None)
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"encrypted_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
update_data["config"] = storage_config
|
||||
else:
|
||||
update_data["config"] = data["config"]
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": update_data},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_config")
|
||||
class UpdateToolConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolConfigModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the configuration of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"config": data["config"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool config: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_actions")
|
||||
class UpdateToolActions(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolActionsModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"actions": fields.List(
|
||||
fields.Raw,
|
||||
required=True,
|
||||
description="Actions the tool can perform",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the actions of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "actions"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_status")
|
||||
class UpdateToolStatus(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolStatusModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the status of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "status"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool status: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/delete_tool")
|
||||
class DeleteTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DeleteToolModel",
|
||||
{"id": fields.String(required=True, description="Tool ID")},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Delete a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return {"success": False, "message": "Tool not found"}, 404
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||
return {"success": False}, 400
|
||||
return {"success": True}, 200
|
||||
@@ -41,10 +41,18 @@ class Settings(BaseSettings):
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
GOOGLE_CLIENT_ID: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client ID
|
||||
)
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client secret
|
||||
)
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
@@ -1,44 +1,135 @@
|
||||
import base64
|
||||
import requests
|
||||
from typing import List
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from application.parser.remote.base import BaseRemote
|
||||
from langchain_core.documents import Document
|
||||
from application.parser.schema.base import Document
|
||||
import mimetypes
|
||||
from application.core.settings import settings
|
||||
|
||||
class GitHubLoader(BaseRemote):
|
||||
def __init__(self):
|
||||
self.access_token = None
|
||||
self.access_token = settings.GITHUB_ACCESS_TOKEN
|
||||
self.headers = {
|
||||
"Authorization": f"token {self.access_token}"
|
||||
} if self.access_token else {}
|
||||
"Authorization": f"token {self.access_token}",
|
||||
"Accept": "application/vnd.github.v3+json"
|
||||
} if self.access_token else {
|
||||
"Accept": "application/vnd.github.v3+json"
|
||||
}
|
||||
return
|
||||
|
||||
def fetch_file_content(self, repo_url: str, file_path: str) -> str:
|
||||
def is_text_file(self, file_path: str) -> bool:
|
||||
"""Determine if a file is a text file based on extension."""
|
||||
# Common text file extensions
|
||||
text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.json', '.xml', '.yaml', '.yml',
|
||||
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
|
||||
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala',
|
||||
'.html', '.css', '.scss', '.sass', '.less',
|
||||
'.sh', '.bash', '.zsh', '.fish',
|
||||
'.sql', '.r', '.m', '.mat',
|
||||
'.ini', '.cfg', '.conf', '.config', '.env',
|
||||
'.gitignore', '.dockerignore', '.editorconfig',
|
||||
'.log', '.csv', '.tsv'
|
||||
}
|
||||
|
||||
# Get file extension
|
||||
file_lower = file_path.lower()
|
||||
for ext in text_extensions:
|
||||
if file_lower.endswith(ext):
|
||||
return True
|
||||
|
||||
# Also check MIME type
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
if mime_type and (mime_type.startswith("text") or mime_type in ["application/json", "application/xml"]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fetch_file_content(self, repo_url: str, file_path: str) -> Optional[str]:
|
||||
"""Fetch file content. Returns None if file should be skipped (binary files or empty files)."""
|
||||
url = f"https://api.github.com/repos/{repo_url}/contents/{file_path}"
|
||||
response = requests.get(url, headers=self.headers)
|
||||
response = self._make_request(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.json()
|
||||
mime_type, _ = mimetypes.guess_type(file_path) # Guess the MIME type based on the file extension
|
||||
content = response.json()
|
||||
|
||||
if content.get("encoding") == "base64":
|
||||
if mime_type and mime_type.startswith("text"): # Handle only text files
|
||||
try:
|
||||
decoded_content = base64.b64decode(content["content"]).decode("utf-8")
|
||||
return f"Filename: {file_path}\n\n{decoded_content}"
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
return f"Filename: {file_path} is a binary file and was skipped."
|
||||
if content.get("encoding") == "base64":
|
||||
if self.is_text_file(file_path): # Handle only text files
|
||||
try:
|
||||
decoded_content = base64.b64decode(content["content"]).decode("utf-8").strip()
|
||||
# Skip empty files
|
||||
if not decoded_content:
|
||||
return None
|
||||
return decoded_content
|
||||
except Exception:
|
||||
# If decoding fails, it's probably a binary file
|
||||
return None
|
||||
else:
|
||||
return f"Filename: {file_path}\n\n{content['content']}"
|
||||
# Skip binary files by returning None
|
||||
return None
|
||||
else:
|
||||
response.raise_for_status()
|
||||
file_content = content['content'].strip()
|
||||
# Skip empty files
|
||||
if not file_content:
|
||||
return None
|
||||
return file_content
|
||||
|
||||
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
for attempt in range(max_retries):
|
||||
response = requests.get(url, headers=self.headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
elif response.status_code == 403:
|
||||
# Check if it's a rate limit issue
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("message", "")
|
||||
|
||||
# Check rate limit headers
|
||||
remaining = response.headers.get("X-RateLimit-Remaining", "unknown")
|
||||
reset_time = response.headers.get("X-RateLimit-Reset", "unknown")
|
||||
|
||||
print(f"GitHub API 403 Error: {error_msg}")
|
||||
print(f"Rate limit remaining: {remaining}, Reset time: {reset_time}")
|
||||
|
||||
if "rate limit" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
print(f"Rate limit hit, waiting {wait_time} seconds before retry...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# Provide helpful error message
|
||||
if remaining == "0":
|
||||
raise Exception(f"GitHub API rate limit exceeded. Please set GITHUB_ACCESS_TOKEN environment variable. Reset time: {reset_time}")
|
||||
else:
|
||||
raise Exception(f"GitHub API error: {error_msg}. This may require authentication - set GITHUB_ACCESS_TOKEN environment variable.")
|
||||
except Exception as e:
|
||||
if isinstance(e, Exception) and "GitHub API" in str(e):
|
||||
raise
|
||||
# If we can't parse the response, raise the original error
|
||||
response.raise_for_status()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
def fetch_repo_files(self, repo_url: str, path: str = "") -> List[str]:
|
||||
url = f"https://api.github.com/repos/{repo_url}/contents/{path}"
|
||||
response = requests.get(url, headers={**self.headers, "Accept": "application/vnd.github.v3.raw"})
|
||||
response = self._make_request(url)
|
||||
|
||||
contents = response.json()
|
||||
|
||||
# Handle error responses from GitHub API
|
||||
if isinstance(contents, dict) and "message" in contents:
|
||||
raise Exception(f"GitHub API error: {contents.get('message')}")
|
||||
|
||||
# Ensure contents is a list
|
||||
if not isinstance(contents, list):
|
||||
raise TypeError(f"Expected list from GitHub API, got {type(contents).__name__}: {contents}")
|
||||
|
||||
files = []
|
||||
for item in contents:
|
||||
if item["type"] == "file":
|
||||
@@ -53,6 +144,15 @@ class GitHubLoader(BaseRemote):
|
||||
documents = []
|
||||
for file_path in files:
|
||||
content = self.fetch_file_content(repo_name, file_path)
|
||||
documents.append(Document(page_content=content, metadata={"title": file_path,
|
||||
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"}))
|
||||
# Skip binary files (content is None)
|
||||
if content is None:
|
||||
continue
|
||||
documents.append(Document(
|
||||
text=content,
|
||||
doc_id=file_path,
|
||||
extra_info={
|
||||
"title": file_path,
|
||||
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"
|
||||
}
|
||||
))
|
||||
return documents
|
||||
|
||||
0
application/seed/__init__.py
Normal file
0
application/seed/__init__.py
Normal file
26
application/seed/commands.py
Normal file
26
application/seed/commands.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import click
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.seed.seeder import DatabaseSeeder
|
||||
|
||||
|
||||
@click.group()
|
||||
def seed():
|
||||
"""Database seeding commands"""
|
||||
pass
|
||||
|
||||
|
||||
@seed.command()
|
||||
@click.option("--force", is_flag=True, help="Force reseeding even if data exists")
|
||||
def init(force):
|
||||
"""Initialize database with seed data"""
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
seeder = DatabaseSeeder(db)
|
||||
seeder.seed_initial_data(force=force)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed()
|
||||
36
application/seed/config/agents_template.yaml
Normal file
36
application/seed/config/agents_template.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
# Configuration for Premade Agents
|
||||
# This file contains template agents that will be seeded into the database
|
||||
|
||||
agents:
|
||||
# Basic Agent Template
|
||||
- name: "Agent Name" # Required: Unique name for the agent
|
||||
description: "What this agent does" # Required: Brief description of the agent's purpose
|
||||
image: "URL_TO_IMAGE" # Optional: URL to agent's avatar/image
|
||||
agent_type: "classic" # Required: Type of agent (e.g., classic, react, etc.)
|
||||
prompt_id: "default" # Optional: Reference to prompt template
|
||||
prompt: # Optional: Define new prompt
|
||||
name: "New Prompt"
|
||||
content: "You are new agent with cool new prompt."
|
||||
chunks: "0" # Optional: Chunking strategy for documents
|
||||
retriever: "" # Optional: Retriever type for document search
|
||||
|
||||
# Source Configuration (where the agent gets its knowledge)
|
||||
source: # Optional: Select a source to link with agent
|
||||
name: "Source Display Name" # Human-readable name for the source
|
||||
url: "https://example.com/data-source" # URL or path to knowledge source
|
||||
loader: "url" # Type of loader (url, pdf, txt, etc.)
|
||||
|
||||
# Tools Configuration (what capabilities the agent has)
|
||||
tools: # Optional: Remove if agent doesn't need tools
|
||||
- name: "tool_name" # Must match a supported tool name
|
||||
display_name: "Tool Display Name" # Optional: Human-readable name for the tool
|
||||
config:
|
||||
# Tool-specific configuration
|
||||
# Example for DuckDuckGo:
|
||||
# token: "${DDG_API_KEY}" # ${} denotes environment variable
|
||||
|
||||
# Add more tools as needed
|
||||
# - name: "another_tool"
|
||||
# config:
|
||||
# param1: "value1"
|
||||
# param2: "${ENV_VAR}"
|
||||
94
application/seed/config/premade_agents.yaml
Normal file
94
application/seed/config/premade_agents.yaml
Normal file
@@ -0,0 +1,94 @@
|
||||
# Configuration for Premade Agents
|
||||
|
||||
agents:
|
||||
- name: "Assistant"
|
||||
description: "Your general-purpose AI assistant. Ready to help with a wide range of tasks."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-logo.svg"
|
||||
agent_type: "classic"
|
||||
prompt_id: "default"
|
||||
chunks: "0"
|
||||
retriever: ""
|
||||
|
||||
# Tools Configuration
|
||||
tools:
|
||||
- name: "tool_name"
|
||||
display_name: "read_webpage"
|
||||
config:
|
||||
|
||||
- name: "Researcher"
|
||||
description: "A specialized research agent that performs deep dives into subjects."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-researcher.svg"
|
||||
agent_type: "react"
|
||||
prompt:
|
||||
name: "Researcher-Agent"
|
||||
content: |
|
||||
You are a specialized AI research assistant, DocsGPT. Your primary function is to conduct in-depth research on a given subject or question. You are methodical, thorough, and analytical. You should perform multiple iterations of thinking to gather and synthesize information before providing a final, comprehensive answer.
|
||||
|
||||
You have access to the 'Read Webpage' tool. Use this tool to explore sources, gather data, and deepen your understanding. Be proactive in using the tool to fill in knowledge gaps and validate information.
|
||||
|
||||
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
|
||||
If appropriate, your answers can include code examples, formatted as follows:
|
||||
```(language)
|
||||
(code)
|
||||
```
|
||||
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses. Try to respond with mermaid charts if visualization helps with users queries. You effectively utilize chat history, ensuring relevant and tailored responses. Try to use additional provided context if it's available, otherwise use your knowledge and tool capabilities.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "0"
|
||||
retriever: ""
|
||||
|
||||
# Tools Configuration
|
||||
tools:
|
||||
- name: "tool_name"
|
||||
display_name: "read_webpage"
|
||||
config:
|
||||
|
||||
- name: "Search Widget"
|
||||
description: "A powerful search widget agent. Ask it anything about DocsGPT"
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-search.svg"
|
||||
agent_type: "classic"
|
||||
prompt:
|
||||
name: "Search-Agent"
|
||||
content: |
|
||||
You are a website search assistant, DocsGPT. Your sole purpose is to help users find information within the provided context of the DocsGPT documentation. Act as a specialized search engine.
|
||||
|
||||
Your answers must be based *only* on the provided context. Do not use any external knowledge. If the answer is not in the context, inform the user that you could not find the information within the documentation.
|
||||
|
||||
Keep your responses concise and directly related to the user's query, pointing them to the most relevant information.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "8"
|
||||
retriever: ""
|
||||
|
||||
source:
|
||||
name: "DocsGPT-Docs"
|
||||
url: "https://d3dg1063dc54p9.cloudfront.net/agent-source/docsgpt-documentation.md" # URL to DocsGPT documentation
|
||||
loader: "url"
|
||||
|
||||
- name: "Support Widget"
|
||||
description: "A friendly support widget agent to help you with any questions."
|
||||
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-support.svg"
|
||||
agent_type: "classic"
|
||||
prompt:
|
||||
name: "Support-Agent"
|
||||
content: |
|
||||
You are a helpful AI support widget agent, DocsGPT. Your goal is to assist users by answering their questions about our website, product and its features. Provide friendly, clear, and direct support.
|
||||
|
||||
Your knowledge is strictly limited to the provided context from the DocsGPT documentation. You must not answer questions outside of this scope. If a user asks something you cannot answer from the context, politely state that you can only help with questions about this website.
|
||||
|
||||
Effectively utilize chat history to understand the user's issue fully. Guide users to the information they need in a helpful and conversational manner.
|
||||
----------------
|
||||
Possible additional context from uploaded sources:
|
||||
{summaries}
|
||||
|
||||
chunks: "8"
|
||||
retriever: ""
|
||||
|
||||
source:
|
||||
name: "DocsGPT-Docs"
|
||||
url: "https://d3dg1063dc54p9.cloudfront.net/agent-source/docsgpt-documentation.md" # URL to DocsGPT documentation
|
||||
loader: "url"
|
||||
277
application/seed/seeder.py
Normal file
277
application/seed/seeder.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from bson import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pymongo import MongoClient
|
||||
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api.user.tasks import ingest_remote
|
||||
|
||||
load_dotenv()
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
class DatabaseSeeder:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.tools_collection = self.db["user_tools"]
|
||||
self.sources_collection = self.db["sources"]
|
||||
self.agents_collection = self.db["agents"]
|
||||
self.prompts_collection = self.db["prompts"]
|
||||
self.system_user_id = "system"
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def seed_initial_data(self, config_path: str = None, force=False):
|
||||
"""Main entry point for seeding all initial data"""
|
||||
if not force and self._is_already_seeded():
|
||||
self.logger.info("Database already seeded. Use force=True to reseed.")
|
||||
return
|
||||
config_path = config_path or os.path.join(
|
||||
os.path.dirname(__file__), "config", "premade_agents.yaml"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
self._seed_from_config(config)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load seeding config: {str(e)}")
|
||||
raise
|
||||
|
||||
def _seed_from_config(self, config: Dict):
|
||||
"""Seed all data from configuration"""
|
||||
self.logger.info("🌱 Starting seeding...")
|
||||
|
||||
if not config.get("agents"):
|
||||
self.logger.warning("No agents found in config")
|
||||
return
|
||||
used_tool_ids = set()
|
||||
|
||||
for agent_config in config["agents"]:
|
||||
try:
|
||||
self.logger.info(f"Processing agent: {agent_config['name']}")
|
||||
|
||||
# 1. Handle Source
|
||||
|
||||
source_result = self._handle_source(agent_config)
|
||||
if source_result is False:
|
||||
self.logger.error(
|
||||
f"Skipping agent {agent_config['name']} due to source ingestion failure"
|
||||
)
|
||||
continue
|
||||
source_id = source_result
|
||||
# 2. Handle Tools
|
||||
|
||||
tool_ids = self._handle_tools(agent_config)
|
||||
if len(tool_ids) == 0:
|
||||
self.logger.warning(
|
||||
f"No valid tools for agent {agent_config['name']}"
|
||||
)
|
||||
used_tool_ids.update(tool_ids)
|
||||
|
||||
# 3. Handle Prompt
|
||||
|
||||
prompt_id = self._handle_prompt(agent_config)
|
||||
|
||||
# 4. Create Agent
|
||||
|
||||
agent_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": agent_config["name"],
|
||||
"description": agent_config["description"],
|
||||
"image": agent_config.get("image", ""),
|
||||
"source": (
|
||||
DBRef("sources", ObjectId(source_id)) if source_id else ""
|
||||
),
|
||||
"tools": [str(tid) for tid in tool_ids],
|
||||
"agent_type": agent_config["agent_type"],
|
||||
"prompt_id": prompt_id or agent_config.get("prompt_id", "default"),
|
||||
"chunks": agent_config.get("chunks", "0"),
|
||||
"retriever": agent_config.get("retriever", ""),
|
||||
"status": "template",
|
||||
"createdAt": datetime.now(timezone.utc),
|
||||
"updatedAt": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
existing = self.agents_collection.find_one(
|
||||
{"user": self.system_user_id, "name": agent_config["name"]}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Updating existing agent: {agent_config['name']}")
|
||||
self.agents_collection.update_one(
|
||||
{"_id": existing["_id"]}, {"$set": agent_data}
|
||||
)
|
||||
agent_id = existing["_id"]
|
||||
else:
|
||||
self.logger.info(f"Creating new agent: {agent_config['name']}")
|
||||
result = self.agents_collection.insert_one(agent_data)
|
||||
agent_id = result.inserted_id
|
||||
self.logger.info(
|
||||
f"Successfully processed agent: {agent_config['name']} (ID: {agent_id})"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error processing agent {agent_config['name']}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
self.logger.info("✅ Database seeding completed")
|
||||
|
||||
def _handle_source(self, agent_config: Dict) -> Union[ObjectId, None, bool]:
|
||||
"""Handle source ingestion and return source ID"""
|
||||
if not agent_config.get("source"):
|
||||
self.logger.info(
|
||||
"No source provided for agent - will create agent without source"
|
||||
)
|
||||
return None
|
||||
source_config = agent_config["source"]
|
||||
self.logger.info(f"Ingesting source: {source_config['url']}")
|
||||
|
||||
try:
|
||||
existing = self.sources_collection.find_one(
|
||||
{"user": self.system_user_id, "remote_data": source_config["url"]}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Source already exists: {existing['_id']}")
|
||||
return existing["_id"]
|
||||
# Ingest new source using worker
|
||||
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_config["url"],
|
||||
job_name=source_config["name"],
|
||||
user=self.system_user_id,
|
||||
loader=source_config.get("loader", "url"),
|
||||
)
|
||||
|
||||
result = task.get(timeout=300)
|
||||
|
||||
if not task.successful():
|
||||
raise Exception(f"Source ingestion failed: {result}")
|
||||
source_id = None
|
||||
if isinstance(result, dict) and "id" in result:
|
||||
source_id = result["id"]
|
||||
else:
|
||||
raise Exception(f"Source ingestion result missing 'id': {result}")
|
||||
self.logger.info(f"Source ingested successfully: {source_id}")
|
||||
return source_id
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to ingest source: {str(e)}")
|
||||
return False
|
||||
|
||||
def _handle_tools(self, agent_config: Dict) -> List[ObjectId]:
|
||||
"""Handle tool creation and return list of tool IDs"""
|
||||
tool_ids = []
|
||||
if not agent_config.get("tools"):
|
||||
return tool_ids
|
||||
for tool_config in agent_config["tools"]:
|
||||
try:
|
||||
tool_name = tool_config["name"]
|
||||
processed_config = self._process_config(tool_config.get("config", {}))
|
||||
self.logger.info(f"Processing tool: {tool_name}")
|
||||
|
||||
existing = self.tools_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"config": processed_config,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Tool already exists: {existing['_id']}")
|
||||
tool_ids.append(existing["_id"])
|
||||
continue
|
||||
tool_data = {
|
||||
"user": self.system_user_id,
|
||||
"name": tool_name,
|
||||
"displayName": tool_config.get("display_name", tool_name),
|
||||
"description": tool_config.get("description", ""),
|
||||
"actions": tool_manager.tools[tool_name].get_actions_metadata(),
|
||||
"config": processed_config,
|
||||
"status": True,
|
||||
}
|
||||
|
||||
result = self.tools_collection.insert_one(tool_data)
|
||||
tool_ids.append(result.inserted_id)
|
||||
self.logger.info(f"Created new tool: {result.inserted_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process tool {tool_name}: {str(e)}")
|
||||
continue
|
||||
return tool_ids
|
||||
|
||||
def _handle_prompt(self, agent_config: Dict) -> Optional[str]:
|
||||
"""Handle prompt creation and return prompt ID"""
|
||||
if not agent_config.get("prompt"):
|
||||
return None
|
||||
|
||||
prompt_config = agent_config["prompt"]
|
||||
prompt_name = prompt_config.get("name", f"{agent_config['name']} Prompt")
|
||||
prompt_content = prompt_config.get("content", "")
|
||||
|
||||
if not prompt_content:
|
||||
self.logger.warning(
|
||||
f"No prompt content provided for agent {agent_config['name']}"
|
||||
)
|
||||
return None
|
||||
|
||||
self.logger.info(f"Processing prompt: {prompt_name}")
|
||||
|
||||
try:
|
||||
existing = self.prompts_collection.find_one(
|
||||
{
|
||||
"user": self.system_user_id,
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
}
|
||||
)
|
||||
if existing:
|
||||
self.logger.info(f"Prompt already exists: {existing['_id']}")
|
||||
return str(existing["_id"])
|
||||
|
||||
prompt_data = {
|
||||
"name": prompt_name,
|
||||
"content": prompt_content,
|
||||
"user": self.system_user_id,
|
||||
}
|
||||
|
||||
result = self.prompts_collection.insert_one(prompt_data)
|
||||
prompt_id = str(result.inserted_id)
|
||||
self.logger.info(f"Created new prompt: {prompt_id}")
|
||||
return prompt_id
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process prompt {prompt_name}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _process_config(self, config: Dict) -> Dict:
|
||||
"""Process config values to replace environment variables"""
|
||||
processed = {}
|
||||
for key, value in config.items():
|
||||
if (
|
||||
isinstance(value, str)
|
||||
and value.startswith("${")
|
||||
and value.endswith("}")
|
||||
):
|
||||
env_var = value[2:-1]
|
||||
processed[key] = os.getenv(env_var, "")
|
||||
else:
|
||||
processed[key] = value
|
||||
return processed
|
||||
|
||||
def _is_already_seeded(self) -> bool:
|
||||
"""Check if premade agents already exist"""
|
||||
return self.agents_collection.count_documents({"user": self.system_user_id}) > 0
|
||||
|
||||
@classmethod
|
||||
def initialize_from_env(cls, worker=None):
|
||||
"""Factory method to create seeder from environment"""
|
||||
mongo_uri = os.getenv("MONGO_URI", "mongodb://localhost:27017")
|
||||
db_name = os.getenv("MONGO_DB_NAME", "docsgpt")
|
||||
client = MongoClient(mongo_uri)
|
||||
db = client[db_name]
|
||||
return cls(db)
|
||||
@@ -26,7 +26,7 @@ class LocalStorage(BaseStorage):
|
||||
return path
|
||||
return os.path.join(self.base_dir, path)
|
||||
|
||||
def save_file(self, file_data: BinaryIO, path: str) -> dict:
|
||||
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
||||
"""Save a file to local storage."""
|
||||
full_path = self._get_full_path(path)
|
||||
|
||||
|
||||
@@ -168,6 +168,10 @@ def validate_function_name(function_name):
|
||||
|
||||
|
||||
def generate_image_url(image_path):
|
||||
if isinstance(image_path, str) and (
|
||||
image_path.startswith("http://") or image_path.startswith("https://")
|
||||
):
|
||||
return image_path
|
||||
strategy = getattr(settings, "URL_STRATEGY", "backend")
|
||||
if strategy == "s3":
|
||||
bucket_name = getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
|
||||
|
||||
@@ -39,6 +39,7 @@ sources_collection = db["sources"]
|
||||
|
||||
# Constants
|
||||
|
||||
|
||||
MIN_TOKENS = 150
|
||||
MAX_TOKENS = 1250
|
||||
RECURSION_DEPTH = 2
|
||||
@@ -740,7 +741,13 @@ def remote_worker(
|
||||
if os.path.exists(full_path):
|
||||
shutil.rmtree(full_path)
|
||||
logging.info("remote_worker task completed successfully")
|
||||
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
|
||||
return {
|
||||
"id": str(id),
|
||||
"urls": source_data,
|
||||
"name_job": name_job,
|
||||
"user": user,
|
||||
"limited": False,
|
||||
}
|
||||
|
||||
|
||||
def sync(
|
||||
|
||||
@@ -107,3 +107,13 @@ Once an agent is created, you can:
|
||||
* Modify any of its configuration settings (name, description, source, prompt, tools, type).
|
||||
* **Generate a Public Link:** From the edit screen, you can create a shareable public link that allows others to import and use your agent.
|
||||
* **Get a Webhook URL:** You can also obtain a Webhook URL for the agent. This allows external applications or services to trigger the agent and receive responses programmatically, enabling powerful integrations and automations.
|
||||
|
||||
## Seeding Premade Agents from YAML
|
||||
|
||||
You can bootstrap a fresh DocsGPT deployment with a curated set of agents by seeding them directly into MongoDB.
|
||||
|
||||
1. **Customize the configuration** – edit `application/seed/config/premade_agents.yaml` (or copy from `application/seed/config/agents_template.yaml`) to describe the agents you want to provision. Each entry lets you define prompts, tools, and optional data sources.
|
||||
2. **Ensure dependencies are running** – MongoDB must be reachable using the credentials in `.env`, and a Celery worker should be available if any agent sources need to be ingested via `ingest_remote`.
|
||||
3. **Execute the seeder** – run `python -m application.seed.commands init`. Add `--force` when you need to reseed an existing environment.
|
||||
|
||||
The seeder keeps templates under the `system` user so they appear in the UI for anyone to clone or customize. Environment variable placeholders such as `${MY_TOKEN}` inside tool configs are resolved during the seeding process.
|
||||
|
||||
@@ -43,7 +43,8 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
|
||||
2) Serve Local (with Ollama)
|
||||
3) Connect Local Inference Engine
|
||||
4) Connect Cloud API Provider
|
||||
Choose option (1-4):
|
||||
5) Advanced: Build images locally (for developers)
|
||||
Choose option (1-5):
|
||||
```
|
||||
|
||||
Let's break down each option:
|
||||
@@ -56,6 +57,8 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
|
||||
|
||||
* **4) Connect Cloud API Provider:** This option lets you connect DocsGPT to a commercial Cloud API provider such as OpenAI, Google (Vertex AI/Gemini), Anthropic (Claude), Groq, HuggingFace Inference API, or Azure OpenAI. You will need an API key from your chosen provider. Select this if you prefer to use a powerful cloud-based LLM.
|
||||
|
||||
* **5) Modify DocsGPT's source code and rebuild the Docker images locally. Instead of pulling prebuilt images from Docker Hub or using the hosted/public API, you build the entire backend and frontend from source, customizing how DocsGPT works internally, or run it in an environment without internet access.
|
||||
|
||||
After selecting an option and providing any required information (like API keys or model names), the script will configure your `.env` file and start DocsGPT using Docker Compose.
|
||||
|
||||
4. **Access DocsGPT in your browser:**
|
||||
|
||||
17
frontend/package-lock.json
generated
17
frontend/package-lock.json
generated
@@ -12,7 +12,7 @@
|
||||
"chart.js": "^4.4.4",
|
||||
"clsx": "^2.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"i18next": "^24.2.0",
|
||||
"i18next": "^25.5.3",
|
||||
"i18next-browser-languagedetector": "^8.0.2",
|
||||
"lodash": "^4.17.21",
|
||||
"mermaid": "^11.6.0",
|
||||
@@ -321,9 +321,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/runtime": {
|
||||
"version": "7.27.3",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.3.tgz",
|
||||
"integrity": "sha512-7EYtGezsdiDMyY80+65EzwiGmcJqpmcZCojSXaRgdrBaGtWTgDZKq69cPIVped6MkIM78cTQ2GOiEYjwOlG4xw==",
|
||||
"version": "7.28.4",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz",
|
||||
"integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
@@ -6217,9 +6217,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/i18next": {
|
||||
"version": "24.2.0",
|
||||
"resolved": "https://registry.npmjs.org/i18next/-/i18next-24.2.0.tgz",
|
||||
"integrity": "sha512-ArJJTS1lV6lgKH7yEf4EpgNZ7+THl7bsGxxougPYiXRTJ/Fe1j08/TBpV9QsXCIYVfdE/HWG/xLezJ5DOlfBOA==",
|
||||
"version": "25.5.3",
|
||||
"resolved": "https://registry.npmjs.org/i18next/-/i18next-25.5.3.tgz",
|
||||
"integrity": "sha512-joFqorDeQ6YpIXni944upwnuHBf5IoPMuqAchGVeQLdWC2JOjxgM9V8UGLhNIIH/Q8QleRxIi0BSRQehSrDLcg==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
@@ -6234,8 +6234,9 @@
|
||||
"url": "https://www.i18next.com/how-to/faq#i18next-is-awesome.-how-can-i-support-the-project"
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.23.2"
|
||||
"@babel/runtime": "^7.27.6"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"typescript": "^5"
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
"chart.js": "^4.4.4",
|
||||
"clsx": "^2.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"i18next": "^24.2.0",
|
||||
"i18next": "^25.5.3",
|
||||
"i18next-browser-languagedetector": "^8.0.2",
|
||||
"lodash": "^4.17.21",
|
||||
"mermaid": "^11.6.0",
|
||||
|
||||
3
frontend/public/toolIcons/tool_memory.svg
Normal file
3
frontend/public/toolIcons/tool_memory.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3">
|
||||
<path d="M240-80q-33 0-56.5-23.5T160-160v-480q0-33 23.5-56.5T240-720h80v-80q0-17 11.5-28.5T360-840q17 0 28.5 11.5T400-800v80h40v-80q0-17 11.5-28.5T480-840q17 0 28.5 11.5T520-800v80h40v-80q0-17 11.5-28.5T600-840q17 0 28.5 11.5T640-800v80h80q33 0 56.5 23.5T800-640v480q0 33-23.5 56.5T720-80H240Zm0-80h480v-480H240v480Zm120-320v-80h240v80H360Zm0 120v-80h240v80H360Zm0 120v-80h160v80H360ZM240-160v-480 480Z"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 523 B |
1
frontend/public/toolIcons/tool_notes.svg
Normal file
1
frontend/public/toolIcons/tool_notes.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#e3e3e3"><path d="M320-240h320v-80H320v80Zm0-160h320v-80H320v80ZM240-80q-33 0-56.5-23.5T160-160v-640q0-33 23.5-56.5T240-880h320l240 240v480q0 33-23.5 56.5T720-80H240Zm280-520v-200H240v640h480v-440H520ZM240-800v200-200 640-640Z"/></svg>
|
||||
|
After Width: | Height: | Size: 334 B |
@@ -7,6 +7,7 @@ import Agents from './agents';
|
||||
import SharedAgentGate from './agents/SharedAgentGate';
|
||||
import ActionButtons from './components/ActionButtons';
|
||||
import Spinner from './components/Spinner';
|
||||
import UploadToast from './components/UploadToast';
|
||||
import Conversation from './conversation/Conversation';
|
||||
import { SharedConversation } from './conversation/SharedConversation';
|
||||
import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
@@ -37,14 +38,15 @@ function MainLayout() {
|
||||
<Navigation navOpen={navOpen} setNavOpen={setNavOpen} />
|
||||
<ActionButtons showNewChat={true} showShare={true} />
|
||||
<div
|
||||
className={`h-[calc(100dvh-64px)] overflow-auto lg:h-screen ${
|
||||
className={`h-[calc(100dvh-64px)] overflow-auto transition-all duration-300 ease-in-out lg:h-screen ${
|
||||
!(isMobile || isTablet)
|
||||
? `ml-0 ${!navOpen ? 'lg:mx-auto' : 'lg:ml-72'}`
|
||||
? `${navOpen ? 'lg:ml-72' : 'lg:ml-0'}`
|
||||
: 'ml-0 lg:ml-16'
|
||||
}`}
|
||||
>
|
||||
<Outlet />
|
||||
</div>
|
||||
<UploadToast />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import Github from './assets/git_nav.svg';
|
||||
import Hamburger from './assets/hamburger.svg';
|
||||
import openNewChat from './assets/openNewChat.svg';
|
||||
import Pin from './assets/pin.svg';
|
||||
import Robot from './assets/robot.svg';
|
||||
import AgentImage from './components/AgentImage';
|
||||
import SettingGear from './assets/settingGear.svg';
|
||||
import Spark from './assets/spark.svg';
|
||||
import SpinnerDark from './assets/spinner-dark.svg';
|
||||
@@ -292,20 +292,26 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
useDefaultDocument();
|
||||
return (
|
||||
<>
|
||||
{!navOpen && (
|
||||
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-25 lg:block">
|
||||
{(isMobile || isTablet) && navOpen && (
|
||||
<div
|
||||
className="fixed inset-0 z-10 bg-black opacity-50 transition-opacity duration-300"
|
||||
onClick={() => setNavOpen(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{
|
||||
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-300 ease-in-out lg:block">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
onClick={() => {
|
||||
setNavOpen(!navOpen);
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
className={`${
|
||||
!navOpen ? 'rotate-180' : 'rotate-0'
|
||||
} m-auto transition-all duration-200`}
|
||||
className="m-auto transition-all duration-300 ease-in-out"
|
||||
/>
|
||||
</button>
|
||||
{queries?.length > 0 && (
|
||||
@@ -313,6 +319,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
onClick={() => {
|
||||
newChat();
|
||||
}}
|
||||
className="transition-transform duration-200 hover:scale-110"
|
||||
>
|
||||
<img
|
||||
src={openNewChat}
|
||||
@@ -326,12 +333,12 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
}
|
||||
<div
|
||||
ref={navRef}
|
||||
className={`${
|
||||
!navOpen && '-ml-96 md:-ml-72'
|
||||
} bg-lotion dark:border-r-purple-taupe dark:bg-chinese-black fixed top-0 z-20 flex h-full w-72 flex-col border-r border-b-0 transition-all duration-20 dark:text-white`}
|
||||
} bg-lotion dark:border-r-purple-taupe dark:bg-chinese-black fixed top-0 z-20 flex h-full w-72 flex-col border-r border-b-0 transition-all duration-300 ease-in-out dark:text-white`}
|
||||
>
|
||||
<div
|
||||
className={'visible mt-2 flex h-[6vh] w-full justify-between md:h-12'}
|
||||
@@ -345,7 +352,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
}}
|
||||
>
|
||||
<a href="/" className="flex gap-1.5">
|
||||
<img className="mb-2 h-10" src={DocsGPT3} alt="DocsGPT Logo" />
|
||||
<img className="h-10" src={DocsGPT3} alt="DocsGPT Logo" />
|
||||
<p className="my-auto text-2xl font-semibold">DocsGPT</p>
|
||||
</a>
|
||||
</div>
|
||||
@@ -358,9 +365,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<img
|
||||
src={Expand}
|
||||
alt="Toggle navigation menu"
|
||||
className={`${
|
||||
!navOpen ? 'rotate-180' : 'rotate-0'
|
||||
} m-auto transition-all duration-200`}
|
||||
className="m-auto transition-all duration-300 ease-in-out hover:scale-110"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
@@ -419,12 +424,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex w-6 justify-center">
|
||||
<img
|
||||
src={
|
||||
agent.image && agent.image.trim() !== ''
|
||||
? agent.image
|
||||
: Robot
|
||||
}
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
alt="agent-logo"
|
||||
className="h-6 w-6 rounded-full object-contain"
|
||||
/>
|
||||
@@ -576,7 +577,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
</NavLink>
|
||||
<NavLink
|
||||
target="_blank"
|
||||
to={'https://twitter.com/docsgptai'}
|
||||
to={'https://x.com/docsgptai'}
|
||||
className={
|
||||
'rounded-full hover:bg-gray-100 dark:hover:bg-[#28292E]'
|
||||
}
|
||||
@@ -585,7 +586,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
src={Twitter}
|
||||
width={20}
|
||||
height={20}
|
||||
alt="Follow us on Twitter"
|
||||
alt="Follow us on X"
|
||||
className="m-2 self-center filter dark:invert"
|
||||
/>
|
||||
</NavLink>
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
import { useRef, useState } from 'react';
|
||||
import { SyntheticEvent, useRef, useState } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import Robot from '../assets/robot.svg';
|
||||
import Duplicate from '../assets/duplicate.svg';
|
||||
import Edit from '../assets/edit.svg';
|
||||
import Link from '../assets/link-gray.svg';
|
||||
import Monitoring from '../assets/monitoring.svg';
|
||||
import Pin from '../assets/pin.svg';
|
||||
import Trash from '../assets/red-trash.svg';
|
||||
import ThreeDots from '../assets/three-dots.svg';
|
||||
import UnPin from '../assets/unpin.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import ContextMenu, { MenuOption } from '../components/ContextMenu';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import {
|
||||
selectAgents,
|
||||
selectToken,
|
||||
setAgents,
|
||||
setSelectedAgent,
|
||||
@@ -18,46 +26,205 @@ import { Agent } from './types';
|
||||
type AgentCardProps = {
|
||||
agent: Agent;
|
||||
agents: Agent[];
|
||||
menuOptions?: MenuOption[];
|
||||
onDelete?: (agentId: string) => void;
|
||||
updateAgents?: (agents: Agent[]) => void;
|
||||
section: string;
|
||||
};
|
||||
|
||||
export default function AgentCard({
|
||||
agent,
|
||||
agents,
|
||||
menuOptions,
|
||||
onDelete,
|
||||
updateAgents,
|
||||
section,
|
||||
}: AgentCardProps) {
|
||||
const navigate = useNavigate();
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const userAgents = useSelector(selectAgents);
|
||||
|
||||
const [isMenuOpen, setIsMenuOpen] = useState(false);
|
||||
const [isMenuOpen, setIsMenuOpen] = useState<boolean>(false);
|
||||
const [deleteConfirmation, setDeleteConfirmation] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
|
||||
const menuRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (agent.status === 'published') {
|
||||
dispatch(setSelectedAgent(agent));
|
||||
navigate('/');
|
||||
const menuOptionsConfig: Record<string, MenuOption[]> = {
|
||||
template: [
|
||||
{
|
||||
icon: Duplicate,
|
||||
label: 'Duplicate',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
handleDuplicate();
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 18,
|
||||
iconHeight: 18,
|
||||
},
|
||||
],
|
||||
user: [
|
||||
{
|
||||
icon: Monitoring,
|
||||
label: 'Logs',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
navigate(`/agents/logs/${agent.id}`);
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 14,
|
||||
iconHeight: 14,
|
||||
},
|
||||
{
|
||||
icon: Edit,
|
||||
label: 'Edit',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
navigate(`/agents/edit/${agent.id}`);
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 14,
|
||||
iconHeight: 14,
|
||||
},
|
||||
...(agent.status === 'published'
|
||||
? [
|
||||
{
|
||||
icon: agent.pinned ? UnPin : Pin,
|
||||
label: agent.pinned ? 'Unpin' : 'Pin agent',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
togglePin();
|
||||
},
|
||||
variant: 'primary' as const,
|
||||
iconWidth: 18,
|
||||
iconHeight: 18,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
icon: Trash,
|
||||
label: 'Delete',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
setDeleteConfirmation('ACTIVE');
|
||||
},
|
||||
variant: 'danger',
|
||||
iconWidth: 13,
|
||||
iconHeight: 13,
|
||||
},
|
||||
],
|
||||
shared: [
|
||||
{
|
||||
icon: Link,
|
||||
label: 'Open',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 12,
|
||||
iconHeight: 12,
|
||||
},
|
||||
{
|
||||
icon: agent.pinned ? UnPin : Pin,
|
||||
label: agent.pinned ? 'Unpin' : 'Pin agent',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
togglePin();
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 18,
|
||||
iconHeight: 18,
|
||||
},
|
||||
{
|
||||
icon: Trash,
|
||||
label: 'Remove',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
handleHideSharedAgent();
|
||||
},
|
||||
variant: 'danger',
|
||||
iconWidth: 13,
|
||||
iconHeight: 13,
|
||||
},
|
||||
],
|
||||
};
|
||||
const menuOptions = menuOptionsConfig[section] || [];
|
||||
|
||||
const handleClick = () => {
|
||||
if (section === 'user') {
|
||||
if (agent.status === 'published') {
|
||||
dispatch(setSelectedAgent(agent));
|
||||
navigate(`/`);
|
||||
}
|
||||
}
|
||||
if (section === 'shared') {
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
}
|
||||
};
|
||||
|
||||
const defaultDelete = async (agentId: string) => {
|
||||
const response = await userService.deleteAgent(agentId, token);
|
||||
if (!response.ok) throw new Error('Failed to delete agent');
|
||||
const data = await response.json();
|
||||
dispatch(setAgents(agents.filter((prevAgent) => prevAgent.id !== data.id)));
|
||||
const togglePin = async () => {
|
||||
try {
|
||||
const response = await userService.togglePinAgent(agent.id ?? '', token);
|
||||
if (!response.ok) throw new Error('Failed to pin agent');
|
||||
const updatedAgents = agents.map((prevAgent) => {
|
||||
if (prevAgent.id === agent.id)
|
||||
return { ...prevAgent, pinned: !prevAgent.pinned };
|
||||
return prevAgent;
|
||||
});
|
||||
updateAgents?.(updatedAgents);
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleHideSharedAgent = async () => {
|
||||
try {
|
||||
const response = await userService.removeSharedAgent(
|
||||
agent.id ?? '',
|
||||
token,
|
||||
);
|
||||
if (!response.ok) throw new Error('Failed to hide shared agent');
|
||||
const updatedAgents = agents.filter(
|
||||
(prevAgent) => prevAgent.id !== agent.id,
|
||||
);
|
||||
updateAgents?.(updatedAgents);
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async () => {
|
||||
try {
|
||||
const response = await userService.deleteAgent(agent.id ?? '', token);
|
||||
if (!response.ok) throw new Error('Failed to delete agent');
|
||||
const updatedAgents = agents.filter(
|
||||
(prevAgent) => prevAgent.id !== agent.id,
|
||||
);
|
||||
updateAgents?.(updatedAgents);
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDuplicate = async () => {
|
||||
try {
|
||||
const response = await userService.adoptAgent(agent.id ?? '', token);
|
||||
if (!response.ok) throw new Error('Failed to duplicate agent');
|
||||
const data = await response.json();
|
||||
if (userAgents) {
|
||||
const updatedAgents = [...userAgents, data.agent];
|
||||
dispatch(setAgents(updatedAgents));
|
||||
} else dispatch(setAgents([data.agent]));
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
}
|
||||
};
|
||||
return (
|
||||
<div
|
||||
className={`relative flex h-44 w-48 flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] dark:bg-[#383838] dark:hover:bg-[#383838]/80 ${
|
||||
agent.status === 'published' ? 'cursor-pointer' : ''
|
||||
}`}
|
||||
onClick={handleCardClick}
|
||||
className={`relative flex h-44 w-full flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] md:w-48 dark:bg-[#383838] dark:hover:bg-[#383838]/80 ${agent.status === 'published' && 'cursor-pointer'}`}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleClick();
|
||||
}}
|
||||
>
|
||||
<div
|
||||
ref={menuRef}
|
||||
@@ -67,30 +234,25 @@ export default function AgentCard({
|
||||
}}
|
||||
className="absolute top-4 right-4 z-10 cursor-pointer"
|
||||
>
|
||||
<img src={ThreeDots} alt="options" className="h-[19px] w-[19px]" />
|
||||
{menuOptions && (
|
||||
<ContextMenu
|
||||
isOpen={isMenuOpen}
|
||||
setIsOpen={setIsMenuOpen}
|
||||
options={menuOptions}
|
||||
anchorRef={menuRef}
|
||||
position="top-right"
|
||||
offset={{ x: 0, y: 0 }}
|
||||
/>
|
||||
)}
|
||||
<img src={ThreeDots} alt={'use-agent'} className="h-[19px] w-[19px]" />
|
||||
<ContextMenu
|
||||
isOpen={isMenuOpen}
|
||||
setIsOpen={setIsMenuOpen}
|
||||
options={menuOptions}
|
||||
anchorRef={menuRef}
|
||||
position="bottom-right"
|
||||
offset={{ x: 0, y: 0 }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center gap-1 px-1">
|
||||
<img
|
||||
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
alt={`${agent.name}`}
|
||||
className="h-7 w-7 rounded-full object-contain"
|
||||
/>
|
||||
{agent.status === 'draft' && (
|
||||
<p className="text-xs text-black opacity-50 dark:text-[#E0E0E0]">
|
||||
(Draft)
|
||||
</p>
|
||||
<p className="text-xs text-black opacity-50 dark:text-[#E0E0E0]">{`(Draft)`}</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
@@ -105,14 +267,13 @@ export default function AgentCard({
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ConfirmationModal
|
||||
message="Are you sure you want to delete this agent?"
|
||||
modalState={deleteConfirmation}
|
||||
setModalState={setDeleteConfirmation}
|
||||
submitLabel="Delete"
|
||||
handleSubmit={() => {
|
||||
onDelete ? onDelete(agent.id || '') : defaultDelete(agent.id || '');
|
||||
handleDelete();
|
||||
setDeleteConfirmation('INACTIVE');
|
||||
}}
|
||||
cancelLabel="Cancel"
|
||||
|
||||
@@ -49,7 +49,7 @@ export default function AgentLogs() {
|
||||
</p>
|
||||
</div>
|
||||
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
|
||||
<h1 className="text-eerie-black m-0 text-[40px] font-bold dark:text-white">
|
||||
<h1 className="text-eerie-black m-0 text-[32px] font-bold md:text-[40px] dark:text-white">
|
||||
Agent Logs
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
134
frontend/src/agents/AgentsList.tsx
Normal file
134
frontend/src/agents/AgentsList.tsx
Normal file
@@ -0,0 +1,134 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
|
||||
import Spinner from '../components/Spinner';
|
||||
import {
|
||||
setConversation,
|
||||
updateConversationId,
|
||||
} from '../conversation/conversationSlice';
|
||||
import {
|
||||
selectSelectedAgent,
|
||||
selectToken,
|
||||
setSelectedAgent,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import AgentCard from './AgentCard';
|
||||
import { agentSectionsConfig } from './agents.config';
|
||||
import { Agent } from './types';
|
||||
|
||||
export default function AgentsList() {
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const selectedAgent = useSelector(selectSelectedAgent);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setConversation([]));
|
||||
dispatch(
|
||||
updateConversationId({
|
||||
query: { conversationId: null },
|
||||
}),
|
||||
);
|
||||
if (selectedAgent) dispatch(setSelectedAgent(null));
|
||||
}, [token]);
|
||||
return (
|
||||
<div className="p-4 md:p-12">
|
||||
<h1 className="text-eerie-black mb-0 text-[32px] font-bold lg:text-[40px] dark:text-[#E0E0E0]">
|
||||
Agents
|
||||
</h1>
|
||||
<p className="dark:text-gray-4000 mt-5 text-[15px] text-[#71717A]">
|
||||
Discover and create custom versions of DocsGPT that combine
|
||||
instructions, extra knowledge, and any combination of skills
|
||||
</p>
|
||||
{agentSectionsConfig.map((sectionConfig) => (
|
||||
<AgentSection key={sectionConfig.id} config={sectionConfig} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function AgentSection({
|
||||
config,
|
||||
}: {
|
||||
config: (typeof agentSectionsConfig)[number];
|
||||
}) {
|
||||
const navigate = useNavigate();
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const agents = useSelector(config.selectData);
|
||||
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
const updateAgents = (updatedAgents: Agent[]) => {
|
||||
dispatch(config.updateAction(updatedAgents));
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const getAgents = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await config.fetchAgents(token);
|
||||
if (!response.ok)
|
||||
throw new Error(`Failed to fetch ${config.id} agents`);
|
||||
const data = await response.json();
|
||||
dispatch(config.updateAction(data));
|
||||
} catch (error) {
|
||||
console.error(`Error fetching ${config.id} agents:`, error);
|
||||
dispatch(config.updateAction([]));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
getAgents();
|
||||
}, [token, config, dispatch]);
|
||||
return (
|
||||
<div className="mt-8 flex flex-col gap-4">
|
||||
<div className="flex w-full items-center justify-between">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h2 className="text-[18px] font-semibold text-[#18181B] dark:text-[#E0E0E0]">
|
||||
{config.title}
|
||||
</h2>
|
||||
<p className="text-[13px] text-[#71717A]">{config.description}</p>
|
||||
</div>
|
||||
{config.showNewAgentButton && (
|
||||
<button
|
||||
className="bg-purple-30 hover:bg-violets-are-blue rounded-full px-4 py-2 text-sm text-white"
|
||||
onClick={() => navigate('/agents/new')}
|
||||
>
|
||||
New Agent
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
{loading ? (
|
||||
<div className="flex h-72 w-full items-center justify-center">
|
||||
<Spinner />
|
||||
</div>
|
||||
) : agents && agents.length > 0 ? (
|
||||
<div className="grid grid-cols-1 gap-4 sm:flex sm:flex-wrap">
|
||||
{agents.map((agent) => (
|
||||
<AgentCard
|
||||
key={agent.id}
|
||||
agent={agent}
|
||||
agents={agents}
|
||||
updateAgents={updateAgents}
|
||||
section={config.id}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex h-72 w-full flex-col items-center justify-center gap-3 text-base text-[#18181B] dark:text-[#E0E0E0]">
|
||||
<p>{config.emptyStateDescription}</p>
|
||||
{config.showNewAgentButton && (
|
||||
<button
|
||||
className="bg-purple-30 hover:bg-violets-are-blue ml-2 rounded-full px-4 py-2 text-sm text-white"
|
||||
onClick={() => navigate('/agents/new')}
|
||||
>
|
||||
New Agent
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,7 +23,7 @@ import PromptsModal from '../preferences/PromptsModal';
|
||||
import Prompts from '../settings/Prompts';
|
||||
import { UserToolType } from '../settings/types';
|
||||
import AgentPreview from './AgentPreview';
|
||||
import { Agent } from './types';
|
||||
import { Agent, ToolSummary } from './types';
|
||||
|
||||
const embeddingsName =
|
||||
import.meta.env.VITE_EMBEDDINGS_NAME ||
|
||||
@@ -64,9 +64,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const [selectedSourceIds, setSelectedSourceIds] = useState<
|
||||
Set<string | number>
|
||||
>(new Set());
|
||||
const [selectedToolIds, setSelectedToolIds] = useState<Set<string | number>>(
|
||||
new Set(),
|
||||
);
|
||||
const [selectedTools, setSelectedTools] = useState<ToolSummary[]>([]);
|
||||
const [deleteConfirmation, setDeleteConfirmation] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [agentDetails, setAgentDetails] = useState<ActiveState>('INACTIVE');
|
||||
@@ -337,7 +335,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const data = await response.json();
|
||||
const tools: OptionType[] = data.tools.map((tool: UserToolType) => ({
|
||||
id: tool.id,
|
||||
label: tool.displayName,
|
||||
label: tool.customName ? tool.customName : tool.displayName,
|
||||
icon: `/toolIcons/tool_${tool.name}.svg`,
|
||||
}));
|
||||
setUserTools(tools);
|
||||
@@ -410,7 +408,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
setSelectedSourceIds(new Set([data.retriever]));
|
||||
}
|
||||
|
||||
if (data.tools) setSelectedToolIds(new Set(data.tools));
|
||||
if (data.tool_details) setSelectedTools(data.tool_details);
|
||||
if (data.status === 'draft') setEffectiveMode('draft');
|
||||
if (data.json_schema) {
|
||||
const jsonText = JSON.stringify(data.json_schema, null, 2);
|
||||
@@ -480,16 +478,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
}, [selectedSourceIds]);
|
||||
|
||||
useEffect(() => {
|
||||
const selectedTool = Array.from(selectedToolIds).map((id) =>
|
||||
userTools.find((tool) => tool.id === id),
|
||||
);
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
tools: selectedTool
|
||||
tools: Array.from(selectedTools)
|
||||
.map((tool) => tool?.id)
|
||||
.filter((id): id is string => typeof id === 'string'),
|
||||
}));
|
||||
}, [selectedToolIds]);
|
||||
}, [selectedTools]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isPublishable()) dispatch(setSelectedAgent(agent));
|
||||
@@ -527,7 +522,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
</p>
|
||||
</div>
|
||||
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
|
||||
<h1 className="text-eerie-black m-0 text-[40px] font-bold dark:text-white">
|
||||
<h1 className="text-eerie-black m-0 text-[32px] font-bold lg:text-[40px] dark:text-white">
|
||||
{modeConfig[effectiveMode].heading}
|
||||
</h1>
|
||||
<div className="flex flex-wrap items-center gap-1">
|
||||
@@ -645,15 +640,15 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
>
|
||||
{selectedSourceIds.size > 0
|
||||
? Array.from(selectedSourceIds)
|
||||
.map(
|
||||
(id) =>
|
||||
sourceDocs?.find(
|
||||
(source) =>
|
||||
source.id === id ||
|
||||
source.name === id ||
|
||||
source.retriever === id,
|
||||
)?.name,
|
||||
)
|
||||
.map((id) => {
|
||||
const matchedDoc = sourceDocs?.find(
|
||||
(source) =>
|
||||
source.id === id ||
|
||||
source.name === id ||
|
||||
source.retriever === id,
|
||||
);
|
||||
return matchedDoc?.name || `External KB`;
|
||||
})
|
||||
.filter(Boolean)
|
||||
.join(', ')
|
||||
: 'Select sources'}
|
||||
@@ -768,16 +763,14 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
ref={toolAnchorButtonRef}
|
||||
onClick={() => setIsToolsPopupOpen(!isToolsPopupOpen)}
|
||||
className={`border-silver dark:bg-raisin-black w-full truncate rounded-3xl border bg-white px-5 py-3 text-left text-sm dark:border-[#7E7E7E] ${
|
||||
selectedToolIds.size > 0
|
||||
selectedTools.length > 0
|
||||
? 'text-jet dark:text-bright-gray'
|
||||
: 'dark:text-silver text-gray-400'
|
||||
}`}
|
||||
>
|
||||
{selectedToolIds.size > 0
|
||||
? Array.from(selectedToolIds)
|
||||
.map(
|
||||
(id) => userTools.find((tool) => tool.id === id)?.label,
|
||||
)
|
||||
{selectedTools.length > 0
|
||||
? selectedTools
|
||||
.map((tool) => tool.display_name || tool.name)
|
||||
.filter(Boolean)
|
||||
.join(', ')
|
||||
: 'Select tools'}
|
||||
@@ -787,9 +780,17 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
onClose={() => setIsToolsPopupOpen(false)}
|
||||
anchorRef={toolAnchorButtonRef}
|
||||
options={userTools}
|
||||
selectedIds={selectedToolIds}
|
||||
selectedIds={new Set(selectedTools.map((tool) => tool.id))}
|
||||
onSelectionChange={(newSelectedIds: Set<string | number>) =>
|
||||
setSelectedToolIds(newSelectedIds)
|
||||
setSelectedTools(
|
||||
userTools
|
||||
.filter((tool) => newSelectedIds.has(tool.id))
|
||||
.map((tool) => ({
|
||||
id: String(tool.id),
|
||||
name: tool.label,
|
||||
display_name: tool.label,
|
||||
})),
|
||||
)
|
||||
}
|
||||
title="Select Tools"
|
||||
searchPlaceholder="Search tools..."
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useParams } from 'react-router-dom';
|
||||
import userService from '../api/services/userService';
|
||||
import NoFilesDarkIcon from '../assets/no-files-dark.svg';
|
||||
import NoFilesIcon from '../assets/no-files.svg';
|
||||
import Robot from '../assets/robot.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import MessageInput from '../components/MessageInput';
|
||||
import Spinner from '../components/Spinner';
|
||||
import ConversationMessages from '../conversation/ConversationMessages';
|
||||
@@ -152,12 +152,8 @@ export default function SharedAgent() {
|
||||
return (
|
||||
<div className="relative h-full w-full">
|
||||
<div className="absolute top-5 left-4 hidden items-center gap-3 sm:flex">
|
||||
<img
|
||||
src={
|
||||
sharedAgent.image && sharedAgent.image.trim() !== ''
|
||||
? sharedAgent.image
|
||||
: Robot
|
||||
}
|
||||
<AgentImage
|
||||
src={sharedAgent.image}
|
||||
alt="agent-logo"
|
||||
className="h-6 w-6 rounded-full object-contain"
|
||||
/>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import Robot from '../assets/robot.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import { Agent } from './types';
|
||||
|
||||
export default function SharedAgentCard({ agent }: { agent: Agent }) {
|
||||
@@ -6,8 +6,8 @@ export default function SharedAgentCard({ agent }: { agent: Agent }) {
|
||||
<div className="border-dark-gray dark:border-grey flex w-full max-w-[720px] flex-col rounded-3xl border p-6 shadow-xs sm:w-fit sm:min-w-[480px]">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-12 w-12 items-center justify-center overflow-hidden rounded-full p-1">
|
||||
<img
|
||||
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
className="h-full w-full rounded-full object-contain"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
|
||||
import {
|
||||
handleFetchAnswer,
|
||||
handleFetchAnswerSteaming,
|
||||
} from '../conversation/conversationHandlers';
|
||||
import {
|
||||
Answer,
|
||||
ConversationState,
|
||||
Query,
|
||||
Status,
|
||||
} from '../conversation/conversationModels';
|
||||
import {
|
||||
handleFetchAnswer,
|
||||
handleFetchAnswerSteaming,
|
||||
} from '../conversation/conversationHandlers';
|
||||
import {
|
||||
selectCompletedAttachments,
|
||||
clearAttachments,
|
||||
} from '../upload/uploadSlice';
|
||||
import store from '../store';
|
||||
import {
|
||||
clearAttachments,
|
||||
selectCompletedAttachments,
|
||||
} from '../upload/uploadSlice';
|
||||
|
||||
const initialState: ConversationState = {
|
||||
queries: [],
|
||||
|
||||
42
frontend/src/agents/agents.config.ts
Normal file
42
frontend/src/agents/agents.config.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import userService from '../api/services/userService';
|
||||
import {
|
||||
selectAgents,
|
||||
selectTemplateAgents,
|
||||
selectSharedAgents,
|
||||
setAgents,
|
||||
setTemplateAgents,
|
||||
setSharedAgents,
|
||||
} from '../preferences/preferenceSlice';
|
||||
|
||||
export const agentSectionsConfig = [
|
||||
{
|
||||
id: 'template',
|
||||
title: 'By DocsGPT',
|
||||
description: 'Agents provided by DocsGPT',
|
||||
showNewAgentButton: false,
|
||||
emptyStateDescription: 'No template agents found.',
|
||||
fetchAgents: (token: string | null) => userService.getTemplateAgents(token),
|
||||
selectData: selectTemplateAgents,
|
||||
updateAction: setTemplateAgents,
|
||||
},
|
||||
{
|
||||
id: 'user',
|
||||
title: 'By me',
|
||||
description: 'Agents created or published by you',
|
||||
showNewAgentButton: true,
|
||||
emptyStateDescription: 'You don’t have any created agents yet.',
|
||||
fetchAgents: (token: string | null) => userService.getAgents(token),
|
||||
selectData: selectAgents,
|
||||
updateAction: setAgents,
|
||||
},
|
||||
{
|
||||
id: 'shared',
|
||||
title: 'Shared with me',
|
||||
description: 'Agents imported by using a public link',
|
||||
showNewAgentButton: false,
|
||||
emptyStateDescription: 'No shared agents found.',
|
||||
fetchAgents: (token: string | null) => userService.getSharedAgents(token),
|
||||
selectData: selectSharedAgents,
|
||||
updateAction: setSharedAgents,
|
||||
},
|
||||
];
|
||||
@@ -1,37 +1,9 @@
|
||||
import { SyntheticEvent, useEffect, useRef, useState } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { Route, Routes, useNavigate } from 'react-router-dom';
|
||||
import { Route, Routes } from 'react-router-dom';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import Edit from '../assets/edit.svg';
|
||||
import Link from '../assets/link-gray.svg';
|
||||
import Monitoring from '../assets/monitoring.svg';
|
||||
import Pin from '../assets/pin.svg';
|
||||
import Trash from '../assets/red-trash.svg';
|
||||
import Robot from '../assets/robot.svg';
|
||||
import ThreeDots from '../assets/three-dots.svg';
|
||||
import UnPin from '../assets/unpin.svg';
|
||||
import ContextMenu, { MenuOption } from '../components/ContextMenu';
|
||||
import Spinner from '../components/Spinner';
|
||||
import {
|
||||
setConversation,
|
||||
updateConversationId,
|
||||
} from '../conversation/conversationSlice';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import {
|
||||
selectAgents,
|
||||
selectSelectedAgent,
|
||||
selectSharedAgents,
|
||||
selectToken,
|
||||
setAgents,
|
||||
setSelectedAgent,
|
||||
setSharedAgents,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import AgentLogs from './AgentLogs';
|
||||
import AgentsList from './AgentsList';
|
||||
import NewAgent from './NewAgent';
|
||||
import SharedAgent from './SharedAgent';
|
||||
import { Agent } from './types';
|
||||
|
||||
export default function Agents() {
|
||||
return (
|
||||
@@ -44,431 +16,3 @@ export default function Agents() {
|
||||
</Routes>
|
||||
);
|
||||
}
|
||||
|
||||
const sectionConfig = {
|
||||
user: {
|
||||
title: 'By me',
|
||||
description: 'Agents created or published by you',
|
||||
showNewAgentButton: true,
|
||||
emptyStateDescription: 'You don’t have any created agents yet',
|
||||
},
|
||||
shared: {
|
||||
title: 'Shared with me',
|
||||
description: 'Agents imported by using a public link',
|
||||
showNewAgentButton: false,
|
||||
emptyStateDescription: 'No shared agents found',
|
||||
},
|
||||
};
|
||||
|
||||
function AgentsList() {
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const agents = useSelector(selectAgents);
|
||||
const sharedAgents = useSelector(selectSharedAgents);
|
||||
const selectedAgent = useSelector(selectSelectedAgent);
|
||||
|
||||
const [loadingUserAgents, setLoadingUserAgents] = useState<boolean>(true);
|
||||
const [loadingSharedAgents, setLoadingSharedAgents] = useState<boolean>(true);
|
||||
|
||||
const getAgents = async () => {
|
||||
try {
|
||||
setLoadingUserAgents(true);
|
||||
const response = await userService.getAgents(token);
|
||||
if (!response.ok) throw new Error('Failed to fetch agents');
|
||||
const data = await response.json();
|
||||
dispatch(setAgents(data));
|
||||
setLoadingUserAgents(false);
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
setLoadingUserAgents(false);
|
||||
}
|
||||
};
|
||||
|
||||
const getSharedAgents = async () => {
|
||||
try {
|
||||
setLoadingSharedAgents(true);
|
||||
const response = await userService.getSharedAgents(token);
|
||||
if (!response.ok) throw new Error('Failed to fetch shared agents');
|
||||
const data = await response.json();
|
||||
dispatch(setSharedAgents(data));
|
||||
setLoadingSharedAgents(false);
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
setLoadingSharedAgents(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getAgents();
|
||||
getSharedAgents();
|
||||
dispatch(setConversation([]));
|
||||
dispatch(
|
||||
updateConversationId({
|
||||
query: { conversationId: null },
|
||||
}),
|
||||
);
|
||||
if (selectedAgent) dispatch(setSelectedAgent(null));
|
||||
}, [token]);
|
||||
return (
|
||||
<div className="p-4 md:p-12">
|
||||
<h1 className="text-eerie-black mb-0 text-[40px] font-bold dark:text-[#E0E0E0]">
|
||||
Agents
|
||||
</h1>
|
||||
<p className="dark:text-gray-4000 mt-5 text-[15px] text-[#71717A]">
|
||||
Discover and create custom versions of DocsGPT that combine
|
||||
instructions, extra knowledge, and any combination of skills
|
||||
</p>
|
||||
{/* Premade agents section */}
|
||||
{/* <div className="mt-6">
|
||||
<h2 className="text-[18px] font-semibold text-[#18181B] dark:text-[#E0E0E0]">
|
||||
Premade by DocsGPT
|
||||
</h2>
|
||||
<div className="mt-4 flex w-full flex-wrap gap-4">
|
||||
{Array.from({ length: 5 }, (_, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="relative flex h-44 w-48 flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 dark:bg-[#383838]"
|
||||
>
|
||||
<button onClick={() => {}} className="absolute right-4 top-4">
|
||||
<img
|
||||
src={Copy}
|
||||
alt={'use-agent'}
|
||||
className="h-[19px] w-[19px]"
|
||||
/>
|
||||
</button>
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center px-1">
|
||||
<img
|
||||
src={Robot}
|
||||
alt="agent-logo"
|
||||
className="h-7 w-7 rounded-full"
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<p
|
||||
title={''}
|
||||
className="truncate px-1 text-[13px] font-semibold capitalize leading-relaxed text-raisin-black-light dark:text-bright-gray"
|
||||
>
|
||||
{}
|
||||
</p>
|
||||
<p className="mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-old-silver dark:text-sonic-silver-light">
|
||||
{}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="absolute bottom-4 right-4"></div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div> */}
|
||||
<AgentSection
|
||||
agents={agents ?? []}
|
||||
updateAgents={(updatedAgents) => {
|
||||
dispatch(setAgents(updatedAgents));
|
||||
}}
|
||||
loading={loadingUserAgents}
|
||||
section="user"
|
||||
/>
|
||||
<AgentSection
|
||||
agents={sharedAgents ?? []}
|
||||
updateAgents={(updatedAgents) => {
|
||||
dispatch(setSharedAgents(updatedAgents));
|
||||
}}
|
||||
loading={loadingSharedAgents}
|
||||
section="shared"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function AgentSection({
|
||||
agents,
|
||||
updateAgents,
|
||||
loading,
|
||||
section,
|
||||
}: {
|
||||
agents: Agent[];
|
||||
updateAgents?: (agents: Agent[]) => void;
|
||||
loading: boolean;
|
||||
section: keyof typeof sectionConfig;
|
||||
}) {
|
||||
const navigate = useNavigate();
|
||||
return (
|
||||
<div className="mt-8 flex flex-col gap-4">
|
||||
<div className="flex w-full items-center justify-between">
|
||||
<div className="flex flex-col gap-2">
|
||||
<h2 className="text-[18px] font-semibold text-[#18181B] dark:text-[#E0E0E0]">
|
||||
{sectionConfig[section].title}
|
||||
</h2>
|
||||
<p className="text-[13px] text-[#71717A]">
|
||||
{sectionConfig[section].description}
|
||||
</p>
|
||||
</div>
|
||||
{sectionConfig[section].showNewAgentButton && (
|
||||
<button
|
||||
className="bg-purple-30 hover:bg-violets-are-blue rounded-full px-4 py-2 text-sm text-white"
|
||||
onClick={() => navigate('/agents/new')}
|
||||
>
|
||||
New Agent
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
{loading ? (
|
||||
<div className="flex h-72 w-full items-center justify-center">
|
||||
<Spinner />
|
||||
</div>
|
||||
) : agents && agents.length > 0 ? (
|
||||
<div className="grid grid-cols-1 gap-4 sm:flex sm:flex-wrap">
|
||||
{agents.map((agent, idx) => (
|
||||
<AgentCard
|
||||
key={agent.id}
|
||||
agent={agent}
|
||||
agents={agents}
|
||||
updateAgents={updateAgents}
|
||||
section={section}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex h-72 w-full flex-col items-center justify-center gap-3 text-base text-[#18181B] dark:text-[#E0E0E0]">
|
||||
<p>{sectionConfig[section].emptyStateDescription}</p>
|
||||
{sectionConfig[section].showNewAgentButton && (
|
||||
<button
|
||||
className="bg-purple-30 hover:bg-violets-are-blue ml-2 rounded-full px-4 py-2 text-sm text-white"
|
||||
onClick={() => navigate('/agents/new')}
|
||||
>
|
||||
New Agent
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function AgentCard({
|
||||
agent,
|
||||
agents,
|
||||
updateAgents,
|
||||
section,
|
||||
}: {
|
||||
agent: Agent;
|
||||
agents: Agent[];
|
||||
updateAgents?: (agents: Agent[]) => void;
|
||||
section: keyof typeof sectionConfig;
|
||||
}) {
|
||||
const navigate = useNavigate();
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
|
||||
const [isMenuOpen, setIsMenuOpen] = useState<boolean>(false);
|
||||
const [deleteConfirmation, setDeleteConfirmation] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
|
||||
const menuRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const togglePin = async () => {
|
||||
try {
|
||||
const response = await userService.togglePinAgent(agent.id ?? '', token);
|
||||
if (!response.ok) throw new Error('Failed to pin agent');
|
||||
const updatedAgents = agents.map((prevAgent) => {
|
||||
if (prevAgent.id === agent.id)
|
||||
return { ...prevAgent, pinned: !prevAgent.pinned };
|
||||
return prevAgent;
|
||||
});
|
||||
updateAgents?.(updatedAgents);
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleHideSharedAgent = async () => {
|
||||
try {
|
||||
const response = await userService.removeSharedAgent(
|
||||
agent.id ?? '',
|
||||
token,
|
||||
);
|
||||
if (!response.ok) throw new Error('Failed to hide shared agent');
|
||||
const updatedAgents = agents.filter(
|
||||
(prevAgent) => prevAgent.id !== agent.id,
|
||||
);
|
||||
updateAgents?.(updatedAgents);
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const menuOptionsConfig: Record<string, MenuOption[]> = {
|
||||
user: [
|
||||
{
|
||||
icon: Monitoring,
|
||||
label: 'Logs',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
navigate(`/agents/logs/${agent.id}`);
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 14,
|
||||
iconHeight: 14,
|
||||
},
|
||||
{
|
||||
icon: Edit,
|
||||
label: 'Edit',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
navigate(`/agents/edit/${agent.id}`);
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 14,
|
||||
iconHeight: 14,
|
||||
},
|
||||
...(agent.status === 'published'
|
||||
? [
|
||||
{
|
||||
icon: agent.pinned ? UnPin : Pin,
|
||||
label: agent.pinned ? 'Unpin' : 'Pin agent',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
togglePin();
|
||||
},
|
||||
variant: 'primary' as const,
|
||||
iconWidth: 18,
|
||||
iconHeight: 18,
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
icon: Trash,
|
||||
label: 'Delete',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
setDeleteConfirmation('ACTIVE');
|
||||
},
|
||||
variant: 'danger',
|
||||
iconWidth: 13,
|
||||
iconHeight: 13,
|
||||
},
|
||||
],
|
||||
shared: [
|
||||
{
|
||||
icon: Link,
|
||||
label: 'Open',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 12,
|
||||
iconHeight: 12,
|
||||
},
|
||||
{
|
||||
icon: agent.pinned ? UnPin : Pin,
|
||||
label: agent.pinned ? 'Unpin' : 'Pin agent',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
togglePin();
|
||||
},
|
||||
variant: 'primary',
|
||||
iconWidth: 18,
|
||||
iconHeight: 18,
|
||||
},
|
||||
{
|
||||
icon: Trash,
|
||||
label: 'Remove',
|
||||
onClick: (e: SyntheticEvent) => {
|
||||
e.stopPropagation();
|
||||
handleHideSharedAgent();
|
||||
},
|
||||
variant: 'danger',
|
||||
iconWidth: 13,
|
||||
iconHeight: 13,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const menuOptions = menuOptionsConfig[section] || [];
|
||||
|
||||
const handleClick = () => {
|
||||
if (section === 'user') {
|
||||
if (agent.status === 'published') {
|
||||
dispatch(setSelectedAgent(agent));
|
||||
navigate(`/`);
|
||||
}
|
||||
}
|
||||
if (section === 'shared') {
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async (agentId: string) => {
|
||||
const response = await userService.deleteAgent(agentId, token);
|
||||
if (!response.ok) throw new Error('Failed to delete agent');
|
||||
const data = await response.json();
|
||||
dispatch(setAgents(agents.filter((prevAgent) => prevAgent.id !== data.id)));
|
||||
};
|
||||
return (
|
||||
<div
|
||||
className={`relative flex h-44 w-full flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] md:w-48 dark:bg-[#383838] dark:hover:bg-[#383838]/80 ${agent.status === 'published' && 'cursor-pointer'}`}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleClick();
|
||||
}}
|
||||
>
|
||||
<div
|
||||
ref={menuRef}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsMenuOpen(true);
|
||||
}}
|
||||
className="absolute top-4 right-4 z-10 cursor-pointer"
|
||||
>
|
||||
<img src={ThreeDots} alt={'use-agent'} className="h-[19px] w-[19px]" />
|
||||
<ContextMenu
|
||||
isOpen={isMenuOpen}
|
||||
setIsOpen={setIsMenuOpen}
|
||||
options={menuOptions}
|
||||
anchorRef={menuRef}
|
||||
position="bottom-right"
|
||||
offset={{ x: 0, y: 0 }}
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center gap-1 px-1">
|
||||
<img
|
||||
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
|
||||
alt={`${agent.name}`}
|
||||
className="h-7 w-7 rounded-full object-contain"
|
||||
/>
|
||||
{agent.status === 'draft' && (
|
||||
<p className="text-xs text-black opacity-50 dark:text-[#E0E0E0]">{`(Draft)`}</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<p
|
||||
title={agent.name}
|
||||
className="truncate px-1 text-[13px] leading-relaxed font-semibold text-[#020617] capitalize dark:text-[#E0E0E0]"
|
||||
>
|
||||
{agent.name}
|
||||
</p>
|
||||
<p className="dark:text-sonic-silver-light mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-[#64748B]">
|
||||
{agent.description}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<ConfirmationModal
|
||||
message="Are you sure you want to delete this agent?"
|
||||
modalState={deleteConfirmation}
|
||||
setModalState={setDeleteConfirmation}
|
||||
submitLabel="Delete"
|
||||
handleSubmit={() => {
|
||||
handleDelete(agent.id || '');
|
||||
setDeleteConfirmation('INACTIVE');
|
||||
}}
|
||||
cancelLabel="Cancel"
|
||||
variant="danger"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ const endpoints = {
|
||||
SHARED_AGENTS: '/api/shared_agents',
|
||||
SHARE_AGENT: `/api/share_agent`,
|
||||
REMOVE_SHARED_AGENT: (id: string) => `/api/remove_shared_agent?id=${id}`,
|
||||
TEMPLATE_AGENTS: '/api/template_agents',
|
||||
ADOPT_AGENT: (id: string) => `/api/adopt_agent?id=${id}`,
|
||||
AGENT_WEBHOOK: (id: string) => `/api/agent_webhook?id=${id}`,
|
||||
PROMPTS: '/api/get_prompts',
|
||||
CREATE_PROMPT: '/api/create_prompt',
|
||||
|
||||
@@ -44,6 +44,10 @@ const userService = {
|
||||
apiClient.put(endpoints.USER.SHARE_AGENT, data, token),
|
||||
removeSharedAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.delete(endpoints.USER.REMOVE_SHARED_AGENT(id), token),
|
||||
getTemplateAgents: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
|
||||
adoptAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.ADOPT_AGENT(id), {}, token),
|
||||
getAgentWebhook: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.AGENT_WEBHOOK(id), token),
|
||||
getPrompts: (token: string | null): Promise<any> =>
|
||||
|
||||
3
frontend/src/assets/check-circle-filled.svg
Normal file
3
frontend/src/assets/check-circle-filled.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="24" height="25" viewBox="0 0 24 25" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 2.5C17.523 2.5 22 6.977 22 12.5C22 18.023 17.523 22.5 12 22.5C6.477 22.5 2 18.023 2 12.5C2 6.977 6.477 2.5 12 2.5ZM15.22 9.47L10.75 13.94L8.78 11.97C8.63783 11.8375 8.44978 11.7654 8.25548 11.7688C8.06118 11.7723 7.87579 11.851 7.73838 11.9884C7.60097 12.1258 7.52225 12.3112 7.51883 12.5055C7.5154 12.6998 7.58752 12.8878 7.72 13.03L10.22 15.53C10.3606 15.6705 10.5512 15.7493 10.75 15.7493C10.9488 15.7493 11.1394 15.6705 11.28 15.53L16.28 10.53C16.3537 10.4613 16.4128 10.3785 16.4538 10.2865C16.4948 10.1945 16.5168 10.0952 16.5186 9.99452C16.5204 9.89382 16.5018 9.79379 16.4641 9.7004C16.4264 9.60701 16.3703 9.52218 16.299 9.45096C16.2278 9.37974 16.143 9.3236 16.0496 9.28588C15.9562 9.24816 15.8562 9.22963 15.7555 9.23141C15.6548 9.23318 15.5555 9.25523 15.4635 9.29622C15.3715 9.33721 15.2887 9.39631 15.22 9.47Z" fill="#0C9D35"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 958 B |
4
frontend/src/assets/duplicate.svg
Normal file
4
frontend/src/assets/duplicate.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="20" height="21" viewBox="0 0 20 21" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M15.8984 5.5H7.22656C5.99687 5.5 5 6.49687 5 7.72656V16.3984C5 17.6281 5.99687 18.625 7.22656 18.625H15.8984C17.1281 18.625 18.125 17.6281 18.125 16.3984V7.72656C18.125 6.49687 17.1281 5.5 15.8984 5.5Z" stroke="#949494" stroke-width="1.25" stroke-linejoin="round"/>
|
||||
<path d="M14.9805 5.5L15 4.5625C14.9984 3.98285 14.7674 3.4274 14.3575 3.01753C13.9476 2.60765 13.3922 2.37665 12.8125 2.375H4.375C3.71256 2.37696 3.07781 2.64098 2.6094 3.1094C2.14098 3.57781 1.87696 4.21256 1.875 4.875V13.3125C1.87665 13.8922 2.10765 14.4476 2.51753 14.8575C2.9274 15.2674 3.48285 15.4984 4.0625 15.5H5M11.5625 8.9375V15.1875M14.6875 12.0625H8.4375" stroke="#949494" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 833 B |
3
frontend/src/assets/warn.svg
Normal file
3
frontend/src/assets/warn.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="20" height="21" viewBox="0 0 20 21" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M15 1.83989C16.5202 2.71758 17.7826 3.97997 18.6603 5.50017C19.538 7.02038 20 8.74483 20 10.5002C20 12.2556 19.5379 13.98 18.6602 15.5002C17.7825 17.0204 16.5201 18.2828 14.9999 19.1605C13.4797 20.0381 11.7552 20.5002 9.99984 20.5001C8.24446 20.5001 6.52002 20.038 4.99984 19.1603C3.47965 18.2826 2.21729 17.0202 1.33963 15.5C0.46198 13.9797 -4.45897e-05 12.2553 3.22765e-09 10.4999L0.00500012 10.1759C0.0610032 8.44888 0.563548 6.76585 1.46364 5.29089C2.36373 3.81592 3.63065 2.59934 5.14089 1.75977C6.65113 0.920205 8.35315 0.486289 10.081 0.50033C11.8089 0.514371 13.5036 0.97589 15 1.83989ZM10 13.4999C9.73478 13.4999 9.48043 13.6052 9.29289 13.7928C9.10536 13.9803 9 14.2347 9 14.4999V14.5099C9 14.7751 9.10536 15.0295 9.29289 15.217C9.48043 15.4045 9.73478 15.5099 10 15.5099C10.2652 15.5099 10.5196 15.4045 10.7071 15.217C10.8946 15.0295 11 14.7751 11 14.5099V14.4999C11 14.2347 10.8946 13.9803 10.7071 13.7928C10.5196 13.6052 10.2652 13.4999 10 13.4999ZM10 6.49989C9.73478 6.49989 9.48043 6.60525 9.29289 6.79279C9.10536 6.98032 9 7.23468 9 7.49989V11.4999C9 11.7651 9.10536 12.0195 9.29289 12.207C9.48043 12.3945 9.73478 12.4999 10 12.4999C10.2652 12.4999 10.5196 12.3945 10.7071 12.207C10.8946 12.0195 11 11.7651 11 11.4999V7.49989C11 7.23468 10.8946 6.98032 10.7071 6.79279C10.5196 6.60525 10.2652 6.49989 10 6.49989Z" fill="#EA4335"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
40
frontend/src/components/AgentImage.tsx
Normal file
40
frontend/src/components/AgentImage.tsx
Normal file
@@ -0,0 +1,40 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import Robot from '../assets/robot.svg';
|
||||
|
||||
type AgentImageProps = {
|
||||
src?: string | null;
|
||||
alt?: string;
|
||||
className?: string;
|
||||
fallbackSrc?: string;
|
||||
};
|
||||
|
||||
export default function AgentImage({
|
||||
src,
|
||||
alt = 'agent',
|
||||
className = '',
|
||||
fallbackSrc = Robot,
|
||||
}: AgentImageProps) {
|
||||
const [currentSrc, setCurrentSrc] = useState(
|
||||
src && src.trim() !== '' ? src : fallbackSrc,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const newSrc = src && src.trim() !== '' ? src : fallbackSrc;
|
||||
if (newSrc !== currentSrc) {
|
||||
setCurrentSrc(newSrc);
|
||||
}
|
||||
}, [src, fallbackSrc]);
|
||||
|
||||
return (
|
||||
<img
|
||||
src={currentSrc}
|
||||
alt={alt}
|
||||
className={className}
|
||||
referrerPolicy="no-referrer"
|
||||
crossOrigin="anonymous"
|
||||
onError={() => {
|
||||
if (currentSrc !== fallbackSrc) setCurrentSrc(fallbackSrc);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import React, { useRef } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDarkTheme } from '../hooks';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
|
||||
@@ -24,6 +25,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
onDisconnect,
|
||||
errorMessage,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const token = useSelector(selectToken);
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const completedRef = useRef(false);
|
||||
@@ -47,12 +49,16 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
cleanup();
|
||||
onSuccess({
|
||||
session_token: event.data.session_token,
|
||||
user_email: event.data.user_email || 'Connected User',
|
||||
user_email:
|
||||
event.data.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
});
|
||||
} else if (errorProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
onError(event.data.error || 'Authentication failed');
|
||||
onError(
|
||||
event.data.error || t('modals.uploadDoc.connectors.auth.authFailed'),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -71,13 +77,15 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
|
||||
if (!authResponse.ok) {
|
||||
throw new Error(
|
||||
`Failed to get authorization URL: ${authResponse.status}`,
|
||||
`${t('modals.uploadDoc.connectors.auth.authUrlFailed')}: ${authResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
const authData = await authResponse.json();
|
||||
if (!authData.success || !authData.authorization_url) {
|
||||
throw new Error(authData.error || 'Failed to get authorization URL');
|
||||
throw new Error(
|
||||
authData.error || t('modals.uploadDoc.connectors.auth.authUrlFailed'),
|
||||
);
|
||||
}
|
||||
|
||||
const authWindow = window.open(
|
||||
@@ -86,9 +94,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
'width=500,height=600,scrollbars=yes,resizable=yes',
|
||||
);
|
||||
if (!authWindow) {
|
||||
throw new Error(
|
||||
'Failed to open authentication window. Please allow popups.',
|
||||
);
|
||||
throw new Error(t('modals.uploadDoc.connectors.auth.popupBlocked'));
|
||||
}
|
||||
|
||||
window.addEventListener('message', handleAuthMessage as any);
|
||||
@@ -98,13 +104,17 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
clearInterval(checkClosed);
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
if (!completedRef.current) {
|
||||
onError('Authentication was cancelled');
|
||||
onError(t('modals.uploadDoc.connectors.auth.authCancelled'));
|
||||
}
|
||||
}
|
||||
}, 1000);
|
||||
intervalRef.current = checkClosed;
|
||||
} catch (error) {
|
||||
onError(error instanceof Error ? error.message : 'Authentication failed');
|
||||
onError(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: t('modals.uploadDoc.connectors.auth.authFailed'),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -147,14 +157,18 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41z"
|
||||
/>
|
||||
</svg>
|
||||
<span>Connected as {userEmail}</span>
|
||||
<span>
|
||||
{t('modals.uploadDoc.connectors.auth.connectedAs', {
|
||||
email: userEmail,
|
||||
})}
|
||||
</span>
|
||||
</div>
|
||||
{onDisconnect && (
|
||||
<button
|
||||
onClick={onDisconnect}
|
||||
className="text-xs font-medium text-[#212121] underline hover:text-gray-700"
|
||||
>
|
||||
Disconnect
|
||||
{t('modals.uploadDoc.connectors.auth.disconnect')}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -8,10 +8,6 @@ import CopyIcon from '../assets/copy.svg?react';
|
||||
|
||||
type CopyButtonProps = {
|
||||
textToCopy: string;
|
||||
bgColorLight?: string;
|
||||
bgColorDark?: string;
|
||||
hoverBgColorLight?: string;
|
||||
hoverBgColorDark?: string;
|
||||
iconSize?: string;
|
||||
padding?: string;
|
||||
showText?: boolean;
|
||||
@@ -27,14 +23,11 @@ const DEFAULT_COPIED_DURATION = 2000;
|
||||
const DEFAULT_BG_LIGHT = '#FFFFFF';
|
||||
const DEFAULT_BG_DARK = 'transparent';
|
||||
const DEFAULT_HOVER_BG_LIGHT = '#EEEEEE';
|
||||
const DEFAULT_HOVER_BG_DARK = '#4A4A4A';
|
||||
const DEFAULT_HOVER_BG_DARK = '#464152';
|
||||
|
||||
export default function CopyButton({
|
||||
textToCopy,
|
||||
bgColorLight = DEFAULT_BG_LIGHT,
|
||||
bgColorDark = DEFAULT_BG_DARK,
|
||||
hoverBgColorLight = DEFAULT_HOVER_BG_LIGHT,
|
||||
hoverBgColorDark = DEFAULT_HOVER_BG_DARK,
|
||||
|
||||
iconSize = DEFAULT_ICON_SIZE,
|
||||
padding = DEFAULT_PADDING,
|
||||
showText = false,
|
||||
@@ -50,9 +43,10 @@ export default function CopyButton({
|
||||
const iconWrapperClasses = clsx(
|
||||
'flex items-center justify-center rounded-full transition-colors duration-150 ease-in-out',
|
||||
padding,
|
||||
`bg-[${bgColorLight}] dark:bg-[${bgColorDark}]`,
|
||||
`hover:bg-[${hoverBgColorLight}] dark:hover:bg-[${hoverBgColorDark}]`,
|
||||
`bg-[${DEFAULT_BG_LIGHT}] dark:bg-[${DEFAULT_BG_DARK}]`,
|
||||
{
|
||||
[`hover:bg-[${DEFAULT_HOVER_BG_LIGHT}] dark:hover:bg-[${DEFAULT_HOVER_BG_DARK}]`]:
|
||||
!isCopied,
|
||||
'bg-green-100 dark:bg-green-900 hover:bg-green-100 dark:hover:bg-green-900':
|
||||
isCopied,
|
||||
},
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import useDrivePicker from 'react-google-drive-picker';
|
||||
|
||||
import ConnectorAuth from './ConnectorAuth';
|
||||
@@ -26,6 +27,7 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
token,
|
||||
onSelectionChange,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const [selectedFiles, setSelectedFiles] = useState<PickerFile[]>([]);
|
||||
const [selectedFolders, setSelectedFolders] = useState<PickerFile[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
@@ -66,14 +68,19 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
|
||||
if (!validateResponse.ok) {
|
||||
setIsConnected(false);
|
||||
setAuthError('Session expired. Please reconnect to Google Drive.');
|
||||
setAuthError(
|
||||
t('modals.uploadDoc.connectors.googleDrive.sessionExpired'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
const validateData = await validateResponse.json();
|
||||
if (validateData.success) {
|
||||
setUserEmail(validateData.user_email || 'Connected User');
|
||||
setUserEmail(
|
||||
validateData.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
setAccessToken(validateData.access_token || null);
|
||||
@@ -83,14 +90,14 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
setIsConnected(false);
|
||||
setAuthError(
|
||||
validateData.error ||
|
||||
'Session expired. Please reconnect your account.',
|
||||
t('modals.uploadDoc.connectors.googleDrive.sessionExpiredGeneric'),
|
||||
);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error validating session:', error);
|
||||
setAuthError('Failed to validate session. Please reconnect.');
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.validateFailed'));
|
||||
setIsConnected(false);
|
||||
setIsValidating(false);
|
||||
return false;
|
||||
@@ -103,15 +110,13 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
const sessionToken = getSessionToken('google_drive');
|
||||
|
||||
if (!sessionToken) {
|
||||
setAuthError('No valid session found. Please reconnect to Google Drive.');
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.noSession'));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!accessToken) {
|
||||
setAuthError(
|
||||
'No access token available. Please reconnect to Google Drive.',
|
||||
);
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.noAccessToken'));
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
@@ -193,7 +198,7 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error opening picker:', error);
|
||||
setAuthError('Failed to open file picker. Please try again.');
|
||||
setAuthError(t('modals.uploadDoc.connectors.googleDrive.pickerFailed'));
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
@@ -264,9 +269,12 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
<>
|
||||
<ConnectorAuth
|
||||
provider="google_drive"
|
||||
label="Connect to Google Drive"
|
||||
label={t('modals.uploadDoc.connectors.googleDrive.connect')}
|
||||
onSuccess={(data) => {
|
||||
setUserEmail(data.user_email || 'Connected User');
|
||||
setUserEmail(
|
||||
data.user_email ||
|
||||
t('modals.uploadDoc.connectors.auth.connectedUser'),
|
||||
);
|
||||
setIsConnected(true);
|
||||
setAuthError('');
|
||||
|
||||
@@ -289,26 +297,34 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
<div className="rounded-lg border border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
||||
<div className="p-4">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h3 className="text-sm font-medium">Selected Files</h3>
|
||||
<h3 className="text-sm font-medium">
|
||||
{t('modals.uploadDoc.connectors.googleDrive.selectedFiles')}
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => handleOpenPicker()}
|
||||
className="rounded-md bg-[#A076F6] px-3 py-1 text-sm text-white hover:bg-[#8A5FD4]"
|
||||
disabled={isLoading}
|
||||
>
|
||||
{isLoading ? 'Loading...' : 'Select Files'}
|
||||
{isLoading
|
||||
? t('modals.uploadDoc.connectors.googleDrive.loading')
|
||||
: t(
|
||||
'modals.uploadDoc.connectors.googleDrive.selectFiles',
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{selectedFiles.length === 0 && selectedFolders.length === 0 ? (
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||
No files or folders selected
|
||||
{t(
|
||||
'modals.uploadDoc.connectors.googleDrive.noFilesSelected',
|
||||
)}
|
||||
</p>
|
||||
) : (
|
||||
<div className="max-h-60 overflow-y-auto">
|
||||
{selectedFolders.length > 0 && (
|
||||
<div className="mb-2">
|
||||
<h4 className="mb-1 text-xs font-medium text-gray-500">
|
||||
Folders
|
||||
{t('modals.uploadDoc.connectors.googleDrive.folders')}
|
||||
</h4>
|
||||
{selectedFolders.map((folder) => (
|
||||
<div
|
||||
@@ -317,7 +333,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
>
|
||||
<img
|
||||
src={folder.iconUrl}
|
||||
alt="Folder"
|
||||
alt={t(
|
||||
'modals.uploadDoc.connectors.googleDrive.folderAlt',
|
||||
)}
|
||||
className="mr-2 h-5 w-5"
|
||||
/>
|
||||
<span className="flex-1 truncate text-sm">
|
||||
@@ -337,7 +355,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
}}
|
||||
className="ml-2 text-sm text-red-500 hover:text-red-700"
|
||||
>
|
||||
Remove
|
||||
{t(
|
||||
'modals.uploadDoc.connectors.googleDrive.remove',
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
@@ -347,7 +367,7 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
{selectedFiles.length > 0 && (
|
||||
<div>
|
||||
<h4 className="mb-1 text-xs font-medium text-gray-500">
|
||||
Files
|
||||
{t('modals.uploadDoc.connectors.googleDrive.files')}
|
||||
</h4>
|
||||
{selectedFiles.map((file) => (
|
||||
<div
|
||||
@@ -356,7 +376,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
>
|
||||
<img
|
||||
src={file.iconUrl}
|
||||
alt="File"
|
||||
alt={t(
|
||||
'modals.uploadDoc.connectors.googleDrive.fileAlt',
|
||||
)}
|
||||
className="mr-2 h-5 w-5"
|
||||
/>
|
||||
<span className="flex-1 truncate text-sm">
|
||||
@@ -375,7 +397,9 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
}}
|
||||
className="ml-2 text-sm text-red-500 hover:text-red-700"
|
||||
>
|
||||
Remove
|
||||
{t(
|
||||
'modals.uploadDoc.connectors.googleDrive.remove',
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
229
frontend/src/components/UploadToast.tsx
Normal file
229
frontend/src/components/UploadToast.tsx
Normal file
@@ -0,0 +1,229 @@
|
||||
import { useState } from 'react';
|
||||
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { selectUploadTasks, dismissUploadTask } from '../upload/uploadSlice';
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import CheckCircleFilled from '../assets/check-circle-filled.svg';
|
||||
import WarnIcon from '../assets/warn.svg';
|
||||
|
||||
const PROGRESS_RADIUS = 10;
|
||||
const PROGRESS_CIRCUMFERENCE = 2 * Math.PI * PROGRESS_RADIUS;
|
||||
|
||||
export default function UploadToast() {
|
||||
const [collapsedTasks, setCollapsedTasks] = useState<Record<string, boolean>>(
|
||||
{},
|
||||
);
|
||||
|
||||
const toggleTaskCollapse = (taskId: string) => {
|
||||
setCollapsedTasks((prev) => ({
|
||||
...prev,
|
||||
[taskId]: !prev[taskId],
|
||||
}));
|
||||
};
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const uploadTasks = useSelector(selectUploadTasks);
|
||||
|
||||
const getStatusHeading = (status: string) => {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('attachments.uploadFailed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2">
|
||||
{uploadTasks
|
||||
.filter((task) => !task.dismissed)
|
||||
.map((task) => {
|
||||
const shouldShowProgress = [
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
].includes(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset =
|
||||
PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
const isCollapsed = collapsedTasks[task.id] ?? false;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={task.id}
|
||||
className={`w-[271px] overflow-hidden rounded-2xl border border-[#00000021] shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300 ${
|
||||
task.status === 'completed'
|
||||
? 'bg-[#FBFBFB] dark:bg-[#26272E]'
|
||||
: task.status === 'failed'
|
||||
? 'bg-[#FBFBFB] dark:bg-[#26272E]'
|
||||
: 'bg-[#FBFBFB] dark:bg-[#26272E]'
|
||||
}`}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
task.status !== 'failed'
|
||||
? 'bg-[#FBF2FE] dark:bg-transparent'
|
||||
: ''
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter text-[14px] leading-[16.5px] font-medium text-black dark:text-[#DCDCDC]">
|
||||
{getStatusHeading(task.status)}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleTaskCollapse(task.id)}
|
||||
aria-label={
|
||||
isCollapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
isCollapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => dispatch(dismissUploadTask(task.id))}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: isCollapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
isCollapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black dark:text-[#B7BAB8]"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{shouldShowProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 flex-shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t(
|
||||
'modals.uploadDoc.progress.uploadProgress',
|
||||
{
|
||||
progress: formattedProgress,
|
||||
},
|
||||
)}
|
||||
>
|
||||
<circle
|
||||
className="text-gray-300 dark:text-gray-700"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 flex-shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 flex-shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' && task.errorMessage && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -188,7 +188,7 @@ const ConversationBubble = forwardRef<
|
||||
setIsEditClicked(true);
|
||||
setEditInputBox(message ?? '');
|
||||
}}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 dark:hover:bg-[#35363B] ${isQuestionHovered || isEditClicked ? 'visible' : 'invisible'}`}
|
||||
className={`hover:bg-light-silver mt-3 flex h-fit shrink-0 cursor-pointer items-center rounded-full p-2 pt-1.5 pl-1.5 dark:hover:bg-[#35363B] ${isQuestionHovered || isEditClicked ? 'visible' : 'invisible'}`}
|
||||
>
|
||||
<img src={Edit} alt="Edit" className="cursor-pointer" />
|
||||
</button>
|
||||
@@ -407,7 +407,7 @@ const ConversationBubble = forwardRef<
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
className={`fade-in-bubble bg-gray-1000 dark:bg-gun-metal mr-5 flex max-w-full rounded-[28px] px-7 py-[18px] ${
|
||||
className={`fade-in-bubble bg-gray-1000 dark:bg-gun-metal mr-5 flex max-w-full rounded-[18px] px-6 py-4.5 ${
|
||||
type === 'ERROR'
|
||||
? 'text-red-3000 dark:border-red-2000 relative flex-row items-center rounded-full border border-transparent bg-[#FFE7E7] p-2 py-5 text-sm font-normal dark:text-white'
|
||||
: 'flex-col rounded-3xl'
|
||||
|
||||
@@ -229,6 +229,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "Upload new document",
|
||||
"select": "Choose how to upload your document to DocsGPT",
|
||||
"selectSource": "Select the way to add your source",
|
||||
"selectedFiles": "Selected Files",
|
||||
"noFilesSelected": "No files selected",
|
||||
"file": "Upload from device",
|
||||
"back": "Back",
|
||||
"wait": "Please wait ...",
|
||||
@@ -257,13 +260,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Upload is in progress",
|
||||
"training": "Training is in progress",
|
||||
"completed": "Training completed",
|
||||
"training": "Upload is in progress",
|
||||
"completed": "Upload completed",
|
||||
"wait": "This may take several minutes",
|
||||
"tokenLimit": "Over the token limit, please consider uploading smaller document"
|
||||
"preparing": "Preparing upload",
|
||||
"tokenLimit": "Over the token limit, please consider uploading smaller document",
|
||||
"expandDetails": "Expand upload details",
|
||||
"collapseDetails": "Collapse upload details",
|
||||
"dismiss": "Dismiss upload toast",
|
||||
"uploadProgress": "Upload progress {{progress}}%",
|
||||
"clear": "Clear"
|
||||
},
|
||||
"showAdvanced": "Show advanced options",
|
||||
"hideAdvanced": "Hide advanced options"
|
||||
"hideAdvanced": "Hide advanced options",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "Upload File",
|
||||
"heading": "Upload new document"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "Crawler",
|
||||
"heading": "Add content with Web Crawler"
|
||||
},
|
||||
"url": {
|
||||
"label": "Link",
|
||||
"heading": "Add content from URL"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "Add content from GitHub"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Add content from Reddit"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Upload from Google Drive"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "Connected User",
|
||||
"authFailed": "Authentication failed",
|
||||
"authUrlFailed": "Failed to get authorization URL",
|
||||
"popupBlocked": "Failed to open authentication window. Please allow popups.",
|
||||
"authCancelled": "Authentication was cancelled",
|
||||
"connectedAs": "Connected as {{email}}",
|
||||
"disconnect": "Disconnect"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Connect to Google Drive",
|
||||
"sessionExpired": "Session expired. Please reconnect to Google Drive.",
|
||||
"sessionExpiredGeneric": "Session expired. Please reconnect your account.",
|
||||
"validateFailed": "Failed to validate session. Please reconnect.",
|
||||
"noSession": "No valid session found. Please reconnect to Google Drive.",
|
||||
"noAccessToken": "No access token available. Please reconnect to Google Drive.",
|
||||
"pickerFailed": "Failed to open file picker. Please try again.",
|
||||
"selectedFiles": "Selected Files",
|
||||
"selectFiles": "Select Files",
|
||||
"loading": "Loading...",
|
||||
"noFilesSelected": "No files or folders selected",
|
||||
"folders": "Folders",
|
||||
"files": "Files",
|
||||
"remove": "Remove",
|
||||
"folderAlt": "Folder",
|
||||
"fileAlt": "File"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "Create New API Key",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "Subir nuevo documento",
|
||||
"select": "Elige cómo cargar tu documento en DocsGPT",
|
||||
"selectSource": "Selecciona la forma de agregar tu fuente",
|
||||
"selectedFiles": "Archivos Seleccionados",
|
||||
"noFilesSelected": "No hay archivos seleccionados",
|
||||
"file": "Subir desde el dispositivo",
|
||||
"back": "Atrás",
|
||||
"wait": "Por favor espera ...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Subida en progreso",
|
||||
"training": "Entrenamiento en progreso",
|
||||
"completed": "Entrenamiento completado",
|
||||
"training": "Subida en progreso",
|
||||
"completed": "Subida completada",
|
||||
"wait": "Esto puede tardar varios minutos",
|
||||
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño"
|
||||
"preparing": "Preparando subida",
|
||||
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño",
|
||||
"expandDetails": "Expandir detalles de subida",
|
||||
"collapseDetails": "Contraer detalles de subida",
|
||||
"dismiss": "Descartar notificación de subida",
|
||||
"uploadProgress": "Progreso de subida {{progress}}%",
|
||||
"clear": "Limpiar"
|
||||
},
|
||||
"showAdvanced": "Mostrar opciones avanzadas",
|
||||
"hideAdvanced": "Ocultar opciones avanzadas"
|
||||
"hideAdvanced": "Ocultar opciones avanzadas",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "Subir archivo",
|
||||
"heading": "Subir nuevo documento"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "Rastreador",
|
||||
"heading": "Agregar contenido con rastreador web"
|
||||
},
|
||||
"url": {
|
||||
"label": "Enlace",
|
||||
"heading": "Agregar contenido desde URL"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "Agregar contenido desde GitHub"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Agregar contenido desde Reddit"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Subir desde Google Drive"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "Usuario Conectado",
|
||||
"authFailed": "Autenticación fallida",
|
||||
"authUrlFailed": "Error al obtener la URL de autorización",
|
||||
"popupBlocked": "Error al abrir la ventana de autenticación. Por favor, permita ventanas emergentes.",
|
||||
"authCancelled": "Autenticación cancelada",
|
||||
"connectedAs": "Conectado como {{email}}",
|
||||
"disconnect": "Desconectar"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Conectar a Google Drive",
|
||||
"sessionExpired": "Sesión expirada. Por favor, reconecte a Google Drive.",
|
||||
"sessionExpiredGeneric": "Sesión expirada. Por favor, reconecte su cuenta.",
|
||||
"validateFailed": "Error al validar la sesión. Por favor, reconecte.",
|
||||
"noSession": "No se encontró una sesión válida. Por favor, reconecte a Google Drive.",
|
||||
"noAccessToken": "No hay token de acceso disponible. Por favor, reconecte a Google Drive.",
|
||||
"pickerFailed": "Error al abrir el selector de archivos. Por favor, inténtelo de nuevo.",
|
||||
"selectedFiles": "Archivos Seleccionados",
|
||||
"selectFiles": "Seleccionar Archivos",
|
||||
"loading": "Cargando...",
|
||||
"noFilesSelected": "No hay archivos o carpetas seleccionados",
|
||||
"folders": "Carpetas",
|
||||
"files": "Archivos",
|
||||
"remove": "Eliminar",
|
||||
"folderAlt": "Carpeta",
|
||||
"fileAlt": "Archivo"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "Crear Nueva Clave de API",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "新しい文書をアップロードする",
|
||||
"select": "ドキュメントを DocsGPT にアップロードする方法を選択します",
|
||||
"selectSource": "ソースを追加する方法を選択してください",
|
||||
"selectedFiles": "選択されたファイル",
|
||||
"noFilesSelected": "ファイルが選択されていません",
|
||||
"file": "デバイスからアップロード",
|
||||
"back": "戻る",
|
||||
"wait": "お待ちください ...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "アップロード中",
|
||||
"training": "トレーニング中",
|
||||
"completed": "トレーニング完了",
|
||||
"training": "アップロード中",
|
||||
"completed": "アップロード完了",
|
||||
"wait": "数分かかる場合があります",
|
||||
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください"
|
||||
"preparing": "アップロードを準備中",
|
||||
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください",
|
||||
"expandDetails": "アップロードの詳細を展開",
|
||||
"collapseDetails": "アップロードの詳細を折りたたむ",
|
||||
"dismiss": "アップロード通知を閉じる",
|
||||
"uploadProgress": "アップロード進行状況 {{progress}}%",
|
||||
"clear": "クリア"
|
||||
},
|
||||
"showAdvanced": "詳細オプションを表示",
|
||||
"hideAdvanced": "詳細オプションを非表示"
|
||||
"hideAdvanced": "詳細オプションを非表示",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "ファイルをアップロード",
|
||||
"heading": "新しいドキュメントをアップロード"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "クローラー",
|
||||
"heading": "Webクローラーでコンテンツを追加"
|
||||
},
|
||||
"url": {
|
||||
"label": "リンク",
|
||||
"heading": "URLからコンテンツを追加"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "GitHubからコンテンツを追加"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Redditからコンテンツを追加"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Google Driveからアップロード"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "接続されたユーザー",
|
||||
"authFailed": "認証に失敗しました",
|
||||
"authUrlFailed": "認証URLの取得に失敗しました",
|
||||
"popupBlocked": "認証ウィンドウを開けませんでした。ポップアップを許可してください。",
|
||||
"authCancelled": "認証がキャンセルされました",
|
||||
"connectedAs": "{{email}}として接続",
|
||||
"disconnect": "切断"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Google Driveに接続",
|
||||
"sessionExpired": "セッションが期限切れです。Google Driveに再接続してください。",
|
||||
"sessionExpiredGeneric": "セッションが期限切れです。アカウントに再接続してください。",
|
||||
"validateFailed": "セッションの検証に失敗しました。再接続してください。",
|
||||
"noSession": "有効なセッションが見つかりません。Google Driveに再接続してください。",
|
||||
"noAccessToken": "アクセストークンが利用できません。Google Driveに再接続してください。",
|
||||
"pickerFailed": "ファイルピッカーを開けませんでした。もう一度お試しください。",
|
||||
"selectedFiles": "選択されたファイル",
|
||||
"selectFiles": "ファイルを選択",
|
||||
"loading": "読み込み中...",
|
||||
"noFilesSelected": "ファイルまたはフォルダが選択されていません",
|
||||
"folders": "フォルダ",
|
||||
"files": "ファイル",
|
||||
"remove": "削除",
|
||||
"folderAlt": "フォルダ",
|
||||
"fileAlt": "ファイル"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "新しいAPIキーを作成",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "Загрузить новый документ",
|
||||
"select": "Выберите способ загрузки документа в DocsGPT",
|
||||
"selectSource": "Выберите способ добавления источника",
|
||||
"selectedFiles": "Выбранные файлы",
|
||||
"noFilesSelected": "Файлы не выбраны",
|
||||
"file": "Загрузить с устройства",
|
||||
"back": "Назад",
|
||||
"wait": "Пожалуйста, подождите...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "Идет загрузка",
|
||||
"training": "Идет обучение",
|
||||
"completed": "Обучение завершено",
|
||||
"training": "Идет загрузка",
|
||||
"completed": "Загрузка завершена",
|
||||
"wait": "Это может занять несколько минут",
|
||||
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера"
|
||||
"preparing": "Подготовка загрузки",
|
||||
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера",
|
||||
"expandDetails": "Развернуть детали загрузки",
|
||||
"collapseDetails": "Свернуть детали загрузки",
|
||||
"dismiss": "Закрыть уведомление о загрузке",
|
||||
"uploadProgress": "Прогресс загрузки {{progress}}%",
|
||||
"clear": "Очистить"
|
||||
},
|
||||
"showAdvanced": "Показать расширенные настройки",
|
||||
"hideAdvanced": "Скрыть расширенные настройки"
|
||||
"hideAdvanced": "Скрыть расширенные настройки",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "Загрузить файл",
|
||||
"heading": "Загрузить новый документ"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "Краулер",
|
||||
"heading": "Добавить контент с помощью веб-краулера"
|
||||
},
|
||||
"url": {
|
||||
"label": "Ссылка",
|
||||
"heading": "Добавить контент из URL"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "Добавить контент из GitHub"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "Добавить контент из Reddit"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "Загрузить из Google Drive"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "Подключенный пользователь",
|
||||
"authFailed": "Ошибка аутентификации",
|
||||
"authUrlFailed": "Не удалось получить URL авторизации",
|
||||
"popupBlocked": "Не удалось открыть окно аутентификации. Пожалуйста, разрешите всплывающие окна.",
|
||||
"authCancelled": "Аутентификация отменена",
|
||||
"connectedAs": "Подключен как {{email}}",
|
||||
"disconnect": "Отключить"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "Подключиться к Google Drive",
|
||||
"sessionExpired": "Сеанс истек. Пожалуйста, переподключитесь к Google Drive.",
|
||||
"sessionExpiredGeneric": "Сеанс истек. Пожалуйста, переподключите свою учетную запись.",
|
||||
"validateFailed": "Не удалось проверить сеанс. Пожалуйста, переподключитесь.",
|
||||
"noSession": "Действительный сеанс не найден. Пожалуйста, переподключитесь к Google Drive.",
|
||||
"noAccessToken": "Токен доступа недоступен. Пожалуйста, переподключитесь к Google Drive.",
|
||||
"pickerFailed": "Не удалось открыть средство выбора файлов. Пожалуйста, попробуйте еще раз.",
|
||||
"selectedFiles": "Выбранные файлы",
|
||||
"selectFiles": "Выбрать файлы",
|
||||
"loading": "Загрузка...",
|
||||
"noFilesSelected": "Файлы или папки не выбраны",
|
||||
"folders": "Папки",
|
||||
"files": "Файлы",
|
||||
"remove": "Удалить",
|
||||
"folderAlt": "Папка",
|
||||
"fileAlt": "Файл"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "Создать новый API ключ",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "上傳新文件",
|
||||
"select": "選擇如何將文件上傳到 DocsGPT",
|
||||
"selectSource": "選擇新增來源的方式",
|
||||
"selectedFiles": "已選擇的檔案",
|
||||
"noFilesSelected": "未選擇檔案",
|
||||
"file": "從檔案",
|
||||
"remote": "遠端",
|
||||
"back": "返回",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "正在上傳",
|
||||
"training": "正在訓練",
|
||||
"completed": "訓練完成",
|
||||
"training": "正在上傳",
|
||||
"completed": "上傳完成",
|
||||
"wait": "這可能需要幾分鐘",
|
||||
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔"
|
||||
"preparing": "準備上傳",
|
||||
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔",
|
||||
"expandDetails": "展開上傳詳情",
|
||||
"collapseDetails": "摺疊上傳詳情",
|
||||
"dismiss": "關閉上傳通知",
|
||||
"uploadProgress": "上傳進度 {{progress}}%",
|
||||
"clear": "清除"
|
||||
},
|
||||
"showAdvanced": "顯示進階選項",
|
||||
"hideAdvanced": "隱藏進階選項"
|
||||
"hideAdvanced": "隱藏進階選項",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "上傳檔案",
|
||||
"heading": "上傳新文檔"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "爬蟲",
|
||||
"heading": "使用網路爬蟲新增內容"
|
||||
},
|
||||
"url": {
|
||||
"label": "連結",
|
||||
"heading": "從URL新增內容"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "從GitHub新增內容"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "從Reddit新增內容"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "從Google Drive上傳"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "已連接使用者",
|
||||
"authFailed": "驗證失敗",
|
||||
"authUrlFailed": "取得授權URL失敗",
|
||||
"popupBlocked": "無法開啟驗證視窗。請允許彈出視窗。",
|
||||
"authCancelled": "驗證已取消",
|
||||
"connectedAs": "已連接為 {{email}}",
|
||||
"disconnect": "中斷連接"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "連接到 Google Drive",
|
||||
"sessionExpired": "工作階段已過期。請重新連接到 Google Drive。",
|
||||
"sessionExpiredGeneric": "工作階段已過期。請重新連接您的帳戶。",
|
||||
"validateFailed": "驗證工作階段失敗。請重新連接。",
|
||||
"noSession": "未找到有效工作階段。請重新連接到 Google Drive。",
|
||||
"noAccessToken": "存取權杖不可用。請重新連接到 Google Drive。",
|
||||
"pickerFailed": "無法開啟檔案選擇器。請重試。",
|
||||
"selectedFiles": "已選擇的檔案",
|
||||
"selectFiles": "選擇檔案",
|
||||
"loading": "載入中...",
|
||||
"noFilesSelected": "未選擇檔案或資料夾",
|
||||
"folders": "資料夾",
|
||||
"files": "檔案",
|
||||
"remove": "移除",
|
||||
"folderAlt": "資料夾",
|
||||
"fileAlt": "檔案"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "建立新的 API 金鑰",
|
||||
|
||||
@@ -192,6 +192,9 @@
|
||||
"uploadDoc": {
|
||||
"label": "上传新文档",
|
||||
"select": "选择如何将文档上传到 DocsGPT",
|
||||
"selectSource": "选择添加源的方式",
|
||||
"selectedFiles": "已选择的文件",
|
||||
"noFilesSelected": "未选择文件",
|
||||
"file": "从设备上传",
|
||||
"back": "后退",
|
||||
"wait": "请稍等 ...",
|
||||
@@ -220,13 +223,74 @@
|
||||
},
|
||||
"progress": {
|
||||
"upload": "正在上传",
|
||||
"training": "正在训练",
|
||||
"completed": "训练完成",
|
||||
"training": "正在上传",
|
||||
"completed": "上传完成",
|
||||
"wait": "这可能需要几分钟",
|
||||
"tokenLimit": "超出令牌限制,请考虑上传较小的文档"
|
||||
"preparing": "准备上传",
|
||||
"tokenLimit": "超出令牌限制,请考虑上传较小的文档",
|
||||
"expandDetails": "展开上传详情",
|
||||
"collapseDetails": "折叠上传详情",
|
||||
"dismiss": "关闭上传通知",
|
||||
"uploadProgress": "上传进度 {{progress}}%",
|
||||
"clear": "清除"
|
||||
},
|
||||
"showAdvanced": "显示高级选项",
|
||||
"hideAdvanced": "隐藏高级选项"
|
||||
"hideAdvanced": "隐藏高级选项",
|
||||
"ingestors": {
|
||||
"local_file": {
|
||||
"label": "上传文件",
|
||||
"heading": "上传新文档"
|
||||
},
|
||||
"crawler": {
|
||||
"label": "爬虫",
|
||||
"heading": "使用网络爬虫添加内容"
|
||||
},
|
||||
"url": {
|
||||
"label": "链接",
|
||||
"heading": "从URL添加内容"
|
||||
},
|
||||
"github": {
|
||||
"label": "GitHub",
|
||||
"heading": "从GitHub添加内容"
|
||||
},
|
||||
"reddit": {
|
||||
"label": "Reddit",
|
||||
"heading": "从Reddit添加内容"
|
||||
},
|
||||
"google_drive": {
|
||||
"label": "Google Drive",
|
||||
"heading": "从Google Drive上传"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
"auth": {
|
||||
"connectedUser": "已连接用户",
|
||||
"authFailed": "身份验证失败",
|
||||
"authUrlFailed": "获取授权URL失败",
|
||||
"popupBlocked": "无法打开身份验证窗口。请允许弹出窗口。",
|
||||
"authCancelled": "身份验证已取消",
|
||||
"connectedAs": "已连接为 {{email}}",
|
||||
"disconnect": "断开连接"
|
||||
},
|
||||
"googleDrive": {
|
||||
"connect": "连接到 Google Drive",
|
||||
"sessionExpired": "会话已过期。请重新连接到 Google Drive。",
|
||||
"sessionExpiredGeneric": "会话已过期。请重新连接您的账户。",
|
||||
"validateFailed": "验证会话失败。请重新连接。",
|
||||
"noSession": "未找到有效会话。请重新连接到 Google Drive。",
|
||||
"noAccessToken": "访问令牌不可用。请重新连接到 Google Drive。",
|
||||
"pickerFailed": "无法打开文件选择器。请重试。",
|
||||
"selectedFiles": "已选择的文件",
|
||||
"selectFiles": "选择文件",
|
||||
"loading": "加载中...",
|
||||
"noFilesSelected": "未选择文件或文件夹",
|
||||
"folders": "文件夹",
|
||||
"files": "文件",
|
||||
"remove": "删除",
|
||||
"folderAlt": "文件夹",
|
||||
"fileAlt": "文件"
|
||||
}
|
||||
}
|
||||
},
|
||||
"createAPIKey": {
|
||||
"label": "创建新的 API 密钥",
|
||||
|
||||
@@ -24,6 +24,7 @@ export interface Preference {
|
||||
token: string | null;
|
||||
modalState: ActiveState;
|
||||
paginatedDocuments: Doc[] | null;
|
||||
templateAgents: Agent[] | null;
|
||||
agents: Agent[] | null;
|
||||
sharedAgents: Agent[] | null;
|
||||
selectedAgent: Agent | null;
|
||||
@@ -52,6 +53,7 @@ const initialState: Preference = {
|
||||
token: localStorage.getItem('authToken') || null,
|
||||
modalState: 'INACTIVE',
|
||||
paginatedDocuments: null,
|
||||
templateAgents: null,
|
||||
agents: null,
|
||||
sharedAgents: null,
|
||||
selectedAgent: null,
|
||||
@@ -91,6 +93,9 @@ export const prefSlice = createSlice({
|
||||
setModalStateDeleteConv: (state, action: PayloadAction<ActiveState>) => {
|
||||
state.modalState = action.payload;
|
||||
},
|
||||
setTemplateAgents: (state, action) => {
|
||||
state.templateAgents = action.payload;
|
||||
},
|
||||
setAgents: (state, action) => {
|
||||
state.agents = action.payload;
|
||||
},
|
||||
@@ -114,6 +119,7 @@ export const {
|
||||
setTokenLimit,
|
||||
setModalStateDeleteConv,
|
||||
setPaginatedDocuments,
|
||||
setTemplateAgents,
|
||||
setAgents,
|
||||
setSharedAgents,
|
||||
setSelectedAgent,
|
||||
@@ -191,6 +197,8 @@ export const selectTokenLimit = (state: RootState) =>
|
||||
state.preference.token_limit;
|
||||
export const selectPaginatedDocuments = (state: RootState) =>
|
||||
state.preference.paginatedDocuments;
|
||||
export const selectTemplateAgents = (state: RootState) =>
|
||||
state.preference.templateAgents;
|
||||
export const selectAgents = (state: RootState) => state.preference.agents;
|
||||
export const selectSharedAgents = (state: RootState) =>
|
||||
state.preference.sharedAgents;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { configureStore } from '@reduxjs/toolkit';
|
||||
|
||||
import agentPreviewReducer from './agents/agentPreviewSlice';
|
||||
import { conversationSlice } from './conversation/conversationSlice';
|
||||
import { sharedConversationSlice } from './conversation/sharedConversationSlice';
|
||||
import {
|
||||
@@ -8,7 +9,6 @@ import {
|
||||
prefSlice,
|
||||
} from './preferences/preferenceSlice';
|
||||
import uploadReducer from './upload/uploadSlice';
|
||||
import agentPreviewReducer from './agents/agentPreviewSlice';
|
||||
|
||||
const key = localStorage.getItem('DocsGPTApiKey');
|
||||
const prompt = localStorage.getItem('DocsGPTPrompt');
|
||||
@@ -43,6 +43,7 @@ const preloadedState: { preference: Preference } = {
|
||||
],
|
||||
modalState: 'INACTIVE',
|
||||
paginatedDocuments: null,
|
||||
templateAgents: null,
|
||||
agents: null,
|
||||
sharedAgents: null,
|
||||
selectedAgent: null,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { nanoid } from '@reduxjs/toolkit';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
@@ -24,6 +25,8 @@ import {
|
||||
getIngestorSchema,
|
||||
IngestorOption,
|
||||
} from '../upload/types/ingestor';
|
||||
import { addUploadTask, updateUploadTask } from './uploadSlice';
|
||||
|
||||
import { FormField, IngestorConfig, IngestorType } from './types/ingestor';
|
||||
|
||||
import { FilePicker } from '../components/FilePicker';
|
||||
@@ -190,12 +193,12 @@ function Upload({
|
||||
<div className="mb-3" {...getRootProps()}>
|
||||
<span className="text-purple-30 dark:text-silver inline-block rounded-3xl border border-[#7F7F82] bg-transparent px-4 py-2 font-medium hover:cursor-pointer">
|
||||
<input type="button" {...getInputProps()} />
|
||||
Choose Files
|
||||
{t('modals.uploadDoc.choose')}
|
||||
</span>
|
||||
</div>
|
||||
<div className="mt-4 max-w-full">
|
||||
<p className="text-eerie-black dark:text-light-gray mb-[14px] text-[14px] font-medium">
|
||||
Selected Files
|
||||
{t('modals.uploadDoc.selectedFiles')}
|
||||
</p>
|
||||
<div className="max-w-full overflow-hidden">
|
||||
{files.map((file) => (
|
||||
@@ -209,7 +212,7 @@ function Upload({
|
||||
))}
|
||||
{files.length === 0 && (
|
||||
<p className="text-gray-6000 dark:text-light-gray text-[14px]">
|
||||
No files selected
|
||||
{t('modals.uploadDoc.noFilesSelected')}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
@@ -259,15 +262,8 @@ function Upload({
|
||||
config: {},
|
||||
}));
|
||||
|
||||
const [progress, setProgress] = useState<{
|
||||
type: 'UPLOAD' | 'TRAINING';
|
||||
percentage: number;
|
||||
taskId?: string;
|
||||
failed?: boolean;
|
||||
}>();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const setTimeoutRef = useRef<number | null>(null);
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const ingestorOptions: IngestorOption[] = IngestorFormSchemas.filter(
|
||||
(schema) => (schema.validate ? schema.validate() : true),
|
||||
@@ -279,188 +275,120 @@ function Upload({
|
||||
}));
|
||||
|
||||
const sourceDocs = useSelector(selectSourceDocs);
|
||||
useEffect(() => {
|
||||
if (setTimeoutRef.current) {
|
||||
clearTimeout(setTimeoutRef.current);
|
||||
}
|
||||
|
||||
const resetUploaderState = useCallback(() => {
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setSelectedFiles([]);
|
||||
setSelectedFolders([]);
|
||||
setShowAdvancedOptions(false);
|
||||
}, []);
|
||||
|
||||
function ProgressBar({ progressPercent }: { progressPercent: number }) {
|
||||
return (
|
||||
<div className="my-8 flex h-full w-full items-center justify-center">
|
||||
<div className="relative h-32 w-32 rounded-full">
|
||||
<div className="absolute inset-0 rounded-full shadow-[0_0_10px_2px_rgba(0,0,0,0.3)_inset] dark:shadow-[0_0_10px_2px_rgba(0,0,0,0.3)_inset]"></div>
|
||||
<div
|
||||
className={`absolute inset-0 rounded-full ${progressPercent === 100 ? 'bg-linear-to-r from-white to-gray-400 shadow-xl shadow-lime-300/50 dark:bg-linear-to-br dark:from-gray-500 dark:to-gray-300 dark:shadow-lime-300/50' : 'shadow-[0_4px_0_#7D54D1] dark:shadow-[0_4px_0_#7D54D1]'}`}
|
||||
style={{
|
||||
animation: `${progressPercent === 100 ? 'none' : 'rotate 2s linear infinite'}`,
|
||||
}}
|
||||
></div>
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<span className="text-2xl font-bold">{progressPercent}%</span>
|
||||
</div>
|
||||
<style>
|
||||
{`@keyframes rotate {
|
||||
0% { transform: rotate(0deg); }
|
||||
100%{ transform: rotate(360deg); }
|
||||
}`}
|
||||
</style>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const handleTaskFailure = useCallback(
|
||||
(clientTaskId: string, errorMessage?: string) => {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
errorMessage: errorMessage || t('attachments.uploadFailed'),
|
||||
},
|
||||
}),
|
||||
);
|
||||
},
|
||||
[dispatch, t],
|
||||
);
|
||||
|
||||
function Progress({
|
||||
title,
|
||||
isCancellable = false,
|
||||
isFailed = false,
|
||||
isTraining = false,
|
||||
}: {
|
||||
title: string;
|
||||
isCancellable?: boolean;
|
||||
isFailed?: boolean;
|
||||
isTraining?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div className="text-gray-2000 dark:text-bright-gray mt-5 flex flex-col items-center gap-2">
|
||||
<p className="text-gra text-xl tracking-[0.15px]">
|
||||
{isTraining &&
|
||||
(progress?.percentage === 100
|
||||
? t('modals.uploadDoc.progress.completed')
|
||||
: title)}
|
||||
{!isTraining && title}
|
||||
</p>
|
||||
<p className="text-sm">{t('modals.uploadDoc.progress.wait')}</p>
|
||||
<p className={`ml-5 text-xl text-red-400 ${isFailed ? '' : 'hidden'}`}>
|
||||
{t('modals.uploadDoc.progress.tokenLimit')}
|
||||
</p>
|
||||
{/* <p className="mt-10 text-2xl">{progress?.percentage || 0}%</p> */}
|
||||
<ProgressBar progressPercent={progress?.percentage || 0} />
|
||||
{isTraining &&
|
||||
(progress?.percentage === 100 ? (
|
||||
<button
|
||||
onClick={() => {
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setProgress(undefined);
|
||||
setModalState('INACTIVE');
|
||||
}}
|
||||
className="h-[42px] cursor-pointer rounded-3xl bg-[#7D54D1] px-[28px] py-[6px] text-sm text-white shadow-lg hover:bg-[#6F3FD1]"
|
||||
>
|
||||
{t('modals.uploadDoc.start')}
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="ml-2 h-[42px] cursor-pointer rounded-3xl bg-[#7D54D14D] px-[28px] py-[6px] text-sm text-white shadow-lg"
|
||||
disabled
|
||||
>
|
||||
{t('modals.uploadDoc.wait')}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const trackTraining = useCallback(
|
||||
(backendTaskId: string, clientTaskId: string) => {
|
||||
let timeoutId: number | null = null;
|
||||
|
||||
function UploadProgress() {
|
||||
return <Progress title={t('modals.uploadDoc.progress.upload')}></Progress>;
|
||||
}
|
||||
|
||||
function TrainingProgress() {
|
||||
const dispatch = useDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
let timeoutID: number | undefined;
|
||||
|
||||
if ((progress?.percentage ?? 0) < 100) {
|
||||
timeoutID = setTimeout(() => {
|
||||
userService
|
||||
.getTaskStatus(progress?.taskId as string, null)
|
||||
.then((data) => data.json())
|
||||
.then((data) => {
|
||||
if (data.status == 'SUCCESS') {
|
||||
if (data.result.limited === true) {
|
||||
getDocs(token).then((data) => {
|
||||
dispatch(setSourceDocs(data));
|
||||
dispatch(
|
||||
setSelectedDocs(
|
||||
Array.isArray(data) &&
|
||||
data?.find(
|
||||
(d: Doc) => d.type?.toLowerCase() === 'local',
|
||||
),
|
||||
),
|
||||
);
|
||||
});
|
||||
setProgress(
|
||||
(progress) =>
|
||||
progress && {
|
||||
...progress,
|
||||
percentage: 100,
|
||||
failed: true,
|
||||
},
|
||||
);
|
||||
} else {
|
||||
getDocs(token).then((data) => {
|
||||
dispatch(setSourceDocs(data));
|
||||
const docIds = new Set(
|
||||
(Array.isArray(sourceDocs) &&
|
||||
sourceDocs?.map((doc: Doc) =>
|
||||
doc.id ? doc.id : null,
|
||||
)) ||
|
||||
[],
|
||||
);
|
||||
if (data && Array.isArray(data)) {
|
||||
data.map((updatedDoc: Doc) => {
|
||||
if (updatedDoc.id && !docIds.has(updatedDoc.id)) {
|
||||
// Select the doc not present in the intersection of current Docs and fetched data
|
||||
dispatch(setSelectedDocs(updatedDoc));
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
setProgress(
|
||||
(progress) =>
|
||||
progress && {
|
||||
...progress,
|
||||
percentage: 100,
|
||||
failed: false,
|
||||
},
|
||||
);
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setProgress(undefined);
|
||||
setModalState('INACTIVE');
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} else if (data.status == 'PROGRESS') {
|
||||
setProgress(
|
||||
(progress) =>
|
||||
progress && {
|
||||
...progress,
|
||||
percentage: data.result.current,
|
||||
},
|
||||
);
|
||||
const poll = () => {
|
||||
userService
|
||||
.getTaskStatus(backendTaskId, null)
|
||||
.then((response) => response.json())
|
||||
.then(async (data) => {
|
||||
if (data.status === 'SUCCESS') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
});
|
||||
}, 5000);
|
||||
}
|
||||
|
||||
// cleanup
|
||||
return () => {
|
||||
if (timeoutID !== undefined) {
|
||||
clearTimeout(timeoutID);
|
||||
}
|
||||
const docs = await getDocs(token);
|
||||
dispatch(setSourceDocs(docs));
|
||||
|
||||
if (Array.isArray(docs)) {
|
||||
const existingDocIds = new Set(
|
||||
(Array.isArray(sourceDocs) ? sourceDocs : [])
|
||||
.map((doc: Doc) => doc?.id)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
const newDoc = docs.find(
|
||||
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
|
||||
);
|
||||
if (newDoc) {
|
||||
dispatch(setSelectedDocs([newDoc]));
|
||||
}
|
||||
}
|
||||
|
||||
if (data.result?.limited) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
progress: 100,
|
||||
errorMessage: t('modals.uploadDoc.progress.tokenLimit'),
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
errorMessage: undefined,
|
||||
},
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} else if (data.status === 'FAILURE') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, data.result?.message);
|
||||
} else if (data.status === 'PROGRESS') {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'training',
|
||||
progress: Math.min(100, data.result?.current ?? 0),
|
||||
},
|
||||
}),
|
||||
);
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
} else {
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId);
|
||||
});
|
||||
};
|
||||
}, [progress, dispatch]);
|
||||
return (
|
||||
<Progress
|
||||
title={t('modals.uploadDoc.progress.training')}
|
||||
isCancellable={progress?.percentage === 100}
|
||||
isFailed={progress?.failed === true}
|
||||
isTraining={true}
|
||||
></Progress>
|
||||
);
|
||||
}
|
||||
|
||||
timeoutId = window.setTimeout(poll, 3000);
|
||||
},
|
||||
[dispatch, handleTaskFailure, onSuccessfulUpload, sourceDocs, t, token],
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: File[]) => {
|
||||
@@ -483,7 +411,7 @@ function Upload({
|
||||
|
||||
const doNothing = () => undefined;
|
||||
|
||||
const uploadFile = () => {
|
||||
const uploadFile = (clientTaskId: string) => {
|
||||
const formData = new FormData();
|
||||
files.forEach((file) => {
|
||||
formData.append('file', file);
|
||||
@@ -491,34 +419,89 @@ function Upload({
|
||||
|
||||
formData.append('name', ingestor.name);
|
||||
formData.append('user', 'local');
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'uploading', progress: 0 },
|
||||
}),
|
||||
);
|
||||
|
||||
xhr.upload.addEventListener('progress', (event) => {
|
||||
const progress = +((event.loaded / event.total) * 100).toFixed(2);
|
||||
setProgress({ type: 'UPLOAD', percentage: progress });
|
||||
if (!event.lengthComputable) return;
|
||||
const progressPercentage = Number(
|
||||
((event.loaded / event.total) * 100).toFixed(2),
|
||||
);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { progress: progressPercentage },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
const { task_id } = JSON.parse(xhr.responseText);
|
||||
setTimeoutRef.current = setTimeout(() => {
|
||||
setProgress({ type: 'TRAINING', percentage: 0, taskId: task_id });
|
||||
}, 3000);
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const parsed = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
if (parsed.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: parsed.task_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(parsed.task_id, clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'completed', progress: 100 },
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} catch (error) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
}
|
||||
} else {
|
||||
handleTaskFailure(clientTaskId, xhr.statusText || undefined);
|
||||
}
|
||||
};
|
||||
xhr.open('POST', `${apiHost + '/api/upload'}`);
|
||||
|
||||
xhr.onerror = () => {
|
||||
handleTaskFailure(clientTaskId);
|
||||
};
|
||||
|
||||
xhr.open('POST', `${apiHost}/api/upload`);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
};
|
||||
|
||||
const uploadRemote = () => {
|
||||
if (!ingestor.type) return;
|
||||
const uploadRemote = (clientTaskId: string) => {
|
||||
if (!ingestor.type) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
return;
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('name', ingestor.name);
|
||||
formData.append('user', 'local');
|
||||
formData.append('source', ingestor.type as string);
|
||||
|
||||
let configData: any = {};
|
||||
|
||||
const ingestorSchema = getIngestorSchema(ingestor.type as IngestorType);
|
||||
if (!ingestorSchema) return;
|
||||
if (!ingestorSchema) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
return;
|
||||
}
|
||||
|
||||
const schema: FormField[] = ingestorSchema.fields;
|
||||
const hasLocalFilePicker = schema.some(
|
||||
(field: FormField) => field.type === 'local_file_picker',
|
||||
@@ -530,11 +513,12 @@ function Upload({
|
||||
(field: FormField) => field.type === 'google_drive_picker',
|
||||
);
|
||||
|
||||
let configData: Record<string, unknown> = { ...ingestor.config };
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
files.forEach((file) => {
|
||||
formData.append('file', file);
|
||||
});
|
||||
configData = { ...ingestor.config };
|
||||
} else if (hasRemoteFilePicker || hasGoogleDrivePicker) {
|
||||
const sessionToken = getSessionToken(ingestor.type as string);
|
||||
configData = {
|
||||
@@ -543,44 +527,122 @@ function Upload({
|
||||
file_ids: selectedFiles,
|
||||
folder_ids: selectedFolders,
|
||||
};
|
||||
} else {
|
||||
configData = { ...ingestor.config };
|
||||
}
|
||||
|
||||
formData.append('data', JSON.stringify(configData));
|
||||
|
||||
const apiHost: string = import.meta.env.VITE_API_HOST;
|
||||
const xhr = new XMLHttpRequest();
|
||||
xhr.upload.addEventListener('progress', (event: ProgressEvent) => {
|
||||
if (event.lengthComputable) {
|
||||
const progressPercentage = +(
|
||||
(event.loaded / event.total) *
|
||||
100
|
||||
).toFixed(2);
|
||||
setProgress({ type: 'UPLOAD', percentage: progressPercentage });
|
||||
}
|
||||
});
|
||||
xhr.onload = () => {
|
||||
const response = JSON.parse(xhr.responseText) as { task_id: string };
|
||||
setTimeoutRef.current = window.setTimeout(() => {
|
||||
setProgress({
|
||||
type: 'TRAINING',
|
||||
percentage: 0,
|
||||
taskId: response.task_id,
|
||||
});
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
const endpoint =
|
||||
ingestor.type === 'local_file'
|
||||
? `${apiHost}/api/upload`
|
||||
: `${apiHost}/api/remote`;
|
||||
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'uploading', progress: 0 },
|
||||
}),
|
||||
);
|
||||
|
||||
xhr.upload.addEventListener('progress', (event: ProgressEvent) => {
|
||||
if (!event.lengthComputable) return;
|
||||
const progressPercentage = Number(
|
||||
((event.loaded / event.total) * 100).toFixed(2),
|
||||
);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { progress: progressPercentage },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
xhr.onload = () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(response.task_id, clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: { status: 'completed', progress: 100 },
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} catch (error) {
|
||||
handleTaskFailure(clientTaskId);
|
||||
}
|
||||
} else {
|
||||
handleTaskFailure(clientTaskId, xhr.statusText || undefined);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
handleTaskFailure(clientTaskId);
|
||||
};
|
||||
|
||||
xhr.open('POST', endpoint);
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
xhr.send(formData);
|
||||
};
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
resetUploaderState();
|
||||
setModalState('INACTIVE');
|
||||
close();
|
||||
}, [close, resetUploaderState, setModalState]);
|
||||
|
||||
const handleUpload = () => {
|
||||
if (!ingestor.type) return;
|
||||
|
||||
const ingestorSchemaForUpload = getIngestorSchema(
|
||||
ingestor.type as IngestorType,
|
||||
);
|
||||
if (!ingestorSchemaForUpload) return;
|
||||
|
||||
const schema: FormField[] = ingestorSchemaForUpload.fields;
|
||||
const hasLocalFilePicker = schema.some(
|
||||
(field: FormField) => field.type === 'local_file_picker',
|
||||
);
|
||||
|
||||
const displayName =
|
||||
ingestor.name?.trim() || files[0]?.name || t('modals.uploadDoc.label');
|
||||
|
||||
const clientTaskId = nanoid();
|
||||
|
||||
dispatch(
|
||||
addUploadTask({
|
||||
id: clientTaskId,
|
||||
fileName: displayName,
|
||||
progress: 0,
|
||||
status: 'preparing',
|
||||
}),
|
||||
);
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
uploadFile(clientTaskId);
|
||||
} else {
|
||||
uploadRemote(clientTaskId);
|
||||
}
|
||||
|
||||
handleClose();
|
||||
};
|
||||
|
||||
const { getRootProps, getInputProps } = useDropzone({
|
||||
onDrop,
|
||||
multiple: true,
|
||||
@@ -733,7 +795,7 @@ function Upload({
|
||||
/>
|
||||
</div>
|
||||
<p className="font-inter self-start text-[13px] leading-[18px] font-semibold">
|
||||
{option.label}
|
||||
{t(`modals.uploadDoc.ingestors.${option.value}.label`)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -741,18 +803,16 @@ function Upload({
|
||||
</div>
|
||||
);
|
||||
};
|
||||
let view;
|
||||
|
||||
if (progress?.type === 'UPLOAD') {
|
||||
view = <UploadProgress></UploadProgress>;
|
||||
} else if (progress?.type === 'TRAINING') {
|
||||
view = <TrainingProgress></TrainingProgress>;
|
||||
} else {
|
||||
view = (
|
||||
return (
|
||||
<WrapperModal
|
||||
close={handleClose}
|
||||
className="max-h-[90vh] w-11/12 sm:max-h-none sm:w-auto sm:min-w-[600px] md:min-w-[700px]"
|
||||
contentClassName="max-h-[80vh] sm:max-h-none"
|
||||
>
|
||||
<div className="flex w-full flex-col gap-6">
|
||||
{!ingestor.type && (
|
||||
<p className="font-inter text-left text-[20px] leading-[28px] font-semibold tracking-[0.15px] text-[#18181B] dark:text-[#ECECF1]">
|
||||
Select the way to add your source
|
||||
{t('modals.uploadDoc.selectSource')}
|
||||
</p>
|
||||
)}
|
||||
|
||||
@@ -770,12 +830,12 @@ function Upload({
|
||||
alt="back"
|
||||
className="h-3 w-3 rotate-180 transform"
|
||||
/>
|
||||
<span>Back</span>
|
||||
<span>{t('modals.uploadDoc.back')}</span>
|
||||
</button>
|
||||
|
||||
<h2 className="font-inter text-[22px] leading-[28px] font-semibold tracking-[0.15px] text-black dark:text-[#E0E0E0]">
|
||||
{ingestor.type &&
|
||||
getIngestorSchema(ingestor.type as IngestorType)?.heading}
|
||||
t(`modals.uploadDoc.ingestors.${ingestor.type}.heading`)}
|
||||
</h2>
|
||||
|
||||
<Input
|
||||
@@ -789,7 +849,7 @@ function Upload({
|
||||
}));
|
||||
}}
|
||||
borderVariant="thin"
|
||||
placeholder="Name"
|
||||
placeholder={t('modals.uploadDoc.name')}
|
||||
required={true}
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
className="w-full"
|
||||
@@ -816,23 +876,7 @@ function Upload({
|
||||
<div className="flex justify-end gap-4">
|
||||
{activeTab && ingestor.type && (
|
||||
<button
|
||||
onClick={() => {
|
||||
if (!ingestor.type) return;
|
||||
const ingestorSchemaForUpload = getIngestorSchema(
|
||||
ingestor.type as IngestorType,
|
||||
);
|
||||
if (!ingestorSchemaForUpload) return;
|
||||
const schema: FormField[] = ingestorSchemaForUpload.fields;
|
||||
const hasLocalFilePicker = schema.some(
|
||||
(field: FormField) => field.type === 'local_file_picker',
|
||||
);
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
uploadFile();
|
||||
} else {
|
||||
uploadRemote();
|
||||
}
|
||||
}}
|
||||
onClick={handleUpload}
|
||||
disabled={isUploadDisabled()}
|
||||
className={`rounded-3xl px-4 py-2 text-[14px] font-medium ${
|
||||
isUploadDisabled()
|
||||
@@ -845,22 +889,6 @@ function Upload({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<WrapperModal
|
||||
isPerformingTask={progress !== undefined && progress.percentage < 100}
|
||||
close={() => {
|
||||
close();
|
||||
setIngestor({ type: null, name: '', config: {} });
|
||||
setfiles([]);
|
||||
setModalState('INACTIVE');
|
||||
}}
|
||||
className="max-h-[90vh] w-11/12 sm:max-h-none sm:w-auto sm:min-w-[600px] md:min-w-[700px]"
|
||||
contentClassName="max-h-[80vh] sm:max-h-none"
|
||||
>
|
||||
{view}
|
||||
</WrapperModal>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -10,12 +10,31 @@ export interface Attachment {
|
||||
token_count?: number;
|
||||
}
|
||||
|
||||
export type UploadTaskStatus =
|
||||
| 'preparing'
|
||||
| 'uploading'
|
||||
| 'training'
|
||||
| 'completed'
|
||||
| 'failed';
|
||||
|
||||
export interface UploadTask {
|
||||
id: string;
|
||||
fileName: string;
|
||||
progress: number;
|
||||
status: UploadTaskStatus;
|
||||
taskId?: string;
|
||||
errorMessage?: string;
|
||||
dismissed?: boolean;
|
||||
}
|
||||
|
||||
interface UploadState {
|
||||
attachments: Attachment[];
|
||||
tasks: UploadTask[];
|
||||
}
|
||||
|
||||
const initialState: UploadState = {
|
||||
attachments: [],
|
||||
tasks: [],
|
||||
};
|
||||
|
||||
export const uploadSlice = createSlice({
|
||||
@@ -52,6 +71,49 @@ export const uploadSlice = createSlice({
|
||||
(att) => att.status === 'uploading' || att.status === 'processing',
|
||||
);
|
||||
},
|
||||
addUploadTask: (state, action: PayloadAction<UploadTask>) => {
|
||||
state.tasks.unshift(action.payload);
|
||||
},
|
||||
updateUploadTask: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
updates: Partial<UploadTask>;
|
||||
}>,
|
||||
) => {
|
||||
const index = state.tasks.findIndex(
|
||||
(task) => task.id === action.payload.id,
|
||||
);
|
||||
if (index !== -1) {
|
||||
const updates = action.payload.updates;
|
||||
|
||||
// When task completes or fails, set dismissed to false to notify user
|
||||
if (updates.status === 'completed' || updates.status === 'failed') {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
...updates,
|
||||
dismissed: false,
|
||||
};
|
||||
} else {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
...updates,
|
||||
};
|
||||
}
|
||||
}
|
||||
},
|
||||
dismissUploadTask: (state, action: PayloadAction<string>) => {
|
||||
const index = state.tasks.findIndex((task) => task.id === action.payload);
|
||||
if (index !== -1) {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
dismissed: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
removeUploadTask: (state, action: PayloadAction<string>) => {
|
||||
state.tasks = state.tasks.filter((task) => task.id !== action.payload);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -60,10 +122,15 @@ export const {
|
||||
updateAttachment,
|
||||
removeAttachment,
|
||||
clearAttachments,
|
||||
addUploadTask,
|
||||
updateUploadTask,
|
||||
dismissUploadTask,
|
||||
removeUploadTask,
|
||||
} = uploadSlice.actions;
|
||||
|
||||
export const selectAttachments = (state: RootState) => state.upload.attachments;
|
||||
export const selectCompletedAttachments = (state: RootState) =>
|
||||
state.upload.attachments.filter((att) => att.status === 'completed');
|
||||
export const selectUploadTasks = (state: RootState) => state.upload.tasks;
|
||||
|
||||
export default uploadSlice.reducer;
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestGitHubLoaderFetchFileContent:
|
||||
|
||||
result = loader.fetch_file_content("owner/repo", "README.md")
|
||||
|
||||
assert result == f"Filename: README.md\n\n{content_str}"
|
||||
assert result == content_str
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.github.com/repos/owner/repo/contents/README.md",
|
||||
headers=loader.headers,
|
||||
@@ -40,7 +40,7 @@ class TestGitHubLoaderFetchFileContent:
|
||||
|
||||
result = loader.fetch_file_content("owner/repo", "image.png")
|
||||
|
||||
assert result == "Filename: image.png is a binary file and was skipped."
|
||||
assert result is None
|
||||
|
||||
@patch("application.parser.remote.github_loader.requests.get")
|
||||
def test_non_base64_plain_content(self, mock_get):
|
||||
@@ -49,7 +49,7 @@ class TestGitHubLoaderFetchFileContent:
|
||||
|
||||
result = loader.fetch_file_content("owner/repo", "file.txt")
|
||||
|
||||
assert result == "Filename: file.txt\n\nPlain text"
|
||||
assert result == "Plain text"
|
||||
|
||||
@patch("application.parser.remote.github_loader.requests.get")
|
||||
def test_http_error_raises(self, mock_get):
|
||||
@@ -102,13 +102,13 @@ class TestGitHubLoaderLoadData:
|
||||
docs = loader.load_data("https://github.com/owner/repo")
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "content for README.md"
|
||||
assert docs[0].metadata == {
|
||||
assert docs[0].text == "content for README.md"
|
||||
assert docs[0].extra_info == {
|
||||
"title": "README.md",
|
||||
"source": "https://github.com/owner/repo/blob/main/README.md",
|
||||
}
|
||||
assert docs[1].page_content == "content for src/main.py"
|
||||
assert docs[1].metadata == {
|
||||
assert docs[1].text == "content for src/main.py"
|
||||
assert docs[1].extra_info == {
|
||||
"title": "src/main.py",
|
||||
"source": "https://github.com/owner/repo/blob/main/src/main.py",
|
||||
}
|
||||
@@ -142,12 +142,13 @@ class TestGitHubLoaderRobustness:
|
||||
GitHubLoader().fetch_file_content("owner/repo", "README.md")
|
||||
|
||||
@patch("application.parser.remote.github_loader.requests.get")
|
||||
def test_fetch_file_content_unexpected_shape_missing_content_raises(self, mock_get):
|
||||
def test_fetch_file_content_unexpected_shape_missing_content_returns_none(self, mock_get):
|
||||
# encoding indicates base64 text, but 'content' key is missing
|
||||
# With the new code, the exception is caught and returns None (treated as binary/skipped)
|
||||
resp = make_response({"encoding": "base64"})
|
||||
mock_get.return_value = resp
|
||||
with pytest.raises(KeyError):
|
||||
GitHubLoader().fetch_file_content("owner/repo", "README.md")
|
||||
result = GitHubLoader().fetch_file_content("owner/repo", "file.txt")
|
||||
assert result is None
|
||||
|
||||
@patch("application.parser.remote.github_loader.base64.b64decode")
|
||||
@patch("application.parser.remote.github_loader.requests.get")
|
||||
@@ -156,4 +157,4 @@ class TestGitHubLoaderRobustness:
|
||||
mock_b64decode.side_effect = AssertionError("b64decode should not be called for binary files")
|
||||
mock_get.return_value = make_response({"encoding": "base64", "content": "AAA"})
|
||||
result = GitHubLoader().fetch_file_content("owner/repo", "bigfile.bin")
|
||||
assert result == "Filename: bigfile.bin is a binary file and was skipped."
|
||||
assert result is None
|
||||
|
||||
765
tests/test_memory_tool.py
Normal file
765
tests/test_memory_tool.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import pytest
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_tool(monkeypatch) -> MemoryTool:
|
||||
"""Provide a MemoryTool with a fake Mongo collection and fixed user_id."""
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {} # path -> document
|
||||
|
||||
def insert_one(self, doc):
|
||||
user_id = doc.get("user_id")
|
||||
tool_id = doc.get("tool_id")
|
||||
path = doc.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
# Add _id to document if not present
|
||||
if "_id" not in doc:
|
||||
doc["_id"] = key
|
||||
self.docs[key] = doc
|
||||
return type("res", (), {"inserted_id": key})
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
# Handle query by _id
|
||||
if "_id" in q:
|
||||
doc_id = q["_id"]
|
||||
if doc_id not in self.docs:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if "$set" in u:
|
||||
old_doc = self.docs[doc_id].copy()
|
||||
old_doc.update(u["$set"])
|
||||
|
||||
# If path changed, update the dictionary key
|
||||
if "path" in u["$set"]:
|
||||
new_path = u["$set"]["path"]
|
||||
user_id = old_doc.get("user_id")
|
||||
tool_id = old_doc.get("tool_id")
|
||||
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||
|
||||
# Remove old key and add with new key
|
||||
del self.docs[doc_id]
|
||||
old_doc["_id"] = new_key
|
||||
self.docs[new_key] = old_doc
|
||||
else:
|
||||
self.docs[doc_id] = old_doc
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
# Handle query by user_id, tool_id, path
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
def find(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
results = []
|
||||
|
||||
# Handle regex queries for directory listing
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
# Remove regex escape characters and ^ anchor for simple matching
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
doc_path = doc.get("path", "")
|
||||
if doc_path.startswith(pattern):
|
||||
results.append(doc)
|
||||
else:
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def delete_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key in self.docs:
|
||||
del self.docs[key]
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
def delete_many(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
deleted = 0
|
||||
|
||||
# Handle regex queries for directory deletion
|
||||
if "path" in q and isinstance(q["path"], dict) and "$regex" in q["path"]:
|
||||
regex_pattern = q["path"]["$regex"]
|
||||
pattern = regex_pattern.replace("\\", "").lstrip("^")
|
||||
|
||||
keys_to_delete = []
|
||||
for key, doc in self.docs.items():
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id:
|
||||
doc_path = doc.get("path", "")
|
||||
if doc_path.startswith(pattern):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
else:
|
||||
# Delete all for user and tool
|
||||
keys_to_delete = [
|
||||
key for key, doc in self.docs.items()
|
||||
if doc.get("user_id") == user_id and doc.get("tool_id") == tool_id
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self.docs[key]
|
||||
deleted += 1
|
||||
|
||||
return type("res", (), {"deleted_count": deleted})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Return tool with a fixed tool_id for consistency in tests
|
||||
return MemoryTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||
|
||||
|
||||
def test_init_without_user_id():
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
memory_tool = MemoryTool(tool_config={})
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "user_id" in result.lower()
|
||||
|
||||
|
||||
def test_view_empty_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Should show empty directory when no files exist."""
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "empty" in result.lower()
|
||||
|
||||
|
||||
def test_create_and_view_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test creating a file and viewing it."""
|
||||
# Create a file
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/notes.txt",
|
||||
file_text="Hello world"
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View the file
|
||||
result = memory_tool.execute_action("view", path="/notes.txt")
|
||||
assert "Hello world" in result
|
||||
|
||||
|
||||
def test_create_overwrite_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that create overwrites existing files."""
|
||||
# Create initial file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Original content"
|
||||
)
|
||||
|
||||
# Overwrite
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="New content"
|
||||
)
|
||||
|
||||
# Verify overwrite
|
||||
result = memory_tool.execute_action("view", path="/test.txt")
|
||||
assert "New content" in result
|
||||
assert "Original content" not in result
|
||||
|
||||
|
||||
def test_view_directory_with_files(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing directory contents."""
|
||||
# Create multiple files
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file3.txt",
|
||||
file_text="Content 3"
|
||||
)
|
||||
|
||||
# View directory
|
||||
result = memory_tool.execute_action("view", path="/")
|
||||
assert "file1.txt" in result
|
||||
assert "file2.txt" in result
|
||||
assert "subdir/file3.txt" in result
|
||||
|
||||
|
||||
def test_view_file_with_line_range(memory_tool: MemoryTool) -> None:
|
||||
"""Test viewing specific lines from a file."""
|
||||
# Create a multiline file
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/multiline.txt",
|
||||
file_text=content
|
||||
)
|
||||
|
||||
# View lines 2-4
|
||||
result = memory_tool.execute_action(
|
||||
"view",
|
||||
path="/multiline.txt",
|
||||
view_range=[2, 4]
|
||||
)
|
||||
assert "Line 2" in result
|
||||
assert "Line 3" in result
|
||||
assert "Line 4" in result
|
||||
assert "Line 1" not in result
|
||||
assert "Line 5" not in result
|
||||
|
||||
|
||||
def test_str_replace(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement in a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/replace.txt",
|
||||
file_text="Hello world, hello universe"
|
||||
)
|
||||
|
||||
# Replace text
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/replace.txt",
|
||||
old_str="hello",
|
||||
new_str="hi"
|
||||
)
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
content = memory_tool.execute_action("view", path="/replace.txt")
|
||||
assert "hi world, hi universe" in content
|
||||
|
||||
|
||||
def test_str_replace_not_found(memory_tool: MemoryTool) -> None:
|
||||
"""Test string replacement when string not found."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Hello world"
|
||||
)
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"str_replace",
|
||||
path="/test.txt",
|
||||
old_str="goodbye",
|
||||
new_str="hi"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_insert_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting text at a line number."""
|
||||
# Create a multiline file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/insert.txt",
|
||||
file_text="Line 1\nLine 2\nLine 3"
|
||||
)
|
||||
|
||||
# Insert at line 2
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/insert.txt",
|
||||
insert_line=2,
|
||||
insert_text="Inserted line"
|
||||
)
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
content = memory_tool.execute_action("view", path="/insert.txt")
|
||||
lines = content.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
def test_insert_invalid_line(memory_tool: MemoryTool) -> None:
|
||||
"""Test inserting at an invalid line number."""
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/test.txt",
|
||||
file_text="Line 1\nLine 2"
|
||||
)
|
||||
|
||||
result = memory_tool.execute_action(
|
||||
"insert",
|
||||
path="/test.txt",
|
||||
insert_line=100,
|
||||
insert_text="Text"
|
||||
)
|
||||
assert "invalid" in result.lower()
|
||||
|
||||
|
||||
def test_delete_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/delete_me.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
# Delete it
|
||||
result = memory_tool.execute_action("delete", path="/delete_me.txt")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify it's gone
|
||||
result = memory_tool.execute_action("view", path="/delete_me.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_delete_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a file that doesn't exist."""
|
||||
result = memory_tool.execute_action("delete", path="/nonexistent.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_delete_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Test deleting a directory with files."""
|
||||
# Create files in a directory
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/subdir/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Delete the directory
|
||||
result = memory_tool.execute_action("delete", path="/subdir/")
|
||||
assert "deleted" in result.lower()
|
||||
|
||||
# Verify files are gone
|
||||
result = memory_tool.execute_action("view", path="/subdir/file1.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_rename_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file."""
|
||||
# Create a file
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/old_name.txt",
|
||||
file_text="Content"
|
||||
)
|
||||
|
||||
# Rename it
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/old_name.txt",
|
||||
new_path="/new_name.txt"
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
|
||||
# Verify old path doesn't exist
|
||||
result = memory_tool.execute_action("view", path="/old_name.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
# Verify new path exists
|
||||
result = memory_tool.execute_action("view", path="/new_name.txt")
|
||||
assert "Content" in result
|
||||
|
||||
|
||||
def test_rename_nonexistent_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a file that doesn't exist."""
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/nonexistent.txt",
|
||||
new_path="/new.txt"
|
||||
)
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_rename_to_existing_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming to a path that already exists."""
|
||||
# Create two files
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Try to rename file1 to file2
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/file1.txt",
|
||||
new_path="/file2.txt"
|
||||
)
|
||||
assert "already exists" in result.lower()
|
||||
|
||||
|
||||
def test_path_traversal_protection(memory_tool: MemoryTool) -> None:
|
||||
"""Test that directory traversal attacks are prevented."""
|
||||
# Try various path traversal attempts
|
||||
invalid_paths = [
|
||||
"/../secrets.txt",
|
||||
"/../../etc/passwd",
|
||||
"..//file.txt",
|
||||
"/subdir/../../outside.txt",
|
||||
]
|
||||
|
||||
for path in invalid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="malicious content"
|
||||
)
|
||||
assert "invalid path" in result.lower()
|
||||
|
||||
|
||||
def test_path_must_start_with_slash(memory_tool: MemoryTool) -> None:
|
||||
"""Test that paths work with or without leading slash (auto-normalized)."""
|
||||
# These paths should all work now (auto-prepended with /)
|
||||
valid_paths = [
|
||||
"etc/passwd", # Auto-prepended with /
|
||||
"home/user/file.txt", # Auto-prepended with /
|
||||
"file.txt", # Auto-prepended with /
|
||||
]
|
||||
|
||||
for path in valid_paths:
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path=path,
|
||||
file_text="content"
|
||||
)
|
||||
assert "created" in result.lower()
|
||||
|
||||
# Verify the file can be accessed with or without leading slash
|
||||
view_result = memory_tool.execute_action("view", path=path)
|
||||
assert "content" in view_result
|
||||
|
||||
|
||||
def test_cannot_create_directory_as_file(memory_tool: MemoryTool) -> None:
|
||||
"""Test that you cannot create a file at a directory path."""
|
||||
result = memory_tool.execute_action(
|
||||
"create",
|
||||
path="/",
|
||||
file_text="content"
|
||||
)
|
||||
assert "cannot create a file at directory path" in result.lower()
|
||||
|
||||
|
||||
def test_get_actions_metadata(memory_tool: MemoryTool) -> None:
|
||||
"""Test that action metadata is properly defined."""
|
||||
metadata = memory_tool.get_actions_metadata()
|
||||
|
||||
# Check that all expected actions are defined
|
||||
action_names = [action["name"] for action in metadata]
|
||||
assert "view" in action_names
|
||||
assert "create" in action_names
|
||||
assert "str_replace" in action_names
|
||||
assert "insert" in action_names
|
||||
assert "delete" in action_names
|
||||
assert "rename" in action_names
|
||||
|
||||
# Check that each action has required fields
|
||||
for action in metadata:
|
||||
assert "name" in action
|
||||
assert "description" in action
|
||||
assert "parameters" in action
|
||||
|
||||
|
||||
def test_memory_tool_isolation(monkeypatch) -> None:
|
||||
"""Test that different memory tool instances have isolated memories."""
|
||||
# Create fake collection
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def insert_one(self, doc):
|
||||
user_id = doc.get("user_id")
|
||||
tool_id = doc.get("tool_id")
|
||||
path = doc.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
self.docs[key] = doc
|
||||
return type("res", (), {"inserted_id": key})
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
# Handle query by _id
|
||||
if "_id" in q:
|
||||
doc_id = q["_id"]
|
||||
if doc_id not in self.docs:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if "$set" in u:
|
||||
old_doc = self.docs[doc_id].copy()
|
||||
old_doc.update(u["$set"])
|
||||
|
||||
# If path changed, update the dictionary key
|
||||
if "path" in u["$set"]:
|
||||
new_path = u["$set"]["path"]
|
||||
user_id = old_doc.get("user_id")
|
||||
tool_id = old_doc.get("tool_id")
|
||||
new_key = f"{user_id}:{tool_id}:{new_path}"
|
||||
|
||||
# Remove old key and add with new key
|
||||
del self.docs[doc_id]
|
||||
old_doc["_id"] = new_key
|
||||
self.docs[new_key] = old_doc
|
||||
else:
|
||||
self.docs[doc_id] = old_doc
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
# Handle query by user_id, tool_id, path
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "path": path, "content": "", "_id": key}
|
||||
|
||||
if "$set" in u:
|
||||
self.docs[key].update(u["$set"])
|
||||
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q, projection=None):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
path = q.get("path")
|
||||
|
||||
if path:
|
||||
key = f"{user_id}:{tool_id}:{path}"
|
||||
return self.docs.get(key)
|
||||
|
||||
return None
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two memory tools with different tool_ids for the same user
|
||||
tool1 = MemoryTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||
tool2 = MemoryTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||
|
||||
# Create a file in tool1
|
||||
tool1.execute_action("create", path="/file.txt", file_text="Content from tool 1")
|
||||
|
||||
# Create a file with the same path in tool2
|
||||
tool2.execute_action("create", path="/file.txt", file_text="Content from tool 2")
|
||||
|
||||
# Verify that each tool sees only its own content
|
||||
result1 = tool1.execute_action("view", path="/file.txt")
|
||||
result2 = tool2.execute_action("view", path="/file.txt")
|
||||
|
||||
assert "Content from tool 1" in result1
|
||||
assert "Content from tool 2" not in result1
|
||||
|
||||
assert "Content from tool 2" in result2
|
||||
assert "Content from tool 1" not in result2
|
||||
|
||||
|
||||
def test_memory_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"memories": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two tools without providing tool_id for the same user
|
||||
tool1 = MemoryTool({}, user_id="test_user")
|
||||
tool2 = MemoryTool({}, user_id="test_user")
|
||||
|
||||
# Both should have the same default tool_id for persistence
|
||||
assert tool1.tool_id == "default_test_user"
|
||||
assert tool2.tool_id == "default_test_user"
|
||||
assert tool1.tool_id == tool2.tool_id
|
||||
|
||||
# Different users should have different tool_ids
|
||||
tool3 = MemoryTool({}, user_id="another_user")
|
||||
assert tool3.tool_id == "default_another_user"
|
||||
assert tool3.tool_id != tool1.tool_id
|
||||
|
||||
|
||||
def test_paths_without_leading_slash(memory_tool) -> None:
|
||||
"""Test that paths without leading slash work correctly."""
|
||||
# Create file without leading slash
|
||||
result = memory_tool.execute_action("create", path="cat_breeds.txt", file_text="- Korat\n- Chartreux\n- British Shorthair\n- Nebelung")
|
||||
assert "created" in result.lower()
|
||||
|
||||
# View file without leading slash
|
||||
view_result = memory_tool.execute_action("view", path="cat_breeds.txt")
|
||||
assert "Korat" in view_result
|
||||
assert "Chartreux" in view_result
|
||||
|
||||
# View same file with leading slash (should work the same)
|
||||
view_result2 = memory_tool.execute_action("view", path="/cat_breeds.txt")
|
||||
assert "Korat" in view_result2
|
||||
|
||||
# Test str_replace without leading slash
|
||||
replace_result = memory_tool.execute_action("str_replace", path="cat_breeds.txt", old_str="Korat", new_str="Maine Coon")
|
||||
assert "updated" in replace_result.lower()
|
||||
|
||||
# Test nested path without leading slash
|
||||
nested_result = memory_tool.execute_action("create", path="projects/tasks.txt", file_text="Task 1\nTask 2")
|
||||
assert "created" in nested_result.lower()
|
||||
|
||||
view_nested = memory_tool.execute_action("view", path="projects/tasks.txt")
|
||||
assert "Task 1" in view_nested
|
||||
|
||||
|
||||
def test_rename_directory(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a directory with files."""
|
||||
# Create files in a directory
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/sub/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Rename directory (with trailing slash)
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/docs/",
|
||||
new_path="/archive/"
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
assert "2 files" in result.lower()
|
||||
|
||||
# Verify old paths don't exist
|
||||
result = memory_tool.execute_action("view", path="/docs/file1.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
# Verify new paths exist
|
||||
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||
assert "Content 1" in result
|
||||
|
||||
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
|
||||
assert "Content 2" in result
|
||||
|
||||
|
||||
def test_rename_directory_without_trailing_slash(memory_tool: MemoryTool) -> None:
|
||||
"""Test renaming a directory when new path is missing trailing slash."""
|
||||
# Create files in a directory
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/file1.txt",
|
||||
file_text="Content 1"
|
||||
)
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/docs/sub/file2.txt",
|
||||
file_text="Content 2"
|
||||
)
|
||||
|
||||
# Rename directory - old path has slash, new path doesn't
|
||||
result = memory_tool.execute_action(
|
||||
"rename",
|
||||
old_path="/docs/",
|
||||
new_path="/archive" # Missing trailing slash
|
||||
)
|
||||
assert "renamed" in result.lower()
|
||||
|
||||
# Verify paths are correct (not corrupted like "/archivesub/file2.txt")
|
||||
result = memory_tool.execute_action("view", path="/archive/file1.txt")
|
||||
assert "Content 1" in result
|
||||
|
||||
result = memory_tool.execute_action("view", path="/archive/sub/file2.txt")
|
||||
assert "Content 2" in result
|
||||
|
||||
# Verify corrupted path doesn't exist
|
||||
result = memory_tool.execute_action("view", path="/archivesub/file2.txt")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_view_file_line_numbers(memory_tool: MemoryTool) -> None:
|
||||
"""Test that view_range displays correct line numbers."""
|
||||
# Create a multiline file
|
||||
content = "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
memory_tool.execute_action(
|
||||
"create",
|
||||
path="/numbered.txt",
|
||||
file_text=content
|
||||
)
|
||||
|
||||
# View lines 2-4
|
||||
result = memory_tool.execute_action(
|
||||
"view",
|
||||
path="/numbered.txt",
|
||||
view_range=[2, 4]
|
||||
)
|
||||
|
||||
# Check that line numbers are correct (should be 2, 3, 4 not 3, 4, 5)
|
||||
assert "2: Line 2" in result
|
||||
assert "3: Line 3" in result
|
||||
assert "4: Line 4" in result
|
||||
assert "1: Line 1" not in result
|
||||
assert "5: Line 5" not in result
|
||||
|
||||
# Verify no off-by-one error
|
||||
assert "3: Line 2" not in result # Wrong line number
|
||||
assert "4: Line 3" not in result # Wrong line number
|
||||
assert "5: Line 4" not in result # Wrong line number
|
||||
223
tests/test_notes_tool.py
Normal file
223
tests/test_notes_tool.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import pytest
|
||||
from application.agents.tools.notes import NotesTool
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notes_tool(monkeypatch) -> NotesTool:
|
||||
"""Provide a NotesTool with a fake Mongo collection and fixed user_id."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {} # key: user_id:tool_id -> doc
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
|
||||
# emulate single-note storage with optional upsert
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
|
||||
if "$set" in u and "note" in u["$set"]:
|
||||
self.docs[key]["note"] = u["$set"]["note"]
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
return self.docs.get(key)
|
||||
|
||||
def delete_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
if key in self.docs:
|
||||
del self.docs[key]
|
||||
return type("res", (), {"deleted_count": 1})
|
||||
return type("res", (), {"deleted_count": 0})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
# Patch MongoDB client globally for the tool
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Return tool with a fixed tool_id for consistency in tests
|
||||
return NotesTool({"tool_id": "test_tool_id"}, user_id="test_user")
|
||||
|
||||
|
||||
def test_view(notes_tool: NotesTool) -> None:
|
||||
# Manually insert a note to test retrieval
|
||||
notes_tool.collection.update_one(
|
||||
{"user_id": "test_user", "tool_id": "test_tool_id"},
|
||||
{"$set": {"note": "hello"}},
|
||||
upsert=True
|
||||
)
|
||||
assert "hello" in notes_tool.execute_action("view")
|
||||
|
||||
|
||||
def test_overwrite_and_delete(notes_tool: NotesTool) -> None:
|
||||
# Overwrite creates a new note
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="first").lower()
|
||||
assert "first" in notes_tool.execute_action("view")
|
||||
|
||||
# Overwrite replaces existing note
|
||||
assert "saved" in notes_tool.execute_action("overwrite", text="second").lower()
|
||||
assert "second" in notes_tool.execute_action("view")
|
||||
|
||||
assert "deleted" in notes_tool.execute_action("delete").lower()
|
||||
assert "no note" in notes_tool.execute_action("view").lower()
|
||||
|
||||
def test_init_without_user_id(monkeypatch):
|
||||
"""Should fail gracefully if no user_id is provided."""
|
||||
notes_tool = NotesTool(tool_config={})
|
||||
result = notes_tool.execute_action("view")
|
||||
assert "user_id" in str(result).lower()
|
||||
|
||||
|
||||
def test_view_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Should return 'No note found.' when no note exists"""
|
||||
result = notes_tool.execute_action("view")
|
||||
assert "no note found" in result.lower()
|
||||
|
||||
|
||||
def test_str_replace(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement in note"""
|
||||
# Create a note
|
||||
notes_tool.execute_action("overwrite", text="Hello world, hello universe")
|
||||
|
||||
# Replace text
|
||||
result = notes_tool.execute_action("str_replace", old_str="hello", new_str="hi")
|
||||
assert "updated" in result.lower()
|
||||
|
||||
# Verify replacement
|
||||
note = notes_tool.execute_action("view")
|
||||
assert "hi world, hi universe" in note.lower()
|
||||
|
||||
|
||||
def test_str_replace_not_found(notes_tool: NotesTool) -> None:
|
||||
"""Test string replacement when string not found"""
|
||||
notes_tool.execute_action("overwrite", text="Hello world")
|
||||
result = notes_tool.execute_action("str_replace", old_str="goodbye", new_str="hi")
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
def test_insert_line(notes_tool: NotesTool) -> None:
|
||||
"""Test inserting text at a line number"""
|
||||
# Create a multiline note
|
||||
notes_tool.execute_action("overwrite", text="Line 1\nLine 2\nLine 3")
|
||||
|
||||
# Insert at line 2
|
||||
result = notes_tool.execute_action("insert", line_number=2, text="Inserted line")
|
||||
assert "inserted" in result.lower()
|
||||
|
||||
# Verify insertion
|
||||
note = notes_tool.execute_action("view")
|
||||
lines = note.split("\n")
|
||||
assert lines[1] == "Inserted line"
|
||||
assert lines[2] == "Line 2"
|
||||
|
||||
|
||||
def test_delete_nonexistent_note(monkeypatch):
|
||||
class FakeResult:
|
||||
deleted_count = 0
|
||||
|
||||
class FakeCollection:
|
||||
def delete_one(self, *args, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.mongo_db.MongoDB.get_client",
|
||||
lambda: {"docsgpt": {"notes": FakeCollection()}}
|
||||
)
|
||||
|
||||
notes_tool = NotesTool(tool_config={}, user_id="user123")
|
||||
result = notes_tool.execute_action("delete")
|
||||
assert "no note found" in result.lower()
|
||||
|
||||
|
||||
def test_notes_tool_isolation(monkeypatch) -> None:
|
||||
"""Test that different notes tool instances have isolated notes."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
|
||||
if key not in self.docs and not upsert:
|
||||
return type("res", (), {"modified_count": 0})
|
||||
if key not in self.docs and upsert:
|
||||
self.docs[key] = {"user_id": user_id, "tool_id": tool_id, "note": ""}
|
||||
if "$set" in u and "note" in u["$set"]:
|
||||
self.docs[key]["note"] = u["$set"]["note"]
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
def find_one(self, q):
|
||||
user_id = q.get("user_id")
|
||||
tool_id = q.get("tool_id")
|
||||
key = f"{user_id}:{tool_id}"
|
||||
return self.docs.get(key)
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two notes tools with different tool_ids for the same user
|
||||
tool1 = NotesTool({"tool_id": "tool_1"}, user_id="test_user")
|
||||
tool2 = NotesTool({"tool_id": "tool_2"}, user_id="test_user")
|
||||
|
||||
# Create a note in tool1
|
||||
tool1.execute_action("overwrite", text="Content from tool 1")
|
||||
|
||||
# Create a note in tool2
|
||||
tool2.execute_action("overwrite", text="Content from tool 2")
|
||||
|
||||
# Verify that each tool sees only its own content
|
||||
result1 = tool1.execute_action("view")
|
||||
result2 = tool2.execute_action("view")
|
||||
|
||||
assert "Content from tool 1" in result1
|
||||
assert "Content from tool 2" not in result1
|
||||
|
||||
assert "Content from tool 2" in result2
|
||||
assert "Content from tool 1" not in result2
|
||||
|
||||
|
||||
def test_notes_tool_auto_generates_tool_id(monkeypatch) -> None:
|
||||
"""Test that tool_id defaults to 'default_{user_id}' for persistence."""
|
||||
class FakeCollection:
|
||||
def __init__(self) -> None:
|
||||
self.docs = {}
|
||||
|
||||
def update_one(self, q, u, upsert=False):
|
||||
return type("res", (), {"modified_count": 1})
|
||||
|
||||
fake_collection = FakeCollection()
|
||||
fake_db = {"notes": fake_collection}
|
||||
fake_client = {settings.MONGO_DB_NAME: fake_db}
|
||||
|
||||
monkeypatch.setattr("application.core.mongo_db.MongoDB.get_client", lambda: fake_client)
|
||||
|
||||
# Create two tools without providing tool_id for the same user
|
||||
tool1 = NotesTool({}, user_id="test_user")
|
||||
tool2 = NotesTool({}, user_id="test_user")
|
||||
|
||||
# Both should have the same default tool_id for persistence
|
||||
assert tool1.tool_id == "default_test_user"
|
||||
assert tool2.tool_id == "default_test_user"
|
||||
assert tool1.tool_id == tool2.tool_id
|
||||
|
||||
# Different users should have different tool_ids
|
||||
tool3 = NotesTool({}, user_id="another_user")
|
||||
assert tool3.tool_id == "default_another_user"
|
||||
assert tool3.tool_id != tool1.tool_id
|
||||
Reference in New Issue
Block a user