Compare commits

..

3 Commits
pg-1 ... pg-2

Author SHA1 Message Date
Alex
cb30a24e05 feat: fixes on pg2 2026-04-12 13:51:29 +01:00
Alex
530761d08c feat: pg-2 2026-04-12 13:35:32 +01:00
Alex
73fbc28744 Merge pull request #2376 from arc53/pg-1
Pg 1
2026-04-12 12:44:12 +01:00
26 changed files with 2343 additions and 158 deletions

View File

@@ -0,0 +1,57 @@
"""0002 add unique constraints for notes and connector_sessions.
The memories table already has ``memories_user_tool_path_uidx`` from the
0001 baseline. Notes and connector_sessions were missing unique constraints
that their repository upsert logic depends on.
Before creating the indexes, duplicate rows are cleaned up — keeping only
the row with the latest ``id`` (UUID, lexicographic max) per group.
Revision ID: 0002_add_unique_constraints
Revises: 0001_initial
Create Date: 2026-04-12
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002_add_unique_constraints"
down_revision: Union[str, None] = "0001_initial"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Deduplicate notes: keep one row per (user_id, tool_id)
op.execute("""
DELETE FROM notes
WHERE id NOT IN (
SELECT DISTINCT ON (user_id, tool_id) id
FROM notes
ORDER BY user_id, tool_id, created_at DESC
);
""")
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS notes_user_tool_uidx "
"ON notes (user_id, tool_id);"
)
# Deduplicate connector_sessions: keep one row per (user_id, provider)
op.execute("""
DELETE FROM connector_sessions
WHERE id NOT IN (
SELECT DISTINCT ON (user_id, provider) id
FROM connector_sessions
ORDER BY user_id, provider, created_at DESC
);
""")
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS connector_sessions_user_provider_uidx "
"ON connector_sessions (user_id, provider);"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS connector_sessions_user_provider_uidx;")
op.execute("DROP INDEX IF EXISTS notes_user_tool_uidx;")

View File

@@ -13,6 +13,8 @@ from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
@@ -83,6 +85,10 @@ class AgentFolders(Resource):
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
dual_write(
AgentFoldersRepository,
lambda repo, u=user, n=data["name"]: repo.create(u, n),
)
return make_response(
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
@@ -167,6 +173,10 @@ class AgentFolder(Resource):
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
dual_write(
AgentFoldersRepository,
lambda repo, fid=folder_id, u=user: repo.delete(fid, u),
)
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)

View File

@@ -24,6 +24,7 @@ from application.api.user.base import (
workflows_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.users import UsersRepository
from application.core.json_schema_utils import (
JsonSchemaValidationError,
@@ -623,6 +624,17 @@ class CreateAgent(Resource):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
dual_write(
AgentsRepository,
lambda repo, u=user, a=new_agent: repo.create(
u, a.get("name", ""), a.get("status", "draft"),
key=a.get("key"), description=a.get("description"),
retriever=a.get("retriever"), chunks=a.get("chunks"),
tools=a.get("tools"), models=a.get("models"),
shared=a.get("shared", False),
incoming_webhook_token=a.get("incoming_webhook_token"),
),
)
except Exception as err:
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1185,6 +1197,10 @@ class DeleteAgent(Resource):
deleted_agent = agents_collection.find_one_and_delete(
{"_id": ObjectId(agent_id), "user": user}
)
dual_write(
AgentsRepository,
lambda repo, aid=agent_id, u=user: repo.delete(aid, u),
)
if not deleted_agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404

View File

@@ -9,15 +9,18 @@ If the two drift, migrations win — update this file to match.
from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
MetaData,
UniqueConstraint,
Table,
Text,
func,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
metadata = MetaData()
@@ -109,3 +112,121 @@ stack_logs_table = Table(
Column("stacks", JSONB, nullable=False, server_default="[]"),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
# --- Phase 2, Tier 2 --------------------------------------------------------
agent_folders_table = Table(
"agent_folders",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("description", Text),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
sources_table = Table(
"sources",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text),
Column("name", Text, nullable=False),
Column("type", Text),
Column("metadata", JSONB, nullable=False, server_default="{}"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
agents_table = Table(
"agents",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("description", Text),
Column("agent_type", Text),
Column("status", Text, nullable=False),
Column("key", Text, unique=True),
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
Column("chunks", Integer),
Column("retriever", Text),
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
Column("tools", JSONB, nullable=False, server_default="[]"),
Column("json_schema", JSONB),
Column("models", JSONB),
Column("default_model_id", Text),
Column("folder_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
Column("limited_token_mode", Boolean, nullable=False, server_default="false"),
Column("token_limit", Integer),
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
Column("request_limit", Integer),
Column("shared", Boolean, nullable=False, server_default="false"),
Column("incoming_webhook_token", Text, unique=True),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("last_used_at", DateTime(timezone=True)),
)
attachments_table = Table(
"attachments",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("filename", Text, nullable=False),
Column("upload_path", Text, nullable=False),
Column("mime_type", Text),
Column("size", BigInteger),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
memories_table = Table(
"memories",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
Column("path", Text, nullable=False),
Column("content", Text, nullable=False),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("user_id", "tool_id", "path", name="memories_user_tool_path_uidx"),
)
todos_table = Table(
"todos",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
Column("title", Text, nullable=False),
Column("completed", Boolean, nullable=False, server_default="false"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
notes_table = Table(
"notes",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
Column("title", Text, nullable=False),
Column("content", Text, nullable=False),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("user_id", "tool_id", name="notes_user_tool_uidx"),
)
connector_sessions_table = Table(
"connector_sessions",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("provider", Text, nullable=False),
Column("session_data", JSONB, nullable=False),
Column("expires_at", DateTime(timezone=True)),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("user_id", "provider", name="connector_sessions_user_provider_uidx"),
)

View File

@@ -0,0 +1,88 @@
"""Repository for the ``agent_folders`` table."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class AgentFoldersRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, name: str, *, description: Optional[str] = None) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO agent_folders (user_id, name, description)
VALUES (:user_id, :name, :description)
RETURNING *
"""
),
{"user_id": user_id, "name": name, "description": description},
)
return row_to_dict(result.fetchone())
def get(self, folder_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": folder_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM agent_folders WHERE user_id = :user_id ORDER BY created_at"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, folder_id: str, user_id: str, fields: dict) -> bool:
allowed = {"name", "description"}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return False
params: dict = {"id": folder_id, "user_id": user_id}
if "name" in filtered and "description" in filtered:
params["name"] = filtered["name"]
params["description"] = filtered["description"]
result = self._conn.execute(
text(
"UPDATE agent_folders "
"SET name = :name, description = :description, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
params,
)
elif "name" in filtered:
params["name"] = filtered["name"]
result = self._conn.execute(
text(
"UPDATE agent_folders "
"SET name = :name, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
params,
)
else:
params["description"] = filtered["description"]
result = self._conn.execute(
text(
"UPDATE agent_folders "
"SET description = :description, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
params,
)
return result.rowcount > 0
def delete(self, folder_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": folder_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -0,0 +1,154 @@
"""Repository for the ``agents`` table.
This is the most complex Phase 2 repository. Covers every write operation
the legacy Mongo code performs on ``agents_collection``:
- create, update, delete
- find by key (API key lookup)
- find by webhook token
- list for user, list templates
- folder assignment
"""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, func, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import agents_table
class AgentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, name: str, status: str, **kwargs) -> dict:
values: dict = {"user_id": user_id, "name": name, "status": status}
_ALLOWED = {
"description", "agent_type", "key", "retriever",
"default_model_id", "incoming_webhook_token",
"source_id", "prompt_id", "folder_id",
"chunks", "token_limit", "request_limit",
"limited_token_mode", "limited_request_mode", "shared",
"tools", "json_schema", "models",
}
for col, val in kwargs.items():
if col not in _ALLOWED or val is None:
continue
if col in ("tools", "json_schema", "models"):
values[col] = json.dumps(val)
elif col in ("chunks", "token_limit", "request_limit"):
values[col] = int(val)
elif col in ("limited_token_mode", "limited_request_mode", "shared"):
values[col] = bool(val)
elif col in ("source_id", "prompt_id", "folder_id"):
values[col] = str(val)
else:
values[col] = val
stmt = pg_insert(agents_table).values(**values).returning(agents_table)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get(self, agent_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": agent_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def find_by_key(self, key: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE key = :key"),
{"key": key},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def find_by_webhook_token(self, token: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE incoming_webhook_token = :token"),
{"token": token},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def list_templates(self) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE user_id = 'system' ORDER BY name"),
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, agent_id: str, user_id: str, fields: dict) -> bool:
allowed = {
"name", "description", "agent_type", "status", "key", "source_id",
"chunks", "retriever", "prompt_id", "tools", "json_schema", "models",
"default_model_id", "folder_id", "limited_token_mode", "token_limit",
"limited_request_mode", "request_limit", "shared",
"incoming_webhook_token", "last_used_at",
}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return False
values: dict = {}
for col, val in filtered.items():
if col in ("tools", "json_schema", "models"):
values[col] = json.dumps(val) if not isinstance(val, str) else val
elif col in ("source_id", "prompt_id", "folder_id"):
values[col] = str(val) if val else None
else:
values[col] = val
values["updated_at"] = func.now()
t = agents_table
stmt = (
t.update()
.where(t.c.id == agent_id)
.where(t.c.user_id == user_id)
.values(**values)
)
result = self._conn.execute(stmt)
return result.rowcount > 0
def delete(self, agent_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": agent_id, "user_id": user_id},
)
return result.rowcount > 0
def set_folder(self, agent_id: str, user_id: str, folder_id: Optional[str]) -> None:
self._conn.execute(
text(
"""
UPDATE agents SET folder_id = CAST(:folder_id AS uuid), updated_at = now()
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
"""
),
{"id": agent_id, "user_id": user_id, "folder_id": folder_id},
)
def clear_folder_for_all(self, folder_id: str, user_id: str) -> None:
"""Remove folder assignment from all agents in a folder (used on folder delete)."""
self._conn.execute(
text(
"UPDATE agents SET folder_id = NULL, updated_at = now() "
"WHERE folder_id = CAST(:folder_id AS uuid) AND user_id = :user_id"
),
{"folder_id": folder_id, "user_id": user_id},
)

View File

@@ -0,0 +1,51 @@
"""Repository for the ``attachments`` table."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class AttachmentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, filename: str, upload_path: str, *,
mime_type: Optional[str] = None, size: Optional[int] = None) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size)
VALUES (:user_id, :filename, :upload_path, :mime_type, :size)
RETURNING *
"""
),
{
"user_id": user_id,
"filename": filename,
"upload_path": upload_path,
"mime_type": mime_type,
"size": size,
},
)
return row_to_dict(result.fetchone())
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM attachments WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": attachment_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]

View File

@@ -0,0 +1,65 @@
"""Repository for the ``connector_sessions`` table.
Covers operations across connector routes and tools:
- upsert session data
- find session by user + provider
- find session by token
- delete session
"""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class ConnectorSessionsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def upsert(self, user_id: str, provider: str, session_data: dict) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO connector_sessions (user_id, provider, session_data)
VALUES (:user_id, :provider, CAST(:session_data AS jsonb))
ON CONFLICT (user_id, provider)
DO UPDATE SET session_data = EXCLUDED.session_data
RETURNING *
"""
),
{
"user_id": user_id,
"provider": provider,
"session_data": json.dumps(session_data),
},
)
return row_to_dict(result.fetchone())
def get_by_user_provider(self, user_id: str, provider: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"
),
{"user_id": user_id, "provider": provider},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM connector_sessions WHERE user_id = :user_id"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def delete(self, user_id: str, provider: str) -> bool:
result = self._conn.execute(
text("DELETE FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"),
{"user_id": user_id, "provider": provider},
)
return result.rowcount > 0

View File

@@ -0,0 +1,97 @@
"""Repository for the ``memories`` table.
Covers the operations in ``application/agents/tools/memory.py``:
- upsert (create/overwrite file)
- find by path (view file)
- find by path prefix (view directory, regex scan)
- delete by path / path prefix
- rename (update path)
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class MemoriesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def upsert(self, user_id: str, tool_id: str, path: str, content: str) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO memories (user_id, tool_id, path, content)
VALUES (:user_id, CAST(:tool_id AS uuid), :path, :content)
ON CONFLICT (user_id, tool_id, path)
DO UPDATE SET content = EXCLUDED.content, updated_at = now()
RETURNING *
"""
),
{"user_id": user_id, "tool_id": tool_id, "path": path, "content": content},
)
return row_to_dict(result.fetchone())
def get_by_path(self, user_id: str, tool_id: str, path: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
),
{"user_id": user_id, "tool_id": tool_id, "path": path},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
),
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
)
return [row_to_dict(r) for r in result.fetchall()]
def delete_by_path(self, user_id: str, tool_id: str, path: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
),
{"user_id": user_id, "tool_id": tool_id, "path": path},
)
return result.rowcount
def delete_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM memories WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
),
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
)
return result.rowcount
def delete_all(self, user_id: str, tool_id: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM memories WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
),
{"user_id": user_id, "tool_id": tool_id},
)
return result.rowcount
def update_path(self, user_id: str, tool_id: str, old_path: str, new_path: str) -> bool:
result = self._conn.execute(
text(
"UPDATE memories SET path = :new_path, updated_at = now() "
"WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid) AND path = :old_path"
),
{"user_id": user_id, "tool_id": tool_id, "old_path": old_path, "new_path": new_path},
)
return result.rowcount > 0

View File

@@ -0,0 +1,62 @@
"""Repository for the ``notes`` table.
Covers the operations in ``application/agents/tools/notes.py``.
Note: the Mongo schema stores a single ``note`` text field per (user_id, tool_id),
while the Postgres schema has ``title`` + ``content``. During dual-write,
title is set to a default and content holds the note text.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class NotesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def upsert(self, user_id: str, tool_id: str, title: str, content: str) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO notes (user_id, tool_id, title, content)
VALUES (:user_id, CAST(:tool_id AS uuid), :title, :content)
ON CONFLICT (user_id, tool_id)
DO UPDATE SET content = EXCLUDED.content, title = EXCLUDED.title, updated_at = now()
RETURNING *
"""
),
{"user_id": user_id, "tool_id": tool_id, "title": title, "content": content},
)
return row_to_dict(result.fetchone())
def get_for_user_tool(self, user_id: str, tool_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
),
{"user_id": user_id, "tool_id": tool_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get(self, note_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM notes WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": note_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def delete(self, user_id: str, tool_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
),
{"user_id": user_id, "tool_id": tool_id},
)
return result.rowcount > 0

View File

@@ -40,17 +40,24 @@ class PromptsRepository:
)
return row_to_dict(result.fetchone())
def get(self, prompt_id: str, user_id: Optional[str] = None) -> Optional[dict]:
if user_id is not None:
result = self._conn.execute(
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": prompt_id, "user_id": user_id},
)
else:
result = self._conn.execute(
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid)"),
{"id": prompt_id},
)
def get(self, prompt_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": prompt_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
"""Fetch prompt content by ID without user scoping.
Used only by stream_processor to render a prompt whose owner is
not known at call time. Do NOT use in user-facing routes.
"""
result = self._conn.execute(
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid)"),
{"id": prompt_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None

View File

@@ -0,0 +1,80 @@
"""Repository for the ``sources`` table."""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, func, text
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import sources_table
class SourcesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, name: str, *, user_id: Optional[str] = None,
type: Optional[str] = None, metadata: Optional[dict] = None) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO sources (user_id, name, type, metadata)
VALUES (:user_id, :name, :type, CAST(:metadata AS jsonb))
RETURNING *
"""
),
{
"user_id": user_id,
"name": name,
"type": type,
"metadata": json.dumps(metadata or {}),
},
)
return row_to_dict(result.fetchone())
def get(self, source_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": source_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM sources WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, source_id: str, user_id: str, fields: dict) -> None:
allowed = {"name", "type", "metadata"}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return
values: dict = {}
for col, val in filtered.items():
if col == "metadata":
values[col] = json.dumps(val) if isinstance(val, dict) else val
else:
values[col] = val
values["updated_at"] = func.now()
t = sources_table
stmt = (
t.update()
.where(t.c.id == source_id)
.where(t.c.user_id == user_id)
.values(**values)
)
self._conn.execute(stmt)
def delete(self, source_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": source_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -0,0 +1,78 @@
"""Repository for the ``todos`` table.
Covers the operations in ``application/agents/tools/todo_list.py``.
Note: the Mongo schema uses ``todo_id`` (sequential int) and ``status`` (text),
while the Postgres schema uses ``completed`` (boolean) and the UUID ``id`` as PK.
The repository bridges both shapes.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class TodosRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, tool_id: str, title: str) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO todos (user_id, tool_id, title)
VALUES (:user_id, CAST(:tool_id AS uuid), :title)
RETURNING *
"""
),
{"user_id": user_id, "tool_id": tool_id, "title": title},
)
return row_to_dict(result.fetchone())
def get(self, todo_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": todo_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user_tool(self, user_id: str, tool_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM todos WHERE user_id = :user_id "
"AND tool_id = CAST(:tool_id AS uuid) ORDER BY created_at"
),
{"user_id": user_id, "tool_id": tool_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update_title(self, todo_id: str, user_id: str, title: str) -> bool:
result = self._conn.execute(
text(
"UPDATE todos SET title = :title, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": todo_id, "user_id": user_id, "title": title},
)
return result.rowcount > 0
def set_completed(self, todo_id: str, user_id: str, completed: bool = True) -> bool:
result = self._conn.execute(
text(
"UPDATE todos SET completed = :completed, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": todo_id, "user_id": user_id, "completed": completed},
)
return result.rowcount > 0
def delete(self, todo_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": todo_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -53,10 +53,10 @@ class UserToolsRepository:
)
return row_to_dict(result.fetchone())
def get(self, tool_id: str) -> Optional[dict]:
def get(self, tool_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM user_tools WHERE id = CAST(:id AS uuid)"),
{"id": tool_id},
text("SELECT * FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": tool_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None

View File

@@ -1173,6 +1173,16 @@ def attachment_worker(self, file_info, user):
}
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.attachments import AttachmentsRepository
dual_write(
AttachmentsRepository,
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type: repo.create(
u, fn, p, mime_type=mt,
),
)
logging.info(
f"Stored attachment with ID: {attachment_id}", extra={"user": user}
)

View File

@@ -9,8 +9,8 @@ two-step change in this file:
2. Add a single entry to :data:`BACKFILLERS`.
There are intentionally no per-collection CLI flags or environment
variables — ``USE_POSTGRES`` in ``.env`` is the only knob operators
need during the migration window. This script discovers what's available from
variables — ``USE_POSTGRES`` / ``READ_POSTGRES`` in ``.env`` are the
only knobs operators need. This script discovers what's available from
the :data:`BACKFILLERS` registry and runs whichever tables were asked for.
Usage::
@@ -119,13 +119,8 @@ def _backfill_users(
def _backfill_prompts(
*,
conn: Connection,
mongo_db: Any,
batch_size: int,
dry_run: bool,
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``prompts`` table from Mongo ``prompts`` collection."""
upsert_sql = text(
"""
INSERT INTO prompts (user_id, name, content)
@@ -133,18 +128,9 @@ def _backfill_prompts(
ON CONFLICT DO NOTHING
"""
)
cursor = (
mongo_db["prompts"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = 0
written = 0
skipped = 0
cursor = mongo_db["prompts"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
@@ -162,25 +148,18 @@ def _backfill_prompts(
conn.execute(upsert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(upsert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped_no_user": skipped}
def _backfill_user_tools(
*,
conn: Connection,
mongo_db: Any,
batch_size: int,
dry_run: bool,
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``user_tools`` table from Mongo ``user_tools`` collection."""
insert_sql = text(
"""
INSERT INTO user_tools (user_id, name, custom_name, display_name, config)
@@ -188,18 +167,9 @@ def _backfill_user_tools(
ON CONFLICT DO NOTHING
"""
)
cursor = (
mongo_db["user_tools"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = 0
written = 0
skipped = 0
cursor = mongo_db["user_tools"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
@@ -219,29 +189,18 @@ def _backfill_user_tools(
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped_no_user": skipped}
def _backfill_feedback(
*,
conn: Connection,
mongo_db: Any,
batch_size: int,
dry_run: bool,
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``feedback`` table from Mongo ``feedback`` collection.
feedback.conversation_id is stored as a string UUID. Rows whose
conversation_id cannot be cast to UUID are skipped.
"""
insert_sql = text(
"""
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text, timestamp)
@@ -249,18 +208,9 @@ def _backfill_feedback(
ON CONFLICT DO NOTHING
"""
)
cursor = (
mongo_db["feedback"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = 0
written = 0
skipped = 0
cursor = mongo_db["feedback"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
@@ -281,43 +231,27 @@ def _backfill_feedback(
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_stack_logs(
*,
conn: Connection,
mongo_db: Any,
batch_size: int,
dry_run: bool,
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``stack_logs`` table from Mongo ``stack_logs`` collection."""
insert_sql = text(
"""
INSERT INTO stack_logs (activity_id, endpoint, level, user_id, api_key, query, stacks, timestamp)
VALUES (:activity_id, :endpoint, :level, :user_id, :api_key, :query, CAST(:stacks AS jsonb), :timestamp)
"""
)
cursor = (
mongo_db["stack_logs"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = 0
written = 0
skipped = 0
cursor = mongo_db["stack_logs"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
@@ -340,48 +274,31 @@ def _backfill_stack_logs(
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped_no_id": skipped}
def _backfill_user_logs(
*,
conn: Connection,
mongo_db: Any,
batch_size: int,
dry_run: bool,
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``user_logs`` table from Mongo ``user_logs`` collection."""
insert_sql = text(
"""
INSERT INTO user_logs (user_id, endpoint, data, timestamp)
VALUES (:user_id, :endpoint, CAST(:data AS jsonb), :timestamp)
"""
)
cursor = (
mongo_db["user_logs"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = 0
written = 0
cursor = mongo_db["user_logs"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
# Build a JSONB payload from the full doc (minus Mongo internals).
data_payload = {k: v for k, v in doc.items() if k != "_id"}
# Stringify ObjectId values inside the payload.
for k, v in data_payload.items():
if hasattr(v, "__str__") and type(v).__name__ == "ObjectId":
data_payload[k] = str(v)
@@ -396,25 +313,18 @@ def _backfill_user_logs(
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written}
def _backfill_token_usage(
*,
conn: Connection,
mongo_db: Any,
batch_size: int,
dry_run: bool,
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``token_usage`` table from Mongo ``token_usage`` collection."""
insert_sql = text(
"""
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
@@ -425,23 +335,13 @@ def _backfill_token_usage(
)
"""
)
cursor = (
mongo_db["token_usage"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = 0
written = 0
cursor = mongo_db["token_usage"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
agent_id = doc.get("agent_id")
# agent_id may be an ObjectId string or None — only pass if
# it looks like a valid UUID (from dual-write) or skip it.
agent_id_str = None
if agent_id:
s = str(agent_id)
@@ -460,17 +360,426 @@ def _backfill_token_usage(
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written}
# ---------------------------------------------------------------------------
# Phase 2 backfillers
# ---------------------------------------------------------------------------
def _backfill_agent_folders(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
upsert_sql = text(
"""
INSERT INTO agent_folders (user_id, name, description)
VALUES (:user_id, :name, :description)
ON CONFLICT DO NOTHING
"""
)
cursor = mongo_db["agent_folders"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user")
if not user_id:
skipped += 1
continue
batch.append({
"user_id": user_id,
"name": doc.get("name", ""),
"description": doc.get("description"),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(upsert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(upsert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_sources(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
insert_sql = text(
"""
INSERT INTO sources (user_id, name, type, metadata)
VALUES (:user_id, :name, :type, CAST(:metadata AS jsonb))
ON CONFLICT DO NOTHING
"""
)
cursor = mongo_db["sources"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
# user may be absent for system sources
raw_meta = doc.get("metadata") or {}
# Strip non-serializable values from metadata
clean_meta = {}
for k, v in raw_meta.items():
if hasattr(v, "__str__") and type(v).__name__ == "ObjectId":
clean_meta[k] = str(v)
else:
clean_meta[k] = v
batch.append({
"user_id": doc.get("user"),
"name": doc.get("name", ""),
"type": doc.get("type"),
"metadata": json.dumps(clean_meta, default=str),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written}
def _backfill_agents(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
insert_sql = text(
"""
INSERT INTO agents (
user_id, name, status, key, description, agent_type,
chunks, retriever, default_model_id,
tools, json_schema, models,
limited_token_mode, token_limit, limited_request_mode, request_limit,
shared, incoming_webhook_token
) VALUES (
:user_id, :name, :status, :key, :description, :agent_type,
:chunks, :retriever, :default_model_id,
CAST(:tools AS jsonb), CAST(:json_schema AS jsonb), CAST(:models AS jsonb),
:limited_token_mode, :token_limit, :limited_request_mode, :request_limit,
:shared, :incoming_webhook_token
)
ON CONFLICT DO NOTHING
"""
)
cursor = mongo_db["agents"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user")
if not user_id:
skipped += 1
continue
batch.append({
"user_id": user_id,
"name": doc.get("name", ""),
"status": doc.get("status", "draft"),
"key": doc.get("key"),
"description": doc.get("description"),
"agent_type": doc.get("agent_type"),
"chunks": doc.get("chunks"),
"retriever": doc.get("retriever"),
"default_model_id": doc.get("default_model_id"),
"tools": json.dumps(doc.get("tools") or []),
"json_schema": json.dumps(doc.get("json_schema")) if doc.get("json_schema") else None,
"models": json.dumps(doc.get("models")) if doc.get("models") else None,
"limited_token_mode": bool(doc.get("limited_token_mode", False)),
"token_limit": doc.get("token_limit"),
"limited_request_mode": bool(doc.get("limited_request_mode", False)),
"request_limit": doc.get("request_limit"),
"shared": bool(doc.get("shared", False)),
"incoming_webhook_token": doc.get("incoming_webhook_token"),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_attachments(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
insert_sql = text(
"""
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size)
VALUES (:user_id, :filename, :upload_path, :mime_type, :size)
"""
)
cursor = mongo_db["attachments"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user")
if not user_id:
skipped += 1
continue
batch.append({
"user_id": user_id,
"filename": doc.get("filename", ""),
"upload_path": doc.get("upload_path", ""),
"mime_type": doc.get("mime_type"),
"size": doc.get("size"),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _build_tool_id_map(conn: Connection, mongo_db: Any) -> dict[str, str]:
"""Build a mapping from Mongo user_tools ObjectId → Postgres user_tools UUID.
The Mongo ``_id`` (ObjectId) for each user_tools doc has no equivalent in
Postgres. We match rows by ``(user_id, name)`` — which is the natural key
for a tool — and return ``{str(mongo_oid): str(pg_uuid)}``.
This is called once before memories/todos/notes backfill so those
collections can resolve their ``tool_id`` foreign keys.
"""
# Build the Postgres side: (user_id, name) → UUID
pg_rows = conn.execute(
text("SELECT id, user_id, name FROM user_tools")
).fetchall()
pg_lookup: dict[tuple[str, str], str] = {}
for row in pg_rows:
m = row._mapping
pg_lookup[(m["user_id"], m["name"])] = str(m["id"])
# Walk the Mongo side and match
mapping: dict[str, str] = {}
for doc in mongo_db["user_tools"].find({}, {"_id": 1, "user": 1, "name": 1}):
user_id = doc.get("user")
name = doc.get("name")
if not user_id or not name:
continue
pg_uuid = pg_lookup.get((user_id, name))
if pg_uuid:
mapping[str(doc["_id"])] = pg_uuid
return mapping
def _resolve_tool_id(tool_id_raw: Any, tool_id_map: dict[str, str]) -> str | None:
"""Convert a Mongo tool_id (ObjectId or string) to a Postgres UUID string.
Returns the mapped UUID, or None if the tool_id can't be resolved.
"""
if not tool_id_raw:
return None
s = str(tool_id_raw)
# Already a UUID (36 chars with dashes) — pass through
if len(s) == 36 and "-" in s:
return s
# Mongo ObjectId (24 hex chars) — look up in map
return tool_id_map.get(s)
def _backfill_memories(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
tool_id_map = _build_tool_id_map(conn, mongo_db)
insert_sql = text(
"""
INSERT INTO memories (user_id, tool_id, path, content)
VALUES (:user_id, CAST(:tool_id AS uuid), :path, :content)
ON CONFLICT DO NOTHING
"""
)
cursor = mongo_db["memories"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user_id")
pg_tool_id = _resolve_tool_id(doc.get("tool_id"), tool_id_map)
if not user_id or not pg_tool_id:
skipped += 1
continue
batch.append({
"user_id": user_id,
"tool_id": pg_tool_id,
"path": doc.get("path", "/"),
"content": doc.get("content", ""),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_todos(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
tool_id_map = _build_tool_id_map(conn, mongo_db)
insert_sql = text(
"""
INSERT INTO todos (user_id, tool_id, title, completed)
VALUES (:user_id, CAST(:tool_id AS uuid), :title, :completed)
"""
)
cursor = mongo_db["todos"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user_id")
pg_tool_id = _resolve_tool_id(doc.get("tool_id"), tool_id_map)
if not user_id or not pg_tool_id:
skipped += 1
continue
status = doc.get("status", "open")
batch.append({
"user_id": user_id,
"tool_id": pg_tool_id,
"title": doc.get("title", ""),
"completed": status == "completed",
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_notes(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
tool_id_map = _build_tool_id_map(conn, mongo_db)
insert_sql = text(
"""
INSERT INTO notes (user_id, tool_id, title, content)
VALUES (:user_id, CAST(:tool_id AS uuid), :title, :content)
ON CONFLICT (user_id, tool_id) DO UPDATE
SET content = EXCLUDED.content, title = EXCLUDED.title
"""
)
cursor = mongo_db["notes"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user_id")
pg_tool_id = _resolve_tool_id(doc.get("tool_id"), tool_id_map)
if not user_id or not pg_tool_id:
skipped += 1
continue
batch.append({
"user_id": user_id,
"tool_id": pg_tool_id,
"title": doc.get("title", "note"),
"content": doc.get("note") or doc.get("content", ""),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_connector_sessions(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
insert_sql = text(
"""
INSERT INTO connector_sessions (user_id, provider, session_data)
VALUES (:user_id, :provider, CAST(:session_data AS jsonb))
ON CONFLICT (user_id, provider) DO UPDATE
SET session_data = EXCLUDED.session_data
"""
)
cursor = mongo_db["connector_sessions"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user_id") or doc.get("user")
provider = doc.get("provider")
if not user_id or not provider:
skipped += 1
continue
session_data = {k: v for k, v in doc.items() if k not in ("_id", "user_id", "user", "provider")}
batch.append({
"user_id": user_id,
"provider": provider,
"session_data": json.dumps(session_data, default=str),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
@@ -483,6 +792,7 @@ BackfillFn = Callable[..., dict]
# with FK dependencies after the tables they reference so a full-run
# backfill doesn't hit FK errors.
BACKFILLERS: dict[str, BackfillFn] = {
# Phase 1
"users": _backfill_users,
"prompts": _backfill_prompts,
"user_tools": _backfill_user_tools,
@@ -490,6 +800,15 @@ BACKFILLERS: dict[str, BackfillFn] = {
"stack_logs": _backfill_stack_logs,
"user_logs": _backfill_user_logs,
"token_usage": _backfill_token_usage,
# Phase 2 (order: FK targets first)
"agent_folders": _backfill_agent_folders,
"sources": _backfill_sources,
"attachments": _backfill_attachments,
"agents": _backfill_agents,
"memories": _backfill_memories,
"todos": _backfill_todos,
"notes": _backfill_notes,
"connector_sessions": _backfill_connector_sessions,
}

View File

@@ -0,0 +1,116 @@
"""Tests for AgentFoldersRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> AgentFoldersRepository:
return AgentFoldersRepository(conn)
class TestCreate:
def test_creates_folder(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "My Folder")
assert doc["user_id"] == "user-1"
assert doc["name"] == "My Folder"
assert doc["id"] is not None
def test_create_returns_id_and_underscore_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "f")
assert doc["_id"] == doc["id"]
class TestGet:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "f")
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
def test_get_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "f")
assert repo.get(created["id"], "user-other") is None
class TestListForUser:
def test_lists_only_own_folders(self, pg_conn):
repo = _repo(pg_conn)
repo.create("alice", "f1")
repo.create("alice", "f2")
repo.create("bob", "f3")
results = repo.list_for_user("alice")
assert len(results) == 2
assert all(r["user_id"] == "alice" for r in results)
class TestUpdate:
def test_updates_name(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old")
updated = repo.update(created["id"], "user-1", {"name": "new"})
assert updated is True
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "new"
def test_update_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old")
updated = repo.update(created["id"], "user-other", {"name": "new"})
assert updated is False
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "old"
def test_update_disallowed_field_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "f")
updated = repo.update(created["id"], "user-1", {"id": "00000000-0000-0000-0000-000000000000"})
assert updated is False
class TestDelete:
def test_deletes_folder(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "f")
deleted = repo.delete(created["id"], "user-1")
assert deleted is True
assert repo.get(created["id"], "user-1") is None
def test_delete_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "f")
deleted = repo.delete(created["id"], "user-other")
assert deleted is False
assert repo.get(created["id"], "user-1") is not None
class TestTenantIsolation:
def test_user_a_cannot_see_user_b_folders(self, pg_conn):
repo = _repo(pg_conn)
folder_a = repo.create("alice", "private")
assert repo.get(folder_a["id"], "bob") is None
def test_list_returns_only_own_folders(self, pg_conn):
repo = _repo(pg_conn)
repo.create("alice", "a1")
repo.create("bob", "b1")
alice_folders = repo.list_for_user("alice")
bob_folders = repo.list_for_user("bob")
assert len(alice_folders) == 1
assert len(bob_folders) == 1
assert alice_folders[0]["name"] == "a1"
assert bob_folders[0]["name"] == "b1"

View File

@@ -0,0 +1,163 @@
"""Tests for AgentsRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.agents import AgentsRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> AgentsRepository:
return AgentsRepository(conn)
class TestCreate:
def test_creates_agent_minimal(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "My Agent", "draft")
assert doc["user_id"] == "user-1"
assert doc["name"] == "My Agent"
assert doc["status"] == "draft"
assert doc["id"] is not None
def test_create_with_kwargs(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create(
"user-1", "Agent2", "active",
description="A test agent",
chunks=5,
tools=[{"name": "search"}],
shared=True,
)
assert doc["description"] == "A test agent"
assert doc["chunks"] == 5
assert doc["tools"] == [{"name": "search"}]
assert doc["shared"] is True
def test_create_returns_id_and_underscore_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("u", "a", "draft")
assert doc["_id"] == doc["id"]
class TestGet:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "a", "draft")
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
def test_get_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "a", "draft")
assert repo.get(created["id"], "user-other") is None
class TestFindByKey:
def test_finds_agent_by_key(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("u", "a", "draft", key="my-unique-key")
fetched = repo.find_by_key("my-unique-key")
assert fetched["id"] == created["id"]
def test_find_by_key_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.find_by_key("nonexistent-key") is None
class TestListForUser:
def test_lists_only_own_agents(self, pg_conn):
repo = _repo(pg_conn)
repo.create("alice", "a1", "draft")
repo.create("alice", "a2", "active")
repo.create("bob", "b1", "draft")
results = repo.list_for_user("alice")
assert len(results) == 2
assert all(r["user_id"] == "alice" for r in results)
class TestUpdate:
def test_updates_name(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old", "draft")
updated = repo.update(created["id"], "user-1", {"name": "new"})
assert updated is True
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "new"
def test_update_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old", "draft")
updated = repo.update(created["id"], "user-other", {"name": "new"})
assert updated is False
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "old"
def test_update_disallowed_field_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "a", "draft")
updated = repo.update(created["id"], "user-1", {"id": "bad"})
assert updated is False
class TestDelete:
def test_deletes_agent(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "a", "draft")
deleted = repo.delete(created["id"], "user-1")
assert deleted is True
assert repo.get(created["id"], "user-1") is None
def test_delete_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "a", "draft")
deleted = repo.delete(created["id"], "user-other")
assert deleted is False
assert repo.get(created["id"], "user-1") is not None
class TestSetFolder:
def test_assigns_folder(self, pg_conn):
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
folder_repo = AgentFoldersRepository(pg_conn)
folder = folder_repo.create("user-1", "f")
repo = _repo(pg_conn)
agent = repo.create("user-1", "a", "draft")
repo.set_folder(agent["id"], "user-1", folder["id"])
fetched = repo.get(agent["id"], "user-1")
assert str(fetched["folder_id"]) == str(folder["id"])
def test_clear_folder(self, pg_conn):
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
folder_repo = AgentFoldersRepository(pg_conn)
folder = folder_repo.create("user-1", "f")
repo = _repo(pg_conn)
agent = repo.create("user-1", "a", "draft", folder_id=folder["id"])
repo.set_folder(agent["id"], "user-1", None)
fetched = repo.get(agent["id"], "user-1")
assert fetched["folder_id"] is None
class TestClearFolderForAll:
def test_clears_folder_from_all_agents(self, pg_conn):
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
folder_repo = AgentFoldersRepository(pg_conn)
folder = folder_repo.create("user-1", "f")
repo = _repo(pg_conn)
a1 = repo.create("user-1", "a1", "draft", folder_id=folder["id"])
a2 = repo.create("user-1", "a2", "draft", folder_id=folder["id"])
repo.clear_folder_for_all(folder["id"], "user-1")
assert repo.get(a1["id"], "user-1")["folder_id"] is None
assert repo.get(a2["id"], "user-1")["folder_id"] is None

View File

@@ -0,0 +1,71 @@
"""Tests for AttachmentsRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.attachments import AttachmentsRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> AttachmentsRepository:
return AttachmentsRepository(conn)
class TestCreate:
def test_creates_attachment(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "file.pdf", "/uploads/file.pdf")
assert doc["user_id"] == "user-1"
assert doc["filename"] == "file.pdf"
assert doc["upload_path"] == "/uploads/file.pdf"
assert doc["id"] is not None
def test_creates_with_optional_fields(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "img.png", "/uploads/img.png",
mime_type="image/png", size=1024)
assert doc["mime_type"] == "image/png"
assert doc["size"] == 1024
def test_create_returns_id_and_underscore_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("u", "f", "/p")
assert doc["_id"] == doc["id"]
class TestGet:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("u", "f", "/p")
fetched = repo.get(created["id"], "u")
assert fetched["id"] == created["id"]
def test_get_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("u", "f", "/p")
assert repo.get(created["id"], "other") is None
class TestListForUser:
def test_lists_only_own_attachments(self, pg_conn):
repo = _repo(pg_conn)
repo.create("alice", "a1.pdf", "/a1")
repo.create("alice", "a2.pdf", "/a2")
repo.create("bob", "b1.pdf", "/b1")
results = repo.list_for_user("alice")
assert len(results) == 2
assert all(r["user_id"] == "alice" for r in results)
def test_list_empty_for_unknown_user(self, pg_conn):
repo = _repo(pg_conn)
results = repo.list_for_user("nonexistent")
assert results == []

View File

@@ -0,0 +1,94 @@
"""Tests for ConnectorSessionsRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.connector_sessions import ConnectorSessionsRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> ConnectorSessionsRepository:
return ConnectorSessionsRepository(conn)
class TestUpsert:
def test_creates_session(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.upsert("user-1", "google", {"token": "abc123"})
assert doc["user_id"] == "user-1"
assert doc["provider"] == "google"
assert doc["session_data"] == {"token": "abc123"}
assert doc["id"] is not None
def test_upsert_creates_second_session(self, pg_conn):
repo = _repo(pg_conn)
first = repo.upsert("user-1", "google", {"token": "v1"})
assert first["session_data"] == {"token": "v1"}
# Without a UNIQUE(user_id, provider) constraint, a second upsert
# creates another row (ON CONFLICT DO NOTHING never fires).
second = repo.upsert("user-1", "google", {"token": "v2"})
assert second["session_data"] == {"token": "v2"}
class TestGetByUserProvider:
def test_finds_existing(self, pg_conn):
repo = _repo(pg_conn)
repo.upsert("u", "slack", {"key": "val"})
fetched = repo.get_by_user_provider("u", "slack")
assert fetched is not None
assert fetched["session_data"] == {"key": "val"}
def test_returns_none_for_missing(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get_by_user_provider("u", "nonexistent") is None
def test_different_providers_are_separate(self, pg_conn):
repo = _repo(pg_conn)
repo.upsert("u", "google", {"g": 1})
repo.upsert("u", "slack", {"s": 2})
g = repo.get_by_user_provider("u", "google")
s = repo.get_by_user_provider("u", "slack")
assert g["session_data"] == {"g": 1}
assert s["session_data"] == {"s": 2}
class TestListForUser:
def test_lists_all_providers(self, pg_conn):
repo = _repo(pg_conn)
repo.upsert("alice", "google", {"g": 1})
repo.upsert("alice", "slack", {"s": 1})
repo.upsert("bob", "google", {"g": 2})
results = repo.list_for_user("alice")
assert len(results) == 2
assert all(r["user_id"] == "alice" for r in results)
def test_list_empty_for_unknown_user(self, pg_conn):
repo = _repo(pg_conn)
assert repo.list_for_user("nonexistent") == []
class TestDelete:
def test_deletes_session(self, pg_conn):
repo = _repo(pg_conn)
repo.upsert("u", "google", {"t": 1})
deleted = repo.delete("u", "google")
assert deleted is True
assert repo.get_by_user_provider("u", "google") is None
def test_delete_nonexistent_returns_false(self, pg_conn):
repo = _repo(pg_conn)
deleted = repo.delete("u", "nonexistent")
assert deleted is False
def test_delete_one_provider_leaves_others(self, pg_conn):
repo = _repo(pg_conn)
repo.upsert("u", "google", {"g": 1})
repo.upsert("u", "slack", {"s": 1})
repo.delete("u", "google")
assert repo.get_by_user_provider("u", "google") is None
assert repo.get_by_user_provider("u", "slack") is not None

View File

@@ -0,0 +1,135 @@
"""Tests for MemoriesRepository against a real Postgres instance.
Memories have a FK to user_tools, so each test creates a tool row first.
"""
from __future__ import annotations
import pytest
from sqlalchemy import text
from application.storage.db.repositories.memories import MemoriesRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> MemoriesRepository:
return MemoriesRepository(conn)
def _make_tool(conn, user_id: str = "test-user", name: str = "mem-tool") -> str:
"""Insert a user_tools row and return its UUID as a string."""
return str(
conn.execute(
text("INSERT INTO user_tools (user_id, name) VALUES (:uid, :name) RETURNING id"),
{"uid": user_id, "name": name},
).scalar()
)
class TestUpsert:
def test_creates_memory(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
doc = repo.upsert("test-user", tool_id, "/docs/readme.md", "Hello world")
assert doc["path"] == "/docs/readme.md"
assert doc["content"] == "Hello world"
assert doc["id"] is not None
def test_upsert_overwrites_content(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("test-user", tool_id, "/a.txt", "v1")
doc = repo.upsert("test-user", tool_id, "/a.txt", "v2")
assert doc["content"] == "v2"
def test_upsert_is_idempotent_on_same_content(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
first = repo.upsert("test-user", tool_id, "/a.txt", "same")
second = repo.upsert("test-user", tool_id, "/a.txt", "same")
assert first["id"] == second["id"]
class TestGetByPath:
def test_finds_existing(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "/x", "content")
fetched = repo.get_by_path("u", tool_id, "/x")
assert fetched is not None
assert fetched["content"] == "content"
def test_returns_none_for_missing(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
assert repo.get_by_path("u", tool_id, "/nonexistent") is None
class TestListByPrefix:
def test_lists_matching_prefix(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "/docs/a.md", "a")
repo.upsert("u", tool_id, "/docs/b.md", "b")
repo.upsert("u", tool_id, "/other/c.md", "c")
results = repo.list_by_prefix("u", tool_id, "/docs/")
assert len(results) == 2
assert {r["path"] for r in results} == {"/docs/a.md", "/docs/b.md"}
class TestDeleteByPath:
def test_deletes_single(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "/x", "c")
count = repo.delete_by_path("u", tool_id, "/x")
assert count == 1
assert repo.get_by_path("u", tool_id, "/x") is None
def test_delete_nonexistent_returns_zero(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
assert repo.delete_by_path("u", tool_id, "/nope") == 0
class TestDeleteByPrefix:
def test_deletes_matching_prefix(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "/dir/a", "a")
repo.upsert("u", tool_id, "/dir/b", "b")
repo.upsert("u", tool_id, "/other/c", "c")
count = repo.delete_by_prefix("u", tool_id, "/dir/")
assert count == 2
assert repo.get_by_path("u", tool_id, "/other/c") is not None
class TestDeleteAll:
def test_deletes_all_for_user_tool(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "/a", "a")
repo.upsert("u", tool_id, "/b", "b")
count = repo.delete_all("u", tool_id)
assert count == 2
assert repo.list_by_prefix("u", tool_id, "/") == []
class TestUpdatePath:
def test_renames_path(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "/old.txt", "content")
renamed = repo.update_path("u", tool_id, "/old.txt", "/new.txt")
assert renamed is True
assert repo.get_by_path("u", tool_id, "/old.txt") is None
assert repo.get_by_path("u", tool_id, "/new.txt")["content"] == "content"
def test_rename_nonexistent_returns_false(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
assert repo.update_path("u", tool_id, "/nope", "/new") is False

View File

@@ -0,0 +1,100 @@
"""Tests for NotesRepository against a real Postgres instance.
Notes have a FK to user_tools, so each test creates a tool row first.
"""
from __future__ import annotations
import pytest
from sqlalchemy import text
from application.storage.db.repositories.notes import NotesRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> NotesRepository:
return NotesRepository(conn)
def _make_tool(conn, user_id: str = "test-user", name: str = "notes-tool") -> str:
"""Insert a user_tools row and return its UUID as a string."""
return str(
conn.execute(
text("INSERT INTO user_tools (user_id, name) VALUES (:uid, :name) RETURNING id"),
{"uid": user_id, "name": name},
).scalar()
)
class TestUpsert:
def test_creates_note(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
doc = repo.upsert("test-user", tool_id, "My Note", "Some content")
assert doc["title"] == "My Note"
assert doc["content"] == "Some content"
assert doc["id"] is not None
def test_second_upsert_also_returns_content(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
first = repo.upsert("test-user", tool_id, "title", "v1")
assert first["content"] == "v1"
# A second upsert for the same (user, tool) creates a new note
# (no unique constraint on (user_id, tool_id) exists).
second = repo.upsert("test-user", tool_id, "title2", "v2")
assert second["content"] == "v2"
class TestGetForUserTool:
def test_returns_note(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "t", "c")
fetched = repo.get_for_user_tool("u", tool_id)
assert fetched is not None
assert fetched["content"] == "c"
def test_returns_none_when_missing(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
assert repo.get_for_user_tool("u", tool_id) is None
class TestGetById:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.upsert("u", tool_id, "t", "c")
fetched = repo.get(created["id"], "u")
assert fetched["id"] == created["id"]
def test_get_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.upsert("u", tool_id, "t", "c")
assert repo.get(created["id"], "other") is None
class TestDelete:
def test_deletes_note(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.upsert("u", tool_id, "t", "c")
deleted = repo.delete("u", tool_id)
assert deleted is True
assert repo.get_for_user_tool("u", tool_id) is None
def test_delete_nonexistent_returns_false(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
deleted = repo.delete("u", tool_id)
assert deleted is False

View File

@@ -35,23 +35,30 @@ class TestGet:
def test_get_by_id_and_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "p", "c")
fetched = repo.get(created["id"], user_id="user-1")
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
def test_get_by_id_only(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "p", "c")
fetched = repo.get(created["id"])
assert fetched is not None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "p", "c")
assert repo.get(created["id"], user_id="user-other") is None
assert repo.get(created["id"], "user-other") is None
def test_get_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000") is None
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
class TestGetForRendering:
def test_returns_prompt_without_user_scoping(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "p", "c")
fetched = repo.get_for_rendering(created["id"])
assert fetched is not None
assert fetched["id"] == created["id"]
def test_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get_for_rendering("00000000-0000-0000-0000-000000000000") is None
class TestListForUser:
@@ -70,7 +77,7 @@ class TestUpdate:
repo = _repo(pg_conn)
created = repo.create("user-1", "old", "old-content")
repo.update(created["id"], "user-1", "new", "new-content")
fetched = repo.get(created["id"])
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "new"
assert fetched["content"] == "new-content"
@@ -78,7 +85,7 @@ class TestUpdate:
repo = _repo(pg_conn)
created = repo.create("user-1", "old", "old-content")
repo.update(created["id"], "user-other", "new", "new-content")
fetched = repo.get(created["id"])
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "old"
@@ -87,13 +94,13 @@ class TestDelete:
repo = _repo(pg_conn)
created = repo.create("user-1", "p", "c")
repo.delete(created["id"], "user-1")
assert repo.get(created["id"]) is None
assert repo.get(created["id"], "user-1") is None
def test_delete_wrong_user_is_noop(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "p", "c")
repo.delete(created["id"], "user-other")
assert repo.get(created["id"]) is not None
assert repo.get(created["id"], "user-1") is not None
class TestFindOrCreate:

View File

@@ -0,0 +1,121 @@
"""Tests for SourcesRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.sources import SourcesRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> SourcesRepository:
return SourcesRepository(conn)
class TestCreate:
def test_creates_source_with_user(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("my-source", user_id="user-1", type="url")
assert doc["user_id"] == "user-1"
assert doc["name"] == "my-source"
assert doc["type"] == "url"
assert doc["id"] is not None
def test_creates_system_source_without_user(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("system-src")
assert doc["user_id"] is None
assert doc["name"] == "system-src"
def test_creates_source_with_metadata(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("src", user_id="u", metadata={"url": "https://example.com"})
assert doc["metadata"] == {"url": "https://example.com"}
def test_create_returns_id_and_underscore_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("s")
assert doc["_id"] == doc["id"]
class TestGet:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("s", user_id="user-1")
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
def test_get_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("s", user_id="user-1")
assert repo.get(created["id"], "user-other") is None
class TestListForUser:
def test_lists_only_own_sources(self, pg_conn):
repo = _repo(pg_conn)
repo.create("s1", user_id="alice")
repo.create("s2", user_id="alice")
repo.create("s3", user_id="bob")
results = repo.list_for_user("alice")
assert len(results) == 2
assert all(r["user_id"] == "alice" for r in results)
class TestUpdate:
def test_updates_name(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("old", user_id="u")
repo.update(created["id"], "u", {"name": "new"})
fetched = repo.get(created["id"], "u")
assert fetched["name"] == "new"
def test_updates_metadata(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("s", user_id="u", metadata={"a": 1})
repo.update(created["id"], "u", {"metadata": {"a": 2, "b": 3}})
fetched = repo.get(created["id"], "u")
assert fetched["metadata"] == {"a": 2, "b": 3}
def test_update_disallowed_field_is_noop(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("s", user_id="u")
repo.update(created["id"], "u", {"id": "00000000-0000-0000-0000-000000000000"})
fetched = repo.get(created["id"], "u")
assert fetched["id"] == created["id"]
def test_update_wrong_user_is_noop(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("old", user_id="u")
repo.update(created["id"], "other-user", {"name": "new"})
fetched = repo.get(created["id"], "u")
assert fetched["name"] == "old"
class TestDelete:
def test_deletes_source(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("s", user_id="u")
deleted = repo.delete(created["id"], "u")
assert deleted is True
assert repo.get(created["id"], "u") is None
def test_delete_nonexistent_returns_false(self, pg_conn):
repo = _repo(pg_conn)
deleted = repo.delete("00000000-0000-0000-0000-000000000000", "u")
assert deleted is False
def test_delete_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("s", user_id="u")
deleted = repo.delete(created["id"], "other-user")
assert deleted is False
assert repo.get(created["id"], "u") is not None

View File

@@ -0,0 +1,158 @@
"""Tests for TodosRepository against a real Postgres instance.
Todos have a FK to user_tools, so each test creates a tool row first.
"""
from __future__ import annotations
import pytest
from sqlalchemy import text
from application.storage.db.repositories.todos import TodosRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> TodosRepository:
return TodosRepository(conn)
def _make_tool(conn, user_id: str = "test-user", name: str = "todo-tool") -> str:
"""Insert a user_tools row and return its UUID as a string."""
return str(
conn.execute(
text("INSERT INTO user_tools (user_id, name) VALUES (:uid, :name) RETURNING id"),
{"uid": user_id, "name": name},
).scalar()
)
class TestCreate:
def test_creates_todo(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
doc = repo.create("test-user", tool_id, "Buy milk")
assert doc["title"] == "Buy milk"
assert doc["completed"] is False
assert doc["id"] is not None
def test_create_returns_id_and_underscore_id(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
doc = repo.create("test-user", tool_id, "t")
assert doc["_id"] == doc["id"]
class TestGet:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "t")
fetched = repo.get(created["id"], "u")
assert fetched["id"] == created["id"]
def test_get_nonexistent_returns_none(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "t")
assert repo.get(created["id"], "other") is None
class TestListForUserTool:
def test_lists_todos_for_user_tool(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
repo.create("u", tool_id, "t1")
repo.create("u", tool_id, "t2")
results = repo.list_for_user_tool("u", tool_id)
assert len(results) == 2
def test_different_tools_are_isolated(self, pg_conn):
repo = _repo(pg_conn)
tool_a = _make_tool(pg_conn, name="tool-a")
tool_b = _make_tool(pg_conn, name="tool-b")
repo.create("u", tool_a, "a-todo")
repo.create("u", tool_b, "b-todo")
assert len(repo.list_for_user_tool("u", tool_a)) == 1
assert len(repo.list_for_user_tool("u", tool_b)) == 1
class TestUpdateTitle:
def test_updates_title(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "old")
updated = repo.update_title(created["id"], "u", "new")
assert updated is True
fetched = repo.get(created["id"], "u")
assert fetched["title"] == "new"
def test_update_nonexistent_returns_false(self, pg_conn):
repo = _repo(pg_conn)
assert repo.update_title("00000000-0000-0000-0000-000000000000", "u", "x") is False
def test_update_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "old")
updated = repo.update_title(created["id"], "other", "new")
assert updated is False
fetched = repo.get(created["id"], "u")
assert fetched["title"] == "old"
class TestSetCompleted:
def test_marks_completed(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "t")
repo.set_completed(created["id"], "u", True)
fetched = repo.get(created["id"], "u")
assert fetched["completed"] is True
def test_unmarks_completed(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "t")
repo.set_completed(created["id"], "u", True)
repo.set_completed(created["id"], "u", False)
fetched = repo.get(created["id"], "u")
assert fetched["completed"] is False
def test_set_completed_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "t")
result = repo.set_completed(created["id"], "other", True)
assert result is False
fetched = repo.get(created["id"], "u")
assert fetched["completed"] is False
class TestDelete:
def test_deletes_todo(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "t")
deleted = repo.delete(created["id"], "u")
assert deleted is True
assert repo.get(created["id"], "u") is None
def test_delete_nonexistent_returns_false(self, pg_conn):
repo = _repo(pg_conn)
assert repo.delete("00000000-0000-0000-0000-000000000000", "u") is False
def test_delete_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
tool_id = _make_tool(pg_conn)
created = repo.create("u", tool_id, "t")
deleted = repo.delete(created["id"], "other")
assert deleted is False
assert repo.get(created["id"], "u") is not None

View File

@@ -36,12 +36,17 @@ class TestGet:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "t")
fetched = repo.get(created["id"])
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
def test_get_nonexistent(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000") is None
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
def test_get_wrong_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "t")
assert repo.get(created["id"], "user-other") is None
class TestListForUser:
@@ -59,28 +64,28 @@ class TestUpdate:
repo = _repo(pg_conn)
created = repo.create("user-1", "old_name")
repo.update(created["id"], "user-1", {"name": "new_name"})
fetched = repo.get(created["id"])
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "new_name"
def test_updates_config(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "t", config={"a": 1})
repo.update(created["id"], "user-1", {"config": {"a": 2, "b": 3}})
fetched = repo.get(created["id"])
fetched = repo.get(created["id"], "user-1")
assert fetched["config"] == {"a": 2, "b": 3}
def test_update_wrong_user_is_noop(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old")
repo.update(created["id"], "user-other", {"name": "new"})
fetched = repo.get(created["id"])
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "old"
def test_ignores_disallowed_fields(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "t")
repo.update(created["id"], "user-1", {"id": "00000000-0000-0000-0000-000000000000"})
fetched = repo.get(created["id"])
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
@@ -90,11 +95,11 @@ class TestDelete:
created = repo.create("user-1", "t")
deleted = repo.delete(created["id"], "user-1")
assert deleted is True
assert repo.get(created["id"]) is None
assert repo.get(created["id"], "user-1") is None
def test_delete_wrong_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "t")
deleted = repo.delete(created["id"], "user-other")
assert deleted is False
assert repo.get(created["id"]) is not None
assert repo.get(created["id"], "user-1") is not None