From ececcb8b17e3dbce61ee8a3806e1cbaf640cd331 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 12 Apr 2026 00:07:24 +0100 Subject: [PATCH] feat: init pg migration --- .env-template | 6 + .gitignore | 2 + application/agents/tools/postgres.py | 10 +- application/alembic.ini | 52 ++ application/alembic/env.py | 82 ++++ application/alembic/script.py.mako | 26 + application/alembic/versions/0001_initial.py | 462 ++++++++++++++++++ application/api/user/agents/routes.py | 14 + application/api/user/agents/sharing.py | 8 + application/api/user/base.py | 5 + application/celery_init.py | 21 +- application/core/db_uri.py | 89 ++++ application/core/settings.py | 41 +- application/requirements.txt | 4 +- application/storage/db/__init__.py | 10 + application/storage/db/base_repository.py | 39 ++ application/storage/db/dual_write.py | 67 +++ application/storage/db/engine.py | 67 +++ application/storage/db/models.py | 38 ++ .../storage/db/repositories/__init__.py | 11 + application/storage/db/repositories/users.py | 245 ++++++++++ application/vectorstore/pgvector.py | 20 +- docs/content/Deploying/Postgres-Migration.mdx | 114 +++++ docs/content/Deploying/_meta.js | 4 + scripts/db/backfill.py | 218 +++++++++ scripts/db/init_postgres.py | 55 +++ tests/agents/test_postgres_tool.py | 26 +- tests/core/test_db_uri.py | 144 ++++++ tests/integration/conftest.py | 84 ++++ tests/integration/test_users_repository.py | 179 +++++++ tests/vectorstore/test_pgvector.py | 18 +- 31 files changed, 2119 insertions(+), 42 deletions(-) create mode 100644 application/alembic.ini create mode 100644 application/alembic/env.py create mode 100644 application/alembic/script.py.mako create mode 100644 application/alembic/versions/0001_initial.py create mode 100644 application/core/db_uri.py create mode 100644 application/storage/db/__init__.py create mode 100644 application/storage/db/base_repository.py create mode 100644 application/storage/db/dual_write.py create mode 100644 application/storage/db/engine.py create mode 100644 application/storage/db/models.py create mode 100644 application/storage/db/repositories/__init__.py create mode 100644 application/storage/db/repositories/users.py create mode 100644 docs/content/Deploying/Postgres-Migration.mdx create mode 100644 scripts/db/backfill.py create mode 100644 scripts/db/init_postgres.py create mode 100644 tests/core/test_db_uri.py create mode 100644 tests/integration/conftest.py create mode 100644 tests/integration/test_users_repository.py diff --git a/.env-template b/.env-template index 4e00b75b..9ac711e0 100644 --- a/.env-template +++ b/.env-template @@ -34,3 +34,9 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id #or "https://login.microsoftonline.com/contoso.onmicrosoft.com". #Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app. MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId} + +# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration). +# Standard Postgres URI — `postgres://` and `postgresql://` both work. +# Leave unset while the migration is still being rolled out; the app will +# fall back to MongoDB for user data until POSTGRES_URI is configured. +# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt diff --git a/.gitignore b/.gitignore index 4d8f8c84..1b36e9fc 100644 --- a/.gitignore +++ b/.gitignore @@ -108,6 +108,8 @@ celerybeat.pid # Environments .env .venv +# Machine-specific Claude Code guidance (see CLAUDE.md preamble) +CLAUDE.md env/ venv/ ENV/ diff --git a/application/agents/tools/postgres.py b/application/agents/tools/postgres.py index d9d5a2b4..fe7fe9f6 100644 --- a/application/agents/tools/postgres.py +++ b/application/agents/tools/postgres.py @@ -1,6 +1,6 @@ import logging -import psycopg2 +import psycopg from application.agents.tools.base import Tool @@ -33,7 +33,7 @@ class PostgresTool(Tool): """ conn = None try: - conn = psycopg2.connect(self.connection_string) + conn = psycopg.connect(self.connection_string) cur = conn.cursor() cur.execute(sql_query) conn.commit() @@ -60,7 +60,7 @@ class PostgresTool(Tool): "response_data": response_data, } - except psycopg2.Error as e: + except psycopg.Error as e: error_message = f"Database error: {e}" logger.error("PostgreSQL execute_sql error: %s", e) return { @@ -78,7 +78,7 @@ class PostgresTool(Tool): """ conn = None try: - conn = psycopg2.connect(self.connection_string) + conn = psycopg.connect(self.connection_string) cur = conn.cursor() cur.execute( @@ -120,7 +120,7 @@ class PostgresTool(Tool): "schema": schema_data, } - except psycopg2.Error as e: + except psycopg.Error as e: error_message = f"Database error: {e}" logger.error("PostgreSQL get_schema error: %s", e) return { diff --git a/application/alembic.ini b/application/alembic.ini new file mode 100644 index 00000000..9e996a47 --- /dev/null +++ b/application/alembic.ini @@ -0,0 +1,52 @@ +# Alembic configuration for the DocsGPT user-data Postgres database. +# +# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from +# ``application.core.settings.settings.POSTGRES_URI`` so the same config +# source serves the running app and migrations. To run from the project +# root:: +# +# alembic -c application/alembic.ini upgrade head + +[alembic] +script_location = %(here)s/alembic +prepend_sys_path = .. +version_path_separator = os + +# sqlalchemy.url is intentionally left blank — env.py supplies it. +sqlalchemy.url = + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/application/alembic/env.py b/application/alembic/env.py new file mode 100644 index 00000000..68eadc01 --- /dev/null +++ b/application/alembic/env.py @@ -0,0 +1,82 @@ +"""Alembic environment for the DocsGPT user-data Postgres database. + +The URL is pulled from ``application.core.settings`` rather than +``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the +running app and ``alembic`` CLI invocations. +""" + +import sys +from logging.config import fileConfig +from pathlib import Path + +# Make the project root importable regardless of cwd. env.py lives at +# /application/alembic/env.py, so parents[2] is the repo root. +_PROJECT_ROOT = Path(__file__).resolve().parents[2] +if str(_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT)) + +from alembic import context # noqa: E402 +from sqlalchemy import engine_from_config, pool # noqa: E402 + +from application.core.settings import settings # noqa: E402 +from application.storage.db.models import metadata as target_metadata # noqa: E402 + +config = context.config + +# Populate the runtime URL from settings. +if settings.POSTGRES_URI: + config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI) + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode (emits SQL without a live DB).""" + url = config.get_main_option("sqlalchemy.url") + if not url: + raise RuntimeError( + "POSTGRES_URI is not configured. Set it in your .env to a " + "psycopg3 URI such as " + "'postgresql+psycopg://user:pass@host:5432/docsgpt'." + ) + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode against a live connection.""" + if not config.get_main_option("sqlalchemy.url"): + raise RuntimeError( + "POSTGRES_URI is not configured. Set it in your .env to a " + "psycopg3 URI such as " + "'postgresql+psycopg://user:pass@host:5432/docsgpt'." + ) + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + future=True, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/application/alembic/script.py.mako b/application/alembic/script.py.mako new file mode 100644 index 00000000..fbc4b07d --- /dev/null +++ b/application/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/application/alembic/versions/0001_initial.py b/application/alembic/versions/0001_initial.py new file mode 100644 index 00000000..84565451 --- /dev/null +++ b/application/alembic/versions/0001_initial.py @@ -0,0 +1,462 @@ +"""0001 initial schema — user-level tables migrated from MongoDB. + +Creates every table described in §2.2 of ``migration-postgres.md``: tiers 1, +2, and 3 in one shot. The schema is small enough that splitting the baseline +across multiple revisions would only cost clarity. + +Subsequent migrations will add columns / tables incrementally. This file is +hand-written raw DDL rather than Core ``op.create_table`` calls because the +DDL uses several Postgres-specific features (``CITEXT``, partial indexes, +``text_pattern_ops``, JSONB defaults) that are clearer in SQL than in +Alembic's Python API. + +Revision ID: 0001_initial +Revises: +Create Date: 2026-04-10 +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0001_initial" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ------------------------------------------------------------------ + # Extensions + # ------------------------------------------------------------------ + op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";') + op.execute('CREATE EXTENSION IF NOT EXISTS "citext";') + + # ------------------------------------------------------------------ + # Tier 1: leaf tables, no FKs into other migrated tables + # ------------------------------------------------------------------ + op.execute(""" + CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL UNIQUE, + agent_preferences JSONB NOT NULL + DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX users_user_id_idx ON users (user_id);") + + op.execute(""" + CREATE TABLE prompts ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + name TEXT NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);") + + op.execute(""" + CREATE TABLE user_tools ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + name TEXT NOT NULL, + custom_name TEXT, + display_name TEXT, + config JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);") + + op.execute(""" + CREATE TABLE token_usage ( + id BIGSERIAL PRIMARY KEY, + user_id TEXT, + api_key TEXT, + agent_id UUID, -- FK added later in this migration + prompt_tokens INTEGER NOT NULL DEFAULT 0, + generated_tokens INTEGER NOT NULL DEFAULT 0, + timestamp TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, timestamp DESC);") + op.execute("CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, timestamp DESC);") + op.execute("CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, timestamp DESC);") + + op.execute(""" + CREATE TABLE user_logs ( + id BIGSERIAL PRIMARY KEY, + user_id TEXT, + endpoint TEXT, + timestamp TIMESTAMPTZ NOT NULL DEFAULT now(), + data JSONB + ); + """) + op.execute("CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, timestamp DESC);") + + op.execute(""" + CREATE TABLE feedback ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + conversation_id UUID NOT NULL, -- FK added later in this migration + user_id TEXT NOT NULL, + question_index INTEGER NOT NULL, + feedback_text TEXT, + timestamp TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX feedback_conv_idx ON feedback (conversation_id);") + + # Append-only debug/error log. The Mongo doc has both `_id` (auto) and an + # `id` field (the activity id). Here the serial PK owns `id`; the + # application-level identifier is renamed to `activity_id`. + op.execute(""" + CREATE TABLE stack_logs ( + id BIGSERIAL PRIMARY KEY, + activity_id TEXT NOT NULL, + endpoint TEXT, + level TEXT, + user_id TEXT, + api_key TEXT, + query TEXT, + stacks JSONB NOT NULL DEFAULT '[]'::jsonb, + timestamp TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX stack_logs_timestamp_idx ON stack_logs (timestamp DESC);") + op.execute("CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, timestamp DESC);") + op.execute("CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, timestamp DESC);") + op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);") + + # ------------------------------------------------------------------ + # Tier 2: FK-bearing tables + # ------------------------------------------------------------------ + op.execute(""" + CREATE TABLE agent_folders ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);") + + op.execute(""" + CREATE TABLE sources ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT, -- NULL for system/template sources + name TEXT NOT NULL, + type TEXT, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX sources_user_idx ON sources (user_id);") + + op.execute(""" + CREATE TABLE agents ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT, + agent_type TEXT, + status TEXT NOT NULL, + key CITEXT UNIQUE, + source_id UUID REFERENCES sources(id) ON DELETE SET NULL, + extra_source_ids UUID[] NOT NULL DEFAULT '{}', + chunks INTEGER, + retriever TEXT, + prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL, + tools JSONB NOT NULL DEFAULT '[]'::jsonb, + json_schema JSONB, + models JSONB, + default_model_id TEXT, + folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL, + limited_token_mode BOOLEAN NOT NULL DEFAULT false, + token_limit INTEGER, + limited_request_mode BOOLEAN NOT NULL DEFAULT false, + request_limit INTEGER, + shared BOOLEAN NOT NULL DEFAULT false, + incoming_webhook_token CITEXT UNIQUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + last_used_at TIMESTAMPTZ + ); + """) + op.execute("CREATE INDEX agents_user_idx ON agents (user_id);") + op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;") + op.execute("CREATE INDEX agents_status_idx ON agents (status);") + + # Backfill the token_usage.agent_id FK now that agents exists. + op.execute(""" + ALTER TABLE token_usage + ADD CONSTRAINT token_usage_agent_fk + FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL; + """) + + op.execute(""" + CREATE TABLE attachments ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + filename TEXT NOT NULL, + upload_path TEXT NOT NULL, + mime_type TEXT, + size BIGINT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);") + + op.execute(""" + CREATE TABLE memories ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE, + path TEXT NOT NULL, + content TEXT NOT NULL, + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute(""" + CREATE UNIQUE INDEX memories_user_tool_path_uidx + ON memories (user_id, tool_id, path); + """) + op.execute(""" + CREATE INDEX memories_path_prefix_idx + ON memories (user_id, tool_id, path text_pattern_ops); + """) + + op.execute(""" + CREATE TABLE todos ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE, + title TEXT NOT NULL, + completed BOOLEAN NOT NULL DEFAULT false, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);") + + op.execute(""" + CREATE TABLE notes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE, + title TEXT NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX notes_user_tool_idx ON notes (user_id, tool_id);") + + op.execute(""" + CREATE TABLE connector_sessions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + provider TEXT NOT NULL, + session_data JSONB NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute(""" + CREATE INDEX connector_sessions_user_provider_idx + ON connector_sessions (user_id, provider); + """) + op.execute(""" + CREATE INDEX connector_sessions_expiry_idx + ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL; + """) + + # ------------------------------------------------------------------ + # Tier 3: conversations, pending_tool_state, workflows + # ------------------------------------------------------------------ + op.execute(""" + CREATE TABLE conversations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + agent_id UUID REFERENCES agents(id) ON DELETE SET NULL, + name TEXT, + api_key TEXT, + is_shared_usage BOOLEAN NOT NULL DEFAULT false, + shared_token TEXT, + date TIMESTAMPTZ NOT NULL DEFAULT now(), + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);") + op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);") + op.execute(""" + CREATE INDEX conversations_shared_token_idx + ON conversations (shared_token) WHERE shared_token IS NOT NULL; + """) + + op.execute(""" + CREATE TABLE conversation_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + position INTEGER NOT NULL, + prompt TEXT, + response TEXT, + thought TEXT, + sources JSONB NOT NULL DEFAULT '[]'::jsonb, + tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb, + attachments UUID[] NOT NULL DEFAULT '{}', + model_id TEXT, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + feedback JSONB, + timestamp TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute(""" + CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx + ON conversation_messages (conversation_id, position); + """) + + # Backfill the feedback.conversation_id FK now that conversations exists. + op.execute(""" + ALTER TABLE feedback + ADD CONSTRAINT feedback_conv_fk + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE; + """) + + op.execute(""" + CREATE TABLE shared_conversations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + user_id TEXT NOT NULL, + prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL, + chunks INTEGER, + is_promptable BOOLEAN NOT NULL DEFAULT false, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);") + op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);") + + # Paused-tool continuation state. The Mongo version relies on a TTL index; + # Postgres has no native TTL, so a Celery beat task (added in Phase 3) + # deletes rows where expires_at < now() once a minute. The unique + # constraint on (conversation_id, user_id) matches the existing upsert + # semantics. + op.execute(""" + CREATE TABLE pending_tool_state ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + user_id TEXT NOT NULL, + messages JSONB NOT NULL, + pending_tool_calls JSONB NOT NULL, + tools_dict JSONB NOT NULL, + tool_schemas JSONB NOT NULL, + agent_config JSONB NOT NULL, + client_tools JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + expires_at TIMESTAMPTZ NOT NULL + ); + """) + op.execute(""" + CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx + ON pending_tool_state (conversation_id, user_id); + """) + op.execute(""" + CREATE INDEX pending_tool_state_expires_idx + ON pending_tool_state (expires_at); + """) + + # Workflows + op.execute(""" + CREATE TABLE workflows ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + """) + op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);") + + op.execute(""" + CREATE TABLE workflow_nodes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE, + graph_version INTEGER NOT NULL, + node_type TEXT NOT NULL, + config JSONB NOT NULL DEFAULT '{}'::jsonb + ); + """) + op.execute(""" + CREATE INDEX workflow_nodes_workflow_version_idx + ON workflow_nodes (workflow_id, graph_version); + """) + + op.execute(""" + CREATE TABLE workflow_edges ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE, + graph_version INTEGER NOT NULL, + from_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE, + to_node_id UUID NOT NULL REFERENCES workflow_nodes(id) ON DELETE CASCADE, + config JSONB NOT NULL DEFAULT '{}'::jsonb + ); + """) + op.execute(""" + CREATE INDEX workflow_edges_workflow_version_idx + ON workflow_edges (workflow_id, graph_version); + """) + + op.execute(""" + CREATE TABLE workflow_runs ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE, + user_id TEXT NOT NULL, + status TEXT NOT NULL, + started_at TIMESTAMPTZ NOT NULL DEFAULT now(), + ended_at TIMESTAMPTZ, + result JSONB + ); + """) + op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);") + op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);") + + +def downgrade() -> None: + # Reverse dependency order. CASCADE would handle FKs anyway, but explicit + # is clearer for anyone reading the migration. + op.execute("DROP TABLE IF EXISTS workflow_runs CASCADE;") + op.execute("DROP TABLE IF EXISTS workflow_edges CASCADE;") + op.execute("DROP TABLE IF EXISTS workflow_nodes CASCADE;") + op.execute("DROP TABLE IF EXISTS workflows CASCADE;") + op.execute("DROP TABLE IF EXISTS pending_tool_state CASCADE;") + op.execute("DROP TABLE IF EXISTS shared_conversations CASCADE;") + op.execute("DROP TABLE IF EXISTS conversation_messages CASCADE;") + op.execute("DROP TABLE IF EXISTS conversations CASCADE;") + op.execute("DROP TABLE IF EXISTS connector_sessions CASCADE;") + op.execute("DROP TABLE IF EXISTS notes CASCADE;") + op.execute("DROP TABLE IF EXISTS todos CASCADE;") + op.execute("DROP TABLE IF EXISTS memories CASCADE;") + op.execute("DROP TABLE IF EXISTS attachments CASCADE;") + op.execute("DROP TABLE IF EXISTS agents CASCADE;") + op.execute("DROP TABLE IF EXISTS sources CASCADE;") + op.execute("DROP TABLE IF EXISTS agent_folders CASCADE;") + op.execute("DROP TABLE IF EXISTS stack_logs CASCADE;") + op.execute("DROP TABLE IF EXISTS feedback CASCADE;") + op.execute("DROP TABLE IF EXISTS user_logs CASCADE;") + op.execute("DROP TABLE IF EXISTS token_usage CASCADE;") + op.execute("DROP TABLE IF EXISTS user_tools CASCADE;") + op.execute("DROP TABLE IF EXISTS prompts CASCADE;") + op.execute("DROP TABLE IF EXISTS users CASCADE;") + # Extensions are intentionally left in place — they may be shared with + # pgvector or other extensions already enabled on the cluster. diff --git a/application/api/user/agents/routes.py b/application/api/user/agents/routes.py index 7d3dc93a..b5d6fa6c 100644 --- a/application/api/user/agents/routes.py +++ b/application/api/user/agents/routes.py @@ -23,6 +23,8 @@ from application.api.user.base import ( workflow_nodes_collection, workflows_collection, ) +from application.storage.db.dual_write import dual_write +from application.storage.db.repositories.users import UsersRepository from application.core.json_schema_utils import ( JsonSchemaValidationError, normalize_json_schema_payload, @@ -1250,6 +1252,9 @@ class PinnedAgents(Resource): {"user_id": user_id}, {"$pullAll": {"agent_preferences.pinned": stale_ids}}, ) + dual_write(UsersRepository, + lambda repo, uid=user_id, ids=stale_ids: repo.remove_pinned_bulk(uid, ids) + ) list_pinned_agents = [ { "id": str(agent["_id"]), @@ -1381,12 +1386,18 @@ class PinAgent(Resource): {"user_id": user_id}, {"$pull": {"agent_preferences.pinned": agent_id}}, ) + dual_write(UsersRepository, + lambda repo, uid=user_id, aid=agent_id: repo.remove_pinned(uid, aid) + ) action = "unpinned" else: users_collection.update_one( {"user_id": user_id}, {"$addToSet": {"agent_preferences.pinned": agent_id}}, ) + dual_write(UsersRepository, + lambda repo, uid=user_id, aid=agent_id: repo.add_pinned(uid, aid) + ) action = "pinned" except Exception as err: current_app.logger.error(f"Error pinning/unpinning agent: {err}") @@ -1432,6 +1443,9 @@ class RemoveSharedAgent(Resource): } }, ) + dual_write(UsersRepository, + lambda repo, uid=user_id, aid=agent_id: repo.remove_agent_from_all(uid, aid) + ) return make_response(jsonify({"success": True, "action": "removed"}), 200) except Exception as err: diff --git a/application/api/user/agents/sharing.py b/application/api/user/agents/sharing.py index 034fc75d..8077773d 100644 --- a/application/api/user/agents/sharing.py +++ b/application/api/user/agents/sharing.py @@ -18,6 +18,8 @@ from application.api.user.base import ( user_tools_collection, users_collection, ) +from application.storage.db.dual_write import dual_write +from application.storage.db.repositories.users import UsersRepository from application.utils import generate_image_url agents_sharing_ns = Namespace( @@ -105,6 +107,9 @@ class SharedAgent(Resource): {"user_id": user_id}, {"$addToSet": {"agent_preferences.shared_with_me": agent_id}}, ) + dual_write(UsersRepository, + lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid) + ) return make_response(jsonify(data), 200) except Exception as err: current_app.logger.error(f"Error retrieving shared agent: {err}") @@ -139,6 +144,9 @@ class SharedAgents(Resource): {"user_id": user_id}, {"$pullAll": {"agent_preferences.shared_with_me": stale_ids}}, ) + dual_write(UsersRepository, + lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids) + ) pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", [])) list_shared_agents = [ diff --git a/application/api/user/base.py b/application/api/user/base.py index afc22268..d5df9ff8 100644 --- a/application/api/user/base.py +++ b/application/api/user/base.py @@ -15,6 +15,8 @@ from werkzeug.utils import secure_filename from application.core.mongo_db import MongoDB from application.core.settings import settings +from application.storage.db.dual_write import dual_write +from application.storage.db.repositories.users import UsersRepository from application.storage.storage_creator import StorageCreator from application.vectorstore.vector_creator import VectorCreator @@ -132,6 +134,9 @@ def ensure_user_doc(user_id): if updates: users_collection.update_one({"user_id": user_id}, {"$set": updates}) user_doc = users_collection.find_one({"user_id": user_id}) + + dual_write(UsersRepository, lambda repo: repo.upsert(user_id)) + return user_doc diff --git a/application/celery_init.py b/application/celery_init.py index 3e9c3c57..555182e0 100644 --- a/application/celery_init.py +++ b/application/celery_init.py @@ -1,6 +1,6 @@ from celery import Celery from application.core.settings import settings -from celery.signals import setup_logging +from celery.signals import setup_logging, worker_process_init def make_celery(app_name=__name__): @@ -20,5 +20,24 @@ def config_loggers(*args, **kwargs): setup_logging() +@worker_process_init.connect +def _dispose_db_engine_on_fork(*args, **kwargs): + """Dispose the SQLAlchemy engine pool in each forked Celery worker. + + SQLAlchemy connection pools are not fork-safe: file descriptors shared + between the parent and a forked worker will corrupt the pool. Disposing + on ``worker_process_init`` gives every worker its own fresh pool on + first use. + + Imported lazily so Celery workers that don't touch Postgres (or where + ``POSTGRES_URI`` is unset) don't fail at startup. + """ + try: + from application.storage.db.engine import dispose_engine + except Exception: + return + dispose_engine() + + celery = make_celery() celery.config_from_object("application.celeryconfig") diff --git a/application/core/db_uri.py b/application/core/db_uri.py new file mode 100644 index 00000000..f620da41 --- /dev/null +++ b/application/core/db_uri.py @@ -0,0 +1,89 @@ +"""Normalize user-supplied Postgres URIs for different drivers. + +DocsGPT has two Postgres connection strings pointing at potentially +different databases: + +* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the + ``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver. +* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly + (via libpq) in ``application/vectorstore/pgvector.py``. libpq only + understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy + dialect prefix is an invalid URI from its point of view. + +The two fields therefore need opposite normalization so operators don't +have to know which driver a given field feeds. Each normalizer also +silently upgrades the legacy ``postgresql+psycopg2://`` prefix since +psycopg2 is no longer in the project. + +This module is deliberately separate from ``application/core/settings.py`` +so the Settings class stays focused on field declarations, and the +URI-rewriting logic can be unit-tested without triggering ``.env`` +file loading from importing Settings. +""" + +from __future__ import annotations + + +def _rewrite_uri_prefixes(v, rewrites): + """Shared URI prefix rewriter used by both normalizers below. + + Strips whitespace, returns ``None`` for empty / ``"none"`` values, + applies the first matching rewrite, and passes unrecognised input + through so downstream consumers (SQLAlchemy, libpq) can produce + their own error messages rather than us silently eating a + misconfiguration. + """ + if v is None: + return None + if not isinstance(v, str): + return v + v = v.strip() + if not v or v.lower() == "none": + return None + for prefix, target in rewrites: + if v.startswith(prefix): + return target + v[len(prefix):] + return v + + +# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://`` +# dialect prefix to select the psycopg v3 driver. Normalize the +# operator-friendly forms TOWARD that dialect. +_POSTGRES_URI_REWRITES = ( + ("postgresql+psycopg2://", "postgresql+psycopg://"), + ("postgresql://", "postgresql+psycopg://"), + ("postgres://", "postgresql+psycopg://"), +) + + +# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in +# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only +# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy +# dialect prefix is an invalid URI from libpq's point of view. Strip it +# if the operator accidentally copied their POSTGRES_URI value here. +_PGVECTOR_CONNECTION_STRING_REWRITES = ( + ("postgresql+psycopg2://", "postgresql://"), + ("postgresql+psycopg://", "postgresql://"), +) + + +def normalize_postgres_uri(v): + """Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form. + + Accepts the forms operators naturally write (``postgres://``, + ``postgresql://``) and rewrites them to ``postgresql+psycopg://``. + Unknown schemes pass through unchanged so SQLAlchemy can produce its + own dialect-not-found error. + """ + return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES) + + +def normalize_pgvector_connection_string(v): + """Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq. + + Strips the SQLAlchemy dialect prefix if the operator accidentally + copied their POSTGRES_URI value here — libpq can't parse it. + User-friendly forms (``postgres://``, ``postgresql://``) pass + through unchanged since libpq accepts them natively. + """ + return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES) diff --git a/application/core/settings.py b/application/core/settings.py index 9dbc1584..4e3e64e6 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -8,6 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from application.core.db_uri import ( # noqa: E402 + normalize_pgvector_connection_string, + normalize_postgres_uri, +) + + class Settings(BaseSettings): model_config = SettingsConfigDict(extra="ignore") @@ -22,6 +28,26 @@ class Settings(BaseSettings): CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1" MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MONGO_DB_NAME: str = "docsgpt" + # User-data Postgres DB (see migration-postgres.md). Optional during the + # MongoDB→Postgres migration; becomes required once the migration is + # complete. Write the URI in whichever form you prefer — all of + # postgres://user:pass@host:port/db + # postgresql://user:pass@host:port/db + # postgresql+psycopg://user:pass@host:port/db + # are accepted and normalized internally to the psycopg3 dialect. + POSTGRES_URI: Optional[str] = None + + # MongoDB→Postgres migration — two global switches, no per-collection + # knobs. Everything that has a Postgres repository implementation is + # dual-written when USE_POSTGRES is on; new collections join the set + # automatically as they're implemented. Flip READ_POSTGRES once you + # trust the Postgres state to cut reads over. + # + # Default False everywhere so behaviour is unchanged until an operator + # explicitly opts in. READ_POSTGRES without USE_POSTGRES is nonsensical + # during the migration window; call sites enforce the pairing. + USE_POSTGRES: bool = False + READ_POSTGRES: bool = False LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") DEFAULT_MAX_HISTORY: int = 150 DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry @@ -117,7 +143,10 @@ class Settings(BaseSettings): QDRANT_PATH: Optional[str] = None QDRANT_DISTANCE_FUNC: str = "Cosine" - # PGVector vectorstore config + # PGVector vectorstore config. Write the URI in whichever form you + # prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy + # dialect form (``postgresql+psycopg://``) are all accepted and + # normalized internally for ``psycopg.connect()``. PGVECTOR_CONNECTION_STRING: Optional[str] = None # Milvus vectorstore config MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt" @@ -156,6 +185,16 @@ class Settings(BaseSettings): COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat + @field_validator("POSTGRES_URI", mode="before") + @classmethod + def _normalize_postgres_uri_validator(cls, v): + return normalize_postgres_uri(v) + + @field_validator("PGVECTOR_CONNECTION_STRING", mode="before") + @classmethod + def _normalize_pgvector_connection_string_validator(cls, v): + return normalize_pgvector_connection_string(v) + @field_validator( "API_KEY", "OPENAI_API_KEY", diff --git a/application/requirements.txt b/application/requirements.txt index c7a43582..cd4ff410 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -1,3 +1,4 @@ +alembic>=1.13,<2 anthropic==0.88.0 boto3==1.42.83 beautifulsoup4==4.14.3 @@ -58,7 +59,7 @@ pillow portalocker>=2.7.0,<4.0.0 prompt-toolkit==3.0.52 protobuf==7.34.1 -psycopg2-binary==2.9.11 +psycopg[binary,pool]>=3.1,<4 py==1.11.0 pydantic pydantic-core @@ -75,6 +76,7 @@ regex==2026.4.4 requests==2.33.1 retry==0.9.2 sentence-transformers==5.3.0 +sqlalchemy>=2.0,<3 tiktoken==0.12.0 tokenizers==0.22.2 torch==2.11.0 diff --git a/application/storage/db/__init__.py b/application/storage/db/__init__.py new file mode 100644 index 00000000..e323c352 --- /dev/null +++ b/application/storage/db/__init__.py @@ -0,0 +1,10 @@ +"""PostgreSQL storage layer for user-level data. + +This package holds the SQLAlchemy Core engine, metadata, repositories, and +migration infrastructure for the user-data Postgres database. It is separate +from ``application/vectorstore/pgvector.py`` — the two may point at the same +cluster or at different clusters depending on operator configuration. + +Repository modules are added in later phases +as individual collections are ported. +""" diff --git a/application/storage/db/base_repository.py b/application/storage/db/base_repository.py new file mode 100644 index 00000000..be81b727 --- /dev/null +++ b/application/storage/db/base_repository.py @@ -0,0 +1,39 @@ +"""Common helpers shared by all repositories. + +Repositories are thin wrappers around SQLAlchemy Core query construction. +They take a ``Connection`` on call and return plain ``dict`` rows during the +Mongo→Postgres cutover so that call sites don't have to change shape. Once +cutover is complete, a follow-up phase may migrate repo return types to +Pydantic DTOs (tracked in the migration plan as a post-migration item). +""" + +from typing import Any, Mapping +from uuid import UUID + + +def row_to_dict(row: Any) -> dict: + """Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids. + + During the migration window, API responses and downstream code still + expect a string ``_id`` field (matching the Mongo shape). This helper + normalizes UUID columns to strings and emits both ``id`` and ``_id`` so + existing serializers keep working unchanged. + + Args: + row: A SQLAlchemy ``Row`` object, or ``None``. + + Returns: + A plain dict, or an empty dict if ``row`` is ``None``. + """ + if row is None: + return {} + + # Row has a ``._mapping`` attribute exposing a MappingProxy view. + mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined] + out = dict(mapping) + + if "id" in out and out["id"] is not None: + out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"] + out["_id"] = out["id"] + + return out diff --git a/application/storage/db/dual_write.py b/application/storage/db/dual_write.py new file mode 100644 index 00000000..a13a9b9f --- /dev/null +++ b/application/storage/db/dual_write.py @@ -0,0 +1,67 @@ +"""Best-effort Postgres dual-write helper used during the MongoDB→Postgres +migration. + +The helper: + +* Returns immediately if ``settings.USE_POSTGRES`` is off, so default-off + call sites add literally zero work. +* Opens a transactional connection from the user-data SQLAlchemy engine. +* Instantiates the caller's repository class on that connection. +* Runs the caller's operation. +* Swallows and logs any exception. **Mongo remains the source of truth + during the dual-write window** — a Postgres-side failure must never + break a user-facing request. Drift that builds up from swallowed + failures is caught separately by re-running the backfill script. + +Call sites look like:: + + users_collection.update_one(..., {"$addToSet": {...}}) # Mongo write, unchanged + dual_write(UsersRepository, lambda r: r.add_pinned(uid, aid)) # Postgres mirror + +A single parameterised helper rather than one function per collection +means a new collection just needs its repository class — no new helper +function, no new feature flag. The whole helper is deleted at Phase 5 +when the migration is complete. +""" + +from __future__ import annotations + +import logging +from typing import Callable, TypeVar + +from application.core.settings import settings + +logger = logging.getLogger(__name__) + +_Repo = TypeVar("_Repo") + + +def dual_write(repo_cls: type[_Repo], fn: Callable[[_Repo], None]) -> None: + """Mirror a Mongo write into Postgres via ``repo_cls``, best-effort. + + No-op when ``settings.USE_POSTGRES`` is false. Any exception + (connection pool exhaustion, migration drift, SQL error) is logged + and swallowed so the caller's primary Mongo write remains the source + of truth. + + Args: + repo_cls: The repository class to instantiate (e.g. ``UsersRepository``). + fn: A callable that takes the instantiated repository and performs + the desired write. + """ + if not settings.USE_POSTGRES: + return + + try: + # Lazy import so modules that import dual_write don't pay the + # SQLAlchemy import cost when the flag is off. + from application.storage.db.engine import get_engine + + with get_engine().begin() as conn: + fn(repo_cls(conn)) + except Exception: + logger.warning( + "Postgres dual-write failed for %s — Mongo write already committed", + repo_cls.__name__, + exc_info=True, + ) diff --git a/application/storage/db/engine.py b/application/storage/db/engine.py new file mode 100644 index 00000000..723a1eb7 --- /dev/null +++ b/application/storage/db/engine.py @@ -0,0 +1,67 @@ +"""SQLAlchemy Core engine factory for the user-data Postgres database. + +The engine is lazily constructed on first use and cached as a module-level +singleton. Repositories and the Alembic env module both obtain connections +through this factory, so pool tuning lives in one place. + +``POSTGRES_URI`` can be written in any of the common Postgres URI forms:: + + postgres://user:pass@host:5432/docsgpt + postgresql://user:pass@host:5432/docsgpt + +Both are accepted and normalized internally to the psycopg3 dialect +(``postgresql+psycopg://``) by ``application.core.settings``. Operators +don't need to know about SQLAlchemy dialect prefixes. +""" + +from typing import Optional + +from sqlalchemy import Engine, create_engine + +from application.core.settings import settings + +_engine: Optional[Engine] = None + + +def get_engine() -> Engine: + """Return the process-wide SQLAlchemy Engine, creating it if needed. + + Raises: + RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that + reach this path without a configured URI have a setup bug — the + error message points them at the right setting. + + Returns: + A SQLAlchemy ``Engine`` configured with a pooled connection to + Postgres via psycopg3. + """ + global _engine + if _engine is None: + if not settings.POSTGRES_URI: + raise RuntimeError( + "POSTGRES_URI is not configured. Set it in your .env to a " + "psycopg3 URI such as " + "'postgresql+psycopg://user:pass@host:5432/docsgpt'." + ) + _engine = create_engine( + settings.POSTGRES_URI, + pool_size=10, + max_overflow=20, + pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles + pool_recycle=1800, + future=True, + ) + return _engine + + +def dispose_engine() -> None: + """Dispose the pooled connections and reset the singleton. + + Called from the Celery ``worker_process_init`` signal so each forked + worker gets a fresh pool instead of sharing file descriptors with the + parent process (which corrupts the pool on fork). + """ + global _engine + if _engine is not None: + _engine.dispose() + _engine = None diff --git a/application/storage/db/models.py b/application/storage/db/models.py new file mode 100644 index 00000000..8b3d0527 --- /dev/null +++ b/application/storage/db/models.py @@ -0,0 +1,38 @@ +"""SQLAlchemy Core metadata for the user-data Postgres database. + +Tables are added here one at a time as repositories are built during the +MongoDB→Postgres migration. The baseline schema in the Alembic migration +(``application/alembic/versions/0001_initial.py``) is the source of truth +for DDL; the ``Table`` definitions below must match it column-for-column. +If the two drift, migrations win — update this file to match. +""" + +from sqlalchemy import ( + Column, + DateTime, + MetaData, + Table, + Text, + func, +) +from sqlalchemy.dialects.postgresql import JSONB, UUID + +metadata = MetaData() + + +# --- Phase 1, Tier 1 -------------------------------------------------------- + +users_table = Table( + "users", + metadata, + Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()), + Column("user_id", Text, nullable=False, unique=True), + Column( + "agent_preferences", + JSONB, + nullable=False, + server_default='{"pinned": [], "shared_with_me": []}', + ), + Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()), + Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()), +) diff --git a/application/storage/db/repositories/__init__.py b/application/storage/db/repositories/__init__.py new file mode 100644 index 00000000..2d57ab23 --- /dev/null +++ b/application/storage/db/repositories/__init__.py @@ -0,0 +1,11 @@ +"""Repositories for the user-data Postgres database. + +Each module in this package exposes exactly one repository class. Repository +methods take a ``Connection`` (either as a constructor argument or as a +method argument) and return plain ``dict`` rows via +``application.storage.db.base_repository.row_to_dict`` during the +MongoDB→Postgres cutover, so call sites don't have to change shape. + +Repositories are added one collection at a time, matching the phased +rollout in ``migration-postgres.md``. +""" diff --git a/application/storage/db/repositories/users.py b/application/storage/db/repositories/users.py new file mode 100644 index 00000000..2827f518 --- /dev/null +++ b/application/storage/db/repositories/users.py @@ -0,0 +1,245 @@ +"""Repository for the ``users`` table. + +Covers every operation the legacy Mongo code performs on +``users_collection``: + +1. ``ensure_user_doc`` in ``application/api/user/base.py`` (upsert + get) +2. Pin/unpin agents in ``application/api/user/agents/routes.py`` (add/remove + on ``agent_preferences.pinned``) +3. Share accept/reject in ``application/api/user/agents/sharing.py`` (add/ + bulk-remove on ``agent_preferences.shared_with_me``) +4. Cascade delete of an agent id from both arrays at once + +All array mutations are implemented as single atomic UPDATE statements +using JSONB operators (``jsonb_set``, ``jsonb_array_elements``, ``@>``) +so there is no read-modify-write race between concurrent writers on the +same user row. + +The repository takes a ``Connection`` and does not manage its own +transactions. Callers are responsible for wrapping writes in +``with engine.begin() as conn:`` (production) or the test fixture's +rollback-per-test connection (tests). +""" + +from __future__ import annotations + +from typing import Iterable, Optional + +from sqlalchemy import Connection, text + +from application.storage.db.base_repository import row_to_dict + + +_DEFAULT_PREFERENCES = '{"pinned": [], "shared_with_me": []}' + + +class UsersRepository: + """Postgres-backed replacement for Mongo ``users_collection`` writes/reads.""" + + def __init__(self, conn: Connection) -> None: + self._conn = conn + + # ------------------------------------------------------------------ + # Reads + # ------------------------------------------------------------------ + def get(self, user_id: str) -> Optional[dict]: + """Return the user row as a dict, or ``None`` if missing. + + Args: + user_id: Auth-provider ``sub`` (opaque string). + """ + result = self._conn.execute( + text("SELECT * FROM users WHERE user_id = :user_id"), + {"user_id": user_id}, + ) + row = result.fetchone() + return row_to_dict(row) if row is not None else None + + # ------------------------------------------------------------------ + # Upsert + # ------------------------------------------------------------------ + def upsert(self, user_id: str) -> dict: + """Ensure a row exists for ``user_id`` and return it. + + Matches Mongo's ``find_one_and_update(..., $setOnInsert, upsert=True, + return_document=AFTER)`` semantics: if the row exists, preferences + are preserved untouched; if it doesn't, a new row is created with + default preferences. + + The ``DO UPDATE SET user_id = EXCLUDED.user_id`` branch is a + deliberate no-op that lets ``RETURNING *`` fire on both the insert + and conflict paths (``DO NOTHING`` would suppress the returning). + """ + result = self._conn.execute( + text( + """ + INSERT INTO users (user_id, agent_preferences) + VALUES (:user_id, CAST(:default_prefs AS jsonb)) + ON CONFLICT (user_id) DO UPDATE + SET user_id = EXCLUDED.user_id + RETURNING * + """ + ), + {"user_id": user_id, "default_prefs": _DEFAULT_PREFERENCES}, + ) + return row_to_dict(result.fetchone()) + + # ------------------------------------------------------------------ + # Pinned agents + # ------------------------------------------------------------------ + def add_pinned(self, user_id: str, agent_id: str) -> None: + """Idempotently append ``agent_id`` to ``agent_preferences.pinned``. + + Uses ``@>`` containment so a duplicate add is a no-op rather than a + silent double-insert. The whole update is a single atomic statement + so concurrent add_pinned calls on the same user cannot interleave + into a read-modify-write race. + """ + self._append_to_jsonb_array(user_id, "pinned", agent_id) + + def remove_pinned(self, user_id: str, agent_id: str) -> None: + """Remove ``agent_id`` from ``agent_preferences.pinned`` if present.""" + self._remove_from_jsonb_array(user_id, "pinned", [agent_id]) + + def remove_pinned_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None: + """Remove every id in ``agent_ids`` from ``agent_preferences.pinned``. + + No-op if the list is empty. Unknown ids are silently ignored so + callers can pass the full "stale" set without pre-filtering. + """ + ids = list(agent_ids) + if not ids: + return + self._remove_from_jsonb_array(user_id, "pinned", ids) + + # ------------------------------------------------------------------ + # Shared-with-me agents + # ------------------------------------------------------------------ + def add_shared(self, user_id: str, agent_id: str) -> None: + """Idempotently append ``agent_id`` to ``agent_preferences.shared_with_me``.""" + self._append_to_jsonb_array(user_id, "shared_with_me", agent_id) + + def remove_shared_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None: + """Bulk-remove from ``agent_preferences.shared_with_me``. Empty list is a no-op.""" + ids = list(agent_ids) + if not ids: + return + self._remove_from_jsonb_array(user_id, "shared_with_me", ids) + + # ------------------------------------------------------------------ + # Combined removal — called when an agent is hard-deleted + # ------------------------------------------------------------------ + def remove_agent_from_all(self, user_id: str, agent_id: str) -> None: + """Remove ``agent_id`` from BOTH pinned and shared_with_me atomically. + + Mirrors the Mongo ``$pull`` that targets both nested array fields + in one ``update_one`` — see ``application/api/user/agents/routes.py`` + around the agent-delete path. + """ + self._conn.execute( + text( + """ + UPDATE users + SET + agent_preferences = jsonb_set( + jsonb_set( + agent_preferences, + '{pinned}', + COALESCE( + ( + SELECT jsonb_agg(elem) + FROM jsonb_array_elements( + COALESCE(agent_preferences->'pinned', '[]'::jsonb) + ) AS elem + WHERE (elem #>> '{}') != :agent_id + ), + '[]'::jsonb + ) + ), + '{shared_with_me}', + COALESCE( + ( + SELECT jsonb_agg(elem) + FROM jsonb_array_elements( + COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb) + ) AS elem + WHERE (elem #>> '{}') != :agent_id + ), + '[]'::jsonb + ) + ), + updated_at = now() + WHERE user_id = :user_id + """ + ), + {"user_id": user_id, "agent_id": agent_id}, + ) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + def _append_to_jsonb_array(self, user_id: str, key: str, agent_id: str) -> None: + """Idempotent append of ``agent_id`` to ``agent_preferences.``. + + The ``key`` argument is NOT user input — it's hard-coded by the + calling method (``pinned`` / ``shared_with_me``). It goes into the + SQL literal because ``jsonb_set`` requires a path literal, not a + bind parameter. This is safe as long as callers never pass + untrusted strings for ``key``. + """ + if key not in ("pinned", "shared_with_me"): + raise ValueError(f"unsupported jsonb key: {key!r}") + self._conn.execute( + text( + f""" + UPDATE users + SET + agent_preferences = jsonb_set( + agent_preferences, + '{{{key}}}', + CASE + WHEN agent_preferences->'{key}' @> to_jsonb(CAST(:agent_id AS text)) + THEN agent_preferences->'{key}' + ELSE + COALESCE(agent_preferences->'{key}', '[]'::jsonb) + || to_jsonb(CAST(:agent_id AS text)) + END + ), + updated_at = now() + WHERE user_id = :user_id + """ + ), + {"user_id": user_id, "agent_id": agent_id}, + ) + + def _remove_from_jsonb_array( + self, user_id: str, key: str, agent_ids: list[str] + ) -> None: + """Remove every id in ``agent_ids`` from ``agent_preferences.``.""" + if key not in ("pinned", "shared_with_me"): + raise ValueError(f"unsupported jsonb key: {key!r}") + self._conn.execute( + text( + f""" + UPDATE users + SET + agent_preferences = jsonb_set( + agent_preferences, + '{{{key}}}', + COALESCE( + ( + SELECT jsonb_agg(elem) + FROM jsonb_array_elements( + COALESCE(agent_preferences->'{key}', '[]'::jsonb) + ) AS elem + WHERE NOT ((elem #>> '{{}}') = ANY(:agent_ids)) + ), + '[]'::jsonb + ) + ), + updated_at = now() + WHERE user_id = :user_id + """ + ), + {"user_id": user_id, "agent_ids": agent_ids}, + ) diff --git a/application/vectorstore/pgvector.py b/application/vectorstore/pgvector.py index 28233821..eb4e7178 100644 --- a/application/vectorstore/pgvector.py +++ b/application/vectorstore/pgvector.py @@ -37,27 +37,25 @@ class PGVectorStore(BaseVectorStore): ) try: - import psycopg2 - from psycopg2.extras import Json - import pgvector.psycopg2 + import psycopg + from pgvector.psycopg import register_vector except ImportError: raise ImportError( "Could not import required packages. " - "Please install with `pip install psycopg2-binary pgvector`." + "Please install with `pip install 'psycopg[binary,pool]' pgvector`." ) - self._psycopg2 = psycopg2 - self._Json = Json - self._pgvector = pgvector.psycopg2 + self._psycopg = psycopg + self._register_vector = register_vector self._connection = None self._ensure_table_exists() def _get_connection(self): """Get or create database connection""" if self._connection is None or self._connection.closed: - self._connection = self._psycopg2.connect(self._connection_string) + self._connection = self._psycopg.connect(self._connection_string) # Register pgvector types - self._pgvector.register_vector(self._connection) + self._register_vector(self._connection) return self._connection def _ensure_table_exists(self): @@ -170,7 +168,7 @@ class PGVectorStore(BaseVectorStore): for text, embedding, metadata in zip(texts, embeddings, metadatas): cursor.execute( insert_query, - (text, embedding, self._Json(metadata), self._source_id) + (text, embedding, metadata, self._source_id) ) inserted_id = cursor.fetchone()[0] inserted_ids.append(str(inserted_id)) @@ -261,7 +259,7 @@ class PGVectorStore(BaseVectorStore): cursor.execute( insert_query, - (text, embeddings[0], self._Json(final_metadata), self._source_id) + (text, embeddings[0], final_metadata, self._source_id) ) inserted_id = cursor.fetchone()[0] conn.commit() diff --git a/docs/content/Deploying/Postgres-Migration.mdx b/docs/content/Deploying/Postgres-Migration.mdx new file mode 100644 index 00000000..7c1af0b3 --- /dev/null +++ b/docs/content/Deploying/Postgres-Migration.mdx @@ -0,0 +1,114 @@ +--- +title: PostgreSQL for User Data +description: Set up PostgreSQL as the user-data store for DocsGPT and migrate from MongoDB at your own pace. +--- + +import { Callout } from 'nextra/components' + +# PostgreSQL for User Data + +DocsGPT is progressively moving user data (conversations, agents, prompts, +preferences, etc.) from MongoDB to PostgreSQL, one collection at a time. +Each collection is guarded by a feature flag so you can opt in and roll +back instantly. MongoDB stays the source of truth until you cut over +reads; vector stores (`VECTOR_STORE=pgvector`, `faiss`, `qdrant`, `mongodb`, …) +are unaffected. + + + Which collections are available today is in the [Status](#status) + table below. That table is the only part of this page that changes + release to release. + + +## Setup + +1. **Run Postgres 13+.** Native install, Docker, or managed (Neon, RDS, + Supabase, Cloud SQL…) — all work. You'll need the `pgcrypto` and + `citext` extensions, both standard contrib modules available + everywhere. + +2. **Create a database and role** (skip if your managed provider gave + you these): + + ```sql + CREATE ROLE docsgpt LOGIN PASSWORD 'docsgpt'; + CREATE DATABASE docsgpt OWNER docsgpt; + ``` + +3. **Set `POSTGRES_URI` in `.env`.** Any standard Postgres URI works — + DocsGPT normalizes it internally. + + ```bash + POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt + # Append ?sslmode=require for managed providers that enforce SSL. + ``` + +4. **Apply the schema** (idempotent — safe to re-run): + + ```bash + python scripts/db/init_postgres.py + ``` + +## Migrating data + +Two global flags, no per-collection knobs — every collection marked ✅ +in the [Status](#status) table is handled automatically. + +1. **Enable dual-write.** Writes go to both Mongo and Postgres; Mongo + remains source of truth. Set the flag in `.env` and restart: + + ```bash + USE_POSTGRES=true + ``` + +2. **Backfill existing data.** Idempotent — re-run any time to re-sync + drifted rows. Without arguments, backfills every registered table; + pass `--tables` to limit. + + ```bash + python scripts/db/backfill.py --dry-run # preview everything + python scripts/db/backfill.py # real run, everything + python scripts/db/backfill.py --tables users # only specific tables + ``` + +3. **Cut over reads** once you trust the Postgres state: + + ```bash + READ_POSTGRES=true + ``` + + Rollback is instant: unset `READ_POSTGRES` and restart. Dual-write + keeps Postgres up to date so you can flip back and forth. + + + Don't decommission MongoDB until every collection you use is fully + cut over. During the migration window, Mongo is still required. + + +## Status + +_Last updated: 2026-04-10_ + +| Collection | Status | +|---|---| +| `users` | ✅ Phase 1 | +| `prompts`, `user_tools`, `feedback`, `stack_logs`, `user_logs`, `token_usage` | ⏳ Phase 1 | +| `agents`, `sources`, `attachments`, `memories`, `todos`, `notes`, `connector_sessions`, `agent_folders` | ⏳ Phase 2 | +| `conversations`, `pending_tool_state`, `workflows` | ⏳ Phase 3 | + +Schemas for **every** row above already exist after `init_postgres.py` +runs. What's landing progressively is the application-level dual-write +wiring and the backfill logic for each collection. Once a collection +is ✅, enabling `USE_POSTGRES=true` and running `python scripts/db/backfill.py` +picks it up automatically — no per-collection config change. + +## Troubleshooting + +- **`relation "..." does not exist`** — run `python scripts/db/init_postgres.py`. +- **`FATAL: role "docsgpt" does not exist`** — run the `CREATE ROLE` / + `CREATE DATABASE` statements from step 2 as a Postgres superuser. +- **SSL errors on a managed provider** — append `?sslmode=require` to + `POSTGRES_URI`. +- **Dual-write warnings in the logs** — expected to be non-fatal. Mongo + is source of truth, so the user-facing request succeeds. Re-run the + backfill to re-sync whichever rows drifted. diff --git a/docs/content/Deploying/_meta.js b/docs/content/Deploying/_meta.js index e0ae379f..71527117 100644 --- a/docs/content/Deploying/_meta.js +++ b/docs/content/Deploying/_meta.js @@ -19,6 +19,10 @@ export default { "title": "☁️ Hosting DocsGPT", "href": "/Deploying/Hosting-the-app" }, + "Postgres-Migration": { + "title": "🐘 PostgreSQL for User Data", + "href": "/Deploying/Postgres-Migration" + }, "Amazon-Lightsail": { "title": "Hosting DocsGPT on Amazon Lightsail", "href": "/Deploying/Amazon-Lightsail", diff --git a/scripts/db/backfill.py b/scripts/db/backfill.py new file mode 100644 index 00000000..f08c8bf8 --- /dev/null +++ b/scripts/db/backfill.py @@ -0,0 +1,218 @@ +"""Backfill DocsGPT's Postgres user-data tables from MongoDB. + +One script for every migrated collection. Adding a new collection is a +two-step change in this file: + +1. Write a ``_backfill_`` function that takes keyword args + ``conn``, ``mongo_db``, ``batch_size``, ``dry_run`` and returns a + stats ``dict``. +2. Add a single entry to :data:`BACKFILLERS`. + +There are intentionally no per-collection CLI flags or environment +variables — ``USE_POSTGRES`` / ``READ_POSTGRES`` in ``.env`` are the +only knobs operators need. This script discovers what's available from +the :data:`BACKFILLERS` registry and runs whichever tables were asked for. + +Usage:: + + python scripts/db/backfill.py # every registered table + python scripts/db/backfill.py --tables users # only specific tables + python scripts/db/backfill.py --dry-run # count without writing + python scripts/db/backfill.py --batch 1000 # tune commit size + +Exit codes: + 0 — every requested table completed successfully + 1 — misconfiguration (missing env var, unknown table name) + 2 — at least one table failed at runtime (others may still have succeeded) +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from pathlib import Path +from typing import Any, Callable + +# Make the project root importable regardless of cwd. +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from sqlalchemy import Connection, text # noqa: E402 + +from application.core.mongo_db import MongoDB # noqa: E402 +from application.core.settings import settings # noqa: E402 +from application.storage.db.engine import get_engine # noqa: E402 + +logger = logging.getLogger("backfill") + + +# --------------------------------------------------------------------------- +# Per-table backfillers +# --------------------------------------------------------------------------- + + +def _backfill_users( + *, + conn: Connection, + mongo_db: Any, + batch_size: int, + dry_run: bool, +) -> dict: + """Sync the ``users`` table from Mongo ``users`` collection. + + Overwrites each Postgres row's ``agent_preferences`` with the Mongo + state (Mongo is source of truth during the cutover window). Missing + ``pinned`` / ``shared_with_me`` keys are filled with empty arrays so + the Postgres row always has the full shape the application expects. + """ + upsert_sql = text( + """ + INSERT INTO users (user_id, agent_preferences) + VALUES (:user_id, CAST(:prefs AS jsonb)) + ON CONFLICT (user_id) DO UPDATE + SET agent_preferences = EXCLUDED.agent_preferences, + updated_at = now() + """ + ) + + cursor = ( + mongo_db["users"] + .find({}, no_cursor_timeout=True) + .batch_size(batch_size) + ) + + seen = 0 + written = 0 + skipped = 0 + batch: list[dict] = [] + + try: + for doc in cursor: + seen += 1 + user_id = doc.get("user_id") + if not user_id: + skipped += 1 + continue + + raw_prefs = doc.get("agent_preferences") or {} + prefs = { + "pinned": list(raw_prefs.get("pinned") or []), + "shared_with_me": list(raw_prefs.get("shared_with_me") or []), + } + batch.append({"user_id": user_id, "prefs": json.dumps(prefs)}) + + if len(batch) >= batch_size: + if not dry_run: + conn.execute(upsert_sql, batch) + written += len(batch) + batch.clear() + + if batch: + if not dry_run: + conn.execute(upsert_sql, batch) + written += len(batch) + finally: + cursor.close() + + return {"seen": seen, "written": written, "skipped_no_user_id": skipped} + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + + +BackfillFn = Callable[..., dict] + +# Register new tables here. Order matters only in the sense that +# ``--tables`` without arguments iterates in insertion order — put tables +# with FK dependencies after the tables they reference so a full-run +# backfill doesn't hit FK errors. +BACKFILLERS: dict[str, BackfillFn] = { + "users": _backfill_users, +} + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Backfill DocsGPT Postgres tables from MongoDB." + ) + parser.add_argument( + "--tables", + default="", + help=( + "Comma-separated table names to backfill. " + f"Defaults to every registered table ({','.join(BACKFILLERS)})." + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Iterate Mongo without writing to Postgres.", + ) + parser.add_argument( + "--batch", + type=int, + default=500, + help="How many rows to commit per Postgres statement (default: 500).", + ) + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-5s %(name)s %(message)s", + ) + + if not settings.POSTGRES_URI: + logger.error("POSTGRES_URI is not set. Configure it in .env first.") + return 1 + if not settings.MONGO_URI: + logger.error("MONGO_URI is not set. Configure it in .env first.") + return 1 + + requested = [t.strip() for t in args.tables.split(",") if t.strip()] + if not requested: + requested = list(BACKFILLERS) + + unknown = [t for t in requested if t not in BACKFILLERS] + if unknown: + logger.error( + "Unknown table(s): %s. Available: %s", + ", ".join(unknown), + ", ".join(BACKFILLERS), + ) + return 1 + + mongo = MongoDB.get_client() + mongo_db = mongo[settings.MONGO_DB_NAME] + engine = get_engine() + + failures = 0 + for table in requested: + logger.info("backfill %s: start", table) + try: + with engine.begin() as conn: + stats = BACKFILLERS[table]( + conn=conn, + mongo_db=mongo_db, + batch_size=args.batch, + dry_run=args.dry_run, + ) + logger.info( + "backfill %s: done %s dry_run=%s", table, stats, args.dry_run + ) + except Exception: + failures += 1 + logger.exception("backfill %s: failed", table) + + return 2 if failures else 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/db/init_postgres.py b/scripts/db/init_postgres.py new file mode 100644 index 00000000..f59b0e9f --- /dev/null +++ b/scripts/db/init_postgres.py @@ -0,0 +1,55 @@ +"""One-shot bootstrap: run all Alembic migrations against POSTGRES_URI. + +Intended use: + + * local dev, after setting ``POSTGRES_URI`` in ``.env``:: + + python scripts/db/init_postgres.py + + * CI, as a step before running the pytest suite. + + * Docker image build or container start, if the operator wants the + migrations applied automatically on first boot. + +This script is a thin wrapper around ``alembic upgrade head``. It exists +separately so the same command is discoverable from the repo root without +remembering the ``-c application/alembic.ini`` invocation. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +from alembic import command +from alembic.config import Config + +REPO_ROOT = Path(__file__).resolve().parents[2] +ALEMBIC_INI = REPO_ROOT / "application" / "alembic.ini" + + +def main() -> int: + """Apply every pending migration up to ``head``. + + Returns: + ``0`` on success, ``1`` on failure. Non-zero is propagated as the + process exit code so CI jobs fail loudly. + """ + if not ALEMBIC_INI.exists(): + print(f"alembic.ini not found at {ALEMBIC_INI}", file=sys.stderr) + return 1 + + cfg = Config(str(ALEMBIC_INI)) + # Make `script_location` resolve correctly when invoked from any cwd. + cfg.set_main_option("script_location", str(ALEMBIC_INI.parent / "alembic")) + + try: + command.upgrade(cfg, "head") + except Exception as exc: # noqa: BLE001 — surface everything to the operator + print(f"alembic upgrade failed: {exc}", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/agents/test_postgres_tool.py b/tests/agents/test_postgres_tool.py index 6610883a..b116d1a0 100644 --- a/tests/agents/test_postgres_tool.py +++ b/tests/agents/test_postgres_tool.py @@ -18,7 +18,7 @@ class TestPostgresExecuteAction: with pytest.raises(ValueError, match="Unknown action"): tool.execute_action("invalid") - @patch("application.agents.tools.postgres.psycopg2.connect") + @patch("application.agents.tools.postgres.psycopg.connect") def test_select_query(self, mock_connect, tool): mock_conn = MagicMock() mock_cur = MagicMock() @@ -37,7 +37,7 @@ class TestPostgresExecuteAction: assert result["response_data"]["data"][0] == {"id": 1, "name": "Alice"} mock_conn.close.assert_called_once() - @patch("application.agents.tools.postgres.psycopg2.connect") + @patch("application.agents.tools.postgres.psycopg.connect") def test_insert_query(self, mock_connect, tool): mock_conn = MagicMock() mock_cur = MagicMock() @@ -55,11 +55,11 @@ class TestPostgresExecuteAction: mock_conn.commit.assert_called_once() mock_conn.close.assert_called_once() - @patch("application.agents.tools.postgres.psycopg2.connect") + @patch("application.agents.tools.postgres.psycopg.connect") def test_db_error(self, mock_connect, tool): - import psycopg2 + import psycopg - mock_connect.side_effect = psycopg2.Error("connection refused") + mock_connect.side_effect = psycopg.Error("connection refused") result = tool.execute_action( "postgres_execute_sql", sql_query="SELECT 1" @@ -68,7 +68,7 @@ class TestPostgresExecuteAction: assert result["status_code"] == 500 assert "Database error" in result["error"] - @patch("application.agents.tools.postgres.psycopg2.connect") + @patch("application.agents.tools.postgres.psycopg.connect") def test_get_schema(self, mock_connect, tool): mock_conn = MagicMock() mock_cur = MagicMock() @@ -89,24 +89,24 @@ class TestPostgresExecuteAction: assert result["schema"]["users"][0]["column_name"] == "id" mock_conn.close.assert_called_once() - @patch("application.agents.tools.postgres.psycopg2.connect") + @patch("application.agents.tools.postgres.psycopg.connect") def test_get_schema_db_error(self, mock_connect, tool): - import psycopg2 + import psycopg - mock_connect.side_effect = psycopg2.Error("auth failed") + mock_connect.side_effect = psycopg.Error("auth failed") result = tool.execute_action("postgres_get_schema", db_name="testdb") assert result["status_code"] == 500 assert "Database error" in result["error"] - @patch("application.agents.tools.postgres.psycopg2.connect") + @patch("application.agents.tools.postgres.psycopg.connect") def test_connection_closed_on_error(self, mock_connect, tool): - import psycopg2 + import psycopg mock_conn = MagicMock() mock_cur = MagicMock() - mock_cur.execute.side_effect = psycopg2.Error("syntax error") + mock_cur.execute.side_effect = psycopg.Error("syntax error") mock_conn.cursor.return_value = mock_cur mock_connect.return_value = mock_conn @@ -114,7 +114,7 @@ class TestPostgresExecuteAction: mock_conn.close.assert_called_once() - @patch("application.agents.tools.postgres.psycopg2.connect") + @patch("application.agents.tools.postgres.psycopg.connect") def test_select_with_no_description(self, mock_connect, tool): mock_conn = MagicMock() mock_cur = MagicMock() diff --git a/tests/core/test_db_uri.py b/tests/core/test_db_uri.py new file mode 100644 index 00000000..bbce2187 --- /dev/null +++ b/tests/core/test_db_uri.py @@ -0,0 +1,144 @@ +"""Tests for ``application.core.db_uri``. + +DocsGPT has two Postgres connection strings — ``POSTGRES_URI`` (consumed +by SQLAlchemy) and ``PGVECTOR_CONNECTION_STRING`` (consumed by +``psycopg.connect()`` directly). They need opposite normalization +because SQLAlchemy requires a ``postgresql+psycopg://`` dialect prefix +and libpq rejects it. Each field has its own normalizer so operators +can write whichever form feels natural and cross-pollination between +the two fields is forgiven. + +The normalizers live in ``application.core.db_uri`` as plain functions +so these tests can exercise them directly without having to instantiate +``Settings`` (which would pull in ``.env`` file side effects). +""" + +from __future__ import annotations + +import pytest + +from application.core.db_uri import ( + normalize_pgvector_connection_string, + normalize_postgres_uri, +) + + +@pytest.mark.unit +class TestNormalizePostgresUri: + @pytest.mark.parametrize( + "input_value,expected", + [ + # User-friendly forms get rewritten to the SQLAlchemy dialect. + ( + "postgres://u:p@h:5432/d", + "postgresql+psycopg://u:p@h:5432/d", + ), + ( + "postgresql://u:p@h:5432/d", + "postgresql+psycopg://u:p@h:5432/d", + ), + # Legacy psycopg2 dialect is silently upgraded — psycopg2 is + # no longer in requirements.txt, so there's no way it can work + # as-is, and rewriting is friendlier than failing. + ( + "postgresql+psycopg2://u:p@h:5432/d", + "postgresql+psycopg://u:p@h:5432/d", + ), + # Already-correct dialect passes through unchanged. + ( + "postgresql+psycopg://u:p@h:5432/d", + "postgresql+psycopg://u:p@h:5432/d", + ), + # Whitespace is trimmed before rewriting. + ( + " postgres://u:p@h/d ", + "postgresql+psycopg://u:p@h/d", + ), + # Query-string params (sslmode, options) are preserved verbatim. + ( + "postgresql://u:p@h:5432/d?sslmode=require&application_name=docsgpt", + "postgresql+psycopg://u:p@h:5432/d?sslmode=require&application_name=docsgpt", + ), + ], + ) + def test_rewrites_common_forms_to_psycopg_dialect(self, input_value, expected): + assert normalize_postgres_uri(input_value) == expected + + @pytest.mark.parametrize( + "input_value", + [None, "", " ", "None", "none"], + ) + def test_empty_or_none_like_returns_none(self, input_value): + assert normalize_postgres_uri(input_value) is None + + def test_unknown_scheme_passes_through(self): + """A dialect we don't recognise is left alone so SQLAlchemy can + produce its own error message when the engine tries to connect. + Better than silently eating the config.""" + weird = "postgresql+asyncpg://u:p@h/d" + assert normalize_postgres_uri(weird) == weird + + def test_non_string_input_passes_through(self): + """Non-string inputs (e.g. if pydantic ever passes an int) shouldn't + crash the normalizer — let pydantic's own type validation handle it.""" + assert normalize_postgres_uri(42) == 42 # type: ignore[arg-type] + + +@pytest.mark.unit +class TestNormalizePgvectorConnectionString: + """Symmetric to the POSTGRES_URI normalizer but pulls in the OPPOSITE + direction: strips the SQLAlchemy dialect prefix so libpq accepts it. + """ + + @pytest.mark.parametrize( + "input_value,expected", + [ + # User-friendly forms pass through — libpq accepts them natively. + ( + "postgres://u:p@h:5432/d", + "postgres://u:p@h:5432/d", + ), + ( + "postgresql://u:p@h:5432/d", + "postgresql://u:p@h:5432/d", + ), + # SQLAlchemy dialect prefixes get stripped so libpq accepts them. + # Operators hit this when they copy POSTGRES_URI → PGVECTOR_CONNECTION_STRING. + ( + "postgresql+psycopg://u:p@h:5432/d", + "postgresql://u:p@h:5432/d", + ), + ( + "postgresql+psycopg2://u:p@h:5432/d", + "postgresql://u:p@h:5432/d", + ), + # Whitespace is trimmed before rewriting. + ( + " postgresql+psycopg://u:p@h/d ", + "postgresql://u:p@h/d", + ), + # Query-string params (sslmode, etc.) are preserved verbatim. + ( + "postgresql+psycopg://u:p@h:5432/d?sslmode=require", + "postgresql://u:p@h:5432/d?sslmode=require", + ), + ], + ) + def test_rewrites_dialect_forms_to_libpq_compatible(self, input_value, expected): + assert normalize_pgvector_connection_string(input_value) == expected + + @pytest.mark.parametrize( + "input_value", + [None, "", " ", "None", "none"], + ) + def test_empty_or_none_like_returns_none(self, input_value): + assert normalize_pgvector_connection_string(input_value) is None + + def test_unknown_scheme_passes_through(self): + """A scheme we don't recognise is left alone so libpq can produce + its own error message when the connection is attempted.""" + weird = "mysql://u:p@h/d" + assert normalize_pgvector_connection_string(weird) == weird + + def test_non_string_input_passes_through(self): + assert normalize_pgvector_connection_string(42) == 42 # type: ignore[arg-type] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..b1a04957 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,84 @@ +"""Fixtures for integration tests that hit a live Postgres. + +These tests are separate from unit tests in two ways: + +1. **Directory**: They live under ``tests/integration/``, which is + ignored by the default pytest run via ``--ignore=tests/integration`` + in ``pytest.ini``. The CI ``pytest.yml`` workflow therefore skips + them automatically — it runs the fast unit suite only. +2. **Marker**: Each test is marked ``@pytest.mark.integration`` so it + can be selected (or excluded) independently of its directory with + ``pytest -m integration`` / ``-m "not integration"``. + +Running the Postgres-backed integration tests manually:: + + .venv/bin/python -m pytest tests/integration/test_users_repository.py \\ + --override-ini="addopts=" --no-cov + +(The ``--override-ini="addopts="`` is needed because ``pytest.ini`` +contains ``--ignore=tests/integration`` in ``addopts``; without the +override pytest would still skip the directory even though you named +it on the command line.) + +Tests are skipped automatically if ``POSTGRES_URI`` is unset, so a +contributor who hasn't set up a local Postgres gets clean skips +instead of red tests. +""" + +from __future__ import annotations + +import pytest +from sqlalchemy import Engine, create_engine, text + +from application.core.settings import settings + + +@pytest.fixture(scope="session") +def pg_engine() -> Engine: + """Session-scoped SQLAlchemy engine for the Postgres integration DB. + + Skips all Postgres-backed tests if ``POSTGRES_URI`` is unset. This + keeps CI and contributor machines without a local Postgres from + erroring out — integration tests that require the DB become + no-ops rather than failures. + """ + if not settings.POSTGRES_URI: + pytest.skip("POSTGRES_URI not set — skipping Postgres integration tests") + engine = create_engine(settings.POSTGRES_URI, future=True, pool_pre_ping=True) + try: + yield engine + finally: + engine.dispose() + + +@pytest.fixture +def pg_conn(pg_engine: Engine): + """Per-test Postgres connection wrapped in a rolled-back transaction. + + Uses SQLAlchemy's explicit outer-transaction pattern so every test + sees a pristine DB view without having to truncate tables. Any + nested ``begin()`` inside the repository code becomes a SAVEPOINT + under the hood. + """ + conn = pg_engine.connect() + outer = conn.begin() + try: + yield conn + finally: + outer.rollback() + conn.close() + + +@pytest.fixture +def pg_clean_users(pg_conn): + """Guarantee a clean ``users`` table view for tests that need it. + + The outer transaction rollback handles cleanup, but if a previous + interrupted run left rows committed, this fixture removes them + inside the transaction scope so they are invisible to the test. + ``DELETE`` rather than ``TRUNCATE`` because ``TRUNCATE`` in Postgres + cannot be rolled back within a nested transaction the way + ``DELETE`` can. + """ + pg_conn.execute(text("DELETE FROM users")) + return pg_conn diff --git a/tests/integration/test_users_repository.py b/tests/integration/test_users_repository.py new file mode 100644 index 00000000..cd3b5e59 --- /dev/null +++ b/tests/integration/test_users_repository.py @@ -0,0 +1,179 @@ +"""Integration tests for ``UsersRepository`` against a live Postgres. + +These tests need: + +* A running Postgres reachable via ``POSTGRES_URI`` +* Alembic migration ``0001_initial`` applied + +They are skipped automatically by the default ``pytest`` run because +``pytest.ini`` has ``--ignore=tests/integration`` in ``addopts``. They +are additionally marked ``@pytest.mark.integration`` so they can be +selected/excluded explicitly. Run them locally with:: + + .venv/bin/python -m pytest tests/integration/test_users_repository.py \\ + --override-ini="addopts=" --no-cov + +Covers every operation the legacy Mongo code performs on +``users_collection``: upsert + get + add/remove on the pinned and +shared_with_me JSONB arrays, plus explicit tenant-isolation checks +(a user's operations must never touch another user's data). +""" + +from __future__ import annotations + +import pytest + +from application.storage.db.repositories.users import UsersRepository + + +@pytest.fixture +def repo(pg_clean_users): + return UsersRepository(pg_clean_users) + + +@pytest.mark.integration +class TestUpsert: + def test_upsert_creates_new_user_with_default_preferences(self, repo): + doc = repo.upsert("alice@example.com") + assert doc["user_id"] == "alice@example.com" + assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []} + assert "id" in doc + assert "_id" in doc # Mongo-compat alias + + def test_upsert_is_idempotent(self, repo): + first = repo.upsert("alice@example.com") + second = repo.upsert("alice@example.com") + assert first["id"] == second["id"] + assert first["agent_preferences"] == second["agent_preferences"] + + def test_upsert_preserves_existing_preferences(self, repo): + repo.upsert("alice@example.com") + repo.add_pinned("alice@example.com", "agent-1") + doc = repo.upsert("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-1"] + + +@pytest.mark.integration +class TestGet: + def test_get_returns_none_for_missing_user(self, repo): + assert repo.get("nobody@example.com") is None + + def test_get_returns_dict_for_existing_user(self, repo): + repo.upsert("alice@example.com") + doc = repo.get("alice@example.com") + assert doc is not None + assert doc["user_id"] == "alice@example.com" + assert doc["agent_preferences"] == {"pinned": [], "shared_with_me": []} + + +@pytest.mark.integration +class TestAddPinned: + def test_add_pinned_to_new_user_creates_user(self, repo): + repo.upsert("alice@example.com") + repo.add_pinned("alice@example.com", "agent-1") + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-1"] + + def test_add_pinned_is_idempotent(self, repo): + repo.upsert("alice@example.com") + repo.add_pinned("alice@example.com", "agent-1") + repo.add_pinned("alice@example.com", "agent-1") + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-1"] + + def test_add_pinned_preserves_order(self, repo): + repo.upsert("alice@example.com") + repo.add_pinned("alice@example.com", "agent-1") + repo.add_pinned("alice@example.com", "agent-2") + repo.add_pinned("alice@example.com", "agent-3") + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-1", "agent-2", "agent-3"] + + +@pytest.mark.integration +class TestRemovePinned: + def test_remove_pinned_single(self, repo): + repo.upsert("alice@example.com") + repo.add_pinned("alice@example.com", "agent-1") + repo.add_pinned("alice@example.com", "agent-2") + repo.remove_pinned("alice@example.com", "agent-1") + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-2"] + + def test_remove_pinned_missing_is_noop(self, repo): + repo.upsert("alice@example.com") + repo.add_pinned("alice@example.com", "agent-1") + repo.remove_pinned("alice@example.com", "agent-999") + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-1"] + + def test_remove_pinned_bulk(self, repo): + repo.upsert("alice@example.com") + for i in range(5): + repo.add_pinned("alice@example.com", f"agent-{i}") + repo.remove_pinned_bulk("alice@example.com", ["agent-1", "agent-3", "agent-999"]) + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-0", "agent-2", "agent-4"] + + +@pytest.mark.integration +class TestSharedWithMe: + def test_add_shared_is_idempotent(self, repo): + repo.upsert("alice@example.com") + repo.add_shared("alice@example.com", "agent-x") + repo.add_shared("alice@example.com", "agent-x") + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["shared_with_me"] == ["agent-x"] + + def test_remove_shared_bulk(self, repo): + repo.upsert("alice@example.com") + for i in range(3): + repo.add_shared("alice@example.com", f"shared-{i}") + repo.remove_shared_bulk("alice@example.com", ["shared-0", "shared-2"]) + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["shared_with_me"] == ["shared-1"] + + +@pytest.mark.integration +class TestRemoveAgentFromAll: + def test_removes_from_both_pinned_and_shared(self, repo): + repo.upsert("alice@example.com") + repo.add_pinned("alice@example.com", "agent-X") + repo.add_pinned("alice@example.com", "agent-keep") + repo.add_shared("alice@example.com", "agent-X") + repo.add_shared("alice@example.com", "agent-keep-2") + + repo.remove_agent_from_all("alice@example.com", "agent-X") + + doc = repo.get("alice@example.com") + assert doc["agent_preferences"]["pinned"] == ["agent-keep"] + assert doc["agent_preferences"]["shared_with_me"] == ["agent-keep-2"] + + +@pytest.mark.integration +class TestTenantIsolation: + """Security-critical: operations on one user must never touch another's data.""" + + def test_add_pinned_does_not_leak_across_users(self, repo): + repo.upsert("alice@example.com") + repo.upsert("bob@example.com") + repo.add_pinned("alice@example.com", "agent-a") + repo.add_pinned("bob@example.com", "agent-b") + + alice = repo.get("alice@example.com") + bob = repo.get("bob@example.com") + assert alice["agent_preferences"]["pinned"] == ["agent-a"] + assert bob["agent_preferences"]["pinned"] == ["agent-b"] + + def test_remove_does_not_leak_across_users(self, repo): + repo.upsert("alice@example.com") + repo.upsert("bob@example.com") + repo.add_pinned("alice@example.com", "shared-agent-id") + repo.add_pinned("bob@example.com", "shared-agent-id") + + repo.remove_pinned("alice@example.com", "shared-agent-id") + + assert repo.get("alice@example.com")["agent_preferences"]["pinned"] == [] + assert repo.get("bob@example.com")["agent_preferences"]["pinned"] == [ + "shared-agent-id" + ] diff --git a/tests/vectorstore/test_pgvector.py b/tests/vectorstore/test_pgvector.py index bbec02bf..08cfc11e 100644 --- a/tests/vectorstore/test_pgvector.py +++ b/tests/vectorstore/test_pgvector.py @@ -16,10 +16,9 @@ def _make_store( ) as mock_settings, patch.dict( "sys.modules", { - "psycopg2": MagicMock(), - "psycopg2.extras": MagicMock(), + "psycopg": MagicMock(), "pgvector": MagicMock(), - "pgvector.psycopg2": MagicMock(), + "pgvector.psycopg": MagicMock(), }, ): mock_emb = Mock() @@ -63,10 +62,9 @@ class TestPGVectorStoreInit: ) as mock_settings, patch.dict( "sys.modules", { - "psycopg2": MagicMock(), - "psycopg2.extras": MagicMock(), + "psycopg": MagicMock(), "pgvector": MagicMock(), - "pgvector.psycopg2": MagicMock(), + "pgvector.psycopg": MagicMock(), }, ): mock_get_emb.return_value = Mock(dimension=768) @@ -264,13 +262,13 @@ class TestPGVectorStoreConnection: store, mock_conn, _, _ = _make_store() mock_conn.closed = True - mock_psycopg2 = MagicMock() + mock_psycopg = MagicMock() new_conn = MagicMock() - mock_psycopg2.connect.return_value = new_conn - store._psycopg2 = mock_psycopg2 + mock_psycopg.connect.return_value = new_conn + store._psycopg = mock_psycopg conn = store._get_connection() - mock_psycopg2.connect.assert_called_once() + mock_psycopg.connect.assert_called_once() assert conn is new_conn def test_get_connection_reuses_open(self):