Compare commits

..

35 Commits

Author SHA1 Message Date
Alex
59d9d4ac50 fix: comments in settings 2026-04-12 10:42:46 +01:00
Alex
55717043f6 fix: vale 2026-04-12 00:29:23 +01:00
Alex
ececcb8b17 feat: init pg migration 2026-04-12 00:07:24 +01:00
Alex
bd03a513e3 Merge pull request #2372 from arc53/fast-ebook
feat: faster ebook parsing
2026-04-09 18:38:13 +01:00
Alex
fcdb4fb5e8 feat: faster ebook parsing 2026-04-09 18:31:06 +01:00
Alex
e787c896eb upd Security.md 2026-04-08 12:49:20 +01:00
Alex
23aeaff5db Merge pull request #2362 from arc53/v1-mini-improvements
feat: history overwrite
2026-04-06 15:02:32 +01:00
Alex
689dd79597 fix: lang 2026-04-06 14:57:51 +01:00
Alex
0c15af90b1 feat: history overwrite 2026-04-06 14:42:01 +01:00
Alex
cdd6ff6557 chore: bump deps 2026-04-04 12:45:34 +01:00
Alex
72b3d94453 fix: tests 2026-04-03 18:30:46 +01:00
Alex
7e88d09e5d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 18:26:37 +01:00
Alex
74a4a237dc fix: bump deps 2026-04-03 18:26:29 +01:00
Alex
c3f01c6619 Merge pull request #2347 from ManishMadan2882/main
Minor frontend updates
2026-04-03 18:17:27 +01:00
Alex
6b408823d4 fix: mini theme color edits 2026-04-03 18:16:07 +01:00
Alex
3fc81ac5d8 fix: clean error 2026-04-03 18:08:38 +01:00
Alex
2652f8a5b0 fix: chatwoot 2026-04-03 18:04:49 +01:00
Alex
d711eefe96 patch: agent usage limits 2026-04-03 18:03:31 +01:00
Alex
79206f3919 fix: harden faiss 2026-04-03 17:57:49 +01:00
Alex
de971d9452 fix: validate mcp url 2026-04-03 17:52:48 +01:00
Alex
1b4d5ca0dd patch: mcp identity 2026-04-03 17:40:22 +01:00
Alex
81989e8258 fix: patch /v1/models 2026-04-03 17:37:09 +01:00
Alex
dc262d1698 patch: error 2026-04-03 17:30:23 +01:00
Alex
69f9c93869 patch: s3 2026-04-03 17:28:09 +01:00
Alex
74bf80b25c patch: sharing convos 2026-04-03 17:20:06 +01:00
Alex
d9a92a7208 feat: improve setup scripts 2026-04-03 17:15:21 +01:00
Alex
02e93d993d patch: available tools 2026-04-03 17:12:36 +01:00
Alex
6b6495f48c patch: key 2026-04-03 17:06:35 +01:00
Alex
249dd9ce37 patch: paths 2026-04-03 16:45:03 +01:00
Alex
9134ab0478 Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 16:40:50 +01:00
Alex
10ef68c9d0 Revise vulnerability reporting process
Updated vulnerability reporting instructions to use GitHub's private reporting flow.
2026-04-03 16:36:10 +01:00
Alex
7d65cf1c2b chore: bump deps 2026-04-03 16:35:10 +01:00
Alex
13c6cc59c1 Merge pull request #2349 from arc53/messages-format
Messages format
2026-04-03 16:26:57 +01:00
ManishMadan2882
648b3f1d20 (fix) lint/fe 2026-04-01 03:30:44 +05:30
ManishMadan2882
a75a9e23f9 (feat:fe) minor good things 2026-04-01 03:19:03 +05:30
85 changed files with 3369 additions and 346 deletions

View File

@@ -34,3 +34,9 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
# Leave unset while the migration is still being rolled out; the app will
# fall back to MongoDB for user data until POSTGRES_URI is configured.
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt

View File

@@ -44,3 +44,8 @@ boolean
bool
hardcode
EOL
Postgres
Supabase
config
backfill
backfills

2
.gitignore vendored
View File

@@ -108,6 +108,8 @@ celerybeat.pid
# Environments
.env
.venv
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
CLAUDE.md
env/
venv/
ENV/

View File

@@ -1,5 +1,7 @@
MinAlertLevel = warning
StylesPath = .github/styles
Vocab = DocsGPT
[*.{md,mdx}]
BasedOnStyles = DocsGPT

View File

@@ -2,13 +2,21 @@
## Supported Versions
Supported Versions:
Currently, we support security patches by committing changes and bumping the version published on Github.
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
## Reporting a Vulnerability
Found a vulnerability? Please email us:
Preferred method: use GitHub's private vulnerability reporting flow:
https://github.com/arc53/DocsGPT/security
security@arc53.com
Then click **Report a vulnerability**.
Alternatively, email us at: security@arc53.com
We aim to acknowledge reports within 48 hours.
## Incident Handling
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.

View File

@@ -24,6 +24,7 @@ from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
@@ -61,7 +62,8 @@ class MCPTool(Tool):
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
raw_url = config.get("server_url", "")
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
@@ -87,6 +89,18 @@ class MCPTool(Tool):
if self.server_url and self.auth_type != "oauth":
self._setup_client()
@staticmethod
def _validate_server_url(server_url: str) -> str:
"""Validate server_url to prevent SSRF to internal networks.
Raises:
ValueError: If the URL points to a private/internal address.
"""
try:
return validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
@@ -108,8 +122,9 @@ class MCPTool(Tool):
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
auth_key = (
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(

View File

@@ -1,6 +1,6 @@
import logging
import psycopg2
import psycopg
from application.agents.tools.base import Tool
@@ -33,7 +33,7 @@ class PostgresTool(Tool):
"""
conn = None
try:
conn = psycopg2.connect(self.connection_string)
conn = psycopg.connect(self.connection_string)
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()
@@ -60,7 +60,7 @@ class PostgresTool(Tool):
"response_data": response_data,
}
except psycopg2.Error as e:
except psycopg.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL execute_sql error: %s", e)
return {
@@ -78,7 +78,7 @@ class PostgresTool(Tool):
"""
conn = None
try:
conn = psycopg2.connect(self.connection_string)
conn = psycopg.connect(self.connection_string)
cur = conn.cursor()
cur.execute(
@@ -120,7 +120,7 @@ class PostgresTool(Tool):
"schema": schema_data,
}
except psycopg2.Error as e:
except psycopg.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL get_schema error: %s", e)
return {

52
application/alembic.ini Normal file
View File

@@ -0,0 +1,52 @@
# Alembic configuration for the DocsGPT user-data Postgres database.
#
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
# source serves the running app and migrations. To run from the project
# root::
#
# alembic -c application/alembic.ini upgrade head
[alembic]
script_location = %(here)s/alembic
prepend_sys_path = ..
version_path_separator = os
# sqlalchemy.url is intentionally left blank — env.py supplies it.
sqlalchemy.url =
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,82 @@
"""Alembic environment for the DocsGPT user-data Postgres database.
The URL is pulled from ``application.core.settings`` rather than
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
running app and ``alembic`` CLI invocations.
"""
import sys
from logging.config import fileConfig
from pathlib import Path
# Make the project root importable regardless of cwd. env.py lives at
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
from alembic import context # noqa: E402
from sqlalchemy import engine_from_config, pool # noqa: E402
from application.core.settings import settings # noqa: E402
from application.storage.db.models import metadata as target_metadata # noqa: E402
config = context.config
# Populate the runtime URL from settings.
if settings.POSTGRES_URI:
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
if config.config_file_name is not None:
fileConfig(config.config_file_name)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
url = config.get_main_option("sqlalchemy.url")
if not url:
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode against a live connection."""
if not config.get_main_option("sqlalchemy.url"):
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,462 @@
"""0001 initial schema — user-level tables migrated from MongoDB.
Creates every table described in §2.2 of ``migration-postgres.md``: tiers 1,
2, and 3 in one shot. The schema is small enough that splitting the baseline
across multiple revisions would only cost clarity.
Subsequent migrations will add columns / tables incrementally. This file is
hand-written raw DDL rather than Core ``op.create_table`` calls because the
DDL uses several Postgres-specific features (``CITEXT``, partial indexes,
``text_pattern_ops``, JSONB defaults) that are clearer in SQL than in
Alembic's Python API.
Revision ID: 0001_initial
Revises:
Create Date: 2026-04-10
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0001_initial"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ------------------------------------------------------------------
# Extensions
# ------------------------------------------------------------------
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
# ------------------------------------------------------------------
# Tier 1: leaf tables, no FKs into other migrated tables
# ------------------------------------------------------------------
op.execute("""
CREATE TABLE users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL UNIQUE,
agent_preferences JSONB NOT NULL
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX users_user_id_idx ON users (user_id);")
op.execute("""
CREATE TABLE prompts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
op.execute("""
CREATE TABLE user_tools (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
custom_name TEXT,
display_name TEXT,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
op.execute("""
CREATE TABLE token_usage (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
api_key TEXT,
agent_id UUID, -- FK added later in this migration
prompt_tokens INTEGER NOT NULL DEFAULT 0,
generated_tokens INTEGER NOT NULL DEFAULT 0,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, timestamp DESC);")
op.execute("CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, timestamp DESC);")
op.execute("CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, timestamp DESC);")
op.execute("""
CREATE TABLE user_logs (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
endpoint TEXT,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
data JSONB
);
""")
op.execute("CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, timestamp DESC);")
op.execute("""
CREATE TABLE feedback (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL, -- FK added later in this migration
user_id TEXT NOT NULL,
question_index INTEGER NOT NULL,
feedback_text TEXT,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX feedback_conv_idx ON feedback (conversation_id);")
# Append-only debug/error log. The Mongo doc has both `_id` (auto) and an
# `id` field (the activity id). Here the serial PK owns `id`; the
# application-level identifier is renamed to `activity_id`.
op.execute("""
CREATE TABLE stack_logs (
id BIGSERIAL PRIMARY KEY,
activity_id TEXT NOT NULL,
endpoint TEXT,
level TEXT,
user_id TEXT,
api_key TEXT,
query TEXT,
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX stack_logs_timestamp_idx ON stack_logs (timestamp DESC);")
op.execute("CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, timestamp DESC);")
op.execute("CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, timestamp DESC);")
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
# ------------------------------------------------------------------
# Tier 2: FK-bearing tables
# ------------------------------------------------------------------
op.execute("""
CREATE TABLE agent_folders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
op.execute("""
CREATE TABLE sources (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT, -- NULL for system/template sources
name TEXT NOT NULL,
type TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
op.execute("""
CREATE TABLE agents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
agent_type TEXT,
status TEXT NOT NULL,
key CITEXT UNIQUE,
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
chunks INTEGER,
retriever TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
json_schema JSONB,
models JSONB,
default_model_id TEXT,
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
token_limit INTEGER,
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
request_limit INTEGER,
shared BOOLEAN NOT NULL DEFAULT false,
incoming_webhook_token CITEXT UNIQUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
last_used_at TIMESTAMPTZ
);
""")
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
# Backfill the token_usage.agent_id FK now that agents exists.
op.execute("""
ALTER TABLE token_usage
ADD CONSTRAINT token_usage_agent_fk
FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;
""")
op.execute("""
CREATE TABLE attachments (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
filename TEXT NOT NULL,
upload_path TEXT NOT NULL,
mime_type TEXT,
size BIGINT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
op.execute("""
CREATE TABLE memories (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
path TEXT NOT NULL,
content TEXT NOT NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("""
CREATE UNIQUE INDEX memories_user_tool_path_uidx
ON memories (user_id, tool_id, path);
""")
op.execute("""
CREATE INDEX memories_path_prefix_idx
ON memories (user_id, tool_id, path text_pattern_ops);
""")
op.execute("""
CREATE TABLE todos (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
op.execute("""
CREATE TABLE notes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX notes_user_tool_idx ON notes (user_id, tool_id);")
op.execute("""
CREATE TABLE connector_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
provider TEXT NOT NULL,
session_data JSONB NOT NULL,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("""
CREATE INDEX connector_sessions_user_provider_idx
ON connector_sessions (user_id, provider);
""")
op.execute("""
CREATE INDEX connector_sessions_expiry_idx
ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;
""")
# ------------------------------------------------------------------
# Tier 3: conversations, pending_tool_state, workflows
# ------------------------------------------------------------------
op.execute("""
CREATE TABLE conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
name TEXT,
api_key TEXT,
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
shared_token TEXT,
date TIMESTAMPTZ NOT NULL DEFAULT now(),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
op.execute("""
CREATE INDEX conversations_shared_token_idx
ON conversations (shared_token) WHERE shared_token IS NOT NULL;
""")
op.execute("""
CREATE TABLE conversation_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
position INTEGER NOT NULL,
prompt TEXT,
response TEXT,
thought TEXT,
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
attachments UUID[] NOT NULL DEFAULT '{}',
model_id TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
feedback JSONB,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("""
CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx
ON conversation_messages (conversation_id, position);
""")
# Backfill the feedback.conversation_id FK now that conversations exists.
op.execute("""
ALTER TABLE feedback
ADD CONSTRAINT feedback_conv_fk
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE;
""")
op.execute("""
CREATE TABLE shared_conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
chunks INTEGER,
is_promptable BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
# Paused-tool continuation state. The Mongo version relies on a TTL index;
# Postgres has no native TTL, so a Celery beat task (added in Phase 3)
# deletes rows where expires_at < now() once a minute. The unique
# constraint on (conversation_id, user_id) matches the existing upsert
# semantics.
op.execute("""
CREATE TABLE pending_tool_state (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
messages JSONB NOT NULL,
pending_tool_calls JSONB NOT NULL,
tools_dict JSONB NOT NULL,
tool_schemas JSONB NOT NULL,
agent_config JSONB NOT NULL,
client_tools JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
expires_at TIMESTAMPTZ NOT NULL
);
""")
op.execute("""
CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx
ON pending_tool_state (conversation_id, user_id);
""")
op.execute("""
CREATE INDEX pending_tool_state_expires_idx
ON pending_tool_state (expires_at);
""")
# Workflows
op.execute("""
CREATE TABLE workflows (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
""")
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
op.execute("""
CREATE TABLE workflow_nodes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
node_type TEXT NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb
);
""")
op.execute("""
CREATE INDEX workflow_nodes_workflow_version_idx
ON workflow_nodes (workflow_id, graph_version);
""")
op.execute("""
CREATE TABLE workflow_edges (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
from_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
to_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
config JSONB NOT NULL DEFAULT '{}'::jsonb
);
""")
op.execute("""
CREATE INDEX workflow_edges_workflow_version_idx
ON workflow_edges (workflow_id, graph_version);
""")
op.execute("""
CREATE TABLE workflow_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
status TEXT NOT NULL,
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
ended_at TIMESTAMPTZ,
result JSONB
);
""")
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
def downgrade() -> None:
# Reverse dependency order. CASCADE would handle FKs anyway, but explicit
# is clearer for anyone reading the migration.
op.execute("DROP TABLE IF EXISTS workflow_runs CASCADE;")
op.execute("DROP TABLE IF EXISTS workflow_edges CASCADE;")
op.execute("DROP TABLE IF EXISTS workflow_nodes CASCADE;")
op.execute("DROP TABLE IF EXISTS workflows CASCADE;")
op.execute("DROP TABLE IF EXISTS pending_tool_state CASCADE;")
op.execute("DROP TABLE IF EXISTS shared_conversations CASCADE;")
op.execute("DROP TABLE IF EXISTS conversation_messages CASCADE;")
op.execute("DROP TABLE IF EXISTS conversations CASCADE;")
op.execute("DROP TABLE IF EXISTS connector_sessions CASCADE;")
op.execute("DROP TABLE IF EXISTS notes CASCADE;")
op.execute("DROP TABLE IF EXISTS todos CASCADE;")
op.execute("DROP TABLE IF EXISTS memories CASCADE;")
op.execute("DROP TABLE IF EXISTS attachments CASCADE;")
op.execute("DROP TABLE IF EXISTS agents CASCADE;")
op.execute("DROP TABLE IF EXISTS sources CASCADE;")
op.execute("DROP TABLE IF EXISTS agent_folders CASCADE;")
op.execute("DROP TABLE IF EXISTS stack_logs CASCADE;")
op.execute("DROP TABLE IF EXISTS feedback CASCADE;")
op.execute("DROP TABLE IF EXISTS user_logs CASCADE;")
op.execute("DROP TABLE IF EXISTS token_usage CASCADE;")
op.execute("DROP TABLE IF EXISTS user_tools CASCADE;")
op.execute("DROP TABLE IF EXISTS prompts CASCADE;")
op.execute("DROP TABLE IF EXISTS users CASCADE;")
# Extensions are intentionally left in place — they may be shared with
# pgvector or other extensions already enabled on the cluster.

View File

@@ -85,6 +85,10 @@ class AnswerResource(Resource, BaseAnswerResource):
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question="",
agent=agent,

View File

@@ -92,6 +92,14 @@ class StreamResource(Resource, BaseAnswerResource):
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question="",

View File

@@ -112,6 +112,7 @@ class StreamProcessor:
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
self._agent_data: Optional[Dict[str, Any]] = None
def initialize(self):
"""Initialize all required components for processing"""
@@ -359,22 +360,29 @@ class StreamProcessor:
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
"""Configure the source based on agent data.
if api_key:
agent_data = self._get_data_from_api_key(api_key)
The literal string ``"default"`` is a placeholder meaning "no
ingested source" and is normalized to an empty source so that no
retrieval is attempted.
"""
if self._agent_data:
agent_data = self._agent_data
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
source["id"]
for source in agent_data["sources"]
if source.get("id") and source["id"] != "default"
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.all_sources = [
s for s in agent_data["sources"] if s.get("id") != "default"
]
elif agent_data.get("source") and agent_data["source"] != "default":
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
@@ -387,11 +395,24 @@ class StreamProcessor:
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
active_docs = self.data["active_docs"]
if active_docs and active_docs != "default":
self.source = {"active_docs": active_docs}
else:
self.source = {}
return
self.source = {}
self.all_sources = []
def _has_active_docs(self) -> bool:
"""Return True if a real document source is configured for retrieval."""
active_docs = self.source.get("active_docs") if self.source else None
if not active_docs:
return False
if active_docs == "default":
return False
return True
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
@@ -433,48 +454,39 @@ class StreamProcessor:
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
data_key = self._get_data_from_api_key(effective_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self._agent_data = self._get_data_from_api_key(effective_key)
if self._agent_data.get("_id"):
self.agent_id = str(self._agent_data.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"prompt_id": self._agent_data.get("prompt_id", "default"),
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
"user_api_key": effective_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
"models": data_key.get("models", []),
"json_schema": self._agent_data.get("json_schema"),
"default_model_id": self._agent_data.get("default_model_id", ""),
"models": self._agent_data.get("models", []),
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
}
)
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
self.initial_user_id = self._agent_data.get("user")
self.decoded_token = {"sub": self._agent_data.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": data_key.get("user")}
self.decoded_token = {"sub": self._agent_data.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
if self._agent_data.get("workflow"):
self.agent_config["workflow"] = self._agent_data["workflow"]
self.agent_config["workflow_owner"] = self._agent_data.get("user")
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
@@ -497,14 +509,45 @@ class StreamProcessor:
)
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# Start with defaults
retriever_name = "classic"
chunks = 2
# Layer agent-level config (if present)
if self._agent_data:
if self._agent_data.get("retriever"):
retriever_name = self._agent_data["retriever"]
if self._agent_data.get("chunks") is not None:
try:
chunks = int(self._agent_data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
"using default value 2"
)
# Explicit request values win over agent config
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
}
# isNoneDoc without an API key forces no retrieval
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
@@ -528,6 +571,9 @@ class StreamProcessor:
if self.data.get("isNoneDoc", False) and not self.agent_id:
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
if not self._has_active_docs():
logger.info("Pre-fetch skipped: no active docs configured")
return None, None
try:
retriever = self.create_retriever()
logger.info(
@@ -910,15 +956,23 @@ class StreamProcessor:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
self._prompt_content = raw_prompt
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
# Allow API callers to override the system prompt when the agent
# has opted in via allow_system_prompt_override.
if (
self.agent_config.get("allow_system_prompt_override", False)
and self.data.get("system_prompt_override")
):
rendered_prompt = self.data["system_prompt_override"]
else:
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)

View File

@@ -26,12 +26,20 @@ internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
"""Verify INTERNAL_KEY for all internal endpoint requests.
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
"""
if not settings.INTERNAL_KEY:
logger.warning(
f"Internal API request rejected from {request.remote_addr}: "
"INTERNAL_KEY is not configured"
)
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])

View File

@@ -23,6 +23,8 @@ from application.api.user.base import (
workflow_nodes_collection,
workflows_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.users import UsersRepository
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
@@ -73,6 +75,7 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -96,6 +99,7 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -220,6 +224,12 @@ def build_agent_document(
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
if "allow_system_prompt_override" in allowed_fields:
base_doc["allow_system_prompt_override"] = (
data.get("allow_system_prompt_override") == "True"
if isinstance(data.get("allow_system_prompt_override"), str)
else bool(data.get("allow_system_prompt_override", False))
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@@ -292,6 +302,9 @@ class GetAgent(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -373,6 +386,9 @@ class GetAgents(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
for agent in agents
if "source" in agent
@@ -450,6 +466,10 @@ class CreateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -491,9 +511,9 @@ class CreateAgent(Resource):
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
except JsonSchemaValidationError:
return make_response(
jsonify({"success": False, "message": f"JSON schema {exc}"}),
jsonify({"success": False, "message": "Invalid JSON schema"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -674,6 +694,10 @@ class UpdateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -765,6 +789,7 @@ class UpdateAgent(Resource):
"default_model_id",
"folder_id",
"workflow",
"allow_system_prompt_override",
]
for field in allowed_fields:
@@ -872,9 +897,9 @@ class UpdateAgent(Resource):
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
except JsonSchemaValidationError:
return make_response(
jsonify({"success": False, "message": f"JSON schema {exc}"}),
jsonify({"success": False, "message": "Invalid JSON schema"}),
400,
)
else:
@@ -983,6 +1008,13 @@ class UpdateAgent(Resource):
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
elif field == "allow_system_prompt_override":
raw_value = data.get("allow_system_prompt_override", False)
update_fields[field] = (
raw_value == "True"
if isinstance(raw_value, str)
else bool(raw_value)
)
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -1220,6 +1252,9 @@ class PinnedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, ids=stale_ids: repo.remove_pinned_bulk(uid, ids)
)
list_pinned_agents = [
{
"id": str(agent["_id"]),
@@ -1351,12 +1386,18 @@ class PinAgent(Resource):
{"user_id": user_id},
{"$pull": {"agent_preferences.pinned": agent_id}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.remove_pinned(uid, aid)
)
action = "unpinned"
else:
users_collection.update_one(
{"user_id": user_id},
{"$addToSet": {"agent_preferences.pinned": agent_id}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.add_pinned(uid, aid)
)
action = "pinned"
except Exception as err:
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
@@ -1402,6 +1443,9 @@ class RemoveSharedAgent(Resource):
}
},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.remove_agent_from_all(uid, aid)
)
return make_response(jsonify({"success": True, "action": "removed"}), 200)
except Exception as err:

View File

@@ -18,6 +18,8 @@ from application.api.user.base import (
user_tools_collection,
users_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.users import UsersRepository
from application.utils import generate_image_url
agents_sharing_ns = Namespace(
@@ -105,6 +107,9 @@ class SharedAgent(Resource):
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
@@ -139,6 +144,9 @@ class SharedAgents(Resource):
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
)
dual_write(UsersRepository,
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
list_shared_agents = [

View File

@@ -612,6 +612,10 @@ class LiveSpeechToTextFinish(Resource):
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
try:
from application.api.user.base import storage
@@ -629,6 +633,10 @@ class ServeImage(Resource):
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(

View File

@@ -15,6 +15,8 @@ from werkzeug.utils import secure_filename
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.users import UsersRepository
from application.storage.storage_creator import StorageCreator
from application.vectorstore.vector_creator import VectorCreator
@@ -132,6 +134,9 @@ def ensure_user_doc(user_id):
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
return user_doc

View File

@@ -57,7 +57,7 @@ class ShareConversation(Resource):
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
{"_id": ObjectId(conversation_id), "user": user}
)
if conversation is None:
return make_response(

View File

@@ -463,6 +463,16 @@ class ManageSourceFiles(Resource):
removed_files = []
map_updated = False
for file_path in file_paths:
if ".." in str(file_path) or str(file_path).startswith("/"):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid file path",
}
),
400,
)
full_path = f"{source_file_path}/{file_path}"
# Remove from storage

View File

@@ -14,6 +14,7 @@ from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields
@@ -63,6 +64,21 @@ def _extract_auth_credentials(config):
return auth_credentials
def _validate_mcp_server_url(config: dict) -> None:
"""Validate the server_url in an MCP config to prevent SSRF.
Raises:
ValueError: If the URL is missing or points to a blocked address.
"""
server_url = (config.get("server_url") or "").strip()
if not server_url:
raise ValueError("server_url is required")
try:
validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid server URL: {exc}") from exc
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@api.expect(
@@ -97,6 +113,8 @@ class TestMCPServerConfig(Resource):
400,
)
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
@@ -105,15 +123,41 @@ class TestMCPServerConfig(Resource):
result = mcp_tool.test_connection()
if result.get("requires_oauth"):
return make_response(jsonify(result), 200)
safe_result = {
k: v
for k, v in result.items()
if k in ("success", "requires_oauth", "auth_url")
}
return make_response(jsonify(safe_result), 200)
if not result.get("success") and "message" in result:
if not result.get("success"):
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
result["message"] = "Connection test failed"
return make_response(
jsonify(
{
"success": False,
"message": "Connection test failed",
"tools_count": 0,
}
),
200,
)
return make_response(jsonify(result), 200)
safe_result = {
"success": True,
"message": result.get("message", "Connection successful"),
"tools_count": result.get("tools_count", 0),
"tools": result.get("tools", []),
}
return make_response(jsonify(safe_result), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server test request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
@@ -165,6 +209,8 @@ class MCPServerSave(Resource):
400,
)
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
auth_type = config.get("auth_type", "none")
mcp_config = config.copy()
@@ -279,6 +325,12 @@ class MCPServerSave(Resource):
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server save request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(

View File

@@ -8,6 +8,7 @@ from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields, validate_function_name
@@ -130,6 +131,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
if not request.decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
@@ -236,6 +239,16 @@ class CreateTool(Resource):
if missing_fields:
return missing_fields
try:
if data["name"] == "mcp_tool":
server_url = (data.get("config", {}).get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
@@ -421,6 +434,16 @@ class UpdateToolConfig(Resource):
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
if tool_name == "mcp_tool":
server_url = (data["config"].get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}

View File

@@ -138,10 +138,18 @@ def chat_completions():
if usage_error:
return usage_error
should_save_conversation = bool(internal_data.get("save_conversation", False))
if is_stream:
return Response(
_stream_response(
helper, question, agent, processor, model_name, continuation
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
),
mimetype="text/event-stream",
headers={
@@ -151,7 +159,13 @@ def chat_completions():
)
else:
return _non_stream_response(
helper, question, agent, processor, model_name, continuation
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
)
except ValueError as e:
@@ -181,6 +195,7 @@ def _stream_response(
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Generator[str, None, None]:
"""Generate translated SSE chunks for streaming response."""
completion_id = f"chatcmpl-{int(time.time())}"
@@ -193,6 +208,7 @@ def _stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -225,6 +241,7 @@ def _non_stream_response(
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Response:
"""Collect full response and return as single JSON."""
stream = helper.complete_stream(
@@ -235,6 +252,7 @@ def _non_stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -293,8 +311,9 @@ def list_models():
for ag in user_agents:
created = ag.get("createdAt")
created_ts = int(created.timestamp()) if created else int(time.time())
model_id = str(ag.get("_id") or ag.get("id") or "")
models.append({
"id": str(ag.get("key", "")),
"id": model_id,
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",

View File

@@ -80,6 +80,17 @@ def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
return None
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
"""Extract the first system message content from the messages array.
Returns None if no system message is present.
"""
for msg in messages:
if msg.get("role") == "system":
return msg.get("content", "")
return None
def convert_history(messages: List[Dict]) -> List[Dict]:
"""Convert chat completions messages array to DocsGPT history format.
@@ -148,20 +159,27 @@ def translate_request(
break
history = convert_history(messages)
system_prompt_override = extract_system_prompt(messages)
docsgpt = data.get("docsgpt", {})
result = {
"question": question,
"api_key": api_key,
"history": json.dumps(history),
"save_conversation": True,
# Conversations are NOT persisted by default on the v1 endpoint.
# Callers opt in via ``docsgpt.save_conversation: true``.
"save_conversation": bool(docsgpt.get("save_conversation", False)),
}
if system_prompt_override is not None:
result["system_prompt_override"] = system_prompt_override
# Client tools
if data.get("tools"):
result["client_tools"] = data["tools"]
# DocsGPT extensions
docsgpt = data.get("docsgpt", {})
if docsgpt.get("attachments"):
result["attachments"] = docsgpt["attachments"]

View File

@@ -1,6 +1,6 @@
from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging
from celery.signals import setup_logging, worker_process_init
def make_celery(app_name=__name__):
@@ -20,5 +20,24 @@ def config_loggers(*args, **kwargs):
setup_logging()
@worker_process_init.connect
def _dispose_db_engine_on_fork(*args, **kwargs):
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
SQLAlchemy connection pools are not fork-safe: file descriptors shared
between the parent and a forked worker will corrupt the pool. Disposing
on ``worker_process_init`` gives every worker its own fresh pool on
first use.
Imported lazily so Celery workers that don't touch Postgres (or where
``POSTGRES_URI`` is unset) don't fail at startup.
"""
try:
from application.storage.db.engine import dispose_engine
except Exception:
return
dispose_engine()
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -0,0 +1,89 @@
"""Normalize user-supplied Postgres URIs for different drivers.
DocsGPT has two Postgres connection strings pointing at potentially
different databases:
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
dialect prefix is an invalid URI from its point of view.
The two fields therefore need opposite normalization so operators don't
have to know which driver a given field feeds. Each normalizer also
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
psycopg2 is no longer in the project.
This module is deliberately separate from ``application/core/settings.py``
so the Settings class stays focused on field declarations, and the
URI-rewriting logic can be unit-tested without triggering ``.env``
file loading from importing Settings.
"""
from __future__ import annotations
def _rewrite_uri_prefixes(v, rewrites):
"""Shared URI prefix rewriter used by both normalizers below.
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
applies the first matching rewrite, and passes unrecognised input
through so downstream consumers (SQLAlchemy, libpq) can produce
their own error messages rather than us silently eating a
misconfiguration.
"""
if v is None:
return None
if not isinstance(v, str):
return v
v = v.strip()
if not v or v.lower() == "none":
return None
for prefix, target in rewrites:
if v.startswith(prefix):
return target + v[len(prefix):]
return v
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
# dialect prefix to select the psycopg v3 driver. Normalize the
# operator-friendly forms TOWARD that dialect.
_POSTGRES_URI_REWRITES = (
("postgresql+psycopg2://", "postgresql+psycopg://"),
("postgresql://", "postgresql+psycopg://"),
("postgres://", "postgresql+psycopg://"),
)
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
# dialect prefix is an invalid URI from libpq's point of view. Strip it
# if the operator accidentally copied their POSTGRES_URI value here.
_PGVECTOR_CONNECTION_STRING_REWRITES = (
("postgresql+psycopg2://", "postgresql://"),
("postgresql+psycopg://", "postgresql://"),
)
def normalize_postgres_uri(v):
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
Accepts the forms operators naturally write (``postgres://``,
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
Unknown schemes pass through unchanged so SQLAlchemy can produce its
own dialect-not-found error.
"""
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
def normalize_pgvector_connection_string(v):
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
Strips the SQLAlchemy dialect prefix if the operator accidentally
copied their POSTGRES_URI value here — libpq can't parse it.
User-friendly forms (``postgres://``, ``postgresql://``) pass
through unchanged since libpq accepts them natively.
"""
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)

View File

@@ -8,6 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from application.core.db_uri import ( # noqa: E402
normalize_pgvector_connection_string,
normalize_postgres_uri,
)
class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
@@ -22,6 +28,13 @@ class Settings(BaseSettings):
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MONGO_DB_NAME: str = "docsgpt"
# User-data Postgres DB Optional during the MongoDB→Postgres migration; becomes required once the migration is
# complete.
POSTGRES_URI: Optional[str] = None
# MongoDB→Postgres migration switches
USE_POSTGRES: bool = False
READ_POSTGRES: bool = False
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
@@ -117,7 +130,10 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config
# PGVector vectorstore config. Write the URI in whichever form you
# prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy
# dialect form (``postgresql+psycopg://``) are all accepted and
# normalized internally for ``psycopg.connect()``.
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
@@ -156,6 +172,16 @@ class Settings(BaseSettings):
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
@field_validator("POSTGRES_URI", mode="before")
@classmethod
def _normalize_postgres_uri_validator(cls, v):
return normalize_postgres_uri(v)
@field_validator("PGVECTOR_CONNECTION_STRING", mode="before")
@classmethod
def _normalize_pgvector_connection_string_validator(cls, v):
return normalize_pgvector_connection_string(v)
@field_validator(
"API_KEY",
"OPENAI_API_KEY",

View File

@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
try:
import ebooklib
from ebooklib import epub
from fast_ebook import epub
except ImportError:
raise ValueError("`EbookLib` is required to read Epub files.")
try:
import html2text
except ImportError:
raise ValueError("`html2text` is required to parse Epub files.")
raise ValueError("`fast-ebook` is required to read Epub files.")
text_list = []
book = epub.read_epub(file, options={"ignore_ncx": True})
# Iterate through all chapters.
for item in book.get_items():
# Chapters are typically located in epub documents items.
if item.get_type() == ebooklib.ITEM_DOCUMENT:
text_list.append(
html2text.html2text(item.get_content().decode("utf-8"))
)
text = "\n".join(text_list)
book = epub.read_epub(file)
text = book.to_markdown()
return text

View File

@@ -1,5 +1,6 @@
anthropic==0.86.0
boto3==1.42.24
alembic>=1.13,<2
anthropic==0.88.0
boto3==1.42.83
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.3
@@ -11,11 +12,11 @@ rapidocr>=1.4.0
onnxruntime>=1.19.0
docx2txt==0.9
ddgs>=8.0.0
ebooklib==0.20
elevenlabs==2.40.0
fast-ebook
elevenlabs==2.41.0
Flask==3.1.3
faiss-cpu==1.13.2
fastmcp==2.14.6
fastmcp==3.2.0
flask-restx==1.3.2
google-genai==1.69.0
google-api-python-client==2.193.0
@@ -23,10 +24,9 @@ google-auth-httplib2==0.3.1
google-auth-oauthlib==1.3.1
gTTS==2.5.4
gunicorn==25.3.0
html2text==2025.4.15
jinja2==3.1.6
jiter==0.13.0
jmespath==1.0.1
jmespath==1.1.0
joblib==1.5.3
jsonpatch==1.33
jsonpointer==3.0.0
@@ -34,7 +34,7 @@ kombu==5.6.2
langchain==1.2.3
langchain-community==0.4.1
langchain-core==1.2.23
langchain-openai==1.1.7
langchain-openai==1.1.12
langchain-text-splitters==1.1.1
langsmith==0.7.23
lazy-object-proxy==1.12.0
@@ -53,29 +53,30 @@ orjson==3.11.7
packaging==26.0
pandas==3.0.2
openpyxl==3.1.5
pathable==0.4.4
pathable==0.5.0
pdf2image>=1.17.0
pillow
portalocker>=2.7.0,<3.0.0
portalocker>=2.7.0,<4.0.0
prompt-toolkit==3.0.52
protobuf==7.34.1
psycopg2-binary==2.9.11
psycopg[binary,pool]>=3.1,<4
py==1.11.0
pydantic
pydantic-core
pydantic-settings
pymongo==4.16.0
pypdf==6.6.0
pypdf==6.9.2
python-dateutil==2.9.0.post0
python-dotenv
python-jose==3.5.0
python-pptx==1.0.2
redis==7.4.0
referencing>=0.28.0,<0.38.0
regex==2026.3.32
regex==2026.4.4
requests==2.33.1
retry==0.9.2
sentence-transformers==5.3.0
sqlalchemy>=2.0,<3
tiktoken==0.12.0
tokenizers==0.22.2
torch==2.11.0

View File

@@ -0,0 +1,10 @@
"""PostgreSQL storage layer for user-level data.
This package holds the SQLAlchemy Core engine, metadata, repositories, and
migration infrastructure for the user-data Postgres database. It is separate
from ``application/vectorstore/pgvector.py`` — the two may point at the same
cluster or at different clusters depending on operator configuration.
Repository modules are added in later phases
as individual collections are ported.
"""

View File

@@ -0,0 +1,39 @@
"""Common helpers shared by all repositories.
Repositories are thin wrappers around SQLAlchemy Core query construction.
They take a ``Connection`` on call and return plain ``dict`` rows during the
Mongo→Postgres cutover so that call sites don't have to change shape. Once
cutover is complete, a follow-up phase may migrate repo return types to
Pydantic DTOs (tracked in the migration plan as a post-migration item).
"""
from typing import Any, Mapping
from uuid import UUID
def row_to_dict(row: Any) -> dict:
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
During the migration window, API responses and downstream code still
expect a string ``_id`` field (matching the Mongo shape). This helper
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
existing serializers keep working unchanged.
Args:
row: A SQLAlchemy ``Row`` object, or ``None``.
Returns:
A plain dict, or an empty dict if ``row`` is ``None``.
"""
if row is None:
return {}
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
out = dict(mapping)
if "id" in out and out["id"] is not None:
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
out["_id"] = out["id"]
return out

View File

@@ -0,0 +1,67 @@
"""Best-effort Postgres dual-write helper used during the MongoDB→Postgres
migration.
The helper:
* Returns immediately if ``settings.USE_POSTGRES`` is off, so default-off
call sites add literally zero work.
* Opens a transactional connection from the user-data SQLAlchemy engine.
* Instantiates the caller's repository class on that connection.
* Runs the caller's operation.
* Swallows and logs any exception. **Mongo remains the source of truth
during the dual-write window** — a Postgres-side failure must never
break a user-facing request. Drift that builds up from swallowed
failures is caught separately by re-running the backfill script.
Call sites look like::
users_collection.update_one(..., {"$addToSet": {...}}) # Mongo write, unchanged
dual_write(UsersRepository, lambda r: r.add_pinned(uid, aid)) # Postgres mirror
A single parameterised helper rather than one function per collection
means a new collection just needs its repository class — no new helper
function, no new feature flag. The whole helper is deleted at Phase 5
when the migration is complete.
"""
from __future__ import annotations
import logging
from typing import Callable, TypeVar
from application.core.settings import settings
logger = logging.getLogger(__name__)
_Repo = TypeVar("_Repo")
def dual_write(repo_cls: type[_Repo], fn: Callable[[_Repo], None]) -> None:
"""Mirror a Mongo write into Postgres via ``repo_cls``, best-effort.
No-op when ``settings.USE_POSTGRES`` is false. Any exception
(connection pool exhaustion, migration drift, SQL error) is logged
and swallowed so the caller's primary Mongo write remains the source
of truth.
Args:
repo_cls: The repository class to instantiate (e.g. ``UsersRepository``).
fn: A callable that takes the instantiated repository and performs
the desired write.
"""
if not settings.USE_POSTGRES:
return
try:
# Lazy import so modules that import dual_write don't pay the
# SQLAlchemy import cost when the flag is off.
from application.storage.db.engine import get_engine
with get_engine().begin() as conn:
fn(repo_cls(conn))
except Exception:
logger.warning(
"Postgres dual-write failed for %s — Mongo write already committed",
repo_cls.__name__,
exc_info=True,
)

View File

@@ -0,0 +1,67 @@
"""SQLAlchemy Core engine factory for the user-data Postgres database.
The engine is lazily constructed on first use and cached as a module-level
singleton. Repositories and the Alembic env module both obtain connections
through this factory, so pool tuning lives in one place.
``POSTGRES_URI`` can be written in any of the common Postgres URI forms::
postgres://user:pass@host:5432/docsgpt
postgresql://user:pass@host:5432/docsgpt
Both are accepted and normalized internally to the psycopg3 dialect
(``postgresql+psycopg://``) by ``application.core.settings``. Operators
don't need to know about SQLAlchemy dialect prefixes.
"""
from typing import Optional
from sqlalchemy import Engine, create_engine
from application.core.settings import settings
_engine: Optional[Engine] = None
def get_engine() -> Engine:
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
Raises:
RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that
reach this path without a configured URI have a setup bug — the
error message points them at the right setting.
Returns:
A SQLAlchemy ``Engine`` configured with a pooled connection to
Postgres via psycopg3.
"""
global _engine
if _engine is None:
if not settings.POSTGRES_URI:
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
_engine = create_engine(
settings.POSTGRES_URI,
pool_size=10,
max_overflow=20,
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
pool_recycle=1800,
future=True,
)
return _engine
def dispose_engine() -> None:
"""Dispose the pooled connections and reset the singleton.
Called from the Celery ``worker_process_init`` signal so each forked
worker gets a fresh pool instead of sharing file descriptors with the
parent process (which corrupts the pool on fork).
"""
global _engine
if _engine is not None:
_engine.dispose()
_engine = None

View File

@@ -0,0 +1,38 @@
"""SQLAlchemy Core metadata for the user-data Postgres database.
Tables are added here one at a time as repositories are built during the
MongoDB→Postgres migration. The baseline schema in the Alembic migration
(``application/alembic/versions/0001_initial.py``) is the source of truth
for DDL; the ``Table`` definitions below must match it column-for-column.
If the two drift, migrations win — update this file to match.
"""
from sqlalchemy import (
Column,
DateTime,
MetaData,
Table,
Text,
func,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
metadata = MetaData()
# --- Phase 1, Tier 1 --------------------------------------------------------
users_table = Table(
"users",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False, unique=True),
Column(
"agent_preferences",
JSONB,
nullable=False,
server_default='{"pinned": [], "shared_with_me": []}',
),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)

View File

@@ -0,0 +1,11 @@
"""Repositories for the user-data Postgres database.
Each module in this package exposes exactly one repository class. Repository
methods take a ``Connection`` (either as a constructor argument or as a
method argument) and return plain ``dict`` rows via
``application.storage.db.base_repository.row_to_dict`` during the
MongoDB→Postgres cutover, so call sites don't have to change shape.
Repositories are added one collection at a time, matching the phased
rollout in ``migration-postgres.md``.
"""

View File

@@ -0,0 +1,245 @@
"""Repository for the ``users`` table.
Covers every operation the legacy Mongo code performs on
``users_collection``:
1. ``ensure_user_doc`` in ``application/api/user/base.py`` (upsert + get)
2. Pin/unpin agents in ``application/api/user/agents/routes.py`` (add/remove
on ``agent_preferences.pinned``)
3. Share accept/reject in ``application/api/user/agents/sharing.py`` (add/
bulk-remove on ``agent_preferences.shared_with_me``)
4. Cascade delete of an agent id from both arrays at once
All array mutations are implemented as single atomic UPDATE statements
using JSONB operators (``jsonb_set``, ``jsonb_array_elements``, ``@>``)
so there is no read-modify-write race between concurrent writers on the
same user row.
The repository takes a ``Connection`` and does not manage its own
transactions. Callers are responsible for wrapping writes in
``with engine.begin() as conn:`` (production) or the test fixture's
rollback-per-test connection (tests).
"""
from __future__ import annotations
from typing import Iterable, Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
_DEFAULT_PREFERENCES = '{"pinned": [], "shared_with_me": []}'
class UsersRepository:
"""Postgres-backed replacement for Mongo ``users_collection`` writes/reads."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
# ------------------------------------------------------------------
# Reads
# ------------------------------------------------------------------
def get(self, user_id: str) -> Optional[dict]:
"""Return the user row as a dict, or ``None`` if missing.
Args:
user_id: Auth-provider ``sub`` (opaque string).
"""
result = self._conn.execute(
text("SELECT * FROM users WHERE user_id = :user_id"),
{"user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
# ------------------------------------------------------------------
# Upsert
# ------------------------------------------------------------------
def upsert(self, user_id: str) -> dict:
"""Ensure a row exists for ``user_id`` and return it.
Matches Mongo's ``find_one_and_update(..., $setOnInsert, upsert=True,
return_document=AFTER)`` semantics: if the row exists, preferences
are preserved untouched; if it doesn't, a new row is created with
default preferences.
The ``DO UPDATE SET user_id = EXCLUDED.user_id`` branch is a
deliberate no-op that lets ``RETURNING *`` fire on both the insert
and conflict paths (``DO NOTHING`` would suppress the returning).
"""
result = self._conn.execute(
text(
"""
INSERT INTO users (user_id, agent_preferences)
VALUES (:user_id, CAST(:default_prefs AS jsonb))
ON CONFLICT (user_id) DO UPDATE
SET user_id = EXCLUDED.user_id
RETURNING *
"""
),
{"user_id": user_id, "default_prefs": _DEFAULT_PREFERENCES},
)
return row_to_dict(result.fetchone())
# ------------------------------------------------------------------
# Pinned agents
# ------------------------------------------------------------------
def add_pinned(self, user_id: str, agent_id: str) -> None:
"""Idempotently append ``agent_id`` to ``agent_preferences.pinned``.
Uses ``@>`` containment so a duplicate add is a no-op rather than a
silent double-insert. The whole update is a single atomic statement
so concurrent add_pinned calls on the same user cannot interleave
into a read-modify-write race.
"""
self._append_to_jsonb_array(user_id, "pinned", agent_id)
def remove_pinned(self, user_id: str, agent_id: str) -> None:
"""Remove ``agent_id`` from ``agent_preferences.pinned`` if present."""
self._remove_from_jsonb_array(user_id, "pinned", [agent_id])
def remove_pinned_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
"""Remove every id in ``agent_ids`` from ``agent_preferences.pinned``.
No-op if the list is empty. Unknown ids are silently ignored so
callers can pass the full "stale" set without pre-filtering.
"""
ids = list(agent_ids)
if not ids:
return
self._remove_from_jsonb_array(user_id, "pinned", ids)
# ------------------------------------------------------------------
# Shared-with-me agents
# ------------------------------------------------------------------
def add_shared(self, user_id: str, agent_id: str) -> None:
"""Idempotently append ``agent_id`` to ``agent_preferences.shared_with_me``."""
self._append_to_jsonb_array(user_id, "shared_with_me", agent_id)
def remove_shared_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
"""Bulk-remove from ``agent_preferences.shared_with_me``. Empty list is a no-op."""
ids = list(agent_ids)
if not ids:
return
self._remove_from_jsonb_array(user_id, "shared_with_me", ids)
# ------------------------------------------------------------------
# Combined removal — called when an agent is hard-deleted
# ------------------------------------------------------------------
def remove_agent_from_all(self, user_id: str, agent_id: str) -> None:
"""Remove ``agent_id`` from BOTH pinned and shared_with_me atomically.
Mirrors the Mongo ``$pull`` that targets both nested array fields
in one ``update_one`` — see ``application/api/user/agents/routes.py``
around the agent-delete path.
"""
self._conn.execute(
text(
"""
UPDATE users
SET
agent_preferences = jsonb_set(
jsonb_set(
agent_preferences,
'{pinned}',
COALESCE(
(
SELECT jsonb_agg(elem)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
) AS elem
WHERE (elem #>> '{}') != :agent_id
),
'[]'::jsonb
)
),
'{shared_with_me}',
COALESCE(
(
SELECT jsonb_agg(elem)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
) AS elem
WHERE (elem #>> '{}') != :agent_id
),
'[]'::jsonb
)
),
updated_at = now()
WHERE user_id = :user_id
"""
),
{"user_id": user_id, "agent_id": agent_id},
)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _append_to_jsonb_array(self, user_id: str, key: str, agent_id: str) -> None:
"""Idempotent append of ``agent_id`` to ``agent_preferences.<key>``.
The ``key`` argument is NOT user input — it's hard-coded by the
calling method (``pinned`` / ``shared_with_me``). It goes into the
SQL literal because ``jsonb_set`` requires a path literal, not a
bind parameter. This is safe as long as callers never pass
untrusted strings for ``key``.
"""
if key not in ("pinned", "shared_with_me"):
raise ValueError(f"unsupported jsonb key: {key!r}")
self._conn.execute(
text(
f"""
UPDATE users
SET
agent_preferences = jsonb_set(
agent_preferences,
'{{{key}}}',
CASE
WHEN agent_preferences->'{key}' @> to_jsonb(CAST(:agent_id AS text))
THEN agent_preferences->'{key}'
ELSE
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|| to_jsonb(CAST(:agent_id AS text))
END
),
updated_at = now()
WHERE user_id = :user_id
"""
),
{"user_id": user_id, "agent_id": agent_id},
)
def _remove_from_jsonb_array(
self, user_id: str, key: str, agent_ids: list[str]
) -> None:
"""Remove every id in ``agent_ids`` from ``agent_preferences.<key>``."""
if key not in ("pinned", "shared_with_me"):
raise ValueError(f"unsupported jsonb key: {key!r}")
self._conn.execute(
text(
f"""
UPDATE users
SET
agent_preferences = jsonb_set(
agent_preferences,
'{{{key}}}',
COALESCE(
(
SELECT jsonb_agg(elem)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
) AS elem
WHERE NOT ((elem #>> '{{}}') = ANY(:agent_ids))
),
'[]'::jsonb
)
),
updated_at = now()
WHERE user_id = :user_id
"""
),
{"user_id": user_id, "agent_ids": agent_ids},
)

View File

@@ -21,10 +21,19 @@ class LocalStorage(BaseStorage):
)
def _get_full_path(self, path: str) -> str:
"""Get absolute path by combining base_dir and path."""
"""Get absolute path by combining base_dir and path.
Raises:
ValueError: If the resolved path escapes base_dir (path traversal).
"""
if os.path.isabs(path):
return path
return os.path.join(self.base_dir, path)
resolved = os.path.realpath(path)
else:
resolved = os.path.realpath(os.path.join(self.base_dir, path))
base = os.path.realpath(self.base_dir)
if not resolved.startswith(base + os.sep) and resolved != base:
raise ValueError(f"Path traversal detected: {path}")
return resolved
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
"""Save a file to local storage."""

View File

@@ -2,6 +2,7 @@
import io
import os
import posixpath
from typing import BinaryIO, Callable, List
import boto3
@@ -14,6 +15,20 @@ from botocore.exceptions import ClientError
class S3Storage(BaseStorage):
"""AWS S3 storage implementation."""
@staticmethod
def _validate_path(path: str) -> str:
"""Validate and normalize an S3 key to prevent path traversal.
Raises:
ValueError: If the path contains traversal sequences or is absolute.
"""
if "\x00" in path:
raise ValueError(f"Null byte in path: {path}")
normalized = posixpath.normpath(path)
if normalized.startswith("/") or normalized.startswith(".."):
raise ValueError(f"Path traversal detected: {path}")
return normalized
def __init__(self, bucket_name=None):
"""
Initialize S3 storage.
@@ -46,6 +61,7 @@ class S3Storage(BaseStorage):
**kwargs,
) -> dict:
"""Save a file to S3 storage."""
path = self._validate_path(path)
self.s3.upload_fileobj(
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
)
@@ -61,6 +77,7 @@ class S3Storage(BaseStorage):
def get_file(self, path: str) -> BinaryIO:
"""Get a file from S3 storage."""
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found: {path}")
file_obj = io.BytesIO()
@@ -70,6 +87,7 @@ class S3Storage(BaseStorage):
def delete_file(self, path: str) -> bool:
"""Delete a file from S3 storage."""
path = self._validate_path(path)
try:
self.s3.delete_object(Bucket=self.bucket_name, Key=path)
return True
@@ -78,6 +96,7 @@ class S3Storage(BaseStorage):
def file_exists(self, path: str) -> bool:
"""Check if a file exists in S3 storage."""
path = self._validate_path(path)
try:
self.s3.head_object(Bucket=self.bucket_name, Key=path)
return True
@@ -115,6 +134,7 @@ class S3Storage(BaseStorage):
import logging
import tempfile
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found in S3: {path}")
with tempfile.NamedTemporaryFile(

View File

@@ -11,11 +11,33 @@ from application.storage.storage_creator import StorageCreator
def get_vectorstore(path: str) -> str:
if path:
vectorstore = f"indexes/{path}"
else:
vectorstore = "indexes"
return vectorstore
"""Build a safe local path for a FAISS index.
Args:
path: Source identifier provided by the caller.
Returns:
The validated vectorstore path rooted under ``indexes``.
Raises:
ValueError: If ``path`` escapes the ``indexes`` directory.
"""
base_dir = "indexes"
if not path:
return base_dir
normalized = str(path).strip()
if "\\" in normalized:
raise ValueError("Invalid source_id path")
candidate = os.path.normpath(os.path.join(base_dir, normalized))
base_abs = os.path.abspath(base_dir)
candidate_abs = os.path.abspath(candidate)
if not candidate_abs.startswith(base_abs + os.sep) and candidate_abs != base_abs:
raise ValueError("Invalid source_id path")
return candidate
class FaissStore(BaseVectorStore):

View File

@@ -37,27 +37,25 @@ class PGVectorStore(BaseVectorStore):
)
try:
import psycopg2
from psycopg2.extras import Json
import pgvector.psycopg2
import psycopg
from pgvector.psycopg import register_vector
except ImportError:
raise ImportError(
"Could not import required packages. "
"Please install with `pip install psycopg2-binary pgvector`."
"Please install with `pip install 'psycopg[binary,pool]' pgvector`."
)
self._psycopg2 = psycopg2
self._Json = Json
self._pgvector = pgvector.psycopg2
self._psycopg = psycopg
self._register_vector = register_vector
self._connection = None
self._ensure_table_exists()
def _get_connection(self):
"""Get or create database connection"""
if self._connection is None or self._connection.closed:
self._connection = self._psycopg2.connect(self._connection_string)
self._connection = self._psycopg.connect(self._connection_string)
# Register pgvector types
self._pgvector.register_vector(self._connection)
self._register_vector(self._connection)
return self._connection
def _ensure_table_exists(self):
@@ -170,7 +168,7 @@ class PGVectorStore(BaseVectorStore):
for text, embedding, metadata in zip(texts, embeddings, metadatas):
cursor.execute(
insert_query,
(text, embedding, self._Json(metadata), self._source_id)
(text, embedding, metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
inserted_ids.append(str(inserted_id))
@@ -261,7 +259,7 @@ class PGVectorStore(BaseVectorStore):
cursor.execute(
insert_query,
(text, embeddings[0], self._Json(final_metadata), self._source_id)
(text, embeddings[0], final_metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
conn.commit()

View File

@@ -7,6 +7,10 @@ export default {
"title": "🔌 Agent API",
"href": "/Agents/api"
},
"openai-compatible": {
"title": "🔄 OpenAI-Compatible API",
"href": "/Agents/openai-compatible"
},
"webhooks": {
"title": "🪝 Agent Webhooks",
"href": "/Agents/webhooks"

View File

@@ -15,6 +15,10 @@ DocsGPT Agents can be accessed programmatically through API endpoints. This page
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
<Callout type="info">
Looking to connect an existing OpenAI-compatible client (opencode, aider, the OpenAI SDKs, etc.) to a DocsGPT Agent? Use the [OpenAI-Compatible Chat Completions API](/Agents/openai-compatible) — it speaks the standard chat completions protocol so no adapter code is required.
</Callout>
## Base URL
<Callout type="info">

View File

@@ -111,6 +111,7 @@ Once an agent is created, you can:
* Modify any of its configuration settings (name, description, source, prompt, tools, type).
* **Generate a Public Link:** From the edit screen, you can create a shareable public link that allows others to import and use your agent.
* **Get a Webhook URL:** You can also obtain a Webhook URL for the agent. This allows external applications or services to trigger the agent and receive responses programmatically, enabling powerful integrations and automations.
* **Use it via API:** Every agent exposes an API key that can be used with the native [Agent API](/Agents/api) or the [OpenAI-Compatible API](/Agents/openai-compatible) so you can drop DocsGPT Agents into any tool that already speaks the chat completions protocol.
## Seeding Premade Agents from YAML

View File

@@ -0,0 +1,93 @@
---
title: OpenAI-Compatible API
description: Connect any OpenAI-compatible client to DocsGPT Agents via /v1/chat/completions.
---
import { Callout, Tabs } from 'nextra/components';
# OpenAI-Compatible API
DocsGPT exposes `/v1/chat/completions` following the standard chat completions protocol. Point any compatible client — **opencode**, **Aider**, **LibreChat** or the OpenAI SDKs — at your DocsGPT Agent by changing only the base URL and API key.
## Quick Start
<Tabs items={['Python', 'cURL']}>
<Tabs.Tab>
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:7091/v1", # or https://gptcloud.arc53.com/v1
api_key="your_agent_api_key",
)
response = client.chat.completions.create(
model="docsgpt-agent",
messages=[{"role": "user", "content": "Summarize our refund policy"}],
)
print(response.choices[0].message.content)
```
</Tabs.Tab>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/v1/chat/completions \
-H "Authorization: Bearer your_agent_api_key" \
-H "Content-Type: application/json" \
-d '{"model":"docsgpt-agent","messages":[{"role":"user","content":"Summarize our refund policy"}]}'
```
</Tabs.Tab>
</Tabs>
The `model` field is accepted but ignored — the agent bound to your API key determines the model. The agent's prompt, sources, tools, and default model are loaded automatically.
## Base URL & Auth
| Environment | Base URL |
| --- | --- |
| Local | `http://localhost:7091/v1` |
| Cloud | `https://gptcloud.arc53.com/v1` |
Authenticate with `Authorization: Bearer <agent_api_key>`.
## Endpoints
| Method | Path | Description |
| --- | --- | --- |
| `POST` | `/v1/chat/completions` | Chat request (streaming or non-streaming) |
| `GET` | `/v1/models` | List agents available to your key |
## Streaming
Set `"stream": true`. You'll receive SSE chunks with `choices[0].delta.content`. DocsGPT-specific events (sources, tool calls) arrive as extra frames with a `docsgpt` key — standard clients ignore them.
```python
stream = client.chat.completions.create(
model="docsgpt-agent",
stream=True,
messages=[{"role": "user", "content": "Explain vector search"}],
)
for chunk in stream:
print(chunk.choices[0].delta.content or "", end="", flush=True)
```
## System Prompt Override
System messages are **dropped by default** — the agent's configured prompt is used. To allow callers to override it, enable **Allow prompt override** in the agent's Advanced settings.
<Callout type="warning">
When an override is active, the agent's prompt template is replaced wholesale — template variables like `{summaries}` are not substituted.
</Callout>
## Conversation Persistence
Conversations are **not persisted by default** (stateless, like most OpenAI clients expect). Opt in per request:
```json
{ "docsgpt": { "save_conversation": true } }
```
The response will include `docsgpt.conversation_id`.
## When to Use Native Endpoints Instead
Use [`/api/answer` or `/stream`](/Agents/api) if you need server-side attachments, `passthrough` template variables, explicit `conversation_id` reuse, or persistence by default.

View File

@@ -0,0 +1,114 @@
---
title: PostgreSQL for User Data
description: Set up PostgreSQL as the user-data store for DocsGPT and migrate from MongoDB at your own pace.
---
import { Callout } from 'nextra/components'
# PostgreSQL for User Data
DocsGPT is progressively moving user data (conversations, agents, prompts,
preferences, etc.) from MongoDB to PostgreSQL, one collection at a time.
Each collection is guarded by a feature flag so you can opt in and roll
back instantly. MongoDB stays the source of truth until you cut over
reads; vector stores (`VECTOR_STORE=pgvector`, `faiss`, `qdrant`, `mongodb`, …)
are unaffected.
<Callout type="info" emoji="">
Which collections are available today is in the [Status](#status)
table below. That table is the only part of this page that changes
release to release.
</Callout>
## Setup
1. **Run Postgres 13+.** Native install, Docker, or managed (Neon, RDS,
Supabase, Cloud SQL…) — all work. You'll need the `pgcrypto` and
`citext` extensions, both standard contrib modules available
everywhere.
2. **Create a database and role** (skip if your managed provider gave
you these):
```sql
CREATE ROLE docsgpt LOGIN PASSWORD 'docsgpt';
CREATE DATABASE docsgpt OWNER docsgpt;
```
3. **Set `POSTGRES_URI` in `.env`.** Any standard Postgres URI works —
DocsGPT normalizes it internally.
```bash
POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
# Append ?sslmode=require for managed providers that enforce SSL.
```
4. **Apply the schema** (idempotent — safe to re-run):
```bash
python scripts/db/init_postgres.py
```
## Migrating data
Two global flags, no per-collection knobs — every collection marked ✅
in the [Status](#status) table is handled automatically.
1. **Enable dual-write.** Writes go to both Mongo and Postgres; Mongo
remains source of truth. Set the flag in `.env` and restart:
```bash
USE_POSTGRES=true
```
2. **Backfill existing data.** Idempotent — re-run any time to re-sync
drifted rows. Without arguments, backfills every registered table;
pass `--tables` to limit.
```bash
python scripts/db/backfill.py --dry-run # preview everything
python scripts/db/backfill.py # real run, everything
python scripts/db/backfill.py --tables users # only specific tables
```
3. **Cut over reads** once you trust the Postgres state:
```bash
READ_POSTGRES=true
```
Rollback is instant: unset `READ_POSTGRES` and restart. Dual-write
keeps Postgres up to date so you can flip back and forth.
<Callout type="warning" emoji="⚠️">
Don't decommission MongoDB until every collection you use is fully
cut over. During the migration window, Mongo is still required.
</Callout>
## Status
_Last updated: 2026-04-10_
| Collection | Status |
|---|---|
| `users` | ✅ Phase 1 |
| `prompts`, `user_tools`, `feedback`, `stack_logs`, `user_logs`, `token_usage` | ⏳ Phase 1 |
| `agents`, `sources`, `attachments`, `memories`, `todos`, `notes`, `connector_sessions`, `agent_folders` | ⏳ Phase 2 |
| `conversations`, `pending_tool_state`, `workflows` | ⏳ Phase 3 |
Schemas for **every** row above already exist after `init_postgres.py`
runs. What's landing progressively is the application-level dual-write
wiring and the backfill logic for each collection. Once a collection
is ✅, enabling `USE_POSTGRES=true` and running `python scripts/db/backfill.py`
picks it up automatically — no per-collection config change.
## Troubleshooting
- **`relation "..." does not exist`** — run `python scripts/db/init_postgres.py`.
- **`FATAL: role "docsgpt" does not exist`** — run the `CREATE ROLE` /
`CREATE DATABASE` statements from step 2 as a Postgres superuser.
- **SSL errors on a managed provider** — append `?sslmode=require` to
`POSTGRES_URI`.
- **Dual-write warnings in the logs** — expected to be non-fatal. Mongo
is source of truth, so the user-facing request succeeds. Re-run the
backfill to re-sync whichever rows drifted.

View File

@@ -19,6 +19,10 @@ export default {
"title": "☁️ Hosting DocsGPT",
"href": "/Deploying/Hosting-the-app"
},
"Postgres-Migration": {
"title": "🐘 PostgreSQL for User Data",
"href": "/Deploying/Postgres-Migration"
},
"Amazon-Lightsail": {
"title": "Hosting DocsGPT on Amazon Lightsail",
"href": "/Deploying/Amazon-Lightsail",

View File

@@ -1,3 +1,5 @@
import hashlib
import hmac
import os
import pprint
@@ -10,6 +12,7 @@ docsgpt_url = os.getenv("docsgpt_url")
chatwoot_url = os.getenv("chatwoot_url")
docsgpt_key = os.getenv("docsgpt_key")
chatwoot_token = os.getenv("chatwoot_token")
chatwoot_webhook_secret = os.getenv("chatwoot_webhook_secret", "")
# account_id = os.getenv("account_id")
# assignee_id = os.getenv("assignee_id")
label_stop = "human-requested"
@@ -45,12 +48,35 @@ def send_to_chatwoot(account, conversation, message):
return r.json()
def is_valid_chatwoot_signature(raw_body: bytes, signature_header: str | None) -> bool:
"""Validate Chatwoot webhook signature using shared secret."""
if not chatwoot_webhook_secret or not signature_header:
return False
expected = hmac.new(
chatwoot_webhook_secret.encode("utf-8"), raw_body, hashlib.sha256
).hexdigest()
provided = signature_header.strip()
if provided.startswith("sha256="):
provided = provided.split("=", maxsplit=1)[1]
return hmac.compare_digest(provided, expected)
app = Flask(__name__)
@app.route('/docsgpt', methods=['POST'])
def docsgpt():
data = request.get_json()
raw_body = request.get_data()
signature = request.headers.get("X-Chatwoot-Signature")
if not is_valid_chatwoot_signature(raw_body, signature):
return "Unauthorized", 401
data = request.get_json(silent=True)
if not isinstance(data, dict):
return "Invalid payload", 400
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(data)
try:

View File

@@ -73,6 +73,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
token_limit: undefined,
limited_request_mode: false,
request_limit: undefined,
allow_system_prompt_override: false,
models: [],
default_model_id: '',
});
@@ -241,6 +242,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
formData.append('request_limit', '0');
}
formData.append(
'allow_system_prompt_override',
agent.allow_system_prompt_override ? 'True' : 'False',
);
if (imageFile) formData.append('image', imageFile);
if (agent.tools && agent.tools.length > 0)
@@ -361,6 +367,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
formData.append('request_limit', '0');
}
formData.append(
'allow_system_prompt_override',
agent.allow_system_prompt_override ? 'True' : 'False',
);
if (agent.models && agent.models.length > 0) {
formData.append('models', JSON.stringify(agent.models));
}
@@ -1266,6 +1277,43 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
}`}
/>
</div>
<div className="mt-6">
<div className="flex items-center justify-between gap-4">
<div className="min-w-0 flex-1">
<h2 className="text-sm font-medium">
{t('agents.form.advanced.systemPromptOverride')}
</h2>
<p className="mt-1 text-xs text-gray-600 dark:text-gray-400">
{t(
'agents.form.advanced.systemPromptOverrideDescription',
)}
</p>
</div>
<button
onClick={() =>
setAgent({
...agent,
allow_system_prompt_override:
!agent.allow_system_prompt_override,
})
}
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
agent.allow_system_prompt_override
? 'bg-primary'
: 'bg-gray-300 dark:bg-gray-600'
}`}
>
<span
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
agent.allow_system_prompt_override
? ''
: '-translate-x-5'
}`}
/>
</button>
</div>
</div>
</div>
)}
</div>

View File

@@ -36,6 +36,7 @@ export type Agent = {
default_model_id?: string;
folder_id?: string;
workflow?: string;
allow_system_prompt_override?: boolean;
};
export type AgentFolder = {

View File

@@ -18,6 +18,7 @@ import {
X,
} from 'lucide-react';
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import { useNavigate, useParams, useSearchParams } from 'react-router-dom';
import ReactFlow, {
@@ -301,6 +302,7 @@ function createWorkflowPayload(
}
function WorkflowBuilderInner() {
const { t } = useTranslation();
const navigate = useNavigate();
const token = useSelector(selectToken);
const sourceDocs = useSelector(selectSourceDocs);
@@ -1142,6 +1144,10 @@ function WorkflowBuilderInner() {
workflowDescription || `Workflow agent: ${workflowName}`,
);
agentFormData.append('status', 'published');
agentFormData.append(
'allow_system_prompt_override',
currentAgent.allow_system_prompt_override ? 'True' : 'False',
);
if (imageFile) {
agentFormData.append('image', imageFile);
}
@@ -1203,6 +1209,10 @@ function WorkflowBuilderInner() {
agentFormData.append('agent_type', 'workflow');
agentFormData.append('status', 'published');
agentFormData.append('workflow', savedWorkflowId || '');
agentFormData.append(
'allow_system_prompt_override',
currentAgent.allow_system_prompt_override ? 'True' : 'False',
);
if (imageFile) {
agentFormData.append('image', imageFile);
}
@@ -1454,6 +1464,40 @@ function WorkflowBuilderInner() {
Image updates are included the next time you save.
</p>
</div>
<div className="mb-3">
<div className="flex items-center justify-between">
<div>
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300">
{t('agents.form.advanced.systemPromptOverride')}
</label>
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
{t('agents.form.advanced.systemPromptOverrideDescription')}
</p>
</div>
<button
onClick={() =>
setCurrentAgent((prev) => ({
...prev,
allow_system_prompt_override:
!prev.allow_system_prompt_override,
}))
}
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
currentAgent.allow_system_prompt_override
? 'bg-primary'
: 'bg-gray-300 dark:bg-gray-600'
}`}
>
<span
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
currentAgent.allow_system_prompt_override
? ''
: '-translate-x-5'
}`}
/>
</button>
</div>
</div>
<button
onClick={handleWorkflowSettingsDone}
disabled={isPublishing}

View File

@@ -78,6 +78,7 @@ function Dropdown<T extends DropdownOption>({
const searchRef = useRef<HTMLInputElement>(null);
const [open, setOpen] = useState(false);
const [query, setQuery] = useState('');
const [dropUp, setDropUp] = useState(false);
const radius = rounded === '3xl' ? 'rounded-3xl' : 'rounded-xl';
const radiusTop = rounded === '3xl' ? 'rounded-t-3xl' : 'rounded-t-xl';
@@ -90,14 +91,23 @@ function Dropdown<T extends DropdownOption>({
setQuery('');
}
};
document.addEventListener('mousedown', handler);
return () => document.removeEventListener('mousedown', handler);
document.addEventListener('mousedown', handler, true);
return () => document.removeEventListener('mousedown', handler, true);
}, []);
useEffect(() => {
if (open && searchable && searchRef.current) searchRef.current.focus();
}, [open, searchable]);
const handleToggle = () => {
if (!open && ref.current) {
const rect = ref.current.getBoundingClientRect();
const spaceBelow = window.innerHeight - rect.bottom;
setDropUp(spaceBelow < 220);
}
setOpen((v) => !v);
};
const filtered = useMemo(() => {
if (!searchable || !query.trim()) return options;
const q = query.toLowerCase();
@@ -110,8 +120,8 @@ function Dropdown<T extends DropdownOption>({
<div className={`relative ${size}`} ref={ref}>
<button
type="button"
onClick={() => setOpen((v) => !v)}
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? radiusTop : radius}`}
onClick={handleToggle}
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? (dropUp ? radiusBottom : radiusTop) : radius}`}
>
<span
className={`truncate ${contentSize} ${displayValue ? '' : 'text-muted-foreground'}`}
@@ -125,7 +135,11 @@ function Dropdown<T extends DropdownOption>({
{open && (
<div
className={`border-border bg-card absolute inset-x-0 z-20 -mt-px overflow-hidden border border-t-0 shadow-lg ${radiusBottom}`}
className={`border-border bg-card absolute inset-x-0 z-20 overflow-hidden border shadow-lg ${
dropUp
? `bottom-full -mt-px border-b-0 ${radiusTop}`
: `-mt-px border-t-0 ${radiusBottom}`
}`}
>
{searchable && (
<div className="flex items-center px-3 py-2">

View File

@@ -10,7 +10,9 @@ interface SkeletonLoaderProps {
| 'chatbot'
| 'dropdown'
| 'chunkCards'
| 'sourceCards';
| 'sourceCards'
| 'toolCards'
| 'addToolCards';
}
const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
@@ -237,6 +239,55 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
</>
);
const renderAddToolCards = () => (
<>
{Array.from({ length: count }).map((_, idx) => (
<div
key={`add-tool-skel-${idx}`}
className="border-light-gainsboro dark:border-arsenic flex h-52 w-full animate-pulse flex-col justify-between rounded-2xl border p-6"
>
<div className="w-full">
<div className="flex w-full items-center justify-between px-1">
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
<div className="mt-[9px] space-y-2 px-1">
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
</div>
</div>
))}
</>
);
const renderToolCards = () => (
<>
{Array.from({ length: count }).map((_, idx) => (
<div
key={`tool-skel-${idx}`}
className="bg-muted flex h-52 w-[300px] animate-pulse flex-col justify-between rounded-2xl p-6"
>
<div className="w-full">
<div className="flex items-center gap-2 px-1">
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
<div className="mt-[9px] space-y-2 px-1">
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
</div>
<div className="flex justify-end">
<div className="h-5 w-9 rounded-full bg-gray-300 dark:bg-gray-600"></div>
</div>
</div>
))}
</>
);
const componentMap = {
fileTable: renderTable,
chatbot: renderChatbot,
@@ -246,6 +297,8 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
analysis: renderAnalysis,
chunkCards: renderChunkCards,
sourceCards: renderSourceCards,
toolCards: renderToolCards,
addToolCards: renderAddToolCards,
};
const render = componentMap[component] || componentMap.default;

View File

@@ -619,7 +619,9 @@
"tokenLimiting": "Token-Limitierung",
"tokenLimitingDescription": "Begrenze die täglich von diesem Agenten verwendbaren Tokens",
"requestLimiting": "Anfrage-Limitierung",
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen"
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen",
"systemPromptOverride": "Prompt-Überschreibung erlauben",
"systemPromptOverrideDescription": "Erlaubt v1-API-Aufrufern, den System-Prompt dieses Agenten zu ersetzen"
},
"preview": {
"publishedPreview": "Veröffentlichte Agenten können hier in der Vorschau angezeigt werden"

View File

@@ -653,7 +653,9 @@
"tokenLimiting": "Token limiting",
"tokenLimitingDescription": "Limit daily total tokens that can be used by this agent",
"requestLimiting": "Request limiting",
"requestLimitingDescription": "Limit daily total requests that can be made to this agent"
"requestLimitingDescription": "Limit daily total requests that can be made to this agent",
"systemPromptOverride": "Allow prompt override",
"systemPromptOverrideDescription": "Let v1 API callers replace this agent's system prompt"
},
"preview": {
"publishedPreview": "Published agents can be previewed here"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "Límite de tokens",
"tokenLimitingDescription": "Limita el total diario de tokens que puede usar este agente",
"requestLimiting": "Límite de solicitudes",
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente"
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente",
"systemPromptOverride": "Permitir sobrescribir el prompt",
"systemPromptOverrideDescription": "Permitir que los llamadores de la API v1 reemplacen el prompt del sistema de este agente"
},
"preview": {
"publishedPreview": "Los agentes publicados se pueden previsualizar aquí"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "トークン制限",
"tokenLimitingDescription": "このエージェントが使用できる1日の合計トークン数を制限します",
"requestLimiting": "リクエスト制限",
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します"
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します",
"systemPromptOverride": "プロンプトの上書きを許可",
"systemPromptOverrideDescription": "v1 API呼び出し元がこのエージェントのシステムプロンプトを置き換えることを許可します"
},
"preview": {
"publishedPreview": "公開されたエージェントはここでプレビューできます"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "Лимит токенов",
"tokenLimitingDescription": "Ограничить ежедневное общее количество токенов, которые может использовать этот агент",
"requestLimiting": "Лимит запросов",
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту"
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту",
"systemPromptOverride": "Разрешить замену промпта",
"systemPromptOverrideDescription": "Разрешить вызовам API v1 заменять системный промпт этого агента"
},
"preview": {
"publishedPreview": "Опубликованные агенты можно просмотреть здесь"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "權杖限制",
"tokenLimitingDescription": "限制此代理每天可使用的總權杖數",
"requestLimiting": "請求限制",
"requestLimitingDescription": "限制每天可向此代理發出的總請求數"
"requestLimitingDescription": "限制每天可向此代理發出的總請求數",
"systemPromptOverride": "允許覆蓋提示詞",
"systemPromptOverrideDescription": "允許 v1 API 呼叫者替換此代理的系統提示詞"
},
"preview": {
"publishedPreview": "已發佈的代理可以在此處預覽"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "令牌限制",
"tokenLimitingDescription": "限制此代理每天可使用的总令牌数",
"requestLimiting": "请求限制",
"requestLimitingDescription": "限制每天可向此代理发出的总请求数"
"requestLimitingDescription": "限制每天可向此代理发出的总请求数",
"systemPromptOverride": "允许覆盖提示词",
"systemPromptOverrideDescription": "允许 v1 API 调用者替换此代理的系统提示词"
},
"preview": {
"publishedPreview": "已发布的代理可以在此处预览"

View File

@@ -3,8 +3,8 @@ import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import Spinner from '../components/Spinner';
import { useOutsideAlerter } from '../hooks';
import SkeletonLoader from '../components/SkeletonLoader';
import { useLoaderState, useOutsideAlerter } from '../hooks';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import ConfigToolModal from './ConfigToolModal';
@@ -37,7 +37,7 @@ export default function AddToolModal({
React.useState<ActiveState>('INACTIVE');
const [mcpModalState, setMcpModalState] =
React.useState<ActiveState>('INACTIVE');
const [loading, setLoading] = React.useState(false);
const [loading, setLoading] = useLoaderState(false);
useOutsideAlerter(modalRef, () => {
if (modalState === 'ACTIVE') {
@@ -121,8 +121,8 @@ export default function AddToolModal({
</h2>
<div className="mt-5 h-[73vh] overflow-auto px-3 py-px">
{loading ? (
<div className="flex h-full items-center justify-center">
<Spinner />
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">
<SkeletonLoader component="addToolCards" count={6} />
</div>
) : (
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">

View File

@@ -10,9 +10,9 @@ import NoFilesIcon from '../assets/no-files.svg';
import SearchIcon from '../assets/search.svg';
import ThreeDotsIcon from '../assets/three-dots.svg';
import ContextMenu, { MenuOption } from '../components/ContextMenu';
import Spinner from '../components/Spinner';
import SkeletonLoader from '../components/SkeletonLoader';
import ToggleSwitch from '../components/ToggleSwitch';
import { useDarkTheme } from '../hooks';
import { useDarkTheme, useLoaderState } from '../hooks';
import AddToolModal from '../modals/AddToolModal';
import ConfirmationModal from '../modals/ConfirmationModal';
import MCPServerModal from '../modals/MCPServerModal';
@@ -33,7 +33,7 @@ export default function Tools() {
const [selectedTool, setSelectedTool] = React.useState<
UserToolType | APIToolType | null
>(null);
const [loading, setLoading] = React.useState(false);
const [loading, setLoading] = useLoaderState(false);
const [activeMenuId, setActiveMenuId] = React.useState<string | null>(null);
const menuRefs = React.useRef<{
[key: string]: React.RefObject<HTMLDivElement | null>;
@@ -242,10 +242,8 @@ export default function Tools() {
</div>
<div className="border-border dark:border-border mt-5 mb-8 border-b" />
{loading ? (
<div className="grid grid-cols-2 gap-6 lg:grid-cols-3">
<div className="col-span-2 mt-24 flex h-32 items-center justify-center lg:col-span-3">
<Spinner />
</div>
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
<SkeletonLoader component="toolCards" count={6} />
</div>
) : (
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">

218
scripts/db/backfill.py Normal file
View File

@@ -0,0 +1,218 @@
"""Backfill DocsGPT's Postgres user-data tables from MongoDB.
One script for every migrated collection. Adding a new collection is a
two-step change in this file:
1. Write a ``_backfill_<name>`` function that takes keyword args
``conn``, ``mongo_db``, ``batch_size``, ``dry_run`` and returns a
stats ``dict``.
2. Add a single entry to :data:`BACKFILLERS`.
There are intentionally no per-collection CLI flags or environment
variables — ``USE_POSTGRES`` / ``READ_POSTGRES`` in ``.env`` are the
only knobs operators need. This script discovers what's available from
the :data:`BACKFILLERS` registry and runs whichever tables were asked for.
Usage::
python scripts/db/backfill.py # every registered table
python scripts/db/backfill.py --tables users # only specific tables
python scripts/db/backfill.py --dry-run # count without writing
python scripts/db/backfill.py --batch 1000 # tune commit size
Exit codes:
0 — every requested table completed successfully
1 — misconfiguration (missing env var, unknown table name)
2 — at least one table failed at runtime (others may still have succeeded)
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
from typing import Any, Callable
# Make the project root importable regardless of cwd.
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from sqlalchemy import Connection, text # noqa: E402
from application.core.mongo_db import MongoDB # noqa: E402
from application.core.settings import settings # noqa: E402
from application.storage.db.engine import get_engine # noqa: E402
logger = logging.getLogger("backfill")
# ---------------------------------------------------------------------------
# Per-table backfillers
# ---------------------------------------------------------------------------
def _backfill_users(
*,
conn: Connection,
mongo_db: Any,
batch_size: int,
dry_run: bool,
) -> dict:
"""Sync the ``users`` table from Mongo ``users`` collection.
Overwrites each Postgres row's ``agent_preferences`` with the Mongo
state (Mongo is source of truth during the cutover window). Missing
``pinned`` / ``shared_with_me`` keys are filled with empty arrays so
the Postgres row always has the full shape the application expects.
"""
upsert_sql = text(
"""
INSERT INTO users (user_id, agent_preferences)
VALUES (:user_id, CAST(:prefs AS jsonb))
ON CONFLICT (user_id) DO UPDATE
SET agent_preferences = EXCLUDED.agent_preferences,
updated_at = now()
"""
)
cursor = (
mongo_db["users"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = 0
written = 0
skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user_id")
if not user_id:
skipped += 1
continue
raw_prefs = doc.get("agent_preferences") or {}
prefs = {
"pinned": list(raw_prefs.get("pinned") or []),
"shared_with_me": list(raw_prefs.get("shared_with_me") or []),
}
batch.append({"user_id": user_id, "prefs": json.dumps(prefs)})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(upsert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(upsert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped_no_user_id": skipped}
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
BackfillFn = Callable[..., dict]
# Register new tables here. Order matters only in the sense that
# ``--tables`` without arguments iterates in insertion order — put tables
# with FK dependencies after the tables they reference so a full-run
# backfill doesn't hit FK errors.
BACKFILLERS: dict[str, BackfillFn] = {
"users": _backfill_users,
}
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> int:
parser = argparse.ArgumentParser(
description="Backfill DocsGPT Postgres tables from MongoDB."
)
parser.add_argument(
"--tables",
default="",
help=(
"Comma-separated table names to backfill. "
f"Defaults to every registered table ({','.join(BACKFILLERS)})."
),
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Iterate Mongo without writing to Postgres.",
)
parser.add_argument(
"--batch",
type=int,
default=500,
help="How many rows to commit per Postgres statement (default: 500).",
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-5s %(name)s %(message)s",
)
if not settings.POSTGRES_URI:
logger.error("POSTGRES_URI is not set. Configure it in .env first.")
return 1
if not settings.MONGO_URI:
logger.error("MONGO_URI is not set. Configure it in .env first.")
return 1
requested = [t.strip() for t in args.tables.split(",") if t.strip()]
if not requested:
requested = list(BACKFILLERS)
unknown = [t for t in requested if t not in BACKFILLERS]
if unknown:
logger.error(
"Unknown table(s): %s. Available: %s",
", ".join(unknown),
", ".join(BACKFILLERS),
)
return 1
mongo = MongoDB.get_client()
mongo_db = mongo[settings.MONGO_DB_NAME]
engine = get_engine()
failures = 0
for table in requested:
logger.info("backfill %s: start", table)
try:
with engine.begin() as conn:
stats = BACKFILLERS[table](
conn=conn,
mongo_db=mongo_db,
batch_size=args.batch,
dry_run=args.dry_run,
)
logger.info(
"backfill %s: done %s dry_run=%s", table, stats, args.dry_run
)
except Exception:
failures += 1
logger.exception("backfill %s: failed", table)
return 2 if failures else 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,55 @@
"""One-shot bootstrap: run all Alembic migrations against POSTGRES_URI.
Intended use:
* local dev, after setting ``POSTGRES_URI`` in ``.env``::
python scripts/db/init_postgres.py
* CI, as a step before running the pytest suite.
* Docker image build or container start, if the operator wants the
migrations applied automatically on first boot.
This script is a thin wrapper around ``alembic upgrade head``. It exists
separately so the same command is discoverable from the repo root without
remembering the ``-c application/alembic.ini`` invocation.
"""
from __future__ import annotations
import sys
from pathlib import Path
from alembic import command
from alembic.config import Config
REPO_ROOT = Path(__file__).resolve().parents[2]
ALEMBIC_INI = REPO_ROOT / "application" / "alembic.ini"
def main() -> int:
"""Apply every pending migration up to ``head``.
Returns:
``0`` on success, ``1`` on failure. Non-zero is propagated as the
process exit code so CI jobs fail loudly.
"""
if not ALEMBIC_INI.exists():
print(f"alembic.ini not found at {ALEMBIC_INI}", file=sys.stderr)
return 1
cfg = Config(str(ALEMBIC_INI))
# Make `script_location` resolve correctly when invoked from any cwd.
cfg.set_main_option("script_location", str(ALEMBIC_INI.parent / "alembic"))
try:
command.upgrade(cfg, "head")
except Exception as exc: # noqa: BLE001 — surface everything to the operator
print(f"alembic upgrade failed: {exc}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -543,8 +543,20 @@ function Configure-TTS {
}
}
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
function Ensure-InternalKey {
$content = if (Test-Path $ENV_FILE) { Get-Content $ENV_FILE -Raw } else { "" }
if ($content -notmatch "(?m)^INTERNAL_KEY=") {
$bytes = New-Object byte[] 32
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
$internal_key = ($bytes | ForEach-Object { $_.ToString("x2") }) -join ""
"INTERNAL_KEY=$internal_key" | Add-Content -Path $ENV_FILE -Encoding utf8
}
}
# Main advanced settings menu
function Prompt-AdvancedSettings {
Ensure-InternalKey
Write-Host ""
$configure_advanced = Read-Host "Would you like to configure advanced settings? (y/N)"
if ($configure_advanced -ne "y" -and $configure_advanced -ne "Y") {

View File

@@ -396,8 +396,18 @@ configure_tts() {
esac
}
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
ensure_internal_key() {
if ! grep -q "^INTERNAL_KEY=" "$ENV_FILE" 2>/dev/null; then
local internal_key
internal_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
echo "INTERNAL_KEY=$internal_key" >> "$ENV_FILE"
fi
}
# Main advanced settings menu
prompt_advanced_settings() {
ensure_internal_key
echo
read -p "$(echo -e "${DEFAULT_FG}Would you like to configure advanced settings? (y/N): ${NC}")" configure_advanced
if [[ ! "$configure_advanced" =~ ^[yY]$ ]]; then

View File

@@ -28,6 +28,7 @@ def _patch_mcp_globals(monkeypatch):
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
monkeypatch.setattr(mcp_mod, "db", mock_db)
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
monkeypatch.setattr(mcp_mod, "validate_url", lambda url: url)
@pytest.fixture

View File

@@ -18,7 +18,7 @@ class TestPostgresExecuteAction:
with pytest.raises(ValueError, match="Unknown action"):
tool.execute_action("invalid")
@patch("application.agents.tools.postgres.psycopg2.connect")
@patch("application.agents.tools.postgres.psycopg.connect")
def test_select_query(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
@@ -37,7 +37,7 @@ class TestPostgresExecuteAction:
assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"}
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
@patch("application.agents.tools.postgres.psycopg.connect")
def test_insert_query(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
@@ -55,11 +55,11 @@ class TestPostgresExecuteAction:
mock_conn.commit.assert_called_once()
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
@patch("application.agents.tools.postgres.psycopg.connect")
def test_db_error(self, mock_connect, tool):
import psycopg2
import psycopg
mock_connect.side_effect = psycopg2.Error("connection refused")
mock_connect.side_effect = psycopg.Error("connection refused")
result = tool.execute_action(
"postgres_execute_sql", sql_query="SELECT 1"
@@ -68,7 +68,7 @@ class TestPostgresExecuteAction:
assert result["status_code"] == 500
assert "Database error" in result["error"]
@patch("application.agents.tools.postgres.psycopg2.connect")
@patch("application.agents.tools.postgres.psycopg.connect")
def test_get_schema(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()
@@ -89,24 +89,24 @@ class TestPostgresExecuteAction:
assert result["schema"]["users"][0]["column_name"] == "id"
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
@patch("application.agents.tools.postgres.psycopg.connect")
def test_get_schema_db_error(self, mock_connect, tool):
import psycopg2
import psycopg
mock_connect.side_effect = psycopg2.Error("auth failed")
mock_connect.side_effect = psycopg.Error("auth failed")
result = tool.execute_action("postgres_get_schema", db_name="testdb")
assert result["status_code"] == 500
assert "Database error" in result["error"]
@patch("application.agents.tools.postgres.psycopg2.connect")
@patch("application.agents.tools.postgres.psycopg.connect")
def test_connection_closed_on_error(self, mock_connect, tool):
import psycopg2
import psycopg
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.execute.side_effect = psycopg2.Error("syntax error")
mock_cur.execute.side_effect = psycopg.Error("syntax error")
mock_conn.cursor.return_value = mock_cur
mock_connect.return_value = mock_conn
@@ -114,7 +114,7 @@ class TestPostgresExecuteAction:
mock_conn.close.assert_called_once()
@patch("application.agents.tools.postgres.psycopg2.connect")
@patch("application.agents.tools.postgres.psycopg.connect")
def test_select_with_no_description(self, mock_connect, tool):
mock_conn = MagicMock()
mock_cur = MagicMock()

View File

@@ -33,6 +33,8 @@ def _patch_mcp_globals(monkeypatch):
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
monkeypatch.setattr(mcp_mod, "db", mock_db)
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
# Bypass DNS-resolving URL validation for tests using fake hostnames.
monkeypatch.setattr(mcp_mod, "validate_url", lambda u, **kw: u)
@pytest.fixture
@@ -136,6 +138,47 @@ class TestMCPToolInit:
})
assert tool.custom_headers == {"X-Custom": "val"}
def test_rejects_metadata_ip(self, monkeypatch):
from application.agents.tools.mcp_tool import MCPTool
from application.core.url_validation import validate_url as real_validate_url
import application.agents.tools.mcp_tool as mcp_mod
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
with pytest.raises(ValueError, match="Invalid MCP server URL"):
MCPTool(config={"server_url": "http://169.254.169.254/latest/meta-data", "auth_type": "none"})
def test_rejects_localhost(self, monkeypatch):
from application.agents.tools.mcp_tool import MCPTool
from application.core.url_validation import validate_url as real_validate_url
import application.agents.tools.mcp_tool as mcp_mod
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
with pytest.raises(ValueError, match="Invalid MCP server URL"):
MCPTool(config={"server_url": "http://localhost:8080/mcp", "auth_type": "none"})
def test_rejects_private_ip(self, monkeypatch):
from application.agents.tools.mcp_tool import MCPTool
from application.core.url_validation import validate_url as real_validate_url
import application.agents.tools.mcp_tool as mcp_mod
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
with pytest.raises(ValueError, match="Invalid MCP server URL"):
MCPTool(config={"server_url": "http://10.0.0.1/mcp", "auth_type": "none"})
def test_accepts_public_url(self):
tool = _make_tool({
"server_url": "https://mcp.example.com/api",
"auth_type": "none",
})
assert tool.server_url == "https://mcp.example.com/api"
def test_empty_server_url_allowed(self):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client"):
tool = MCPTool(config={"server_url": "", "auth_type": "none"})
assert tool.server_url == ""
# =====================================================================
# Redirect URI Resolution

View File

@@ -330,6 +330,170 @@ class TestStreamProcessorDocPrefetch:
assert docs_together is not None
assert "Agent doc content" in docs_together
def test_configure_source_treats_default_string_as_no_docs(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_default_source_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
"source": "default",
}
)
processor = StreamProcessor(
{"question": "Hi", "api_key": "agent_default_source_key"},
None,
)
processor._configure_agent()
processor._configure_source()
assert processor.source == {}
assert processor.all_sources == []
def test_prefetch_skipped_when_no_active_docs(self, mock_mongo_db):
from unittest.mock import MagicMock
from application.api.answer.services.stream_processor import StreamProcessor
processor = StreamProcessor(
{"question": "Hi there"},
{"sub": "user_123"},
)
processor.initialize()
processor.create_retriever = MagicMock()
docs_together, docs = processor.pre_fetch_docs("Hi there")
processor.create_retriever.assert_not_called()
assert docs_together is None
assert docs is None
def test_prefetch_skipped_when_active_docs_is_default(self, mock_mongo_db):
from unittest.mock import MagicMock
from application.api.answer.services.stream_processor import StreamProcessor
processor = StreamProcessor(
{"question": "Hi", "active_docs": "default"},
{"sub": "user_123"},
)
processor.initialize()
processor.create_retriever = MagicMock()
docs_together, docs = processor.pre_fetch_docs("Hi")
processor.create_retriever.assert_not_called()
assert docs_together is None
assert docs is None
def test_agent_retriever_and_chunks_propagate_to_retriever_config(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
source_id = ObjectId()
db["sources"].insert_one(
{"_id": source_id, "name": "src", "retriever": "hybrid", "chunks": 5}
)
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_ret_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
"retriever": "hybrid",
"chunks": 5,
"source": DBRef("sources", source_id),
}
)
processor = StreamProcessor(
{"question": "Test", "api_key": "agent_ret_key"},
None,
)
processor.initialize()
assert processor.retriever_config["retriever_name"] == "hybrid"
assert processor.retriever_config["chunks"] == 5
def test_request_retriever_and_chunks_override_agent_config(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_override_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
"retriever": "hybrid",
"chunks": 5,
}
)
processor = StreamProcessor(
{
"question": "Test",
"api_key": "agent_override_key",
"retriever": "classic",
"chunks": 7,
},
None,
)
processor.initialize()
assert processor.retriever_config["retriever_name"] == "classic"
assert processor.retriever_config["chunks"] == 7
def test_agent_data_fetched_once_per_request(self, mock_mongo_db):
from unittest.mock import patch
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_once_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
}
)
processor = StreamProcessor(
{"question": "Test", "api_key": "agent_once_key"},
None,
)
with patch.object(
processor, "_get_data_from_api_key", wraps=processor._get_data_from_api_key
) as spy:
processor.initialize()
assert spy.call_count == 1
@pytest.mark.unit
class TestStreamProcessorAttachments:

View File

@@ -566,14 +566,13 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {
sp._agent_data = {
"sources": [
{"id": "src1", "retriever": "classic"},
{"id": "src2", "retriever": "hybrid"},
],
"source": None,
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {"active_docs": ["src1", "src2"]}
assert len(sp.all_sources) == 2
@@ -593,12 +592,11 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {
sp._agent_data = {
"sources": [],
"source": "single_src",
"retriever": "classic",
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {"active_docs": "single_src"}
assert len(sp.all_sources) == 1
@@ -618,8 +616,7 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {"sources": [], "source": None}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._agent_data = {"sources": [], "source": None}
sp._configure_source()
assert sp.source == {}
assert sp.all_sources == []
@@ -639,11 +636,10 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = "agent_key_123"
agent_data = {
sp._agent_data = {
"sources": [{"id": "s1", "retriever": "classic"}],
"source": None,
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {"active_docs": ["s1"]}
@@ -662,11 +658,10 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {
sp._agent_data = {
"sources": [{"id": None}, {"retriever": "classic"}],
"source": None,
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {}
@@ -1189,6 +1184,8 @@ class TestConfigureAgent:
"chunks": "5",
})
sp._configure_agent()
sp.model_id = "test-model"
sp._configure_retriever()
assert sp.agent_config["workflow"] == "wf_123"
assert sp.agent_config["workflow_owner"] == "user1"
assert sp.retriever_config["retriever_name"] == "hybrid"
@@ -1211,6 +1208,8 @@ class TestConfigureAgent:
"chunks": "not_a_number",
})
sp._configure_agent()
sp.model_id = "test-model"
sp._configure_retriever()
assert sp.retriever_config["chunks"] == 2
@@ -1763,8 +1762,8 @@ class TestConfigureAgentAdditionalPaths:
assert sp.decoded_token == {"sub": "owner_user"}
@pytest.mark.unit
def test_configure_agent_with_source_in_data_key(self):
"""Cover line 463-464: data_key has 'source' set."""
def test_configure_source_picks_up_cached_agent_data(self):
"""After _configure_agent caches _agent_data, _configure_source uses it."""
sp = self._make_sp()
sp._resolve_agent_id = MagicMock(return_value="agent_id_1")
sp._get_agent_key = MagicMock(return_value=("agent_key", False, None))
@@ -1780,6 +1779,7 @@ class TestConfigureAgentAdditionalPaths:
"source": "my_source",
})
sp._configure_agent()
sp._configure_source()
assert sp.source == {"active_docs": "my_source"}
@pytest.mark.unit
@@ -2067,7 +2067,7 @@ class TestPreFetchDocsFullPaths:
"chunks": 2,
"doc_token_limit": 50000,
}
sp.source = {}
sp.source = {"active_docs": ["src1"]}
sp.model_id = "test-model"
sp.agent_id = None
return sp

View File

@@ -48,7 +48,7 @@ def internal_app(monkeypatch, mock_mongo_db):
@pytest.mark.unit
class TestVerifyInternalKey:
def test_no_internal_key_configured_allows_access(
def test_no_internal_key_configured_rejects_access(
self, internal_app, monkeypatch
):
app, db = internal_app
@@ -63,9 +63,8 @@ class TestVerifyInternalKey:
),
)
with app.test_client() as client:
# download will fail for missing file but should not be 401
resp = client.get("/api/download?user=u&name=n&file=f")
assert resp.status_code != 401
assert resp.status_code == 401
def test_missing_key_returns_401(self, internal_app, monkeypatch):
app, db = internal_app
@@ -131,9 +130,12 @@ class TestVerifyInternalKey:
@pytest.mark.unit
class TestUploadIndex:
_TEST_INTERNAL_KEY = "test-internal-key"
_AUTH_HEADERS = {"X-Internal-Key": "test-internal-key"}
def _make_settings(self, vector_store="faiss"):
return MagicMock(
INTERNAL_KEY=None,
INTERNAL_KEY=self._TEST_INTERNAL_KEY,
UPLOAD_FOLDER="uploads",
VECTOR_STORE=vector_store,
EMBEDDINGS_NAME="test_embeddings",
@@ -146,7 +148,7 @@ class TestUploadIndex:
"application.api.internal.routes.settings", self._make_settings()
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={})
resp = client.post("/api/upload_index", data={}, headers=self._AUTH_HEADERS)
assert resp.json["status"] == "no user"
def test_missing_name_returns_no_name(self, internal_app, monkeypatch):
@@ -155,7 +157,7 @@ class TestUploadIndex:
"application.api.internal.routes.settings", self._make_settings()
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={"user": "testuser"})
resp = client.post("/api/upload_index", data={"user": "testuser"}, headers=self._AUTH_HEADERS)
assert resp.json["status"] == "no name"
def test_creates_new_source_entry(self, internal_app, monkeypatch):
@@ -182,6 +184,7 @@ class TestUploadIndex:
"id": doc_id,
"type": "local",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -219,6 +222,7 @@ class TestUploadIndex:
"id": str(doc_id),
"type": "remote",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -252,6 +256,7 @@ class TestUploadIndex:
"type": "local",
"directory_structure": json.dumps(dir_struct),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -285,6 +290,7 @@ class TestUploadIndex:
"type": "local",
"directory_structure": "not valid json",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -317,6 +323,7 @@ class TestUploadIndex:
"type": "local",
"file_name_map": json.dumps(fmap),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -349,6 +356,7 @@ class TestUploadIndex:
"id": doc_id,
"type": "local",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file"
@@ -379,6 +387,7 @@ class TestUploadIndex:
"type": "local",
"file_faiss": (io.BytesIO(b""), ""),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file name"
@@ -408,6 +417,7 @@ class TestUploadIndex:
"remote_data": '{"url":"http://example.com"}',
"sync_frequency": "daily",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -443,6 +453,7 @@ class TestUploadIndex:
"file_pkl": (io.BytesIO(b"pkl data"), "index.pkl"),
},
content_type="multipart/form-data",
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -477,6 +488,7 @@ class TestUploadIndex:
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
},
content_type="multipart/form-data",
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file"
@@ -508,9 +520,27 @@ class TestUploadIndex:
"file_pkl": (io.BytesIO(b""), ""),
},
content_type="multipart/form-data",
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file name"
def test_no_internal_key_rejects_upload(self, internal_app, monkeypatch):
"""Verify that upload_index is rejected when INTERNAL_KEY is not set."""
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings",
MagicMock(
INTERNAL_KEY=None,
UPLOAD_FOLDER="uploads",
VECTOR_STORE="faiss",
EMBEDDINGS_NAME="test",
MONGO_DB_NAME="docsgpt",
),
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={"user": "attacker"})
assert resp.status_code == 401
def test_update_existing_with_file_name_map(self, internal_app, monkeypatch):
"""Cover line 124: update existing entry with file_name_map."""
app, db = internal_app
@@ -540,6 +570,7 @@ class TestUploadIndex:
"type": "local",
"file_name_map": json.dumps(fmap),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -572,6 +603,7 @@ class TestUploadIndex:
"type": "local",
"file_name_map": "not valid json{{{",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"

View File

@@ -14,6 +14,13 @@ def app():
return app
@pytest.fixture(autouse=True)
def _bypass_url_validation():
"""Bypass SSRF URL validation so tests using localhost URLs can proceed."""
with patch("application.api.user.tools.mcp.validate_url"):
yield
# ---------------------------------------------------------------------------
# Helper: _sanitize_mcp_transport
# ---------------------------------------------------------------------------

View File

@@ -395,6 +395,9 @@ class TestAvailableTools:
"application.api.user.tools.routes.tool_manager", mock_manager
):
with app.test_request_context("/api/available_tools"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AvailableTools().get()
assert response.status_code == 200
@@ -419,6 +422,9 @@ class TestAvailableTools:
"application.api.user.tools.routes.tool_manager", mock_manager
):
with app.test_request_context("/api/available_tools"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AvailableTools().get()
assert response.status_code == 400
@@ -438,6 +444,9 @@ class TestAvailableTools:
"application.api.user.tools.routes.tool_manager", mock_manager
):
with app.test_request_context("/api/available_tools"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AvailableTools().get()
assert response.status_code == 200

0
tests/api/v1/__init__.py Normal file
View File

View File

@@ -0,0 +1,64 @@
from flask import Flask
from application.api.v1.routes import v1_bp
class _FakeCollection:
def __init__(self, docs):
self.docs = docs
def find_one(self, query):
for doc in self.docs:
if all(doc.get(k) == v for k, v in query.items()):
return doc
return None
def find(self, query):
return [doc for doc in self.docs if all(doc.get(k) == v for k, v in query.items())]
def _build_app():
app = Flask(__name__)
app.register_blueprint(v1_bp)
return app
def test_v1_models_does_not_expose_agent_keys(monkeypatch):
docs = [
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
{"_id": "agent-2", "key": "key-2", "user": "user-1", "name": "Agent Two"},
]
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
app = _build_app()
client = app.test_client()
response = client.get("/v1/models", headers={"Authorization": "Bearer key-1"})
assert response.status_code == 200
payload = response.get_json()
assert payload["object"] == "list"
assert len(payload["data"]) == 2
assert payload["data"][0]["id"] == "agent-1"
assert payload["data"][1]["id"] == "agent-2"
# Keys must never appear as model IDs
assert all(model["id"] != "key-1" for model in payload["data"])
assert all(model["id"] != "key-2" for model in payload["data"])
def test_v1_models_invalid_key_returns_401(monkeypatch):
docs = [
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
]
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
app = _build_app()
client = app.test_client()
response = client.get("/v1/models", headers={"Authorization": "Bearer wrong-key"})
assert response.status_code == 401

144
tests/core/test_db_uri.py Normal file
View File

@@ -0,0 +1,144 @@
"""Tests for ``application.core.db_uri``.
DocsGPT has two Postgres connection strings — ``POSTGRES_URI`` (consumed
by SQLAlchemy) and ``PGVECTOR_CONNECTION_STRING`` (consumed by
``psycopg.connect()`` directly). They need opposite normalization
because SQLAlchemy requires a ``postgresql+psycopg://`` dialect prefix
and libpq rejects it. Each field has its own normalizer so operators
can write whichever form feels natural and cross-pollination between
the two fields is forgiven.
The normalizers live in ``application.core.db_uri`` as plain functions
so these tests can exercise them directly without having to instantiate
``Settings`` (which would pull in ``.env`` file side effects).
"""
from __future__ import annotations
import pytest
from application.core.db_uri import (
normalize_pgvector_connection_string,
normalize_postgres_uri,
)
@pytest.mark.unit
class TestNormalizePostgresUri:
@pytest.mark.parametrize(
"input_value,expected",
[
# User-friendly forms get rewritten to the SQLAlchemy dialect.
(
"postgres://u:p@h:5432/d",
"postgresql+psycopg://u:p@h:5432/d",
),
(
"postgresql://u:p@h:5432/d",
"postgresql+psycopg://u:p@h:5432/d",
),
# Legacy psycopg2 dialect is silently upgraded — psycopg2 is
# no longer in requirements.txt, so there's no way it can work
# as-is, and rewriting is friendlier than failing.
(
"postgresql+psycopg2://u:p@h:5432/d",
"postgresql+psycopg://u:p@h:5432/d",
),
# Already-correct dialect passes through unchanged.
(
"postgresql+psycopg://u:p@h:5432/d",
"postgresql+psycopg://u:p@h:5432/d",
),
# Whitespace is trimmed before rewriting.
(
" postgres://u:p@h/d ",
"postgresql+psycopg://u:p@h/d",
),
# Query-string params (sslmode, options) are preserved verbatim.
(
"postgresql://u:p@h:5432/d?sslmode=require&application_name=docsgpt",
"postgresql+psycopg://u:p@h:5432/d?sslmode=require&application_name=docsgpt",
),
],
)
def test_rewrites_common_forms_to_psycopg_dialect(self, input_value, expected):
assert normalize_postgres_uri(input_value) == expected
@pytest.mark.parametrize(
"input_value",
[None, "", " ", "None", "none"],
)
def test_empty_or_none_like_returns_none(self, input_value):
assert normalize_postgres_uri(input_value) is None
def test_unknown_scheme_passes_through(self):
"""A dialect we don't recognise is left alone so SQLAlchemy can
produce its own error message when the engine tries to connect.
Better than silently eating the config."""
weird = "postgresql+asyncpg://u:p@h/d"
assert normalize_postgres_uri(weird) == weird
def test_non_string_input_passes_through(self):
"""Non-string inputs (e.g. if pydantic ever passes an int) shouldn't
crash the normalizer — let pydantic's own type validation handle it."""
assert normalize_postgres_uri(42) == 42 # type: ignore[arg-type]
@pytest.mark.unit
class TestNormalizePgvectorConnectionString:
"""Symmetric to the POSTGRES_URI normalizer but pulls in the OPPOSITE
direction: strips the SQLAlchemy dialect prefix so libpq accepts it.
"""
@pytest.mark.parametrize(
"input_value,expected",
[
# User-friendly forms pass through — libpq accepts them natively.
(
"postgres://u:p@h:5432/d",
"postgres://u:p@h:5432/d",
),
(
"postgresql://u:p@h:5432/d",
"postgresql://u:p@h:5432/d",
),
# SQLAlchemy dialect prefixes get stripped so libpq accepts them.
# Operators hit this when they copy POSTGRES_URI → PGVECTOR_CONNECTION_STRING.
(
"postgresql+psycopg://u:p@h:5432/d",
"postgresql://u:p@h:5432/d",
),
(
"postgresql+psycopg2://u:p@h:5432/d",
"postgresql://u:p@h:5432/d",
),
# Whitespace is trimmed before rewriting.
(
" postgresql+psycopg://u:p@h/d ",
"postgresql://u:p@h/d",
),
# Query-string params (sslmode, etc.) are preserved verbatim.
(
"postgresql+psycopg://u:p@h:5432/d?sslmode=require",
"postgresql://u:p@h:5432/d?sslmode=require",
),
],
)
def test_rewrites_dialect_forms_to_libpq_compatible(self, input_value, expected):
assert normalize_pgvector_connection_string(input_value) == expected
@pytest.mark.parametrize(
"input_value",
[None, "", " ", "None", "none"],
)
def test_empty_or_none_like_returns_none(self, input_value):
assert normalize_pgvector_connection_string(input_value) is None
def test_unknown_scheme_passes_through(self):
"""A scheme we don't recognise is left alone so libpq can produce
its own error message when the connection is attempted."""
weird = "mysql://u:p@h/d"
assert normalize_pgvector_connection_string(weird) == weird
def test_non_string_input_passes_through(self):
assert normalize_pgvector_connection_string(42) == 42 # type: ignore[arg-type]

View File

@@ -0,0 +1,84 @@
"""Fixtures for integration tests that hit a live Postgres.
These tests are separate from unit tests in two ways:
1. **Directory**: They live under ``tests/integration/``, which is
ignored by the default pytest run via ``--ignore=tests/integration``
in ``pytest.ini``. The CI ``pytest.yml`` workflow therefore skips
them automatically — it runs the fast unit suite only.
2. **Marker**: Each test is marked ``@pytest.mark.integration`` so it
can be selected (or excluded) independently of its directory with
``pytest -m integration`` / ``-m "not integration"``.
Running the Postgres-backed integration tests manually::
.venv/bin/python -m pytest tests/integration/test_users_repository.py \\
--override-ini="addopts=" --no-cov
(The ``--override-ini="addopts="`` is needed because ``pytest.ini``
contains ``--ignore=tests/integration`` in ``addopts``; without the
override pytest would still skip the directory even though you named
it on the command line.)
Tests are skipped automatically if ``POSTGRES_URI`` is unset, so a
contributor who hasn't set up a local Postgres gets clean skips
instead of red tests.
"""
from __future__ import annotations
import pytest
from sqlalchemy import Engine, create_engine, text
from application.core.settings import settings
@pytest.fixture(scope="session")
def pg_engine() -> Engine:
"""Session-scoped SQLAlchemy engine for the Postgres integration DB.
Skips all Postgres-backed tests if ``POSTGRES_URI`` is unset. This
keeps CI and contributor machines without a local Postgres from
erroring out — integration tests that require the DB become
no-ops rather than failures.
"""
if not settings.POSTGRES_URI:
pytest.skip("POSTGRES_URI not set — skipping Postgres integration tests")
engine = create_engine(settings.POSTGRES_URI, future=True, pool_pre_ping=True)
try:
yield engine
finally:
engine.dispose()
@pytest.fixture
def pg_conn(pg_engine: Engine):
"""Per-test Postgres connection wrapped in a rolled-back transaction.
Uses SQLAlchemy's explicit outer-transaction pattern so every test
sees a pristine DB view without having to truncate tables. Any
nested ``begin()`` inside the repository code becomes a SAVEPOINT
under the hood.
"""
conn = pg_engine.connect()
outer = conn.begin()
try:
yield conn
finally:
outer.rollback()
conn.close()
@pytest.fixture
def pg_clean_users(pg_conn):
"""Guarantee a clean ``users`` table view for tests that need it.
The outer transaction rollback handles cleanup, but if a previous
interrupted run left rows committed, this fixture removes them
inside the transaction scope so they are invisible to the test.
``DELETE`` rather than ``TRUNCATE`` because ``TRUNCATE`` in Postgres
cannot be rolled back within a nested transaction the way
``DELETE`` can.
"""
pg_conn.execute(text("DELETE FROM users"))
return pg_conn

View File

@@ -0,0 +1,179 @@
"""Integration tests for ``UsersRepository`` against a live Postgres.
These tests need:
* A running Postgres reachable via ``POSTGRES_URI``
* Alembic migration ``0001_initial`` applied
They are skipped automatically by the default ``pytest`` run because
``pytest.ini`` has ``--ignore=tests/integration`` in ``addopts``. They
are additionally marked ``@pytest.mark.integration`` so they can be
selected/excluded explicitly. Run them locally with::
.venv/bin/python -m pytest tests/integration/test_users_repository.py \\
--override-ini="addopts=" --no-cov
Covers every operation the legacy Mongo code performs on
``users_collection``: upsert + get + add/remove on the pinned and
shared_with_me JSONB arrays, plus explicit tenant-isolation checks
(a user's operations must never touch another user's data).
"""
from __future__ import annotations
import pytest
from application.storage.db.repositories.users import UsersRepository
@pytest.fixture
def repo(pg_clean_users):
return UsersRepository(pg_clean_users)
@pytest.mark.integration
class TestUpsert:
def test_upsert_creates_new_user_with_default_preferences(self, repo):
doc = repo.upsert("alice@example.com")
assert doc["user_id"] == "alice@example.com"
assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []}
assert "id" in doc
assert "_id" in doc # Mongo-compat alias
def test_upsert_is_idempotent(self, repo):
first = repo.upsert("alice@example.com")
second = repo.upsert("alice@example.com")
assert first["id"] == second["id"]
assert first["agent_preferences"] == second["agent_preferences"]
def test_upsert_preserves_existing_preferences(self, repo):
repo.upsert("alice@example.com")
repo.add_pinned("alice@example.com", "agent-1")
doc = repo.upsert("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
@pytest.mark.integration
class TestGet:
def test_get_returns_none_for_missing_user(self, repo):
assert repo.get("nobody@example.com") is None
def test_get_returns_dict_for_existing_user(self, repo):
repo.upsert("alice@example.com")
doc = repo.get("alice@example.com")
assert doc is not None
assert doc["user_id"] == "alice@example.com"
assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []}
@pytest.mark.integration
class TestAddPinned:
def test_add_pinned_to_new_user_creates_user(self, repo):
repo.upsert("alice@example.com")
repo.add_pinned("alice@example.com", "agent-1")
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
def test_add_pinned_is_idempotent(self, repo):
repo.upsert("alice@example.com")
repo.add_pinned("alice@example.com", "agent-1")
repo.add_pinned("alice@example.com", "agent-1")
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
def test_add_pinned_preserves_order(self, repo):
repo.upsert("alice@example.com")
repo.add_pinned("alice@example.com", "agent-1")
repo.add_pinned("alice@example.com", "agent-2")
repo.add_pinned("alice@example.com", "agent-3")
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-1", "agent-2", "agent-3"]
@pytest.mark.integration
class TestRemovePinned:
def test_remove_pinned_single(self, repo):
repo.upsert("alice@example.com")
repo.add_pinned("alice@example.com", "agent-1")
repo.add_pinned("alice@example.com", "agent-2")
repo.remove_pinned("alice@example.com", "agent-1")
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-2"]
def test_remove_pinned_missing_is_noop(self, repo):
repo.upsert("alice@example.com")
repo.add_pinned("alice@example.com", "agent-1")
repo.remove_pinned("alice@example.com", "agent-999")
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
def test_remove_pinned_bulk(self, repo):
repo.upsert("alice@example.com")
for i in range(5):
repo.add_pinned("alice@example.com", f"agent-{i}")
repo.remove_pinned_bulk("alice@example.com", ["agent-1", "agent-3", "agent-999"])
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-0", "agent-2", "agent-4"]
@pytest.mark.integration
class TestSharedWithMe:
def test_add_shared_is_idempotent(self, repo):
repo.upsert("alice@example.com")
repo.add_shared("alice@example.com", "agent-x")
repo.add_shared("alice@example.com", "agent-x")
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["shared_with_me"] == ["agent-x"]
def test_remove_shared_bulk(self, repo):
repo.upsert("alice@example.com")
for i in range(3):
repo.add_shared("alice@example.com", f"shared-{i}")
repo.remove_shared_bulk("alice@example.com", ["shared-0", "shared-2"])
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["shared_with_me"] == ["shared-1"]
@pytest.mark.integration
class TestRemoveAgentFromAll:
def test_removes_from_both_pinned_and_shared(self, repo):
repo.upsert("alice@example.com")
repo.add_pinned("alice@example.com", "agent-X")
repo.add_pinned("alice@example.com", "agent-keep")
repo.add_shared("alice@example.com", "agent-X")
repo.add_shared("alice@example.com", "agent-keep-2")
repo.remove_agent_from_all("alice@example.com", "agent-X")
doc = repo.get("alice@example.com")
assert doc["agent_preferences"]["pinned"] == ["agent-keep"]
assert doc["agent_preferences"]["shared_with_me"] == ["agent-keep-2"]
@pytest.mark.integration
class TestTenantIsolation:
"""Security-critical: operations on one user must never touch another's data."""
def test_add_pinned_does_not_leak_across_users(self, repo):
repo.upsert("alice@example.com")
repo.upsert("bob@example.com")
repo.add_pinned("alice@example.com", "agent-a")
repo.add_pinned("bob@example.com", "agent-b")
alice = repo.get("alice@example.com")
bob = repo.get("bob@example.com")
assert alice["agent_preferences"]["pinned"] == ["agent-a"]
assert bob["agent_preferences"]["pinned"] == ["agent-b"]
def test_remove_does_not_leak_across_users(self, repo):
repo.upsert("alice@example.com")
repo.upsert("bob@example.com")
repo.add_pinned("alice@example.com", "shared-agent-id")
repo.add_pinned("bob@example.com", "shared-agent-id")
repo.remove_pinned("alice@example.com", "shared-agent-id")
assert repo.get("alice@example.com")["agent_preferences"]["pinned"] == []
assert repo.get("bob@example.com")["agent_preferences"]["pinned"] == [
"shared-agent-id"
]

View File

@@ -20,133 +20,49 @@ def test_epub_init_parser():
assert parser.parser_config_set
def test_epub_parser_ebooklib_import_error(epub_parser):
"""Test that ImportError is raised when ebooklib is not available."""
with patch.dict(sys.modules, {"ebooklib": None}):
with pytest.raises(ValueError, match="`EbookLib` is required to read Epub files"):
def test_epub_parser_fast_ebook_import_error(epub_parser):
"""Test that ImportError is raised when fast-ebook is not available."""
with patch.dict(sys.modules, {"fast_ebook": None}):
with pytest.raises(ValueError, match="`fast-ebook` is required to read Epub files"):
epub_parser.parse_file(Path("test.epub"))
def test_epub_parser_html2text_import_error(epub_parser):
"""Test that ImportError is raised when html2text is not available."""
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_ebooklib.epub = fake_epub
with patch.dict(sys.modules, {"ebooklib": fake_ebooklib, "ebooklib.epub": fake_epub}):
with patch.dict(sys.modules, {"html2text": None}):
with pytest.raises(ValueError, match="`html2text` is required to parse Epub files"):
epub_parser.parse_file(Path("test.epub"))
def test_epub_parser_successful_parsing(epub_parser):
"""Test successful parsing of an epub file."""
fake_fast_ebook = types.ModuleType("fast_ebook")
fake_epub = types.ModuleType("fast_ebook.epub")
fake_fast_ebook.epub = fake_epub
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_html2text = types.ModuleType("html2text")
# Mock ebooklib constants
fake_ebooklib.ITEM_DOCUMENT = "document"
fake_ebooklib.epub = fake_epub
mock_item1 = MagicMock()
mock_item1.get_type.return_value = "document"
mock_item1.get_content.return_value = b"<h1>Chapter 1</h1><p>Content 1</p>"
mock_item2 = MagicMock()
mock_item2.get_type.return_value = "document"
mock_item2.get_content.return_value = b"<h1>Chapter 2</h1><p>Content 2</p>"
mock_item3 = MagicMock()
mock_item3.get_type.return_value = "other" # Should be ignored
mock_item3.get_content.return_value = b"<p>Other content</p>"
mock_book = MagicMock()
mock_book.get_items.return_value = [mock_item1, mock_item2, mock_item3]
mock_book.to_markdown.return_value = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
fake_epub.read_epub = MagicMock(return_value=mock_book)
def mock_html2text_func(html_content):
if "Chapter 1" in html_content:
return "# Chapter 1\n\nContent 1\n"
elif "Chapter 2" in html_content:
return "# Chapter 2\n\nContent 2\n"
return "Other content\n"
fake_html2text.html2text = mock_html2text_func
with patch.dict(sys.modules, {
"ebooklib": fake_ebooklib,
"ebooklib.epub": fake_epub,
"html2text": fake_html2text
"fast_ebook": fake_fast_ebook,
"fast_ebook.epub": fake_epub,
}):
result = epub_parser.parse_file(Path("test.epub"))
expected_result = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
assert result == expected_result
# Verify epub.read_epub was called with correct parameters
fake_epub.read_epub.assert_called_once_with(Path("test.epub"), options={"ignore_ncx": True})
assert result == "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
fake_epub.read_epub.assert_called_once_with(Path("test.epub"))
def test_epub_parser_empty_book(epub_parser):
"""Test parsing an epub file with no document items."""
# Create mock modules
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_html2text = types.ModuleType("html2text")
fake_ebooklib.ITEM_DOCUMENT = "document"
fake_ebooklib.epub = fake_epub
# Create mock book with no document items
"""Test parsing an epub file with no content."""
fake_fast_ebook = types.ModuleType("fast_ebook")
fake_epub = types.ModuleType("fast_ebook.epub")
fake_fast_ebook.epub = fake_epub
mock_book = MagicMock()
mock_book.get_items.return_value = []
mock_book.to_markdown.return_value = ""
fake_epub.read_epub = MagicMock(return_value=mock_book)
fake_html2text.html2text = MagicMock()
with patch.dict(sys.modules, {
"ebooklib": fake_ebooklib,
"ebooklib.epub": fake_epub,
"html2text": fake_html2text
"fast_ebook": fake_fast_ebook,
"fast_ebook.epub": fake_epub,
}):
result = epub_parser.parse_file(Path("empty.epub"))
assert result == ""
fake_html2text.html2text.assert_not_called()
def test_epub_parser_non_document_items_ignored(epub_parser):
"""Test that non-document items are ignored during parsing."""
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_html2text = types.ModuleType("html2text")
fake_ebooklib.ITEM_DOCUMENT = "document"
fake_ebooklib.epub = fake_epub
mock_doc_item = MagicMock()
mock_doc_item.get_type.return_value = "document"
mock_doc_item.get_content.return_value = b"<p>Document content</p>"
mock_other_item = MagicMock()
mock_other_item.get_type.return_value = "image" # Not a document
mock_book = MagicMock()
mock_book.get_items.return_value = [mock_other_item, mock_doc_item]
fake_epub.read_epub = MagicMock(return_value=mock_book)
fake_html2text.html2text = MagicMock(return_value="Document content\n")
with patch.dict(sys.modules, {
"ebooklib": fake_ebooklib,
"ebooklib.epub": fake_epub,
"html2text": fake_html2text
}):
result = epub_parser.parse_file(Path("test.epub"))
assert result == "Document content\n"
fake_html2text.html2text.assert_called_once_with("<p>Document content</p>")

View File

@@ -8,7 +8,7 @@ from application.storage.local import LocalStorage
@pytest.fixture
def temp_base_dir():
return "/tmp/test_storage"
return os.path.realpath("/tmp/test_storage")
@pytest.fixture
@@ -30,12 +30,12 @@ class TestLocalStorageInitialization:
def test_get_full_path_with_relative_path(self, local_storage):
result = local_storage._get_full_path("documents/test.txt")
expected = os.path.join("/tmp/test_storage", "documents/test.txt")
assert os.path.normpath(result) == os.path.normpath(expected)
expected = os.path.realpath(os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt"))
assert result == expected
def test_get_full_path_with_absolute_path(self, local_storage):
result = local_storage._get_full_path("/absolute/path/test.txt")
assert result == "/absolute/path/test.txt"
def test_get_full_path_with_absolute_path_outside_base_raises(self, local_storage):
with pytest.raises(ValueError, match="Path traversal detected"):
local_storage._get_full_path("/absolute/path/test.txt")
@patch("os.makedirs")
@patch("builtins.open", new_callable=mock_open)
@@ -48,8 +48,8 @@ class TestLocalStorageInitialization:
result = local_storage.save_file(file_data, path)
expected_dir = os.path.join("/tmp/test_storage", "documents")
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert mock_makedirs.call_count == 1
assert os.path.normpath(mock_makedirs.call_args[0][0]) == os.path.normpath(
@@ -74,25 +74,19 @@ class TestLocalStorageInitialization:
result = local_storage.save_file(file_data, path)
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert file_data.save.call_count == 1
assert os.path.normpath(file_data.save.call_args[0][0]) == os.path.normpath(
expected_file
)
assert result == {"storage_type": "local"}
@patch("os.makedirs")
@patch("builtins.open", new_callable=mock_open)
def test_save_file_with_absolute_path(
self, mock_file, mock_makedirs, local_storage
):
def test_save_file_with_absolute_path_outside_base_raises(self, local_storage):
file_data = io.BytesIO(b"test content")
path = "/absolute/path/test.txt"
local_storage.save_file(file_data, path)
mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True)
mock_file.assert_called_once_with("/absolute/path/test.txt", "wb")
with pytest.raises(ValueError, match="Path traversal detected"):
local_storage.save_file(file_data, path)
@pytest.mark.unit
@@ -105,7 +99,7 @@ class TestLocalStorageGetFile:
result = local_storage.get_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
@@ -122,7 +116,7 @@ class TestLocalStorageGetFile:
with pytest.raises(FileNotFoundError, match="File not found"):
local_storage.get_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
@@ -141,7 +135,7 @@ class TestLocalStorageDeleteFile:
result = local_storage.delete_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is True
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -158,7 +152,7 @@ class TestLocalStorageDeleteFile:
result = local_storage.delete_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -175,7 +169,7 @@ class TestLocalStorageFileExists:
result = local_storage.file_exists(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is True
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -188,7 +182,7 @@ class TestLocalStorageFileExists:
result = local_storage.file_exists(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -205,7 +199,7 @@ class TestLocalStorageListFiles:
self, mock_exists, mock_walk, local_storage
):
directory = "documents"
base_dir = os.path.join("/tmp/test_storage", "documents")
base_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
mock_walk.return_value = [
(base_dir, ["subdir"], ["file1.txt", "file2.txt"]),
@@ -228,7 +222,7 @@ class TestLocalStorageListFiles:
result = local_storage.list_files(directory)
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
assert result == []
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -248,7 +242,7 @@ class TestLocalStorageProcessFile:
result = local_storage.process_file(path, processor_func, extra_arg="value")
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result == "processed"
assert processor_func.call_count == 1
call_kwargs = processor_func.call_args[1]
@@ -280,7 +274,7 @@ class TestLocalStorageIsDirectory:
result = local_storage.is_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is True
assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
@@ -295,7 +289,7 @@ class TestLocalStorageIsDirectory:
result = local_storage.is_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is False
assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
@@ -316,7 +310,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is True
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -339,7 +333,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -355,7 +349,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -376,7 +370,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is False
assert mock_rmtree.call_count == 1
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
@@ -393,7 +387,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is False
assert mock_rmtree.call_count == 1
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(

View File

@@ -2105,23 +2105,35 @@ class TestInternalRoutes:
app.register_blueprint(internal)
return app
_TEST_KEY = "test-key"
_AUTH_HEADERS = {"X-Internal-Key": "test-key"}
def test_upload_index_no_user(self, app):
with app.test_client() as client:
with patch(
"application.api.internal.routes.settings"
) as ms:
ms.INTERNAL_KEY = None
resp = client.post("/api/upload_index")
ms.INTERNAL_KEY = self._TEST_KEY
resp = client.post("/api/upload_index", headers=self._AUTH_HEADERS)
assert resp.get_json()["status"] == "no user"
def test_upload_index_no_name(self, app):
with app.test_client() as client:
with patch(
"application.api.internal.routes.settings"
) as ms:
ms.INTERNAL_KEY = self._TEST_KEY
resp = client.post("/api/upload_index", data={"user": "u1"}, headers=self._AUTH_HEADERS)
assert resp.get_json()["status"] == "no name"
def test_upload_index_rejected_without_internal_key(self, app):
with app.test_client() as client:
with patch(
"application.api.internal.routes.settings"
) as ms:
ms.INTERNAL_KEY = None
resp = client.post("/api/upload_index", data={"user": "u1"})
assert resp.get_json()["status"] == "no name"
assert resp.status_code == 401
# ---------------------------------------------------------------------------

View File

@@ -11,6 +11,7 @@ import pytest
from application.api.v1.translator import (
_get_client_tool_name,
convert_history,
extract_system_prompt,
extract_tool_results,
is_continuation,
translate_request,
@@ -148,6 +149,48 @@ class TestConvertHistory:
assert history == []
# ---------------------------------------------------------------------------
# extract_system_prompt
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestExtractSystemPrompt:
def test_extracts_first_system_message(self):
messages = [
{"role": "system", "content": "You are a pirate"},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == "You are a pirate"
def test_returns_none_when_no_system_message(self):
messages = [{"role": "user", "content": "Hello"}]
assert extract_system_prompt(messages) is None
def test_returns_first_of_multiple_system_messages(self):
messages = [
{"role": "system", "content": "First"},
{"role": "system", "content": "Second"},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == "First"
def test_empty_content_returns_empty_string(self):
messages = [
{"role": "system", "content": ""},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == ""
def test_missing_content_returns_empty_string(self):
messages = [
{"role": "system"},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == ""
# ---------------------------------------------------------------------------
# translate_request
# ---------------------------------------------------------------------------
@@ -167,11 +210,25 @@ class TestTranslateRequest:
result = translate_request(data, "test-key")
assert result["question"] == "What's 2+2?"
assert result["api_key"] == "test-key"
assert result["save_conversation"] is True
# Conversations are not persisted by default on the v1 endpoint.
assert result["save_conversation"] is False
history = json.loads(result["history"])
assert len(history) == 1
assert history[0]["prompt"] == "Hello"
def test_save_conversation_opt_in_via_docsgpt_extension(self):
data = {
"messages": [{"role": "user", "content": "Hi"}],
"docsgpt": {"save_conversation": True},
}
result = translate_request(data, "key")
assert result["save_conversation"] is True
def test_save_conversation_default_false(self):
data = {"messages": [{"role": "user", "content": "Hi"}]}
result = translate_request(data, "key")
assert result["save_conversation"] is False
def test_continuation_request(self):
data = {
"messages": [
@@ -237,6 +294,23 @@ class TestTranslateRequest:
result = translate_request(data, "key")
assert result["attachments"] == ["att1", "att2"]
def test_system_prompt_override_included_when_present(self):
data = {
"messages": [
{"role": "system", "content": "Custom prompt"},
{"role": "user", "content": "Hello"},
],
}
result = translate_request(data, "key")
assert result["system_prompt_override"] == "Custom prompt"
def test_system_prompt_override_absent_when_no_system_message(self):
data = {
"messages": [{"role": "user", "content": "Hello"}],
}
result = translate_request(data, "key")
assert "system_prompt_override" not in result
# ---------------------------------------------------------------------------
# translate_response

View File

@@ -363,6 +363,28 @@ class TestGetVectorstore:
assert get_vectorstore("user/source123") == "indexes/user/source123"
@pytest.mark.parametrize(
"malicious_path",
[
"../outside",
"../../etc/passwd",
"nested/../../../outside",
"/tmp/evil",
"..\\outside",
"valid/../../escape",
],
)
def test_rejects_path_traversal(self, malicious_path):
from application.vectorstore.faiss import get_vectorstore
with pytest.raises(ValueError, match="Invalid source_id path"):
get_vectorstore(malicious_path)
def test_allows_mongodb_style_ids(self):
from application.vectorstore.faiss import get_vectorstore
assert get_vectorstore("65e8f6a8a7a96b1bdad4154f") == "indexes/65e8f6a8a7a96b1bdad4154f"
@pytest.mark.unit
class TestFaissStoreAddChunk:

View File

@@ -16,10 +16,9 @@ def _make_store(
) as mock_settings, patch.dict(
"sys.modules",
{
"psycopg2": MagicMock(),
"psycopg2.extras": MagicMock(),
"psycopg": MagicMock(),
"pgvector": MagicMock(),
"pgvector.psycopg2": MagicMock(),
"pgvector.psycopg": MagicMock(),
},
):
mock_emb = Mock()
@@ -63,10 +62,9 @@ class TestPGVectorStoreInit:
) as mock_settings, patch.dict(
"sys.modules",
{
"psycopg2": MagicMock(),
"psycopg2.extras": MagicMock(),
"psycopg": MagicMock(),
"pgvector": MagicMock(),
"pgvector.psycopg2": MagicMock(),
"pgvector.psycopg": MagicMock(),
},
):
mock_get_emb.return_value = Mock(dimension=768)
@@ -264,13 +262,13 @@ class TestPGVectorStoreConnection:
store, mock_conn, _, _ = _make_store()
mock_conn.closed = True
mock_psycopg2 = MagicMock()
mock_psycopg = MagicMock()
new_conn = MagicMock()
mock_psycopg2.connect.return_value = new_conn
store._psycopg2 = mock_psycopg2
mock_psycopg.connect.return_value = new_conn
store._psycopg = mock_psycopg
conn = store._get_connection()
mock_psycopg2.connect.assert_called_once()
mock_psycopg.connect.assert_called_once()
assert conn is new_conn
def test_get_connection_reuses_open(self):