mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 14:34:32 +00:00
Compare commits
18 Commits
v1-mini-im
...
pg-1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5b6538762 | ||
|
|
a9761061fc | ||
|
|
9388996a15 | ||
|
|
875868b7e5 | ||
|
|
502819ae52 | ||
|
|
cada1a44fc | ||
|
|
6192767451 | ||
|
|
5c3e6eca54 | ||
|
|
59d9d4ac50 | ||
|
|
3931ccccee | ||
|
|
55717043f6 | ||
|
|
ececcb8b17 | ||
|
|
420e9d3dd5 | ||
|
|
749eed3d0b | ||
|
|
bd03a513e3 | ||
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db |
@@ -34,3 +34,9 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||
# Leave unset while the migration is still being rolled out; the app will
|
||||
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
@@ -1,46 +1,80 @@
|
||||
Ollama
|
||||
Qdrant
|
||||
Milvus
|
||||
Chatwoot
|
||||
Nextra
|
||||
VSCode
|
||||
npm
|
||||
LLMs
|
||||
Agentic
|
||||
Anthropic's
|
||||
api
|
||||
APIs
|
||||
Groq
|
||||
SGLang
|
||||
LMDeploy
|
||||
OAuth
|
||||
Vite
|
||||
LLM
|
||||
JSONPath
|
||||
UIs
|
||||
Atlassian
|
||||
automations
|
||||
autoescaping
|
||||
Autoescaping
|
||||
backfill
|
||||
backfills
|
||||
bool
|
||||
boolean
|
||||
brave_web_search
|
||||
chatbot
|
||||
Chatwoot
|
||||
config
|
||||
configs
|
||||
uncomment
|
||||
qdrant
|
||||
vectorstore
|
||||
CSVs
|
||||
dev
|
||||
diarization
|
||||
Docling
|
||||
docsgpt
|
||||
llm
|
||||
docstrings
|
||||
Entra
|
||||
env
|
||||
enqueues
|
||||
EOL
|
||||
ESLint
|
||||
feedbacks
|
||||
Figma
|
||||
GPUs
|
||||
Groq
|
||||
hardcode
|
||||
hardcoding
|
||||
Idempotency
|
||||
JSONPath
|
||||
kubectl
|
||||
Lightsail
|
||||
enqueues
|
||||
chatbot
|
||||
VSCode's
|
||||
Shareability
|
||||
feedbacks
|
||||
automations
|
||||
llama_cpp
|
||||
llm
|
||||
LLM
|
||||
LLMs
|
||||
LMDeploy
|
||||
Milvus
|
||||
Mixtral
|
||||
namespace
|
||||
namespaces
|
||||
needs_auth
|
||||
Nextra
|
||||
Novita
|
||||
npm
|
||||
OAuth
|
||||
Ollama
|
||||
opencode
|
||||
parsable
|
||||
passthrough
|
||||
PDFs
|
||||
pgvector
|
||||
Postgres
|
||||
Premade
|
||||
Signup
|
||||
Pydantic
|
||||
pytest
|
||||
Qdrant
|
||||
qdrant
|
||||
Repo
|
||||
repo
|
||||
env
|
||||
URl
|
||||
agentic
|
||||
llama_cpp
|
||||
parsable
|
||||
Sanitization
|
||||
SDKs
|
||||
boolean
|
||||
bool
|
||||
hardcode
|
||||
EOL
|
||||
SGLang
|
||||
Shareability
|
||||
Signup
|
||||
Supabase
|
||||
UIs
|
||||
uncomment
|
||||
URl
|
||||
vectorstore
|
||||
Vite
|
||||
VSCode
|
||||
VSCode's
|
||||
widget's
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -108,6 +108,8 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||
CLAUDE.md
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
@@ -181,5 +183,6 @@ application/vectors/
|
||||
|
||||
node_modules/
|
||||
.vscode/settings.json
|
||||
.vscode/sftp.json
|
||||
/models/
|
||||
model/
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
MinAlertLevel = warning
|
||||
StylesPath = .github/styles
|
||||
Vocab = DocsGPT
|
||||
|
||||
[*.{md,mdx}]
|
||||
BasedOnStyles = DocsGPT
|
||||
|
||||
|
||||
12
SECURITY.md
12
SECURITY.md
@@ -2,9 +2,7 @@
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Supported Versions:
|
||||
|
||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
||||
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
@@ -14,7 +12,11 @@ https://github.com/arc53/DocsGPT/security
|
||||
Then click **Report a vulnerability**.
|
||||
|
||||
|
||||
Alternatively:
|
||||
Alternatively, email us at: security@arc53.com
|
||||
|
||||
security@arc53.com
|
||||
We aim to acknowledge reports within 48 hours.
|
||||
|
||||
## Incident Handling
|
||||
|
||||
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class BraveSearchTool(Tool):
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
@@ -118,7 +118,7 @@ class BraveSearchTool(Tool):
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
|
||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
response = requests.get(url, timeout=100)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if currency.upper() in data:
|
||||
|
||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
@@ -33,7 +33,7 @@ class PostgresTool(Tool):
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql_query)
|
||||
conn.commit()
|
||||
@@ -60,7 +60,7 @@ class PostgresTool(Tool):
|
||||
"response_data": response_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||
return {
|
||||
@@ -78,7 +78,7 @@ class PostgresTool(Tool):
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute(
|
||||
@@ -120,7 +120,7 @@ class PostgresTool(Tool):
|
||||
"schema": schema_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL get_schema error: %s", e)
|
||||
return {
|
||||
|
||||
@@ -31,14 +31,14 @@ class TelegramTool(Tool):
|
||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
52
application/alembic.ini
Normal file
52
application/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
||||
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||
#
|
||||
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||
# source serves the running app and migrations. To run from the project
|
||||
# root::
|
||||
#
|
||||
# alembic -c application/alembic.ini upgrade head
|
||||
|
||||
[alembic]
|
||||
script_location = %(here)s/alembic
|
||||
prepend_sys_path = ..
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||
sqlalchemy.url =
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
82
application/alembic/env.py
Normal file
82
application/alembic/env.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||
|
||||
The URL is pulled from ``application.core.settings`` rather than
|
||||
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||
running app and ``alembic`` CLI invocations.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
# Make the project root importable regardless of cwd. env.py lives at
|
||||
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||
|
||||
config = context.config
|
||||
|
||||
# Populate the runtime URL from settings.
|
||||
if settings.POSTGRES_URI:
|
||||
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
if not url:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode against a live connection."""
|
||||
if not config.get_main_option("sqlalchemy.url"):
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
future=True,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
application/alembic/script.py.mako
Normal file
26
application/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
462
application/alembic/versions/0001_initial.py
Normal file
462
application/alembic/versions/0001_initial.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""0001 initial schema — user-level tables migrated from MongoDB.
|
||||
|
||||
Creates every table described in §2.2 of ``migration-postgres.md``: tiers 1,
|
||||
2, and 3 in one shot. The schema is small enough that splitting the baseline
|
||||
across multiple revisions would only cost clarity.
|
||||
|
||||
Subsequent migrations will add columns / tables incrementally. This file is
|
||||
hand-written raw DDL rather than Core ``op.create_table`` calls because the
|
||||
DDL uses several Postgres-specific features (``CITEXT``, partial indexes,
|
||||
``text_pattern_ops``, JSONB defaults) that are clearer in SQL than in
|
||||
Alembic's Python API.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
Create Date: 2026-04-10
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0001_initial"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Extensions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 1: leaf tables, no FKs into other migrated tables
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("""
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL UNIQUE,
|
||||
agent_preferences JSONB NOT NULL
|
||||
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX users_user_id_idx ON users (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE prompts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE user_tools (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
custom_name TEXT,
|
||||
display_name TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE token_usage (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
agent_id UUID, -- FK added later in this migration
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, timestamp DESC);")
|
||||
op.execute("CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, timestamp DESC);")
|
||||
op.execute("CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, timestamp DESC);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE user_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
endpoint TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
data JSONB
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, timestamp DESC);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE feedback (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL, -- FK added later in this migration
|
||||
user_id TEXT NOT NULL,
|
||||
question_index INTEGER NOT NULL,
|
||||
feedback_text TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX feedback_conv_idx ON feedback (conversation_id);")
|
||||
|
||||
# Append-only debug/error log. The Mongo doc has both `_id` (auto) and an
|
||||
# `id` field (the activity id). Here the serial PK owns `id`; the
|
||||
# application-level identifier is renamed to `activity_id`.
|
||||
op.execute("""
|
||||
CREATE TABLE stack_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
activity_id TEXT NOT NULL,
|
||||
endpoint TEXT,
|
||||
level TEXT,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
query TEXT,
|
||||
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX stack_logs_timestamp_idx ON stack_logs (timestamp DESC);")
|
||||
op.execute("CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, timestamp DESC);")
|
||||
op.execute("CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, timestamp DESC);")
|
||||
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 2: FK-bearing tables
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("""
|
||||
CREATE TABLE agent_folders (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT, -- NULL for system/template sources
|
||||
name TEXT NOT NULL,
|
||||
type TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE agents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
agent_type TEXT,
|
||||
status TEXT NOT NULL,
|
||||
key CITEXT UNIQUE,
|
||||
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||
chunks INTEGER,
|
||||
retriever TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
json_schema JSONB,
|
||||
models JSONB,
|
||||
default_model_id TEXT,
|
||||
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
token_limit INTEGER,
|
||||
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
request_limit INTEGER,
|
||||
shared BOOLEAN NOT NULL DEFAULT false,
|
||||
incoming_webhook_token CITEXT UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_used_at TIMESTAMPTZ
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||
|
||||
# Backfill the token_usage.agent_id FK now that agents exists.
|
||||
op.execute("""
|
||||
ALTER TABLE token_usage
|
||||
ADD CONSTRAINT token_usage_agent_fk
|
||||
FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE attachments (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
upload_path TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
size BIGINT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX memories_user_tool_path_uidx
|
||||
ON memories (user_id, tool_id, path);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX memories_path_prefix_idx
|
||||
ON memories (user_id, tool_id, path text_pattern_ops);
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE todos (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE notes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX notes_user_tool_idx ON notes (user_id, tool_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE connector_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
session_data JSONB NOT NULL,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX connector_sessions_user_provider_idx
|
||||
ON connector_sessions (user_id, provider);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX connector_sessions_expiry_idx
|
||||
ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;
|
||||
""")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tier 3: conversations, pending_tool_state, workflows
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("""
|
||||
CREATE TABLE conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||
name TEXT,
|
||||
api_key TEXT,
|
||||
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||
op.execute("""
|
||||
CREATE INDEX conversations_shared_token_idx
|
||||
ON conversations (shared_token) WHERE shared_token IS NOT NULL;
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE conversation_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
position INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
thought TEXT,
|
||||
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
attachments UUID[] NOT NULL DEFAULT '{}',
|
||||
model_id TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
feedback JSONB,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx
|
||||
ON conversation_messages (conversation_id, position);
|
||||
""")
|
||||
|
||||
# Backfill the feedback.conversation_id FK now that conversations exists.
|
||||
op.execute("""
|
||||
ALTER TABLE feedback
|
||||
ADD CONSTRAINT feedback_conv_fk
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE;
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE shared_conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
chunks INTEGER,
|
||||
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||
|
||||
# Paused-tool continuation state. The Mongo version relies on a TTL index;
|
||||
# Postgres has no native TTL, so a Celery beat task (added in Phase 3)
|
||||
# deletes rows where expires_at < now() once a minute. The unique
|
||||
# constraint on (conversation_id, user_id) matches the existing upsert
|
||||
# semantics.
|
||||
op.execute("""
|
||||
CREATE TABLE pending_tool_state (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
messages JSONB NOT NULL,
|
||||
pending_tool_calls JSONB NOT NULL,
|
||||
tools_dict JSONB NOT NULL,
|
||||
tool_schemas JSONB NOT NULL,
|
||||
agent_config JSONB NOT NULL,
|
||||
client_tools JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx
|
||||
ON pending_tool_state (conversation_id, user_id);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX pending_tool_state_expires_idx
|
||||
ON pending_tool_state (expires_at);
|
||||
""")
|
||||
|
||||
# Workflows
|
||||
op.execute("""
|
||||
CREATE TABLE workflows (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE workflow_nodes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX workflow_nodes_workflow_version_idx
|
||||
ON workflow_nodes (workflow_id, graph_version);
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE workflow_edges (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
from_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
|
||||
to_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb
|
||||
);
|
||||
""")
|
||||
op.execute("""
|
||||
CREATE INDEX workflow_edges_workflow_version_idx
|
||||
ON workflow_edges (workflow_id, graph_version);
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
CREATE TABLE workflow_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
ended_at TIMESTAMPTZ,
|
||||
result JSONB
|
||||
);
|
||||
""")
|
||||
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse dependency order. CASCADE would handle FKs anyway, but explicit
|
||||
# is clearer for anyone reading the migration.
|
||||
op.execute("DROP TABLE IF EXISTS workflow_runs CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS workflow_edges CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS workflow_nodes CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS workflows CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS pending_tool_state CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS shared_conversations CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS conversation_messages CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS conversations CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS connector_sessions CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS notes CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS todos CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS memories CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS attachments CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS agents CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS sources CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS agent_folders CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS stack_logs CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS feedback CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS user_logs CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS token_usage CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS user_tools CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS prompts CASCADE;")
|
||||
op.execute("DROP TABLE IF EXISTS users CASCADE;")
|
||||
# Extensions are intentionally left in place — they may be shared with
|
||||
# pgvector or other extensions already enabled on the cluster.
|
||||
@@ -469,6 +469,18 @@ class BaseAnswerResource:
|
||||
log_data[key] = value[:10000]
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
|
||||
dual_write(
|
||||
UserLogsRepository,
|
||||
lambda repo, d=log_data: repo.insert(
|
||||
user_id=d.get("user"),
|
||||
endpoint="stream_answer",
|
||||
data=d,
|
||||
),
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
|
||||
@@ -23,6 +23,8 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
@@ -1250,6 +1252,9 @@ class PinnedAgents(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, ids=stale_ids: repo.remove_pinned_bulk(uid, ids)
|
||||
)
|
||||
list_pinned_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
@@ -1381,12 +1386,18 @@ class PinAgent(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pull": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.remove_pinned(uid, aid)
|
||||
)
|
||||
action = "unpinned"
|
||||
else:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.pinned": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.add_pinned(uid, aid)
|
||||
)
|
||||
action = "pinned"
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
|
||||
@@ -1432,6 +1443,9 @@ class RemoveSharedAgent(Resource):
|
||||
}
|
||||
},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.remove_agent_from_all(uid, aid)
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
||||
except Exception as err:
|
||||
|
||||
@@ -18,6 +18,8 @@ from application.api.user.base import (
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
@@ -105,6 +107,9 @@ class SharedAgent(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
@@ -139,6 +144,9 @@ class SharedAgents(Resource):
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
|
||||
@@ -15,6 +15,8 @@ from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
@@ -132,6 +134,9 @@ def ensure_user_doc(user_id):
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
|
||||
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
|
||||
|
||||
return user_doc
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
@@ -49,6 +51,10 @@ class CreatePrompt(Resource):
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, u=user, n=data["name"], c=data["content"]: repo.create(u, n, c),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -149,6 +155,10 @@ class DeletePrompt(Resource):
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user: repo.delete(pid, u),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -185,6 +195,10 @@ class UpdatePrompt(Resource):
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update(pid, u, n, c),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
@@ -9,6 +9,8 @@ from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
@@ -294,6 +296,13 @@ class CreateTool(Resource):
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
UserToolsRepository,
|
||||
lambda repo, u=user, t=new_tool: repo.create(
|
||||
u, t["name"], config=t.get("config"),
|
||||
custom_name=t.get("customName"), display_name=t.get("displayName"),
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -581,6 +590,10 @@ class DeleteTool(Resource):
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
dual_write(
|
||||
UserToolsRepository,
|
||||
lambda repo, tid=data["id"], u=user: repo.delete(tid, u),
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from celery import Celery
|
||||
from application.core.settings import settings
|
||||
from celery.signals import setup_logging
|
||||
from celery.signals import setup_logging, worker_process_init
|
||||
|
||||
|
||||
def make_celery(app_name=__name__):
|
||||
@@ -20,5 +20,24 @@ def config_loggers(*args, **kwargs):
|
||||
setup_logging()
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
|
||||
|
||||
SQLAlchemy connection pools are not fork-safe: file descriptors shared
|
||||
between the parent and a forked worker will corrupt the pool. Disposing
|
||||
on ``worker_process_init`` gives every worker its own fresh pool on
|
||||
first use.
|
||||
|
||||
Imported lazily so Celery workers that don't touch Postgres (or where
|
||||
``POSTGRES_URI`` is unset) don't fail at startup.
|
||||
"""
|
||||
try:
|
||||
from application.storage.db.engine import dispose_engine
|
||||
except Exception:
|
||||
return
|
||||
dispose_engine()
|
||||
|
||||
|
||||
celery = make_celery()
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
|
||||
89
application/core/db_uri.py
Normal file
89
application/core/db_uri.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Normalize user-supplied Postgres URIs for different drivers.
|
||||
|
||||
DocsGPT has two Postgres connection strings pointing at potentially
|
||||
different databases:
|
||||
|
||||
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
|
||||
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
|
||||
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
|
||||
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
|
||||
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
|
||||
dialect prefix is an invalid URI from its point of view.
|
||||
|
||||
The two fields therefore need opposite normalization so operators don't
|
||||
have to know which driver a given field feeds. Each normalizer also
|
||||
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
|
||||
psycopg2 is no longer in the project.
|
||||
|
||||
This module is deliberately separate from ``application/core/settings.py``
|
||||
so the Settings class stays focused on field declarations, and the
|
||||
URI-rewriting logic can be unit-tested without triggering ``.env``
|
||||
file loading from importing Settings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _rewrite_uri_prefixes(v, rewrites):
|
||||
"""Shared URI prefix rewriter used by both normalizers below.
|
||||
|
||||
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
|
||||
applies the first matching rewrite, and passes unrecognised input
|
||||
through so downstream consumers (SQLAlchemy, libpq) can produce
|
||||
their own error messages rather than us silently eating a
|
||||
misconfiguration.
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
v = v.strip()
|
||||
if not v or v.lower() == "none":
|
||||
return None
|
||||
for prefix, target in rewrites:
|
||||
if v.startswith(prefix):
|
||||
return target + v[len(prefix):]
|
||||
return v
|
||||
|
||||
|
||||
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
|
||||
# dialect prefix to select the psycopg v3 driver. Normalize the
|
||||
# operator-friendly forms TOWARD that dialect.
|
||||
_POSTGRES_URI_REWRITES = (
|
||||
("postgresql+psycopg2://", "postgresql+psycopg://"),
|
||||
("postgresql://", "postgresql+psycopg://"),
|
||||
("postgres://", "postgresql+psycopg://"),
|
||||
)
|
||||
|
||||
|
||||
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
|
||||
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
|
||||
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
|
||||
# dialect prefix is an invalid URI from libpq's point of view. Strip it
|
||||
# if the operator accidentally copied their POSTGRES_URI value here.
|
||||
_PGVECTOR_CONNECTION_STRING_REWRITES = (
|
||||
("postgresql+psycopg2://", "postgresql://"),
|
||||
("postgresql+psycopg://", "postgresql://"),
|
||||
)
|
||||
|
||||
|
||||
def normalize_postgres_uri(v):
|
||||
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
|
||||
|
||||
Accepts the forms operators naturally write (``postgres://``,
|
||||
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
|
||||
Unknown schemes pass through unchanged so SQLAlchemy can produce its
|
||||
own dialect-not-found error.
|
||||
"""
|
||||
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
|
||||
|
||||
|
||||
def normalize_pgvector_connection_string(v):
|
||||
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
|
||||
|
||||
Strips the SQLAlchemy dialect prefix if the operator accidentally
|
||||
copied their POSTGRES_URI value here — libpq can't parse it.
|
||||
User-friendly forms (``postgres://``, ``postgresql://``) pass
|
||||
through unchanged since libpq accepts them natively.
|
||||
"""
|
||||
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)
|
||||
@@ -8,6 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
from application.core.db_uri import ( # noqa: E402
|
||||
normalize_pgvector_connection_string,
|
||||
normalize_postgres_uri,
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(extra="ignore")
|
||||
|
||||
@@ -22,6 +28,11 @@ class Settings(BaseSettings):
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
# User-data Postgres DB.
|
||||
POSTGRES_URI: Optional[str] = None
|
||||
|
||||
# MongoDB→Postgres migration: dual-write to Postgres (Mongo stays source of truth)
|
||||
USE_POSTGRES: bool = False
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
@@ -59,6 +70,10 @@ class Settings(BaseSettings):
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
|
||||
# Confluence Cloud integration
|
||||
CONFLUENCE_CLIENT_ID: Optional[str] = None
|
||||
CONFLUENCE_CLIENT_SECRET: Optional[str] = None
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
@@ -117,7 +132,10 @@ class Settings(BaseSettings):
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
# PGVector vectorstore config
|
||||
# PGVector vectorstore config. Write the URI in whichever form you
|
||||
# prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy
|
||||
# dialect form (``postgresql+psycopg://``) are all accepted and
|
||||
# normalized internally for ``psycopg.connect()``.
|
||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||
# Milvus vectorstore config
|
||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||
@@ -156,6 +174,16 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
return normalize_postgres_uri(v)
|
||||
|
||||
@field_validator("PGVECTOR_CONNECTION_STRING", mode="before")
|
||||
@classmethod
|
||||
def _normalize_pgvector_connection_string_validator(cls, v):
|
||||
return normalize_pgvector_connection_string(v)
|
||||
|
||||
@field_validator(
|
||||
"API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
|
||||
@@ -157,5 +157,21 @@ def _log_to_mongodb(
|
||||
user_logs_collection.insert_one(log_entry)
|
||||
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
dual_write(
|
||||
StackLogsRepository,
|
||||
lambda repo, e=log_entry: repo.insert(
|
||||
activity_id=e["id"],
|
||||
endpoint=e.get("endpoint"),
|
||||
level=e.get("level"),
|
||||
user_id=e.get("user"),
|
||||
api_key=e.get("api_key"),
|
||||
query=e.get("query"),
|
||||
stacks=e.get("stacks"),
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)
|
||||
|
||||
4
application/parser/connectors/confluence/__init__.py
Normal file
4
application/parser/connectors/confluence/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .auth import ConfluenceAuth
|
||||
from .loader import ConfluenceLoader
|
||||
|
||||
__all__ = ["ConfluenceAuth", "ConfluenceLoader"]
|
||||
216
application/parser/connectors/confluence/auth.py
Normal file
216
application/parser/connectors/confluence/auth.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfluenceAuth(BaseConnectorAuth):
|
||||
|
||||
SCOPES = [
|
||||
"read:page:confluence",
|
||||
"read:space:confluence",
|
||||
"read:attachment:confluence",
|
||||
"read:me",
|
||||
"offline_access",
|
||||
]
|
||||
|
||||
AUTH_URL = "https://auth.atlassian.com/authorize"
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
ME_URL = "https://api.atlassian.com/me"
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.CONFLUENCE_CLIENT_ID
|
||||
self.client_secret = settings.CONFLUENCE_CLIENT_SECRET
|
||||
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError(
|
||||
"Confluence OAuth credentials not configured. "
|
||||
"Please set CONFLUENCE_CLIENT_ID and CONFLUENCE_CLIENT_SECRET in settings."
|
||||
)
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": self.client_id,
|
||||
"scope": " ".join(self.SCOPES),
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": state,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
return f"{self.AUTH_URL}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
if not authorization_code:
|
||||
raise ValueError("Authorization code is required")
|
||||
|
||||
response = requests.post(
|
||||
self.TOKEN_URL,
|
||||
json={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": authorization_code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError("OAuth flow did not return an access token")
|
||||
|
||||
refresh_token = token_data.get("refresh_token")
|
||||
if not refresh_token:
|
||||
raise ValueError("OAuth flow did not return a refresh token")
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
expiry = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expires_in)
|
||||
).isoformat()
|
||||
|
||||
cloud_id = self._fetch_cloud_id(access_token)
|
||||
user_info = self._fetch_user_info(access_token)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_uri": self.TOKEN_URL,
|
||||
"scopes": self.SCOPES,
|
||||
"expiry": expiry,
|
||||
"cloud_id": cloud_id,
|
||||
"user_info": {
|
||||
"name": user_info.get("display_name", ""),
|
||||
"email": user_info.get("email", ""),
|
||||
},
|
||||
}
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
if not refresh_token:
|
||||
raise ValueError("Refresh token is required")
|
||||
|
||||
response = requests.post(
|
||||
self.TOKEN_URL,
|
||||
json={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
access_token = token_data.get("access_token")
|
||||
new_refresh_token = token_data.get("refresh_token", refresh_token)
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
expiry = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expires_in)
|
||||
).isoformat()
|
||||
|
||||
cloud_id = self._fetch_cloud_id(access_token)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
"token_uri": self.TOKEN_URL,
|
||||
"scopes": self.SCOPES,
|
||||
"expiry": expiry,
|
||||
"cloud_id": cloud_id,
|
||||
}
|
||||
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
if not token_info:
|
||||
return True
|
||||
|
||||
expiry = token_info.get("expiry")
|
||||
if not expiry:
|
||||
return bool(token_info.get("access_token"))
|
||||
|
||||
try:
|
||||
expiry_dt = datetime.datetime.fromisoformat(expiry)
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return now >= expiry_dt - datetime.timedelta(seconds=60)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings as app_settings
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[app_settings.MONGO_DB_NAME]
|
||||
|
||||
session = db["connector_sessions"].find_one({"session_token": session_token})
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
|
||||
token_info = session.get("token_info")
|
||||
if not token_info:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
required = ["access_token", "refresh_token", "cloud_id"]
|
||||
missing = [f for f in required if not token_info.get(f)]
|
||||
if missing:
|
||||
raise ValueError(f"Missing required token fields: {missing}")
|
||||
|
||||
return token_info
|
||||
|
||||
def sanitize_token_info(
|
||||
self, token_info: Dict[str, Any], **extra_fields
|
||||
) -> Dict[str, Any]:
|
||||
return super().sanitize_token_info(
|
||||
token_info,
|
||||
cloud_id=token_info.get("cloud_id"),
|
||||
**extra_fields,
|
||||
)
|
||||
|
||||
def _fetch_cloud_id(self, access_token: str) -> str:
|
||||
response = requests.get(
|
||||
self.RESOURCES_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
resources = response.json()
|
||||
|
||||
if not resources:
|
||||
raise ValueError("No accessible Confluence sites found for this account")
|
||||
|
||||
return resources[0]["id"]
|
||||
|
||||
def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
response = requests.get(
|
||||
self.ME_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch user info: %s", e)
|
||||
return {}
|
||||
416
application/parser/connectors/confluence/loader.py
Normal file
416
application/parser/connectors/confluence/loader.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_V2 = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki/api/v2"
|
||||
DOWNLOAD_BASE = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki"
|
||||
|
||||
SUPPORTED_ATTACHMENT_TYPES = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"text/html": ".html",
|
||||
"text/markdown": ".md",
|
||||
"application/json": ".json",
|
||||
"application/epub+zip": ".epub",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
}
|
||||
|
||||
|
||||
def _retry_on_auth_failure(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code in (401, 403):
|
||||
logger.info(
|
||||
"Auth failure in %s, refreshing token and retrying", func.__name__
|
||||
)
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info["access_token"]
|
||||
self.refresh_token = new_token_info.get(
|
||||
"refresh_token", self.refresh_token
|
||||
)
|
||||
self._persist_refreshed_tokens(new_token_info)
|
||||
except Exception as refresh_err:
|
||||
raise ValueError(
|
||||
f"Authentication failed and could not be refreshed: {refresh_err}"
|
||||
) from e
|
||||
return func(self, *args, **kwargs)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ConfluenceLoader(BaseConnectorLoader):
|
||||
|
||||
def __init__(self, session_token: str):
|
||||
self.auth = ConfluenceAuth()
|
||||
self.session_token = session_token
|
||||
|
||||
token_info = self.auth.get_token_info_from_session(session_token)
|
||||
self.access_token = token_info["access_token"]
|
||||
self.refresh_token = token_info["refresh_token"]
|
||||
self.cloud_id = token_info["cloud_id"]
|
||||
|
||||
self.base_url = API_V2.format(cloud_id=self.cloud_id)
|
||||
self.download_base = DOWNLOAD_BASE.format(cloud_id=self.cloud_id)
|
||||
self.next_page_token = None
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def _persist_refreshed_tokens(self, token_info: Dict[str, Any]) -> None:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings as app_settings
|
||||
|
||||
sanitized = self.auth.sanitize_token_info(token_info)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[app_settings.MONGO_DB_NAME]
|
||||
db["connector_sessions"].update_one(
|
||||
{"session_token": self.session_token},
|
||||
{"$set": {"token_info": sanitized}},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist refreshed tokens: %s", e)
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
folder_id = inputs.get("folder_id")
|
||||
file_ids = inputs.get("file_ids", [])
|
||||
limit = inputs.get("limit", 100)
|
||||
list_only = inputs.get("list_only", False)
|
||||
page_token = inputs.get("page_token")
|
||||
search_query = inputs.get("search_query")
|
||||
self.next_page_token = None
|
||||
|
||||
if file_ids:
|
||||
return self._load_pages_by_ids(file_ids, list_only, search_query)
|
||||
|
||||
if folder_id:
|
||||
return self._list_pages_in_space(
|
||||
folder_id, limit, list_only, page_token, search_query
|
||||
)
|
||||
|
||||
return self._list_spaces(limit, page_token, search_query)
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
|
||||
config = source_config or getattr(self, "config", {})
|
||||
file_ids = config.get("file_ids", [])
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
files_downloaded = 0
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [file_ids]
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [folder_ids]
|
||||
|
||||
for page_id in file_ids:
|
||||
if self._download_page(page_id, local_dir):
|
||||
files_downloaded += 1
|
||||
files_downloaded += self._download_page_attachments(page_id, local_dir)
|
||||
|
||||
for space_id in folder_ids:
|
||||
files_downloaded += self._download_space(space_id, local_dir)
|
||||
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": files_downloaded == 0,
|
||||
"source_type": "confluence",
|
||||
"config_used": config,
|
||||
}
|
||||
|
||||
def _list_spaces(
|
||||
self, limit: int, cursor: Optional[str], search_query: Optional[str]
|
||||
) -> List[Document]:
|
||||
documents: List[Document] = []
|
||||
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/spaces",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
for space in data.get("results", []):
|
||||
name = space.get("name", "")
|
||||
if search_query and search_query.lower() not in name.lower():
|
||||
continue
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
text="",
|
||||
doc_id=space["id"],
|
||||
extra_info={
|
||||
"file_name": name,
|
||||
"mime_type": "folder",
|
||||
"size": None,
|
||||
"created_time": space.get("createdAt"),
|
||||
"modified_time": None,
|
||||
"source": "confluence",
|
||||
"is_folder": True,
|
||||
"space_key": space.get("key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
self.next_page_token = self._extract_cursor(next_link)
|
||||
return documents
|
||||
|
||||
def _list_pages_in_space(
|
||||
self,
|
||||
space_id: str,
|
||||
limit: int,
|
||||
list_only: bool,
|
||||
cursor: Optional[str],
|
||||
search_query: Optional[str],
|
||||
) -> List[Document]:
|
||||
documents: List[Document] = []
|
||||
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/spaces/{space_id}/pages",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
for page in data.get("results", []):
|
||||
title = page.get("title", "")
|
||||
if search_query and search_query.lower() not in title.lower():
|
||||
continue
|
||||
|
||||
doc = self._page_to_document(
|
||||
page, load_content=not list_only, space_id=space_id
|
||||
)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
self.next_page_token = self._extract_cursor(next_link)
|
||||
return documents
|
||||
|
||||
def _load_pages_by_ids(
|
||||
self, page_ids: List[str], list_only: bool, search_query: Optional[str]
|
||||
) -> List[Document]:
|
||||
documents: List[Document] = []
|
||||
for page_id in page_ids:
|
||||
try:
|
||||
params: Dict[str, str] = {}
|
||||
if not list_only:
|
||||
params["body-format"] = "storage"
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/pages/{page_id}",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
page = response.json()
|
||||
|
||||
title = page.get("title", "")
|
||||
if search_query and search_query.lower() not in title.lower():
|
||||
continue
|
||||
|
||||
doc = self._page_to_document(page, load_content=not list_only)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logger.error("Error loading page %s: %s", page_id, e)
|
||||
return documents
|
||||
|
||||
def _page_to_document(
|
||||
self,
|
||||
page: Dict[str, Any],
|
||||
load_content: bool = False,
|
||||
space_id: Optional[str] = None,
|
||||
) -> Optional[Document]:
|
||||
page_id = page.get("id")
|
||||
title = page.get("title", "Unknown")
|
||||
version = page.get("version", {})
|
||||
modified_time = version.get("createdAt") if isinstance(version, dict) else None
|
||||
created_time = page.get("createdAt")
|
||||
resolved_space_id = space_id or page.get("spaceId")
|
||||
|
||||
text = ""
|
||||
if load_content:
|
||||
body = page.get("body", {})
|
||||
storage = body.get("storage", {}) if isinstance(body, dict) else {}
|
||||
text = storage.get("value", "") if isinstance(storage, dict) else ""
|
||||
|
||||
return Document(
|
||||
text=text,
|
||||
doc_id=str(page_id),
|
||||
extra_info={
|
||||
"file_name": title,
|
||||
"mime_type": "text/html",
|
||||
"size": len(text) if text else None,
|
||||
"created_time": created_time,
|
||||
"modified_time": modified_time,
|
||||
"source": "confluence",
|
||||
"is_folder": False,
|
||||
"page_id": str(page_id),
|
||||
"space_id": resolved_space_id,
|
||||
"cloud_id": self.cloud_id,
|
||||
},
|
||||
)
|
||||
|
||||
def _download_page(self, page_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/pages/{page_id}",
|
||||
headers=self._headers(),
|
||||
params={"body-format": "storage"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
page = response.json()
|
||||
|
||||
title = page.get("title", page_id)
|
||||
safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in title)
|
||||
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||
|
||||
file_path = os.path.join(local_dir, f"{safe_name}.html")
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(body)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error downloading page %s: %s", page_id, e)
|
||||
return False
|
||||
|
||||
def _download_page_attachments(self, page_id: str, local_dir: str) -> int:
|
||||
downloaded = 0
|
||||
try:
|
||||
cursor = None
|
||||
while True:
|
||||
params: Dict[str, Any] = {"limit": 100}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/pages/{page_id}/attachments",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
for att in data.get("results", []):
|
||||
media_type = att.get("mediaType", "")
|
||||
if media_type not in SUPPORTED_ATTACHMENT_TYPES:
|
||||
continue
|
||||
|
||||
download_link = att.get("_links", {}).get("download")
|
||||
if not download_link:
|
||||
continue
|
||||
|
||||
raw_name = att.get("title", att.get("id", "attachment"))
|
||||
file_name = "".join(
|
||||
c if c.isalnum() or c in " -_." else "_"
|
||||
for c in os.path.basename(raw_name)
|
||||
) or "attachment"
|
||||
file_path = os.path.join(local_dir, file_name)
|
||||
|
||||
url = f"{self.download_base}{download_link}"
|
||||
file_resp = requests.get(
|
||||
url, headers=self._headers(), timeout=60, stream=True
|
||||
)
|
||||
file_resp.raise_for_status()
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in file_resp.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
downloaded += 1
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
cursor = self._extract_cursor(next_link)
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error downloading attachments for page %s: %s", page_id, e)
|
||||
return downloaded
|
||||
|
||||
def _download_space(self, space_id: str, local_dir: str) -> int:
|
||||
downloaded = 0
|
||||
cursor = None
|
||||
while True:
|
||||
params: Dict[str, Any] = {"limit": 250}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/spaces/{space_id}/pages",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except Exception as e:
|
||||
logger.error("Error listing pages in space %s: %s", space_id, e)
|
||||
break
|
||||
|
||||
for page in data.get("results", []):
|
||||
page_id = page.get("id")
|
||||
if self._download_page(str(page_id), local_dir):
|
||||
downloaded += 1
|
||||
downloaded += self._download_page_attachments(str(page_id), local_dir)
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
cursor = self._extract_cursor(next_link)
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
return downloaded
|
||||
|
||||
@staticmethod
|
||||
def _extract_cursor(next_link: Optional[str]) -> Optional[str]:
|
||||
if not next_link:
|
||||
return None
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
parsed = urlparse(next_link)
|
||||
cursors = parse_qs(parsed.query).get("cursor")
|
||||
return cursors[0] if cursors else None
|
||||
@@ -1,5 +1,7 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||
from application.parser.connectors.confluence.loader import ConfluenceLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||
|
||||
@@ -13,11 +15,13 @@ class ConnectorCreator:
|
||||
"""
|
||||
|
||||
connectors = {
|
||||
"confluence": ConfluenceLoader,
|
||||
"google_drive": GoogleDriveLoader,
|
||||
"share_point": SharePointLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"confluence": ConfluenceAuth,
|
||||
"google_drive": GoogleDriveAuth,
|
||||
"share_point": SharePointAuth,
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
file_metadata = response.json()
|
||||
@@ -236,9 +236,9 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
||||
else:
|
||||
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
||||
response = requests.get(search_url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(search_url, headers=self._get_headers(), params=params, timeout=100)
|
||||
else:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -307,7 +307,8 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
response = requests.get(
|
||||
f"{self.GRAPH_API_BASE}/me/drive",
|
||||
headers=self._get_headers(),
|
||||
params={'$select': 'webUrl'}
|
||||
params={'$select': 'webUrl'},
|
||||
timeout=100,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get('webUrl')
|
||||
@@ -352,7 +353,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
|
||||
headers = self._get_headers()
|
||||
headers["Content-Type"] = "application/json"
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response = requests.post(url, headers=headers, json=body, timeout=100)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
@@ -472,7 +473,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
|
||||
try:
|
||||
url = f"{self._get_item_url(file_id)}/content"
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response = requests.get(url, headers=self._get_headers(), timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
@@ -491,7 +492,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
metadata = response.json()
|
||||
@@ -507,7 +508,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
full_path = os.path.join(local_dir, file_name)
|
||||
|
||||
download_url = f"{self._get_item_url(file_id)}/content"
|
||||
download_response = requests.get(download_url, headers=self._get_headers())
|
||||
download_response = requests.get(download_url, headers=self._get_headers(), timeout=100)
|
||||
download_response.raise_for_status()
|
||||
|
||||
with open(full_path, 'wb') as f:
|
||||
@@ -527,7 +528,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
params = {'$top': 1000}
|
||||
|
||||
while url:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
@@ -609,7 +610,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
try:
|
||||
url = self._get_item_url(folder_id)
|
||||
params = {'$select': 'id,name'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
folder_metadata = response.json()
|
||||
|
||||
@@ -24,7 +24,7 @@ class PDFParser(BaseParser):
|
||||
# alternatively you can use local vision capable LLM
|
||||
with open(file, "rb") as file_loaded:
|
||||
files = {'file': file_loaded}
|
||||
response = requests.post(doc2md_service, files=files)
|
||||
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||
data = response.json()["markdown"]
|
||||
return data
|
||||
|
||||
|
||||
@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
"""Parse file."""
|
||||
try:
|
||||
import ebooklib
|
||||
from ebooklib import epub
|
||||
from fast_ebook import epub
|
||||
except ImportError:
|
||||
raise ValueError("`EbookLib` is required to read Epub files.")
|
||||
try:
|
||||
import html2text
|
||||
except ImportError:
|
||||
raise ValueError("`html2text` is required to parse Epub files.")
|
||||
raise ValueError("`fast-ebook` is required to read Epub files.")
|
||||
|
||||
text_list = []
|
||||
book = epub.read_epub(file, options={"ignore_ncx": True})
|
||||
|
||||
# Iterate through all chapters.
|
||||
for item in book.get_items():
|
||||
# Chapters are typically located in epub documents items.
|
||||
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
||||
text_list.append(
|
||||
html2text.html2text(item.get_content().decode("utf-8"))
|
||||
)
|
||||
|
||||
text = "\n".join(text_list)
|
||||
book = epub.read_epub(file)
|
||||
text = book.to_markdown()
|
||||
return text
|
||||
|
||||
@@ -24,7 +24,7 @@ class ImageParser(BaseParser):
|
||||
# alternatively you can use local vision capable LLM
|
||||
with open(file, "rb") as file_loaded:
|
||||
files = {'file': file_loaded}
|
||||
response = requests.post(doc2md_service, files=files)
|
||||
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||
data = response.json()["markdown"]
|
||||
else:
|
||||
data = ""
|
||||
|
||||
@@ -77,7 +77,7 @@ class GitHubLoader(BaseRemote):
|
||||
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
for attempt in range(max_retries):
|
||||
response = requests.get(url, headers=self.headers)
|
||||
response = requests.get(url, headers=self.headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
alembic>=1.13,<2
|
||||
anthropic==0.88.0
|
||||
boto3==1.42.83
|
||||
beautifulsoup4==4.14.3
|
||||
@@ -11,7 +12,7 @@ rapidocr>=1.4.0
|
||||
onnxruntime>=1.19.0
|
||||
docx2txt==0.9
|
||||
ddgs>=8.0.0
|
||||
ebooklib==0.20
|
||||
fast-ebook
|
||||
elevenlabs==2.41.0
|
||||
Flask==3.1.3
|
||||
faiss-cpu==1.13.2
|
||||
@@ -23,7 +24,6 @@ google-auth-httplib2==0.3.1
|
||||
google-auth-oauthlib==1.3.1
|
||||
gTTS==2.5.4
|
||||
gunicorn==25.3.0
|
||||
html2text==2025.4.15
|
||||
jinja2==3.1.6
|
||||
jiter==0.13.0
|
||||
jmespath==1.1.0
|
||||
@@ -59,7 +59,7 @@ pillow
|
||||
portalocker>=2.7.0,<4.0.0
|
||||
prompt-toolkit==3.0.52
|
||||
protobuf==7.34.1
|
||||
psycopg2-binary==2.9.11
|
||||
psycopg[binary,pool]>=3.1,<4
|
||||
py==1.11.0
|
||||
pydantic
|
||||
pydantic-core
|
||||
@@ -76,6 +76,7 @@ regex==2026.4.4
|
||||
requests==2.33.1
|
||||
retry==0.9.2
|
||||
sentence-transformers==5.3.0
|
||||
sqlalchemy>=2.0,<3
|
||||
tiktoken==0.12.0
|
||||
tokenizers==0.22.2
|
||||
torch==2.11.0
|
||||
|
||||
10
application/storage/db/__init__.py
Normal file
10
application/storage/db/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""PostgreSQL storage layer for user-level data.
|
||||
|
||||
This package holds the SQLAlchemy Core engine, metadata, repositories, and
|
||||
migration infrastructure for the user-data Postgres database. It is separate
|
||||
from ``application/vectorstore/pgvector.py`` — the two may point at the same
|
||||
cluster or at different clusters depending on operator configuration.
|
||||
|
||||
Repository modules are added in later phases
|
||||
as individual collections are ported.
|
||||
"""
|
||||
39
application/storage/db/base_repository.py
Normal file
39
application/storage/db/base_repository.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Common helpers shared by all repositories.
|
||||
|
||||
Repositories are thin wrappers around SQLAlchemy Core query construction.
|
||||
They take a ``Connection`` on call and return plain ``dict`` rows during the
|
||||
Mongo→Postgres cutover so that call sites don't have to change shape. Once
|
||||
cutover is complete, a follow-up phase may migrate repo return types to
|
||||
Pydantic DTOs (tracked in the migration plan as a post-migration item).
|
||||
"""
|
||||
|
||||
from typing import Any, Mapping
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def row_to_dict(row: Any) -> dict:
|
||||
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
|
||||
|
||||
During the migration window, API responses and downstream code still
|
||||
expect a string ``_id`` field (matching the Mongo shape). This helper
|
||||
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
|
||||
existing serializers keep working unchanged.
|
||||
|
||||
Args:
|
||||
row: A SQLAlchemy ``Row`` object, or ``None``.
|
||||
|
||||
Returns:
|
||||
A plain dict, or an empty dict if ``row`` is ``None``.
|
||||
"""
|
||||
if row is None:
|
||||
return {}
|
||||
|
||||
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
|
||||
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
|
||||
out = dict(mapping)
|
||||
|
||||
if "id" in out and out["id"] is not None:
|
||||
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
|
||||
out["_id"] = out["id"]
|
||||
|
||||
return out
|
||||
67
application/storage/db/dual_write.py
Normal file
67
application/storage/db/dual_write.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Best-effort Postgres dual-write helper used during the MongoDB→Postgres
|
||||
migration.
|
||||
|
||||
The helper:
|
||||
|
||||
* Returns immediately if ``settings.USE_POSTGRES`` is off, so default-off
|
||||
call sites add literally zero work.
|
||||
* Opens a transactional connection from the user-data SQLAlchemy engine.
|
||||
* Instantiates the caller's repository class on that connection.
|
||||
* Runs the caller's operation.
|
||||
* Swallows and logs any exception. **Mongo remains the source of truth
|
||||
during the dual-write window** — a Postgres-side failure must never
|
||||
break a user-facing request. Drift that builds up from swallowed
|
||||
failures is caught separately by re-running the backfill script.
|
||||
|
||||
Call sites look like::
|
||||
|
||||
users_collection.update_one(..., {"$addToSet": {...}}) # Mongo write, unchanged
|
||||
dual_write(UsersRepository, lambda r: r.add_pinned(uid, aid)) # Postgres mirror
|
||||
|
||||
A single parameterised helper rather than one function per collection
|
||||
means a new collection just needs its repository class — no new helper
|
||||
function, no new feature flag. The whole helper is deleted at Phase 5
|
||||
when the migration is complete.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_Repo = TypeVar("_Repo")
|
||||
|
||||
|
||||
def dual_write(repo_cls: type[_Repo], fn: Callable[[_Repo], None]) -> None:
|
||||
"""Mirror a Mongo write into Postgres via ``repo_cls``, best-effort.
|
||||
|
||||
No-op when ``settings.USE_POSTGRES`` is false. Any exception
|
||||
(connection pool exhaustion, migration drift, SQL error) is logged
|
||||
and swallowed so the caller's primary Mongo write remains the source
|
||||
of truth.
|
||||
|
||||
Args:
|
||||
repo_cls: The repository class to instantiate (e.g. ``UsersRepository``).
|
||||
fn: A callable that takes the instantiated repository and performs
|
||||
the desired write.
|
||||
"""
|
||||
if not settings.USE_POSTGRES:
|
||||
return
|
||||
|
||||
try:
|
||||
# Lazy import so modules that import dual_write don't pay the
|
||||
# SQLAlchemy import cost when the flag is off.
|
||||
from application.storage.db.engine import get_engine
|
||||
|
||||
with get_engine().begin() as conn:
|
||||
fn(repo_cls(conn))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Postgres dual-write failed for %s — Mongo write already committed",
|
||||
repo_cls.__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
73
application/storage/db/engine.py
Normal file
73
application/storage/db/engine.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""SQLAlchemy Core engine factory for the user-data Postgres database.
|
||||
|
||||
The engine is lazily constructed on first use and cached as a module-level
|
||||
singleton. Repositories and the Alembic env module both obtain connections
|
||||
through this factory, so pool tuning lives in one place.
|
||||
|
||||
``POSTGRES_URI`` can be written in any of the common Postgres URI forms::
|
||||
|
||||
postgres://user:pass@host:5432/docsgpt
|
||||
postgresql://user:pass@host:5432/docsgpt
|
||||
|
||||
Both are accepted and normalized internally to the psycopg3 dialect
|
||||
(``postgresql+psycopg://``) by ``application.core.settings``. Operators
|
||||
don't need to know about SQLAlchemy dialect prefixes.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Engine, create_engine
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
|
||||
def _resolve_uri() -> str:
|
||||
"""Return the Postgres URI for user-data tables.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that
|
||||
reach this path without a configured URI have a setup bug — the
|
||||
error message points them at the right setting.
|
||||
"""
|
||||
if not settings.POSTGRES_URI:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
return settings.POSTGRES_URI
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
|
||||
|
||||
Returns:
|
||||
A SQLAlchemy ``Engine`` configured with a pooled connection to
|
||||
Postgres via psycopg3.
|
||||
"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = create_engine(
|
||||
_resolve_uri(),
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
|
||||
pool_recycle=1800,
|
||||
future=True,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def dispose_engine() -> None:
|
||||
"""Dispose the pooled connections and reset the singleton.
|
||||
|
||||
Called from the Celery ``worker_process_init`` signal so each forked
|
||||
worker gets a fresh pool instead of sharing file descriptors with the
|
||||
parent process (which corrupts the pool on fork).
|
||||
"""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.dispose()
|
||||
_engine = None
|
||||
111
application/storage/db/models.py
Normal file
111
application/storage/db/models.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""SQLAlchemy Core metadata for the user-data Postgres database.
|
||||
|
||||
Tables are added here one at a time as repositories are built during the
|
||||
MongoDB→Postgres migration. The baseline schema in the Alembic migration
|
||||
(``application/alembic/versions/0001_initial.py``) is the source of truth
|
||||
for DDL; the ``Table`` definitions below must match it column-for-column.
|
||||
If the two drift, migrations win — update this file to match.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
MetaData,
|
||||
Table,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||
|
||||
users_table = Table(
|
||||
"users",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False, unique=True),
|
||||
Column(
|
||||
"agent_preferences",
|
||||
JSONB,
|
||||
nullable=False,
|
||||
server_default='{"pinned": [], "shared_with_me": []}',
|
||||
),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
prompts_table = Table(
|
||||
"prompts",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("content", Text, nullable=False),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
user_tools_table = Table(
|
||||
"user_tools",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("custom_name", Text),
|
||||
Column("display_name", Text),
|
||||
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
token_usage_table = Table(
|
||||
"token_usage",
|
||||
metadata,
|
||||
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||
Column("user_id", Text),
|
||||
Column("api_key", Text),
|
||||
Column("agent_id", UUID(as_uuid=True)),
|
||||
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
user_logs_table = Table(
|
||||
"user_logs",
|
||||
metadata,
|
||||
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||
Column("user_id", Text),
|
||||
Column("endpoint", Text),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("data", JSONB),
|
||||
)
|
||||
|
||||
feedback_table = Table(
|
||||
"feedback",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("conversation_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("question_index", Integer, nullable=False),
|
||||
Column("feedback_text", Text),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
stack_logs_table = Table(
|
||||
"stack_logs",
|
||||
metadata,
|
||||
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||
Column("activity_id", Text, nullable=False),
|
||||
Column("endpoint", Text),
|
||||
Column("level", Text),
|
||||
Column("user_id", Text),
|
||||
Column("api_key", Text),
|
||||
Column("query", Text),
|
||||
Column("stacks", JSONB, nullable=False, server_default="[]"),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
11
application/storage/db/repositories/__init__.py
Normal file
11
application/storage/db/repositories/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Repositories for the user-data Postgres database.
|
||||
|
||||
Each module in this package exposes exactly one repository class. Repository
|
||||
methods take a ``Connection`` (either as a constructor argument or as a
|
||||
method argument) and return plain ``dict`` rows via
|
||||
``application.storage.db.base_repository.row_to_dict`` during the
|
||||
MongoDB→Postgres cutover, so call sites don't have to change shape.
|
||||
|
||||
Repositories are added one collection at a time, matching the phased
|
||||
rollout in ``migration-postgres.md``.
|
||||
"""
|
||||
57
application/storage/db/repositories/feedback.py
Normal file
57
application/storage/db/repositories/feedback.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Repository for the ``feedback`` table.
|
||||
|
||||
The ``feedback_collection`` global is declared in ``base.py`` but currently
|
||||
has zero direct call sites in the application code (all feedback writes go
|
||||
through ``conversation_messages.feedback`` JSONB field on the conversations
|
||||
collection). The table exists for when feedback is denormalized into its own
|
||||
rows. This repository provides the append-only insert and basic reads
|
||||
needed for that future.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
"""Postgres-backed replacement for Mongo ``feedback_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
question_index: int,
|
||||
feedback_text: Optional[str] = None,
|
||||
) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text)
|
||||
VALUES (CAST(:conversation_id AS uuid), :user_id, :question_index, :feedback_text)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"question_index": question_index,
|
||||
"feedback_text": feedback_text,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def list_for_conversation(self, conversation_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM feedback WHERE conversation_id = CAST(:cid AS uuid) ORDER BY question_index"
|
||||
),
|
||||
{"cid": conversation_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
96
application/storage/db/repositories/prompts.py
Normal file
96
application/storage/db/repositories/prompts.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Repository for the ``prompts`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``prompts_collection``:
|
||||
|
||||
1. ``insert_one`` in prompts/routes.py (create)
|
||||
2. ``find`` by user in prompts/routes.py (list)
|
||||
3. ``find_one`` by id+user in prompts/routes.py (get single)
|
||||
4. ``find_one`` by id only in stream_processor.py (get content for rendering)
|
||||
5. ``update_one`` in prompts/routes.py (update name+content)
|
||||
6. ``delete_one`` in prompts/routes.py (delete)
|
||||
7. ``find_one`` + ``insert_one`` in seeder.py (upsert by user+name+content)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class PromptsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``prompts_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, name: str, content: str) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO prompts (user_id, name, content)
|
||||
VALUES (:user_id, :name, :content)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "name": name, "content": content},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, prompt_id: str, user_id: Optional[str] = None) -> Optional[dict]:
|
||||
if user_id is not None:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": prompt_id, "user_id": user_id},
|
||||
)
|
||||
else:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": prompt_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM prompts WHERE user_id = :user_id ORDER BY created_at"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, prompt_id: str, user_id: str, name: str, content: str) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE prompts
|
||||
SET name = :name, content = :content, updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"id": prompt_id, "user_id": user_id, "name": name, "content": content},
|
||||
)
|
||||
|
||||
def delete(self, prompt_id: str, user_id: str) -> None:
|
||||
self._conn.execute(
|
||||
text("DELETE FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": prompt_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
def find_or_create(self, user_id: str, name: str, content: str) -> dict:
|
||||
"""Return existing prompt matching (user, name, content), or create one.
|
||||
|
||||
Used by the seeder to avoid duplicating template prompts.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM prompts WHERE user_id = :user_id AND name = :name AND content = :content"
|
||||
),
|
||||
{"user_id": user_id, "name": name, "content": content},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is not None:
|
||||
return row_to_dict(row)
|
||||
return self.create(user_id, name, content)
|
||||
58
application/storage/db/repositories/stack_logs.py
Normal file
58
application/storage/db/repositories/stack_logs.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Repository for the ``stack_logs`` table.
|
||||
|
||||
Covers the single operation the legacy Mongo code performs:
|
||||
|
||||
1. ``insert_one`` in logging.py ``_log_to_mongodb`` — append-only debug/error
|
||||
activity log. The Mongo collection is ``stack_logs``; the Mongo variable
|
||||
inside ``_log_to_mongodb`` is misleadingly named ``user_logs_collection``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
|
||||
class StackLogsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``stack_logs`` collection."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def insert(
|
||||
self,
|
||||
*,
|
||||
activity_id: str,
|
||||
endpoint: Optional[str] = None,
|
||||
level: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
stacks: Optional[list] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO stack_logs (activity_id, endpoint, level, user_id, api_key, query, stacks, timestamp)
|
||||
VALUES (
|
||||
:activity_id, :endpoint, :level, :user_id, :api_key, :query,
|
||||
CAST(:stacks AS jsonb),
|
||||
COALESCE(:timestamp, now())
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"activity_id": activity_id,
|
||||
"endpoint": endpoint,
|
||||
"level": level,
|
||||
"user_id": user_id,
|
||||
"api_key": api_key,
|
||||
"query": query,
|
||||
"stacks": json.dumps(stacks or []),
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
104
application/storage/db/repositories/token_usage.py
Normal file
104
application/storage/db/repositories/token_usage.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Repository for the ``token_usage`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``token_usage_collection`` / ``usage_collection``:
|
||||
|
||||
1. ``insert_one`` in usage.py (record per-call token counts)
|
||||
2. ``aggregate`` in analytics/routes.py (time-bucketed totals)
|
||||
3. ``aggregate`` in answer/routes/base.py (24h sum for rate limiting)
|
||||
4. ``count_documents`` in answer/routes/base.py (24h request count)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
|
||||
class TokenUsageRepository:
|
||||
"""Postgres-backed replacement for Mongo ``token_usage_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def insert(
|
||||
self,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
prompt_tokens: int = 0,
|
||||
generated_tokens: int = 0,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
|
||||
VALUES (
|
||||
:user_id, :api_key,
|
||||
CAST(:agent_id AS uuid),
|
||||
:prompt_tokens, :generated_tokens,
|
||||
COALESCE(:timestamp, now())
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"api_key": api_key,
|
||||
"agent_id": agent_id,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
def sum_tokens_in_range(
|
||||
self,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Total (prompt + generated) tokens in the given time range."""
|
||||
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||
params: dict = {"start": start, "end": end}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
params["user_id"] = user_id
|
||||
if api_key is not None:
|
||||
clauses.append("api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
result = self._conn.execute(
|
||||
text(f"SELECT COALESCE(SUM(prompt_tokens + generated_tokens), 0) FROM token_usage WHERE {where}"),
|
||||
params,
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
def count_in_range(
|
||||
self,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Count of token_usage rows in the given time range (for request limiting)."""
|
||||
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||
params: dict = {"start": start, "end": end}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
params["user_id"] = user_id
|
||||
if api_key is not None:
|
||||
clauses.append("api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
result = self._conn.execute(
|
||||
text(f"SELECT COUNT(*) FROM token_usage WHERE {where}"),
|
||||
params,
|
||||
)
|
||||
return result.scalar()
|
||||
84
application/storage/db/repositories/user_logs.py
Normal file
84
application/storage/db/repositories/user_logs.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Repository for the ``user_logs`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``user_logs_collection``:
|
||||
|
||||
1. ``insert_one`` in logging.py (per-request activity log via
|
||||
``_log_to_mongodb`` — note: the *Mongo* variable is confusingly named
|
||||
``user_logs_collection`` but points at the ``user_logs`` Mongo
|
||||
collection, not ``stack_logs``)
|
||||
2. ``insert_one`` in answer/routes/base.py (per-stream log entry)
|
||||
3. ``find`` with sort/skip/limit in analytics/routes.py (paginated log list)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class UserLogsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``user_logs_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def insert(
|
||||
self,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
data: Optional[dict] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO user_logs (user_id, endpoint, data, timestamp)
|
||||
VALUES (:user_id, :endpoint, CAST(:data AS jsonb), COALESCE(:timestamp, now()))
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"endpoint": endpoint,
|
||||
"data": json.dumps(data) if data is not None else None,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
def list_paginated(
|
||||
self,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
) -> tuple[list[dict], bool]:
|
||||
"""Return ``(rows, has_more)`` for the requested page.
|
||||
|
||||
Mirrors the Mongo ``find(query).sort().skip().limit(page_size+1)``
|
||||
pattern used in analytics/routes.py.
|
||||
"""
|
||||
clauses: list[str] = []
|
||||
params: dict = {"limit": page_size + 1, "offset": (page - 1) * page_size}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
params["user_id"] = user_id
|
||||
if api_key is not None:
|
||||
clauses.append("data->>'api_key' = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
f"SELECT * FROM user_logs {where} ORDER BY timestamp DESC LIMIT :limit OFFSET :offset"
|
||||
),
|
||||
params,
|
||||
)
|
||||
rows = [row_to_dict(r) for r in result.fetchall()]
|
||||
has_more = len(rows) > page_size
|
||||
return rows[:page_size], has_more
|
||||
114
application/storage/db/repositories/user_tools.py
Normal file
114
application/storage/db/repositories/user_tools.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Repository for the ``user_tools`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``user_tools_collection``:
|
||||
|
||||
1. ``find`` by user in tools/routes.py and base.py (list all / active)
|
||||
2. ``find_one`` by id in tools/routes.py and sharing.py (get single)
|
||||
3. ``insert_one`` in tools/routes.py and mcp.py (create)
|
||||
4. ``update_one`` in tools/routes.py and mcp.py (update fields)
|
||||
5. ``delete_one`` in tools/routes.py (delete)
|
||||
6. ``find`` by user+status in stream_processor.py and tool_executor.py (active tools)
|
||||
7. ``find_one`` by user+name in mcp.py (upsert check)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class UserToolsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``user_tools_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, name: str, *, config: Optional[dict] = None,
|
||||
custom_name: Optional[str] = None, display_name: Optional[str] = None,
|
||||
extra: Optional[dict] = None) -> dict:
|
||||
"""Insert a new tool row. ``extra`` is merged into the config JSONB."""
|
||||
cfg = config or {}
|
||||
if extra:
|
||||
cfg.update(extra)
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO user_tools (user_id, name, custom_name, display_name, config)
|
||||
VALUES (:user_id, :name, :custom_name, :display_name, CAST(:config AS jsonb))
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"custom_name": custom_name,
|
||||
"display_name": display_name,
|
||||
"config": json.dumps(cfg),
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, tool_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM user_tools WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": tool_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM user_tools WHERE user_id = :user_id ORDER BY created_at"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, tool_id: str, user_id: str, fields: dict) -> None:
|
||||
"""Update arbitrary fields on a tool row.
|
||||
|
||||
``fields`` maps column names to new values. Only ``name``,
|
||||
``custom_name``, ``display_name``, and ``config`` are allowed.
|
||||
"""
|
||||
allowed = {"name", "custom_name", "display_name", "config"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return
|
||||
params: dict = {
|
||||
"id": tool_id,
|
||||
"user_id": user_id,
|
||||
"name": filtered.get("name"),
|
||||
"custom_name": filtered.get("custom_name"),
|
||||
"display_name": filtered.get("display_name"),
|
||||
"config": (
|
||||
json.dumps(filtered["config"])
|
||||
if "config" in filtered and isinstance(filtered["config"], dict)
|
||||
else filtered.get("config")
|
||||
),
|
||||
}
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE user_tools
|
||||
SET
|
||||
name = COALESCE(:name, name),
|
||||
custom_name = COALESCE(:custom_name, custom_name),
|
||||
display_name = COALESCE(:display_name, display_name),
|
||||
config = COALESCE(CAST(:config AS jsonb), config),
|
||||
updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
def delete(self, tool_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": tool_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
245
application/storage/db/repositories/users.py
Normal file
245
application/storage/db/repositories/users.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Repository for the ``users`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``users_collection``:
|
||||
|
||||
1. ``ensure_user_doc`` in ``application/api/user/base.py`` (upsert + get)
|
||||
2. Pin/unpin agents in ``application/api/user/agents/routes.py`` (add/remove
|
||||
on ``agent_preferences.pinned``)
|
||||
3. Share accept/reject in ``application/api/user/agents/sharing.py`` (add/
|
||||
bulk-remove on ``agent_preferences.shared_with_me``)
|
||||
4. Cascade delete of an agent id from both arrays at once
|
||||
|
||||
All array mutations are implemented as single atomic UPDATE statements
|
||||
using JSONB operators (``jsonb_set``, ``jsonb_array_elements``, ``@>``)
|
||||
so there is no read-modify-write race between concurrent writers on the
|
||||
same user row.
|
||||
|
||||
The repository takes a ``Connection`` and does not manage its own
|
||||
transactions. Callers are responsible for wrapping writes in
|
||||
``with engine.begin() as conn:`` (production) or the test fixture's
|
||||
rollback-per-test connection (tests).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
_DEFAULT_PREFERENCES = '{"pinned": [], "shared_with_me": []}'
|
||||
|
||||
|
||||
class UsersRepository:
|
||||
"""Postgres-backed replacement for Mongo ``users_collection`` writes/reads."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
def get(self, user_id: str) -> Optional[dict]:
|
||||
"""Return the user row as a dict, or ``None`` if missing.
|
||||
|
||||
Args:
|
||||
user_id: Auth-provider ``sub`` (opaque string).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM users WHERE user_id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Upsert
|
||||
# ------------------------------------------------------------------
|
||||
def upsert(self, user_id: str) -> dict:
|
||||
"""Ensure a row exists for ``user_id`` and return it.
|
||||
|
||||
Matches Mongo's ``find_one_and_update(..., $setOnInsert, upsert=True,
|
||||
return_document=AFTER)`` semantics: if the row exists, preferences
|
||||
are preserved untouched; if it doesn't, a new row is created with
|
||||
default preferences.
|
||||
|
||||
The ``DO UPDATE SET user_id = EXCLUDED.user_id`` branch is a
|
||||
deliberate no-op that lets ``RETURNING *`` fire on both the insert
|
||||
and conflict paths (``DO NOTHING`` would suppress the returning).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO users (user_id, agent_preferences)
|
||||
VALUES (:user_id, CAST(:default_prefs AS jsonb))
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET user_id = EXCLUDED.user_id
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "default_prefs": _DEFAULT_PREFERENCES},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pinned agents
|
||||
# ------------------------------------------------------------------
|
||||
def add_pinned(self, user_id: str, agent_id: str) -> None:
|
||||
"""Idempotently append ``agent_id`` to ``agent_preferences.pinned``.
|
||||
|
||||
Uses ``@>`` containment so a duplicate add is a no-op rather than a
|
||||
silent double-insert. The whole update is a single atomic statement
|
||||
so concurrent add_pinned calls on the same user cannot interleave
|
||||
into a read-modify-write race.
|
||||
"""
|
||||
self._append_to_jsonb_array(user_id, "pinned", agent_id)
|
||||
|
||||
def remove_pinned(self, user_id: str, agent_id: str) -> None:
|
||||
"""Remove ``agent_id`` from ``agent_preferences.pinned`` if present."""
|
||||
self._remove_from_jsonb_array(user_id, "pinned", [agent_id])
|
||||
|
||||
def remove_pinned_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
|
||||
"""Remove every id in ``agent_ids`` from ``agent_preferences.pinned``.
|
||||
|
||||
No-op if the list is empty. Unknown ids are silently ignored so
|
||||
callers can pass the full "stale" set without pre-filtering.
|
||||
"""
|
||||
ids = list(agent_ids)
|
||||
if not ids:
|
||||
return
|
||||
self._remove_from_jsonb_array(user_id, "pinned", ids)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared-with-me agents
|
||||
# ------------------------------------------------------------------
|
||||
def add_shared(self, user_id: str, agent_id: str) -> None:
|
||||
"""Idempotently append ``agent_id`` to ``agent_preferences.shared_with_me``."""
|
||||
self._append_to_jsonb_array(user_id, "shared_with_me", agent_id)
|
||||
|
||||
def remove_shared_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
|
||||
"""Bulk-remove from ``agent_preferences.shared_with_me``. Empty list is a no-op."""
|
||||
ids = list(agent_ids)
|
||||
if not ids:
|
||||
return
|
||||
self._remove_from_jsonb_array(user_id, "shared_with_me", ids)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Combined removal — called when an agent is hard-deleted
|
||||
# ------------------------------------------------------------------
|
||||
def remove_agent_from_all(self, user_id: str, agent_id: str) -> None:
|
||||
"""Remove ``agent_id`` from BOTH pinned and shared_with_me atomically.
|
||||
|
||||
Mirrors the Mongo ``$pull`` that targets both nested array fields
|
||||
in one ``update_one`` — see ``application/api/user/agents/routes.py``
|
||||
around the agent-delete path.
|
||||
"""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE users
|
||||
SET
|
||||
agent_preferences = jsonb_set(
|
||||
jsonb_set(
|
||||
agent_preferences,
|
||||
'{pinned}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||
) AS elem
|
||||
WHERE (elem #>> '{}') != :agent_id
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
),
|
||||
'{shared_with_me}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||
) AS elem
|
||||
WHERE (elem #>> '{}') != :agent_id
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "agent_id": agent_id},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _append_to_jsonb_array(self, user_id: str, key: str, agent_id: str) -> None:
|
||||
"""Idempotent append of ``agent_id`` to ``agent_preferences.<key>``.
|
||||
|
||||
The ``key`` argument is NOT user input — it's hard-coded by the
|
||||
calling method (``pinned`` / ``shared_with_me``). It goes into the
|
||||
SQL literal because ``jsonb_set`` requires a path literal, not a
|
||||
bind parameter. This is safe as long as callers never pass
|
||||
untrusted strings for ``key``.
|
||||
"""
|
||||
if key not in ("pinned", "shared_with_me"):
|
||||
raise ValueError(f"unsupported jsonb key: {key!r}")
|
||||
self._conn.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE users
|
||||
SET
|
||||
agent_preferences = jsonb_set(
|
||||
agent_preferences,
|
||||
'{{{key}}}',
|
||||
CASE
|
||||
WHEN agent_preferences->'{key}' @> to_jsonb(CAST(:agent_id AS text))
|
||||
THEN agent_preferences->'{key}'
|
||||
ELSE
|
||||
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|
||||
|| to_jsonb(CAST(:agent_id AS text))
|
||||
END
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "agent_id": agent_id},
|
||||
)
|
||||
|
||||
def _remove_from_jsonb_array(
|
||||
self, user_id: str, key: str, agent_ids: list[str]
|
||||
) -> None:
|
||||
"""Remove every id in ``agent_ids`` from ``agent_preferences.<key>``."""
|
||||
if key not in ("pinned", "shared_with_me"):
|
||||
raise ValueError(f"unsupported jsonb key: {key!r}")
|
||||
self._conn.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE users
|
||||
SET
|
||||
agent_preferences = jsonb_set(
|
||||
agent_preferences,
|
||||
'{{{key}}}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|
||||
) AS elem
|
||||
WHERE NOT ((elem #>> '{{}}') = ANY(:agent_ids))
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "agent_ids": agent_ids},
|
||||
)
|
||||
@@ -110,6 +110,20 @@ def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
usage_data["agent_id"] = normalized_agent_id
|
||||
usage_collection.insert_one(usage_data)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
|
||||
dual_write(
|
||||
TokenUsageRepository,
|
||||
lambda repo, d=usage_data: repo.insert(
|
||||
user_id=d.get("user_id"),
|
||||
api_key=d.get("api_key"),
|
||||
agent_id=d.get("agent_id"),
|
||||
prompt_tokens=d["prompt_tokens"],
|
||||
generated_tokens=d["generated_tokens"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
|
||||
@@ -27,37 +27,42 @@ class PGVectorStore(BaseVectorStore):
|
||||
self._metadata_column = metadata_column
|
||||
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
|
||||
# Use provided connection string or fall back to settings
|
||||
# Use provided connection string or fall back to settings.
|
||||
# If PGVECTOR_CONNECTION_STRING is not set but POSTGRES_URI is,
|
||||
# reuse the same cluster — normalize from SQLAlchemy dialect to libpq form.
|
||||
self._connection_string = connection_string or getattr(settings, 'PGVECTOR_CONNECTION_STRING', None)
|
||||
|
||||
|
||||
if not self._connection_string and getattr(settings, 'POSTGRES_URI', None):
|
||||
from application.core.db_uri import normalize_pgvector_connection_string
|
||||
self._connection_string = normalize_pgvector_connection_string(settings.POSTGRES_URI)
|
||||
|
||||
if not self._connection_string:
|
||||
raise ValueError(
|
||||
"PostgreSQL connection string is required. "
|
||||
"Set PGVECTOR_CONNECTION_STRING in settings or pass connection_string parameter."
|
||||
"Set PGVECTOR_CONNECTION_STRING or POSTGRES_URI in settings, "
|
||||
"or pass connection_string parameter."
|
||||
)
|
||||
|
||||
try:
|
||||
import psycopg2
|
||||
from psycopg2.extras import Json
|
||||
import pgvector.psycopg2
|
||||
import psycopg
|
||||
from pgvector.psycopg import register_vector
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import required packages. "
|
||||
"Please install with `pip install psycopg2-binary pgvector`."
|
||||
"Please install with `pip install 'psycopg[binary,pool]' pgvector`."
|
||||
)
|
||||
|
||||
self._psycopg2 = psycopg2
|
||||
self._Json = Json
|
||||
self._pgvector = pgvector.psycopg2
|
||||
self._psycopg = psycopg
|
||||
self._register_vector = register_vector
|
||||
self._connection = None
|
||||
self._ensure_table_exists()
|
||||
|
||||
def _get_connection(self):
|
||||
"""Get or create database connection"""
|
||||
if self._connection is None or self._connection.closed:
|
||||
self._connection = self._psycopg2.connect(self._connection_string)
|
||||
self._connection = self._psycopg.connect(self._connection_string)
|
||||
# Register pgvector types
|
||||
self._pgvector.register_vector(self._connection)
|
||||
self._register_vector(self._connection)
|
||||
return self._connection
|
||||
|
||||
def _ensure_table_exists(self):
|
||||
@@ -170,7 +175,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embedding, self._Json(metadata), self._source_id)
|
||||
(text, embedding, metadata, self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
inserted_ids.append(str(inserted_id))
|
||||
@@ -261,7 +266,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embeddings[0], self._Json(final_metadata), self._source_id)
|
||||
(text, embeddings[0], final_metadata, self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
|
||||
@@ -247,7 +247,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
|
||||
|
||||
def download_file(url, params, dest_path):
|
||||
try:
|
||||
response = requests.get(url, params=params)
|
||||
response = requests.get(url, params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
@@ -284,12 +284,14 @@ def upload_index(full_path, file_data):
|
||||
files=files,
|
||||
data=file_data,
|
||||
headers=headers,
|
||||
timeout=100,
|
||||
)
|
||||
else:
|
||||
response = requests.post(
|
||||
urljoin(settings.API_URL, "/api/upload_index"),
|
||||
data=file_data,
|
||||
headers=headers,
|
||||
timeout=100,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except (requests.RequestException, FileNotFoundError) as e:
|
||||
|
||||
114
docs/content/Deploying/Postgres-Migration.mdx
Normal file
114
docs/content/Deploying/Postgres-Migration.mdx
Normal file
@@ -0,0 +1,114 @@
|
||||
---
|
||||
title: PostgreSQL for User Data
|
||||
description: Set up PostgreSQL as the user-data store for DocsGPT and migrate from MongoDB at your own pace.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
|
||||
# PostgreSQL for User Data
|
||||
|
||||
DocsGPT is progressively moving user data (conversations, agents, prompts,
|
||||
preferences, etc.) from MongoDB to PostgreSQL, one collection at a time.
|
||||
Each collection is guarded by a feature flag so you can opt in and roll
|
||||
back instantly. MongoDB stays the source of truth until you cut over
|
||||
reads; vector stores (`VECTOR_STORE=pgvector`, `faiss`, `qdrant`, `mongodb`, …)
|
||||
are unaffected.
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
Which collections are available today is in the [Status](#status)
|
||||
table below. That table is the only part of this page that changes
|
||||
release to release.
|
||||
</Callout>
|
||||
|
||||
## Setup
|
||||
|
||||
1. **Run Postgres 13+.** Native install, Docker, or managed (Neon, RDS,
|
||||
Supabase, Cloud SQL…) — all work. You'll need the `pgcrypto` and
|
||||
`citext` extensions, both standard contrib modules available
|
||||
everywhere.
|
||||
|
||||
2. **Create a database and role** (skip if your managed provider gave
|
||||
you these):
|
||||
|
||||
```sql
|
||||
CREATE ROLE docsgpt LOGIN PASSWORD 'docsgpt';
|
||||
CREATE DATABASE docsgpt OWNER docsgpt;
|
||||
```
|
||||
|
||||
3. **Set `POSTGRES_URI` in `.env`.** Any standard Postgres URI works —
|
||||
DocsGPT normalizes it internally.
|
||||
|
||||
```bash
|
||||
POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
# Append ?sslmode=require for managed providers that enforce SSL.
|
||||
```
|
||||
|
||||
4. **Apply the schema** (idempotent — safe to re-run):
|
||||
|
||||
```bash
|
||||
python scripts/db/init_postgres.py
|
||||
```
|
||||
|
||||
## Migrating data
|
||||
|
||||
Two global flags, no per-collection knobs — every collection marked ✅
|
||||
in the [Status](#status) table is handled automatically.
|
||||
|
||||
1. **Enable dual-write.** Writes go to both Mongo and Postgres; Mongo
|
||||
remains source of truth. Set the flag in `.env` and restart:
|
||||
|
||||
```bash
|
||||
USE_POSTGRES=true
|
||||
```
|
||||
|
||||
2. **Backfill existing data.** Idempotent — re-run any time to re-sync
|
||||
drifted rows. Without arguments, backfills every registered table;
|
||||
pass `--tables` to limit.
|
||||
|
||||
```bash
|
||||
python scripts/db/backfill.py --dry-run # preview everything
|
||||
python scripts/db/backfill.py # real run, everything
|
||||
python scripts/db/backfill.py --tables users # only specific tables
|
||||
```
|
||||
|
||||
3. **Cut over reads** once you trust the Postgres state:
|
||||
|
||||
```bash
|
||||
READ_POSTGRES=true
|
||||
```
|
||||
|
||||
Rollback is instant: unset `READ_POSTGRES` and restart. Dual-write
|
||||
keeps Postgres up to date so you can flip back and forth.
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Don't decommission MongoDB until every collection you use is fully
|
||||
cut over. During the migration window, Mongo is still required.
|
||||
</Callout>
|
||||
|
||||
## Status
|
||||
|
||||
_Last updated: 2026-04-10_
|
||||
|
||||
| Collection | Status |
|
||||
|---|---|
|
||||
| `users` | ✅ Phase 1 |
|
||||
| `prompts`, `user_tools`, `feedback`, `stack_logs`, `user_logs`, `token_usage` | ⏳ Phase 1 |
|
||||
| `agents`, `sources`, `attachments`, `memories`, `todos`, `notes`, `connector_sessions`, `agent_folders` | ⏳ Phase 2 |
|
||||
| `conversations`, `pending_tool_state`, `workflows` | ⏳ Phase 3 |
|
||||
|
||||
Schemas for **every** row above already exist after `init_postgres.py`
|
||||
runs. What's landing progressively is the application-level dual-write
|
||||
wiring and the backfill logic for each collection. Once a collection
|
||||
is ✅, enabling `USE_POSTGRES=true` and running `python scripts/db/backfill.py`
|
||||
picks it up automatically — no per-collection config change.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **`relation "..." does not exist`** — run `python scripts/db/init_postgres.py`.
|
||||
- **`FATAL: role "docsgpt" does not exist`** — run the `CREATE ROLE` /
|
||||
`CREATE DATABASE` statements from step 2 as a Postgres superuser.
|
||||
- **SSL errors on a managed provider** — append `?sslmode=require` to
|
||||
`POSTGRES_URI`.
|
||||
- **Dual-write warnings in the logs** — expected to be non-fatal. Mongo
|
||||
is source of truth, so the user-facing request succeeds. Re-run the
|
||||
backfill to re-sync whichever rows drifted.
|
||||
@@ -19,6 +19,10 @@ export default {
|
||||
"title": "☁️ Hosting DocsGPT",
|
||||
"href": "/Deploying/Hosting-the-app"
|
||||
},
|
||||
"Postgres-Migration": {
|
||||
"title": "🐘 PostgreSQL for User Data",
|
||||
"href": "/Deploying/Postgres-Migration"
|
||||
},
|
||||
"Amazon-Lightsail": {
|
||||
"title": "Hosting DocsGPT on Amazon Lightsail",
|
||||
"href": "/Deploying/Amazon-Lightsail",
|
||||
|
||||
@@ -54,8 +54,8 @@ flowchart LR
|
||||
* **Technology:** Supports multiple LLM APIs and local engines.
|
||||
* **Responsibility:** This layer provides an abstraction for interacting with Large Language Models (LLMs).
|
||||
* **Key Features:**
|
||||
* Supports LLMs from OpenAI, Google, Anthropic, Groq, HuggingFace Inference API, Azure OpenAI, also compatable with local models like Ollama, LLaMa.cpp, Text Generation Inference (TGI), SGLang, vLLM, Aphrodite, FriendliAI, and LMDeploy.
|
||||
* Manages API key handling and request formatting and Tool fromatting.
|
||||
* Supports LLMs from OpenAI, Google, Anthropic, Groq, HuggingFace Inference API, Azure OpenAI, also compatible with local models like Ollama, LLaMa.cpp, Text Generation Inference (TGI), SGLang, vLLM, Aphrodite, FriendliAI, and LMDeploy.
|
||||
* Manages API key handling and request formatting and Tool formatting.
|
||||
* Offers caching mechanisms to improve response times and reduce API usage.
|
||||
* Handles streaming responses for a more interactive user experience.
|
||||
|
||||
@@ -120,7 +120,7 @@ sequenceDiagram
|
||||
|
||||
## Deployment Architecture
|
||||
|
||||
DocsGPT is designed to be deployed using Docker and Kubernetes, here is a qucik overview of a simple k8s deployment.
|
||||
DocsGPT is designed to be deployed using Docker and Kubernetes, here is a quick overview of a simple k8s deployment.
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
|
||||
@@ -7,6 +7,10 @@ export default {
|
||||
"title": "🔗 SharePoint / OneDrive",
|
||||
"href": "/Guides/Integrations/sharepoint-connector"
|
||||
},
|
||||
"confluence-connector": {
|
||||
"title": "🔗 Confluence",
|
||||
"href": "/Guides/Integrations/confluence-connector"
|
||||
},
|
||||
"mcp-tool-integration": {
|
||||
"title": "🔗 MCP Tools",
|
||||
"href": "/Guides/Integrations/mcp-tool-integration"
|
||||
|
||||
67
docs/content/Guides/Integrations/confluence-connector.mdx
Normal file
67
docs/content/Guides/Integrations/confluence-connector.mdx
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
title: Confluence Connector
|
||||
description: Connect your Confluence Cloud workspace as an external knowledge base to upload and process pages directly.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import { Steps } from 'nextra/components'
|
||||
|
||||
# Confluence Connector
|
||||
|
||||
Connect your Confluence Cloud workspace to upload and process pages directly as an external knowledge base. Supports page content and attachments (PDFs, Office files, text files, images, and more). Authentication is handled via Atlassian OAuth 2.0 with automatic token refresh.
|
||||
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Create an OAuth 2.0 App in Atlassian
|
||||
|
||||
1. Go to [developer.atlassian.com/console/myapps](https://developer.atlassian.com/console/myapps/) and click **Create** > **OAuth 2.0 integration**
|
||||
2. Under **Authorization**, add a callback URL:
|
||||
- Local: `http://localhost:7091/api/connectors/callback?provider=confluence`
|
||||
- Production: `https://yourdomain.com/api/connectors/callback?provider=confluence`
|
||||
|
||||
### Step 2: Configure Permissions
|
||||
|
||||
In your app settings, go to **Permissions** and add the **Confluence API**. Enable these scopes:
|
||||
- `read:page:confluence`
|
||||
- `read:space:confluence`
|
||||
- `read:attachment:confluence`
|
||||
|
||||
### Step 3: Get Your Credentials
|
||||
|
||||
Go to **Settings** in your app to find the **Client ID** and **Secret**. Copy both.
|
||||
|
||||
### Step 4: Configure Environment Variables
|
||||
|
||||
Add to your backend `.env` file:
|
||||
|
||||
```env
|
||||
CONFLUENCE_CLIENT_ID=your-atlassian-client-id
|
||||
CONFLUENCE_CLIENT_SECRET=your-atlassian-client-secret
|
||||
```
|
||||
|
||||
Add to your frontend `.env` file:
|
||||
|
||||
```env
|
||||
VITE_CONFLUENCE_CLIENT_ID=your-atlassian-client-id
|
||||
```
|
||||
|
||||
| Variable | Description | Required |
|
||||
|----------|-------------|----------|
|
||||
| `CONFLUENCE_CLIENT_ID` | Client ID from your Atlassian OAuth app | Yes |
|
||||
| `CONFLUENCE_CLIENT_SECRET` | Client secret from your Atlassian OAuth app | Yes |
|
||||
| `VITE_CONFLUENCE_CLIENT_ID` | Same Client ID, used by the frontend to show the Confluence option | Yes |
|
||||
|
||||
### Step 5: Restart and Use
|
||||
|
||||
Restart your application, then go to the upload section in DocsGPT and select **Confluence** as the source. You'll be redirected to Atlassian to sign in, then can browse spaces and select pages to process.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Option not appearing** — Verify `VITE_CONFLUENCE_CLIENT_ID` is set in the frontend `.env`, then restart.
|
||||
- **Authentication failed** — Check that the callback URL matches exactly, including `?provider=confluence`.
|
||||
- **No accessible sites** — Ensure the authenticating user has access to at least one Confluence Cloud site.
|
||||
- **Permission denied** — Verify that the Confluence API scopes are enabled in your Atlassian app settings.
|
||||
@@ -8,205 +8,66 @@ import { Steps } from 'nextra/components'
|
||||
|
||||
# Google Drive Connector
|
||||
|
||||
The Google Drive Connector allows you to seamlessly connect your Google Drive account as an external knowledge base. This integration enables you to upload and process files directly from your Google Drive without manually downloading and uploading them to DocsGPT.
|
||||
Connect your Google Drive account to upload and process files directly as an external knowledge base. Supports Google Workspace files (Docs, Sheets, Slides), Office files, PDFs, text files, CSVs, images, and more. Authentication is handled via Google OAuth 2.0 with automatic token refresh.
|
||||
|
||||
## Features
|
||||
|
||||
- **Direct File Access**: Browse and select files directly from your Google Drive
|
||||
- **Comprehensive File Support**: Supports all major document formats including:
|
||||
- Google Workspace files (Docs, Sheets, Slides)
|
||||
- Microsoft Office files (.docx, .xlsx, .pptx, .doc, .ppt, .xls)
|
||||
- PDF documents
|
||||
- Text files (.txt, .md, .rst, .html, .rtf)
|
||||
- Data files (.csv, .json)
|
||||
- Image files (.png, .jpg, .jpeg)
|
||||
- E-books (.epub)
|
||||
- **Secure Authentication**: Uses OAuth 2.0 for secure access to your Google Drive
|
||||
- **Real-time Sync**: Process files directly from Google Drive without local downloads
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
The Google Drive Connector requires proper configuration of Google API credentials. Follow the setup instructions below to enable this feature.
|
||||
</Callout>
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before setting up the Google Drive Connector, you'll need:
|
||||
|
||||
1. A Google Cloud Platform (GCP) project
|
||||
2. Google Drive API enabled
|
||||
3. OAuth 2.0 credentials configured
|
||||
4. DocsGPT instance with proper environment variables
|
||||
|
||||
## Setup Instructions
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Create a Google Cloud Project
|
||||
|
||||
1. Go to the [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create a new project or select an existing one
|
||||
3. Note down your Project ID for later use
|
||||
1. Go to the [Google Cloud Console](https://console.cloud.google.com/) and create a new project (or select an existing one)
|
||||
2. Navigate to **APIs & Services** > **Library**, search for "Google Drive API", and click **Enable**
|
||||
|
||||
### Step 2: Enable Google Drive API
|
||||
### Step 2: Create OAuth 2.0 Credentials
|
||||
|
||||
1. In the Google Cloud Console, navigate to **APIs & Services** > **Library**
|
||||
2. Search for "Google Drive API"
|
||||
3. Click on "Google Drive API" and click **Enable**
|
||||
1. Go to **APIs & Services** > **Credentials** > **Create Credentials** > **OAuth client ID**
|
||||
2. If prompted, configure the OAuth consent screen (choose **External**, fill in required fields)
|
||||
3. Select **Web application** as the application type
|
||||
4. Add your DocsGPT URL to **Authorized JavaScript origins** (e.g. `http://localhost:3000`)
|
||||
5. Add your callback URL to **Authorized redirect URIs**:
|
||||
- Local: `http://localhost:7091/api/connectors/callback?provider=google_drive`
|
||||
- Production: `https://yourdomain.com/api/connectors/callback?provider=google_drive`
|
||||
6. Click **Create** and copy the **Client ID** and **Client Secret**
|
||||
|
||||
### Step 3: Create OAuth 2.0 Credentials
|
||||
### Step 3: Configure Environment Variables
|
||||
|
||||
1. Go to **APIs & Services** > **Credentials**
|
||||
2. Click **Create Credentials** > **OAuth client ID**
|
||||
3. If prompted, configure the OAuth consent screen:
|
||||
- Choose **External** user type (unless you're using Google Workspace)
|
||||
- Fill in the required fields (App name, User support email, Developer contact)
|
||||
- Add your domain to **Authorized domains** if deploying publicly
|
||||
4. For Application type, select **Web application**
|
||||
5. Add your DocsGPT frontend URL to **Authorized JavaScript origins**:
|
||||
- For local development: `http://localhost:3000`
|
||||
- For production: `https://yourdomain.com`
|
||||
6. Add your DocsGPT callback URL to **Authorized redirect URIs**:
|
||||
- For local development: `http://localhost:7091/api/connectors/callback?provider=google_drive`
|
||||
- For production: `https://yourdomain.com/api/connectors/callback?provider=google_drive`
|
||||
7. Click **Create** and note down the **Client ID** and **Client Secret**
|
||||
|
||||
|
||||
|
||||
### Step 4: Configure Backend Environment Variables
|
||||
|
||||
Add the following environment variables to your backend configuration:
|
||||
|
||||
**For Docker deployment**, add to your `.env` file in the root directory:
|
||||
Add to your backend `.env` file:
|
||||
|
||||
```env
|
||||
# Google Drive Connector Configuration
|
||||
GOOGLE_CLIENT_ID=your_google_client_id_here
|
||||
GOOGLE_CLIENT_SECRET=your_google_client_secret_here
|
||||
GOOGLE_CLIENT_ID=your-google-client-id
|
||||
GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
```
|
||||
|
||||
**For manual deployment**, set these environment variables in your system or application configuration.
|
||||
|
||||
### Step 5: Configure Frontend Environment Variables
|
||||
|
||||
Add the following environment variables to your frontend `.env` file:
|
||||
Add to your frontend `.env` file:
|
||||
|
||||
```env
|
||||
# Google Drive Frontend Configuration
|
||||
VITE_GOOGLE_CLIENT_ID=your_google_client_id_here
|
||||
VITE_GOOGLE_CLIENT_ID=your-google-client-id
|
||||
```
|
||||
|
||||
| Variable | Description | Required |
|
||||
|----------|-------------|----------|
|
||||
| `GOOGLE_CLIENT_ID` | OAuth Client ID from GCP Credentials | Yes |
|
||||
| `GOOGLE_CLIENT_SECRET` | OAuth Client Secret from GCP Credentials | Yes |
|
||||
| `VITE_GOOGLE_CLIENT_ID` | Same Client ID, used by the frontend to show the Google Drive option | Yes |
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Make sure to use the same Google Client ID in both backend and frontend configurations.
|
||||
</Callout>
|
||||
|
||||
### Step 6: Restart Your Application
|
||||
### Step 4: Restart and Use
|
||||
|
||||
After configuring the environment variables:
|
||||
|
||||
1. **For Docker**: Restart your Docker containers
|
||||
```bash
|
||||
docker-compose down
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
2. **For manual deployment**: Restart both backend and frontend services
|
||||
Restart your application, then go to the upload section in DocsGPT and select **Google Drive** as the source. You'll be redirected to Google to sign in, then can browse and select files to process.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Using the Google Drive Connector
|
||||
|
||||
Once configured, you can use the Google Drive Connector to upload files:
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Access the Upload Interface
|
||||
|
||||
1. Navigate to the DocsGPT interface
|
||||
2. Go to the upload/training section
|
||||
3. You should now see "Google Drive" as an available upload option
|
||||
|
||||
### Step 2: Connect Your Google Account
|
||||
|
||||
1. Select "Google Drive" as your upload method
|
||||
2. Click "Connect to Google Drive"
|
||||
3. You'll be redirected to Google's OAuth consent screen
|
||||
4. Grant the necessary permissions to DocsGPT
|
||||
5. You'll be redirected back to DocsGPT with a successful connection
|
||||
|
||||
### Step 3: Select Files
|
||||
|
||||
1. Once connected, click "Select Files"
|
||||
2. The Google Drive picker will open
|
||||
3. Browse your Google Drive and select the files you want to process
|
||||
4. Click "Select" to confirm your choices
|
||||
|
||||
### Step 4: Process Files
|
||||
|
||||
1. Review your selected files
|
||||
2. Click "Train" or "Upload" to process the files
|
||||
3. DocsGPT will download and process the files from your Google Drive
|
||||
4. Once processing is complete, the files will be available in your knowledge base
|
||||
|
||||
</Steps>
|
||||
|
||||
## Supported File Types
|
||||
|
||||
The Google Drive Connector supports the following file types:
|
||||
|
||||
| File Type | Extensions | Description |
|
||||
|-----------|------------|-------------|
|
||||
| **Google Workspace** | - | Google Docs, Sheets, Slides (automatically converted) |
|
||||
| **Microsoft Office** | .docx, .xlsx, .pptx | Modern Office formats |
|
||||
| **Legacy Office** | .doc, .ppt, .xls | Older Office formats |
|
||||
| **PDF Documents** | .pdf | Portable Document Format |
|
||||
| **Text Files** | .txt, .md, .rst, .html, .rtf | Various text formats |
|
||||
| **Data Files** | .csv, .json | Structured data formats |
|
||||
| **Images** | .png, .jpg, .jpeg | Image files (with OCR if enabled) |
|
||||
| **E-books** | .epub | Electronic publication format |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**"Google Drive option not appearing"**
|
||||
- Verify that `VITE_GOOGLE_CLIENT_ID` is set in frontend environment
|
||||
- Check that `VITE_GOOGLE_CLIENT_ID` environment variable is present in your frontend configuration
|
||||
- Check browser console for any JavaScript errors
|
||||
- Ensure the frontend has been restarted after adding environment variables
|
||||
|
||||
**"Authentication failed"**
|
||||
- Verify that your OAuth 2.0 credentials are correctly configured
|
||||
- Check that the redirect URI `http://<your-domain>/api/connectors/callback?provider=google_drive` is correctly added in GCP console
|
||||
- Ensure the Google Drive API is enabled in your GCP project
|
||||
|
||||
**"Permission denied" errors**
|
||||
- Verify that the OAuth consent screen is properly configured
|
||||
- Check that your Google account has access to the files you're trying to select
|
||||
- Ensure the required scopes are granted during authentication
|
||||
|
||||
**"Files not processing"**
|
||||
- Check that the backend environment variables are correctly set
|
||||
- Verify that the OAuth credentials have the necessary permissions
|
||||
- Check the backend logs for any error messages
|
||||
|
||||
### Environment Variable Checklist
|
||||
|
||||
**Backend (.env in root directory):**
|
||||
- ✅ `GOOGLE_CLIENT_ID`
|
||||
- ✅ `GOOGLE_CLIENT_SECRET`
|
||||
|
||||
**Frontend (.env in frontend directory):**
|
||||
- ✅ `VITE_GOOGLE_CLIENT_ID`
|
||||
|
||||
### Security Considerations
|
||||
|
||||
- Keep your Google Client Secret secure and never expose it in frontend code
|
||||
- Regularly rotate your OAuth credentials
|
||||
- Use HTTPS in production to protect authentication tokens
|
||||
- Ensure proper OAuth consent screen configuration for production use
|
||||
- **Option not appearing** — Verify `VITE_GOOGLE_CLIENT_ID` is set in the frontend `.env`, then restart.
|
||||
- **Authentication failed** — Check that the redirect URI matches exactly, including `?provider=google_drive`. Ensure the Google Drive API is enabled.
|
||||
- **Permission denied** — Verify the OAuth consent screen is configured and the user has access to the target files.
|
||||
- **Files not processing** — Check backend logs and verify that backend environment variables are correctly set.
|
||||
|
||||
<Callout type="tip" emoji="💡">
|
||||
For production deployments, make sure to add your actual domain to the OAuth consent screen and authorized origins/redirect URIs.
|
||||
For production deployments, add your actual domain to the OAuth consent screen and authorized origins/redirect URIs.
|
||||
</Callout>
|
||||
|
||||
|
||||
|
||||
4
frontend/src/assets/confluence.svg
Normal file
4
frontend/src/assets/confluence.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M15.7903 2.01315C16.0583 1.53239 16.6644 1.35859 17.1464 1.62428L21.9827 4.28993C22.2157 4.41835 22.3879 4.63433 22.4613 4.89007C22.5346 5.14582 22.503 5.42024 22.3735 5.65262L20.7649 8.53807C19.6743 10.4944 17.9383 11.868 15.9685 12.5426L21.8863 15.8043C22.1193 15.9328 22.2915 16.1488 22.3649 16.4045C22.4382 16.6602 22.4066 16.9347 22.2771 17.167L19.5962 21.9761C19.3282 22.4569 18.7221 22.6307 18.24 22.365L11.4692 18.6331C10.8804 18.3085 10.1413 18.5224 9.81847 19.1015L8.20996 21.987C7.94196 22.4677 7.33584 22.6415 6.8538 22.3758L2.01729 19.7101C1.78429 19.5816 1.61207 19.3657 1.53874 19.1099C1.46541 18.8542 1.49701 18.5798 1.62655 18.3474L3.23506 15.4619C4.32566 13.5056 6.06166 12.132 8.0315 11.4574L2.11368 8.19564C1.88068 8.06721 1.70846 7.85124 1.63513 7.59549C1.56179 7.33975 1.59339 7.06533 1.72294 6.83295L4.40379 2.02389C4.67179 1.54313 5.27791 1.36933 5.75995 1.63502L12.531 5.36708C13.1199 5.69165 13.8589 5.47779 14.1818 4.89861L15.7903 2.01315ZM17.0526 3.85624L15.9287 5.87243C15.067 7.41803 13.1136 7.97187 11.5656 7.11864L5.66611 3.86698L3.9591 6.92911L9.85005 10.1761C13.11 11.9729 17.2146 10.7994 19.018 7.56424L20.1373 5.55645L17.0526 3.85624ZM14.15 13.8239C10.89 12.0271 6.78543 13.2006 4.98197 16.4357L3.86271 18.4435L6.94764 20.1439L8.07157 18.1277C8.93317 16.5821 10.8866 16.0283 12.4346 16.8815L18.3339 20.133L20.0409 17.0709L14.15 13.8239Z" fill="#000000"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
@@ -1,6 +1,7 @@
|
||||
import React, { useRef } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import { useDarkTheme } from '../hooks';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
|
||||
@@ -149,7 +150,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
|
||||
{isConnected ? (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-sm font-medium text-[#212121]">
|
||||
<div className="text-eerie-black flex w-full items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-sm font-medium">
|
||||
<div className="flex max-w-[500px] items-center gap-2">
|
||||
<svg className="h-4 w-4" viewBox="0 0 24 24">
|
||||
<path
|
||||
@@ -166,7 +167,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
{onDisconnect && (
|
||||
<button
|
||||
onClick={onDisconnect}
|
||||
className="text-xs font-medium text-[#212121] underline hover:text-gray-700"
|
||||
className="text-eerie-black text-xs font-medium underline hover:text-gray-700"
|
||||
>
|
||||
{t('modals.uploadDoc.connectors.auth.disconnect')}
|
||||
</button>
|
||||
|
||||
@@ -60,6 +60,10 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
displayName: 'SharePoint',
|
||||
rootName: 'My Files',
|
||||
},
|
||||
confluence: {
|
||||
displayName: 'Confluence',
|
||||
rootName: 'Spaces',
|
||||
},
|
||||
} as const;
|
||||
|
||||
const getProviderConfig = (provider: string) => {
|
||||
@@ -202,7 +206,9 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
if (!validateResponse.ok) {
|
||||
removeSessionToken(provider);
|
||||
setIsConnected(false);
|
||||
setAuthError('Session expired. Please reconnect to Google Drive.');
|
||||
setAuthError(
|
||||
`Session expired. Please reconnect to ${getProviderConfig(provider).displayName}.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -398,6 +404,7 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
|
||||
<ConnectorAuth
|
||||
provider={provider}
|
||||
label={`Connect to ${getProviderConfig(provider).displayName}`}
|
||||
onSuccess={(data) => {
|
||||
setUserEmail(data.user_email || 'Connected User');
|
||||
setIsConnected(true);
|
||||
|
||||
@@ -263,8 +263,8 @@ const MermaidRenderer: React.FC<MermaidRendererProps> = ({
|
||||
const errorRender = !isCurrentlyLoading && error;
|
||||
|
||||
return (
|
||||
<div className="w-inherit group border-border bg-card relative rounded-lg border">
|
||||
<div className="bg-platinum flex items-center justify-between px-2 py-1">
|
||||
<div className="w-inherit group border-border bg-card relative overflow-hidden rounded-[14px] border">
|
||||
<div className="bg-platinum dark:bg-muted flex items-center justify-between px-2 py-1">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
mermaid
|
||||
</span>
|
||||
@@ -401,7 +401,7 @@ const MermaidRenderer: React.FC<MermaidRendererProps> = ({
|
||||
|
||||
{showCode && (
|
||||
<div className="border-border border-t">
|
||||
<div className="bg-platinum p-2">
|
||||
<div className="bg-platinum dark:bg-muted p-2">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
Mermaid Code
|
||||
</span>
|
||||
|
||||
@@ -1296,9 +1296,8 @@ export default function MessageInput({
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (autoFocus) inputRef.current?.focus();
|
||||
handleInput();
|
||||
}, [autoFocus, handleInput]);
|
||||
}, [handleInput]);
|
||||
|
||||
const handleChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setValue(e.target.value);
|
||||
@@ -1364,8 +1363,9 @@ export default function MessageInput({
|
||||
) {
|
||||
onSubmit(value);
|
||||
setValue('');
|
||||
// Refocus input after submission if autoFocus is enabled
|
||||
if (autoFocus) {
|
||||
if (isTouch) {
|
||||
inputRef.current?.blur();
|
||||
} else if (autoFocus) {
|
||||
setTimeout(() => {
|
||||
if (isMountedRef.current) {
|
||||
inputRef.current?.focus();
|
||||
@@ -1544,6 +1544,7 @@ export default function MessageInput({
|
||||
id="message-input"
|
||||
ref={inputRef}
|
||||
value={value}
|
||||
autoFocus={autoFocus && !isTouch}
|
||||
onChange={handleChange}
|
||||
readOnly={
|
||||
recordingState === 'recording' ||
|
||||
|
||||
@@ -236,7 +236,7 @@ export default function Conversation() {
|
||||
isSplitArtifactOpen ? 'w-[60%] px-6' : 'w-full'
|
||||
}`}
|
||||
>
|
||||
<div className="min-h-0 flex-1">
|
||||
<div className="relative min-h-0 flex-1 ">
|
||||
<ConversationMessages
|
||||
handleQuestion={handleQuestion}
|
||||
handleQuestionSubmission={handleQuestionSubmission}
|
||||
@@ -255,6 +255,7 @@ export default function Conversation() {
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
<div className="from-background pointer-events-none absolute right-1.5 bottom-0 left-0 h-6 rounded-t-2xl bg-linear-to-t to-transparent" />
|
||||
</div>
|
||||
|
||||
<div
|
||||
|
||||
@@ -559,7 +559,7 @@ const ConversationBubble = forwardRef<
|
||||
|
||||
return match ? (
|
||||
<div className="group border-border relative overflow-hidden rounded-[14px] border">
|
||||
<div className="bg-platinum flex items-center justify-between px-2 py-1">
|
||||
<div className="bg-platinum dark:bg-muted flex items-center justify-between px-2 py-1">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
{language}
|
||||
</span>
|
||||
@@ -1204,7 +1204,7 @@ function Thought({
|
||||
|
||||
return match ? (
|
||||
<div className="group border-border relative overflow-hidden rounded-[14px] border">
|
||||
<div className="bg-platinum flex items-center justify-between px-2 py-1">
|
||||
<div className="bg-platinum dark:bg-muted flex items-center justify-between px-2 py-1">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
{language}
|
||||
</span>
|
||||
|
||||
@@ -62,40 +62,130 @@ export default function ConversationMessages({
|
||||
const { t } = useTranslation();
|
||||
|
||||
const conversationRef = useRef<HTMLDivElement>(null);
|
||||
const [hasScrolledToLast, setHasScrolledToLast] = useState(true);
|
||||
const [userInterruptedScroll, setUserInterruptedScroll] = useState(false);
|
||||
const [scrollButtonVisible, setScrollButtonVisible] = useState(false);
|
||||
const userInterruptedRef = useRef(false);
|
||||
const [interrupted, setInterrupted] = useState(false);
|
||||
const lastTouchYRef = useRef<number | null>(null);
|
||||
const isInitialLoad = useRef(true);
|
||||
const prevQueriesRef = useRef(queries);
|
||||
const isAutoScrollingRef = useRef(false);
|
||||
const smoothScrollTimeoutRef =
|
||||
useRef<ReturnType<typeof setTimeout>>(undefined);
|
||||
const showButtonTimerRef = useRef<ReturnType<typeof setTimeout>>(undefined);
|
||||
|
||||
const handleUserScrollInterruption = useCallback(() => {
|
||||
if (!userInterruptedScroll && status === 'loading') {
|
||||
setUserInterruptedScroll(true);
|
||||
}
|
||||
}, [userInterruptedScroll, status]);
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
clearTimeout(smoothScrollTimeoutRef.current);
|
||||
clearTimeout(showButtonTimerRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const scrollConversationToBottom = useCallback(() => {
|
||||
if (!conversationRef.current || userInterruptedScroll) return;
|
||||
const isAtBottom = useCallback(() => {
|
||||
const el = conversationRef.current;
|
||||
if (!el) return true;
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}, []);
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
if (!conversationRef?.current) return;
|
||||
// Arm on upward scroll intent; requiring !isAtBottom() missed small nudges still inside SCROLL_THRESHOLD.
|
||||
const markInterruptedIfLoading = useCallback(() => {
|
||||
if (userInterruptedRef.current || status !== 'loading') return;
|
||||
userInterruptedRef.current = true;
|
||||
setInterrupted(true);
|
||||
}, [status]);
|
||||
|
||||
if (status === 'idle' || !queries[queries.length - 1]?.response) {
|
||||
conversationRef.current.scrollTo({
|
||||
behavior: 'smooth',
|
||||
top: conversationRef.current.scrollHeight,
|
||||
});
|
||||
} else {
|
||||
conversationRef.current.scrollTop =
|
||||
conversationRef.current.scrollHeight;
|
||||
const handleWheel = useCallback(
|
||||
(e: React.WheelEvent) => {
|
||||
if (e.deltaY < 0) markInterruptedIfLoading();
|
||||
},
|
||||
[markInterruptedIfLoading],
|
||||
);
|
||||
|
||||
const handleTouchStart = useCallback((e: React.TouchEvent) => {
|
||||
lastTouchYRef.current = e.touches[0].clientY;
|
||||
}, []);
|
||||
|
||||
const handleTouchMove = useCallback(
|
||||
(e: React.TouchEvent) => {
|
||||
const y = e.touches[0].clientY;
|
||||
if (lastTouchYRef.current !== null && y > lastTouchYRef.current) {
|
||||
markInterruptedIfLoading();
|
||||
}
|
||||
});
|
||||
}, [userInterruptedScroll, status, queries]);
|
||||
lastTouchYRef.current = y;
|
||||
},
|
||||
[markInterruptedIfLoading],
|
||||
);
|
||||
|
||||
const checkScrollPosition = useCallback(() => {
|
||||
const setButtonHidden = useCallback(() => {
|
||||
clearTimeout(showButtonTimerRef.current);
|
||||
showButtonTimerRef.current = undefined;
|
||||
setScrollButtonVisible(false);
|
||||
}, []);
|
||||
|
||||
const setButtonVisibleDebounced = useCallback(() => {
|
||||
if (showButtonTimerRef.current) return;
|
||||
showButtonTimerRef.current = setTimeout(() => {
|
||||
setScrollButtonVisible(true);
|
||||
showButtonTimerRef.current = undefined;
|
||||
}, 300);
|
||||
}, []);
|
||||
|
||||
const scrollConversationToBottom = useCallback(
|
||||
(instant?: boolean) => {
|
||||
if (!conversationRef.current) return;
|
||||
|
||||
isAutoScrollingRef.current = true;
|
||||
clearTimeout(smoothScrollTimeoutRef.current);
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
if (!conversationRef?.current) return;
|
||||
|
||||
if (instant) {
|
||||
conversationRef.current.scrollTop =
|
||||
conversationRef.current.scrollHeight;
|
||||
if (isAtBottom()) {
|
||||
setButtonHidden();
|
||||
}
|
||||
isAutoScrollingRef.current = false;
|
||||
} else {
|
||||
conversationRef.current.scrollTo({
|
||||
behavior: 'smooth',
|
||||
top: conversationRef.current.scrollHeight,
|
||||
});
|
||||
smoothScrollTimeoutRef.current = setTimeout(() => {
|
||||
if (isAtBottom()) {
|
||||
setButtonHidden();
|
||||
}
|
||||
isAutoScrollingRef.current = false;
|
||||
}, 500);
|
||||
}
|
||||
});
|
||||
},
|
||||
[isAtBottom, setButtonHidden],
|
||||
);
|
||||
|
||||
const handleScroll = useCallback(() => {
|
||||
const el = conversationRef.current;
|
||||
if (!el) return;
|
||||
const isAtBottom =
|
||||
el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
setHasScrolledToLast(isAtBottom);
|
||||
}, [setHasScrolledToLast]);
|
||||
|
||||
const atBottom = isAtBottom();
|
||||
|
||||
if (atBottom && userInterruptedRef.current) {
|
||||
userInterruptedRef.current = false;
|
||||
setInterrupted(false);
|
||||
}
|
||||
|
||||
if (atBottom) {
|
||||
setButtonHidden();
|
||||
isAutoScrollingRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (isAutoScrollingRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
setButtonVisibleDebounced();
|
||||
}, [isAtBottom, setButtonHidden, setButtonVisibleDebounced]);
|
||||
|
||||
const lastQuery = queries[queries.length - 1];
|
||||
const lastQueryResponse = lastQuery?.response;
|
||||
@@ -103,34 +193,46 @@ export default function ConversationMessages({
|
||||
const lastQueryThought = lastQuery?.thought;
|
||||
|
||||
useEffect(() => {
|
||||
if (!userInterruptedScroll) {
|
||||
scrollConversationToBottom();
|
||||
if (interrupted) return;
|
||||
|
||||
const prevQueries = prevQueriesRef.current;
|
||||
const isConversationSwitch =
|
||||
prevQueries !== queries && prevQueries[0] !== queries[0];
|
||||
|
||||
if (isInitialLoad.current || isConversationSwitch) {
|
||||
isInitialLoad.current = false;
|
||||
scrollConversationToBottom(true);
|
||||
prevQueriesRef.current = queries;
|
||||
return;
|
||||
}
|
||||
|
||||
const isNewMessage = queries.length > prevQueries.length;
|
||||
prevQueriesRef.current = queries;
|
||||
|
||||
scrollConversationToBottom(isNewMessage ? false : true);
|
||||
}, [
|
||||
queries.length,
|
||||
lastQueryResponse,
|
||||
lastQueryError,
|
||||
lastQueryThought,
|
||||
userInterruptedScroll,
|
||||
interrupted,
|
||||
scrollConversationToBottom,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (status === 'idle') {
|
||||
setUserInterruptedScroll(false);
|
||||
userInterruptedRef.current = false;
|
||||
setInterrupted(false);
|
||||
}
|
||||
}, [status]);
|
||||
|
||||
useEffect(() => {
|
||||
const currentConversationRef = conversationRef.current;
|
||||
currentConversationRef?.addEventListener('scroll', checkScrollPosition);
|
||||
currentConversationRef?.addEventListener('scroll', handleScroll);
|
||||
return () => {
|
||||
currentConversationRef?.removeEventListener(
|
||||
'scroll',
|
||||
checkScrollPosition,
|
||||
);
|
||||
currentConversationRef?.removeEventListener('scroll', handleScroll);
|
||||
};
|
||||
}, [checkScrollPosition]);
|
||||
}, [handleScroll]);
|
||||
|
||||
const retryIconProps = {
|
||||
width: 12,
|
||||
@@ -208,7 +310,7 @@ export default function ConversationMessages({
|
||||
>
|
||||
<div className="flex max-w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<div className="flex h-[34px] w-[34px] items-center justify-center overflow-hidden rounded-full">
|
||||
<div className="flex h-8.5 w-8.5 items-center justify-center overflow-hidden rounded-full">
|
||||
<img
|
||||
src={DocsGPT3}
|
||||
alt={t('conversation.answer')}
|
||||
@@ -237,18 +339,24 @@ export default function ConversationMessages({
|
||||
return (
|
||||
<div
|
||||
ref={conversationRef}
|
||||
onWheel={handleUserScrollInterruption}
|
||||
onTouchMove={handleUserScrollInterruption}
|
||||
onWheel={handleWheel}
|
||||
onTouchStart={handleTouchStart}
|
||||
onTouchMove={handleTouchMove}
|
||||
className="flex h-full w-full justify-center overflow-y-auto will-change-scroll sm:pt-6 lg:pt-12"
|
||||
>
|
||||
{queries.length > 0 && !hasScrolledToLast && (
|
||||
{queries.length > 0 && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setUserInterruptedScroll(false);
|
||||
userInterruptedRef.current = false;
|
||||
setInterrupted(false);
|
||||
scrollConversationToBottom();
|
||||
}}
|
||||
aria-label={t('Scroll to bottom') || 'Scroll to bottom'}
|
||||
className="border-border bg-card fixed right-14 bottom-40 z-10 flex h-7 w-7 items-center justify-center rounded-full border md:h-9 md:w-9"
|
||||
className={`border-border bg-card fixed bottom-40 left-1/2 z-10 flex h-7 w-7 -translate-x-1/2 items-center justify-center rounded-full border transition-all duration-300 ease-in-out md:right-14 md:left-auto md:h-9 md:w-9 md:translate-x-0 ${
|
||||
scrollButtonVisible
|
||||
? 'pointer-events-auto scale-100 opacity-100'
|
||||
: 'pointer-events-none scale-75 opacity-0'
|
||||
}`}
|
||||
>
|
||||
<img
|
||||
src={ArrowDown}
|
||||
@@ -261,8 +369,8 @@ export default function ConversationMessages({
|
||||
<div
|
||||
className={
|
||||
isSplitView
|
||||
? 'w-full max-w-[1300px] px-2'
|
||||
: 'w-full max-w-[1300px] px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12'
|
||||
? 'w-full max-w-325 px-2'
|
||||
: 'w-full max-w-325 px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12'
|
||||
}
|
||||
>
|
||||
{headerContent}
|
||||
|
||||
@@ -325,6 +325,14 @@
|
||||
"s3": {
|
||||
"label": "Amazon S3",
|
||||
"heading": "Inhalt von Amazon S3 hinzufügen"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Von SharePoint hochladen"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Von Confluence hochladen"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -341,6 +341,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Upload from SharePoint"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Upload from Confluence"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Subir desde SharePoint"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Subir desde Confluence"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "SharePointからアップロード"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Confluenceからアップロード"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Загрузить из SharePoint"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Загрузить из Confluence"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "從SharePoint上傳"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "從Confluence上傳"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "从SharePoint上传"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "从Confluence上传"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -266,6 +266,23 @@ function Upload({
|
||||
initialSelectedFolders={selectedFolders}
|
||||
/>
|
||||
);
|
||||
case 'confluence_picker':
|
||||
return (
|
||||
<FilePicker
|
||||
key={field.name}
|
||||
onSelectionChange={(
|
||||
selectedFileIds: string[],
|
||||
selectedFolderIds: string[] = [],
|
||||
) => {
|
||||
setSelectedFiles(selectedFileIds);
|
||||
setSelectedFolders(selectedFolderIds);
|
||||
}}
|
||||
provider="confluence"
|
||||
token={token}
|
||||
initialSelectedFiles={selectedFiles}
|
||||
initialSelectedFolders={selectedFolders}
|
||||
/>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
@@ -551,6 +568,9 @@ function Upload({
|
||||
const hasSharePointPicker = schema.some(
|
||||
(field: FormField) => field.type === 'share_point_picker',
|
||||
);
|
||||
const hasConfluencePicker = schema.some(
|
||||
(field: FormField) => field.type === 'confluence_picker',
|
||||
);
|
||||
|
||||
let configData: Record<string, unknown> = { ...ingestor.config };
|
||||
|
||||
@@ -561,7 +581,8 @@ function Upload({
|
||||
} else if (
|
||||
hasRemoteFilePicker ||
|
||||
hasGoogleDrivePicker ||
|
||||
hasSharePointPicker
|
||||
hasSharePointPicker ||
|
||||
hasConfluencePicker
|
||||
) {
|
||||
const sessionToken = getSessionToken(ingestor.type as string);
|
||||
configData = {
|
||||
@@ -721,6 +742,9 @@ function Upload({
|
||||
const hasSharePointPicker = schema.some(
|
||||
(field: FormField) => field.type === 'share_point_picker',
|
||||
);
|
||||
const hasConfluencePicker = schema.some(
|
||||
(field: FormField) => field.type === 'confluence_picker',
|
||||
);
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
if (files.length === 0) {
|
||||
@@ -729,7 +753,8 @@ function Upload({
|
||||
} else if (
|
||||
hasRemoteFilePicker ||
|
||||
hasGoogleDrivePicker ||
|
||||
hasSharePointPicker
|
||||
hasSharePointPicker ||
|
||||
hasConfluencePicker
|
||||
) {
|
||||
if (selectedFiles.length === 0 && selectedFolders.length === 0) {
|
||||
return true;
|
||||
|
||||
@@ -6,8 +6,10 @@ import RedditIcon from '../../assets/reddit.svg';
|
||||
import DriveIcon from '../../assets/drive.svg';
|
||||
import S3Icon from '../../assets/s3.svg';
|
||||
import SharePoint from '../../assets/sharepoint.svg';
|
||||
import ConfluenceIcon from '../../assets/confluence.svg';
|
||||
|
||||
export type IngestorType =
|
||||
| 'confluence'
|
||||
| 'crawler'
|
||||
| 'github'
|
||||
| 'reddit'
|
||||
@@ -38,7 +40,8 @@ export type FieldType =
|
||||
| 'local_file_picker'
|
||||
| 'remote_file_picker'
|
||||
| 'google_drive_picker'
|
||||
| 'share_point_picker';
|
||||
| 'share_point_picker'
|
||||
| 'confluence_picker';
|
||||
|
||||
export interface FormField {
|
||||
name: string;
|
||||
@@ -214,6 +217,24 @@ export const IngestorFormSchemas: IngestorSchema[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
key: 'confluence',
|
||||
label: 'Confluence',
|
||||
icon: ConfluenceIcon,
|
||||
heading: 'Upload from Confluence',
|
||||
validate: () => {
|
||||
const confluenceClientId = import.meta.env.VITE_CONFLUENCE_CLIENT_ID;
|
||||
return !!confluenceClientId;
|
||||
},
|
||||
fields: [
|
||||
{
|
||||
name: 'files',
|
||||
label: 'Select Pages from Confluence',
|
||||
type: 'confluence_picker',
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export const IngestorDefaultConfigs: Record<
|
||||
@@ -261,6 +282,13 @@ export const IngestorDefaultConfigs: Record<
|
||||
recursive: true,
|
||||
},
|
||||
},
|
||||
confluence: {
|
||||
name: '',
|
||||
config: {
|
||||
file_ids: '',
|
||||
folder_ids: '',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export interface IngestorOption {
|
||||
|
||||
577
scripts/db/backfill.py
Normal file
577
scripts/db/backfill.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""Backfill DocsGPT's Postgres user-data tables from MongoDB.
|
||||
|
||||
One script for every migrated collection. Adding a new collection is a
|
||||
two-step change in this file:
|
||||
|
||||
1. Write a ``_backfill_<name>`` function that takes keyword args
|
||||
``conn``, ``mongo_db``, ``batch_size``, ``dry_run`` and returns a
|
||||
stats ``dict``.
|
||||
2. Add a single entry to :data:`BACKFILLERS`.
|
||||
|
||||
There are intentionally no per-collection CLI flags or environment
|
||||
variables — ``USE_POSTGRES`` in ``.env`` is the only knob operators
|
||||
need during the migration window. This script discovers what's available from
|
||||
the :data:`BACKFILLERS` registry and runs whichever tables were asked for.
|
||||
|
||||
Usage::
|
||||
|
||||
python scripts/db/backfill.py # every registered table
|
||||
python scripts/db/backfill.py --tables users # only specific tables
|
||||
python scripts/db/backfill.py --dry-run # count without writing
|
||||
python scripts/db/backfill.py --batch 1000 # tune commit size
|
||||
|
||||
Exit codes:
|
||||
0 — every requested table completed successfully
|
||||
1 — misconfiguration (missing env var, unknown table name)
|
||||
2 — at least one table failed at runtime (others may still have succeeded)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
# Make the project root importable regardless of cwd.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from sqlalchemy import Connection, text # noqa: E402
|
||||
|
||||
from application.core.mongo_db import MongoDB # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.engine import get_engine # noqa: E402
|
||||
|
||||
logger = logging.getLogger("backfill")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-table backfillers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _backfill_users(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``users`` table from Mongo ``users`` collection.
|
||||
|
||||
Overwrites each Postgres row's ``agent_preferences`` with the Mongo
|
||||
state (Mongo is source of truth during the cutover window). Missing
|
||||
``pinned`` / ``shared_with_me`` keys are filled with empty arrays so
|
||||
the Postgres row always has the full shape the application expects.
|
||||
"""
|
||||
upsert_sql = text(
|
||||
"""
|
||||
INSERT INTO users (user_id, agent_preferences)
|
||||
VALUES (:user_id, CAST(:prefs AS jsonb))
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET agent_preferences = EXCLUDED.agent_preferences,
|
||||
updated_at = now()
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["users"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
skipped = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user_id")
|
||||
if not user_id:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
raw_prefs = doc.get("agent_preferences") or {}
|
||||
prefs = {
|
||||
"pinned": list(raw_prefs.get("pinned") or []),
|
||||
"shared_with_me": list(raw_prefs.get("shared_with_me") or []),
|
||||
}
|
||||
batch.append({"user_id": user_id, "prefs": json.dumps(prefs)})
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(upsert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(upsert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written, "skipped_no_user_id": skipped}
|
||||
|
||||
|
||||
def _backfill_prompts(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``prompts`` table from Mongo ``prompts`` collection."""
|
||||
upsert_sql = text(
|
||||
"""
|
||||
INSERT INTO prompts (user_id, name, content)
|
||||
VALUES (:user_id, :name, :content)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["prompts"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
skipped = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user")
|
||||
if not user_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"user_id": user_id,
|
||||
"name": doc.get("name", ""),
|
||||
"content": doc.get("content", ""),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(upsert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(upsert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written, "skipped_no_user": skipped}
|
||||
|
||||
|
||||
def _backfill_user_tools(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``user_tools`` table from Mongo ``user_tools`` collection."""
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO user_tools (user_id, name, custom_name, display_name, config)
|
||||
VALUES (:user_id, :name, :custom_name, :display_name, CAST(:config AS jsonb))
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["user_tools"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
skipped = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user")
|
||||
if not user_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"user_id": user_id,
|
||||
"name": doc.get("name", ""),
|
||||
"custom_name": doc.get("customName"),
|
||||
"display_name": doc.get("displayName"),
|
||||
"config": json.dumps(doc.get("config") or {}),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written, "skipped_no_user": skipped}
|
||||
|
||||
|
||||
def _backfill_feedback(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``feedback`` table from Mongo ``feedback`` collection.
|
||||
|
||||
feedback.conversation_id is stored as a string UUID. Rows whose
|
||||
conversation_id cannot be cast to UUID are skipped.
|
||||
"""
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text, timestamp)
|
||||
VALUES (CAST(:conversation_id AS uuid), :user_id, :question_index, :feedback_text, :timestamp)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["feedback"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
skipped = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user")
|
||||
conv_id = doc.get("conversation_id")
|
||||
if not user_id or not conv_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"conversation_id": str(conv_id),
|
||||
"user_id": user_id,
|
||||
"question_index": doc.get("question_index", 0),
|
||||
"feedback_text": doc.get("feedback_text"),
|
||||
"timestamp": doc.get("timestamp"),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
def _backfill_stack_logs(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``stack_logs`` table from Mongo ``stack_logs`` collection."""
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO stack_logs (activity_id, endpoint, level, user_id, api_key, query, stacks, timestamp)
|
||||
VALUES (:activity_id, :endpoint, :level, :user_id, :api_key, :query, CAST(:stacks AS jsonb), :timestamp)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["stack_logs"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
skipped = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
activity_id = doc.get("id")
|
||||
if not activity_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"activity_id": str(activity_id),
|
||||
"endpoint": doc.get("endpoint"),
|
||||
"level": doc.get("level"),
|
||||
"user_id": doc.get("user"),
|
||||
"api_key": doc.get("api_key"),
|
||||
"query": doc.get("query"),
|
||||
"stacks": json.dumps(doc.get("stacks") or []),
|
||||
"timestamp": doc.get("timestamp"),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written, "skipped_no_id": skipped}
|
||||
|
||||
|
||||
def _backfill_user_logs(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``user_logs`` table from Mongo ``user_logs`` collection."""
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO user_logs (user_id, endpoint, data, timestamp)
|
||||
VALUES (:user_id, :endpoint, CAST(:data AS jsonb), :timestamp)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["user_logs"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
# Build a JSONB payload from the full doc (minus Mongo internals).
|
||||
data_payload = {k: v for k, v in doc.items() if k != "_id"}
|
||||
# Stringify ObjectId values inside the payload.
|
||||
for k, v in data_payload.items():
|
||||
if hasattr(v, "__str__") and type(v).__name__ == "ObjectId":
|
||||
data_payload[k] = str(v)
|
||||
batch.append({
|
||||
"user_id": doc.get("user"),
|
||||
"endpoint": doc.get("action") or doc.get("endpoint"),
|
||||
"data": json.dumps(data_payload, default=str),
|
||||
"timestamp": doc.get("timestamp"),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written}
|
||||
|
||||
|
||||
def _backfill_token_usage(
|
||||
*,
|
||||
conn: Connection,
|
||||
mongo_db: Any,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``token_usage`` table from Mongo ``token_usage`` collection."""
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
|
||||
VALUES (
|
||||
:user_id, :api_key,
|
||||
CAST(:agent_id AS uuid),
|
||||
:prompt_tokens, :generated_tokens, :timestamp
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = (
|
||||
mongo_db["token_usage"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
|
||||
seen = 0
|
||||
written = 0
|
||||
batch: list[dict] = []
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
agent_id = doc.get("agent_id")
|
||||
# agent_id may be an ObjectId string or None — only pass if
|
||||
# it looks like a valid UUID (from dual-write) or skip it.
|
||||
agent_id_str = None
|
||||
if agent_id:
|
||||
s = str(agent_id)
|
||||
if len(s) == 36 and "-" in s:
|
||||
agent_id_str = s
|
||||
batch.append({
|
||||
"user_id": doc.get("user_id"),
|
||||
"api_key": doc.get("api_key"),
|
||||
"agent_id": agent_id_str,
|
||||
"prompt_tokens": doc.get("prompt_tokens", 0),
|
||||
"generated_tokens": doc.get("generated_tokens", 0),
|
||||
"timestamp": doc.get("timestamp"),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {"seen": seen, "written": written}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
BackfillFn = Callable[..., dict]
|
||||
|
||||
# Register new tables here. Order matters only in the sense that
|
||||
# ``--tables`` without arguments iterates in insertion order — put tables
|
||||
# with FK dependencies after the tables they reference so a full-run
|
||||
# backfill doesn't hit FK errors.
|
||||
BACKFILLERS: dict[str, BackfillFn] = {
|
||||
"users": _backfill_users,
|
||||
"prompts": _backfill_prompts,
|
||||
"user_tools": _backfill_user_tools,
|
||||
"feedback": _backfill_feedback,
|
||||
"stack_logs": _backfill_stack_logs,
|
||||
"user_logs": _backfill_user_logs,
|
||||
"token_usage": _backfill_token_usage,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Backfill DocsGPT Postgres tables from MongoDB."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tables",
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated table names to backfill. "
|
||||
f"Defaults to every registered table ({','.join(BACKFILLERS)})."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Iterate Mongo without writing to Postgres.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="How many rows to commit per Postgres statement (default: 500).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-5s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
if not settings.POSTGRES_URI:
|
||||
logger.error("POSTGRES_URI is not set. Configure it in .env first.")
|
||||
return 1
|
||||
if not settings.MONGO_URI:
|
||||
logger.error("MONGO_URI is not set. Configure it in .env first.")
|
||||
return 1
|
||||
|
||||
requested = [t.strip() for t in args.tables.split(",") if t.strip()]
|
||||
if not requested:
|
||||
requested = list(BACKFILLERS)
|
||||
|
||||
unknown = [t for t in requested if t not in BACKFILLERS]
|
||||
if unknown:
|
||||
logger.error(
|
||||
"Unknown table(s): %s. Available: %s",
|
||||
", ".join(unknown),
|
||||
", ".join(BACKFILLERS),
|
||||
)
|
||||
return 1
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
mongo_db = mongo[settings.MONGO_DB_NAME]
|
||||
engine = get_engine()
|
||||
|
||||
failures = 0
|
||||
for table in requested:
|
||||
logger.info("backfill %s: start", table)
|
||||
try:
|
||||
with engine.begin() as conn:
|
||||
stats = BACKFILLERS[table](
|
||||
conn=conn,
|
||||
mongo_db=mongo_db,
|
||||
batch_size=args.batch,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
logger.info(
|
||||
"backfill %s: done %s dry_run=%s", table, stats, args.dry_run
|
||||
)
|
||||
except Exception:
|
||||
failures += 1
|
||||
logger.exception("backfill %s: failed", table)
|
||||
|
||||
return 2 if failures else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
55
scripts/db/init_postgres.py
Normal file
55
scripts/db/init_postgres.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""One-shot bootstrap: run all Alembic migrations against POSTGRES_URI.
|
||||
|
||||
Intended use:
|
||||
|
||||
* local dev, after setting ``POSTGRES_URI`` in ``.env``::
|
||||
|
||||
python scripts/db/init_postgres.py
|
||||
|
||||
* CI, as a step before running the pytest suite.
|
||||
|
||||
* Docker image build or container start, if the operator wants the
|
||||
migrations applied automatically on first boot.
|
||||
|
||||
This script is a thin wrapper around ``alembic upgrade head``. It exists
|
||||
separately so the same command is discoverable from the repo root without
|
||||
remembering the ``-c application/alembic.ini`` invocation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
ALEMBIC_INI = REPO_ROOT / "application" / "alembic.ini"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Apply every pending migration up to ``head``.
|
||||
|
||||
Returns:
|
||||
``0`` on success, ``1`` on failure. Non-zero is propagated as the
|
||||
process exit code so CI jobs fail loudly.
|
||||
"""
|
||||
if not ALEMBIC_INI.exists():
|
||||
print(f"alembic.ini not found at {ALEMBIC_INI}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
cfg = Config(str(ALEMBIC_INI))
|
||||
# Make `script_location` resolve correctly when invoked from any cwd.
|
||||
cfg.set_main_option("script_location", str(ALEMBIC_INI.parent / "alembic"))
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, "head")
|
||||
except Exception as exc: # noqa: BLE001 — surface everything to the operator
|
||||
print(f"alembic upgrade failed: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -18,7 +18,7 @@ class TestPostgresExecuteAction:
|
||||
with pytest.raises(ValueError, match="Unknown action"):
|
||||
tool.execute_action("invalid")
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_select_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
@@ -37,7 +37,7 @@ class TestPostgresExecuteAction:
|
||||
assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"}
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_insert_query(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
@@ -55,11 +55,11 @@ class TestPostgresExecuteAction:
|
||||
mock_conn.commit.assert_called_once()
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("connection refused")
|
||||
mock_connect.side_effect = psycopg.Error("connection refused")
|
||||
|
||||
result = tool.execute_action(
|
||||
"postgres_execute_sql", sql_query="SELECT 1"
|
||||
@@ -68,7 +68,7 @@ class TestPostgresExecuteAction:
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_get_schema(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
@@ -89,24 +89,24 @@ class TestPostgresExecuteAction:
|
||||
assert result["schema"]["users"][0]["column_name"] == "id"
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_get_schema_db_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
mock_connect.side_effect = psycopg2.Error("auth failed")
|
||||
mock_connect.side_effect = psycopg.Error("auth failed")
|
||||
|
||||
result = tool.execute_action("postgres_get_schema", db_name="testdb")
|
||||
|
||||
assert result["status_code"] == 500
|
||||
assert "Database error" in result["error"]
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_connection_closed_on_error(self, mock_connect, tool):
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = psycopg2.Error("syntax error")
|
||||
mock_cur.execute.side_effect = psycopg.Error("syntax error")
|
||||
mock_conn.cursor.return_value = mock_cur
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
@@ -114,7 +114,7 @@ class TestPostgresExecuteAction:
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
@patch("application.agents.tools.postgres.psycopg2.connect")
|
||||
@patch("application.agents.tools.postgres.psycopg.connect")
|
||||
def test_select_with_no_description(self, mock_connect, tool):
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
|
||||
144
tests/core/test_db_uri.py
Normal file
144
tests/core/test_db_uri.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Tests for ``application.core.db_uri``.
|
||||
|
||||
DocsGPT has two Postgres connection strings — ``POSTGRES_URI`` (consumed
|
||||
by SQLAlchemy) and ``PGVECTOR_CONNECTION_STRING`` (consumed by
|
||||
``psycopg.connect()`` directly). They need opposite normalization
|
||||
because SQLAlchemy requires a ``postgresql+psycopg://`` dialect prefix
|
||||
and libpq rejects it. Each field has its own normalizer so operators
|
||||
can write whichever form feels natural and cross-pollination between
|
||||
the two fields is forgiven.
|
||||
|
||||
The normalizers live in ``application.core.db_uri`` as plain functions
|
||||
so these tests can exercise them directly without having to instantiate
|
||||
``Settings`` (which would pull in ``.env`` file side effects).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.db_uri import (
|
||||
normalize_pgvector_connection_string,
|
||||
normalize_postgres_uri,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNormalizePostgresUri:
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,expected",
|
||||
[
|
||||
# User-friendly forms get rewritten to the SQLAlchemy dialect.
|
||||
(
|
||||
"postgres://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
(
|
||||
"postgresql://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
# Legacy psycopg2 dialect is silently upgraded — psycopg2 is
|
||||
# no longer in requirements.txt, so there's no way it can work
|
||||
# as-is, and rewriting is friendlier than failing.
|
||||
(
|
||||
"postgresql+psycopg2://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
# Already-correct dialect passes through unchanged.
|
||||
(
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
),
|
||||
# Whitespace is trimmed before rewriting.
|
||||
(
|
||||
" postgres://u:p@h/d ",
|
||||
"postgresql+psycopg://u:p@h/d",
|
||||
),
|
||||
# Query-string params (sslmode, options) are preserved verbatim.
|
||||
(
|
||||
"postgresql://u:p@h:5432/d?sslmode=require&application_name=docsgpt",
|
||||
"postgresql+psycopg://u:p@h:5432/d?sslmode=require&application_name=docsgpt",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rewrites_common_forms_to_psycopg_dialect(self, input_value, expected):
|
||||
assert normalize_postgres_uri(input_value) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value",
|
||||
[None, "", " ", "None", "none"],
|
||||
)
|
||||
def test_empty_or_none_like_returns_none(self, input_value):
|
||||
assert normalize_postgres_uri(input_value) is None
|
||||
|
||||
def test_unknown_scheme_passes_through(self):
|
||||
"""A dialect we don't recognise is left alone so SQLAlchemy can
|
||||
produce its own error message when the engine tries to connect.
|
||||
Better than silently eating the config."""
|
||||
weird = "postgresql+asyncpg://u:p@h/d"
|
||||
assert normalize_postgres_uri(weird) == weird
|
||||
|
||||
def test_non_string_input_passes_through(self):
|
||||
"""Non-string inputs (e.g. if pydantic ever passes an int) shouldn't
|
||||
crash the normalizer — let pydantic's own type validation handle it."""
|
||||
assert normalize_postgres_uri(42) == 42 # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNormalizePgvectorConnectionString:
|
||||
"""Symmetric to the POSTGRES_URI normalizer but pulls in the OPPOSITE
|
||||
direction: strips the SQLAlchemy dialect prefix so libpq accepts it.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value,expected",
|
||||
[
|
||||
# User-friendly forms pass through — libpq accepts them natively.
|
||||
(
|
||||
"postgres://u:p@h:5432/d",
|
||||
"postgres://u:p@h:5432/d",
|
||||
),
|
||||
(
|
||||
"postgresql://u:p@h:5432/d",
|
||||
"postgresql://u:p@h:5432/d",
|
||||
),
|
||||
# SQLAlchemy dialect prefixes get stripped so libpq accepts them.
|
||||
# Operators hit this when they copy POSTGRES_URI → PGVECTOR_CONNECTION_STRING.
|
||||
(
|
||||
"postgresql+psycopg://u:p@h:5432/d",
|
||||
"postgresql://u:p@h:5432/d",
|
||||
),
|
||||
(
|
||||
"postgresql+psycopg2://u:p@h:5432/d",
|
||||
"postgresql://u:p@h:5432/d",
|
||||
),
|
||||
# Whitespace is trimmed before rewriting.
|
||||
(
|
||||
" postgresql+psycopg://u:p@h/d ",
|
||||
"postgresql://u:p@h/d",
|
||||
),
|
||||
# Query-string params (sslmode, etc.) are preserved verbatim.
|
||||
(
|
||||
"postgresql+psycopg://u:p@h:5432/d?sslmode=require",
|
||||
"postgresql://u:p@h:5432/d?sslmode=require",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rewrites_dialect_forms_to_libpq_compatible(self, input_value, expected):
|
||||
assert normalize_pgvector_connection_string(input_value) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value",
|
||||
[None, "", " ", "None", "none"],
|
||||
)
|
||||
def test_empty_or_none_like_returns_none(self, input_value):
|
||||
assert normalize_pgvector_connection_string(input_value) is None
|
||||
|
||||
def test_unknown_scheme_passes_through(self):
|
||||
"""A scheme we don't recognise is left alone so libpq can produce
|
||||
its own error message when the connection is attempted."""
|
||||
weird = "mysql://u:p@h/d"
|
||||
assert normalize_pgvector_connection_string(weird) == weird
|
||||
|
||||
def test_non_string_input_passes_through(self):
|
||||
assert normalize_pgvector_connection_string(42) == 42 # type: ignore[arg-type]
|
||||
84
tests/integration/conftest.py
Normal file
84
tests/integration/conftest.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Fixtures for integration tests that hit a live Postgres.
|
||||
|
||||
These tests are separate from unit tests in two ways:
|
||||
|
||||
1. **Directory**: They live under ``tests/integration/``, which is
|
||||
ignored by the default pytest run via ``--ignore=tests/integration``
|
||||
in ``pytest.ini``. The CI ``pytest.yml`` workflow therefore skips
|
||||
them automatically — it runs the fast unit suite only.
|
||||
2. **Marker**: Each test is marked ``@pytest.mark.integration`` so it
|
||||
can be selected (or excluded) independently of its directory with
|
||||
``pytest -m integration`` / ``-m "not integration"``.
|
||||
|
||||
Running the Postgres-backed integration tests manually::
|
||||
|
||||
.venv/bin/python -m pytest tests/integration/test_users_repository.py \\
|
||||
--override-ini="addopts=" --no-cov
|
||||
|
||||
(The ``--override-ini="addopts="`` is needed because ``pytest.ini``
|
||||
contains ``--ignore=tests/integration`` in ``addopts``; without the
|
||||
override pytest would still skip the directory even though you named
|
||||
it on the command line.)
|
||||
|
||||
Tests are skipped automatically if ``POSTGRES_URI`` is unset, so a
|
||||
contributor who hasn't set up a local Postgres gets clean skips
|
||||
instead of red tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine, text
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pg_engine() -> Engine:
|
||||
"""Session-scoped SQLAlchemy engine for the Postgres integration DB.
|
||||
|
||||
Skips all Postgres-backed tests if ``POSTGRES_URI`` is unset. This
|
||||
keeps CI and contributor machines without a local Postgres from
|
||||
erroring out — integration tests that require the DB become
|
||||
no-ops rather than failures.
|
||||
"""
|
||||
if not settings.POSTGRES_URI:
|
||||
pytest.skip("POSTGRES_URI not set — skipping Postgres integration tests")
|
||||
engine = create_engine(settings.POSTGRES_URI, future=True, pool_pre_ping=True)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_conn(pg_engine: Engine):
|
||||
"""Per-test Postgres connection wrapped in a rolled-back transaction.
|
||||
|
||||
Uses SQLAlchemy's explicit outer-transaction pattern so every test
|
||||
sees a pristine DB view without having to truncate tables. Any
|
||||
nested ``begin()`` inside the repository code becomes a SAVEPOINT
|
||||
under the hood.
|
||||
"""
|
||||
conn = pg_engine.connect()
|
||||
outer = conn.begin()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
outer.rollback()
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_clean_users(pg_conn):
|
||||
"""Guarantee a clean ``users`` table view for tests that need it.
|
||||
|
||||
The outer transaction rollback handles cleanup, but if a previous
|
||||
interrupted run left rows committed, this fixture removes them
|
||||
inside the transaction scope so they are invisible to the test.
|
||||
``DELETE`` rather than ``TRUNCATE`` because ``TRUNCATE`` in Postgres
|
||||
cannot be rolled back within a nested transaction the way
|
||||
``DELETE`` can.
|
||||
"""
|
||||
pg_conn.execute(text("DELETE FROM users"))
|
||||
return pg_conn
|
||||
179
tests/integration/test_users_repository.py
Normal file
179
tests/integration/test_users_repository.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Integration tests for ``UsersRepository`` against a live Postgres.
|
||||
|
||||
These tests need:
|
||||
|
||||
* A running Postgres reachable via ``POSTGRES_URI``
|
||||
* Alembic migration ``0001_initial`` applied
|
||||
|
||||
They are skipped automatically by the default ``pytest`` run because
|
||||
``pytest.ini`` has ``--ignore=tests/integration`` in ``addopts``. They
|
||||
are additionally marked ``@pytest.mark.integration`` so they can be
|
||||
selected/excluded explicitly. Run them locally with::
|
||||
|
||||
.venv/bin/python -m pytest tests/integration/test_users_repository.py \\
|
||||
--override-ini="addopts=" --no-cov
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``users_collection``: upsert + get + add/remove on the pinned and
|
||||
shared_with_me JSONB arrays, plus explicit tenant-isolation checks
|
||||
(a user's operations must never touch another user's data).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo(pg_clean_users):
|
||||
return UsersRepository(pg_clean_users)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestUpsert:
|
||||
def test_upsert_creates_new_user_with_default_preferences(self, repo):
|
||||
doc = repo.upsert("alice@example.com")
|
||||
assert doc["user_id"] == "alice@example.com"
|
||||
assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []}
|
||||
assert "id" in doc
|
||||
assert "_id" in doc # Mongo-compat alias
|
||||
|
||||
def test_upsert_is_idempotent(self, repo):
|
||||
first = repo.upsert("alice@example.com")
|
||||
second = repo.upsert("alice@example.com")
|
||||
assert first["id"] == second["id"]
|
||||
assert first["agent_preferences"] == second["agent_preferences"]
|
||||
|
||||
def test_upsert_preserves_existing_preferences(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.upsert("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestGet:
|
||||
def test_get_returns_none_for_missing_user(self, repo):
|
||||
assert repo.get("nobody@example.com") is None
|
||||
|
||||
def test_get_returns_dict_for_existing_user(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc is not None
|
||||
assert doc["user_id"] == "alice@example.com"
|
||||
assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAddPinned:
|
||||
def test_add_pinned_to_new_user_creates_user(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
def test_add_pinned_is_idempotent(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
def test_add_pinned_preserves_order(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.add_pinned("alice@example.com", "agent-2")
|
||||
repo.add_pinned("alice@example.com", "agent-3")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1", "agent-2", "agent-3"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRemovePinned:
|
||||
def test_remove_pinned_single(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.add_pinned("alice@example.com", "agent-2")
|
||||
repo.remove_pinned("alice@example.com", "agent-1")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-2"]
|
||||
|
||||
def test_remove_pinned_missing_is_noop(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-1")
|
||||
repo.remove_pinned("alice@example.com", "agent-999")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-1"]
|
||||
|
||||
def test_remove_pinned_bulk(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
for i in range(5):
|
||||
repo.add_pinned("alice@example.com", f"agent-{i}")
|
||||
repo.remove_pinned_bulk("alice@example.com", ["agent-1", "agent-3", "agent-999"])
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-0", "agent-2", "agent-4"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSharedWithMe:
|
||||
def test_add_shared_is_idempotent(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_shared("alice@example.com", "agent-x")
|
||||
repo.add_shared("alice@example.com", "agent-x")
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["agent-x"]
|
||||
|
||||
def test_remove_shared_bulk(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
for i in range(3):
|
||||
repo.add_shared("alice@example.com", f"shared-{i}")
|
||||
repo.remove_shared_bulk("alice@example.com", ["shared-0", "shared-2"])
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["shared-1"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRemoveAgentFromAll:
|
||||
def test_removes_from_both_pinned_and_shared(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-X")
|
||||
repo.add_pinned("alice@example.com", "agent-keep")
|
||||
repo.add_shared("alice@example.com", "agent-X")
|
||||
repo.add_shared("alice@example.com", "agent-keep-2")
|
||||
|
||||
repo.remove_agent_from_all("alice@example.com", "agent-X")
|
||||
|
||||
doc = repo.get("alice@example.com")
|
||||
assert doc["agent_preferences"]["pinned"] == ["agent-keep"]
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["agent-keep-2"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTenantIsolation:
|
||||
"""Security-critical: operations on one user must never touch another's data."""
|
||||
|
||||
def test_add_pinned_does_not_leak_across_users(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.upsert("bob@example.com")
|
||||
repo.add_pinned("alice@example.com", "agent-a")
|
||||
repo.add_pinned("bob@example.com", "agent-b")
|
||||
|
||||
alice = repo.get("alice@example.com")
|
||||
bob = repo.get("bob@example.com")
|
||||
assert alice["agent_preferences"]["pinned"] == ["agent-a"]
|
||||
assert bob["agent_preferences"]["pinned"] == ["agent-b"]
|
||||
|
||||
def test_remove_does_not_leak_across_users(self, repo):
|
||||
repo.upsert("alice@example.com")
|
||||
repo.upsert("bob@example.com")
|
||||
repo.add_pinned("alice@example.com", "shared-agent-id")
|
||||
repo.add_pinned("bob@example.com", "shared-agent-id")
|
||||
|
||||
repo.remove_pinned("alice@example.com", "shared-agent-id")
|
||||
|
||||
assert repo.get("alice@example.com")["agent_preferences"]["pinned"] == []
|
||||
assert repo.get("bob@example.com")["agent_preferences"]["pinned"] == [
|
||||
"shared-agent-id"
|
||||
]
|
||||
@@ -20,133 +20,49 @@ def test_epub_init_parser():
|
||||
assert parser.parser_config_set
|
||||
|
||||
|
||||
def test_epub_parser_ebooklib_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when ebooklib is not available."""
|
||||
with patch.dict(sys.modules, {"ebooklib": None}):
|
||||
with pytest.raises(ValueError, match="`EbookLib` is required to read Epub files"):
|
||||
def test_epub_parser_fast_ebook_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when fast-ebook is not available."""
|
||||
with patch.dict(sys.modules, {"fast_ebook": None}):
|
||||
with pytest.raises(ValueError, match="`fast-ebook` is required to read Epub files"):
|
||||
epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_html2text_import_error(epub_parser):
|
||||
"""Test that ImportError is raised when html2text is not available."""
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
with patch.dict(sys.modules, {"ebooklib": fake_ebooklib, "ebooklib.epub": fake_epub}):
|
||||
with patch.dict(sys.modules, {"html2text": None}):
|
||||
with pytest.raises(ValueError, match="`html2text` is required to parse Epub files"):
|
||||
epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_successful_parsing(epub_parser):
|
||||
"""Test successful parsing of an epub file."""
|
||||
fake_fast_ebook = types.ModuleType("fast_ebook")
|
||||
fake_epub = types.ModuleType("fast_ebook.epub")
|
||||
fake_fast_ebook.epub = fake_epub
|
||||
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
# Mock ebooklib constants
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
mock_item1 = MagicMock()
|
||||
mock_item1.get_type.return_value = "document"
|
||||
mock_item1.get_content.return_value = b"<h1>Chapter 1</h1><p>Content 1</p>"
|
||||
|
||||
mock_item2 = MagicMock()
|
||||
mock_item2.get_type.return_value = "document"
|
||||
mock_item2.get_content.return_value = b"<h1>Chapter 2</h1><p>Content 2</p>"
|
||||
|
||||
mock_item3 = MagicMock()
|
||||
mock_item3.get_type.return_value = "other" # Should be ignored
|
||||
mock_item3.get_content.return_value = b"<p>Other content</p>"
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = [mock_item1, mock_item2, mock_item3]
|
||||
|
||||
mock_book.to_markdown.return_value = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
|
||||
def mock_html2text_func(html_content):
|
||||
if "Chapter 1" in html_content:
|
||||
return "# Chapter 1\n\nContent 1\n"
|
||||
elif "Chapter 2" in html_content:
|
||||
return "# Chapter 2\n\nContent 2\n"
|
||||
return "Other content\n"
|
||||
|
||||
fake_html2text.html2text = mock_html2text_func
|
||||
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
"fast_ebook": fake_fast_ebook,
|
||||
"fast_ebook.epub": fake_epub,
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
expected_result = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
assert result == expected_result
|
||||
|
||||
# Verify epub.read_epub was called with correct parameters
|
||||
fake_epub.read_epub.assert_called_once_with(Path("test.epub"), options={"ignore_ncx": True})
|
||||
|
||||
assert result == "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
||||
fake_epub.read_epub.assert_called_once_with(Path("test.epub"))
|
||||
|
||||
|
||||
def test_epub_parser_empty_book(epub_parser):
|
||||
"""Test parsing an epub file with no document items."""
|
||||
# Create mock modules
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
# Create mock book with no document items
|
||||
"""Test parsing an epub file with no content."""
|
||||
fake_fast_ebook = types.ModuleType("fast_ebook")
|
||||
fake_epub = types.ModuleType("fast_ebook.epub")
|
||||
fake_fast_ebook.epub = fake_epub
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = []
|
||||
|
||||
mock_book.to_markdown.return_value = ""
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
fake_html2text.html2text = MagicMock()
|
||||
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
"fast_ebook": fake_fast_ebook,
|
||||
"fast_ebook.epub": fake_epub,
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("empty.epub"))
|
||||
|
||||
assert result == ""
|
||||
|
||||
fake_html2text.html2text.assert_not_called()
|
||||
|
||||
|
||||
def test_epub_parser_non_document_items_ignored(epub_parser):
|
||||
"""Test that non-document items are ignored during parsing."""
|
||||
fake_ebooklib = types.ModuleType("ebooklib")
|
||||
fake_epub = types.ModuleType("ebooklib.epub")
|
||||
fake_html2text = types.ModuleType("html2text")
|
||||
|
||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
||||
fake_ebooklib.epub = fake_epub
|
||||
|
||||
mock_doc_item = MagicMock()
|
||||
mock_doc_item.get_type.return_value = "document"
|
||||
mock_doc_item.get_content.return_value = b"<p>Document content</p>"
|
||||
|
||||
mock_other_item = MagicMock()
|
||||
mock_other_item.get_type.return_value = "image" # Not a document
|
||||
|
||||
mock_book = MagicMock()
|
||||
mock_book.get_items.return_value = [mock_other_item, mock_doc_item]
|
||||
|
||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
||||
fake_html2text.html2text = MagicMock(return_value="Document content\n")
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"ebooklib": fake_ebooklib,
|
||||
"ebooklib.epub": fake_epub,
|
||||
"html2text": fake_html2text
|
||||
}):
|
||||
result = epub_parser.parse_file(Path("test.epub"))
|
||||
|
||||
assert result == "Document content\n"
|
||||
|
||||
fake_html2text.html2text.assert_called_once_with("<p>Document content</p>")
|
||||
|
||||
@@ -31,6 +31,7 @@ class TestGitHubLoaderFetchFileContent:
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.github.com/repos/owner/repo/contents/README.md",
|
||||
headers=loader.headers,
|
||||
timeout=100,
|
||||
)
|
||||
|
||||
@patch("application.parser.remote.github_loader.requests.get")
|
||||
@@ -66,7 +67,7 @@ class TestGitHubLoaderFetchRepoFiles:
|
||||
def test_recurses_directories(self, mock_get):
|
||||
loader = GitHubLoader()
|
||||
|
||||
def side_effect(url, headers=None):
|
||||
def side_effect(url, headers=None, timeout=None):
|
||||
if url.endswith("/contents/"):
|
||||
return make_response([
|
||||
{"type": "file", "path": "README.md"},
|
||||
|
||||
0
tests/storage/db/__init__.py
Normal file
0
tests/storage/db/__init__.py
Normal file
66
tests/storage/db/conftest.py
Normal file
66
tests/storage/db/conftest.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Fixtures for repository tests against a real Postgres instance.
|
||||
|
||||
These tests hit the local dev Postgres (the DBngin instance on this machine,
|
||||
or CI's service container). Each test runs inside a transaction that is
|
||||
rolled back at the end, so tests never leak state into each other and the
|
||||
database stays clean without needing per-test CREATE/DROP overhead.
|
||||
|
||||
Required env:
|
||||
POSTGRES_URI — e.g. postgresql+psycopg://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
Tests are skipped automatically when POSTGRES_URI is unset so that
|
||||
contributors without a local Postgres can still run the rest of the suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def _run_alembic_upgrade(engine):
|
||||
"""Run ``alembic upgrade head`` to ensure the full schema is present.
|
||||
|
||||
Falls back to inline DDL for CI environments where alembic is not
|
||||
on PATH (shouldn't happen, but defence in depth).
|
||||
"""
|
||||
alembic_ini = Path(__file__).resolve().parents[3] / "application" / "alembic.ini"
|
||||
try:
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
# Alembic failed — tables likely already exist from a prior run.
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pg_engine():
|
||||
"""Session-scoped engine pointing at the test Postgres."""
|
||||
if not settings.POSTGRES_URI:
|
||||
pytest.skip("POSTGRES_URI not set")
|
||||
engine = create_engine(settings.POSTGRES_URI)
|
||||
_run_alembic_upgrade(engine)
|
||||
yield engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pg_conn(pg_engine):
|
||||
"""Per-test connection wrapped in a transaction that always rolls back.
|
||||
|
||||
Repositories receive this connection and operate normally. At teardown
|
||||
the outer transaction is rolled back so no data persists between tests.
|
||||
"""
|
||||
conn = pg_engine.connect()
|
||||
txn = conn.begin()
|
||||
yield conn
|
||||
txn.rollback()
|
||||
conn.close()
|
||||
0
tests/storage/db/repositories/__init__.py
Normal file
0
tests/storage/db/repositories/__init__.py
Normal file
79
tests/storage/db/repositories/test_feedback.py
Normal file
79
tests/storage/db/repositories/test_feedback.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for FeedbackRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.feedback import FeedbackRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> FeedbackRepository:
|
||||
return FeedbackRepository(conn)
|
||||
|
||||
|
||||
def _make_conversation_id(pg_conn) -> str:
|
||||
"""Insert a minimal conversations row and return its id as string.
|
||||
|
||||
feedback has a FK to conversations (added in Tier 3 migration), but
|
||||
the FK constraint may not exist yet during early phases. We create a
|
||||
row anyway to keep tests realistic.
|
||||
"""
|
||||
cid = str(uuid.uuid4())
|
||||
# Only insert if the conversations table exists; otherwise use a random UUID.
|
||||
row = pg_conn.execute(
|
||||
text(
|
||||
"SELECT 1 FROM information_schema.tables "
|
||||
"WHERE table_schema='public' AND table_name='conversations'"
|
||||
)
|
||||
).scalar()
|
||||
if row:
|
||||
pg_conn.execute(
|
||||
text("INSERT INTO conversations (id, user_id) VALUES (CAST(:id AS uuid), 'test')"),
|
||||
{"id": cid},
|
||||
)
|
||||
return cid
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_feedback(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid = _make_conversation_id(pg_conn)
|
||||
doc = repo.create(cid, "user-1", 0, "great answer")
|
||||
assert doc["conversation_id"] is not None
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["question_index"] == 0
|
||||
assert doc["feedback_text"] == "great answer"
|
||||
|
||||
def test_allows_null_feedback_text(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid = _make_conversation_id(pg_conn)
|
||||
doc = repo.create(cid, "user-1", 1)
|
||||
assert doc["feedback_text"] is None
|
||||
|
||||
|
||||
class TestListForConversation:
|
||||
def test_lists_feedback_for_conversation(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid = _make_conversation_id(pg_conn)
|
||||
repo.create(cid, "user-1", 0, "good")
|
||||
repo.create(cid, "user-1", 1, "bad")
|
||||
results = repo.list_for_conversation(cid)
|
||||
assert len(results) == 2
|
||||
assert results[0]["question_index"] == 0
|
||||
assert results[1]["question_index"] == 1
|
||||
|
||||
def test_does_not_mix_conversations(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid1 = _make_conversation_id(pg_conn)
|
||||
cid2 = _make_conversation_id(pg_conn)
|
||||
repo.create(cid1, "user-1", 0, "a")
|
||||
repo.create(cid2, "user-1", 0, "b")
|
||||
assert len(repo.list_for_conversation(cid1)) == 1
|
||||
115
tests/storage/db/repositories/test_prompts.py
Normal file
115
tests/storage/db/repositories/test_prompts.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Tests for PromptsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> PromptsRepository:
|
||||
return PromptsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_prompt(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "greeting", "Hello {{name}}")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "greeting"
|
||||
assert doc["content"] == "Hello {{name}}"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "p", "c")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_by_id_and_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
fetched = repo.get(created["id"], user_id="user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_by_id_only(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched is not None
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
assert repo.get(created["id"], user_id="user-other") is None
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_prompts(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "a1", "c1")
|
||||
repo.create("alice", "a2", "c2")
|
||||
repo.create("bob", "b1", "c3")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
def test_updates_name_and_content(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "old-content")
|
||||
repo.update(created["id"], "user-1", "new", "new-content")
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["name"] == "new"
|
||||
assert fetched["content"] == "new-content"
|
||||
|
||||
def test_update_wrong_user_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "old-content")
|
||||
repo.update(created["id"], "user-other", "new", "new-content")
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["name"] == "old"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_prompt(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
repo.delete(created["id"], "user-1")
|
||||
assert repo.get(created["id"]) is None
|
||||
|
||||
def test_delete_wrong_user_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
repo.delete(created["id"], "user-other")
|
||||
assert repo.get(created["id"]) is not None
|
||||
|
||||
|
||||
class TestFindOrCreate:
|
||||
def test_creates_when_missing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.find_or_create("sys", "template", "content")
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_returns_existing_on_match(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
first = repo.find_or_create("sys", "template", "content")
|
||||
second = repo.find_or_create("sys", "template", "content")
|
||||
assert first["id"] == second["id"]
|
||||
|
||||
def test_different_content_creates_new(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
first = repo.find_or_create("sys", "template", "v1")
|
||||
second = repo.find_or_create("sys", "template", "v2")
|
||||
assert first["id"] != second["id"]
|
||||
58
tests/storage/db/repositories/test_stack_logs.py
Normal file
58
tests/storage/db/repositories/test_stack_logs.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for StackLogsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> StackLogsRepository:
|
||||
return StackLogsRepository(conn)
|
||||
|
||||
|
||||
class TestInsert:
|
||||
def test_inserts_log(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(
|
||||
activity_id="act-1",
|
||||
endpoint="/api/answer",
|
||||
level="info",
|
||||
user_id="u1",
|
||||
api_key="k1",
|
||||
query="what is python?",
|
||||
stacks=[{"component": "retriever", "data": {"docs": 3}}],
|
||||
)
|
||||
row = pg_conn.execute(
|
||||
text("SELECT * FROM stack_logs WHERE activity_id = 'act-1'")
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
mapping = dict(row._mapping)
|
||||
assert mapping["endpoint"] == "/api/answer"
|
||||
assert mapping["level"] == "info"
|
||||
assert mapping["user_id"] == "u1"
|
||||
assert mapping["stacks"] == [{"component": "retriever", "data": {"docs": 3}}]
|
||||
|
||||
def test_inserts_with_empty_stacks(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(activity_id="act-2", level="error")
|
||||
row = pg_conn.execute(
|
||||
text("SELECT stacks FROM stack_logs WHERE activity_id = 'act-2'")
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
assert dict(row._mapping)["stacks"] == []
|
||||
|
||||
def test_truncated_query_stored(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
long_query = "x" * 20000
|
||||
repo.insert(activity_id="act-3", query=long_query)
|
||||
row = pg_conn.execute(
|
||||
text("SELECT query FROM stack_logs WHERE activity_id = 'act-3'")
|
||||
).fetchone()
|
||||
assert len(dict(row._mapping)["query"]) == 20000
|
||||
90
tests/storage/db/repositories/test_token_usage.py
Normal file
90
tests/storage/db/repositories/test_token_usage.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Tests for TokenUsageRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> TokenUsageRepository:
|
||||
return TokenUsageRepository(conn)
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class TestInsert:
|
||||
def test_inserts_row(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1", prompt_tokens=10, generated_tokens=5)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert total == 15
|
||||
|
||||
def test_insert_with_api_key(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(api_key="key-1", prompt_tokens=20, generated_tokens=10)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), api_key="key-1"
|
||||
)
|
||||
assert total == 30
|
||||
|
||||
|
||||
class TestSumTokensInRange:
|
||||
def test_sums_correctly(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1", prompt_tokens=10, generated_tokens=5)
|
||||
repo.insert(user_id="u1", prompt_tokens=20, generated_tokens=10)
|
||||
repo.insert(user_id="u2", prompt_tokens=100, generated_tokens=50)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert total == 45
|
||||
|
||||
def test_returns_zero_when_no_rows(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="nobody"
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
def test_respects_time_range(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
old = _now() - timedelta(hours=48)
|
||||
repo.insert(user_id="u1", prompt_tokens=100, generated_tokens=0, timestamp=old)
|
||||
repo.insert(user_id="u1", prompt_tokens=10, generated_tokens=0)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(hours=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert total == 10
|
||||
|
||||
|
||||
class TestCountInRange:
|
||||
def test_counts_rows(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1", prompt_tokens=1, generated_tokens=1)
|
||||
repo.insert(user_id="u1", prompt_tokens=1, generated_tokens=1)
|
||||
repo.insert(user_id="u2", prompt_tokens=1, generated_tokens=1)
|
||||
count = repo.count_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert count == 2
|
||||
|
||||
def test_filters_by_api_key(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(api_key="k1", prompt_tokens=1, generated_tokens=1)
|
||||
repo.insert(api_key="k2", prompt_tokens=1, generated_tokens=1)
|
||||
count = repo.count_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), api_key="k1"
|
||||
)
|
||||
assert count == 1
|
||||
64
tests/storage/db/repositories/test_user_logs.py
Normal file
64
tests/storage/db/repositories/test_user_logs.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Tests for UserLogsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> UserLogsRepository:
|
||||
return UserLogsRepository(conn)
|
||||
|
||||
|
||||
class TestInsert:
|
||||
def test_inserts_log(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1", endpoint="/api/answer", data={"question": "hi"})
|
||||
rows, _ = repo.list_paginated(user_id="u1")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["data"]["question"] == "hi"
|
||||
|
||||
def test_allows_null_data(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1")
|
||||
rows, _ = repo.list_paginated(user_id="u1")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["data"] is None
|
||||
|
||||
|
||||
class TestListPaginated:
|
||||
def test_paginates(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
for i in range(5):
|
||||
repo.insert(user_id="u1", data={"i": i})
|
||||
page1, has_more1 = repo.list_paginated(user_id="u1", page=1, page_size=3)
|
||||
assert len(page1) == 3
|
||||
assert has_more1 is True
|
||||
page2, has_more2 = repo.list_paginated(user_id="u1", page=2, page_size=3)
|
||||
assert len(page2) == 2
|
||||
assert has_more2 is False
|
||||
|
||||
def test_filters_by_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="alice", data={"x": 1})
|
||||
repo.insert(user_id="bob", data={"x": 2})
|
||||
rows, _ = repo.list_paginated(user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
|
||||
def test_ordered_by_timestamp_desc(self, pg_conn):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
earlier = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
later = datetime.now(timezone.utc)
|
||||
repo.insert(user_id="u1", data={"order": "first"}, timestamp=earlier)
|
||||
repo.insert(user_id="u1", data={"order": "second"}, timestamp=later)
|
||||
rows, _ = repo.list_paginated(user_id="u1")
|
||||
assert rows[0]["data"]["order"] == "second"
|
||||
100
tests/storage/db/repositories/test_user_tools.py
Normal file
100
tests/storage/db/repositories/test_user_tools.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Tests for UserToolsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> UserToolsRepository:
|
||||
return UserToolsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_tool(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "my_tool", config={"key": "val"})
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "my_tool"
|
||||
assert doc["config"] == {"key": "val"}
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_create_with_display_names(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "t", custom_name="Custom", display_name="Display")
|
||||
assert doc["custom_name"] == "Custom"
|
||||
assert doc["display_name"] == "Display"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "t")
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_tools(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "t1")
|
||||
repo.create("alice", "t2")
|
||||
repo.create("bob", "t3")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
def test_updates_name(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old_name")
|
||||
repo.update(created["id"], "user-1", {"name": "new_name"})
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["name"] == "new_name"
|
||||
|
||||
def test_updates_config(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "t", config={"a": 1})
|
||||
repo.update(created["id"], "user-1", {"config": {"a": 2, "b": 3}})
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["config"] == {"a": 2, "b": 3}
|
||||
|
||||
def test_update_wrong_user_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
repo.update(created["id"], "user-other", {"name": "new"})
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["name"] == "old"
|
||||
|
||||
def test_ignores_disallowed_fields(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "t")
|
||||
repo.update(created["id"], "user-1", {"id": "00000000-0000-0000-0000-000000000000"})
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_tool(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "t")
|
||||
deleted = repo.delete(created["id"], "user-1")
|
||||
assert deleted is True
|
||||
assert repo.get(created["id"]) is None
|
||||
|
||||
def test_delete_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "t")
|
||||
deleted = repo.delete(created["id"], "user-other")
|
||||
assert deleted is False
|
||||
assert repo.get(created["id"]) is not None
|
||||
222
tests/storage/db/repositories/test_users.py
Normal file
222
tests/storage/db/repositories/test_users.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Tests for UsersRepository against a real Postgres instance.
|
||||
|
||||
Every test runs inside a rolled-back transaction (see ``pg_conn`` fixture
|
||||
in the parent conftest) so no data leaks between tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _repo(conn) -> UsersRepository:
|
||||
return UsersRepository(conn)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# upsert / get
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestUpsert:
|
||||
def test_creates_new_user_with_defaults(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.upsert("user-new")
|
||||
assert doc["user_id"] == "user-new"
|
||||
assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []}
|
||||
assert "id" in doc
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_upsert_is_idempotent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
first = repo.upsert("user-idem")
|
||||
second = repo.upsert("user-idem")
|
||||
assert first["id"] == second["id"]
|
||||
assert first["agent_preferences"] == second["agent_preferences"]
|
||||
|
||||
def test_upsert_preserves_existing_preferences(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-prefs")
|
||||
repo.add_pinned("user-prefs", "agent-1")
|
||||
doc = repo.upsert("user-prefs")
|
||||
assert "agent-1" in doc["agent_preferences"]["pinned"]
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_returns_none_for_missing_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("nonexistent") is None
|
||||
|
||||
def test_returns_user_after_upsert(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-get")
|
||||
doc = repo.get("user-get")
|
||||
assert doc is not None
|
||||
assert doc["user_id"] == "user-get"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# pinned agents
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestPinned:
|
||||
def test_add_pinned(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-pin")
|
||||
repo.add_pinned("user-pin", "a1")
|
||||
doc = repo.get("user-pin")
|
||||
assert doc["agent_preferences"]["pinned"] == ["a1"]
|
||||
|
||||
def test_add_pinned_is_idempotent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-pin2")
|
||||
repo.add_pinned("user-pin2", "a1")
|
||||
repo.add_pinned("user-pin2", "a1")
|
||||
doc = repo.get("user-pin2")
|
||||
assert doc["agent_preferences"]["pinned"] == ["a1"]
|
||||
|
||||
def test_add_multiple_pinned(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-pin3")
|
||||
repo.add_pinned("user-pin3", "a1")
|
||||
repo.add_pinned("user-pin3", "a2")
|
||||
doc = repo.get("user-pin3")
|
||||
assert set(doc["agent_preferences"]["pinned"]) == {"a1", "a2"}
|
||||
|
||||
def test_remove_pinned(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-unpin")
|
||||
repo.add_pinned("user-unpin", "a1")
|
||||
repo.add_pinned("user-unpin", "a2")
|
||||
repo.remove_pinned("user-unpin", "a1")
|
||||
doc = repo.get("user-unpin")
|
||||
assert doc["agent_preferences"]["pinned"] == ["a2"]
|
||||
|
||||
def test_remove_pinned_nonexistent_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-unpin2")
|
||||
repo.add_pinned("user-unpin2", "a1")
|
||||
repo.remove_pinned("user-unpin2", "zzz")
|
||||
doc = repo.get("user-unpin2")
|
||||
assert doc["agent_preferences"]["pinned"] == ["a1"]
|
||||
|
||||
def test_remove_pinned_bulk(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-bulk")
|
||||
repo.add_pinned("user-bulk", "a1")
|
||||
repo.add_pinned("user-bulk", "a2")
|
||||
repo.add_pinned("user-bulk", "a3")
|
||||
repo.remove_pinned_bulk("user-bulk", ["a1", "a3"])
|
||||
doc = repo.get("user-bulk")
|
||||
assert doc["agent_preferences"]["pinned"] == ["a2"]
|
||||
|
||||
def test_remove_pinned_bulk_empty_list_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-bulk2")
|
||||
repo.add_pinned("user-bulk2", "a1")
|
||||
repo.remove_pinned_bulk("user-bulk2", [])
|
||||
doc = repo.get("user-bulk2")
|
||||
assert doc["agent_preferences"]["pinned"] == ["a1"]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# shared_with_me
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestShared:
|
||||
def test_add_shared(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-share")
|
||||
repo.add_shared("user-share", "s1")
|
||||
doc = repo.get("user-share")
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["s1"]
|
||||
|
||||
def test_add_shared_is_idempotent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-share2")
|
||||
repo.add_shared("user-share2", "s1")
|
||||
repo.add_shared("user-share2", "s1")
|
||||
doc = repo.get("user-share2")
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["s1"]
|
||||
|
||||
def test_remove_shared_bulk(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-rshare")
|
||||
repo.add_shared("user-rshare", "s1")
|
||||
repo.add_shared("user-rshare", "s2")
|
||||
repo.remove_shared_bulk("user-rshare", ["s1"])
|
||||
doc = repo.get("user-rshare")
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["s2"]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# remove_agent_from_all (cascade on agent delete)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestRemoveAgentFromAll:
|
||||
def test_removes_from_both_arrays(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-cascade")
|
||||
repo.add_pinned("user-cascade", "agent-x")
|
||||
repo.add_shared("user-cascade", "agent-x")
|
||||
repo.remove_agent_from_all("user-cascade", "agent-x")
|
||||
doc = repo.get("user-cascade")
|
||||
assert "agent-x" not in doc["agent_preferences"]["pinned"]
|
||||
assert "agent-x" not in doc["agent_preferences"]["shared_with_me"]
|
||||
|
||||
def test_leaves_other_agents_untouched(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-cascade2")
|
||||
repo.add_pinned("user-cascade2", "keep")
|
||||
repo.add_pinned("user-cascade2", "remove")
|
||||
repo.add_shared("user-cascade2", "keep")
|
||||
repo.add_shared("user-cascade2", "remove")
|
||||
repo.remove_agent_from_all("user-cascade2", "remove")
|
||||
doc = repo.get("user-cascade2")
|
||||
assert doc["agent_preferences"]["pinned"] == ["keep"]
|
||||
assert doc["agent_preferences"]["shared_with_me"] == ["keep"]
|
||||
|
||||
def test_noop_when_agent_not_present(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("user-cascade3")
|
||||
repo.add_pinned("user-cascade3", "a1")
|
||||
repo.remove_agent_from_all("user-cascade3", "nonexistent")
|
||||
doc = repo.get("user-cascade3")
|
||||
assert doc["agent_preferences"]["pinned"] == ["a1"]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# tenant isolation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestTenantIsolation:
|
||||
def test_get_cannot_see_other_users(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("alice")
|
||||
repo.upsert("bob")
|
||||
repo.add_pinned("alice", "a-private")
|
||||
bob_doc = repo.get("bob")
|
||||
assert "a-private" not in bob_doc["agent_preferences"]["pinned"]
|
||||
|
||||
def test_mutations_on_one_user_dont_affect_another(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("alice")
|
||||
repo.upsert("bob")
|
||||
repo.add_pinned("alice", "a1")
|
||||
repo.add_shared("bob", "s1")
|
||||
alice = repo.get("alice")
|
||||
bob = repo.get("bob")
|
||||
assert alice["agent_preferences"]["pinned"] == ["a1"]
|
||||
assert alice["agent_preferences"]["shared_with_me"] == []
|
||||
assert bob["agent_preferences"]["pinned"] == []
|
||||
assert bob["agent_preferences"]["shared_with_me"] == ["s1"]
|
||||
@@ -342,7 +342,7 @@ class TestDownloadFile:
|
||||
dest = str(tmp_path / "downloaded.txt")
|
||||
download_file("http://example.com/file", {"key": "val"}, dest)
|
||||
|
||||
mock_get.assert_called_once_with("http://example.com/file", params={"key": "val"})
|
||||
mock_get.assert_called_once_with("http://example.com/file", params={"key": "val"}, timeout=100)
|
||||
with open(dest, "rb") as f:
|
||||
assert f.read() == b"file content"
|
||||
|
||||
|
||||
@@ -16,10 +16,9 @@ def _make_store(
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"psycopg2": MagicMock(),
|
||||
"psycopg2.extras": MagicMock(),
|
||||
"psycopg": MagicMock(),
|
||||
"pgvector": MagicMock(),
|
||||
"pgvector.psycopg2": MagicMock(),
|
||||
"pgvector.psycopg": MagicMock(),
|
||||
},
|
||||
):
|
||||
mock_emb = Mock()
|
||||
@@ -63,15 +62,15 @@ class TestPGVectorStoreInit:
|
||||
) as mock_settings, patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"psycopg2": MagicMock(),
|
||||
"psycopg2.extras": MagicMock(),
|
||||
"psycopg": MagicMock(),
|
||||
"pgvector": MagicMock(),
|
||||
"pgvector.psycopg2": MagicMock(),
|
||||
"pgvector.psycopg": MagicMock(),
|
||||
},
|
||||
):
|
||||
mock_get_emb.return_value = Mock(dimension=768)
|
||||
mock_settings.EMBEDDINGS_NAME = "test_model"
|
||||
mock_settings.PGVECTOR_CONNECTION_STRING = None
|
||||
mock_settings.POSTGRES_URI = None
|
||||
|
||||
from application.vectorstore.pgvector import PGVectorStore
|
||||
|
||||
@@ -264,13 +263,13 @@ class TestPGVectorStoreConnection:
|
||||
store, mock_conn, _, _ = _make_store()
|
||||
mock_conn.closed = True
|
||||
|
||||
mock_psycopg2 = MagicMock()
|
||||
mock_psycopg = MagicMock()
|
||||
new_conn = MagicMock()
|
||||
mock_psycopg2.connect.return_value = new_conn
|
||||
store._psycopg2 = mock_psycopg2
|
||||
mock_psycopg.connect.return_value = new_conn
|
||||
store._psycopg = mock_psycopg
|
||||
|
||||
conn = store._get_connection()
|
||||
mock_psycopg2.connect.assert_called_once()
|
||||
mock_psycopg.connect.assert_called_once()
|
||||
assert conn is new_conn
|
||||
|
||||
def test_get_connection_reuses_open(self):
|
||||
|
||||
Reference in New Issue
Block a user