mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 14:34:32 +00:00
Compare commits
35 Commits
messages-f
...
pg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59d9d4ac50 | ||
|
|
55717043f6 | ||
|
|
ececcb8b17 | ||
|
|
bd03a513e3 | ||
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db | ||
|
|
689dd79597 | ||
|
|
0c15af90b1 | ||
|
|
cdd6ff6557 | ||
|
|
72b3d94453 | ||
|
|
7e88d09e5d | ||
|
|
74a4a237dc | ||
|
|
c3f01c6619 | ||
|
|
6b408823d4 | ||
|
|
3fc81ac5d8 | ||
|
|
2652f8a5b0 | ||
|
|
d711eefe96 | ||
|
|
79206f3919 | ||
|
|
de971d9452 | ||
|
|
1b4d5ca0dd | ||
|
|
81989e8258 | ||
|
|
dc262d1698 | ||
|
|
69f9c93869 | ||
|
|
74bf80b25c | ||
|
|
d9a92a7208 | ||
|
|
02e93d993d | ||
|
|
6b6495f48c | ||
|
|
249dd9ce37 | ||
|
|
9134ab0478 | ||
|
|
10ef68c9d0 | ||
|
|
7d65cf1c2b | ||
|
|
13c6cc59c1 | ||
|
|
648b3f1d20 | ||
|
|
a75a9e23f9 |
@@ -34,3 +34,9 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||
# Leave unset while the migration is still being rolled out; the app will
|
||||
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
@@ -44,3 +44,8 @@ boolean
|
||||
bool
|
||||
hardcode
|
||||
EOL
|
||||
Postgres
|
||||
Supabase
|
||||
config
|
||||
backfill
|
||||
backfills
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -108,6 +108,8 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||
CLAUDE.md
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
MinAlertLevel = warning
|
||||
StylesPath = .github/styles
|
||||
Vocab = DocsGPT
|
||||
|
||||
[*.{md,mdx}]
|
||||
BasedOnStyles = DocsGPT
|
||||
|
||||
|
||||
18
SECURITY.md
18
SECURITY.md
@@ -2,13 +2,21 @@
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Supported Versions:
|
||||
|
||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
||||
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Found a vulnerability? Please email us:
|
||||
Preferred method: use GitHub's private vulnerability reporting flow:
|
||||
https://github.com/arc53/DocsGPT/security
|
||||
|
||||
security@arc53.com
|
||||
Then click **Report a vulnerability**.
|
||||
|
||||
|
||||
Alternatively, email us at: security@arc53.com
|
||||
|
||||
We aim to acknowledge reports within 48 hours.
|
||||
|
||||
## Incident Handling
|
||||
|
||||
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -61,7 +62,8 @@ class MCPTool(Tool):
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
self.server_url = config.get("server_url", "")
|
||||
raw_url = config.get("server_url", "")
|
||||
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
@@ -87,6 +89,18 @@ class MCPTool(Tool):
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
@staticmethod
|
||||
def _validate_server_url(server_url: str) -> str:
|
||||
"""Validate server_url to prevent SSRF to internal networks.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL points to a private/internal address.
|
||||
"""
|
||||
try:
|
||||
return validate_url(server_url)
|
||||
except SSRFError as exc:
|
||||
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
|
||||
|
||||
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
||||
if configured_redirect_uri:
|
||||
return configured_redirect_uri.rstrip("/")
|
||||
@@ -108,8 +122,9 @@ class MCPTool(Tool):
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
|
||||
auth_key = (
|
||||
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
)
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
@@ -33,7 +33,7 @@ class PostgresTool(Tool):
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql_query)
|
||||
conn.commit()
|
||||
@@ -60,7 +60,7 @@ class PostgresTool(Tool):
|
||||
"response_data": response_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||
return {
|
||||
@@ -78,7 +78,7 @@ class PostgresTool(Tool):
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute(
|
||||
@@ -120,7 +120,7 @@ class PostgresTool(Tool):
|
||||
"schema": schema_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL get_schema error: %s", e)
|
||||
return {
|
||||
|
||||
52
application/alembic.ini
Normal file
52
application/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
||||
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||
#
|
||||
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||
# source serves the running app and migrations. To run from the project
|
||||
# root::
|
||||
#
|
||||
# alembic -c application/alembic.ini upgrade head
|
||||
|
||||
[alembic]
|
||||
script_location = %(here)s/alembic
|
||||
prepend_sys_path = ..
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||
sqlalchemy.url =
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
82
application/alembic/env.py
Normal file
82
application/alembic/env.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||
|
||||
The URL is pulled from ``application.core.settings`` rather than
|
||||
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||
running app and ``alembic`` CLI invocations.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
# Make the project root importable regardless of cwd. env.py lives at
|
||||
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||
|
||||
config = context.config
|
||||
|
||||
# Populate the runtime URL from settings.
|
||||
if settings.POSTGRES_URI:
|
||||
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
if not url:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode against a live connection."""
|
||||
if not config.get_main_option("sqlalchemy.url"):
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
future=True,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
application/alembic/script.py.mako
Normal file
26
application/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
462
application/alembic/versions/0001_initial.py
Normal file
462
application/alembic/versions/0001_initial.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""0001 initial schema — user-level tables migrated from MongoDB.
|
||||
|
||||
Creates every table described in §2.2 of ``migration-postgres.md``: tiers 1,
|
||||
2, and 3 in one shot. The schema is small enough that splitting the baseline
|
||||
across multiple revisions would only cost clarity.
|
||||
|
||||
Subsequent migrations will add columns / tables incrementally. This file is
|
||||
hand-written raw DDL rather than Core ``op.create_table`` calls because the
|
||||
DDL uses several Postgres-specific features (``CITEXT``, partial indexes,
|
||||
``text_pattern_ops``, JSONB defaults) that are clearer in SQL than in
|
||||
Alembic's Python API.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
Create Date: 2026-04-10
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0001_initial"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Extensions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 1: leaf tables, no FKs into other migrated tables
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("""
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL UNIQUE,
|
||||
agent_preferences JSONB NOT NULL
|
||||
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX users_user_id_idx ON users (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE prompts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE user_tools (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
custom_name TEXT,
|
||||
display_name TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE token_usage (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
agent_id UUID, -- FK added later in this migration
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, timestamp DESC);")
|
||||
op.execute("CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, timestamp DESC);")
|
||||
op.execute("CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, timestamp DESC);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE user_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
endpoint TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
data JSONB
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, timestamp DESC);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE feedback (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL, -- FK added later in this migration
|
||||
user_id TEXT NOT NULL,
|
||||
question_index INTEGER NOT NULL,
|
||||
feedback_text TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX feedback_conv_idx ON feedback (conversation_id);")
|
||||
|
||||
# Append-only debug/error log. The Mongo doc has both `_id` (auto) and an
|
||||
# `id` field (the activity id). Here the serial PK owns `id`; the
|
||||
# application-level identifier is renamed to `activity_id`.
|
||||
op.execute("""
|
||||
CREATE TABLE stack_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
activity_id TEXT NOT NULL,
|
||||
endpoint TEXT,
|
||||
level TEXT,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
query TEXT,
|
||||
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX stack_logs_timestamp_idx ON stack_logs (timestamp DESC);")
|
||||
op.execute("CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, timestamp DESC);")
|
||||
op.execute("CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, timestamp DESC);")
|
||||
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 2: FK-bearing tables
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("""
|
||||
CREATE TABLE agent_folders (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT, -- NULL for system/template sources
|
||||
name TEXT NOT NULL,
|
||||
type TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE agents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
agent_type TEXT,
|
||||
status TEXT NOT NULL,
|
||||
key CITEXT UNIQUE,
|
||||
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||
chunks INTEGER,
|
||||
retriever TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
json_schema JSONB,
|
||||
models JSONB,
|
||||
default_model_id TEXT,
|
||||
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
token_limit INTEGER,
|
||||
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
request_limit INTEGER,
|
||||
shared BOOLEAN NOT NULL DEFAULT false,
|
||||
incoming_webhook_token CITEXT UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_used_at TIMESTAMPTZ
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||
|
||||
# Backfill the token_usage.agent_id FK now that agents exists.
|
||||
op.execute("""
|
||||
ALTER TABLE token_usage
|
||||
ADD CONSTRAINT token_usage_agent_fk
|
||||
FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE attachments (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
upload_path TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
size BIGINT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX memories_user_tool_path_uidx
|
||||
ON memories (user_id, tool_id, path);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX memories_path_prefix_idx
|
||||
ON memories (user_id, tool_id, path text_pattern_ops);
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE todos (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE notes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX notes_user_tool_idx ON notes (user_id, tool_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE connector_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
session_data JSONB NOT NULL,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX connector_sessions_user_provider_idx
|
||||
ON connector_sessions (user_id, provider);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX connector_sessions_expiry_idx
|
||||
ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;
|
||||
""")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 3: conversations, pending_tool_state, workflows
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("""
|
||||
CREATE TABLE conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||
name TEXT,
|
||||
api_key TEXT,
|
||||
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||
op.execute("""
|
||||
CREATE INDEX conversations_shared_token_idx
|
||||
ON conversations (shared_token) WHERE shared_token IS NOT NULL;
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE conversation_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
position INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
thought TEXT,
|
||||
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
attachments UUID[] NOT NULL DEFAULT '{}',
|
||||
model_id TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
feedback JSONB,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx
|
||||
ON conversation_messages (conversation_id, position);
|
||||
""")
|
||||
|
||||
# Backfill the feedback.conversation_id FK now that conversations exists.
|
||||
op.execute("""
|
||||
ALTER TABLE feedback
|
||||
ADD CONSTRAINT feedback_conv_fk
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE;
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE shared_conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
chunks INTEGER,
|
||||
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||
|
||||
# Paused-tool continuation state. The Mongo version relies on a TTL index;
|
||||
# Postgres has no native TTL, so a Celery beat task (added in Phase 3)
|
||||
# deletes rows where expires_at < now() once a minute. The unique
|
||||
# constraint on (conversation_id, user_id) matches the existing upsert
|
||||
# semantics.
|
||||
op.execute("""
|
||||
CREATE TABLE pending_tool_state (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
messages JSONB NOT NULL,
|
||||
pending_tool_calls JSONB NOT NULL,
|
||||
tools_dict JSONB NOT NULL,
|
||||
tool_schemas JSONB NOT NULL,
|
||||
agent_config JSONB NOT NULL,
|
||||
client_tools JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx
|
||||
ON pending_tool_state (conversation_id, user_id);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX pending_tool_state_expires_idx
|
||||
ON pending_tool_state (expires_at);
|
||||
""")
|
||||
|
||||
# Workflows
|
||||
op.execute("""
|
||||
CREATE TABLE workflows (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE workflow_nodes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX workflow_nodes_workflow_version_idx
|
||||
ON workflow_nodes (workflow_id, graph_version);
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE workflow_edges (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
from_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
|
||||
to_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX workflow_edges_workflow_version_idx
|
||||
ON workflow_edges (workflow_id, graph_version);
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE workflow_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
ended_at TIMESTAMPTZ,
|
||||
result JSONB
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse dependency order. CASCADE would handle FKs anyway, but explicit
|
||||
# is clearer for anyone reading the migration.
|
||||
op.execute("DROP TABLE IF EXISTS workflow_runs CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS workflow_edges CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS workflow_nodes CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS workflows CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS pending_tool_state CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS shared_conversations CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS conversation_messages CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS conversations CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS connector_sessions CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS notes CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS todos CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS memories CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS attachments CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS agents CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS sources CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS agent_folders CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS stack_logs CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS feedback CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS user_logs CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS token_usage CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS user_tools CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS prompts CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS users CASCADE;")
|
||||
# Extensions are intentionally left in place — they may be shared with
|
||||
# pgvector or other extensions already enabled on the cluster.
|
||||
@@ -85,6 +85,10 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
stream = self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
|
||||
@@ -92,6 +92,14 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question="",
|
||||
|
||||
@@ -112,6 +112,7 @@ class StreamProcessor:
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
self.compressed_summary: Optional[str] = None
|
||||
self.compressed_summary_tokens: int = 0
|
||||
self._agent_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
@@ -359,22 +360,29 @@ class StreamProcessor:
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
"""Configure the source based on agent data.
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
The literal string ``"default"`` is a placeholder meaning "no
|
||||
ingested source" and is normalized to an empty source so that no
|
||||
retrieval is attempted.
|
||||
"""
|
||||
if self._agent_data:
|
||||
agent_data = self._agent_data
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
source["id"]
|
||||
for source in agent_data["sources"]
|
||||
if source.get("id") and source["id"] != "default"
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.all_sources = [
|
||||
s for s in agent_data["sources"] if s.get("id") != "default"
|
||||
]
|
||||
elif agent_data.get("source") and agent_data["source"] != "default":
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
@@ -387,11 +395,24 @@ class StreamProcessor:
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
active_docs = self.data["active_docs"]
|
||||
if active_docs and active_docs != "default":
|
||||
self.source = {"active_docs": active_docs}
|
||||
else:
|
||||
self.source = {}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _has_active_docs(self) -> bool:
|
||||
"""Return True if a real document source is configured for retrieval."""
|
||||
active_docs = self.source.get("active_docs") if self.source else None
|
||||
if not active_docs:
|
||||
return False
|
||||
if active_docs == "default":
|
||||
return False
|
||||
return True
|
||||
|
||||
def _resolve_agent_id(self) -> Optional[str]:
|
||||
"""Resolve agent_id from request, then fall back to conversation context."""
|
||||
request_agent_id = self.data.get("agent_id")
|
||||
@@ -433,48 +454,39 @@ class StreamProcessor:
|
||||
effective_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if effective_key:
|
||||
data_key = self._get_data_from_api_key(effective_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self._agent_data = self._get_data_from_api_key(effective_key)
|
||||
if self._agent_data.get("_id"):
|
||||
self.agent_id = str(self._agent_data.get("_id"))
|
||||
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"prompt_id": self._agent_data.get("prompt_id", "default"),
|
||||
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": effective_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
"models": data_key.get("models", []),
|
||||
"json_schema": self._agent_data.get("json_schema"),
|
||||
"default_model_id": self._agent_data.get("default_model_id", ""),
|
||||
"models": self._agent_data.get("models", []),
|
||||
"allow_system_prompt_override": self._agent_data.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Set identity context
|
||||
if self.data.get("api_key"):
|
||||
# External API key: use the key owner's identity
|
||||
self.initial_user_id = data_key.get("user")
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
self.initial_user_id = self._agent_data.get("user")
|
||||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||
elif self.is_shared_usage:
|
||||
# Shared agent: keep the caller's identity
|
||||
pass
|
||||
else:
|
||||
# Owner using their own agent
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("workflow"):
|
||||
self.agent_config["workflow"] = data_key["workflow"]
|
||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
if self._agent_data.get("workflow"):
|
||||
self.agent_config["workflow"] = self._agent_data["workflow"]
|
||||
self.agent_config["workflow_owner"] = self._agent_data.get("user")
|
||||
else:
|
||||
# No API key — default/workflow configuration
|
||||
agent_type = settings.AGENT_NAME
|
||||
@@ -497,14 +509,45 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Assemble retriever config with precedence: request > agent > default."""
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||
|
||||
# Start with defaults
|
||||
retriever_name = "classic"
|
||||
chunks = 2
|
||||
|
||||
# Layer agent-level config (if present)
|
||||
if self._agent_data:
|
||||
if self._agent_data.get("retriever"):
|
||||
retriever_name = self._agent_data["retriever"]
|
||||
if self._agent_data.get("chunks") is not None:
|
||||
try:
|
||||
chunks = int(self._agent_data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
# Explicit request values win over agent config
|
||||
if "retriever" in self.data:
|
||||
retriever_name = self.data["retriever"]
|
||||
if "chunks" in self.data:
|
||||
try:
|
||||
chunks = int(self.data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"retriever_name": retriever_name,
|
||||
"chunks": chunks,
|
||||
"doc_token_limit": doc_token_limit,
|
||||
}
|
||||
|
||||
# isNoneDoc without an API key forces no retrieval
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
@@ -528,6 +571,9 @@ class StreamProcessor:
|
||||
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||
return None, None
|
||||
if not self._has_active_docs():
|
||||
logger.info("Pre-fetch skipped: no active docs configured")
|
||||
return None, None
|
||||
try:
|
||||
retriever = self.create_retriever()
|
||||
logger.info(
|
||||
@@ -910,15 +956,23 @@ class StreamProcessor:
|
||||
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
||||
self._prompt_content = raw_prompt
|
||||
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
# Allow API callers to override the system prompt when the agent
|
||||
# has opted in via allow_system_prompt_override.
|
||||
if (
|
||||
self.agent_config.get("allow_system_prompt_override", False)
|
||||
and self.data.get("system_prompt_override")
|
||||
):
|
||||
rendered_prompt = self.data["system_prompt_override"]
|
||||
else:
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
|
||||
@@ -26,12 +26,20 @@ internal = Blueprint("internal", __name__)
|
||||
|
||||
@internal.before_request
|
||||
def verify_internal_key():
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests."""
|
||||
if settings.INTERNAL_KEY:
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests.
|
||||
|
||||
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
|
||||
"""
|
||||
if not settings.INTERNAL_KEY:
|
||||
logger.warning(
|
||||
f"Internal API request rejected from {request.remote_addr}: "
|
||||
"INTERNAL_KEY is not configured"
|
||||
)
|
||||
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
|
||||
|
||||
@internal.route("/api/download", methods=["get"])
|
||||
|
||||
@@ -23,6 +23,8 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
@@ -73,6 +75,7 @@ AGENT_TYPE_SCHEMAS = {
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"allow_system_prompt_override",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
@@ -96,6 +99,7 @@ AGENT_TYPE_SCHEMAS = {
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"allow_system_prompt_override",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
@@ -220,6 +224,12 @@ def build_agent_document(
|
||||
base_doc["request_limit"] = int(
|
||||
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
if "allow_system_prompt_override" in allowed_fields:
|
||||
base_doc["allow_system_prompt_override"] = (
|
||||
data.get("allow_system_prompt_override") == "True"
|
||||
if isinstance(data.get("allow_system_prompt_override"), str)
|
||||
else bool(data.get("allow_system_prompt_override", False))
|
||||
)
|
||||
return {k: v for k, v in base_doc.items() if k in allowed_fields}
|
||||
|
||||
|
||||
@@ -292,6 +302,9 @@ class GetAgent(Resource):
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
"allow_system_prompt_override": agent.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -373,6 +386,9 @@ class GetAgents(Resource):
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
"allow_system_prompt_override": agent.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent
|
||||
@@ -450,6 +466,10 @@ class CreateAgent(Resource):
|
||||
"folder_id": fields.String(
|
||||
required=False, description="Folder ID to organize the agent"
|
||||
),
|
||||
"allow_system_prompt_override": fields.Boolean(
|
||||
required=False,
|
||||
description="Allow API callers to override the system prompt via the v1 endpoint",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -491,9 +511,9 @@ class CreateAgent(Resource):
|
||||
data["json_schema"] = normalize_json_schema_payload(
|
||||
data.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
except JsonSchemaValidationError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
jsonify({"success": False, "message": "Invalid JSON schema"}),
|
||||
400,
|
||||
)
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
@@ -674,6 +694,10 @@ class UpdateAgent(Resource):
|
||||
"folder_id": fields.String(
|
||||
required=False, description="Folder ID to organize the agent"
|
||||
),
|
||||
"allow_system_prompt_override": fields.Boolean(
|
||||
required=False,
|
||||
description="Allow API callers to override the system prompt via the v1 endpoint",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -765,6 +789,7 @@ class UpdateAgent(Resource):
|
||||
"default_model_id",
|
||||
"folder_id",
|
||||
"workflow",
|
||||
"allow_system_prompt_override",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -872,9 +897,9 @@ class UpdateAgent(Resource):
|
||||
update_fields[field] = normalize_json_schema_payload(
|
||||
json_schema
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
except JsonSchemaValidationError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
jsonify({"success": False, "message": "Invalid JSON schema"}),
|
||||
400,
|
||||
)
|
||||
else:
|
||||
@@ -983,6 +1008,13 @@ class UpdateAgent(Resource):
|
||||
if workflow_error:
|
||||
return workflow_error
|
||||
update_fields[field] = workflow_id
|
||||
elif field == "allow_system_prompt_override":
|
||||
raw_value = data.get("allow_system_prompt_override", False)
|
||||
update_fields[field] = (
|
||||
raw_value == "True"
|
||||
if isinstance(raw_value, str)
|
||||
else bool(raw_value)
|
||||
)
|
||||
else:
|
||||
value = data[field]
|
||||
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||
@@ -1220,6 +1252,9 @@ class PinnedAgents(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, ids=stale_ids: repo.remove_pinned_bulk(uid, ids)
|
||||
)
|
||||
list_pinned_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
@@ -1351,12 +1386,18 @@ class PinAgent(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pull": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.remove_pinned(uid, aid)
|
||||
)
|
||||
action = "unpinned"
|
||||
else:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.add_pinned(uid, aid)
|
||||
)
|
||||
action = "pinned"
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
|
||||
@@ -1402,6 +1443,9 @@ class RemoveSharedAgent(Resource):
|
||||
}
|
||||
},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.remove_agent_from_all(uid, aid)
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
||||
except Exception as err:
|
||||
|
||||
@@ -18,6 +18,8 @@ from application.api.user.base import (
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
@@ -105,6 +107,9 @@ class SharedAgent(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
@@ -139,6 +144,9 @@ class SharedAgents(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
|
||||
@@ -612,6 +612,10 @@ class LiveSpeechToTextFinish(Resource):
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
try:
|
||||
from application.api.user.base import storage
|
||||
|
||||
@@ -629,6 +633,10 @@ class ServeImage(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
|
||||
@@ -15,6 +15,8 @@ from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
@@ -132,6 +134,9 @@ def ensure_user_doc(user_id):
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
|
||||
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
|
||||
|
||||
return user_doc
|
||||
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class ShareConversation(Resource):
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
{"_id": ObjectId(conversation_id), "user": user}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
|
||||
@@ -463,6 +463,16 @@ class ManageSourceFiles(Resource):
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid file path",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
|
||||
@@ -14,6 +14,7 @@ from application.api.user.tools.routes import transform_actions
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields
|
||||
|
||||
@@ -63,6 +64,21 @@ def _extract_auth_credentials(config):
|
||||
return auth_credentials
|
||||
|
||||
|
||||
def _validate_mcp_server_url(config: dict) -> None:
|
||||
"""Validate the server_url in an MCP config to prevent SSRF.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL is missing or points to a blocked address.
|
||||
"""
|
||||
server_url = (config.get("server_url") or "").strip()
|
||||
if not server_url:
|
||||
raise ValueError("server_url is required")
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError as exc:
|
||||
raise ValueError(f"Invalid server URL: {exc}") from exc
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
@@ -97,6 +113,8 @@ class TestMCPServerConfig(Resource):
|
||||
400,
|
||||
)
|
||||
|
||||
_validate_mcp_server_url(config)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
@@ -105,15 +123,41 @@ class TestMCPServerConfig(Resource):
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
if result.get("requires_oauth"):
|
||||
return make_response(jsonify(result), 200)
|
||||
safe_result = {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k in ("success", "requires_oauth", "auth_url")
|
||||
}
|
||||
return make_response(jsonify(safe_result), 200)
|
||||
|
||||
if not result.get("success") and "message" in result:
|
||||
if not result.get("success"):
|
||||
current_app.logger.error(
|
||||
f"MCP connection test failed: {result.get('message')}"
|
||||
)
|
||||
result["message"] = "Connection test failed"
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Connection test failed",
|
||||
"tools_count": 0,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
safe_result = {
|
||||
"success": True,
|
||||
"message": result.get("message", "Connection successful"),
|
||||
"tools_count": result.get("tools_count", 0),
|
||||
"tools": result.get("tools", []),
|
||||
}
|
||||
return make_response(jsonify(safe_result), 200)
|
||||
except ValueError as e:
|
||||
current_app.logger.warning(f"Invalid MCP server test request: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
@@ -165,6 +209,8 @@ class MCPServerSave(Resource):
|
||||
400,
|
||||
)
|
||||
|
||||
_validate_mcp_server_url(config)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
auth_type = config.get("auth_type", "none")
|
||||
mcp_config = config.copy()
|
||||
@@ -279,6 +325,12 @@ class MCPServerSave(Resource):
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except ValueError as e:
|
||||
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
|
||||
@@ -8,6 +8,7 @@ from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
@@ -130,6 +131,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
if not request.decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
@@ -236,6 +239,16 @@ class CreateTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["name"] == "mcp_tool":
|
||||
server_url = (data.get("config", {}).get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
@@ -421,6 +434,16 @@ class UpdateToolConfig(Resource):
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
if tool_name == "mcp_tool":
|
||||
server_url = (data["config"].get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
|
||||
@@ -138,10 +138,18 @@ def chat_completions():
|
||||
if usage_error:
|
||||
return usage_error
|
||||
|
||||
should_save_conversation = bool(internal_data.get("save_conversation", False))
|
||||
|
||||
if is_stream:
|
||||
return Response(
|
||||
_stream_response(
|
||||
helper, question, agent, processor, model_name, continuation
|
||||
helper,
|
||||
question,
|
||||
agent,
|
||||
processor,
|
||||
model_name,
|
||||
continuation,
|
||||
should_save_conversation,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
@@ -151,7 +159,13 @@ def chat_completions():
|
||||
)
|
||||
else:
|
||||
return _non_stream_response(
|
||||
helper, question, agent, processor, model_name, continuation
|
||||
helper,
|
||||
question,
|
||||
agent,
|
||||
processor,
|
||||
model_name,
|
||||
continuation,
|
||||
should_save_conversation,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
@@ -181,6 +195,7 @@ def _stream_response(
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
should_save_conversation: bool,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate translated SSE chunks for streaming response."""
|
||||
completion_id = f"chatcmpl-{int(time.time())}"
|
||||
@@ -193,6 +208,7 @@ def _stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
@@ -225,6 +241,7 @@ def _non_stream_response(
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
should_save_conversation: bool,
|
||||
) -> Response:
|
||||
"""Collect full response and return as single JSON."""
|
||||
stream = helper.complete_stream(
|
||||
@@ -235,6 +252,7 @@ def _non_stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
@@ -293,8 +311,9 @@ def list_models():
|
||||
for ag in user_agents:
|
||||
created = ag.get("createdAt")
|
||||
created_ts = int(created.timestamp()) if created else int(time.time())
|
||||
model_id = str(ag.get("_id") or ag.get("id") or "")
|
||||
models.append({
|
||||
"id": str(ag.get("key", "")),
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": created_ts,
|
||||
"owned_by": "docsgpt",
|
||||
|
||||
@@ -80,6 +80,17 @@ def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
|
||||
"""Extract the first system message content from the messages array.
|
||||
|
||||
Returns None if no system message is present.
|
||||
"""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
return msg.get("content", "")
|
||||
return None
|
||||
|
||||
|
||||
def convert_history(messages: List[Dict]) -> List[Dict]:
|
||||
"""Convert chat completions messages array to DocsGPT history format.
|
||||
|
||||
@@ -148,20 +159,27 @@ def translate_request(
|
||||
break
|
||||
|
||||
history = convert_history(messages)
|
||||
system_prompt_override = extract_system_prompt(messages)
|
||||
|
||||
docsgpt = data.get("docsgpt", {})
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
"api_key": api_key,
|
||||
"history": json.dumps(history),
|
||||
"save_conversation": True,
|
||||
# Conversations are NOT persisted by default on the v1 endpoint.
|
||||
# Callers opt in via ``docsgpt.save_conversation: true``.
|
||||
"save_conversation": bool(docsgpt.get("save_conversation", False)),
|
||||
}
|
||||
|
||||
if system_prompt_override is not None:
|
||||
result["system_prompt_override"] = system_prompt_override
|
||||
|
||||
# Client tools
|
||||
if data.get("tools"):
|
||||
result["client_tools"] = data["tools"]
|
||||
|
||||
# DocsGPT extensions
|
||||
docsgpt = data.get("docsgpt", {})
|
||||
if docsgpt.get("attachments"):
|
||||
result["attachments"] = docsgpt["attachments"]
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from celery import Celery
|
||||
from application.core.settings import settings
|
||||
from celery.signals import setup_logging
|
||||
from celery.signals import setup_logging, worker_process_init
|
||||
|
||||
|
||||
def make_celery(app_name=__name__):
|
||||
@@ -20,5 +20,24 @@ def config_loggers(*args, **kwargs):
|
||||
setup_logging()
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
|
||||
|
||||
SQLAlchemy connection pools are not fork-safe: file descriptors shared
|
||||
between the parent and a forked worker will corrupt the pool. Disposing
|
||||
on ``worker_process_init`` gives every worker its own fresh pool on
|
||||
first use.
|
||||
|
||||
Imported lazily so Celery workers that don't touch Postgres (or where
|
||||
``POSTGRES_URI`` is unset) don't fail at startup.
|
||||
"""
|
||||
try:
|
||||
from application.storage.db.engine import dispose_engine
|
||||
except Exception:
|
||||
return
|
||||
dispose_engine()
|
||||
|
||||
|
||||
celery = make_celery()
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
|
||||
89
application/core/db_uri.py
Normal file
89
application/core/db_uri.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Normalize user-supplied Postgres URIs for different drivers.
|
||||
|
||||
DocsGPT has two Postgres connection strings pointing at potentially
|
||||
different databases:
|
||||
|
||||
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
|
||||
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
|
||||
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
|
||||
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
|
||||
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
|
||||
dialect prefix is an invalid URI from its point of view.
|
||||
|
||||
The two fields therefore need opposite normalization so operators don't
|
||||
have to know which driver a given field feeds. Each normalizer also
|
||||
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
|
||||
psycopg2 is no longer in the project.
|
||||
|
||||
This module is deliberately separate from ``application/core/settings.py``
|
||||
so the Settings class stays focused on field declarations, and the
|
||||
URI-rewriting logic can be unit-tested without triggering ``.env``
|
||||
file loading from importing Settings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _rewrite_uri_prefixes(v, rewrites):
|
||||
"""Shared URI prefix rewriter used by both normalizers below.
|
||||
|
||||
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
|
||||
applies the first matching rewrite, and passes unrecognised input
|
||||
through so downstream consumers (SQLAlchemy, libpq) can produce
|
||||
their own error messages rather than us silently eating a
|
||||
misconfiguration.
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
v = v.strip()
|
||||
if not v or v.lower() == "none":
|
||||
return None
|
||||
for prefix, target in rewrites:
|
||||
if v.startswith(prefix):
|
||||
return target + v[len(prefix):]
|
||||
return v
|
||||
|
||||
|
||||
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
|
||||
# dialect prefix to select the psycopg v3 driver. Normalize the
|
||||
# operator-friendly forms TOWARD that dialect.
|
||||
_POSTGRES_URI_REWRITES = (
|
||||
("postgresql+psycopg2://", "postgresql+psycopg://"),
|
||||
("postgresql://", "postgresql+psycopg://"),
|
||||
("postgres://", "postgresql+psycopg://"),
|
||||
)
|
||||
|
||||
|
||||
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
|
||||
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
|
||||
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
|
||||
# dialect prefix is an invalid URI from libpq's point of view. Strip it
|
||||
# if the operator accidentally copied their POSTGRES_URI value here.
|
||||
_PGVECTOR_CONNECTION_STRING_REWRITES = (
|
||||
("postgresql+psycopg2://", "postgresql://"),
|
||||
("postgresql+psycopg://", "postgresql://"),
|
||||
)
|
||||
|
||||
|
||||
def normalize_postgres_uri(v):
|
||||
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
|
||||
|
||||
Accepts the forms operators naturally write (``postgres://``,
|
||||
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
|
||||
Unknown schemes pass through unchanged so SQLAlchemy can produce its
|
||||
own dialect-not-found error.
|
||||
"""
|
||||
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
|
||||
|
||||
|
||||
def normalize_pgvector_connection_string(v):
|
||||
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
|
||||
|
||||
Strips the SQLAlchemy dialect prefix if the operator accidentally
|
||||
copied their POSTGRES_URI value here — libpq can't parse it.
|
||||
User-friendly forms (``postgres://``, ``postgresql://``) pass
|
||||
through unchanged since libpq accepts them natively.
|
||||
"""
|
||||
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)
|
||||
@@ -8,6 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
from application.core.db_uri import ( # noqa: E402
|
||||
normalize_pgvector_connection_string,
|
||||
normalize_postgres_uri,
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(extra="ignore")
|
||||
|
||||
@@ -22,6 +28,13 @@ class Settings(BaseSettings):
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
# User-data Postgres DB Optional during the MongoDB→Postgres migration; becomes required once the migration is
|
||||
# complete.
|
||||
POSTGRES_URI: Optional[str] = None
|
||||
|
||||
# MongoDB→Postgres migration switches
|
||||
USE_POSTGRES: bool = False
|
||||
READ_POSTGRES: bool = False
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
@@ -117,7 +130,10 @@ class Settings(BaseSettings):
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
# PGVector vectorstore config
|
||||
# PGVector vectorstore config. Write the URI in whichever form you
|
||||
# prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy
|
||||
# dialect form (``postgresql+psycopg://``) are all accepted and
|
||||
# normalized internally for ``psycopg.connect()``.
|
||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||
# Milvus vectorstore config
|
||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||
@@ -156,6 +172,16 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
return normalize_postgres_uri(v)
|
||||
|
||||
@field_validator("PGVECTOR_CONNECTION_STRING", mode="before")
|
||||
@classmethod
|
||||
def _normalize_pgvector_connection_string_validator(cls, v):
|
||||
return normalize_pgvector_connection_string(v)
|
||||
|
||||
@field_validator(
|
||||
"API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
|
||||
@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
"""Parse file."""
|
||||
try:
|
||||
import ebooklib
|
||||
from ebooklib import epub
|
||||
from fast_ebook import epub
|
||||
except ImportError:
|
||||
raise ValueError("`EbookLib` is required to read Epub files.")
|
||||
try:
|
||||
import html2text
|
||||
except ImportError:
|
||||
raise ValueError("`html2text` is required to parse Epub files.")
|
||||
raise ValueError("`fast-ebook` is required to read Epub files.")
|
||||
|
||||
text_list = []
|
||||
book = epub.read_epub(file, options={"ignore_ncx": True})
|
||||
|
||||
# Iterate through all chapters.
|
||||
for item in book.get_items():
|
||||
# Chapters are typically located in epub documents items.
|
||||
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
||||
text_list.append(
|
||||
html2text.html2text(item.get_content().decode("utf-8"))
|
||||
)
|
||||
|
||||
text = "\n".join(text_list)
|
||||
book = epub.read_epub(file)
|
||||
text = book.to_markdown()
|
||||
return text
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
anthropic==0.86.0
|
||||
boto3==1.42.24
|
||||
alembic>=1.13,<2
|
||||
anthropic==0.88.0
|
||||
boto3==1.42.83
|
||||
beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.3
|
||||
@@ -11,11 +12,11 @@ rapidocr>=1.4.0
|
||||
onnxruntime>=1.19.0
|
||||
docx2txt==0.9
|
||||
ddgs>=8.0.0
|
||||
ebooklib==0.20
|
||||
elevenlabs==2.40.0
|
||||
fast-ebook
|
||||
elevenlabs==2.41.0
|
||||
Flask==3.1.3
|
||||
faiss-cpu==1.13.2
|
||||
fastmcp==2.14.6
|
||||
fastmcp==3.2.0
|
||||
flask-restx==1.3.2
|
||||
google-genai==1.69.0
|
||||
google-api-python-client==2.193.0
|
||||
@@ -23,10 +24,9 @@ google-auth-httplib2==0.3.1
|
||||
google-auth-oauthlib==1.3.1
|
||||
gTTS==2.5.4
|
||||
gunicorn==25.3.0
|
||||
html2text==2025.4.15
|
||||
jinja2==3.1.6
|
||||
jiter==0.13.0
|
||||
jmespath==1.0.1
|
||||
jmespath==1.1.0
|
||||
joblib==1.5.3
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
@@ -34,7 +34,7 @@ kombu==5.6.2
|
||||
langchain==1.2.3
|
||||
langchain-community==0.4.1
|
||||
langchain-core==1.2.23
|
||||
langchain-openai==1.1.7
|
||||
langchain-openai==1.1.12
|
||||
langchain-text-splitters==1.1.1
|
||||
langsmith==0.7.23
|
||||
lazy-object-proxy==1.12.0
|
||||
@@ -53,29 +53,30 @@ orjson==3.11.7
|
||||
packaging==26.0
|
||||
pandas==3.0.2
|
||||
openpyxl==3.1.5
|
||||
pathable==0.4.4
|
||||
pathable==0.5.0
|
||||
pdf2image>=1.17.0
|
||||
pillow
|
||||
portalocker>=2.7.0,<3.0.0
|
||||
portalocker>=2.7.0,<4.0.0
|
||||
prompt-toolkit==3.0.52
|
||||
protobuf==7.34.1
|
||||
psycopg2-binary==2.9.11
|
||||
psycopg[binary,pool]>=3.1,<4
|
||||
py==1.11.0
|
||||
pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pymongo==4.16.0
|
||||
pypdf==6.6.0
|
||||
pypdf==6.9.2
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv
|
||||
python-jose==3.5.0
|
||||
python-pptx==1.0.2
|
||||
redis==7.4.0
|
||||
referencing>=0.28.0,<0.38.0
|
||||
regex==2026.3.32
|
||||
regex==2026.4.4
|
||||
requests==2.33.1
|
||||
retry==0.9.2
|
||||
sentence-transformers==5.3.0
|
||||
sqlalchemy>=2.0,<3
|
||||
tiktoken==0.12.0
|
||||
tokenizers==0.22.2
|
||||
torch==2.11.0
|
||||
|
||||
10
application/storage/db/__init__.py
Normal file
10
application/storage/db/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""PostgreSQL storage layer for user-level data.
|
||||
|
||||
This package holds the SQLAlchemy Core engine, metadata, repositories, and
|
||||
migration infrastructure for the user-data Postgres database. It is separate
|
||||
from ``application/vectorstore/pgvector.py`` — the two may point at the same
|
||||
cluster or at different clusters depending on operator configuration.
|
||||
|
||||
Repository modules are added in later phases
|
||||
as individual collections are ported.
|
||||
"""
|
||||
39
application/storage/db/base_repository.py
Normal file
39
application/storage/db/base_repository.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Common helpers shared by all repositories.
|
||||
|
||||
Repositories are thin wrappers around SQLAlchemy Core query construction.
|
||||
They take a ``Connection`` on call and return plain ``dict`` rows during the
|
||||
Mongo→Postgres cutover so that call sites don't have to change shape. Once
|
||||
cutover is complete, a follow-up phase may migrate repo return types to
|
||||
Pydantic DTOs (tracked in the migration plan as a post-migration item).
|
||||
"""
|
||||
|
||||
from typing import Any, Mapping
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def row_to_dict(row: Any) -> dict:
|
||||
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
|
||||
|
||||
During the migration window, API responses and downstream code still
|
||||
expect a string ``_id`` field (matching the Mongo shape). This helper
|
||||
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
|
||||
existing serializers keep working unchanged.
|
||||
|
||||
Args:
|
||||
row: A SQLAlchemy ``Row`` object, or ``None``.
|
||||
|
||||
Returns:
|
||||
A plain dict, or an empty dict if ``row`` is ``None``.
|
||||
"""
|
||||
if row is None:
|
||||
return {}
|
||||
|
||||
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
|
||||
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
|
||||
out = dict(mapping)
|
||||
|
||||
if "id" in out and out["id"] is not None:
|
||||
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
|
||||
out["_id"] = out["id"]
|
||||
|
||||
return out
|
||||
67
application/storage/db/dual_write.py
Normal file
67
application/storage/db/dual_write.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Best-effort Postgres dual-write helper used during the MongoDB→Postgres
|
||||
migration.
|
||||
|
||||
The helper:
|
||||
|
||||
* Returns immediately if ``settings.USE_POSTGRES`` is off, so default-off
|
||||
call sites add literally zero work.
|
||||
* Opens a transactional connection from the user-data SQLAlchemy engine.
|
||||
* Instantiates the caller's repository class on that connection.
|
||||
* Runs the caller's operation.
|
||||
* Swallows and logs any exception. **Mongo remains the source of truth
|
||||
during the dual-write window** — a Postgres-side failure must never
|
||||
break a user-facing request. Drift that builds up from swallowed
|
||||
failures is caught separately by re-running the backfill script.
|
||||
|
||||
Call sites look like::
|
||||
|
||||
users_collection.update_one(..., {"$addToSet": {...}}) # Mongo write, unchanged
|
||||
dual_write(UsersRepository, lambda r: r.add_pinned(uid, aid)) # Postgres mirror
|
||||
|
||||
A single parameterised helper rather than one function per collection
|
||||
means a new collection just needs its repository class — no new helper
|
||||
function, no new feature flag. The whole helper is deleted at Phase 5
|
||||
when the migration is complete.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_Repo = TypeVar("_Repo")
|
||||
|
||||
|
||||
def dual_write(repo_cls: type[_Repo], fn: Callable[[_Repo], None]) -> None:
|
||||
"""Mirror a Mongo write into Postgres via ``repo_cls``, best-effort.
|
||||
|
||||
No-op when ``settings.USE_POSTGRES`` is false. Any exception
|
||||
(connection pool exhaustion, migration drift, SQL error) is logged
|
||||
and swallowed so the caller's primary Mongo write remains the source
|
||||
of truth.
|
||||
|
||||
Args:
|
||||
repo_cls: The repository class to instantiate (e.g. ``UsersRepository``).
|
||||
fn: A callable that takes the instantiated repository and performs
|
||||
the desired write.
|
||||
"""
|
||||
if not settings.USE_POSTGRES:
|
||||
return
|
||||
|
||||
try:
|
||||
# Lazy import so modules that import dual_write don't pay the
|
||||
# SQLAlchemy import cost when the flag is off.
|
||||
from application.storage.db.engine import get_engine
|
||||
|
||||
with get_engine().begin() as conn:
|
||||
fn(repo_cls(conn))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Postgres dual-write failed for %s — Mongo write already committed",
|
||||
repo_cls.__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
67
application/storage/db/engine.py
Normal file
67
application/storage/db/engine.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""SQLAlchemy Core engine factory for the user-data Postgres database.
|
||||
|
||||
The engine is lazily constructed on first use and cached as a module-level
|
||||
singleton. Repositories and the Alembic env module both obtain connections
|
||||
through this factory, so pool tuning lives in one place.
|
||||
|
||||
``POSTGRES_URI`` can be written in any of the common Postgres URI forms::
|
||||
|
||||
postgres://user:pass@host:5432/docsgpt
|
||||
postgresql://user:pass@host:5432/docsgpt
|
||||
|
||||
Both are accepted and normalized internally to the psycopg3 dialect
|
||||
(``postgresql+psycopg://``) by ``application.core.settings``. Operators
|
||||
don't need to know about SQLAlchemy dialect prefixes.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Engine, create_engine
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that
|
||||
reach this path without a configured URI have a setup bug — the
|
||||
error message points them at the right setting.
|
||||
|
||||
Returns:
|
||||
A SQLAlchemy ``Engine`` configured with a pooled connection to
|
||||
Postgres via psycopg3.
|
||||
"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
if not settings.POSTGRES_URI:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
_engine = create_engine(
|
||||
settings.POSTGRES_URI,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
|
||||
pool_recycle=1800,
|
||||
future=True,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def dispose_engine() -> None:
|
||||
"""Dispose the pooled connections and reset the singleton.
|
||||
|
||||
Called from the Celery ``worker_process_init`` signal so each forked
|
||||
worker gets a fresh pool instead of sharing file descriptors with the
|
||||
parent process (which corrupts the pool on fork).
|
||||
"""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
38
application/storage/db/models.py
Normal file
38
application/storage/db/models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""SQLAlchemy Core metadata for the user-data Postgres database.
|
||||
|
||||
Tables are added here one at a time as repositories are built during the
|
||||
MongoDB→Postgres migration. The baseline schema in the Alembic migration
|
||||
(``application/alembic/versions/0001_initial.py``) is the source of truth
|
||||
for DDL; the ``Table`` definitions below must match it column-for-column.
|
||||
If the two drift, migrations win — update this file to match.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
MetaData,
|
||||
Table,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||
|
||||
users_table = Table(
|
||||
"users",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False, unique=True),
|
||||
Column(
|
||||
"agent_preferences",
|
||||
JSONB,
|
||||
nullable=False,
|
||||
server_default='{"pinned": [], "shared_with_me": []}',
|
||||
),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
11
application/storage/db/repositories/__init__.py
Normal file
11
application/storage/db/repositories/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Repositories for the user-data Postgres database.
|
||||
|
||||
Each module in this package exposes exactly one repository class. Repository
|
||||
methods take a ``Connection`` (either as a constructor argument or as a
|
||||
method argument) and return plain ``dict`` rows via
|
||||
``application.storage.db.base_repository.row_to_dict`` during the
|
||||
MongoDB→Postgres cutover, so call sites don't have to change shape.
|
||||
|
||||
Repositories are added one collection at a time, matching the phased
|
||||
rollout in ``migration-postgres.md``.
|
||||
"""
|
||||
245
application/storage/db/repositories/users.py
Normal file
245
application/storage/db/repositories/users.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Repository for the ``users`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``users_collection``:
|
||||
|
||||
1. ``ensure_user_doc`` in ``application/api/user/base.py`` (upsert + get)
|
||||
2. Pin/unpin agents in ``application/api/user/agents/routes.py`` (add/remove
|
||||
on ``agent_preferences.pinned``)
|
||||
3. Share accept/reject in ``application/api/user/agents/sharing.py`` (add/
|
||||
bulk-remove on ``agent_preferences.shared_with_me``)
|
||||
4. Cascade delete of an agent id from both arrays at once
|
||||
|
||||
All array mutations are implemented as single atomic UPDATE statements
|
||||
using JSONB operators (``jsonb_set``, ``jsonb_array_elements``, ``@>``)
|
||||
so there is no read-modify-write race between concurrent writers on the
|
||||
same user row.
|
||||
|
||||
The repository takes a ``Connection`` and does not manage its own
|
||||
transactions. Callers are responsible for wrapping writes in
|
||||
``with engine.begin() as conn:`` (production) or the test fixture's
|
||||
rollback-per-test connection (tests).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
_DEFAULT_PREFERENCES = '{"pinned": [], "shared_with_me": []}'
|
||||
|
||||
|
||||
class UsersRepository:
|
||||
"""Postgres-backed replacement for Mongo ``users_collection`` writes/reads."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
def get(self, user_id: str) -> Optional[dict]:
|
||||
"""Return the user row as a dict, or ``None`` if missing.
|
||||
|
||||
Args:
|
||||
user_id: Auth-provider ``sub`` (opaque string).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM users WHERE user_id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Upsert
|
||||
# ------------------------------------------------------------------
|
||||
def upsert(self, user_id: str) -> dict:
|
||||
"""Ensure a row exists for ``user_id`` and return it.
|
||||
|
||||
Matches Mongo's ``find_one_and_update(..., $setOnInsert, upsert=True,
|
||||
return_document=AFTER)`` semantics: if the row exists, preferences
|
||||
are preserved untouched; if it doesn't, a new row is created with
|
||||
default preferences.
|
||||
|
||||
The ``DO UPDATE SET user_id = EXCLUDED.user_id`` branch is a
|
||||
deliberate no-op that lets ``RETURNING *`` fire on both the insert
|
||||
and conflict paths (``DO NOTHING`` would suppress the returning).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO users (user_id, agent_preferences)
|
||||
VALUES (:user_id, CAST(:default_prefs AS jsonb))
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET user_id = EXCLUDED.user_id
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "default_prefs": _DEFAULT_PREFERENCES},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pinned agents
|
||||
# ------------------------------------------------------------------
|
||||
def add_pinned(self, user_id: str, agent_id: str) -> None:
|
||||
"""Idempotently append ``agent_id`` to ``agent_preferences.pinned``.
|
||||
|
||||
Uses ``@>`` containment so a duplicate add is a no-op rather than a
|
||||
silent double-insert. The whole update is a single atomic statement
|
||||
so concurrent add_pinned calls on the same user cannot interleave
|
||||
into a read-modify-write race.
|
||||
"""
|
||||
self._append_to_jsonb_array(user_id, "pinned", agent_id)
|
||||
|
||||
def remove_pinned(self, user_id: str, agent_id: str) -> None:
|
||||
"""Remove ``agent_id`` from ``agent_preferences.pinned`` if present."""
|
||||
self._remove_from_jsonb_array(user_id, "pinned", [agent_id])
|
||||
|
||||
def remove_pinned_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
|
||||
"""Remove every id in ``agent_ids`` from ``agent_preferences.pinned``.
|
||||
|
||||
No-op if the list is empty. Unknown ids are silently ignored so
|
||||
callers can pass the full "stale" set without pre-filtering.
|
||||
"""
|
||||
ids = list(agent_ids)
|
||||
if not ids:
|
||||
return
|
||||
self._remove_from_jsonb_array(user_id, "pinned", ids)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared-with-me agents
|
||||
# ------------------------------------------------------------------
|
||||
def add_shared(self, user_id: str, agent_id: str) -> None:
|
||||
"""Idempotently append ``agent_id`` to ``agent_preferences.shared_with_me``."""
|
||||
self._append_to_jsonb_array(user_id, "shared_with_me", agent_id)
|
||||
|
||||
def remove_shared_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
|
||||
"""Bulk-remove from ``agent_preferences.shared_with_me``. Empty list is a no-op."""
|
||||
ids = list(agent_ids)
|
||||
if not ids:
|
||||
return
|
||||
self._remove_from_jsonb_array(user_id, "shared_with_me", ids)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Combined removal — called when an agent is hard-deleted
|
||||
# ------------------------------------------------------------------
|
||||
def remove_agent_from_all(self, user_id: str, agent_id: str) -> None:
|
||||
"""Remove ``agent_id`` from BOTH pinned and shared_with_me atomically.
|
||||
|
||||
Mirrors the Mongo ``$pull`` that targets both nested array fields
|
||||
in one ``update_one`` — see ``application/api/user/agents/routes.py``
|
||||
around the agent-delete path.
|
||||
"""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE users
|
||||
SET
|
||||
agent_preferences = jsonb_set(
|
||||
jsonb_set(
|
||||
agent_preferences,
|
||||
'{pinned}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||
) AS elem
|
||||
WHERE (elem #>> '{}') != :agent_id
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
),
|
||||
'{shared_with_me}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||
) AS elem
|
||||
WHERE (elem #>> '{}') != :agent_id
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "agent_id": agent_id},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _append_to_jsonb_array(self, user_id: str, key: str, agent_id: str) -> None:
|
||||
"""Idempotent append of ``agent_id`` to ``agent_preferences.<key>``.
|
||||
|
||||
The ``key`` argument is NOT user input — it's hard-coded by the
|
||||
calling method (``pinned`` / ``shared_with_me``). It goes into the
|
||||
SQL literal because ``jsonb_set`` requires a path literal, not a
|
||||
bind parameter. This is safe as long as callers never pass
|
||||
untrusted strings for ``key``.
|
||||
"""
|
||||
if key not in ("pinned", "shared_with_me"):
|
||||
raise ValueError(f"unsupported jsonb key: {key!r}")
|
||||
self._conn.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE users
|
||||
SET
|
||||
agent_preferences = jsonb_set(
|
||||
agent_preferences,
|
||||
'{{{key}}}',
|
||||
CASE
|
||||
WHEN agent_preferences->'{key}' @> to_jsonb(CAST(:agent_id AS text))
|
||||
THEN agent_preferences->'{key}'
|
||||
ELSE
|
||||
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|
||||
|| to_jsonb(CAST(:agent_id AS text))
|
||||
END
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "agent_id": agent_id},
|
||||
)
|
||||
|
||||
def _remove_from_jsonb_array(
|
||||
self, user_id: str, key: str, agent_ids: list[str]
|
||||
) -> None:
|
||||
"""Remove every id in ``agent_ids`` from ``agent_preferences.<key>``."""
|
||||
if key not in ("pinned", "shared_with_me"):
|
||||
raise ValueError(f"unsupported jsonb key: {key!r}")
|
||||
self._conn.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE users
|
||||
SET
|
||||
agent_preferences = jsonb_set(
|
||||
agent_preferences,
|
||||
'{{{key}}}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|
||||
) AS elem
|
||||
WHERE NOT ((elem #>> '{{}}') = ANY(:agent_ids))
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "agent_ids": agent_ids},
|
||||
)
|
||||
@@ -21,10 +21,19 @@ class LocalStorage(BaseStorage):
|
||||
)
|
||||
|
||||
def _get_full_path(self, path: str) -> str:
|
||||
"""Get absolute path by combining base_dir and path."""
|
||||
"""Get absolute path by combining base_dir and path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the resolved path escapes base_dir (path traversal).
|
||||
"""
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.join(self.base_dir, path)
|
||||
resolved = os.path.realpath(path)
|
||||
else:
|
||||
resolved = os.path.realpath(os.path.join(self.base_dir, path))
|
||||
base = os.path.realpath(self.base_dir)
|
||||
if not resolved.startswith(base + os.sep) and resolved != base:
|
||||
raise ValueError(f"Path traversal detected: {path}")
|
||||
return resolved
|
||||
|
||||
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
||||
"""Save a file to local storage."""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import io
|
||||
import os
|
||||
import posixpath
|
||||
from typing import BinaryIO, Callable, List
|
||||
|
||||
import boto3
|
||||
@@ -14,6 +15,20 @@ from botocore.exceptions import ClientError
|
||||
class S3Storage(BaseStorage):
|
||||
"""AWS S3 storage implementation."""
|
||||
|
||||
@staticmethod
|
||||
def _validate_path(path: str) -> str:
|
||||
"""Validate and normalize an S3 key to prevent path traversal.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path contains traversal sequences or is absolute.
|
||||
"""
|
||||
if "\x00" in path:
|
||||
raise ValueError(f"Null byte in path: {path}")
|
||||
normalized = posixpath.normpath(path)
|
||||
if normalized.startswith("/") or normalized.startswith(".."):
|
||||
raise ValueError(f"Path traversal detected: {path}")
|
||||
return normalized
|
||||
|
||||
def __init__(self, bucket_name=None):
|
||||
"""
|
||||
Initialize S3 storage.
|
||||
@@ -46,6 +61,7 @@ class S3Storage(BaseStorage):
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Save a file to S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
self.s3.upload_fileobj(
|
||||
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
|
||||
)
|
||||
@@ -61,6 +77,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
def get_file(self, path: str) -> BinaryIO:
|
||||
"""Get a file from S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
if not self.file_exists(path):
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
file_obj = io.BytesIO()
|
||||
@@ -70,6 +87,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
def delete_file(self, path: str) -> bool:
|
||||
"""Delete a file from S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
try:
|
||||
self.s3.delete_object(Bucket=self.bucket_name, Key=path)
|
||||
return True
|
||||
@@ -78,6 +96,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
def file_exists(self, path: str) -> bool:
|
||||
"""Check if a file exists in S3 storage."""
|
||||
path = self._validate_path(path)
|
||||
try:
|
||||
self.s3.head_object(Bucket=self.bucket_name, Key=path)
|
||||
return True
|
||||
@@ -115,6 +134,7 @@ class S3Storage(BaseStorage):
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
path = self._validate_path(path)
|
||||
if not self.file_exists(path):
|
||||
raise FileNotFoundError(f"File not found in S3: {path}")
|
||||
with tempfile.NamedTemporaryFile(
|
||||
|
||||
@@ -11,11 +11,33 @@ from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
def get_vectorstore(path: str) -> str:
|
||||
if path:
|
||||
vectorstore = f"indexes/{path}"
|
||||
else:
|
||||
vectorstore = "indexes"
|
||||
return vectorstore
|
||||
"""Build a safe local path for a FAISS index.
|
||||
|
||||
Args:
|
||||
path: Source identifier provided by the caller.
|
||||
|
||||
Returns:
|
||||
The validated vectorstore path rooted under ``indexes``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``path`` escapes the ``indexes`` directory.
|
||||
"""
|
||||
base_dir = "indexes"
|
||||
if not path:
|
||||
return base_dir
|
||||
|
||||
normalized = str(path).strip()
|
||||
if "\\" in normalized:
|
||||
raise ValueError("Invalid source_id path")
|
||||
|
||||
candidate = os.path.normpath(os.path.join(base_dir, normalized))
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
candidate_abs = os.path.abspath(candidate)
|
||||
|
||||
if not candidate_abs.startswith(base_abs + os.sep) and candidate_abs != base_abs:
|
||||
raise ValueError("Invalid source_id path")
|
||||
|
||||
return candidate
|
||||
|
||||
|
||||
class FaissStore(BaseVectorStore):
|
||||
|
||||
@@ -37,27 +37,25 @@ class PGVectorStore(BaseVectorStore):
|
||||
)
|
||||
|
||||
try:
|
||||
import psycopg2
|
||||
from psycopg2.extras import Json
|
||||
import pgvector.psycopg2
|
||||
import psycopg
|
||||
from pgvector.psycopg import register_vector
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import required packages. "
|
||||
"Please install with `pip install psycopg2-binary pgvector`."
|
||||
"Please install with `pip install 'psycopg[binary,pool]' pgvector`."
|
||||
)
|
||||
|
||||
self._psycopg2 = psycopg2
|
||||
self._Json = Json
|
||||
self._pgvector = pgvector.psycopg2
|
||||
self._psycopg = psycopg
|
||||
self._register_vector = register_vector
|
||||
self._connection = None
|
||||
self._ensure_table_exists()
|
||||
|
||||
def _get_connection(self):
|
||||
"""Get or create database connection"""
|
||||
if self._connection is None or self._connection.closed:
|
||||
self._connection = self._psycopg2.connect(self._connection_string)
|
||||
self._connection = self._psycopg.connect(self._connection_string)
|
||||
# Register pgvector types
|
||||
self._pgvector.register_vector(self._connection)
|
||||
self._register_vector(self._connection)
|
||||
return self._connection
|
||||
|
||||
def _ensure_table_exists(self):
|
||||
@@ -170,7 +168,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embedding, self._Json(metadata), self._source_id)
|
||||
(text, embedding, metadata, self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
inserted_ids.append(str(inserted_id))
|
||||
@@ -261,7 +259,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embeddings[0], self._Json(final_metadata), self._source_id)
|
||||
(text, embeddings[0], final_metadata, self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
|
||||
@@ -7,6 +7,10 @@ export default {
|
||||
"title": "🔌 Agent API",
|
||||
"href": "/Agents/api"
|
||||
},
|
||||
"openai-compatible": {
|
||||
"title": "🔄 OpenAI-Compatible API",
|
||||
"href": "/Agents/openai-compatible"
|
||||
},
|
||||
"webhooks": {
|
||||
"title": "🪝 Agent Webhooks",
|
||||
"href": "/Agents/webhooks"
|
||||
|
||||
@@ -15,6 +15,10 @@ DocsGPT Agents can be accessed programmatically through API endpoints. This page
|
||||
|
||||
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
|
||||
|
||||
<Callout type="info">
|
||||
Looking to connect an existing OpenAI-compatible client (opencode, aider, the OpenAI SDKs, etc.) to a DocsGPT Agent? Use the [OpenAI-Compatible Chat Completions API](/Agents/openai-compatible) — it speaks the standard chat completions protocol so no adapter code is required.
|
||||
</Callout>
|
||||
|
||||
## Base URL
|
||||
|
||||
<Callout type="info">
|
||||
|
||||
@@ -111,6 +111,7 @@ Once an agent is created, you can:
|
||||
* Modify any of its configuration settings (name, description, source, prompt, tools, type).
|
||||
* **Generate a Public Link:** From the edit screen, you can create a shareable public link that allows others to import and use your agent.
|
||||
* **Get a Webhook URL:** You can also obtain a Webhook URL for the agent. This allows external applications or services to trigger the agent and receive responses programmatically, enabling powerful integrations and automations.
|
||||
* **Use it via API:** Every agent exposes an API key that can be used with the native [Agent API](/Agents/api) or the [OpenAI-Compatible API](/Agents/openai-compatible) so you can drop DocsGPT Agents into any tool that already speaks the chat completions protocol.
|
||||
|
||||
## Seeding Premade Agents from YAML
|
||||
|
||||
|
||||
93
docs/content/Agents/openai-compatible.mdx
Normal file
93
docs/content/Agents/openai-compatible.mdx
Normal file
@@ -0,0 +1,93 @@
|
||||
---
|
||||
title: OpenAI-Compatible API
|
||||
description: Connect any OpenAI-compatible client to DocsGPT Agents via /v1/chat/completions.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# OpenAI-Compatible API
|
||||
|
||||
DocsGPT exposes `/v1/chat/completions` following the standard chat completions protocol. Point any compatible client — **opencode**, **Aider**, **LibreChat** or the OpenAI SDKs — at your DocsGPT Agent by changing only the base URL and API key.
|
||||
|
||||
## Quick Start
|
||||
|
||||
<Tabs items={['Python', 'cURL']}>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:7091/v1", # or https://gptcloud.arc53.com/v1
|
||||
api_key="your_agent_api_key",
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="docsgpt-agent",
|
||||
messages=[{"role": "user", "content": "Summarize our refund policy"}],
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/v1/chat/completions \
|
||||
-H "Authorization: Bearer your_agent_api_key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"docsgpt-agent","messages":[{"role":"user","content":"Summarize our refund policy"}]}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
The `model` field is accepted but ignored — the agent bound to your API key determines the model. The agent's prompt, sources, tools, and default model are loaded automatically.
|
||||
|
||||
## Base URL & Auth
|
||||
|
||||
| Environment | Base URL |
|
||||
| --- | --- |
|
||||
| Local | `http://localhost:7091/v1` |
|
||||
| Cloud | `https://gptcloud.arc53.com/v1` |
|
||||
|
||||
Authenticate with `Authorization: Bearer <agent_api_key>`.
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
| --- | --- | --- |
|
||||
| `POST` | `/v1/chat/completions` | Chat request (streaming or non-streaming) |
|
||||
| `GET` | `/v1/models` | List agents available to your key |
|
||||
|
||||
## Streaming
|
||||
|
||||
Set `"stream": true`. You'll receive SSE chunks with `choices[0].delta.content`. DocsGPT-specific events (sources, tool calls) arrive as extra frames with a `docsgpt` key — standard clients ignore them.
|
||||
|
||||
```python
|
||||
stream = client.chat.completions.create(
|
||||
model="docsgpt-agent",
|
||||
stream=True,
|
||||
messages=[{"role": "user", "content": "Explain vector search"}],
|
||||
)
|
||||
for chunk in stream:
|
||||
print(chunk.choices[0].delta.content or "", end="", flush=True)
|
||||
```
|
||||
|
||||
## System Prompt Override
|
||||
|
||||
System messages are **dropped by default** — the agent's configured prompt is used. To allow callers to override it, enable **Allow prompt override** in the agent's Advanced settings.
|
||||
|
||||
<Callout type="warning">
|
||||
When an override is active, the agent's prompt template is replaced wholesale — template variables like `{summaries}` are not substituted.
|
||||
</Callout>
|
||||
|
||||
## Conversation Persistence
|
||||
|
||||
Conversations are **not persisted by default** (stateless, like most OpenAI clients expect). Opt in per request:
|
||||
|
||||
```json
|
||||
{ "docsgpt": { "save_conversation": true } }
|
||||
```
|
||||
|
||||
The response will include `docsgpt.conversation_id`.
|
||||
|
||||
## When to Use Native Endpoints Instead
|
||||
|
||||
Use [`/api/answer` or `/stream`](/Agents/api) if you need server-side attachments, `passthrough` template variables, explicit `conversation_id` reuse, or persistence by default.
|
||||
114
docs/content/Deploying/Postgres-Migration.mdx
Normal file
114
docs/content/Deploying/Postgres-Migration.mdx
Normal file
@@ -0,0 +1,114 @@
|
||||
---
|
||||
title: PostgreSQL for User Data
|
||||
description: Set up PostgreSQL as the user-data store for DocsGPT and migrate from MongoDB at your own pace.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
|
||||
# PostgreSQL for User Data
|
||||
|
||||
DocsGPT is progressively moving user data (conversations, agents, prompts,
|
||||
preferences, etc.) from MongoDB to PostgreSQL, one collection at a time.
|
||||
Each collection is guarded by a feature flag so you can opt in and roll
|
||||
back instantly. MongoDB stays the source of truth until you cut over
|
||||
reads; vector stores (`VECTOR_STORE=pgvector`, `faiss`, `qdrant`, `mongodb`, …)
|
||||
are unaffected.
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
Which collections are available today is in the [Status](#status)
|
||||
table below. That table is the only part of this page that changes
|
||||
release to release.
|
||||
</Callout>
|
||||
|
||||
## Setup
|
||||
|
||||
1. **Run Postgres 13+.** Native install, Docker, or managed (Neon, RDS,
|
||||
Supabase, Cloud SQL…) — all work. You'll need the `pgcrypto` and
|
||||
`citext` extensions, both standard contrib modules available
|
||||
everywhere.
|
||||
|
||||
2. **Create a database and role** (skip if your managed provider gave
|
||||
you these):
|
||||
|
||||
```sql
|
||||
CREATE ROLE docsgpt LOGIN PASSWORD 'docsgpt';
|
||||
CREATE DATABASE docsgpt OWNER docsgpt;
|
||||
```
|
||||
|
||||
3. **Set `POSTGRES_URI` in `.env`.** Any standard Postgres URI works —
|
||||
DocsGPT normalizes it internally.
|
||||
|
||||
```bash
|
||||
POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
# Append ?sslmode=require for managed providers that enforce SSL.
|
||||
```
|
||||
|
||||
4. **Apply the schema** (idempotent — safe to re-run):
|
||||
|
||||
```bash
|
||||
python scripts/db/init_postgres.py
|
||||
```
|
||||
|
||||
## Migrating data
|
||||
|
||||
Two global flags, no per-collection knobs — every collection marked ✅
|
||||
in the [Status](#status) table is handled automatically.
|
||||
|
||||
1. **Enable dual-write.** Writes go to both Mongo and Postgres; Mongo
|
||||
remains source of truth. Set the flag in `.env` and restart:
|
||||
|
||||
```bash
|
||||
USE_POSTGRES=true
|
||||
```
|
||||
|
||||
2. **Backfill existing data.** Idempotent — re-run any time to re-sync
|
||||
drifted rows. Without arguments, backfills every registered table;
|
||||
pass `--tables` to limit.
|
||||
|
||||
```bash
|
||||
python scripts/db/backfill.py --dry-run # preview everything
|
||||
python scripts/db/backfill.py # real run, everything
|
||||
python scripts/db/backfill.py --tables users # only specific tables
|
||||
```
|
||||
|
||||
3. **Cut over reads** once you trust the Postgres state:
|
||||
|
||||
```bash
|
||||
READ_POSTGRES=true
|
||||
```
|
||||
|
||||
Rollback is instant: unset `READ_POSTGRES` and restart. Dual-write
|
||||
keeps Postgres up to date so you can flip back and forth.
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Don't decommission MongoDB until every collection you use is fully
|
||||
cut over. During the migration window, Mongo is still required.
|
||||
</Callout>
|
||||
|
||||
## Status
|
||||
|
||||
_Last updated: 2026-04-10_
|
||||
|
||||
| Collection | Status |
|
||||
|---|---|
|
||||
| `users` | ✅ Phase 1 |
|
||||
| `prompts`, `user_tools`, `feedback`, `stack_logs`, `user_logs`, `token_usage` | ⏳ Phase 1 |
|
||||
| `agents`, `sources`, `attachments`, `memories`, `todos`, `notes`, `connector_sessions`, `agent_folders` | ⏳ Phase 2 |
|
||||
| `conversations`, `pending_tool_state`, `workflows` | ⏳ Phase 3 |
|
||||
|
||||
Schemas for **every** row above already exist after `init_postgres.py`
|
||||
runs. What's landing progressively is the application-level dual-write
|
||||
wiring and the backfill logic for each collection. Once a collection
|
||||
is ✅, enabling `USE_POSTGRES=true` and running `python scripts/db/backfill.py`
|
||||
picks it up automatically — no per-collection config change.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **`relation "..." does not exist`** — run `python scripts/db/init_postgres.py`.
|
||||
- **`FATAL: role "docsgpt" does not exist`** — run the `CREATE ROLE` /
|
||||
`CREATE DATABASE` statements from step 2 as a Postgres superuser.
|
||||
- **SSL errors on a managed provider** — append `?sslmode=require` to
|
||||
`POSTGRES_URI`.
|
||||
- **Dual-write warnings in the logs** — expected to be non-fatal. Mongo
|
||||
is source of truth, so the user-facing request succeeds. Re-run the
|
||||
backfill to re-sync whichever rows drifted.
|
||||
@@ -19,6 +19,10 @@ export default {
|
||||
"title": "☁️ Hosting DocsGPT",
|
||||
"href": "/Deploying/Hosting-the-app"
|
||||
},
|
||||
"Postgres-Migration": {
|
||||
"title": "🐘 PostgreSQL for User Data",
|
||||
"href": "/Deploying/Postgres-Migration"
|
||||
},
|
||||
"Amazon-Lightsail": {
|
||||
"title": "Hosting DocsGPT on Amazon Lightsail",
|
||||
"href": "/Deploying/Amazon-Lightsail",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import pprint
|
||||
|
||||
@@ -10,6 +12,7 @@ docsgpt_url = os.getenv("docsgpt_url")
|
||||
chatwoot_url = os.getenv("chatwoot_url")
|
||||
docsgpt_key = os.getenv("docsgpt_key")
|
||||
chatwoot_token = os.getenv("chatwoot_token")
|
||||
chatwoot_webhook_secret = os.getenv("chatwoot_webhook_secret", "")
|
||||
# account_id = os.getenv("account_id")
|
||||
# assignee_id = os.getenv("assignee_id")
|
||||
label_stop = "human-requested"
|
||||
@@ -45,12 +48,35 @@ def send_to_chatwoot(account, conversation, message):
|
||||
return r.json()
|
||||
|
||||
|
||||
def is_valid_chatwoot_signature(raw_body: bytes, signature_header: str | None) -> bool:
|
||||
"""Validate Chatwoot webhook signature using shared secret."""
|
||||
if not chatwoot_webhook_secret or not signature_header:
|
||||
return False
|
||||
|
||||
expected = hmac.new(
|
||||
chatwoot_webhook_secret.encode("utf-8"), raw_body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
provided = signature_header.strip()
|
||||
if provided.startswith("sha256="):
|
||||
provided = provided.split("=", maxsplit=1)[1]
|
||||
|
||||
return hmac.compare_digest(provided, expected)
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/docsgpt', methods=['POST'])
|
||||
def docsgpt():
|
||||
data = request.get_json()
|
||||
raw_body = request.get_data()
|
||||
signature = request.headers.get("X-Chatwoot-Signature")
|
||||
if not is_valid_chatwoot_signature(raw_body, signature):
|
||||
return "Unauthorized", 401
|
||||
|
||||
data = request.get_json(silent=True)
|
||||
if not isinstance(data, dict):
|
||||
return "Invalid payload", 400
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
pp.pprint(data)
|
||||
try:
|
||||
|
||||
@@ -73,6 +73,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
token_limit: undefined,
|
||||
limited_request_mode: false,
|
||||
request_limit: undefined,
|
||||
allow_system_prompt_override: false,
|
||||
models: [],
|
||||
default_model_id: '',
|
||||
});
|
||||
@@ -241,6 +242,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
formData.append(
|
||||
'allow_system_prompt_override',
|
||||
agent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
|
||||
if (imageFile) formData.append('image', imageFile);
|
||||
|
||||
if (agent.tools && agent.tools.length > 0)
|
||||
@@ -361,6 +367,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
formData.append('request_limit', '0');
|
||||
}
|
||||
|
||||
formData.append(
|
||||
'allow_system_prompt_override',
|
||||
agent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
|
||||
if (agent.models && agent.models.length > 0) {
|
||||
formData.append('models', JSON.stringify(agent.models));
|
||||
}
|
||||
@@ -1266,6 +1277,43 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-6">
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="min-w-0 flex-1">
|
||||
<h2 className="text-sm font-medium">
|
||||
{t('agents.form.advanced.systemPromptOverride')}
|
||||
</h2>
|
||||
<p className="mt-1 text-xs text-gray-600 dark:text-gray-400">
|
||||
{t(
|
||||
'agents.form.advanced.systemPromptOverrideDescription',
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() =>
|
||||
setAgent({
|
||||
...agent,
|
||||
allow_system_prompt_override:
|
||||
!agent.allow_system_prompt_override,
|
||||
})
|
||||
}
|
||||
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
|
||||
agent.allow_system_prompt_override
|
||||
? 'bg-primary'
|
||||
: 'bg-gray-300 dark:bg-gray-600'
|
||||
}`}
|
||||
>
|
||||
<span
|
||||
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
|
||||
agent.allow_system_prompt_override
|
||||
? ''
|
||||
: '-translate-x-5'
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -36,6 +36,7 @@ export type Agent = {
|
||||
default_model_id?: string;
|
||||
folder_id?: string;
|
||||
workflow?: string;
|
||||
allow_system_prompt_override?: boolean;
|
||||
};
|
||||
|
||||
export type AgentFolder = {
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
X,
|
||||
} from 'lucide-react';
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useNavigate, useParams, useSearchParams } from 'react-router-dom';
|
||||
import ReactFlow, {
|
||||
@@ -301,6 +302,7 @@ function createWorkflowPayload(
|
||||
}
|
||||
|
||||
function WorkflowBuilderInner() {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const token = useSelector(selectToken);
|
||||
const sourceDocs = useSelector(selectSourceDocs);
|
||||
@@ -1142,6 +1144,10 @@ function WorkflowBuilderInner() {
|
||||
workflowDescription || `Workflow agent: ${workflowName}`,
|
||||
);
|
||||
agentFormData.append('status', 'published');
|
||||
agentFormData.append(
|
||||
'allow_system_prompt_override',
|
||||
currentAgent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
if (imageFile) {
|
||||
agentFormData.append('image', imageFile);
|
||||
}
|
||||
@@ -1203,6 +1209,10 @@ function WorkflowBuilderInner() {
|
||||
agentFormData.append('agent_type', 'workflow');
|
||||
agentFormData.append('status', 'published');
|
||||
agentFormData.append('workflow', savedWorkflowId || '');
|
||||
agentFormData.append(
|
||||
'allow_system_prompt_override',
|
||||
currentAgent.allow_system_prompt_override ? 'True' : 'False',
|
||||
);
|
||||
if (imageFile) {
|
||||
agentFormData.append('image', imageFile);
|
||||
}
|
||||
@@ -1454,6 +1464,40 @@ function WorkflowBuilderInner() {
|
||||
Image updates are included the next time you save.
|
||||
</p>
|
||||
</div>
|
||||
<div className="mb-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{t('agents.form.advanced.systemPromptOverride')}
|
||||
</label>
|
||||
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
|
||||
{t('agents.form.advanced.systemPromptOverrideDescription')}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() =>
|
||||
setCurrentAgent((prev) => ({
|
||||
...prev,
|
||||
allow_system_prompt_override:
|
||||
!prev.allow_system_prompt_override,
|
||||
}))
|
||||
}
|
||||
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
|
||||
currentAgent.allow_system_prompt_override
|
||||
? 'bg-primary'
|
||||
: 'bg-gray-300 dark:bg-gray-600'
|
||||
}`}
|
||||
>
|
||||
<span
|
||||
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
|
||||
currentAgent.allow_system_prompt_override
|
||||
? ''
|
||||
: '-translate-x-5'
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleWorkflowSettingsDone}
|
||||
disabled={isPublishing}
|
||||
|
||||
@@ -78,6 +78,7 @@ function Dropdown<T extends DropdownOption>({
|
||||
const searchRef = useRef<HTMLInputElement>(null);
|
||||
const [open, setOpen] = useState(false);
|
||||
const [query, setQuery] = useState('');
|
||||
const [dropUp, setDropUp] = useState(false);
|
||||
|
||||
const radius = rounded === '3xl' ? 'rounded-3xl' : 'rounded-xl';
|
||||
const radiusTop = rounded === '3xl' ? 'rounded-t-3xl' : 'rounded-t-xl';
|
||||
@@ -90,14 +91,23 @@ function Dropdown<T extends DropdownOption>({
|
||||
setQuery('');
|
||||
}
|
||||
};
|
||||
document.addEventListener('mousedown', handler);
|
||||
return () => document.removeEventListener('mousedown', handler);
|
||||
document.addEventListener('mousedown', handler, true);
|
||||
return () => document.removeEventListener('mousedown', handler, true);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (open && searchable && searchRef.current) searchRef.current.focus();
|
||||
}, [open, searchable]);
|
||||
|
||||
const handleToggle = () => {
|
||||
if (!open && ref.current) {
|
||||
const rect = ref.current.getBoundingClientRect();
|
||||
const spaceBelow = window.innerHeight - rect.bottom;
|
||||
setDropUp(spaceBelow < 220);
|
||||
}
|
||||
setOpen((v) => !v);
|
||||
};
|
||||
|
||||
const filtered = useMemo(() => {
|
||||
if (!searchable || !query.trim()) return options;
|
||||
const q = query.toLowerCase();
|
||||
@@ -110,8 +120,8 @@ function Dropdown<T extends DropdownOption>({
|
||||
<div className={`relative ${size}`} ref={ref}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setOpen((v) => !v)}
|
||||
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? radiusTop : radius}`}
|
||||
onClick={handleToggle}
|
||||
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? (dropUp ? radiusBottom : radiusTop) : radius}`}
|
||||
>
|
||||
<span
|
||||
className={`truncate ${contentSize} ${displayValue ? '' : 'text-muted-foreground'}`}
|
||||
@@ -125,7 +135,11 @@ function Dropdown<T extends DropdownOption>({
|
||||
|
||||
{open && (
|
||||
<div
|
||||
className={`border-border bg-card absolute inset-x-0 z-20 -mt-px overflow-hidden border border-t-0 shadow-lg ${radiusBottom}`}
|
||||
className={`border-border bg-card absolute inset-x-0 z-20 overflow-hidden border shadow-lg ${
|
||||
dropUp
|
||||
? `bottom-full -mt-px border-b-0 ${radiusTop}`
|
||||
: `-mt-px border-t-0 ${radiusBottom}`
|
||||
}`}
|
||||
>
|
||||
{searchable && (
|
||||
<div className="flex items-center px-3 py-2">
|
||||
|
||||
@@ -10,7 +10,9 @@ interface SkeletonLoaderProps {
|
||||
| 'chatbot'
|
||||
| 'dropdown'
|
||||
| 'chunkCards'
|
||||
| 'sourceCards';
|
||||
| 'sourceCards'
|
||||
| 'toolCards'
|
||||
| 'addToolCards';
|
||||
}
|
||||
|
||||
const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
|
||||
@@ -237,6 +239,55 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
|
||||
</>
|
||||
);
|
||||
|
||||
const renderAddToolCards = () => (
|
||||
<>
|
||||
{Array.from({ length: count }).map((_, idx) => (
|
||||
<div
|
||||
key={`add-tool-skel-${idx}`}
|
||||
className="border-light-gainsboro dark:border-arsenic flex h-52 w-full animate-pulse flex-col justify-between rounded-2xl border p-6"
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center justify-between px-1">
|
||||
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="mt-[9px] space-y-2 px-1">
|
||||
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
|
||||
const renderToolCards = () => (
|
||||
<>
|
||||
{Array.from({ length: count }).map((_, idx) => (
|
||||
<div
|
||||
key={`tool-skel-${idx}`}
|
||||
className="bg-muted flex h-52 w-[300px] animate-pulse flex-col justify-between rounded-2xl p-6"
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="flex items-center gap-2 px-1">
|
||||
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
<div className="mt-[9px] space-y-2 px-1">
|
||||
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
|
||||
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex justify-end">
|
||||
<div className="h-5 w-9 rounded-full bg-gray-300 dark:bg-gray-600"></div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
|
||||
const componentMap = {
|
||||
fileTable: renderTable,
|
||||
chatbot: renderChatbot,
|
||||
@@ -246,6 +297,8 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
|
||||
analysis: renderAnalysis,
|
||||
chunkCards: renderChunkCards,
|
||||
sourceCards: renderSourceCards,
|
||||
toolCards: renderToolCards,
|
||||
addToolCards: renderAddToolCards,
|
||||
};
|
||||
|
||||
const render = componentMap[component] || componentMap.default;
|
||||
|
||||
@@ -619,7 +619,9 @@
|
||||
"tokenLimiting": "Token-Limitierung",
|
||||
"tokenLimitingDescription": "Begrenze die täglich von diesem Agenten verwendbaren Tokens",
|
||||
"requestLimiting": "Anfrage-Limitierung",
|
||||
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen"
|
||||
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen",
|
||||
"systemPromptOverride": "Prompt-Überschreibung erlauben",
|
||||
"systemPromptOverrideDescription": "Erlaubt v1-API-Aufrufern, den System-Prompt dieses Agenten zu ersetzen"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Veröffentlichte Agenten können hier in der Vorschau angezeigt werden"
|
||||
|
||||
@@ -653,7 +653,9 @@
|
||||
"tokenLimiting": "Token limiting",
|
||||
"tokenLimitingDescription": "Limit daily total tokens that can be used by this agent",
|
||||
"requestLimiting": "Request limiting",
|
||||
"requestLimitingDescription": "Limit daily total requests that can be made to this agent"
|
||||
"requestLimitingDescription": "Limit daily total requests that can be made to this agent",
|
||||
"systemPromptOverride": "Allow prompt override",
|
||||
"systemPromptOverrideDescription": "Let v1 API callers replace this agent's system prompt"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Published agents can be previewed here"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "Límite de tokens",
|
||||
"tokenLimitingDescription": "Limita el total diario de tokens que puede usar este agente",
|
||||
"requestLimiting": "Límite de solicitudes",
|
||||
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente"
|
||||
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente",
|
||||
"systemPromptOverride": "Permitir sobrescribir el prompt",
|
||||
"systemPromptOverrideDescription": "Permitir que los llamadores de la API v1 reemplacen el prompt del sistema de este agente"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Los agentes publicados se pueden previsualizar aquí"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "トークン制限",
|
||||
"tokenLimitingDescription": "このエージェントが使用できる1日の合計トークン数を制限します",
|
||||
"requestLimiting": "リクエスト制限",
|
||||
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します"
|
||||
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します",
|
||||
"systemPromptOverride": "プロンプトの上書きを許可",
|
||||
"systemPromptOverrideDescription": "v1 API呼び出し元がこのエージェントのシステムプロンプトを置き換えることを許可します"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "公開されたエージェントはここでプレビューできます"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "Лимит токенов",
|
||||
"tokenLimitingDescription": "Ограничить ежедневное общее количество токенов, которые может использовать этот агент",
|
||||
"requestLimiting": "Лимит запросов",
|
||||
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту"
|
||||
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту",
|
||||
"systemPromptOverride": "Разрешить замену промпта",
|
||||
"systemPromptOverrideDescription": "Разрешить вызовам API v1 заменять системный промпт этого агента"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "Опубликованные агенты можно просмотреть здесь"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "權杖限制",
|
||||
"tokenLimitingDescription": "限制此代理每天可使用的總權杖數",
|
||||
"requestLimiting": "請求限制",
|
||||
"requestLimitingDescription": "限制每天可向此代理發出的總請求數"
|
||||
"requestLimitingDescription": "限制每天可向此代理發出的總請求數",
|
||||
"systemPromptOverride": "允許覆蓋提示詞",
|
||||
"systemPromptOverrideDescription": "允許 v1 API 呼叫者替換此代理的系統提示詞"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "已發佈的代理可以在此處預覽"
|
||||
|
||||
@@ -641,7 +641,9 @@
|
||||
"tokenLimiting": "令牌限制",
|
||||
"tokenLimitingDescription": "限制此代理每天可使用的总令牌数",
|
||||
"requestLimiting": "请求限制",
|
||||
"requestLimitingDescription": "限制每天可向此代理发出的总请求数"
|
||||
"requestLimitingDescription": "限制每天可向此代理发出的总请求数",
|
||||
"systemPromptOverride": "允许覆盖提示词",
|
||||
"systemPromptOverrideDescription": "允许 v1 API 调用者替换此代理的系统提示词"
|
||||
},
|
||||
"preview": {
|
||||
"publishedPreview": "已发布的代理可以在此处预览"
|
||||
|
||||
@@ -3,8 +3,8 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import Spinner from '../components/Spinner';
|
||||
import { useOutsideAlerter } from '../hooks';
|
||||
import SkeletonLoader from '../components/SkeletonLoader';
|
||||
import { useLoaderState, useOutsideAlerter } from '../hooks';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import ConfigToolModal from './ConfigToolModal';
|
||||
@@ -37,7 +37,7 @@ export default function AddToolModal({
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
const [mcpModalState, setMcpModalState] =
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const [loading, setLoading] = useLoaderState(false);
|
||||
|
||||
useOutsideAlerter(modalRef, () => {
|
||||
if (modalState === 'ACTIVE') {
|
||||
@@ -121,8 +121,8 @@ export default function AddToolModal({
|
||||
</h2>
|
||||
<div className="mt-5 h-[73vh] overflow-auto px-3 py-px">
|
||||
{loading ? (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Spinner />
|
||||
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">
|
||||
<SkeletonLoader component="addToolCards" count={6} />
|
||||
</div>
|
||||
) : (
|
||||
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">
|
||||
|
||||
@@ -10,9 +10,9 @@ import NoFilesIcon from '../assets/no-files.svg';
|
||||
import SearchIcon from '../assets/search.svg';
|
||||
import ThreeDotsIcon from '../assets/three-dots.svg';
|
||||
import ContextMenu, { MenuOption } from '../components/ContextMenu';
|
||||
import Spinner from '../components/Spinner';
|
||||
import SkeletonLoader from '../components/SkeletonLoader';
|
||||
import ToggleSwitch from '../components/ToggleSwitch';
|
||||
import { useDarkTheme } from '../hooks';
|
||||
import { useDarkTheme, useLoaderState } from '../hooks';
|
||||
import AddToolModal from '../modals/AddToolModal';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import MCPServerModal from '../modals/MCPServerModal';
|
||||
@@ -33,7 +33,7 @@ export default function Tools() {
|
||||
const [selectedTool, setSelectedTool] = React.useState<
|
||||
UserToolType | APIToolType | null
|
||||
>(null);
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
const [loading, setLoading] = useLoaderState(false);
|
||||
const [activeMenuId, setActiveMenuId] = React.useState<string | null>(null);
|
||||
const menuRefs = React.useRef<{
|
||||
[key: string]: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -242,10 +242,8 @@ export default function Tools() {
|
||||
</div>
|
||||
<div className="border-border dark:border-border mt-5 mb-8 border-b" />
|
||||
{loading ? (
|
||||
<div className="grid grid-cols-2 gap-6 lg:grid-cols-3">
|
||||
<div className="col-span-2 mt-24 flex h-32 items-center justify-center lg:col-span-3">
|
||||
<Spinner />
|
||||
</div>
|
||||
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
|
||||
<SkeletonLoader component="toolCards" count={6} />
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
|
||||
|
||||
218
scripts/db/backfill.py
Normal file
218
scripts/db/backfill.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Backfill DocsGPT's Postgres user-data tables from MongoDB.
|
||||
|
||||
One script for every migrated collection. Adding a new collection is a
|
||||
two-step change in this file:
|
||||
|
||||
1. Write a ``_backfill_<name>`` function that takes keyword args
|
||||
``conn``, ``mongo_db``, ``batch_size``, ``dry_run`` and returns a
|
||||
stats ``dict``.
|
||||
2. Add a single entry to :data:`BACKFILLERS`.
|
||||
|
||||
There are intentionally no per-collection CLI flags or environment
|
||||
variables — ``USE_POSTGRES`` / ``READ_POSTGRES`` in ``.env`` are the
|
||||
only knobs operators need. This script discovers what's available from
|
||||
the :data:`BACKFILLERS` registry and runs whichever tables were asked for.
|
||||
|
||||
Usage::
|
||||
|
||||
python scripts/db/backfill.py # every registered table
|
||||
python scripts/db/backfill.py --tables users # only specific tables
|
||||
python scripts/db/backfill.py --dry-run # count without writing
|
||||
python scripts/db/backfill.py --batch 1000 # tune commit size
|
||||
|
||||
Exit codes:
|
||||
0 — every requested table completed successfully
|
||||
1 — misconfiguration (missing env var, unknown table name)
|
||||
2 — at least one table failed at runtime (others may still have succeeded)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
# Make the project root importable regardless of cwd.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sqlalchemy import Connection, text # noqa: E402
|
||||
|
||||
from application.core.mongo_db import MongoDB # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.engine import get_engine # noqa: E402
|
||||
|
||||
logger = logging.getLogger("backfill")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-table backfillers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _backfill_users(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``users`` table from Mongo ``users`` collection.
|
||||
|
||||
Overwrites each Postgres row's ``agent_preferences`` with the Mongo
|
||||
state (Mongo is source of truth during the cutover window). Missing
|
||||
``pinned`` / ``shared_with_me`` keys are filled with empty arrays so
|
||||
the Postgres row always has the full shape the application expects.
|
||||
"""
|
||||
upsert_sql = text(
|
||||
"""
|
||||
INSERT INTO users (user_id, agent_preferences)
|
||||
VALUES (:user_id, CAST(:prefs AS jsonb))
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET agent_preferences = EXCLUDED.agent_preferences,
|
||||
updated_at = now()
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["users"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
skipped = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user_id")
|
||||
if not user_id:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
raw_prefs = doc.get("agent_preferences") or {}
|
||||
prefs = {
|
||||
"pinned": list(raw_prefs.get("pinned") or []),
|
||||
"shared_with_me": list(raw_prefs.get("shared_with_me") or []),
|
||||
}
|
||||
batch.append({"user_id": user_id, "prefs": json.dumps(prefs)})
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(upsert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(upsert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written, "skipped_no_user_id": skipped}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
BackfillFn = Callable[..., dict]
|
||||
|
||||
# Register new tables here. Order matters only in the sense that
|
||||
# ``--tables`` without arguments iterates in insertion order — put tables
|
||||
# with FK dependencies after the tables they reference so a full-run
|
||||
# backfill doesn't hit FK errors.
|
||||
BACKFILLERS: dict[str, BackfillFn] = {
|
||||
"users": _backfill_users,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Backfill DocsGPT Postgres tables from MongoDB."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tables",
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated table names to backfill. "
|
||||
f"Defaults to every registered table ({','.join(BACKFILLERS)})."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Iterate Mongo without writing to Postgres.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="How many rows to commit per Postgres statement (default: 500).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-5s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
if not settings.POSTGRES_URI:
|
||||
logger.error("POSTGRES_URI is not set. Configure it in .env first.")
|
||||
return 1
|
||||
if not settings.MONGO_URI:
|
||||
logger.error("MONGO_URI is not set. Configure it in .env first.")
|
||||
return 1
|
||||
|
||||
requested = [t.strip() for t in args.tables.split(",") if t.strip()]
|
||||
if not requested:
|
||||
requested = list(BACKFILLERS)
|
||||
|
||||
unknown = [t for t in requested if t not in BACKFILLERS]
|
||||
if unknown:
|
||||
logger.error(
|
||||
"Unknown table(s): %s. Available: %s",
|
||||
", ".join(unknown),
|
||||
", ".join(BACKFILLERS),
|
||||
)
|
||||
return 1
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
mongo_db = mongo[settings.MONGO_DB_NAME]
|
||||
engine = get_engine()
|
||||
|
||||
failures = 0
|
||||
for table in requested:
|
||||
logger.info("backfill %s: start", table)
|
||||
try:
|
||||
with engine.begin() as conn:
|
||||
stats = BACKFILLERS[table](
|
||||
conn=conn,
|
||||
mongo_db=mongo_db,
|
||||
batch_size=args.batch,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
logger.info(
|
||||
"backfill %s: done %s dry_run=%s", table, stats, args.dry_run
|
||||
)
|
||||
except Exception:
|
||||
failures += 1
|
||||
logger.exception("backfill %s: failed", table)
|
||||
|
||||
return 2 if failures else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
55
scripts/db/init_postgres.py
Normal file
55
scripts/db/init_postgres.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""One-shot bootstrap: run all Alembic migrations against POSTGRES_URI.
|
||||
|
||||
Intended use:
|
||||
|
||||
* local dev, after setting ``POSTGRES_URI`` in ``.env``::
|
||||
|
||||
python scripts/db/init_postgres.py
|
||||
|
||||
* CI, as a step before running the pytest suite.
|
||||
|
||||
* Docker image build or container start, if the operator wants the
|
||||
migrations applied automatically on first boot.
|
||||
|
||||
This script is a thin wrapper around ``alembic upgrade head``. It exists
|
||||
separately so the same command is discoverable from the repo root without
|
||||
remembering the ``-c application/alembic.ini`` invocation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
ALEMBIC_INI = REPO_ROOT / "application" / "alembic.ini"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Apply every pending migration up to ``head``.
|
||||
|
||||
Returns:
|
||||
``0`` on success, ``1`` on failure. Non-zero is propagated as the
|
||||
process exit code so CI jobs fail loudly.
|
||||
"""
|
||||
if not ALEMBIC_INI.exists():
|
||||
print(f"alembic.ini not found at {ALEMBIC_INI}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
cfg = Config(str(ALEMBIC_INI))
|
||||
# Make `script_location` resolve correctly when invoked from any cwd.
|
||||
cfg.set_main_option("script_location", str(ALEMBIC_INI.parent / "alembic"))
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, "head")
|
||||
except Exception as exc: # noqa: BLE001 — surface everything to the operator
|
||||
print(f"alembic upgrade failed: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
12
setup.ps1
12
setup.ps1
@@ -543,8 +543,20 @@ function Configure-TTS {
|
||||
}
|
||||
}
|
||||
|
||||
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
|
||||
function Ensure-InternalKey {
|
||||
$content = if (Test-Path $ENV_FILE) { Get-Content $ENV_FILE -Raw } else { "" }
|
||||
if ($content -notmatch "(?m)^INTERNAL_KEY=") {
|
||||
$bytes = New-Object byte[] 32
|
||||
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
|
||||
$internal_key = ($bytes | ForEach-Object { $_.ToString("x2") }) -join ""
|
||||
"INTERNAL_KEY=$internal_key" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
}
|
||||
}
|
||||
|
||||
# Main advanced settings menu
|
||||
function Prompt-AdvancedSettings {
|
||||
Ensure-InternalKey
|
||||
Write-Host ""
|
||||
$configure_advanced = Read-Host "Would you like to configure advanced settings? (y/N)"
|
||||
if ($configure_advanced -ne "y" -and $configure_advanced -ne "Y") {
|
||||
|
||||
10
setup.sh
10
setup.sh
@@ -396,8 +396,18 @@ configure_tts() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
|
||||
ensure_internal_key() {
|
||||
if ! grep -q "^INTERNAL_KEY=" "$ENV_FILE" 2>/dev/null; then
|
||||
local internal_key
|
||||
internal_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
|
||||
echo "INTERNAL_KEY=$internal_key" >> "$ENV_FILE"
|
||||
fi
|
||||
}
|
||||
|
||||
# Main advanced settings menu
|
||||
prompt_advanced_settings() {
|
||||
ensure_internal_key
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Would you like to configure advanced settings? (y/N): ${NC}")" configure_advanced
|
||||
if [[ ! "$configure_advanced" =~ ^[yY]$ ]]; then
|
||||
|
||||
@@ -28,6 +28,7 @@ def _patch_mcp_globals(monkeypatch):
|
||||
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
|
||||
monkeypatch.setattr(mcp_mod, "db", mock_db)
|
||||
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", lambda url: url)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestPostgresExecuteAction:
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_select_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
@@ -37,7 +37,7 @@ class TestPostgresExecuteAction:
|
||||
assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"}
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_insert_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
@@ -55,11 +55,11 @@ class TestPostgresExecuteAction:
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("connection refused")
|
||||
mock_connect.side_effect = psycopg.Error("connection refused")
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT 1"
|
||||
@@ -68,7 +68,7 @@ class TestPostgresExecuteAction:
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_get_schema(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
@@ -89,24 +89,24 @@ class TestPostgresExecuteAction:
|
||||
assert result["schema"]["users"][0]["column_name"] == "id"
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_get_schema_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("auth failed")
|
||||
mock_connect.side_effect = psycopg.Error("auth failed")
|
||||
|
||||
result = tool.execute_action("postgres_get_schema", db_name="testdb")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_connection_closed_on_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = psycopg2.Error("syntax error")
|
||||
mock_cur.execute.side_effect = psycopg.Error("syntax error")
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
@@ -114,7 +114,7 @@ class TestPostgresExecuteAction:
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_select_with_no_description(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
|
||||
@@ -33,6 +33,8 @@ def _patch_mcp_globals(monkeypatch):
|
||||
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
|
||||
monkeypatch.setattr(mcp_mod, "db", mock_db)
|
||||
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
|
||||
# Bypass DNS-resolving URL validation for tests using fake hostnames.
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", lambda u, **kw: u)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -136,6 +138,47 @@ class TestMCPToolInit:
|
||||
})
|
||||
assert tool.custom_headers == {"X-Custom": "val"}
|
||||
|
||||
def test_rejects_metadata_ip(self, monkeypatch):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.core.url_validation import validate_url as real_validate_url
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
|
||||
with pytest.raises(ValueError, match="Invalid MCP server URL"):
|
||||
MCPTool(config={"server_url": "http://169.254.169.254/latest/meta-data", "auth_type": "none"})
|
||||
|
||||
def test_rejects_localhost(self, monkeypatch):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.core.url_validation import validate_url as real_validate_url
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
|
||||
with pytest.raises(ValueError, match="Invalid MCP server URL"):
|
||||
MCPTool(config={"server_url": "http://localhost:8080/mcp", "auth_type": "none"})
|
||||
|
||||
def test_rejects_private_ip(self, monkeypatch):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
from application.core.url_validation import validate_url as real_validate_url
|
||||
import application.agents.tools.mcp_tool as mcp_mod
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
|
||||
with pytest.raises(ValueError, match="Invalid MCP server URL"):
|
||||
MCPTool(config={"server_url": "http://10.0.0.1/mcp", "auth_type": "none"})
|
||||
|
||||
def test_accepts_public_url(self):
|
||||
tool = _make_tool({
|
||||
"server_url": "https://mcp.example.com/api",
|
||||
"auth_type": "none",
|
||||
})
|
||||
assert tool.server_url == "https://mcp.example.com/api"
|
||||
|
||||
def test_empty_server_url_allowed(self):
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
with patch.object(MCPTool, "_setup_client"):
|
||||
tool = MCPTool(config={"server_url": "", "auth_type": "none"})
|
||||
assert tool.server_url == ""
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Redirect URI Resolution
|
||||
|
||||
@@ -330,6 +330,170 @@ class TestStreamProcessorDocPrefetch:
|
||||
assert docs_together is not None
|
||||
assert "Agent doc content" in docs_together
|
||||
|
||||
def test_configure_source_treats_default_string_as_no_docs(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_default_source_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
"source": "default",
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Hi", "api_key": "agent_default_source_key"},
|
||||
None,
|
||||
)
|
||||
processor._configure_agent()
|
||||
processor._configure_source()
|
||||
|
||||
assert processor.source == {}
|
||||
assert processor.all_sources == []
|
||||
|
||||
def test_prefetch_skipped_when_no_active_docs(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Hi there"},
|
||||
{"sub": "user_123"},
|
||||
)
|
||||
processor.initialize()
|
||||
processor.create_retriever = MagicMock()
|
||||
|
||||
docs_together, docs = processor.pre_fetch_docs("Hi there")
|
||||
|
||||
processor.create_retriever.assert_not_called()
|
||||
assert docs_together is None
|
||||
assert docs is None
|
||||
|
||||
def test_prefetch_skipped_when_active_docs_is_default(self, mock_mongo_db):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Hi", "active_docs": "default"},
|
||||
{"sub": "user_123"},
|
||||
)
|
||||
processor.initialize()
|
||||
processor.create_retriever = MagicMock()
|
||||
|
||||
docs_together, docs = processor.pre_fetch_docs("Hi")
|
||||
|
||||
processor.create_retriever.assert_not_called()
|
||||
assert docs_together is None
|
||||
assert docs is None
|
||||
|
||||
def test_agent_retriever_and_chunks_propagate_to_retriever_config(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
source_id = ObjectId()
|
||||
db["sources"].insert_one(
|
||||
{"_id": source_id, "name": "src", "retriever": "hybrid", "chunks": 5}
|
||||
)
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_ret_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
"retriever": "hybrid",
|
||||
"chunks": 5,
|
||||
"source": DBRef("sources", source_id),
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Test", "api_key": "agent_ret_key"},
|
||||
None,
|
||||
)
|
||||
processor.initialize()
|
||||
|
||||
assert processor.retriever_config["retriever_name"] == "hybrid"
|
||||
assert processor.retriever_config["chunks"] == 5
|
||||
|
||||
def test_request_retriever_and_chunks_override_agent_config(self, mock_mongo_db):
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_override_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
"retriever": "hybrid",
|
||||
"chunks": 5,
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{
|
||||
"question": "Test",
|
||||
"api_key": "agent_override_key",
|
||||
"retriever": "classic",
|
||||
"chunks": 7,
|
||||
},
|
||||
None,
|
||||
)
|
||||
processor.initialize()
|
||||
|
||||
assert processor.retriever_config["retriever_name"] == "classic"
|
||||
assert processor.retriever_config["chunks"] == 7
|
||||
|
||||
def test_agent_data_fetched_once_per_request(self, mock_mongo_db):
|
||||
from unittest.mock import patch
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.core.settings import settings
|
||||
|
||||
db = mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
|
||||
agent_id = ObjectId()
|
||||
agents_collection.insert_one(
|
||||
{
|
||||
"_id": agent_id,
|
||||
"key": "agent_once_key",
|
||||
"user": "user_123",
|
||||
"prompt_id": "default",
|
||||
"agent_type": "classic",
|
||||
}
|
||||
)
|
||||
|
||||
processor = StreamProcessor(
|
||||
{"question": "Test", "api_key": "agent_once_key"},
|
||||
None,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
processor, "_get_data_from_api_key", wraps=processor._get_data_from_api_key
|
||||
) as spy:
|
||||
processor.initialize()
|
||||
assert spy.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamProcessorAttachments:
|
||||
|
||||
@@ -566,14 +566,13 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [
|
||||
{"id": "src1", "retriever": "classic"},
|
||||
{"id": "src2", "retriever": "hybrid"},
|
||||
],
|
||||
"source": None,
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": ["src1", "src2"]}
|
||||
assert len(sp.all_sources) == 2
|
||||
@@ -593,12 +592,11 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [],
|
||||
"source": "single_src",
|
||||
"retriever": "classic",
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": "single_src"}
|
||||
assert len(sp.all_sources) == 1
|
||||
@@ -618,8 +616,7 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {"sources": [], "source": None}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._agent_data = {"sources": [], "source": None}
|
||||
sp._configure_source()
|
||||
assert sp.source == {}
|
||||
assert sp.all_sources == []
|
||||
@@ -639,11 +636,10 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = "agent_key_123"
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [{"id": "s1", "retriever": "classic"}],
|
||||
"source": None,
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": ["s1"]}
|
||||
|
||||
@@ -662,11 +658,10 @@ class TestConfigureSource:
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.agent_key = None
|
||||
agent_data = {
|
||||
sp._agent_data = {
|
||||
"sources": [{"id": None}, {"retriever": "classic"}],
|
||||
"source": None,
|
||||
}
|
||||
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
|
||||
sp._configure_source()
|
||||
assert sp.source == {}
|
||||
|
||||
@@ -1189,6 +1184,8 @@ class TestConfigureAgent:
|
||||
"chunks": "5",
|
||||
})
|
||||
sp._configure_agent()
|
||||
sp.model_id = "test-model"
|
||||
sp._configure_retriever()
|
||||
assert sp.agent_config["workflow"] == "wf_123"
|
||||
assert sp.agent_config["workflow_owner"] == "user1"
|
||||
assert sp.retriever_config["retriever_name"] == "hybrid"
|
||||
@@ -1211,6 +1208,8 @@ class TestConfigureAgent:
|
||||
"chunks": "not_a_number",
|
||||
})
|
||||
sp._configure_agent()
|
||||
sp.model_id = "test-model"
|
||||
sp._configure_retriever()
|
||||
assert sp.retriever_config["chunks"] == 2
|
||||
|
||||
|
||||
@@ -1763,8 +1762,8 @@ class TestConfigureAgentAdditionalPaths:
|
||||
assert sp.decoded_token == {"sub": "owner_user"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_configure_agent_with_source_in_data_key(self):
|
||||
"""Cover line 463-464: data_key has 'source' set."""
|
||||
def test_configure_source_picks_up_cached_agent_data(self):
|
||||
"""After _configure_agent caches _agent_data, _configure_source uses it."""
|
||||
sp = self._make_sp()
|
||||
sp._resolve_agent_id = MagicMock(return_value="agent_id_1")
|
||||
sp._get_agent_key = MagicMock(return_value=("agent_key", False, None))
|
||||
@@ -1780,6 +1779,7 @@ class TestConfigureAgentAdditionalPaths:
|
||||
"source": "my_source",
|
||||
})
|
||||
sp._configure_agent()
|
||||
sp._configure_source()
|
||||
assert sp.source == {"active_docs": "my_source"}
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -2067,7 +2067,7 @@ class TestPreFetchDocsFullPaths:
|
||||
"chunks": 2,
|
||||
"doc_token_limit": 50000,
|
||||
}
|
||||
sp.source = {}
|
||||
sp.source = {"active_docs": ["src1"]}
|
||||
sp.model_id = "test-model"
|
||||
sp.agent_id = None
|
||||
return sp
|
||||
|
||||
@@ -48,7 +48,7 @@ def internal_app(monkeypatch, mock_mongo_db):
|
||||
@pytest.mark.unit
|
||||
class TestVerifyInternalKey:
|
||||
|
||||
def test_no_internal_key_configured_allows_access(
|
||||
def test_no_internal_key_configured_rejects_access(
|
||||
self, internal_app, monkeypatch
|
||||
):
|
||||
app, db = internal_app
|
||||
@@ -63,9 +63,8 @@ class TestVerifyInternalKey:
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
# download will fail for missing file but should not be 401
|
||||
resp = client.get("/api/download?user=u&name=n&file=f")
|
||||
assert resp.status_code != 401
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_missing_key_returns_401(self, internal_app, monkeypatch):
|
||||
app, db = internal_app
|
||||
@@ -131,9 +130,12 @@ class TestVerifyInternalKey:
|
||||
@pytest.mark.unit
|
||||
class TestUploadIndex:
|
||||
|
||||
_TEST_INTERNAL_KEY = "test-internal-key"
|
||||
_AUTH_HEADERS = {"X-Internal-Key": "test-internal-key"}
|
||||
|
||||
def _make_settings(self, vector_store="faiss"):
|
||||
return MagicMock(
|
||||
INTERNAL_KEY=None,
|
||||
INTERNAL_KEY=self._TEST_INTERNAL_KEY,
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE=vector_store,
|
||||
EMBEDDINGS_NAME="test_embeddings",
|
||||
@@ -146,7 +148,7 @@ class TestUploadIndex:
|
||||
"application.api.internal.routes.settings", self._make_settings()
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={})
|
||||
resp = client.post("/api/upload_index", data={}, headers=self._AUTH_HEADERS)
|
||||
assert resp.json["status"] == "no user"
|
||||
|
||||
def test_missing_name_returns_no_name(self, internal_app, monkeypatch):
|
||||
@@ -155,7 +157,7 @@ class TestUploadIndex:
|
||||
"application.api.internal.routes.settings", self._make_settings()
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={"user": "testuser"})
|
||||
resp = client.post("/api/upload_index", data={"user": "testuser"}, headers=self._AUTH_HEADERS)
|
||||
assert resp.json["status"] == "no name"
|
||||
|
||||
def test_creates_new_source_entry(self, internal_app, monkeypatch):
|
||||
@@ -182,6 +184,7 @@ class TestUploadIndex:
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -219,6 +222,7 @@ class TestUploadIndex:
|
||||
"id": str(doc_id),
|
||||
"type": "remote",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -252,6 +256,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"directory_structure": json.dumps(dir_struct),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -285,6 +290,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"directory_structure": "not valid json",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -317,6 +323,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_name_map": json.dumps(fmap),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -349,6 +356,7 @@ class TestUploadIndex:
|
||||
"id": doc_id,
|
||||
"type": "local",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file"
|
||||
|
||||
@@ -379,6 +387,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_faiss": (io.BytesIO(b""), ""),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file name"
|
||||
|
||||
@@ -408,6 +417,7 @@ class TestUploadIndex:
|
||||
"remote_data": '{"url":"http://example.com"}',
|
||||
"sync_frequency": "daily",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -443,6 +453,7 @@ class TestUploadIndex:
|
||||
"file_pkl": (io.BytesIO(b"pkl data"), "index.pkl"),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -477,6 +488,7 @@ class TestUploadIndex:
|
||||
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file"
|
||||
|
||||
@@ -508,9 +520,27 @@ class TestUploadIndex:
|
||||
"file_pkl": (io.BytesIO(b""), ""),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "no file name"
|
||||
|
||||
def test_no_internal_key_rejects_upload(self, internal_app, monkeypatch):
|
||||
"""Verify that upload_index is rejected when INTERNAL_KEY is not set."""
|
||||
app, db = internal_app
|
||||
monkeypatch.setattr(
|
||||
"application.api.internal.routes.settings",
|
||||
MagicMock(
|
||||
INTERNAL_KEY=None,
|
||||
UPLOAD_FOLDER="uploads",
|
||||
VECTOR_STORE="faiss",
|
||||
EMBEDDINGS_NAME="test",
|
||||
MONGO_DB_NAME="docsgpt",
|
||||
),
|
||||
)
|
||||
with app.test_client() as client:
|
||||
resp = client.post("/api/upload_index", data={"user": "attacker"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_update_existing_with_file_name_map(self, internal_app, monkeypatch):
|
||||
"""Cover line 124: update existing entry with file_name_map."""
|
||||
app, db = internal_app
|
||||
@@ -540,6 +570,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_name_map": json.dumps(fmap),
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
@@ -572,6 +603,7 @@ class TestUploadIndex:
|
||||
"type": "local",
|
||||
"file_name_map": "not valid json{{{",
|
||||
},
|
||||
headers=self._AUTH_HEADERS,
|
||||
)
|
||||
assert resp.json["status"] == "ok"
|
||||
|
||||
|
||||
@@ -14,6 +14,13 @@ def app():
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bypass_url_validation():
|
||||
"""Bypass SSRF URL validation so tests using localhost URLs can proceed."""
|
||||
with patch("application.api.user.tools.mcp.validate_url"):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: _sanitize_mcp_transport
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -395,6 +395,9 @@ class TestAvailableTools:
|
||||
"application.api.user.tools.routes.tool_manager", mock_manager
|
||||
):
|
||||
with app.test_request_context("/api/available_tools"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AvailableTools().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -419,6 +422,9 @@ class TestAvailableTools:
|
||||
"application.api.user.tools.routes.tool_manager", mock_manager
|
||||
):
|
||||
with app.test_request_context("/api/available_tools"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AvailableTools().get()
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -438,6 +444,9 @@ class TestAvailableTools:
|
||||
"application.api.user.tools.routes.tool_manager", mock_manager
|
||||
):
|
||||
with app.test_request_context("/api/available_tools"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user1"}
|
||||
response = AvailableTools().get()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
0
tests/api/v1/__init__.py
Normal file
0
tests/api/v1/__init__.py
Normal file
64
tests/api/v1/test_routes.py
Normal file
64
tests/api/v1/test_routes.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from flask import Flask
|
||||
|
||||
from application.api.v1.routes import v1_bp
|
||||
|
||||
|
||||
class _FakeCollection:
|
||||
def __init__(self, docs):
|
||||
self.docs = docs
|
||||
|
||||
def find_one(self, query):
|
||||
for doc in self.docs:
|
||||
if all(doc.get(k) == v for k, v in query.items()):
|
||||
return doc
|
||||
return None
|
||||
|
||||
def find(self, query):
|
||||
return [doc for doc in self.docs if all(doc.get(k) == v for k, v in query.items())]
|
||||
|
||||
|
||||
def _build_app():
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(v1_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_v1_models_does_not_expose_agent_keys(monkeypatch):
|
||||
docs = [
|
||||
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
|
||||
{"_id": "agent-2", "key": "key-2", "user": "user-1", "name": "Agent Two"},
|
||||
]
|
||||
|
||||
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
|
||||
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
|
||||
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
|
||||
|
||||
app = _build_app()
|
||||
client = app.test_client()
|
||||
response = client.get("/v1/models", headers={"Authorization": "Bearer key-1"})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["object"] == "list"
|
||||
assert len(payload["data"]) == 2
|
||||
assert payload["data"][0]["id"] == "agent-1"
|
||||
assert payload["data"][1]["id"] == "agent-2"
|
||||
# Keys must never appear as model IDs
|
||||
assert all(model["id"] != "key-1" for model in payload["data"])
|
||||
assert all(model["id"] != "key-2" for model in payload["data"])
|
||||
|
||||
|
||||
def test_v1_models_invalid_key_returns_401(monkeypatch):
|
||||
docs = [
|
||||
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
|
||||
]
|
||||
|
||||
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
|
||||
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
|
||||
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
|
||||
|
||||
app = _build_app()
|
||||
client = app.test_client()
|
||||
response = client.get("/v1/models", headers={"Authorization": "Bearer wrong-key"})
|
||||
|
||||
assert response.status_code == 401
|
||||
144
tests/core/test_db_uri.py
Normal file
144
tests/core/test_db_uri.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Tests for ``application.core.db_uri``.
|
||||
|
||||
DocsGPT has two Postgres connection strings — ``POSTGRES_URI`` (consumed
|
||||
by SQLAlchemy) and ``PGVECTOR_CONNECTION_STRING`` (consumed by
|
||||
``psycopg.connect()`` directly). They need opposite normalization
|
||||
because SQLAlchemy requires a ``postgresql+psycopg://`` dialect prefix
|
||||
and libpq rejects it. Each field has its own normalizer so operators
|
||||
can write whichever form feels natural and cross-pollination between
|
||||
the two fields is forgiven.
|
||||
|
||||
The normalizers live in ``application.core.db_uri`` as plain functions
|
||||
so these tests can exercise them directly without having to instantiate
|
||||
``Settings`` (which would pull in ``.env`` file side effects).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.db_uri import (
|
||||
normalize_pgvector_connection_string,
|
||||
normalize_postgres_uri,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNormalizePostgresUri:
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,expected",
|
||||
[
|
||||
# User-friendly forms get rewritten to the SQLAlchemy dialect.
|
||||
(
|
||||
"postgres://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
(
|
||||
"postgresql://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
# Legacy psycopg2 dialect is silently upgraded — psycopg2 is
|
||||
# no longer in requirements.txt, so there's no way it can work
|
||||
# as-is, and rewriting is friendlier than failing.
|
||||
(
|
||||
"postgresql+psycopg2://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
# Already-correct dialect passes through unchanged.
|
||||
(
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
# Whitespace is trimmed before rewriting.
|
||||
(
|
||||
" postgres://u:p@h/d ",
|
||||
"postgresql+psycopg://u:p@h/d",
|
||||
),
|
||||
# Query-string params (sslmode, options) are preserved verbatim.
|
||||
(
|
||||
"postgresql://u:p@h:5432/d?sslmode=require&application_name=docsgpt",
|
||||
"postgresql+psycopg://u:p@h:5432/d?sslmode=require&application_name=docsgpt",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rewrites_common_forms_to_psycopg_dialect(self, input_value, expected):
|
||||
assert normalize_postgres_uri(input_value) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value",
|
||||
[None, "", " ", "None", "none"],
|
||||
)
|
||||
def test_empty_or_none_like_returns_none(self, input_value):
|
||||
assert normalize_postgres_uri(input_value) is None
|
||||
|
||||
def test_unknown_scheme_passes_through(self):
|
||||
"""A dialect we don't recognise is left alone so SQLAlchemy can
|
||||
produce its own error message when the engine tries to connect.
|
||||
Better than silently eating the config."""
|
||||
weird = "postgresql+asyncpg://u:p@h/d"
|
||||
assert normalize_postgres_uri(weird) == weird
|
||||
|
||||
def test_non_string_input_passes_through(self):
|
||||
"""Non-string inputs (e.g. if pydantic ever passes an int) shouldn't
|
||||
crash the normalizer — let pydantic's own type validation handle it."""
|
||||
assert normalize_postgres_uri(42) == 42 # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNormalizePgvectorConnectionString:
|
||||
"""Symmetric to the POSTGRES_URI normalizer but pulls in the OPPOSITE
|
||||
direction: strips the SQLAlchemy dialect prefix so libpq accepts it.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,expected",
|
||||
[
|
||||
# User-friendly forms pass through — libpq accepts them natively.
|
||||
(
|
||||
"postgres://u:p@h:5432/d",
|
||||
"postgres://u:p@h:5432/d",
|
||||
),
|
||||
(
|
||||
"postgresql://u:p@h:5432/d",
|
||||
"postgresql://u:p@h:5432/d",
|
||||
),
|
||||
# SQLAlchemy dialect prefixes get stripped so libpq accepts them.
|
||||
# Operators hit this when they copy POSTGRES_URI → PGVECTOR_CONNECTION_STRING.
|
||||
(
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
"postgresql://u:p@h:5432/d",
|
||||
),
|
||||
(
|
||||
"postgresql+psycopg2://u:p@h:5432/d",
|
||||
"postgresql://u:p@h:5432/d",
|
||||
),
|
||||
# Whitespace is trimmed before rewriting.
|
||||
(
|
||||
" postgresql+psycopg://u:p@h/d ",
|
||||
"postgresql://u:p@h/d",
|
||||
),
|
||||
# Query-string params (sslmode, etc.) are preserved verbatim.
|
||||
(
|
||||
"postgresql+psycopg://u:p@h:5432/d?sslmode=require",
|
||||
"postgresql://u:p@h:5432/d?sslmode=require",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rewrites_dialect_forms_to_libpq_compatible(self, input_value, expected):
|
||||
assert normalize_pgvector_connection_string(input_value) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value",
|
||||
[None, "", " ", "None", "none"],
|
||||
)
|
||||
def test_empty_or_none_like_returns_none(self, input_value):
|
||||
assert normalize_pgvector_connection_string(input_value) is None
|
||||
|
||||
def test_unknown_scheme_passes_through(self):
|
||||
"""A scheme we don't recognise is left alone so libpq can produce
|
||||
its own error message when the connection is attempted."""
|
||||
weird = "mysql://u:p@h/d"
|
||||
assert normalize_pgvector_connection_string(weird) == weird
|
||||
|
||||
def test_non_string_input_passes_through(self):
|
||||
assert normalize_pgvector_connection_string(42) == 42 # type: ignore[arg-type]
|
||||
84
tests/integration/conftest.py
Normal file
84
tests/integration/conftest.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Fixtures for integration tests that hit a live Postgres.
|
||||
|
||||
These tests are separate from unit tests in two ways:
|
||||
|
||||
1. **Directory**: They live under ``tests/integration/``, which is
|
||||
ignored by the default pytest run via ``--ignore=tests/integration``
|
||||
in ``pytest.ini``. The CI ``pytest.yml`` workflow therefore skips
|
||||
them automatically — it runs the fast unit suite only.
|
||||
2. **Marker**: Each test is marked ``@pytest.mark.integration`` so it
|
||||
can be selected (or excluded) independently of its directory with
|
||||
``pytest -m integration`` / ``-m "not integration"``.
|
||||
|
||||
Running the Postgres-backed integration tests manually::
|
||||
|
||||
.venv/bin/python -m pytest tests/integration/test_users_repository.py \\
|
||||
--override-ini="addopts=" --no-cov
|
||||
|
||||
(The ``--override-ini="addopts="`` is needed because ``pytest.ini``
|
||||
contains ``--ignore=tests/integration`` in ``addopts``; without the
|
||||
override pytest would still skip the directory even though you named
|
||||
it on the command line.)
|
||||
|
||||
Tests are skipped automatically if ``POSTGRES_URI`` is unset, so a
|
||||
contributor who hasn't set up a local Postgres gets clean skips
|
||||
instead of red tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine, text
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pg_engine() -> Engine:
|
||||
"""Session-scoped SQLAlchemy engine for the Postgres integration DB.
|
||||
|
||||
Skips all Postgres-backed tests if ``POSTGRES_URI`` is unset. This
|
||||
keeps CI and contributor machines without a local Postgres from
|
||||
erroring out — integration tests that require the DB become
|
||||
no-ops rather than failures.
|
||||
"""
|
||||
if not settings.POSTGRES_URI:
|
||||
pytest.skip("POSTGRES_URI not set — skipping Postgres integration tests")
|
||||
engine = create_engine(settings.POSTGRES_URI, future=True, pool_pre_ping=True)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_conn(pg_engine: Engine):
|
||||
"""Per-test Postgres connection wrapped in a rolled-back transaction.
|
||||
|
||||
Uses SQLAlchemy's explicit outer-transaction pattern so every test
|
||||
sees a pristine DB view without having to truncate tables. Any
|
||||
nested ``begin()`` inside the repository code becomes a SAVEPOINT
|
||||
under the hood.
|
||||
"""
|
||||
conn = pg_engine.connect()
|
||||
outer = conn.begin()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
outer.rollback()
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_clean_users(pg_conn):
|
||||
"""Guarantee a clean ``users`` table view for tests that need it.
|
||||
|
||||
The outer transaction rollback handles cleanup, but if a previous
|
||||
interrupted run left rows committed, this fixture removes them
|
||||
inside the transaction scope so they are invisible to the test.
|
||||
``DELETE`` rather than ``TRUNCATE`` because ``TRUNCATE`` in Postgres
|
||||
cannot be rolled back within a nested transaction the way
|
||||
``DELETE`` can.
|
||||
"""
|
||||
pg_conn.execute(text("DELETE FROM users"))
|
||||
return pg_conn
|
||||
179
tests/integration/test_users_repository.py
Normal file
179
tests/integration/test_users_repository.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Integration tests for ``UsersRepository`` against a live Postgres.
|
||||
|
||||
These tests need:
|
||||
|
||||
* A running Postgres reachable via ``POSTGRES_URI``
|
||||
* Alembic migration ``0001_initial`` applied
|
||||
|
||||
They are skipped automatically by the default ``pytest`` run because
|
||||
``pytest.ini`` has ``--ignore=tests/integration`` in ``addopts``. They
|
||||
are additionally marked ``@pytest.mark.integration`` so they can be
|
||||
selected/excluded explicitly. Run them locally with::
|
||||
|
||||
.venv/bin/python -m pytest tests/integration/test_users_repository.py \\
|
||||
--override-ini="addopts=" --no-cov
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``users_collection``: upsert + get + add/remove on the pinned and
|
||||
shared_with_me JSONB arrays, plus explicit tenant-isolation checks
|
||||
(a user's operations must never touch another user's data).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo(pg_clean_users):
|
||||
return UsersRepository(pg_clean_users)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestUpsert:
|
||||
def test_upsert_creates_new_user_with_default_preferences(self, repo):
|
||||
doc = repo.upsert("alice@example.com")
|
||||
assert doc["user_id"] == "alice@example.com"
|
||||
assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []}
|
||||
assert "id" in doc
|
||||
assert "_id" in doc # Mongo-compat alias
|
||||
|
||||
def test_upsert_is_idempotent(self, repo):
|
||||
first = repo.upsert("alice@example.com")
|
||||
second = repo.upsert("alice@example.com")
|
||||
assert first["id"] == second["id"]
|
||||
assert first["agent_preferences"] == second["agent_preferences"]
|
||||
|
||||
def test_upsert_preserves_existing_preferences(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.upsert("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestGet:
|
||||
def test_get_returns_none_for_missing_user(self, repo):
|
||||
assert repo.get("nobody@example.com") is None
|
||||
|
||||
def test_get_returns_dict_for_existing_user(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc is not None
|
||||
assert doc["user_id"] == "alice@example.com"
|
||||
assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAddPinned:
|
||||
def test_add_pinned_to_new_user_creates_user(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
def test_add_pinned_is_idempotent(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
def test_add_pinned_preserves_order(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.add_pinned("alice@example.com", "agent-2")
|
||||
repo.add_pinned("alice@example.com", "agent-3")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1", "agent-2", "agent-3"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRemovePinned:
|
||||
def test_remove_pinned_single(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.add_pinned("alice@example.com", "agent-2")
|
||||
repo.remove_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-2"]
|
||||
|
||||
def test_remove_pinned_missing_is_noop(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.remove_pinned("alice@example.com", "agent-999")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
def test_remove_pinned_bulk(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
for i in range(5):
|
||||
repo.add_pinned("alice@example.com", f"agent-{i}")
|
||||
repo.remove_pinned_bulk("alice@example.com", ["agent-1", "agent-3", "agent-999"])
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-0", "agent-2", "agent-4"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSharedWithMe:
|
||||
def test_add_shared_is_idempotent(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_shared("alice@example.com", "agent-x")
|
||||
repo.add_shared("alice@example.com", "agent-x")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["agent-x"]
|
||||
|
||||
def test_remove_shared_bulk(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
for i in range(3):
|
||||
repo.add_shared("alice@example.com", f"shared-{i}")
|
||||
repo.remove_shared_bulk("alice@example.com", ["shared-0", "shared-2"])
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["shared-1"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRemoveAgentFromAll:
|
||||
def test_removes_from_both_pinned_and_shared(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-X")
|
||||
repo.add_pinned("alice@example.com", "agent-keep")
|
||||
repo.add_shared("alice@example.com", "agent-X")
|
||||
repo.add_shared("alice@example.com", "agent-keep-2")
|
||||
|
||||
repo.remove_agent_from_all("alice@example.com", "agent-X")
|
||||
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-keep"]
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["agent-keep-2"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTenantIsolation:
|
||||
"""Security-critical: operations on one user must never touch another's data."""
|
||||
|
||||
def test_add_pinned_does_not_leak_across_users(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.upsert("bob@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-a")
|
||||
repo.add_pinned("bob@example.com", "agent-b")
|
||||
|
||||
alice = repo.get("alice@example.com")
|
||||
bob = repo.get("bob@example.com")
|
||||
assert alice["agent_preferences"]["pinned"] == ["agent-a"]
|
||||
assert bob["agent_preferences"]["pinned"] == ["agent-b"]
|
||||
|
||||
def test_remove_does_not_leak_across_users(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.upsert("bob@example.com")
|
||||
repo.add_pinned("alice@example.com", "shared-agent-id")
|
||||
repo.add_pinned("bob@example.com", "shared-agent-id")
|
||||
|
||||
repo.remove_pinned("alice@example.com", "shared-agent-id")
|
||||
|
||||
assert repo.get("alice@example.com")["agent_preferences"]["pinned"] == []
|
||||
assert repo.get("bob@example.com")["agent_preferences"]["pinned"] == [
|
||||
"shared-agent-id"
|
||||
]
|
||||
@@ -20,133 +20,49 @@ def test_epub_init_parser():
|
||||
assert parser.parser_config_set
|
||||
|
||||
|
||||
def test_epub_parser_ebooklib_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when ebooklib is not available."""
|
||||
with patch.dict(sys.modules, {"ebooklib": None}):
|
||||
with pytest.raises(ValueError, match="`EbookLib` is required to read Epub files"):
|
||||
def test_epub_parser_fast_ebook_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when fast-ebook is not available."""
|
||||
with patch.dict(sys.modules, {"fast_ebook": None}):
|
||||
with pytest.raises(ValueError, match="`fast-ebook` is required to read Epub files"):
|
||||
epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_html2text_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when html2text is not available."""
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
with patch.dict(sys.modules, {"ebooklib": fake_ebooklib, "ebooklib.epub": fake_epub}):
|
||||
with patch.dict(sys.modules, {"html2text": None}):
|
||||
with pytest.raises(ValueError, match="`html2text` is required to parse Epub files"):
|
||||
epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_successful_parsing(epub_parser):
|
||||
"""Test successful parsing of an epub file."""
|
||||
fake_fast_ebook = types.ModuleType("fast_ebook")
|
||||
fake_epub = types.ModuleType("fast_ebook.epub")
|
||||
fake_fast_ebook.epub = fake_epub
|
||||
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
# Mock ebooklib constants
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
mock_item1 = MagicMock()
|
||||
mock_item1.get_type.return_value = "document"
|
||||
mock_item1.get_content.return_value = b"<h1>Chapter 1</h1><p>Content 1</p>"
|
||||
|
||||
mock_item2 = MagicMock()
|
||||
mock_item2.get_type.return_value = "document"
|
||||
mock_item2.get_content.return_value = b"<h1>Chapter 2</h1><p>Content 2</p>"
|
||||
|
||||
mock_item3 = MagicMock()
|
||||
mock_item3.get_type.return_value = "other" # Should be ignored
|
||||
mock_item3.get_content.return_value = b"<p>Other content</p>"
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = [mock_item1, mock_item2, mock_item3]
|
||||
|
||||
mock_book.to_markdown.return_value = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
|
||||
def mock_html2text_func(html_content):
|
||||
if "Chapter 1" in html_content:
|
||||
return "# Chapter 1\n\nContent 1\n"
|
||||
elif "Chapter 2" in html_content:
|
||||
return "# Chapter 2\n\nContent 2\n"
|
||||
return "Other content\n"
|
||||
|
||||
fake_html2text.html2text = mock_html2text_func
|
||||
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
"fast_ebook": fake_fast_ebook,
|
||||
"fast_ebook.epub": fake_epub,
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
expected_result = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
assert result == expected_result
|
||||
|
||||
# Verify epub.read_epub was called with correct parameters
|
||||
fake_epub.read_epub.assert_called_once_with(Path("test.epub"), options={"ignore_ncx": True})
|
||||
|
||||
assert result == "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
fake_epub.read_epub.assert_called_once_with(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_empty_book(epub_parser):
|
||||
"""Test parsing an epub file with no document items."""
|
||||
# Create mock modules
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
# Create mock book with no document items
|
||||
"""Test parsing an epub file with no content."""
|
||||
fake_fast_ebook = types.ModuleType("fast_ebook")
|
||||
fake_epub = types.ModuleType("fast_ebook.epub")
|
||||
fake_fast_ebook.epub = fake_epub
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = []
|
||||
|
||||
mock_book.to_markdown.return_value = ""
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
fake_html2text.html2text = MagicMock()
|
||||
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
"fast_ebook": fake_fast_ebook,
|
||||
"fast_ebook.epub": fake_epub,
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("empty.epub"))
|
||||
|
||||
assert result == ""
|
||||
|
||||
fake_html2text.html2text.assert_not_called()
|
||||
|
||||
|
||||
def test_epub_parser_non_document_items_ignored(epub_parser):
|
||||
"""Test that non-document items are ignored during parsing."""
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
mock_doc_item = MagicMock()
|
||||
mock_doc_item.get_type.return_value = "document"
|
||||
mock_doc_item.get_content.return_value = b"<p>Document content</p>"
|
||||
|
||||
mock_other_item = MagicMock()
|
||||
mock_other_item.get_type.return_value = "image" # Not a document
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = [mock_other_item, mock_doc_item]
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
fake_html2text.html2text = MagicMock(return_value="Document content\n")
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
assert result == "Document content\n"
|
||||
|
||||
fake_html2text.html2text.assert_called_once_with("<p>Document content</p>")
|
||||
|
||||
@@ -8,7 +8,7 @@ from application.storage.local import LocalStorage
|
||||
|
||||
@pytest.fixture
|
||||
def temp_base_dir():
|
||||
return "/tmp/test_storage"
|
||||
return os.path.realpath("/tmp/test_storage")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -30,12 +30,12 @@ class TestLocalStorageInitialization:
|
||||
|
||||
def test_get_full_path_with_relative_path(self, local_storage):
|
||||
result = local_storage._get_full_path("documents/test.txt")
|
||||
expected = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
assert os.path.normpath(result) == os.path.normpath(expected)
|
||||
expected = os.path.realpath(os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt"))
|
||||
assert result == expected
|
||||
|
||||
def test_get_full_path_with_absolute_path(self, local_storage):
|
||||
result = local_storage._get_full_path("/absolute/path/test.txt")
|
||||
assert result == "/absolute/path/test.txt"
|
||||
def test_get_full_path_with_absolute_path_outside_base_raises(self, local_storage):
|
||||
with pytest.raises(ValueError, match="Path traversal detected"):
|
||||
local_storage._get_full_path("/absolute/path/test.txt")
|
||||
|
||||
@patch("os.makedirs")
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
@@ -48,8 +48,8 @@ class TestLocalStorageInitialization:
|
||||
|
||||
result = local_storage.save_file(file_data, path)
|
||||
|
||||
expected_dir = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
|
||||
assert mock_makedirs.call_count == 1
|
||||
assert os.path.normpath(mock_makedirs.call_args[0][0]) == os.path.normpath(
|
||||
@@ -74,25 +74,19 @@ class TestLocalStorageInitialization:
|
||||
|
||||
result = local_storage.save_file(file_data, path)
|
||||
|
||||
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert file_data.save.call_count == 1
|
||||
assert os.path.normpath(file_data.save.call_args[0][0]) == os.path.normpath(
|
||||
expected_file
|
||||
)
|
||||
assert result == {"storage_type": "local"}
|
||||
|
||||
@patch("os.makedirs")
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
def test_save_file_with_absolute_path(
|
||||
self, mock_file, mock_makedirs, local_storage
|
||||
):
|
||||
def test_save_file_with_absolute_path_outside_base_raises(self, local_storage):
|
||||
file_data = io.BytesIO(b"test content")
|
||||
path = "/absolute/path/test.txt"
|
||||
|
||||
local_storage.save_file(file_data, path)
|
||||
|
||||
mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True)
|
||||
mock_file.assert_called_once_with("/absolute/path/test.txt", "wb")
|
||||
with pytest.raises(ValueError, match="Path traversal detected"):
|
||||
local_storage.save_file(file_data, path)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -105,7 +99,7 @@ class TestLocalStorageGetFile:
|
||||
|
||||
result = local_storage.get_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
@@ -122,7 +116,7 @@ class TestLocalStorageGetFile:
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="File not found"):
|
||||
local_storage.get_file(path)
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
expected_path
|
||||
@@ -141,7 +135,7 @@ class TestLocalStorageDeleteFile:
|
||||
|
||||
result = local_storage.delete_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is True
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -158,7 +152,7 @@ class TestLocalStorageDeleteFile:
|
||||
|
||||
result = local_storage.delete_file(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -175,7 +169,7 @@ class TestLocalStorageFileExists:
|
||||
|
||||
result = local_storage.file_exists(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is True
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -188,7 +182,7 @@ class TestLocalStorageFileExists:
|
||||
|
||||
result = local_storage.file_exists(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -205,7 +199,7 @@ class TestLocalStorageListFiles:
|
||||
self, mock_exists, mock_walk, local_storage
|
||||
):
|
||||
directory = "documents"
|
||||
base_dir = os.path.join("/tmp/test_storage", "documents")
|
||||
base_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
|
||||
mock_walk.return_value = [
|
||||
(base_dir, ["subdir"], ["file1.txt", "file2.txt"]),
|
||||
@@ -228,7 +222,7 @@ class TestLocalStorageListFiles:
|
||||
|
||||
result = local_storage.list_files(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
|
||||
assert result == []
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -248,7 +242,7 @@ class TestLocalStorageProcessFile:
|
||||
|
||||
result = local_storage.process_file(path, processor_func, extra_arg="value")
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result == "processed"
|
||||
assert processor_func.call_count == 1
|
||||
call_kwargs = processor_func.call_args[1]
|
||||
@@ -280,7 +274,7 @@ class TestLocalStorageIsDirectory:
|
||||
|
||||
result = local_storage.is_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is True
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
@@ -295,7 +289,7 @@ class TestLocalStorageIsDirectory:
|
||||
|
||||
result = local_storage.is_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is False
|
||||
assert mock_isdir.call_count == 1
|
||||
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
|
||||
@@ -316,7 +310,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is True
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -339,7 +333,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -355,7 +349,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(path)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
|
||||
assert result is False
|
||||
assert mock_exists.call_count == 1
|
||||
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
|
||||
@@ -376,7 +370,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is False
|
||||
assert mock_rmtree.call_count == 1
|
||||
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
|
||||
@@ -393,7 +387,7 @@ class TestLocalStorageRemoveDirectory:
|
||||
|
||||
result = local_storage.remove_directory(directory)
|
||||
|
||||
expected_path = os.path.join("/tmp/test_storage", "documents")
|
||||
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
|
||||
assert result is False
|
||||
assert mock_rmtree.call_count == 1
|
||||
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
|
||||
|
||||
@@ -2105,23 +2105,35 @@ class TestInternalRoutes:
|
||||
app.register_blueprint(internal)
|
||||
return app
|
||||
|
||||
_TEST_KEY = "test-key"
|
||||
_AUTH_HEADERS = {"X-Internal-Key": "test-key"}
|
||||
|
||||
def test_upload_index_no_user(self, app):
|
||||
with app.test_client() as client:
|
||||
with patch(
|
||||
"application.api.internal.routes.settings"
|
||||
) as ms:
|
||||
ms.INTERNAL_KEY = None
|
||||
resp = client.post("/api/upload_index")
|
||||
ms.INTERNAL_KEY = self._TEST_KEY
|
||||
resp = client.post("/api/upload_index", headers=self._AUTH_HEADERS)
|
||||
assert resp.get_json()["status"] == "no user"
|
||||
|
||||
def test_upload_index_no_name(self, app):
|
||||
with app.test_client() as client:
|
||||
with patch(
|
||||
"application.api.internal.routes.settings"
|
||||
) as ms:
|
||||
ms.INTERNAL_KEY = self._TEST_KEY
|
||||
resp = client.post("/api/upload_index", data={"user": "u1"}, headers=self._AUTH_HEADERS)
|
||||
assert resp.get_json()["status"] == "no name"
|
||||
|
||||
def test_upload_index_rejected_without_internal_key(self, app):
|
||||
with app.test_client() as client:
|
||||
with patch(
|
||||
"application.api.internal.routes.settings"
|
||||
) as ms:
|
||||
ms.INTERNAL_KEY = None
|
||||
resp = client.post("/api/upload_index", data={"user": "u1"})
|
||||
assert resp.get_json()["status"] == "no name"
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from application.api.v1.translator import (
|
||||
_get_client_tool_name,
|
||||
convert_history,
|
||||
extract_system_prompt,
|
||||
extract_tool_results,
|
||||
is_continuation,
|
||||
translate_request,
|
||||
@@ -148,6 +149,48 @@ class TestConvertHistory:
|
||||
assert history == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_system_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExtractSystemPrompt:
|
||||
|
||||
def test_extracts_first_system_message(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a pirate"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == "You are a pirate"
|
||||
|
||||
def test_returns_none_when_no_system_message(self):
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
assert extract_system_prompt(messages) is None
|
||||
|
||||
def test_returns_first_of_multiple_system_messages(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "First"},
|
||||
{"role": "system", "content": "Second"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == "First"
|
||||
|
||||
def test_empty_content_returns_empty_string(self):
|
||||
messages = [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == ""
|
||||
|
||||
def test_missing_content_returns_empty_string(self):
|
||||
messages = [
|
||||
{"role": "system"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
assert extract_system_prompt(messages) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# translate_request
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -167,11 +210,25 @@ class TestTranslateRequest:
|
||||
result = translate_request(data, "test-key")
|
||||
assert result["question"] == "What's 2+2?"
|
||||
assert result["api_key"] == "test-key"
|
||||
assert result["save_conversation"] is True
|
||||
# Conversations are not persisted by default on the v1 endpoint.
|
||||
assert result["save_conversation"] is False
|
||||
history = json.loads(result["history"])
|
||||
assert len(history) == 1
|
||||
assert history[0]["prompt"] == "Hello"
|
||||
|
||||
def test_save_conversation_opt_in_via_docsgpt_extension(self):
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"docsgpt": {"save_conversation": True},
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["save_conversation"] is True
|
||||
|
||||
def test_save_conversation_default_false(self):
|
||||
data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
result = translate_request(data, "key")
|
||||
assert result["save_conversation"] is False
|
||||
|
||||
def test_continuation_request(self):
|
||||
data = {
|
||||
"messages": [
|
||||
@@ -237,6 +294,23 @@ class TestTranslateRequest:
|
||||
result = translate_request(data, "key")
|
||||
assert result["attachments"] == ["att1", "att2"]
|
||||
|
||||
def test_system_prompt_override_included_when_present(self):
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "Custom prompt"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert result["system_prompt_override"] == "Custom prompt"
|
||||
|
||||
def test_system_prompt_override_absent_when_no_system_message(self):
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
result = translate_request(data, "key")
|
||||
assert "system_prompt_override" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# translate_response
|
||||
|
||||
@@ -363,6 +363,28 @@ class TestGetVectorstore:
|
||||
|
||||
assert get_vectorstore("user/source123") == "indexes/user/source123"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_path",
|
||||
[
|
||||
"../outside",
|
||||
"../../etc/passwd",
|
||||
"nested/../../../outside",
|
||||
"/tmp/evil",
|
||||
"..\\outside",
|
||||
"valid/../../escape",
|
||||
],
|
||||
)
|
||||
def test_rejects_path_traversal(self, malicious_path):
|
||||
from application.vectorstore.faiss import get_vectorstore
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid source_id path"):
|
||||
get_vectorstore(malicious_path)
|
||||
|
||||
def test_allows_mongodb_style_ids(self):
|
||||
from application.vectorstore.faiss import get_vectorstore
|
||||
|
||||
assert get_vectorstore("65e8f6a8a7a96b1bdad4154f") == "indexes/65e8f6a8a7a96b1bdad4154f"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFaissStoreAddChunk:
|
||||
|
||||
@@ -16,10 +16,9 @@ def _make_store(
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"psycopg2": MagicMock(),
|
||||
"psycopg2.extras": MagicMock(),
|
||||
"psycopg": MagicMock(),
|
||||
"pgvector": MagicMock(),
|
||||
"pgvector.psycopg2": MagicMock(),
|
||||
"pgvector.psycopg": MagicMock(),
|
||||
},
|
||||
):
|
||||
mock_emb = Mock()
|
||||
@@ -63,10 +62,9 @@ class TestPGVectorStoreInit:
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"psycopg2": MagicMock(),
|
||||
"psycopg2.extras": MagicMock(),
|
||||
"psycopg": MagicMock(),
|
||||
"pgvector": MagicMock(),
|
||||
"pgvector.psycopg2": MagicMock(),
|
||||
"pgvector.psycopg": MagicMock(),
|
||||
},
|
||||
):
|
||||
mock_get_emb.return_value = Mock(dimension=768)
|
||||
@@ -264,13 +262,13 @@ class TestPGVectorStoreConnection:
|
||||
store, mock_conn, _, _ = _make_store()
|
||||
mock_conn.closed = True
|
||||
|
||||
mock_psycopg2 = MagicMock()
|
||||
mock_psycopg = MagicMock()
|
||||
new_conn = MagicMock()
|
||||
mock_psycopg2.connect.return_value = new_conn
|
||||
store._psycopg2 = mock_psycopg2
|
||||
mock_psycopg.connect.return_value = new_conn
|
||||
store._psycopg = mock_psycopg
|
||||
|
||||
conn = store._get_connection()
|
||||
mock_psycopg2.connect.assert_called_once()
|
||||
mock_psycopg.connect.assert_called_once()
|
||||
assert conn is new_conn
|
||||
|
||||
def test_get_connection_reuses_open(self):
|
||||
|
||||
Reference in New Issue
Block a user