mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 22:44:10 +00:00
Compare commits
31 Commits
messages-f
...
fast-ebook
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db | ||
|
|
689dd79597 | ||
|
|
0c15af90b1 | ||
|
|
cdd6ff6557 | ||
|
|
72b3d94453 | ||
|
|
7e88d09e5d | ||
|
|
74a4a237dc | ||
|
|
c3f01c6619 | ||
|
|
6b408823d4 | ||
|
|
3fc81ac5d8 | ||
|
|
2652f8a5b0 | ||
|
|
d711eefe96 | ||
|
|
79206f3919 | ||
|
|
de971d9452 | ||
|
|
1b4d5ca0dd | ||
|
|
81989e8258 | ||
|
|
dc262d1698 | ||
|
|
69f9c93869 | ||
|
|
74bf80b25c | ||
|
|
d9a92a7208 | ||
|
|
02e93d993d | ||
|
|
6b6495f48c | ||
|
|
249dd9ce37 | ||
|
|
9134ab0478 | ||
|
|
10ef68c9d0 | ||
|
|
7d65cf1c2b | ||
|
|
13c6cc59c1 | ||
|
|
648b3f1d20 | ||
|
|
a75a9e23f9 |
18
SECURITY.md
18
SECURITY.md
@@ -2,13 +2,21 @@
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Supported Versions:
|
||||
|
||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
||||
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Found a vulnerability? Please email us:
|
||||
Preferred method: use GitHub's private vulnerability reporting flow:
|
||||
https://github.com/arc53/DocsGPT/security
|
||||
|
||||
security@arc53.com
|
||||
Then click **Report a vulnerability**.
|
||||
|
||||
|
||||
Alternatively, email us at: security@arc53.com
|
||||
|
||||
We aim to acknowledge reports within 48 hours.
|
||||
|
||||
## Incident Handling
|
||||
|
||||
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -61,7 +62,8 @@ class MCPTool(Tool):
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
self.server_url = config.get("server_url", "")
|
||||
raw_url = config.get("server_url", "")
|
||||
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
@@ -87,6 +89,18 @@ class MCPTool(Tool):
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
@staticmethod
|
||||
def _validate_server_url(server_url: str) -> str:
|
||||
"""Validate server_url to prevent SSRF to internal networks.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL points to a private/internal address.
|
||||
"""
|
||||
try:
|
||||
return validate_url(server_url)
|
||||
except SSRFError as exc:
|
||||
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
|
||||
|
||||
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
||||
if configured_redirect_uri:
|
||||
return configured_redirect_uri.rstrip("/")
|
||||
@@ -108,8 +122,9 @@ class MCPTool(Tool):
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
|
||||
auth_key = (
|
||||
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
)
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
|
||||
@@ -85,6 +85,10 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
stream = self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
|
||||
@@ -92,6 +92,14 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question="",
|
||||
|
||||
@@ -112,6 +112,7 @@ class StreamProcessor:
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
self.compressed_summary: Optional[str] = None
|
||||
self.compressed_summary_tokens: int = 0
|
||||
self._agent_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
@@ -359,22 +360,29 @@ class StreamProcessor:
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
"""Configure the source based on agent data.
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
The literal string ``"default"`` is a placeholder meaning "no
|
||||
ingested source" and is normalized to an empty source so that no
|
||||
retrieval is attempted.
|
||||
"""
|
||||
if self._agent_data:
|
||||
agent_data = self._agent_data
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
source["id"]
|
||||
for source in agent_data["sources"]
|
||||
if source.get("id") and source["id"] != "default"
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.all_sources = [
|
||||
s for s in agent_data["sources"] if s.get("id") != "default"
|
||||
]
|
||||
elif agent_data.get("source") and agent_data["source"] != "default":
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
@@ -387,11 +395,24 @@ class StreamProcessor:
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
active_docs = self.data["active_docs"]
|
||||
if active_docs and active_docs != "default":
|
||||
self.source = {"active_docs": active_docs}
|
||||
else:
|
||||
self.source = {}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _has_active_docs(self) -> bool:
|
||||
"""Return True if a real document source is configured for retrieval."""
|
||||
active_docs = self.source.get("active_docs") if self.source else None
|
||||
if not active_docs:
|
||||
return False
|
||||
if active_docs == "default":
|
||||
return False
|
||||
return True
|
||||
|
||||
def _resolve_agent_id(self) -> Optional[str]:
|
||||
"""Resolve agent_id from request, then fall back to conversation context."""
|
||||
request_agent_id = self.data.get("agent_id")
|
||||
@@ -433,48 +454,39 @@ class StreamProcessor:
|
||||
effective_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if effective_key:
|
||||
data_key = self._get_data_from_api_key(effective_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self._agent_data = self._get_data_from_api_key(effective_key)
|
||||
if self._agent_data.get("_id"):
|
||||
self.agent_id = str(self._agent_data.get("_id"))
|
||||
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"prompt_id": self._agent_data.get("prompt_id", "default"),
|
||||
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": effective_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
"models": data_key.get("models", []),
|
||||
"json_schema": self._agent_data.get("json_schema"),
|
||||
"default_model_id": self._agent_data.get("default_model_id", ""),
|
||||
"models": self._agent_data.get("models", []),
|
||||
"allow_system_prompt_override": self._agent_data.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Set identity context
|
||||
if self.data.get("api_key"):
|
||||
# External API key: use the key owner's identity
|
||||
self.initial_user_id = data_key.get("user")
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
self.initial_user_id = self._agent_data.get("user")
|
||||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||
elif self.is_shared_usage:
|
||||
# Shared agent: keep the caller's identity
|
||||
pass
|
||||
else:
|
||||
# Owner using their own agent
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("workflow"):
|
||||
self.agent_config["workflow"] = data_key["workflow"]
|
||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
if self._agent_data.get("workflow"):
|
||||
self.agent_config["workflow"] = self._agent_data["workflow"]
|
||||
self.agent_config["workflow_owner"] = self._agent_data.get("user")
|
||||
else:
|
||||
# No API key — default/workflow configuration
|
||||
agent_type = settings.AGENT_NAME
|
||||
@@ -497,14 +509,45 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Assemble retriever config with precedence: request > agent > default."""
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||
|
||||
# Start with defaults
|
||||
retriever_name = "classic"
|
||||
chunks = 2
|
||||
|
||||
# Layer agent-level config (if present)
|
||||
if self._agent_data:
|
||||
if self._agent_data.get("retriever"):
|
||||
retriever_name = self._agent_data["retriever"]
|
||||
if self._agent_data.get("chunks") is not None:
|
||||
try:
|
||||
chunks = int(self._agent_data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
# Explicit request values win over agent config
|
||||
if "retriever" in self.data:
|
||||
retriever_name = self.data["retriever"]
|
||||
if "chunks" in self.data:
|
||||
try:
|
||||
chunks = int(self.data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"retriever_name": retriever_name,
|
||||
"chunks": chunks,
|
||||
"doc_token_limit": doc_token_limit,
|
||||
}
|
||||
|
||||
# isNoneDoc without an API key forces no retrieval
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
@@ -528,6 +571,9 @@ class StreamProcessor:
|
||||
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||
return None, None
|
||||
if not self._has_active_docs():
|
||||
logger.info("Pre-fetch skipped: no active docs configured")
|
||||
return None, None
|
||||
try:
|
||||
retriever = self.create_retriever()
|
||||
logger.info(
|
||||
@@ -910,15 +956,23 @@ class StreamProcessor:
|
||||
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
||||
self._prompt_content = raw_prompt
|
||||
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
# Allow API callers to override the system prompt when the agent
|
||||
# has opted in via allow_system_prompt_override.
|
||||
if (
|
||||
self.agent_config.get("allow_system_prompt_override", False)
|
||||
and self.data.get("system_prompt_override")
|
||||
):
|
||||
rendered_prompt = self.data["system_prompt_override"]
|
||||
else:
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
|
||||
@@ -26,12 +26,20 @@ internal = Blueprint("internal", __name__)
|
||||
|
||||
@internal.before_request
|
||||
def verify_internal_key():
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests."""
|
||||
if settings.INTERNAL_KEY:
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests.
|
||||
|
||||
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
|
||||
"""
|
||||
if not settings.INTERNAL_KEY:
|
||||
logger.warning(
|
||||
f"Internal API request rejected from {request.remote_addr}: "
|
||||
"INTERNAL_KEY is not configured"
|
||||
)
|
||||
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
|
||||
|
||||
@internal.route("/api/download", methods=["get"])
|
||||
|
||||
@@ -73,6 +73,7 @@ AGENT_TYPE_SCHEMAS = {
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"allow_system_prompt_override",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
@@ -96,6 +97,7 @@ AGENT_TYPE_SCHEMAS = {
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"allow_system_prompt_override",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
@@ -220,6 +222,12 @@ def build_agent_document(
|
||||
base_doc["request_limit"] = int(
|
||||
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
if "allow_system_prompt_override" in allowed_fields:
|
||||
base_doc["allow_system_prompt_override"] = (
|
||||
data.get("allow_system_prompt_override") == "True"
|
||||
if isinstance(data.get("allow_system_prompt_override"), str)
|
||||
else bool(data.get("allow_system_prompt_override", False))
|
||||
)
|
||||
return {k: v for k, v in base_doc.items() if k in allowed_fields}
|
||||
|
||||
|
||||
@@ -292,6 +300,9 @@ class GetAgent(Resource):
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
"allow_system_prompt_override": agent.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -373,6 +384,9 @@ class GetAgents(Resource):
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
"allow_system_prompt_override": agent.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent
|
||||
@@ -450,6 +464,10 @@ class CreateAgent(Resource):
|
||||
"folder_id": fields.String(
|
||||
required=False, description="Folder ID to organize the agent"
|
||||
),
|
||||
"allow_system_prompt_override": fields.Boolean(
|
||||
required=False,
|
||||
description="Allow API callers to override the system prompt via the v1 endpoint",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -491,9 +509,9 @@ class CreateAgent(Resource):
|
||||
data["json_schema"] = normalize_json_schema_payload(
|
||||
data.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
except JsonSchemaValidationError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
jsonify({"success": False, "message": "Invalid JSON schema"}),
|
||||
400,
|
||||
)
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
@@ -674,6 +692,10 @@ class UpdateAgent(Resource):
|
||||
"folder_id": fields.String(
|
||||
required=False, description="Folder ID to organize the agent"
|
||||
),
|
||||
"allow_system_prompt_override": fields.Boolean(
|
||||
required=False,
|
||||
description="Allow API callers to override the system prompt via the v1 endpoint",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -765,6 +787,7 @@ class UpdateAgent(Resource):
|
||||
"default_model_id",
|
||||
"folder_id",
|
||||
"workflow",
|
||||
"allow_system_prompt_override",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -872,9 +895,9 @@ class UpdateAgent(Resource):
|
||||
update_fields[field] = normalize_json_schema_payload(
|
||||
json_schema
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
except JsonSchemaValidationError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
jsonify({"success": False, "message": "Invalid JSON schema"}),
|
||||
400,
|
||||
)
|
||||
else:
|
||||
@@ -983,6 +1006,13 @@ class UpdateAgent(Resource):
|
||||
if workflow_error:
|
||||
return workflow_error
|
||||
update_fields[field] = workflow_id
|
||||
elif field == "allow_system_prompt_override":
|
||||
raw_value = data.get("allow_system_prompt_override", False)
|
||||
update_fields[field] = (
|
||||
raw_value == "True"
|
||||
if isinstance(raw_value, str)
|
||||
else bool(raw_value)
|
||||
)
|
||||
else:
|
||||
value = data[field]
|
||||
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||
|
||||
@@ -612,6 +612,10 @@ class LiveSpeechToTextFinish(Resource):
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
try:
|
||||
from application.api.user.base import storage
|
||||
|
||||
@@ -629,6 +633,10 @@ class ServeImage(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
|
||||
@@ -57,7 +57,7 @@ class ShareConversation(Resource):
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
{"_id": ObjectId(conversation_id), "user": user}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
|
||||
@@ -463,6 +463,16 @@ class ManageSourceFiles(Resource):
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid file path",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
|
||||
@@ -14,6 +14,7 @@ from application.api.user.tools.routes import transform_actions
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields
|
||||
|
||||
@@ -63,6 +64,21 @@ def _extract_auth_credentials(config):
|
||||
return auth_credentials
|
||||
|
||||
|
||||
def _validate_mcp_server_url(config: dict) -> None:
|
||||
"""Validate the server_url in an MCP config to prevent SSRF.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL is missing or points to a blocked address.
|
||||
"""
|
||||
server_url = (config.get("server_url") or "").strip()
|
||||
if not server_url:
|
||||
raise ValueError("server_url is required")
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError as exc:
|
||||
raise ValueError(f"Invalid server URL: {exc}") from exc
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
@@ -97,6 +113,8 @@ class TestMCPServerConfig(Resource):
|
||||
400,
|
||||
)
|
||||
|
||||
_validate_mcp_server_url(config)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
@@ -105,15 +123,41 @@ class TestMCPServerConfig(Resource):
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
if result.get("requires_oauth"):
|
||||
return make_response(jsonify(result), 200)
|
||||
safe_result = {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k in ("success", "requires_oauth", "auth_url")
|
||||
}
|
||||
return make_response(jsonify(safe_result), 200)
|
||||
|
||||
if not result.get("success") and "message" in result:
|
||||
if not result.get("success"):
|
||||
current_app.logger.error(
|
||||
f"MCP connection test failed: {result.get('message')}"
|
||||
)
|
||||
result["message"] = "Connection test failed"
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Connection test failed",
|
||||
"tools_count": 0,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
safe_result = {
|
||||
"success": True,
|
||||
"message": result.get("message", "Connection successful"),
|
||||
"tools_count": result.get("tools_count", 0),
|
||||
"tools": result.get("tools", []),
|
||||
}
|
||||
return make_response(jsonify(safe_result), 200)
|
||||
except ValueError as e:
|
||||
current_app.logger.warning(f"Invalid MCP server test request: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
@@ -165,6 +209,8 @@ class MCPServerSave(Resource):
|
||||
400,
|
||||
)
|
||||
|
||||
_validate_mcp_server_url(config)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
auth_type = config.get("auth_type", "none")
|
||||
mcp_config = config.copy()
|
||||
@@ -279,6 +325,12 @@ class MCPServerSave(Resource):
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except ValueError as e:
|
||||
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
|
||||
@@ -8,6 +8,7 @@ from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
@@ -130,6 +131,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
if not request.decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
@@ -236,6 +239,16 @@ class CreateTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["name"] == "mcp_tool":
|
||||
server_url = (data.get("config", {}).get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
@@ -421,6 +434,16 @@ class UpdateToolConfig(Resource):
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
if tool_name == "mcp_tool":
|
||||
server_url = (data["config"].get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
|
||||
@@ -138,10 +138,18 @@ def chat_completions():
|
||||
if usage_error:
|
||||
return usage_error
|
||||
|
||||
should_save_conversation = bool(internal_data.get("save_conversation", False))
|
||||
|
||||
if is_stream:
|
||||
return Response(
|
||||
_stream_response(
|
||||
helper, question, agent, processor, model_name, continuation
|
||||
helper,
|
||||
question,
|
||||
agent,
|
||||
processor,
|
||||
model_name,
|
||||
continuation,
|
||||
should_save_conversation,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
@@ -151,7 +159,13 @@ def chat_completions():
|
||||
)
|
||||
else:
|
||||
return _non_stream_response(
|
||||
helper, question, agent, processor, model_name, continuation
|
||||
helper,
|
||||
question,
|
||||
agent,
|
||||
processor,
|
||||
model_name,
|
||||
continuation,
|
||||
should_save_conversation,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
@@ -181,6 +195,7 @@ def _stream_response(
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
should_save_conversation: bool,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate translated SSE chunks for streaming response."""
|
||||
completion_id = f"chatcmpl-{int(time.time())}"
|
||||
@@ -193,6 +208,7 @@ def _stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
@@ -225,6 +241,7 @@ def _non_stream_response(
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
should_save_conversation: bool,
|
||||
) -> Response:
|
||||
"""Collect full response and return as single JSON."""
|
||||
stream = helper.complete_stream(
|
||||
@@ -235,6 +252,7 @@ def _non_stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
@@ -293,8 +311,9 @@ def list_models():
|
||||
for ag in user_agents:
|
||||
created = ag.get("createdAt")
|
||||
created_ts = int(created.timestamp()) if created else int(time.time())
|
||||
model_id = str(ag.get("_id") or ag.get("id") or "")
|
||||
models.append({
|
||||
"id": str(ag.get("key", "")),
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": created_ts,
|
||||
"owned_by": "docsgpt",
|
||||
|
||||
@@ -80,6 +80,17 @@ def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
|
||||
"""Extract the first system message content from the messages array.
|
||||
|
||||
Returns None if no system message is present.
|
||||
"""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
return msg.get("content", "")
|
||||
return None
|
||||
|
||||
|
||||
def convert_history(messages: List[Dict]) -> List[Dict]:
|
||||
"""Convert chat completions messages array to DocsGPT history format.
|
||||
|
||||
@@ -148,20 +159,27 @@ def translate_request(
|
||||
break
|
||||
|
||||
history = convert_history(messages)
|
||||
system_prompt_override = extract_system_prompt(messages)
|
||||
|
||||
docsgpt = data.get("docsgpt", {})
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
"api_key": api_key,
|
||||
"history": json.dumps(history),
|
||||
"save_conversation": True,
|
||||
# Conversations are NOT persisted by default on the v1 endpoint.
|
||||
# Callers opt in via ``docsgpt.save_conversation: true``.
|
||||
"save_conversation": bool(docsgpt.get("save_conversation", False)),
|
||||
}
|
||||
|
||||
if system_prompt_override is not None:
|
||||
result["system_prompt_override"] = system_prompt_override
|
||||
|
||||
# Client tools
|
||||
if data.get("tools"):
|
||||
result["client_tools"] = data["tools"]
|
||||
|
||||
# DocsGPT extensions
|
||||
docsgpt = data.get("docsgpt", {})
|
||||
if docsgpt.get("attachments"):
|
||||
result["attachments"] = docsgpt["attachments"]
|
||||
|
||||
|
||||
@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
"""Parse file."""
|
||||
try:
|
||||
import ebooklib
|
||||
from ebooklib import epub
|
||||
from fast_ebook import epub
|
||||
except ImportError:
|
||||
raise ValueError("`EbookLib` is required to read Epub files.")
|
||||
try:
|
||||
import html2text
|
||||
except ImportError:
|
||||
raise ValueError("`html2text` is required to parse Epub files.")
|
||||
raise ValueError("`fast-ebook` is required to read Epub files.")
|
||||
|
||||
text_list = []
|
||||
book = epub.read_epub(file, options={"ignore_ncx": True})
|
||||
|
||||
# Iterate through all chapters.
|
||||
for item in book.get_items():
|
||||
# Chapters are typically located in epub documents items.
|
||||
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
||||
text_list.append(
|
||||
html2text.html2text(item.get_content().decode("utf-8"))
|
||||
)
|
||||
|
||||
text = "\n".join(text_list)
|
||||
book = epub.read_epub(file)
|
||||
text = book.to_markdown()
|
||||
return text
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
anthropic==0.86.0
|
||||
boto3==1.42.24
|
||||
anthropic==0.88.0
|
||||
boto3==1.42.83
|
||||
beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.3
|
||||
@@ -11,11 +11,11 @@ rapidocr>=1.4.0
|
||||
onnxruntime>=1.19.0
|
||||
docx2txt==0.9
|
||||
ddgs>=8.0.0
|
||||
ebooklib==0.20
|
||||
elevenlabs==2.40.0
|
||||
fast-ebook
|
||||
elevenlabs==2.41.0
|
||||
Flask==3.1.3
|
||||
faiss-cpu==1.13.2
|
||||
fastmcp==2.14.6
|
||||
fastmcp==3.2.0
|
||||
flask-restx==1.3.2
|
||||
google-genai==1.69.0
|
||||
google-api-python-client==2.193.0
|
||||
@@ -23,10 +23,9 @@ google-auth-httplib2==0.3.1
|
||||
google-auth-oauthlib==1.3.1
|
||||
gTTS==2.5.4
|
||||
gunicorn==25.3.0
|
||||
html2text==2025.4.15
|
||||
jinja2==3.1.6
|
||||
jiter==0.13.0
|
||||
jmespath==1.0.1
|
||||
jmespath==1.1.0
|
||||
joblib==1.5.3
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
@@ -34,7 +33,7 @@ kombu==5.6.2
|
||||
langchain==1.2.3
|
||||
langchain-community==0.4.1
|
||||
langchain-core==1.2.23
|
||||
langchain-openai==1.1.7
|
||||
langchain-openai==1.1.12
|
||||
langchain-text-splitters==1.1.1
|
||||
langsmith==0.7.23
|
||||
lazy-object-proxy==1.12.0
|
||||
@@ -53,10 +52,10 @@ orjson==3.11.7
|
||||
packaging==26.0
|
||||
pandas==3.0.2
|
||||
openpyxl==3.1.5
|
||||
pathable==0.4.4
|
||||
pathable==0.5.0
|
||||
pdf2image>=1.17.0
|
||||
pillow
|
||||
portalocker>=2.7.0,<3.0.0
|
||||
portalocker>=2.7.0,<4.0.0
|
||||
prompt-toolkit==3.0.52
|
||||
protobuf==7.34.1
|
||||
psycopg2-binary==2.9.11
|
||||
@@ -65,14 +64,14 @@ pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pymongo==4.16.0
|
||||
pypdf==6.6.0
|
||||
pypdf==6.9.2
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv
|
||||
python-jose==3.5.0
|
||||
python-pptx==1.0.2
|
||||
redis==7.4.0
|
||||
referencing>=0.28.0,<0.38.0
|
||||
regex==2026.3.32
|
||||
regex==2026.4.4
|
||||
requests==2.33.1
|
||||
retry==0.9.2
|
||||
sentence-transformers==5.3.0
|
||||
|
||||
@@ -21,10 +21,19 @@ class LocalStorage(BaseStorage):
|
||||
)
|
||||
|
||||
def _get_full_path(self, path: str) -> str:
|
||||
"""Get absolute path by combining base_dir and path."""
|
||||
"""Get absolute path by combining base_dir and path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the resolved path escapes base_dir (path traversal).
|
||||
"""
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.join(self.base_dir, path)
|
||||
resolved = os.path.realpath(path)
|
||||
else:
|
||||
resolved = os.path.realpath(os.path.join(self.base_dir, path))
|
||||
base = os.path.realpath(self.base_dir)
|
||||
if not resolved.startswith(base + os.sep) and resolved != base:
|
||||
raise ValueError(f"Path traversal detected: {path}")
|
||||
return resolved
|
||||
|
||||
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
||||
"""Save a file to local storage."""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import io
|
||||
import os
|
||||
import posixpath
|
||||
from typing import BinaryIO, Callable, List
|
||||
|
||||
import boto3
|
||||
@@ -14,6 +15,20 @@ from botocore.exceptions import ClientError
|
||||
class S3Storage(BaseStorage):
|
||||
"""AWS S3 storage implementation."""
|
||||
|
||||
@staticmethod
|
||||
def _validate_path(path: str) -> str:
|
||||
"""Validate and normalize an S3 key to prevent path traversal.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path contains traversal sequences or is absolute.
|
||||
"""
|
||||
if "\x00" in path:
|
||||
raise ValueError(f"Null byte in path: {path}")
|
||||
normalized = posixpath.normpath(path)
|
||||
if normalized.startswith("/") or normalized.startswith(".."):
|
||||
raise ValueError(f"Path traversal detected: {path}")
|
||||
return normalized
|
||||
|
||||
def __init__(self, bucket_name=None):
|
||||
"""
|
||||
Initialize S3 storage.
|
||||
@@ -46,6 +61,7 @@ class S3Storage(BaseStorage):
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Save a file to S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
self.s3.upload_fileobj(
|
||||
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
|
||||
)
|
||||
@@ -61,6 +77,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
def get_file(self, path: str) -> BinaryIO:
|
||||
"""Get a file from S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
if not self.file_exists(path):
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
file_obj = io.BytesIO()
|
||||
@@ -70,6 +87,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
def delete_file(self, path: str) -> bool:
|
||||
"""Delete a file from S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
try:
|
||||
self.s3.delete_object(Bucket=self.bucket_name, Key=path)
|
||||
return True
|
||||
@@ -78,6 +96,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
def file_exists(self, path: str) -> bool:
|
||||
"""Check if a file exists in S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
try:
|
||||
self.s3.head_object(Bucket=self.bucket_name, Key=path)
|
||||
return True
|
||||
@@ -115,6 +134,7 @@ class S3Storage(BaseStorage):
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
path = self._validate_path(path)
|
||||
if not self.file_exists(path):
|
||||
raise FileNotFoundError(f"File not found in S3: {path}")
|
||||
with tempfile.NamedTemporaryFile(
|
||||
|
||||
@@ -11,11 +11,33 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
def get_vectorstore(path: str) -> str:
|
||||
if path:
|
||||
vectorstore = f"indexes/{path}"
|
||||
else:
|
||||
vectorstore = "indexes"
|
||||
return vectorstore
|
||||
"""Build a safe local path for a FAISS index.
|
||||
|
||||
Args:
|
||||
path: Source identifier provided by the caller.
|
||||
|
||||
Returns:
|
||||
The validated vectorstore path rooted under ``indexes``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``path`` escapes the ``indexes`` directory.
|
||||
"""
|
||||
base_dir = "indexes"
|
||||
if not path:
|
||||
return base_dir
|
||||
|
||||
normalized = str(path).strip()
|
||||
if "\\" in normalized:
|
||||
raise ValueError("Invalid source_id path")
|
||||
|
||||
candidate = os.path.normpath(os.path.join(base_dir, normalized))
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
candidate_abs = os.path.abspath(candidate)
|
||||
|
||||
if not candidate_abs.startswith(base_abs + os.sep) and candidate_abs != base_abs:
|
||||
raise ValueError("Invalid source_id path")
|
||||
|
||||
return candidate
|
||||
|
||||
|
||||
class FaissStore(BaseVectorStore):
|
||||
|
||||
@@ -7,6 +7,10 @@ export default {
|
||||
"title": "🔌 Agent API",
|
||||
"href": "/Agents/api"
|
||||
},
|
||||
"openai-compatible": {
|
||||
"title": "🔄 OpenAI-Compatible API",
|
||||
"href": "/Agents/openai-compatible"
|
||||
},
|
||||
"webhooks": {
|
||||
"title": "🪝 Agent Webhooks",
|
||||
"href": "/Agents/webhooks"
|
||||
|
||||
@@ -15,6 +15,10 @@ DocsGPT Agents can be accessed programmatically through API endpoints. This page
|
||||
|
||||
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
|
||||
|
||||
<Callout type="info">
|
||||
Looking to connect an existing OpenAI-compatible client (opencode, aider, the OpenAI SDKs, etc.) to a DocsGPT Agent? Use the [OpenAI-Compatible Chat Completions API](/Agents/openai-compatible) — it speaks the standard chat completions protocol so no adapter code is required.
|
||||
</Callout>
|
||||
|
||||
## Base URL
|
||||
|
||||
<Callout type="info">
|
||||
|
||||
@@ -111,6 +111,7 @@ Once an agent is created, you can:
|
||||
* Modify any of its configuration settings (name, description, source, prompt, tools, type).
|
||||
* **Generate a Public Link:** From the edit screen, you can create a shareable public link that allows others to import and use your agent.
|
||||
* **Get a Webhook URL:** You can also obtain a Webhook URL for the agent. This allows external applications or services to trigger the agent and receive responses programmatically, enabling powerful integrations and automations.
|
||||
* **Use it via API:** Every agent exposes an API key that can be used with the native [Agent API](/Agents/api) or the [OpenAI-Compatible API](/Agents/openai-compatible) so you can drop DocsGPT Agents into any tool that already speaks the chat completions protocol.
|
||||
|
||||
## Seeding Premade Agents from YAML
|
||||
|
||||
|
||||
93
docs/content/Agents/openai-compatible.mdx
Normal file
93
docs/content/Agents/openai-compatible.mdx
Normal file
@@ -0,0 +1,93 @@
|
||||
---
|
||||
title: OpenAI-Compatible API
|
||||
description: Connect any OpenAI-compatible client to DocsGPT Agents via /v1/chat/completions.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# OpenAI-Compatible API
|
||||
|
||||
DocsGPT exposes `/v1/chat/completions` following the standard chat completions protocol. Point any compatible client — **opencode**, **Aider**, **LibreChat** or the OpenAI SDKs — at your DocsGPT Agent by changing only the base URL and API key.
|
||||
|
||||
## Quick Start
|
||||
|
||||
<Tabs items={['Python', 'cURL']}>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:7091/v1", # or https://gptcloud.arc53.com/v1
|
||||
api_key="your_agent_api_key",
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="docsgpt-agent",
|
||||
messages=[{"role": "user", "content": "Summarize our refund policy"}],
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/v1/chat/completions \
|
||||
-H "Authorization: Bearer your_agent_api_key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"docsgpt-agent","messages":[{"role":"user","content":"Summarize our refund policy"}]}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
The `model` field is accepted but ignored — the agent bound to your API key determines the model. The agent's prompt, sources, tools, and default model are loaded automatically.
|
||||
|
||||
## Base URL & Auth
|
||||
|
||||
| Environment | Base URL |
|
||||
| --- | --- |
|
||||
| Local | `http://localhost:7091/v1` |
|
||||
| Cloud | `https://gptcloud.arc53.com/v1` |
|
||||
|
||||
Authenticate with `Authorization: Bearer <agent_api_key>`.
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
| --- | --- | --- |
|
||||
| `POST` | `/v1/chat/completions` | Chat request (streaming or non-streaming) |
|
||||
| `GET` | `/v1/models` | List agents available to your key |
|
||||
|
||||
## Streaming
|
||||
|
||||
Set `"stream": true`. You'll receive SSE chunks with `choices[0].delta.content`. DocsGPT-specific events (sources, tool calls) arrive as extra frames with a `docsgpt` key — standard clients ignore them.
|
||||
|
||||
```python
|
||||
stream = client.chat.completions.create(
|
||||
model="docsgpt-agent",
|
||||
stream=True,
|
||||
messages=[{"role": "user", "content": "Explain vector search"}],
|
||||
)
|
||||
for chunk in stream:
|
||||
print(chunk.choices[0].delta.content or "", end="", flush=True)
|
||||
```
|
||||
|
||||
## System Prompt Override
|
||||
|
||||
System messages are **dropped by default** — the agent's configured prompt is used. To allow callers to override it, enable **Allow prompt override** in the agent's Advanced settings.
|
||||
|
||||
<Callout type="warning">
|
||||
When an override is active, the agent's prompt template is replaced wholesale — template variables like `{summaries}` are not substituted.
|
||||
</Callout>
|
||||
|
||||
## Conversation Persistence
|
||||
|
||||
Conversations are **not persisted by default** (stateless, like most OpenAI clients expect). Opt in per request:
|
||||
|
||||
```json
|
||||
{ "docsgpt": { "save_conversation": true } }
|
||||
```
|
||||
|
||||
The response will include `docsgpt.conversation_id`.
|
||||
|
||||
## When to Use Native Endpoints Instead
|
||||
|
||||
Use [`/api/answer` or `/stream`](/Agents/api) if you need server-side attachments, `passthrough` template variables, explicit `conversation_id` reuse, or persistence by default.
|
||||
@@ -1,3 +1,5 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@@ -10,6 +12,7 @@ docsgpt_url = os.getenv("docsgpt_url")
|
||||
chatwoot_url = os.getenv("chatwoot_url")
|
||||
docsgpt_key = os.getenv("docsgpt_key")
|
||||
chatwoot_token = os.getenv("chatwoot_token")
|
||||
chatwoot_webhook_secret = os.getenv("chatwoot_webhook_secret", "")
|
||||
# account_id = os.getenv("account_id")
|
||||
# assignee_id = os.getenv("assignee_id")
|
||||
label_stop = "human-requested"
|
||||
@@ -45,12 +48,35 @@ def send_to_chatwoot(account, conversation, message):
|
||||
return r.json()
|
||||
|
||||
|
||||
def is_valid_chatwoot_signature(raw_body: bytes, signature_header: str | None) -> bool:
|
||||
"""Validate Chatwoot webhook signature using shared secret."""
|
||||
if not chatwoot_webhook_secret or not signature_header:
|
||||
return False
|
||||
|
||||
expected = hmac.new(
|
||||
chatwoot_webhook_secret.encode("utf-8"), raw_body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
provided = signature_header.strip()
|
||||
if provided.startswith("sha256="):
|
||||
provided = provided.split("=", maxsplit=1)[1]
|
||||
|
||||
return hmac.compare_digest(provided, expected)
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/docsgpt', methods=['POST'])
|
||||
def docsgpt():
|
||||
data = request.get_json()
|
||||
raw_body = request.get_data()
|
||||
signature = request.headers.get("X-Chatwoot-Signature")
|
||||
if not is_valid_chatwoot_signature(raw_body, signature):
|
||||
return "Unauthorized", 401
|
||||
|
||||
data = request.get_json(silent=True)
|
||||
if not isinstance(data, dict):
|
||||
return "Invalid payload", 400
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
pp.pprint(data)
|
||||
try:
|
||||
|
||||
@@ -73,6 +73,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
token_limit: undefined,
|
||||
limited_request_mode: false,
|
||||
request_limit: undefined,
|
||||
allow_system_prompt_override: false,
|
||||
models: [],
|
||||
default_model_id: '',
|
||||
});
|
||||
@@ -241,6 +242,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
formData.append(
|
||||
'allow_system_prompt_override',
|
||||
agent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
|
||||
if (imageFile) formData.append('image', imageFile);
|
||||
|
||||
if (agent.tools && agent.tools.length > 0)
|
||||
@@ -361,6 +367,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
formData.append(
|
||||
'allow_system_prompt_override',
|
||||
agent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
|
||||
if (agent.models && agent.models.length > 0) {
|
||||
formData.append('models', JSON.stringify(agent.models));
|
||||
}
|
||||
@@ -1266,6 +1277,43 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-6">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="min-w-0 flex-1">
|
||||
<h2 className="text-sm font-medium">
|
||||
{t('agents.form.advanced.systemPromptOverride')}
|
||||
</h2>
|
||||
<p className="mt-1 text-xs text-gray-600 dark:text-gray-400">
|
||||
{t(
|
||||
'agents.form.advanced.systemPromptOverrideDescription',
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() =>
|
||||
setAgent({
|
||||
...agent,
|
||||
allow_system_prompt_override:
|
||||
!agent.allow_system_prompt_override,
|
||||
})
|
||||
}
|
||||
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
|
||||
agent.allow_system_prompt_override
|
||||
? 'bg-primary'
|
||||
: 'bg-gray-300 dark:bg-gray-600'
|
||||
}`}
|
||||
>
|
||||
<span
|
||||
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
|
||||
agent.allow_system_prompt_override
|
||||
? ''
|
||||
: '-translate-x-5'
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -36,6 +36,7 @@ export type Agent = {
|
||||
default_model_id?: string;
|
||||
folder_id?: string;
|
||||
workflow?: string;
|
||||
allow_system_prompt_override?: boolean;
|
||||
};
|
||||
|
||||
export type AgentFolder = {
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
X,
|
||||
} from 'lucide-react';
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useNavigate, useParams, useSearchParams } from 'react-router-dom';
|
||||
import ReactFlow, {
|
||||
@@ -301,6 +302,7 @@ function createWorkflowPayload(
|
||||
}
|
||||
|
||||
function WorkflowBuilderInner() {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const token = useSelector(selectToken);
|
||||
const sourceDocs = useSelector(selectSourceDocs);
|
||||
@@ -1142,6 +1144,10 @@ function WorkflowBuilderInner() {
|
||||
workflowDescription || `Workflow agent: ${workflowName}`,
|
||||
);
|
||||
agentFormData.append('status', 'published');
|
||||
agentFormData.append(
|
||||
'allow_system_prompt_override',
|
||||
currentAgent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
if (imageFile) {
|
||||
agentFormData.append('image', imageFile);
|
||||
}
|
||||
@@ -1203,6 +1209,10 @@ function WorkflowBuilderInner() {
|
||||
agentFormData.append('agent_type', 'workflow');
|
||||
agentFormData.append('status', 'published');
|
||||
agentFormData.append('workflow', savedWorkflowId || '');
|
||||
agentFormData.append(
|
||||
'allow_system_prompt_override',
|
||||
currentAgent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
if (imageFile) {
|
||||
agentFormData.append('image', imageFile);
|
||||
}
|
||||
@@ -1454,6 +1464,40 @@ function WorkflowBuilderInner() {
|
||||
Image updates are included the next time you save.
|
||||
</p>
|
||||
</div>
|
||||
<div className="mb-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{t('agents.form.advanced.systemPromptOverride')}
|
||||
</label>
|
||||
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
|
||||
{t('agents.form.advanced.systemPromptOverrideDescription')}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() =>
|
||||
setCurrentAgent((prev) => ({
|
||||
...prev,
|
||||
allow_system_prompt_override:
|
||||
!prev.allow_system_prompt_override,
|
||||
}))
|
||||
}
|
||||
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
|
||||
currentAgent.allow_system_prompt_override
|
||||
? 'bg-primary'
|
||||
: 'bg-gray-300 dark:bg-gray-600'
|
||||
}`}
|
||||
>
|
||||
<span
|
||||
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
|
||||
currentAgent.allow_system_prompt_override
|
||||
? ''
|
||||
: '-translate-x-5'
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleWorkflowSettingsDone}
|
||||
disabled={isPublishing}
|
||||
|
||||
@@ -78,6 +78,7 @@ function Dropdown<T extends DropdownOption>({
|
||||
const searchRef = useRef<HTMLInputElement>(null);
|
||||
const [open, setOpen] = useState(false);
|
||||
const [query, setQuery] = useState('');
|
||||
const [dropUp, setDropUp] = useState(false);
|
||||
|
||||
const radius = rounded === '3xl' ? 'rounded-3xl' : 'rounded-xl';
|
||||
const radiusTop = rounded === '3xl' ? 'rounded-t-3xl' : 'rounded-t-xl';
|
||||
@@ -90,14 +91,23 @@ function Dropdown<T extends DropdownOption>({
|
||||
setQuery('');
|
||||
}
|
||||
};
|
||||
document.addEventListener('mousedown', handler);
|
||||
return () => document.removeEventListener('mousedown', handler);
|
||||
document.addEventListener('mousedown', handler, true);
|
||||
return () => document.removeEventListener('mousedown', handler, true);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (open && searchable && searchRef.current) searchRef.current.focus();
|
||||
}, [open, searchable]);
|
||||
|
||||
const handleToggle = () => {
|
||||
if (!open && ref.current) {
|
||||
const rect = ref.current.getBoundingClientRect();
|
||||
const spaceBelow = window.innerHeight - rect.bottom;
|
||||
setDropUp(spaceBelow < 220);
|
||||
}
|
||||
setOpen((v) => !v);
|
||||
};
|
||||
|
||||
const filtered = useMemo(() => {
|
||||
if (!searchable || !query.trim()) return options;
|
||||
const q = query.toLowerCase();
|
||||
@@ -110,8 +120,8 @@ function Dropdown<T extends DropdownOption>({
|
||||
<div className={`relative ${size}`} ref={ref}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setOpen((v) => !v)}
|
||||
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? radiusTop : radius}`}
|
||||
onClick={handleToggle}
|
||||
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? (dropUp ? radiusBottom : radiusTop) : radius}`}
|
||||
>
|
||||
<span
|
||||
className={`truncate ${contentSize} ${displayValue ? '' : 'text-muted-foreground'}`}
|
||||
@@ -125,7 +135,11 @@ function Dropdown<T extends DropdownOption>({
|
||||
|
||||
{open && (
|
||||
<div
|
||||
className={`border-border bg-card absolute inset-x-0 z-20 -mt-px overflow-hidden border border-t-0 shadow-lg ${radiusBottom}`}
|
||||
className={`border-border bg-card absolute inset-x-0 z-20 overflow-hidden border shadow-lg ${
|
||||
dropUp
|
||||
? `bottom-full -mt-px border-b-0 ${radiusTop}`
|
||||
: `-mt-px border-t-0 ${radiusBottom}`
|
||||
}`}
|
||||
>
|
||||
{searchable && (
|
||||
<div className="flex items-center px-3 py-2">
|
||||
|
||||
@@ -10,7 +10,9 @@ interface SkeletonLoaderProps {
|
||||
| 'chatbot'
|
||||
| 'dropdown'
|
||||
| 'chunkCards'
|
||||
| 'sourceCards';
|
||||
| 'sourceCards'
|
||||
| 'toolCards'
|
||||
| 'addToolCards';
|
||||
}
|
||||
|
||||
const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
|
||||
@@ -237,6 +239,55 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
|
||||
</>
|
||||
);
|
||||
|
||||
const renderAddToolCards = () => (
|
||||
<>
|
||||
{Array.from({ length: count }).map((_, idx) => (
|
||||
<div
|
||||
key={`add-tool-skel-${idx}`}
|
||||
className="border-light-gainsboro dark:border-arsenic flex h-52 w-full animate-pulse flex-col justify-between rounded-2xl border p-6"
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center justify-between px-1">
|
||||
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="mt-[9px] space-y-2 px-1">
|
||||
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
|
||||
const renderToolCards = () => (
|
||||
<>
|
||||
{Array.from({ length: count }).map((_, idx) => (
|
||||
<div
|
||||
key={`tool-skel-${idx}`}
|
||||
className="bg-muted flex h-52 w-[300px] animate-pulse flex-col justify-between rounded-2xl p-6"
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="flex items-center gap-2 px-1">
|
||||
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="mt-[9px] space-y-2 px-1">
|
||||
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex justify-end">
|
||||
<div className="h-5 w-9 rounded-full bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
|
||||
const componentMap = {
|
||||
fileTable: renderTable,
|
||||
chatbot: renderChatbot,
|
||||
@@ -246,6 +297,8 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
|
||||
analysis: renderAnalysis,
|
||||
chunkCards: renderChunkCards,
|
||||
sourceCards: renderSourceCards,
|
||||
toolCards: renderToolCards,
|
||||
addToolCards: renderAddToolCards,
|
||||
};
|
||||
|
||||
const render = componentMap[component] || componentMap.default;
|
||||
|
||||
@@ -619,7 +619,9 @@
|
||||
"tokenLimiting": "Token-Limitierung",
|
||||
"tokenLimitingDescription": "Begrenze die täglich von diesem Agenten verwendbaren Tokens",
|
||||
"requestLimiting": "Anfrage-Limitierung",
|
||||
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen"
|
||||
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen",
|
||||
"systemPromptOverride": "Prompt-Überschreibung erlauben",
|
||||
"systemPromptOverrideDescription": "Erlaubt v1-API-Aufrufern, den System-Prompt dieses Agenten zu ersetzen"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Veröffentlichte Agenten können hier in der Vorschau angezeigt werden"
|
||||
|
||||
@@ -653,7 +653,9 @@
|
||||
"tokenLimiting": "Token limiting",
|
||||
"tokenLimitingDescription": "Limit daily total tokens that can be used by this agent",
|
||||
"requestLimiting": "Request limiting",
|
||||
"requestLimitingDescription": "Limit daily total requests that can be made to this agent"
|
||||
"requestLimitingDescription": "Limit daily total requests that can be made to this agent",
|
||||
"systemPromptOverride": "Allow prompt override",
|
||||
"systemPromptOverrideDescription": "Let v1 API callers replace this agent's system prompt"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Published agents can be previewed here"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "Límite de tokens",
|
||||
"tokenLimitingDescription": "Limita el total diario de tokens que puede usar este agente",
|
||||
"requestLimiting": "Límite de solicitudes",
|
||||
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente"
|
||||
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente",
|
||||
"systemPromptOverride": "Permitir sobrescribir el prompt",
|
||||
"systemPromptOverrideDescription": "Permitir que los llamadores de la API v1 reemplacen el prompt del sistema de este agente"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Los agentes publicados se pueden previsualizar aquí"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "トークン制限",
|
||||
"tokenLimitingDescription": "このエージェントが使用できる1日の合計トークン数を制限します",
|
||||
"requestLimiting": "リクエスト制限",
|
||||
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します"
|
||||
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します",
|
||||
"systemPromptOverride": "プロンプトの上書きを許可",
|
||||
"systemPromptOverrideDescription": "v1 API呼び出し元がこのエージェントのシステムプロンプトを置き換えることを許可します"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "公開されたエージェントはここでプレビューできます"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "Лимит токенов",
|
||||
"tokenLimitingDescription": "Ограничить ежедневное общее количество токенов, которые может использовать этот агент",
|
||||
"requestLimiting": "Лимит запросов",
|
||||
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту"
|
||||
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту",
|
||||
"systemPromptOverride": "Разрешить замену промпта",
|
||||
"systemPromptOverrideDescription": "Разрешить вызовам API v1 заменять системный промпт этого агента"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Опубликованные агенты можно просмотреть здесь"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "權杖限制",
|
||||
"tokenLimitingDescription": "限制此代理每天可使用的總權杖數",
|
||||
"requestLimiting": "請求限制",
|
||||
"requestLimitingDescription": "限制每天可向此代理發出的總請求數"
|
||||
"requestLimitingDescription": "限制每天可向此代理發出的總請求數",
|
||||
"systemPromptOverride": "允許覆蓋提示詞",
|
||||
"systemPromptOverrideDescription": "允許 v1 API 呼叫者替換此代理的系統提示詞"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "已發佈的代理可以在此處預覽"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "令牌限制",
|
||||
"tokenLimitingDescription": "限制此代理每天可使用的总令牌数",
|
||||
"requestLimiting": "请求限制",
|
||||
"requestLimitingDescription": "限制每天可向此代理发出的总请求数"
|
||||
"requestLimitingDescription": "限制每天可向此代理发出的总请求数",
|
||||
"systemPromptOverride": "允许覆盖提示词",
|
||||
"systemPromptOverrideDescription": "允许 v1 API 调用者替换此代理的系统提示词"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "已发布的代理可以在此处预览"
|
||||
|
||||
@@ -3,8 +3,8 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import Spinner from '../components/Spinner';
|
||||
import { useOutsideAlerter } from '../hooks';
|
||||
import SkeletonLoader from '../components/SkeletonLoader';
|
||||
import { useLoaderState, useOutsideAlerter } from '../hooks';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import ConfigToolModal from './ConfigToolModal';
|
||||
@@ -37,7 +37,7 @@ export default function AddToolModal({
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
const [mcpModalState, setMcpModalState] =
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const [loading, setLoading] = useLoaderState(false);
|
||||
|
||||
useOutsideAlerter(modalRef, () => {
|
||||
if (modalState === 'ACTIVE') {
|
||||
@@ -121,8 +121,8 @@ export default function AddToolModal({
|
||||
</h2>
|
||||
<div className="mt-5 h-[73vh] overflow-auto px-3 py-px">
|
||||
{loading ? (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Spinner />
|
||||
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">
|
||||
<SkeletonLoader component="addToolCards" count={6} />
|
||||
</div>
|
||||
) : (
|
||||
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">
|
||||
|
||||
@@ -10,9 +10,9 @@ import NoFilesIcon from '../assets/no-files.svg';
|
||||
import SearchIcon from '../assets/search.svg';
|
||||
import ThreeDotsIcon from '../assets/three-dots.svg';
|
||||
import ContextMenu, { MenuOption } from '../components/ContextMenu';
|
||||
import Spinner from '../components/Spinner';
|
||||
import SkeletonLoader from '../components/SkeletonLoader';
|
||||
import ToggleSwitch from '../components/ToggleSwitch';
|
||||
import { useDarkTheme } from '../hooks';
|
||||
import { useDarkTheme, useLoaderState } from '../hooks';
|
||||
import AddToolModal from '../modals/AddToolModal';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import MCPServerModal from '../modals/MCPServerModal';
|
||||
@@ -33,7 +33,7 @@ export default function Tools() {
|
||||
const [selectedTool, setSelectedTool] = React.useState<
|
||||
UserToolType | APIToolType | null
|
||||
>(null);
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const [loading, setLoading] = useLoaderState(false);
|
||||
const [activeMenuId, setActiveMenuId] = React.useState<string | null>(null);
|
||||
const menuRefs = React.useRef<{
|
||||
[key: string]: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -242,10 +242,8 @@ export default function Tools() {
|
||||
</div>
|
||||
<div className="border-border dark:border-border mt-5 mb-8 border-b" />
|
||||
{loading ? (
|
||||
<div className="grid grid-cols-2 gap-6 lg:grid-cols-3">
|
||||
<div className="col-span-2 mt-24 flex h-32 items-center justify-center lg:col-span-3">
|
||||
<Spinner />
|
||||
</div>
|
||||
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
|
||||
<SkeletonLoader component="toolCards" count={6} />
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
|
||||
|
||||
12
setup.ps1
12
setup.ps1
@@ -543,8 +543,20 @@ function Configure-TTS {
|
||||
}
|
||||
}
|
||||
|
||||
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
|
||||
function Ensure-InternalKey {
|
||||
$content = if (Test-Path $ENV_FILE) { Get-Content $ENV_FILE -Raw } else { "" }
|
||||
if ($content -notmatch "(?m)^INTERNAL_KEY=") {
|
||||
$bytes = New-Object byte[] 32
|
||||
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
|
||||
$internal_key = ($bytes | ForEach-Object { $_.ToString("x2") }) -join ""
|
||||
"INTERNAL_KEY=$internal_key" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
}
|
||||
}
|
||||
|
||||
# Main advanced settings menu
|
||||
function Prompt-AdvancedSettings {
|
||||
Ensure-InternalKey
|
||||
Write-Host ""
|
||||
$configure_advanced = Read-Host "Would you like to configure advanced settings? (y/N)"
|
||||
if ($configure_advanced -ne "y" -and $configure_advanced -ne "Y") {
|
||||
|
||||
10
setup.sh
10
setup.sh
@@ -396,8 +396,18 @@ configure_tts() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
|
||||
ensure_internal_key() {
|
||||
if ! grep -q "^INTERNAL_KEY=" "$ENV_FILE" 2>/dev/null; then
|
||||
local internal_key
|
||||
internal_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
|
||||
echo "INTERNAL_KEY=$internal_key" >> "$ENV_FILE"
|
||||
fi
|
||||
}
|
||||
|
||||
# Main advanced settings menu
|
||||
prompt_advanced_settings() {
|
||||
ensure_internal_key
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Would you like to configure advanced settings? (y/N): ${NC}")" configure_advanced
|
||||
if [[ ! "$configure_advanced" =~ ^[yY]$ ]]; then
|
||||
|
||||
@@ -28,6 +28,7 @@ def _patch_mcp_globals(monkeypatch):
|
||||
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
|
||||
monkeypatch.setattr(mcp_mod, "db", mock_db)
|
||||
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", lambda url: url)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -33,6 +33,8 @@ def _patch_mcp_globals(monkeypatch):
|
||||
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
|
||||
monkeypatch.setattr(mcp_mod, "db", mock_db)
|
||||
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
|
||||
# Bypass DNS-resolving URL validation for tests using fake hostnames.
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", lambda u, **kw: u)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -136,6 +138,47 @@ class TestMCPToolInit:
|
||||
})
|
||||
assert tool.custom_headers == {"X-Custom": "val"}
|
||||
|
||||
def test_rejects_metadata_ip(self, monkeypatch):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.core.url_validation import validate_url as real_validate_url
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
|
||||
with pytest.raises(ValueError, match="Invalid MCP server URL"):
|
||||
MCPTool(config={"server_url": "http://169.254.169.254/latest/meta-data", "auth_type": "none"})
|
||||
|
||||
def test_rejects_localhost(self, monkeypatch):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.core.url_validation import validate_url as real_validate_url
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
|
||||
with pytest.raises(ValueError, match="Invalid MCP server URL"):
|
||||
MCPTool(config={"server_url": "http://localhost:8080/mcp", "auth_type": "none"})
|
||||
|
||||
def test_rejects_private_ip(self, monkeypatch):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.core.url_validation import validate_url as real_validate_url
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
|
||||
with pytest.raises(ValueError, match="Invalid MCP server URL"):
|
||||
MCPTool(config={"server_url": "http://10.0.0.1/mcp", "auth_type": "none"})
|
||||
|
||||
def test_accepts_public_url(self):
|
||||
tool = _make_tool({
|
||||
"server_url": "https://mcp.example.com/api",
|
||||
"auth_type": "none",
|
||||
})
|
||||
assert tool.server_url == "https://mcp.example.com/api"
|
||||
|
||||
def test_empty_server_url_allowed(self):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client"):
|
||||
tool = MCPTool(config={"server_url": "", "auth_type": "none"})
|
||||
assert tool.server_url == ""
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Redirect URI Resolution
|
||||
|
||||
@@ -330,6 +330,170 @@ class TestStreamProcessorDocPrefetch:
|
||||
assert docs_together is not None
|
||||
assert "Agent doc content" in docs_together
|
||||
|
||||
def test_configure_source_treats_default_string_as_no_docs(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_default_source_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
"source": "default",
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Hi", "api_key": "agent_default_source_key"},
|
||||
None,
|
||||
)
|
||||
processor._configure_agent()
|
||||
processor._configure_source()
|
||||
|
||||
assert processor.source == {}
|
||||
assert processor.all_sources == []
|
||||
|
||||
def test_prefetch_skipped_when_no_active_docs(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Hi there"},
|
||||
{"sub": "user_123"},
|
||||
)
|
||||
processor.initialize()
|
||||
processor.create_retriever = MagicMock()
|
||||
|
||||
docs_together, docs = processor.pre_fetch_docs("Hi there")
|
||||
|
||||
processor.create_retriever.assert_not_called()
|
||||
assert docs_together is None
|
||||
assert docs is None
|
||||
|
||||
def test_prefetch_skipped_when_active_docs_is_default(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Hi", "active_docs": "default"},
|
||||
{"sub": "user_123"},
|
||||
)
|
||||
processor.initialize()
|
||||
processor.create_retriever = MagicMock()
|
||||
|
||||
docs_together, docs = processor.pre_fetch_docs("Hi")
|
||||
|
||||
processor.create_retriever.assert_not_called()
|
||||
assert docs_together is None
|
||||
assert docs is None
|
||||
|
||||
def test_agent_retriever_and_chunks_propagate_to_retriever_config(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
source_id = ObjectId()
|
||||
db["sources"].insert_one(
|
||||
{"_id": source_id, "name": "src", "retriever": "hybrid", "chunks": 5}
|
||||
)
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_ret_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
"retriever": "hybrid",
|
||||
"chunks": 5,
|
||||
"source": DBRef("sources", source_id),
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Test", "api_key": "agent_ret_key"},
|
||||
None,
|
||||
)
|
||||
processor.initialize()
|
||||
|
||||
assert processor.retriever_config["retriever_name"] == "hybrid"
|
||||
assert processor.retriever_config["chunks"] == 5
|
||||
|
||||
def test_request_retriever_and_chunks_override_agent_config(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_override_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
"retriever": "hybrid",
|
||||
"chunks": 5,
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{
|
||||
"question": "Test",
|
||||
"api_key": "agent_override_key",
|
||||
"retriever": "classic",
|
||||
"chunks": 7,
|
||||
},
|
||||
None,
|
||||
)
|
||||
processor.initialize()
|
||||
|
||||
assert processor.retriever_config["retriever_name"] == "classic"
|
||||
assert processor.retriever_config["chunks"] == 7
|
||||
|
||||
def test_agent_data_fetched_once_per_request(self, mock_mongo_db):
|
||||
from unittest.mock import patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_once_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Test", "api_key": "agent_once_key"},
|
||||
None,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
processor, "_get_data_from_api_key", wraps=processor._get_data_from_api_key
|
||||
) as spy:
|
||||
processor.initialize()
|
||||
assert spy.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAttachments:
|
||||
|
||||
@@ -566,14 +566,13 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [
|
||||
{"id": "src1", "retriever": "classic"},
|
||||
{"id": "src2", "retriever": "hybrid"},
|
||||
],
|
||||
"source": None,
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": ["src1", "src2"]}
|
||||
assert len(sp.all_sources) == 2
|
||||
@@ -593,12 +592,11 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [],
|
||||
"source": "single_src",
|
||||
"retriever": "classic",
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": "single_src"}
|
||||
assert len(sp.all_sources) == 1
|
||||
@@ -618,8 +616,7 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {"sources": [], "source": None}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._agent_data = {"sources": [], "source": None}
|
||||
sp._configure_source()
|
||||
assert sp.source == {}
|
||||
assert sp.all_sources == []
|
||||
@@ -639,11 +636,10 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = "agent_key_123"
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [{"id": "s1", "retriever": "classic"}],
|
||||
"source": None,
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": ["s1"]}
|
||||
|
||||
@@ -662,11 +658,10 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [{"id": None}, {"retriever": "classic"}],
|
||||
"source": None,
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {}
|
||||
|
||||
@@ -1189,6 +1184,8 @@ class TestConfigureAgent:
|
||||
"chunks": "5",
|
||||
})
|
||||
sp._configure_agent()
|
||||
sp.model_id = "test-model"
|
||||
sp._configure_retriever()
|
||||
assert sp.agent_config["workflow"] == "wf_123"
|
||||
assert sp.agent_config["workflow_owner"] == "user1"
|
||||
assert sp.retriever_config["retriever_name"] == "hybrid"
|
||||
@@ -1211,6 +1208,8 @@ class TestConfigureAgent:
|
||||
"chunks": "not_a_number",
|
||||
})
|
||||
sp._configure_agent()
|
||||
sp.model_id = "test-model"
|
||||
sp._configure_retriever()
|
||||
assert sp.retriever_config["chunks"] == 2
|
||||
|
||||
|
||||
@@ -1763,8 +1762,8 @@ class TestConfigureAgentAdditionalPaths:
|
||||
assert sp.decoded_token == {"sub": "owner_user"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_configure_agent_with_source_in_data_key(self):
|
||||
"""Cover line 463-464: data_key has 'source' set."""
|
||||
def test_configure_source_picks_up_cached_agent_data(self):
|
||||
"""After _configure_agent caches _agent_data, _configure_source uses it."""
|
||||
sp = self._make_sp()
|
||||
sp._resolve_agent_id = MagicMock(return_value="agent_id_1")
|
||||
sp._get_agent_key = MagicMock(return_value=("agent_key", False, None))
|
||||
@@ -1780,6 +1779,7 @@ class TestConfigureAgentAdditionalPaths:
|
||||
"source": "my_source",
|
||||
})
|
||||
sp._configure_agent()
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": "my_source"}
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -2067,7 +2067,7 @@ class TestPreFetchDocsFullPaths:
|
||||
"chunks": 2,
|
||||
"doc_token_limit": 50000,
|
||||
}
|
||||
sp.source = {}
|
||||
sp.source = {"active_docs": ["src1"]}
|
||||
sp.model_id = "test-model"
|
||||
sp.agent_id = None
|
||||
return sp
|
||||
|
||||
@@ -48,7 +48,7 @@ def internal_app(monkeypatch, mock_mongo_db):
|
||||
@pytest.mark.unit
|
||||
class TestVerifyInternalKey:
|
||||
|
||||
def test_no_internal_key_configured_allows_access(
|
||||
def test_no_internal_key_configured_rejects_access(
|
||||
self, internal_app, monkeypatch
|
||||
):
|
||||
app, db = internal_app
|
||||
@@ -63,9 +63,8 @@ class TestVerifyInternalKey:
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
# download will fail for missing file but should not be 401
|
||||
resp = client.get("/api/download?user=u&name=n&file=f")
|
||||
assert resp.status_code != 401
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_missing_key_returns_401(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
@@ -131,9 +130,12 @@ class TestVerifyInternalKey:
|
||||
@pytest.mark.unit
|
||||
class TestUploadIndex:
|
||||
|
||||
_TEST_INTERNAL_KEY = "test-internal-key"
|
||||
_AUTH_HEADERS = {"X-Internal-Key": "test-internal-key"}
|
||||
|
||||
def _make_settings(self, vector_store="faiss"):
|
||||
return MagicMock(
|
||||
INTERNAL_KEY=None,
|
||||
INTERNAL_KEY=self._TEST_INTERNAL_KEY,
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE=vector_store,
|
||||
EMBEDDINGS_NAME="test_embeddings",
|
||||
@@ -146,7 +148,7 @@ class TestUploadIndex:
|
||||
"application.api.internal.routes.settings", self._make_settings()
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={})
|
||||
resp = client.post("/api/upload_index", data={}, headers=self._AUTH_HEADERS)
|
||||
assert resp.json["status"] == "no user"
|
||||
|
||||
def test_missing_name_returns_no_name(self, internal_app, monkeypatch):
|
||||
@@ -155,7 +157,7 @@ class TestUploadIndex:
|
||||
"application.api.internal.routes.settings", self._make_settings()
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={"user": "testuser"})
|
||||
resp = client.post("/api/upload_index", data={"user": "testuser"}, headers=self._AUTH_HEADERS)
|
||||
assert resp.json["status"] == "no name"
|
||||
|
||||
def test_creates_new_source_entry(self, internal_app, monkeypatch):
|
||||
@@ -182,6 +184,7 @@ class TestUploadIndex:
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -219,6 +222,7 @@ class TestUploadIndex:
|
||||
"id": str(doc_id),
|
||||
"type": "remote",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -252,6 +256,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"directory_structure": json.dumps(dir_struct),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -285,6 +290,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"directory_structure": "not valid json",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -317,6 +323,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_name_map": json.dumps(fmap),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -349,6 +356,7 @@ class TestUploadIndex:
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file"
|
||||
|
||||
@@ -379,6 +387,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_faiss": (io.BytesIO(b""), ""),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file name"
|
||||
|
||||
@@ -408,6 +417,7 @@ class TestUploadIndex:
|
||||
"remote_data": '{"url":"http://example.com"}',
|
||||
"sync_frequency": "daily",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -443,6 +453,7 @@ class TestUploadIndex:
|
||||
"file_pkl": (io.BytesIO(b"pkl data"), "index.pkl"),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -477,6 +488,7 @@ class TestUploadIndex:
|
||||
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file"
|
||||
|
||||
@@ -508,9 +520,27 @@ class TestUploadIndex:
|
||||
"file_pkl": (io.BytesIO(b""), ""),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file name"
|
||||
|
||||
def test_no_internal_key_rejects_upload(self, internal_app, monkeypatch):
|
||||
"""Verify that upload_index is rejected when INTERNAL_KEY is not set."""
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings",
|
||||
MagicMock(
|
||||
INTERNAL_KEY=None,
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE="faiss",
|
||||
EMBEDDINGS_NAME="test",
|
||||
MONGO_DB_NAME="docsgpt",
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={"user": "attacker"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_update_existing_with_file_name_map(self, internal_app, monkeypatch):
|
||||
"""Cover line 124: update existing entry with file_name_map."""
|
||||
app, db = internal_app
|
||||
@@ -540,6 +570,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_name_map": json.dumps(fmap),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -572,6 +603,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_name_map": "not valid json{{{",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
|
||||
@@ -14,6 +14,13 @@ def app():
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_url_validation():
|
||||
"""Bypass SSRF URL validation so tests using localhost URLs can proceed."""
|
||||
with patch("application.api.user.tools.mcp.validate_url"):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: _sanitize_mcp_transport
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -395,6 +395,9 @@ class TestAvailableTools:
|
||||
"application.api.user.tools.routes.tool_manager", mock_manager
|
||||
):
|
||||
with app.test_request_context("/api/available_tools"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AvailableTools().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -419,6 +422,9 @@ class TestAvailableTools:
|
||||
"application.api.user.tools.routes.tool_manager", mock_manager
|
||||
):
|
||||
with app.test_request_context("/api/available_tools"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AvailableTools().get()
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -438,6 +444,9 @@ class TestAvailableTools:
|
||||
"application.api.user.tools.routes.tool_manager", mock_manager
|
||||
):
|
||||
with app.test_request_context("/api/available_tools"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AvailableTools().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
0
tests/api/v1/__init__.py
Normal file
0
tests/api/v1/__init__.py
Normal file
64
tests/api/v1/test_routes.py
Normal file
64
tests/api/v1/test_routes.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from flask import Flask
|
||||
|
||||
from application.api.v1.routes import v1_bp
|
||||
|
||||
|
||||
class _FakeCollection:
|
||||
def __init__(self, docs):
|
||||
self.docs = docs
|
||||
|
||||
def find_one(self, query):
|
||||
for doc in self.docs:
|
||||
if all(doc.get(k) == v for k, v in query.items()):
|
||||
return doc
|
||||
return None
|
||||
|
||||
def find(self, query):
|
||||
return [doc for doc in self.docs if all(doc.get(k) == v for k, v in query.items())]
|
||||
|
||||
|
||||
def _build_app():
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(v1_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_v1_models_does_not_expose_agent_keys(monkeypatch):
|
||||
docs = [
|
||||
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
|
||||
{"_id": "agent-2", "key": "key-2", "user": "user-1", "name": "Agent Two"},
|
||||
]
|
||||
|
||||
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
|
||||
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
|
||||
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
|
||||
|
||||
app = _build_app()
|
||||
client = app.test_client()
|
||||
response = client.get("/v1/models", headers={"Authorization": "Bearer key-1"})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["object"] == "list"
|
||||
assert len(payload["data"]) == 2
|
||||
assert payload["data"][0]["id"] == "agent-1"
|
||||
assert payload["data"][1]["id"] == "agent-2"
|
||||
# Keys must never appear as model IDs
|
||||
assert all(model["id"] != "key-1" for model in payload["data"])
|
||||
assert all(model["id"] != "key-2" for model in payload["data"])
|
||||
|
||||
|
||||
def test_v1_models_invalid_key_returns_401(monkeypatch):
|
||||
docs = [
|
||||
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
|
||||
]
|
||||
|
||||
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
|
||||
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
|
||||
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
|
||||
|
||||
app = _build_app()
|
||||
client = app.test_client()
|
||||
response = client.get("/v1/models", headers={"Authorization": "Bearer wrong-key"})
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -20,133 +20,49 @@ def test_epub_init_parser():
|
||||
assert parser.parser_config_set
|
||||
|
||||
|
||||
def test_epub_parser_ebooklib_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when ebooklib is not available."""
|
||||
with patch.dict(sys.modules, {"ebooklib": None}):
|
||||
with pytest.raises(ValueError, match="`EbookLib` is required to read Epub files"):
|
||||
def test_epub_parser_fast_ebook_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when fast-ebook is not available."""
|
||||
with patch.dict(sys.modules, {"fast_ebook": None}):
|
||||
with pytest.raises(ValueError, match="`fast-ebook` is required to read Epub files"):
|
||||
epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_html2text_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when html2text is not available."""
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
with patch.dict(sys.modules, {"ebooklib": fake_ebooklib, "ebooklib.epub": fake_epub}):
|
||||
with patch.dict(sys.modules, {"html2text": None}):
|
||||
with pytest.raises(ValueError, match="`html2text` is required to parse Epub files"):
|
||||
epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_successful_parsing(epub_parser):
|
||||
"""Test successful parsing of an epub file."""
|
||||
fake_fast_ebook = types.ModuleType("fast_ebook")
|
||||
fake_epub = types.ModuleType("fast_ebook.epub")
|
||||
fake_fast_ebook.epub = fake_epub
|
||||
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
# Mock ebooklib constants
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
mock_item1 = MagicMock()
|
||||
mock_item1.get_type.return_value = "document"
|
||||
mock_item1.get_content.return_value = b"<h1>Chapter 1</h1><p>Content 1</p>"
|
||||
|
||||
mock_item2 = MagicMock()
|
||||
mock_item2.get_type.return_value = "document"
|
||||
mock_item2.get_content.return_value = b"<h1>Chapter 2</h1><p>Content 2</p>"
|
||||
|
||||
mock_item3 = MagicMock()
|
||||
mock_item3.get_type.return_value = "other" # Should be ignored
|
||||
mock_item3.get_content.return_value = b"<p>Other content</p>"
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = [mock_item1, mock_item2, mock_item3]
|
||||
|
||||
mock_book.to_markdown.return_value = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
|
||||
def mock_html2text_func(html_content):
|
||||
if "Chapter 1" in html_content:
|
||||
return "# Chapter 1\n\nContent 1\n"
|
||||
elif "Chapter 2" in html_content:
|
||||
return "# Chapter 2\n\nContent 2\n"
|
||||
return "Other content\n"
|
||||
|
||||
fake_html2text.html2text = mock_html2text_func
|
||||
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
"fast_ebook": fake_fast_ebook,
|
||||
"fast_ebook.epub": fake_epub,
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
expected_result = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
assert result == expected_result
|
||||
|
||||
# Verify epub.read_epub was called with correct parameters
|
||||
fake_epub.read_epub.assert_called_once_with(Path("test.epub"), options={"ignore_ncx": True})
|
||||
|
||||
assert result == "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
fake_epub.read_epub.assert_called_once_with(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_empty_book(epub_parser):
|
||||
"""Test parsing an epub file with no document items."""
|
||||
# Create mock modules
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
# Create mock book with no document items
|
||||
"""Test parsing an epub file with no content."""
|
||||
fake_fast_ebook = types.ModuleType("fast_ebook")
|
||||
fake_epub = types.ModuleType("fast_ebook.epub")
|
||||
fake_fast_ebook.epub = fake_epub
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = []
|
||||
|
||||
mock_book.to_markdown.return_value = ""
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
fake_html2text.html2text = MagicMock()
|
||||
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
"fast_ebook": fake_fast_ebook,
|
||||
"fast_ebook.epub": fake_epub,
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("empty.epub"))
|
||||
|
||||
assert result == ""
|
||||
|
||||
fake_html2text.html2text.assert_not_called()
|
||||
|
||||
|
||||
def test_epub_parser_non_document_items_ignored(epub_parser):
|
||||
"""Test that non-document items are ignored during parsing."""
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
mock_doc_item = MagicMock()
|
||||
mock_doc_item.get_type.return_value = "document"
|
||||
mock_doc_item.get_content.return_value = b"<p>Document content</p>"
|
||||
|
||||
mock_other_item = MagicMock()
|
||||
mock_other_item.get_type.return_value = "image" # Not a document
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = [mock_other_item, mock_doc_item]
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
fake_html2text.html2text = MagicMock(return_value="Document content\n")
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
assert result == "Document content\n"
|
||||
|
||||
fake_html2text.html2text.assert_called_once_with("<p>Document content</p>")
|
||||
|
||||
@@ -8,7 +8,7 @@ from application.storage.local import LocalStorage
|
||||
|
||||
@pytest.fixture
|
||||
def temp_base_dir():
|
||||
return "/tmp/test_storage"
|
||||
return os.path.realpath("/tmp/test_storage")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -30,12 +30,12 @@ class TestLocalStorageInitialization:
|
||||
|
||||
def test_get_full_path_with_relative_path(self, local_storage):
|
||||
result = local_storage._get_full_path("documents/test.txt")
|
||||
expected = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert os.path.normpath(result) == os.path.normpath(expected)
|
||||
expected = os.path.realpath(os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt"))
|
||||
assert result == expected
|
||||
|
||||
def test_get_full_path_with_absolute_path(self, local_storage):
|
||||
result = local_storage._get_full_path("/absolute/path/test.txt")
|
||||
assert result == "/absolute/path/test.txt"
|
||||
def test_get_full_path_with_absolute_path_outside_base_raises(self, local_storage):
|
||||
with pytest.raises(ValueError, match="Path traversal detected"):
|
||||
local_storage._get_full_path("/absolute/path/test.txt")
|
||||
|
||||
@patch("os.makedirs")
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
@@ -48,8 +48,8 @@ class TestLocalStorageInitialization:
|
||||
|
||||
result = local_storage.save_file(file_data, path)
|
||||
|
||||
expected_dir = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
|
||||
assert mock_makedirs.call_count == 1
|
||||
assert os.path.normpath(mock_makedirs.call_args[0][0]) == os.path.normpath(
|
||||
@@ -74,25 +74,19 @@ class TestLocalStorageInitialization:
|
||||
|
||||
result = local_storage.save_file(file_data, path)
|
||||
|
||||
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert file_data.save.call_count == 1
|
||||
assert os.path.normpath(file_data.save.call_args[0][0]) == os.path.normpath(
|
||||
expected_file
|
||||
)
|
||||
assert result == {"storage_type": "local"}
|
||||
|
||||
@patch("os.makedirs")
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
def test_save_file_with_absolute_path(
|
||||
self, mock_file, mock_makedirs, local_storage
|
||||
):
|
||||
def test_save_file_with_absolute_path_outside_base_raises(self, local_storage):
|
||||
file_data = io.BytesIO(b"test content")
|
||||
path = "/absolute/path/test.txt"
|
||||
|
||||
local_storage.save_file(file_data, path)
|
||||
|
||||
mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True)
|
||||
mock_file.assert_called_once_with("/absolute/path/test.txt", "wb")
|
||||
with pytest.raises(ValueError, match="Path traversal detected"):
|
||||
local_storage.save_file(file_data, path)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -105,7 +99,7 @@ class TestLocalStorageGetFile:
|
||||
|
||||
result = local_storage.get_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
@@ -122,7 +116,7 @@ class TestLocalStorageGetFile:
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="File not found"):
|
||||
local_storage.get_file(path)
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
@@ -141,7 +135,7 @@ class TestLocalStorageDeleteFile:
|
||||
|
||||
result = local_storage.delete_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is True
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -158,7 +152,7 @@ class TestLocalStorageDeleteFile:
|
||||
|
||||
result = local_storage.delete_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -175,7 +169,7 @@ class TestLocalStorageFileExists:
|
||||
|
||||
result = local_storage.file_exists(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is True
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -188,7 +182,7 @@ class TestLocalStorageFileExists:
|
||||
|
||||
result = local_storage.file_exists(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -205,7 +199,7 @@ class TestLocalStorageListFiles:
|
||||
self, mock_exists, mock_walk, local_storage
|
||||
):
|
||||
directory = "documents"
|
||||
base_dir = os.path.join("/tmp/test_storage", "documents")
|
||||
base_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
|
||||
mock_walk.return_value = [
|
||||
(base_dir, ["subdir"], ["file1.txt", "file2.txt"]),
|
||||
@@ -228,7 +222,7 @@ class TestLocalStorageListFiles:
|
||||
|
||||
result = local_storage.list_files(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
|
||||
assert result == []
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -248,7 +242,7 @@ class TestLocalStorageProcessFile:
|
||||
|
||||
result = local_storage.process_file(path, processor_func, extra_arg="value")
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result == "processed"
|
||||
assert processor_func.call_count == 1
|
||||
call_kwargs = processor_func.call_args[1]
|
||||
@@ -280,7 +274,7 @@ class TestLocalStorageIsDirectory:
|
||||
|
||||
result = local_storage.is_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is True
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
@@ -295,7 +289,7 @@ class TestLocalStorageIsDirectory:
|
||||
|
||||
result = local_storage.is_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is False
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
@@ -316,7 +310,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is True
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -339,7 +333,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -355,7 +349,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -376,7 +370,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is False
|
||||
assert mock_rmtree.call_count == 1
|
||||
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
|
||||
@@ -393,7 +387,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is False
|
||||
assert mock_rmtree.call_count == 1
|
||||
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
|
||||
|
||||
@@ -2105,23 +2105,35 @@ class TestInternalRoutes:
|
||||
app.register_blueprint(internal)
|
||||
return app
|
||||
|
||||
_TEST_KEY = "test-key"
|
||||
_AUTH_HEADERS = {"X-Internal-Key": "test-key"}
|
||||
|
||||
def test_upload_index_no_user(self, app):
|
||||
with app.test_client() as client:
|
||||
with patch(
|
||||
"application.api.internal.routes.settings"
|
||||
) as ms:
|
||||
ms.INTERNAL_KEY = None
|
||||
resp = client.post("/api/upload_index")
|
||||
ms.INTERNAL_KEY = self._TEST_KEY
|
||||
resp = client.post("/api/upload_index", headers=self._AUTH_HEADERS)
|
||||
assert resp.get_json()["status"] == "no user"
|
||||
|
||||
def test_upload_index_no_name(self, app):
|
||||
with app.test_client() as client:
|
||||
with patch(
|
||||
"application.api.internal.routes.settings"
|
||||
) as ms:
|
||||
ms.INTERNAL_KEY = self._TEST_KEY
|
||||
resp = client.post("/api/upload_index", data={"user": "u1"}, headers=self._AUTH_HEADERS)
|
||||
assert resp.get_json()["status"] == "no name"
|
||||
|
||||
def test_upload_index_rejected_without_internal_key(self, app):
|
||||
with app.test_client() as client:
|
||||
with patch(
|
||||
"application.api.internal.routes.settings"
|
||||
) as ms:
|
||||
ms.INTERNAL_KEY = None
|
||||
resp = client.post("/api/upload_index", data={"user": "u1"})
|
||||
assert resp.get_json()["status"] == "no name"
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from application.api.v1.translator import (
|
||||
_get_client_tool_name,
|
||||
convert_history,
|
||||
extract_system_prompt,
|
||||
extract_tool_results,
|
||||
is_continuation,
|
||||
translate_request,
|
||||
@@ -148,6 +149,48 @@ class TestConvertHistory:
|
||||
assert history == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_system_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractSystemPrompt:
|
||||
|
||||
def test_extracts_first_system_message(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a pirate"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == "You are a pirate"
|
||||
|
||||
def test_returns_none_when_no_system_message(self):
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
assert extract_system_prompt(messages) is None
|
||||
|
||||
def test_returns_first_of_multiple_system_messages(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "First"},
|
||||
{"role": "system", "content": "Second"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == "First"
|
||||
|
||||
def test_empty_content_returns_empty_string(self):
|
||||
messages = [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == ""
|
||||
|
||||
def test_missing_content_returns_empty_string(self):
|
||||
messages = [
|
||||
{"role": "system"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# translate_request
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -167,11 +210,25 @@ class TestTranslateRequest:
|
||||
result = translate_request(data, "test-key")
|
||||
assert result["question"] == "What's 2+2?"
|
||||
assert result["api_key"] == "test-key"
|
||||
assert result["save_conversation"] is True
|
||||
# Conversations are not persisted by default on the v1 endpoint.
|
||||
assert result["save_conversation"] is False
|
||||
history = json.loads(result["history"])
|
||||
assert len(history) == 1
|
||||
assert history[0]["prompt"] == "Hello"
|
||||
|
||||
def test_save_conversation_opt_in_via_docsgpt_extension(self):
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"docsgpt": {"save_conversation": True},
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["save_conversation"] is True
|
||||
|
||||
def test_save_conversation_default_false(self):
|
||||
data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
result = translate_request(data, "key")
|
||||
assert result["save_conversation"] is False
|
||||
|
||||
def test_continuation_request(self):
|
||||
data = {
|
||||
"messages": [
|
||||
@@ -237,6 +294,23 @@ class TestTranslateRequest:
|
||||
result = translate_request(data, "key")
|
||||
assert result["attachments"] == ["att1", "att2"]
|
||||
|
||||
def test_system_prompt_override_included_when_present(self):
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "Custom prompt"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["system_prompt_override"] == "Custom prompt"
|
||||
|
||||
def test_system_prompt_override_absent_when_no_system_message(self):
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert "system_prompt_override" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# translate_response
|
||||
|
||||
@@ -363,6 +363,28 @@ class TestGetVectorstore:
|
||||
|
||||
assert get_vectorstore("user/source123") == "indexes/user/source123"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_path",
|
||||
[
|
||||
"../outside",
|
||||
"../../etc/passwd",
|
||||
"nested/../../../outside",
|
||||
"/tmp/evil",
|
||||
"..\\outside",
|
||||
"valid/../../escape",
|
||||
],
|
||||
)
|
||||
def test_rejects_path_traversal(self, malicious_path):
|
||||
from application.vectorstore.faiss import get_vectorstore
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid source_id path"):
|
||||
get_vectorstore(malicious_path)
|
||||
|
||||
def test_allows_mongodb_style_ids(self):
|
||||
from application.vectorstore.faiss import get_vectorstore
|
||||
|
||||
assert get_vectorstore("65e8f6a8a7a96b1bdad4154f") == "indexes/65e8f6a8a7a96b1bdad4154f"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreAddChunk:
|
||||
|
||||
Reference in New Issue
Block a user