Compare commits

...

20 Commits

Author SHA1 Message Date
Alex
aa938d76d7 Add GitHub Actions zizmor security workflow 2026-04-14 17:56:14 +01:00
Manish Madan
2940628aa6 Merge pull request #2319 from arc53/dependabot/npm_and_yarn/frontend/npm_and_yarn-e5a595f223
chore(deps-dev): bump flatted from 3.4.1 to 3.4.2 in /frontend in the npm_and_yarn group across 1 directory
2026-04-14 21:30:54 +05:30
dependabot[bot]
7f23928134 chore(deps-dev): bump flatted
Bumps the npm_and_yarn group with 1 update in the /frontend directory: [flatted](https://github.com/WebReflection/flatted).


Updates `flatted` from 3.4.1 to 3.4.2
- [Commits](https://github.com/WebReflection/flatted/compare/v3.4.1...v3.4.2)

---
updated-dependencies:
- dependency-name: flatted
  dependency-version: 3.4.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 15:10:18 +00:00
Alex
20e17c84c7 Merge pull request #2379 from arc53/codex/refine-and-review-incident-response-plan
Add INCIDENT_RESPONSE.md and reference it from SECURITY.md
2026-04-14 14:59:04 +01:00
copilot-swe-agent[bot]
389ddf6068 Fix secret references in INCIDENT_RESPONSE.md to match actual DocsGPT config
Agent-Logs-Url: https://github.com/arc53/DocsGPT/sessions/c6bfd68d-4dac-46ec-8404-fe5bfda0e8f3

Co-authored-by: dartpain <15183589+dartpain@users.noreply.github.com>
2026-04-14 10:51:22 +00:00
Alex
1e2443fb90 Update .github/INCIDENT_RESPONSE.md
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-14 11:49:13 +01:00
Manish Madan
6387bd1892 Merge pull request #2303 from arc53/dependabot/npm_and_yarn/frontend/typescript-eslint/eslint-plugin-8.57.1
chore(deps-dev): bump @typescript-eslint/eslint-plugin from 8.46.3 to 8.57.1 in /frontend
2026-04-14 14:16:35 +05:30
dependabot[bot]
7d22724d1c chore(deps-dev): bump @typescript-eslint/eslint-plugin in /frontend
Bumps [@typescript-eslint/eslint-plugin](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/eslint-plugin) from 8.46.3 to 8.57.1.
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/eslint-plugin/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.57.1/packages/eslint-plugin)

---
updated-dependencies:
- dependency-name: "@typescript-eslint/eslint-plugin"
  dependency-version: 8.57.1
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 08:39:33 +00:00
Manish Madan
f6f12f6895 Merge pull request #2302 from arc53/dependabot/npm_and_yarn/frontend/prettier-plugin-tailwindcss-0.7.2
chore(deps-dev): bump prettier-plugin-tailwindcss from 0.7.1 to 0.7.2 in /frontend
2026-04-14 14:07:38 +05:30
dependabot[bot]
934127f323 chore(deps-dev): bump prettier-plugin-tailwindcss in /frontend
Bumps [prettier-plugin-tailwindcss](https://github.com/tailwindlabs/prettier-plugin-tailwindcss) from 0.7.1 to 0.7.2.
- [Release notes](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/compare/v0.7.1...v0.7.2)

---
updated-dependencies:
- dependency-name: prettier-plugin-tailwindcss
  dependency-version: 0.7.2
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-14 08:25:09 +00:00
Manish Madan
1780e3cc91 Merge pull request #2301 from arc53/dependabot/npm_and_yarn/frontend/react-i18next-16.5.8
chore(deps): bump react-i18next from 16.2.4 to 16.5.8 in /frontend
2026-04-14 13:53:12 +05:30
ManishMadan2882
5e7fab2f34 (chore:fe) i18next 2026-04-14 13:50:03 +05:30
Alex
92ae76f95e Merge pull request #2381 from arc53/pg-3
feat: pre depriciation
2026-04-14 08:33:42 +01:00
Alex
18755bdd9b fix: workflow tests 2026-04-14 00:35:57 +01:00
Alex
0f20adcbf4 feat: pre depriciation 2026-04-14 00:19:50 +01:00
Alex
18e2a829c9 docs: apply revised incident response plan wording 2026-04-13 14:11:45 +01:00
dependabot[bot]
cd44501a71 chore(deps): bump react-i18next from 16.2.4 to 16.5.8 in /frontend
Bumps [react-i18next](https://github.com/i18next/react-i18next) from 16.2.4 to 16.5.8.
- [Changelog](https://github.com/i18next/react-i18next/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/react-i18next/compare/v16.2.4...v16.5.8)

---
updated-dependencies:
- dependency-name: react-i18next
  dependency-version: 16.5.8
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-13 12:36:35 +00:00
Manish Madan
f8ebdf3fd4 Merge pull request #2300 from arc53/dependabot/npm_and_yarn/frontend/i18next-browser-languagedetector-8.2.1
chore(deps): bump i18next-browser-languagedetector from 8.2.0 to 8.2.1 in /frontend
2026-04-13 18:03:46 +05:30
dependabot[bot]
7c6fca18ad chore(deps): bump i18next-browser-languagedetector in /frontend
Bumps [i18next-browser-languagedetector](https://github.com/i18next/i18next-browser-languageDetector) from 8.2.0 to 8.2.1.
- [Changelog](https://github.com/i18next/i18next-browser-languageDetector/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/i18next-browser-languageDetector/compare/v8.2.0...v8.2.1)

---
updated-dependencies:
- dependency-name: i18next-browser-languagedetector
  dependency-version: 8.2.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-13 12:28:26 +00:00
Alex
5fab798707 Merge pull request #2377 from arc53/pg-2
feat: pg-2
2026-04-12 14:09:52 +01:00
52 changed files with 5683 additions and 1284 deletions

99
.github/INCIDENT_RESPONSE.md vendored Normal file
View File

@@ -0,0 +1,99 @@
# DocsGPT Incident Response Plan (IRP)
This playbook describes how maintainers respond to confirmed or suspected security incidents.
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
## Severity
| Severity | Definition | Typical examples |
|---|---|---|
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
## Response workflow
### 1) Triage (target: initial response within 48 hours)
1. Acknowledge report.
2. Validate on latest release and `main`.
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
### 2) Investigation
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
3. Request a CVE through GHSA for **Medium+** issues.
### 3) Containment, fix, and disclosure
1. Implement and test fix in private security workflow (GHSA private fork/branch).
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
### 4) Post-incident
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
3. Track follow-up hardening actions with owners/dates.
4. Update this IRP and related runbooks as needed.
## Scenario playbooks
### Supply-chain compromise
1. Freeze releases and investigate blast radius.
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
4. Publish advisory with exact affected versions and required user actions.
### Data exposure
1. Determine scope (users, documents, keys, logs, time window).
2. Disable affected path or hotfix immediately for managed cloud.
3. Notify affected users with concrete remediation steps (for example, rotate keys).
4. Continue through standard fix/disclosure workflow.
### Critical regression with security impact
1. Identify introducing change (`git bisect` if needed).
2. Publish workaround within 24 hours (for example, pin to known-good version).
3. Ship patch release with regression test and close incident with public summary.
## AI-specific guidance
Treat confirmed AI-specific abuse as security incidents:
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
- Cross-tenant retrieval/isolation failure -> **High**
- API key disclosure in output -> **High**
## Secret rotation quick reference
| Secret | Standard rotation action |
|---|---|
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
| Database credentials | Rotate in DB platform; redeploy with new secrets |
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
## Maintenance
Review this document:
- after every **Critical/High** incident, and
- at least annually.
Changes should be proposed via pull request to `main`.

25
.github/workflows/zizmor.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: GitHub Actions Security Analysis
on:
push:
branches: ["master"]
pull_request:
branches: ["**"]
permissions: {}
jobs:
zizmor:
runs-on: ubuntu-latest
permissions:
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Run zizmor 🌈
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2

View File

@@ -18,5 +18,5 @@ We aim to acknowledge reports within 48 hours.
## Incident Handling
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.

View File

@@ -15,6 +15,9 @@ from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
logger = logging.getLogger(__name__)
@@ -181,6 +184,9 @@ class WorkflowAgent(BaseAgent):
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
@@ -188,6 +194,7 @@ class WorkflowAgent(BaseAgent):
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
user=owner_id,
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
@@ -196,7 +203,34 @@ class WorkflowAgent(BaseAgent):
completed_at=datetime.now(timezone.utc),
)
workflow_runs_coll.insert_one(run.to_mongo_doc())
result = workflow_runs_coll.insert_one(run.to_mongo_doc())
legacy_mongo_id = (
str(result.inserted_id)
if getattr(result, "inserted_id", None) is not None
else None
)
def _pg_write(repo: WorkflowRunsRepository) -> None:
if not self.workflow_id or not owner_id or not legacy_mongo_id:
return
workflow = WorkflowsRepository(repo._conn).get_by_legacy_id(
self.workflow_id, owner_id,
)
if workflow is None:
return
repo.create(
workflow["id"],
owner_id,
run.status.value,
inputs=run.inputs,
result=run.outputs,
steps=[step.model_dump(mode="json") for step in run.steps],
started_at=run.created_at,
ended_at=run.completed_at,
legacy_mongo_id=legacy_mongo_id,
)
dual_write(WorkflowRunsRepository, _pg_write)
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")

View File

@@ -211,6 +211,7 @@ class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = Field(None, alias="_id")
workflow_id: str
user: Optional[str] = None
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
@@ -226,7 +227,7 @@ class WorkflowRun(BaseModel):
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
doc = {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
@@ -235,3 +236,7 @@ class WorkflowRun(BaseModel):
"created_at": self.created_at,
"completed_at": self.completed_at,
}
if self.user:
doc["user"] = self.user
doc["user_id"] = self.user
return doc

File diff suppressed because it is too large Load Diff

View File

@@ -1,57 +0,0 @@
"""0002 add unique constraints for notes and connector_sessions.
The memories table already has ``memories_user_tool_path_uidx`` from the
0001 baseline. Notes and connector_sessions were missing unique constraints
that their repository upsert logic depends on.
Before creating the indexes, duplicate rows are cleaned up — keeping only
the row with the latest ``id`` (UUID, lexicographic max) per group.
Revision ID: 0002_add_unique_constraints
Revises: 0001_initial
Create Date: 2026-04-12
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002_add_unique_constraints"
down_revision: Union[str, None] = "0001_initial"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Deduplicate notes: keep one row per (user_id, tool_id)
op.execute("""
DELETE FROM notes
WHERE id NOT IN (
SELECT DISTINCT ON (user_id, tool_id) id
FROM notes
ORDER BY user_id, tool_id, created_at DESC
);
""")
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS notes_user_tool_uidx "
"ON notes (user_id, tool_id);"
)
# Deduplicate connector_sessions: keep one row per (user_id, provider)
op.execute("""
DELETE FROM connector_sessions
WHERE id NOT IN (
SELECT DISTINCT ON (user_id, provider) id
FROM connector_sessions
ORDER BY user_id, provider, created_at DESC
);
""")
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS connector_sessions_user_provider_uidx "
"ON connector_sessions (user_id, provider);"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS connector_sessions_user_provider_uidx;")
op.execute("DROP INDEX IF EXISTS notes_user_tool_uidx;")

View File

@@ -13,6 +13,11 @@ from bson import ObjectId
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.pending_tool_state import (
PendingToolStateRepository,
)
logger = logging.getLogger(__name__)
@@ -107,6 +112,26 @@ class ContinuationService:
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
# Dual-write to Postgres — upsert against the same Mongo conversation
# by resolving its UUID via conversations.legacy_mongo_id.
def _pg_save(_: PendingToolStateRepository) -> None:
conn = _._conn # reuse the existing transaction
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is None:
return
_.save_state(
conv["id"],
user,
messages=_make_serializable(messages),
pending_tool_calls=_make_serializable(pending_tool_calls),
tools_dict=_make_serializable(tools_dict),
tool_schemas=_make_serializable(tool_schemas),
agent_config=_make_serializable(agent_config),
client_tools=_make_serializable(client_tools) if client_tools else None,
)
dual_write(PendingToolStateRepository, _pg_save)
return state_id
def load_state(
@@ -138,4 +163,13 @@ class ContinuationService:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
# Dual-write to Postgres — delete the same row.
def _pg_delete(repo: PendingToolStateRepository) -> None:
conv = ConversationsRepository(repo._conn).get_by_legacy_id(conversation_id)
if conv is None:
return
repo.delete_state(conv["id"], user)
dual_write(PendingToolStateRepository, _pg_delete)
return result.deleted_count > 0

View File

@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from bson import ObjectId
@@ -113,6 +115,26 @@ class ConversationService:
},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
# Dual-write to Postgres: update the message at :index and
# truncate anything after it, mirroring Mongo's $set+$slice.
def _pg_update_at_index(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
return
repo.update_message_at(conv["id"], index, {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
**({"metadata": metadata} if metadata else {}),
})
repo.truncate_after(conv["id"], index)
dual_write(ConversationsRepository, _pg_update_at_index)
return conversation_id
elif conversation_id:
# Append new message to existing conversation
@@ -138,6 +160,25 @@ class ConversationService:
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
# Dual-write to Postgres: append the same message.
def _pg_append(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
return
repo.append_message(conv["id"], {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
"metadata": metadata or {},
})
dual_write(ConversationsRepository, _pg_append)
return conversation_id
else:
# Create new conversation
@@ -193,7 +234,34 @@ class ConversationService:
if agent:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)
inserted_id = str(result.inserted_id)
# Dual-write to Postgres: create the conversation row with
# legacy_mongo_id and append the first message.
def _pg_create(repo: ConversationsRepository) -> None:
conv = repo.create(
user_id,
completion,
agent_id=conversation_data.get("agent_id"),
api_key=conversation_data.get("api_key"),
is_shared_usage=conversation_data.get("is_shared_usage", False),
shared_token=conversation_data.get("shared_token"),
legacy_mongo_id=inserted_id,
)
repo.append_message(conv["id"], {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
"metadata": metadata or {},
})
dual_write(ConversationsRepository, _pg_create)
return inserted_id
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
@@ -230,6 +298,24 @@ class ConversationService:
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
# Dual-write to Postgres: mirror $set + $push $slice.
def _pg_compression(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
return
repo.set_compression_flags(
conv["id"],
is_compressed=True,
last_compression_at=compression_metadata.get("timestamp"),
)
repo.append_compression_point(
conv["id"],
compression_metadata,
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
)
dual_write(ConversationsRepository, _pg_compression)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
@@ -266,6 +352,23 @@ class ConversationService:
}
},
)
def _pg_append_summary(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
return
repo.append_message(conv["id"], {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"attachments": [],
"model_id": compression_metadata.get("model_used"),
"timestamp": timestamp,
})
dual_write(ConversationsRepository, _pg_append_summary)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(

View File

@@ -114,6 +114,35 @@ AGENT_TYPE_SCHEMAS["research"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
def _build_pg_agent_fields(fields: dict) -> dict:
"""Translate Mongo-shaped agent fields into the Postgres mirror subset."""
allowed = {
"name",
"description",
"agent_type",
"status",
"key",
"chunks",
"retriever",
"tools",
"json_schema",
"models",
"default_model_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"incoming_webhook_token",
"lastUsedAt",
}
translated: dict = {}
for key, value in fields.items():
if key not in allowed:
continue
translated["last_used_at" if key == "lastUsedAt" else key] = value
return translated
def normalize_workflow_reference(workflow_value):
"""Normalize workflow references from form/json payloads."""
if workflow_value is None:
@@ -626,13 +655,14 @@ class CreateAgent(Resource):
new_id = str(resp.inserted_id)
dual_write(
AgentsRepository,
lambda repo, u=user, a=new_agent: repo.create(
lambda repo, u=user, a=new_agent, mid=new_id: repo.create(
u, a.get("name", ""), a.get("status", "draft"),
key=a.get("key"), description=a.get("description"),
retriever=a.get("retriever"), chunks=a.get("chunks"),
tools=a.get("tools"), models=a.get("models"),
shared=a.get("shared", False),
incoming_webhook_token=a.get("incoming_webhook_token"),
legacy_mongo_id=mid,
),
)
except Exception as err:
@@ -1170,6 +1200,14 @@ class UpdateAgent(Resource):
jsonify({"success": False, "message": "Database error during update"}),
500,
)
pg_update_fields = _build_pg_agent_fields(update_fields)
if pg_update_fields:
dual_write(
AgentsRepository,
lambda repo, aid=agent_id, u=user, fields=pg_update_fields: repo.update_by_legacy_id(
aid, u, fields,
),
)
response_data = {
"success": True,
"id": agent_id,
@@ -1199,7 +1237,7 @@ class DeleteAgent(Resource):
)
dual_write(
AgentsRepository,
lambda repo, aid=agent_id, u=user: repo.delete(aid, u),
lambda repo, aid=agent_id, u=user: repo.delete_by_legacy_id(aid, u),
)
if not deleted_agent:
return make_response(

View File

@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import attachments_collection, conversations_collection
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from application.utils import check_required_fields
conversations_ns = Namespace(
@@ -30,15 +32,23 @@ class DeleteConversation(Resource):
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
user_id = decoded_token["sub"]
try:
conversations_collection.delete_one(
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
{"_id": ObjectId(conversation_id), "user": user_id}
)
except Exception as err:
current_app.logger.error(
f"Error deleting conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
def _pg_delete(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(conversation_id)
if conv is not None:
repo.delete(conv["id"], user_id)
dual_write(ConversationsRepository, _pg_delete)
return make_response(jsonify({"success": True}), 200)
@@ -59,6 +69,11 @@ class DeleteAllConversations(Resource):
f"Error deleting all conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
dual_write(
ConversationsRepository,
lambda r, uid=user_id: r.delete_all_for_user(uid),
)
return make_response(jsonify({"success": True}), 200)
@@ -190,9 +205,10 @@ class UpdateConversationName(Resource):
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user_id = decoded_token.get("sub")
try:
conversations_collection.update_one(
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
{"_id": ObjectId(data["id"]), "user": user_id},
{"$set": {"name": data["name"]}},
)
except Exception as err:
@@ -200,6 +216,13 @@ class UpdateConversationName(Resource):
f"Error updating conversation name: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
def _pg_rename(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(data["id"])
if conv is not None:
repo.rename(conv["id"], user_id, data["name"])
dual_write(ConversationsRepository, _pg_rename)
return make_response(jsonify({"success": True}), 200)
@@ -277,4 +300,21 @@ class SubmitFeedback(Resource):
except Exception as err:
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
# Dual-write to Postgres: mirror the per-message feedback set/unset.
feedback_value = data["feedback"]
question_index = int(data["question_index"])
feedback_payload = (
None if feedback_value is None
else {"text": feedback_value, "timestamp": datetime.datetime.now(
datetime.timezone.utc
).isoformat()}
)
def _pg_feedback(repo: ConversationsRepository) -> None:
conv = repo.get_by_legacy_id(data["conversation_id"])
if conv is not None:
repo.set_feedback(conv["id"], question_index, feedback_payload)
dual_write(ConversationsRepository, _pg_feedback)
return make_response(jsonify({"success": True}), 200)

View File

@@ -53,7 +53,9 @@ class CreatePrompt(Resource):
new_id = str(resp.inserted_id)
dual_write(
PromptsRepository,
lambda repo, u=user, n=data["name"], c=data["content"]: repo.create(u, n, c),
lambda repo, u=user, n=data["name"], c=data["content"], mid=new_id: repo.create(
u, n, c, legacy_mongo_id=mid,
),
)
except Exception as err:
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
@@ -157,7 +159,7 @@ class DeletePrompt(Resource):
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
dual_write(
PromptsRepository,
lambda repo, pid=data["id"], u=user: repo.delete(pid, u),
lambda repo, pid=data["id"], u=user: repo.delete_by_legacy_id(pid, u),
)
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
@@ -197,7 +199,9 @@ class UpdatePrompt(Resource):
)
dual_write(
PromptsRepository,
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update(pid, u, n, c),
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update_by_legacy_id(
pid, u, n, c,
),
)
except Exception as err:
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)

View File

@@ -15,8 +15,71 @@ from application.api.user.base import (
conversations_collection,
shared_conversations_collections,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.shared_conversations import (
SharedConversationsRepository,
)
from application.utils import check_required_fields
def _dual_write_share(
mongo_conv_id: str,
share_uuid: str,
user: str,
*,
is_promptable: bool,
first_n_queries: int,
api_key: str | None,
prompt_id: str | None = None,
chunks: int | None = None,
) -> None:
"""Mirror a Mongo share-record insert into Postgres.
Preserves the Mongo-generated UUID so public ``/shared/{uuid}`` URLs
resolve from both stores during cutover.
"""
def _write(repo: SharedConversationsRepository) -> None:
conv = ConversationsRepository(repo._conn).get_by_legacy_id(
mongo_conv_id, user_id=user,
)
if conv is None:
return
# prompt_id / chunks are only meaningful for promptable shares;
# prompt_id is often the string "default" or an ObjectId that
# hasn't been migrated — pass as-is and let the repo drop
# non-UUID values. Scope the prompt lookup by user_id so an
# authenticated caller can't link another user's prompt into
# their share record.
resolved_prompt_id = None
if prompt_id and len(str(prompt_id)) == 24:
from sqlalchemy import text as _text
row = repo._conn.execute(
_text(
"SELECT id FROM prompts "
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
),
{"legacy_id": str(prompt_id), "user_id": user},
).fetchone()
if row:
resolved_prompt_id = str(row[0])
# get_or_create is race-free on the PG side thanks to the
# composite partial unique index on the dedup tuple
# (migration 0008). It converges concurrent share requests to
# a single row.
repo.get_or_create(
conv["id"],
user,
is_promptable=is_promptable,
first_n_queries=first_n_queries,
api_key=api_key,
prompt_id=resolved_prompt_id,
chunks=chunks,
share_uuid=share_uuid,
)
dual_write(SharedConversationsRepository, _write)
sharing_ns = Namespace(
"sharing", description="Conversation sharing operations", path="/api"
)
@@ -124,6 +187,16 @@ class ShareConversation(Resource):
"api_key": api_uuid,
}
)
_dual_write_share(
conversation_id,
str(explicit_binary.as_uuid()),
user,
is_promptable=is_promptable,
first_n_queries=current_n_queries,
api_key=api_uuid,
prompt_id=prompt_id,
chunks=int(chunks) if chunks else None,
)
return make_response(
jsonify(
{
@@ -155,6 +228,16 @@ class ShareConversation(Resource):
"api_key": api_uuid,
}
)
_dual_write_share(
conversation_id,
str(explicit_binary.as_uuid()),
user,
is_promptable=is_promptable,
first_n_queries=current_n_queries,
api_key=api_uuid,
prompt_id=prompt_id,
chunks=int(chunks) if chunks else None,
)
return make_response(
jsonify(
{
@@ -192,6 +275,14 @@ class ShareConversation(Resource):
"user": user,
}
)
_dual_write_share(
conversation_id,
str(explicit_binary.as_uuid()),
user,
is_promptable=is_promptable,
first_n_queries=current_n_queries,
api_key=None,
)
return make_response(
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}

View File

@@ -134,6 +134,12 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30),
schedule_syncs.s("monthly"),
)
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
sender.add_periodic_task(
timedelta(seconds=60),
cleanup_pending_tool_state.s(),
name="cleanup-pending-tool-state",
)
@celery.task(bind=True)
@@ -146,3 +152,27 @@ def mcp_oauth_task(self, config, user):
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp
@celery.task(bind=True)
def cleanup_pending_tool_state(self):
"""Delete pending_tool_state rows past their TTL.
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
no native TTL, so this task runs every 60 seconds to keep
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
configured (keeps the task runnable in Mongo-only environments).
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.pending_tool_state import (
PendingToolStateRepository,
)
engine = get_engine()
with engine.begin() as conn:
deleted = PendingToolStateRepository(conn).cleanup_expired()
return {"deleted": deleted}

View File

@@ -11,6 +11,10 @@ from application.api.user.base import (
workflow_nodes_collection,
workflows_collection,
)
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
@@ -35,6 +39,174 @@ def _workflow_error_response(message: str, err: Exception):
return error_response(message)
# ---------------------------------------------------------------------------
# Postgres dual-write helpers
#
# Workflows are unusual relative to other Phase 3 tables: a single user
# action (create / update) writes to three collections in concert
# (workflows + workflow_nodes + workflow_edges) and the edges reference
# nodes by user-provided string ids. The Postgres mirror needs to:
#
# 1. Run all three writes inside one PG transaction (so the just-created
# nodes are visible when we resolve their UUIDs for the edge insert).
# 2. Translate edge source_id/target_id strings → workflow_nodes.id UUIDs
# after the bulk_create returns them.
#
# Each helper opens exactly one ``dual_write`` call (one PG txn) and uses
# the connection from whichever repo it was instantiated with to spin up
# any sibling repos it needs.
# ---------------------------------------------------------------------------
def _dual_write_workflow_create(
mongo_workflow_id: str,
user_id: str,
name: str,
description: str,
nodes_data: List[Dict],
edges_data: List[Dict],
graph_version: int = 1,
) -> None:
"""Mirror a Mongo workflow create into Postgres."""
def _do(repo: WorkflowsRepository) -> None:
conn = repo._conn
wf = repo.create(
user_id,
name,
description=description,
legacy_mongo_id=mongo_workflow_id,
)
_write_graph(conn, wf["id"], graph_version, nodes_data, edges_data)
dual_write(WorkflowsRepository, _do)
def _dual_write_workflow_update(
mongo_workflow_id: str,
user_id: str,
name: str,
description: str,
nodes_data: List[Dict],
edges_data: List[Dict],
next_graph_version: int,
) -> None:
"""Mirror a Mongo workflow update into Postgres.
Mirrors the Mongo route: insert the new graph_version's nodes/edges,
bump the workflow's name/description/current_graph_version, then drop
every other graph_version's nodes/edges.
"""
def _do(repo: WorkflowsRepository) -> None:
conn = repo._conn
wf = _resolve_pg_workflow(conn, mongo_workflow_id)
if wf is None:
return
_write_graph(conn, wf["id"], next_graph_version, nodes_data, edges_data)
repo.update(wf["id"], user_id, {
"name": name,
"description": description,
"current_graph_version": next_graph_version,
})
WorkflowNodesRepository(conn).delete_other_versions(
wf["id"], next_graph_version,
)
WorkflowEdgesRepository(conn).delete_other_versions(
wf["id"], next_graph_version,
)
dual_write(WorkflowsRepository, _do)
def _dual_write_workflow_delete(mongo_workflow_id: str, user_id: str) -> None:
"""Mirror a Mongo workflow delete into Postgres.
The CASCADE on workflows.id → workflow_nodes/workflow_edges takes
care of the children automatically.
"""
def _do(repo: WorkflowsRepository) -> None:
wf = _resolve_pg_workflow(repo._conn, mongo_workflow_id)
if wf is not None:
repo.delete(wf["id"], user_id)
dual_write(WorkflowsRepository, _do)
def _resolve_pg_workflow(conn, mongo_workflow_id: str) -> Optional[Dict]:
"""Look up a Postgres workflow by its Mongo ObjectId string."""
from sqlalchemy import text as _text
row = conn.execute(
_text("SELECT id FROM workflows WHERE legacy_mongo_id = :legacy_id"),
{"legacy_id": mongo_workflow_id},
).fetchone()
return {"id": str(row[0])} if row else None
def _write_graph(
conn,
pg_workflow_id: str,
graph_version: int,
nodes_data: List[Dict],
edges_data: List[Dict],
) -> None:
"""Bulk-create nodes + edges for one graph version inside one txn.
Edges arrive with source/target as user-provided node-id strings
(the same shape the Mongo route stores). We bulk-insert nodes first,
capture their ``node_id → UUID`` map from the returned rows, then
translate edge source/target strings to those UUIDs before the edge
bulk insert. Edges referencing missing nodes are dropped (logged).
"""
nodes_repo = WorkflowNodesRepository(conn)
edges_repo = WorkflowEdgesRepository(conn)
if nodes_data:
created_nodes = nodes_repo.bulk_create(
pg_workflow_id, graph_version,
[
{
"node_id": n["id"],
"node_type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
"legacy_mongo_id": n.get("legacy_mongo_id"),
}
for n in nodes_data
],
)
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
else:
node_uuid_by_str = {}
if edges_data:
translated_edges: List[Dict] = []
for e in edges_data:
src = e.get("source")
tgt = e.get("target")
from_uuid = node_uuid_by_str.get(src)
to_uuid = node_uuid_by_str.get(tgt)
if not from_uuid or not to_uuid:
current_app.logger.warning(
"PG dual-write: dropping edge %s; node refs unresolved "
"(source=%s, target=%s)",
e.get("id"), src, tgt,
)
continue
translated_edges.append({
"edge_id": e["id"],
"from_node_id": from_uuid,
"to_node_id": to_uuid,
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
})
if translated_edges:
edges_repo.bulk_create(pg_workflow_id, graph_version, translated_edges)
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
@@ -317,24 +489,28 @@ def _can_reach_end(
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
) -> List[Dict]:
"""Insert workflow nodes into Mongo and return rows with Mongo ids."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
mongo_nodes = [
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
result = workflow_nodes_collection.insert_many(mongo_nodes)
return [
{**node, "legacy_mongo_id": str(inserted_id)}
for node, inserted_id in zip(nodes_data, result.inserted_ids)
]
return []
def create_workflow_edges(
@@ -399,7 +575,7 @@ class WorkflowList(Resource):
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as err:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
@@ -407,6 +583,15 @@ class WorkflowList(Resource):
workflows_collection.delete_one({"_id": result.inserted_id})
return _workflow_error_response("Failed to create workflow structure", err)
_dual_write_workflow_create(
workflow_id,
user_id,
name,
data.get("description", ""),
created_nodes,
edges_data,
)
return success_response({"id": workflow_id}, 201)
@@ -473,7 +658,9 @@ class WorkflowDetail(Resource):
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
created_nodes = create_workflow_nodes(
workflow_id, nodes_data, next_graph_version,
)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as err:
workflow_nodes_collection.delete_many(
@@ -520,6 +707,16 @@ class WorkflowDetail(Resource):
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
_dual_write_workflow_update(
workflow_id,
user_id,
name,
data.get("description", ""),
created_nodes,
edges_data,
next_graph_version,
)
return success_response()
@require_auth
@@ -543,4 +740,6 @@ class WorkflowDetail(Resource):
except Exception as err:
return _workflow_error_response("Failed to delete workflow", err)
_dual_write_workflow_delete(workflow_id, user_id)
return success_response()

View File

@@ -5,6 +5,14 @@ MongoDB→Postgres migration. The baseline schema in the Alembic migration
(``application/alembic/versions/0001_initial.py``) is the source of truth
for DDL; the ``Table`` definitions below must match it column-for-column.
If the two drift, migrations win — update this file to match.
Cross-table invariant not expressed in the Core ``Table`` definitions
below: every ``user_id`` column is FK-enforced against
``users(user_id)`` with ``ON DELETE RESTRICT``, and a
``BEFORE INSERT OR UPDATE OF user_id`` trigger on each child table
auto-creates the ``users`` row if it does not yet exist. See migration
``0015_user_id_fk``. The FKs are intentionally omitted from the Core
declarations to keep this file readable; the DB is the authority.
"""
from sqlalchemy import (
@@ -13,6 +21,7 @@ from sqlalchemy import (
Column,
DateTime,
ForeignKey,
ForeignKeyConstraint,
Integer,
MetaData,
UniqueConstraint,
@@ -20,7 +29,7 @@ from sqlalchemy import (
Text,
func,
)
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
metadata = MetaData()
@@ -51,6 +60,7 @@ prompts_table = Table(
Column("content", Text, nullable=False),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("legacy_mongo_id", Text),
)
user_tools_table = Table(
@@ -88,17 +98,6 @@ user_logs_table = Table(
Column("data", JSONB),
)
feedback_table = Table(
"feedback",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("conversation_id", UUID(as_uuid=True), nullable=False),
Column("user_id", Text, nullable=False),
Column("question_index", Integer, nullable=False),
Column("feedback_text", Text),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
stack_logs_table = Table(
"stack_logs",
metadata,
@@ -131,7 +130,7 @@ sources_table = Table(
"sources",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("type", Text),
Column("metadata", JSONB, nullable=False, server_default="{}"),
@@ -148,7 +147,7 @@ agents_table = Table(
Column("description", Text),
Column("agent_type", Text),
Column("status", Text, nullable=False),
Column("key", Text, unique=True),
Column("key", CITEXT, unique=True),
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
Column("chunks", Integer),
@@ -164,10 +163,11 @@ agents_table = Table(
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
Column("request_limit", Integer),
Column("shared", Boolean, nullable=False, server_default="false"),
Column("incoming_webhook_token", Text, unique=True),
Column("incoming_webhook_token", CITEXT, unique=True),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("last_used_at", DateTime(timezone=True)),
Column("legacy_mongo_id", Text),
)
attachments_table = Table(
@@ -180,6 +180,7 @@ attachments_table = Table(
Column("mime_type", Text),
Column("size", BigInteger),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("legacy_mongo_id", Text),
)
memories_table = Table(
@@ -230,3 +231,166 @@ connector_sessions_table = Table(
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("user_id", "provider", name="connector_sessions_user_provider_uidx"),
)
# --- Phase 3, Tier 3 --------------------------------------------------------
conversations_table = Table(
"conversations",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("agent_id", UUID(as_uuid=True), ForeignKey("agents.id", ondelete="SET NULL")),
Column("name", Text),
Column("api_key", Text),
Column("is_shared_usage", Boolean, nullable=False, server_default="false"),
Column("shared_token", Text),
Column("shared_with", ARRAY(Text), nullable=False, server_default="{}"),
Column("compression_metadata", JSONB),
Column("date", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("legacy_mongo_id", Text),
)
conversation_messages_table = Table(
"conversation_messages",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
# Denormalised from conversations.user_id. Auto-filled on insert by a
# BEFORE INSERT trigger when the caller omits it. See migration 0020.
Column("user_id", Text, nullable=False),
Column("position", Integer, nullable=False),
Column("prompt", Text),
Column("response", Text),
Column("thought", Text),
Column("sources", JSONB, nullable=False, server_default="[]"),
Column("tool_calls", JSONB, nullable=False, server_default="[]"),
# Postgres cannot FK-enforce array elements, so the referential
# invariant is kept by an AFTER DELETE trigger on ``attachments``
# that array_removes the id from every row that references it.
# See migration 0017_cleanup_dangling_refs.
Column("attachments", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
Column("model_id", Text),
# Renamed from ``metadata`` in migration 0016 to avoid SQLAlchemy's
# reserved attribute collision on declarative models. The repository
# translates this ↔ API dict key ``metadata`` so external callers
# still see ``metadata``.
Column("message_metadata", JSONB, nullable=False, server_default="{}"),
Column("feedback", JSONB),
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
)
shared_conversations_table = Table(
"shared_conversations",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("uuid", UUID(as_uuid=True), nullable=False, unique=True),
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
Column("user_id", Text, nullable=False),
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
Column("chunks", Integer),
Column("is_promptable", Boolean, nullable=False, server_default="false"),
Column("first_n_queries", Integer, nullable=False, server_default="0"),
Column("api_key", Text),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
pending_tool_state_table = Table(
"pending_tool_state",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
Column("user_id", Text, nullable=False),
Column("messages", JSONB, nullable=False),
Column("pending_tool_calls", JSONB, nullable=False),
Column("tools_dict", JSONB, nullable=False),
Column("tool_schemas", JSONB, nullable=False),
Column("agent_config", JSONB, nullable=False),
Column("client_tools", JSONB),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("expires_at", DateTime(timezone=True), nullable=False),
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
)
workflows_table = Table(
"workflows",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("name", Text, nullable=False),
Column("description", Text),
Column("current_graph_version", Integer, nullable=False, server_default="1"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("legacy_mongo_id", Text),
)
workflow_nodes_table = Table(
"workflow_nodes",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
Column("graph_version", Integer, nullable=False),
Column("node_id", Text, nullable=False),
Column("node_type", Text, nullable=False),
Column("title", Text),
Column("description", Text),
Column("position", JSONB, nullable=False, server_default='{"x": 0, "y": 0}'),
Column("config", JSONB, nullable=False, server_default="{}"),
Column("legacy_mongo_id", Text),
# Composite UNIQUE so workflow_edges can use a composite FK that
# enforces endpoint nodes belong to the same (workflow, version) as
# the edge itself. See migration 0008.
UniqueConstraint(
"id", "workflow_id", "graph_version",
name="workflow_nodes_id_wf_ver_key",
),
)
workflow_edges_table = Table(
"workflow_edges",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
Column("graph_version", Integer, nullable=False),
Column("edge_id", Text, nullable=False),
Column("from_node_id", UUID(as_uuid=True), nullable=False),
Column("to_node_id", UUID(as_uuid=True), nullable=False),
Column("source_handle", Text),
Column("target_handle", Text),
Column("config", JSONB, nullable=False, server_default="{}"),
# Composite FKs: endpoints must belong to the same (workflow, version)
# as the edge. Prevents cross-workflow / cross-version edges that the
# single-column FKs couldn't catch. See migration 0008.
ForeignKeyConstraint(
["from_node_id", "workflow_id", "graph_version"],
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
ondelete="CASCADE",
name="workflow_edges_from_node_fk",
),
ForeignKeyConstraint(
["to_node_id", "workflow_id", "graph_version"],
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
ondelete="CASCADE",
name="workflow_edges_to_node_fk",
),
)
workflow_runs_table = Table(
"workflow_runs",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
Column("user_id", Text, nullable=False),
Column("status", Text, nullable=False),
Column("inputs", JSONB),
Column("result", JSONB),
Column("steps", JSONB, nullable=False, server_default="[]"),
Column("started_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("ended_at", DateTime(timezone=True)),
Column("legacy_mongo_id", Text),
)

View File

@@ -12,7 +12,6 @@ the legacy Mongo code performs on ``agents_collection``:
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, func, text
@@ -26,6 +25,13 @@ class AgentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
@staticmethod
def _normalize_unique_text(col: str, val):
"""Coerce blank strings for nullable unique text columns to NULL."""
if col == "key" and val == "":
return None
return val
def create(self, user_id: str, name: str, status: str, **kwargs) -> dict:
values: dict = {"user_id": user_id, "name": name, "status": status}
@@ -35,14 +41,18 @@ class AgentsRepository:
"source_id", "prompt_id", "folder_id",
"chunks", "token_limit", "request_limit",
"limited_token_mode", "limited_request_mode", "shared",
"tools", "json_schema", "models",
"tools", "json_schema", "models", "legacy_mongo_id",
}
for col, val in kwargs.items():
if col not in _ALLOWED or val is None:
continue
if col in ("tools", "json_schema", "models"):
values[col] = json.dumps(val)
# JSONB columns: pass the Python object directly. SQLAlchemy
# Core's JSONB type processor json.dumps it once during
# bind; pre-serialising would double-encode and the value
# would round-trip as a JSON string instead of the dict.
values[col] = val
elif col in ("chunks", "token_limit", "request_limit"):
values[col] = int(val)
elif col in ("limited_token_mode", "limited_request_mode", "shared"):
@@ -50,7 +60,7 @@ class AgentsRepository:
elif col in ("source_id", "prompt_id", "folder_id"):
values[col] = str(val)
else:
values[col] = val
values[col] = self._normalize_unique_text(col, val)
stmt = pg_insert(agents_table).values(**values).returning(agents_table)
result = self._conn.execute(stmt)
@@ -64,6 +74,17 @@ class AgentsRepository:
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
"""Fetch an agent by the original Mongo ObjectId string."""
sql = "SELECT * FROM agents WHERE legacy_mongo_id = :legacy_id"
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
if user_id is not None:
sql += " AND user_id = :user_id"
params["user_id"] = user_id
result = self._conn.execute(text(sql), params)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def find_by_key(self, key: str) -> Optional[dict]:
result = self._conn.execute(
text("SELECT * FROM agents WHERE key = :key"),
@@ -108,11 +129,13 @@ class AgentsRepository:
values: dict = {}
for col, val in filtered.items():
if col in ("tools", "json_schema", "models"):
values[col] = json.dumps(val) if not isinstance(val, str) else val
# See note in create(): JSONB columns receive Python
# objects, the type processor handles serialisation.
values[col] = val
elif col in ("source_id", "prompt_id", "folder_id"):
values[col] = str(val) if val else None
else:
values[col] = val
values[col] = self._normalize_unique_text(col, val)
values["updated_at"] = func.now()
t = agents_table
@@ -125,6 +148,13 @@ class AgentsRepository:
result = self._conn.execute(stmt)
return result.rowcount > 0
def update_by_legacy_id(self, legacy_mongo_id: str, user_id: str, fields: dict) -> bool:
"""Update an agent addressed by the Mongo ObjectId string."""
agent = self.get_by_legacy_id(legacy_mongo_id, user_id)
if agent is None:
return False
return self.update(agent["id"], user_id, fields)
def delete(self, agent_id: str, user_id: str) -> bool:
result = self._conn.execute(
text("DELETE FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
@@ -132,6 +162,17 @@ class AgentsRepository:
)
return result.rowcount > 0
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
"""Delete an agent addressed by the Mongo ObjectId string."""
result = self._conn.execute(
text(
"DELETE FROM agents "
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
),
{"legacy_id": legacy_mongo_id, "user_id": user_id},
)
return result.rowcount > 0
def set_folder(self, agent_id: str, user_id: str, folder_id: Optional[str]) -> None:
self._conn.execute(
text(

View File

@@ -14,12 +14,15 @@ class AttachmentsRepository:
self._conn = conn
def create(self, user_id: str, filename: str, upload_path: str, *,
mime_type: Optional[str] = None, size: Optional[int] = None) -> dict:
mime_type: Optional[str] = None, size: Optional[int] = None,
legacy_mongo_id: Optional[str] = None) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size)
VALUES (:user_id, :filename, :upload_path, :mime_type, :size)
INSERT INTO attachments
(user_id, filename, upload_path, mime_type, size, legacy_mongo_id)
VALUES
(:user_id, :filename, :upload_path, :mime_type, :size, :legacy_mongo_id)
RETURNING *
"""
),
@@ -29,6 +32,7 @@ class AttachmentsRepository:
"upload_path": upload_path,
"mime_type": mime_type,
"size": size,
"legacy_mongo_id": legacy_mongo_id,
},
)
return row_to_dict(result.fetchone())
@@ -43,6 +47,17 @@ class AttachmentsRepository:
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
"""Fetch an attachment by the original Mongo ObjectId string."""
sql = "SELECT * FROM attachments WHERE legacy_mongo_id = :legacy_id"
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
if user_id is not None:
sql += " AND user_id = :user_id"
params["user_id"] = user_id
result = self._conn.execute(text(sql), params)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),

View File

@@ -0,0 +1,476 @@
"""Repository for the ``conversations`` and ``conversation_messages`` tables.
Covers every operation the legacy Mongo code performs on
``conversations_collection``:
- create / get / list / delete conversations
- append message (transactional position allocation)
- update message at index (overwrite + optional truncation)
- set / unset feedback on a message
- rename conversation
- update compression metadata
- shared_with access checks
"""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import Connection, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import conversations_table, conversation_messages_table
def _message_row_to_dict(row) -> dict:
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
back to the public API key ``metadata`` so callers keep the Mongo-era
shape. See migration 0016 for the column rename rationale."""
out = row_to_dict(row)
if "message_metadata" in out:
out["metadata"] = out.pop("message_metadata")
return out
class ConversationsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
# ------------------------------------------------------------------
# Conversation CRUD
# ------------------------------------------------------------------
def create(
self,
user_id: str,
name: str | None = None,
*,
agent_id: str | None = None,
api_key: str | None = None,
is_shared_usage: bool = False,
shared_token: str | None = None,
legacy_mongo_id: str | None = None,
) -> dict:
"""Create a new conversation.
``legacy_mongo_id`` is used by the dual-write shim so that a
Postgres row inserted *after* a successful Mongo insert carries
the Mongo ``_id`` as a lookup key. Subsequent appends/updates
can then resolve the PG row by that id via
:meth:`get_by_legacy_id`.
"""
values: dict = {
"user_id": user_id,
"name": name,
}
if agent_id:
values["agent_id"] = agent_id
if api_key:
values["api_key"] = api_key
if is_shared_usage:
values["is_shared_usage"] = True
if shared_token:
values["shared_token"] = shared_token
if legacy_mongo_id:
values["legacy_mongo_id"] = legacy_mongo_id
stmt = pg_insert(conversations_table).values(**values).returning(conversations_table)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get_by_legacy_id(
self, legacy_mongo_id: str, user_id: str | None = None,
) -> Optional[dict]:
"""Look up a conversation by the original Mongo ObjectId string.
Used by the dual-write helpers to translate a Mongo ``_id`` into
the Postgres UUID for follow-up writes. When ``user_id`` is
provided, the lookup is scoped to rows owned by that user so
callers can't accidentally resolve another user's conversation.
"""
if user_id is not None:
result = self._conn.execute(
text(
"SELECT * FROM conversations "
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
),
{"legacy_id": legacy_mongo_id, "user_id": user_id},
)
else:
result = self._conn.execute(
text(
"SELECT * FROM conversations WHERE legacy_mongo_id = :legacy_id"
),
{"legacy_id": legacy_mongo_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get(self, conversation_id: str, user_id: str) -> Optional[dict]:
"""Fetch a conversation the user owns or has shared access to."""
result = self._conn.execute(
text(
"SELECT * FROM conversations "
"WHERE id = CAST(:id AS uuid) "
"AND (user_id = :user_id OR :user_id = ANY(shared_with))"
),
{"id": conversation_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_owned(self, conversation_id: str, user_id: str) -> Optional[dict]:
"""Fetch a conversation owned by the user (no shared access)."""
result = self._conn.execute(
text(
"SELECT * FROM conversations "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": conversation_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str, limit: int = 30) -> list[dict]:
"""List conversations for a user, most recent first.
Mirrors the Mongo query: either no api_key or agent_id exists.
"""
result = self._conn.execute(
text(
"SELECT * FROM conversations "
"WHERE user_id = :user_id "
"AND (api_key IS NULL OR agent_id IS NOT NULL) "
"ORDER BY date DESC LIMIT :limit"
),
{"user_id": user_id, "limit": limit},
)
return [row_to_dict(r) for r in result.fetchall()]
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
result = self._conn.execute(
text(
"UPDATE conversations SET name = :name, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": conversation_id, "user_id": user_id, "name": name},
)
return result.rowcount > 0
def set_shared_token(self, conversation_id: str, user_id: str, token: str) -> bool:
result = self._conn.execute(
text(
"UPDATE conversations SET shared_token = :token, updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": conversation_id, "user_id": user_id, "token": token},
)
return result.rowcount > 0
def update_compression_metadata(
self, conversation_id: str, user_id: str, metadata: dict,
) -> bool:
"""Replace the entire ``compression_metadata`` JSONB blob.
Prefer :meth:`append_compression_point` + :meth:`set_compression_flags`
to match the Mongo service semantics exactly (those two mirror
``$set`` + ``$push $slice``). This method is retained for callers
that already compute the full merged blob client-side.
"""
result = self._conn.execute(
text(
"UPDATE conversations "
"SET compression_metadata = CAST(:meta AS jsonb), updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": conversation_id, "user_id": user_id, "meta": json.dumps(metadata)},
)
return result.rowcount > 0
def set_compression_flags(
self,
conversation_id: str,
*,
is_compressed: bool,
last_compression_at,
) -> bool:
"""Update ``compression_metadata.is_compressed`` and
``compression_metadata.last_compression_at`` without touching
``compression_points``.
Mirrors the Mongo ``$set`` on those two subfields in
``ConversationService.update_compression_metadata``. Initialises
the surrounding object when the row has no ``compression_metadata``
yet.
"""
result = self._conn.execute(
text(
"""
UPDATE conversations SET
compression_metadata = jsonb_set(
jsonb_set(
COALESCE(compression_metadata, '{}'::jsonb),
'{is_compressed}',
to_jsonb(CAST(:is_compressed AS boolean)), true
),
'{last_compression_at}',
to_jsonb(CAST(:last_compression_at AS text)), true
),
updated_at = now()
WHERE id = CAST(:id AS uuid)
"""
),
{
"id": conversation_id,
"is_compressed": bool(is_compressed),
"last_compression_at": (
str(last_compression_at) if last_compression_at is not None else None
),
},
)
return result.rowcount > 0
def append_compression_point(
self,
conversation_id: str,
point: dict,
*,
max_points: int,
) -> bool:
"""Append one compression point, keeping at most ``max_points``.
Mirrors Mongo's ``$push {"$each": [point], "$slice": -max_points}``
on ``compression_metadata.compression_points``. Preserves the
other top-level keys in ``compression_metadata``.
"""
result = self._conn.execute(
text(
"""
UPDATE conversations SET
compression_metadata = jsonb_set(
COALESCE(compression_metadata, '{}'::jsonb),
'{compression_points}',
COALESCE(
(
SELECT jsonb_agg(elem ORDER BY rn)
FROM (
SELECT
elem,
row_number() OVER () AS rn,
count(*) OVER () AS cnt
FROM jsonb_array_elements(
COALESCE(
compression_metadata -> 'compression_points',
'[]'::jsonb
) || jsonb_build_array(CAST(:point AS jsonb))
) AS elem
) ranked
WHERE rn > cnt - :max_points
),
'[]'::jsonb
),
true
),
updated_at = now()
WHERE id = CAST(:id AS uuid)
"""
),
{
"id": conversation_id,
"point": json.dumps(point, default=str),
"max_points": int(max_points),
},
)
return result.rowcount > 0
def delete(self, conversation_id: str, user_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM conversations "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": conversation_id, "user_id": user_id},
)
return result.rowcount > 0
def delete_all_for_user(self, user_id: str) -> int:
result = self._conn.execute(
text("DELETE FROM conversations WHERE user_id = :user_id"),
{"user_id": user_id},
)
return result.rowcount
# ------------------------------------------------------------------
# Messages
# ------------------------------------------------------------------
def get_messages(self, conversation_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM conversation_messages "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"ORDER BY position ASC"
),
{"conv_id": conversation_id},
)
return [_message_row_to_dict(r) for r in result.fetchall()]
def get_message_at(self, conversation_id: str, position: int) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM conversation_messages "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"AND position = :pos"
),
{"conv_id": conversation_id, "pos": position},
)
row = result.fetchone()
return _message_row_to_dict(row) if row is not None else None
def append_message(self, conversation_id: str, message: dict) -> dict:
"""Append a message to a conversation.
Uses ``SELECT ... FOR UPDATE`` to allocate the next position
atomically. The caller must be inside a transaction.
Mirrors Mongo's ``$push`` on the ``queries`` array.
"""
# Lock the parent conversation row to serialize concurrent appends.
self._conn.execute(
text(
"SELECT id FROM conversations "
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
),
{"conv_id": conversation_id},
)
next_pos_result = self._conn.execute(
text(
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
"FROM conversation_messages "
"WHERE conversation_id = CAST(:conv_id AS uuid)"
),
{"conv_id": conversation_id},
)
next_pos = next_pos_result.scalar()
values = {
"conversation_id": conversation_id,
"position": next_pos,
"prompt": message.get("prompt"),
"response": message.get("response"),
"thought": message.get("thought"),
"sources": message.get("sources") or [],
"tool_calls": message.get("tool_calls") or [],
"model_id": message.get("model_id"),
"message_metadata": message.get("metadata") or {},
}
if message.get("timestamp") is not None:
values["timestamp"] = message["timestamp"]
attachments = message.get("attachments")
if attachments:
values["attachments"] = [str(a) for a in attachments]
stmt = (
pg_insert(conversation_messages_table)
.values(**values)
.returning(conversation_messages_table)
)
result = self._conn.execute(stmt)
# Touch the parent conversation's updated_at.
self._conn.execute(
text(
"UPDATE conversations SET updated_at = now() "
"WHERE id = CAST(:id AS uuid)"
),
{"id": conversation_id},
)
return _message_row_to_dict(result.fetchone())
def update_message_at(
self, conversation_id: str, position: int, fields: dict,
) -> bool:
"""Update specific fields on a message at a given position.
Mirrors Mongo's ``$set`` on ``queries.{index}.*``.
"""
allowed = {
"prompt", "response", "thought", "sources", "tool_calls",
"attachments", "model_id", "metadata", "timestamp",
}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return False
# Map public API key ``metadata`` → DB column ``message_metadata``.
api_to_col = {"metadata": "message_metadata"}
set_parts = []
params: dict = {"conv_id": conversation_id, "pos": position}
for key, val in filtered.items():
col = api_to_col.get(key, key)
if key in ("sources", "tool_calls", "metadata"):
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
params[col] = json.dumps(val) if not isinstance(val, str) else val
elif key == "attachments":
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
params[col] = [str(a) for a in val] if val else []
else:
set_parts.append(f"{col} = :{col}")
params[col] = val
if "timestamp" not in filtered:
set_parts.append("timestamp = now()")
sql = (
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
)
result = self._conn.execute(text(sql), params)
return result.rowcount > 0
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
"""Delete messages with position > keep_up_to.
Mirrors Mongo's ``$push`` + ``$slice`` that trims queries after an
index-based update.
"""
result = self._conn.execute(
text(
"DELETE FROM conversation_messages "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"AND position > :pos"
),
{"conv_id": conversation_id, "pos": keep_up_to},
)
return result.rowcount
def set_feedback(
self, conversation_id: str, position: int, feedback: dict | None,
) -> bool:
"""Set or unset feedback on a message.
``feedback`` is a JSONB value, e.g. ``{"text": "thumbs_up",
"timestamp": "..."}`` or ``None`` to unset.
"""
fb_json = json.dumps(feedback) if feedback is not None else None
result = self._conn.execute(
text(
"UPDATE conversation_messages "
"SET feedback = CAST(:fb AS jsonb) "
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
),
{"conv_id": conversation_id, "pos": position, "fb": fb_json},
)
return result.rowcount > 0
def message_count(self, conversation_id: str) -> int:
result = self._conn.execute(
text(
"SELECT COUNT(*) FROM conversation_messages "
"WHERE conversation_id = CAST(:conv_id AS uuid)"
),
{"conv_id": conversation_id},
)
return result.scalar() or 0

View File

@@ -1,57 +0,0 @@
"""Repository for the ``feedback`` table.
The ``feedback_collection`` global is declared in ``base.py`` but currently
has zero direct call sites in the application code (all feedback writes go
through ``conversation_messages.feedback`` JSONB field on the conversations
collection). The table exists for when feedback is denormalized into its own
rows. This repository provides the append-only insert and basic reads
needed for that future.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
class FeedbackRepository:
"""Postgres-backed replacement for Mongo ``feedback_collection``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(
self,
conversation_id: str,
user_id: str,
question_index: int,
feedback_text: Optional[str] = None,
) -> dict:
result = self._conn.execute(
text(
"""
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text)
VALUES (CAST(:conversation_id AS uuid), :user_id, :question_index, :feedback_text)
RETURNING *
"""
),
{
"conversation_id": conversation_id,
"user_id": user_id,
"question_index": question_index,
"feedback_text": feedback_text,
},
)
return row_to_dict(result.fetchone())
def list_for_conversation(self, conversation_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM feedback WHERE conversation_id = CAST(:cid AS uuid) ORDER BY question_index"
),
{"cid": conversation_id},
)
return [row_to_dict(r) for r in result.fetchall()]

View File

@@ -0,0 +1,128 @@
"""Repository for the ``pending_tool_state`` table.
Mirrors the continuation service's three operations on
``pending_tool_state`` in Mongo:
- save_state → upsert (INSERT ... ON CONFLICT DO UPDATE)
- load_state → find_one by (conversation_id, user_id)
- delete_state → delete_one by (conversation_id, user_id)
Plus a cleanup method for the Celery beat task that replaces Mongo's
TTL index.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
class PendingToolStateRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def save_state(
self,
conversation_id: str,
user_id: str,
*,
messages: list,
pending_tool_calls: list,
tools_dict: dict,
tool_schemas: list,
agent_config: dict,
client_tools: list | None = None,
ttl_seconds: int = PENDING_STATE_TTL_SECONDS,
) -> dict:
"""Upsert pending tool state.
Mirrors Mongo's ``replace_one(..., upsert=True)``.
"""
now = datetime.now(timezone.utc)
expires = datetime.fromtimestamp(
now.timestamp() + ttl_seconds, tz=timezone.utc,
)
result = self._conn.execute(
text(
"""
INSERT INTO pending_tool_state
(conversation_id, user_id, messages, pending_tool_calls,
tools_dict, tool_schemas, agent_config, client_tools,
created_at, expires_at)
VALUES
(CAST(:conv_id AS uuid), :user_id,
CAST(:messages AS jsonb), CAST(:pending AS jsonb),
CAST(:tools_dict AS jsonb), CAST(:schemas AS jsonb),
CAST(:agent_config AS jsonb), CAST(:client_tools AS jsonb),
:created_at, :expires_at)
ON CONFLICT (conversation_id, user_id) DO UPDATE SET
messages = EXCLUDED.messages,
pending_tool_calls = EXCLUDED.pending_tool_calls,
tools_dict = EXCLUDED.tools_dict,
tool_schemas = EXCLUDED.tool_schemas,
agent_config = EXCLUDED.agent_config,
client_tools = EXCLUDED.client_tools,
created_at = EXCLUDED.created_at,
expires_at = EXCLUDED.expires_at
RETURNING *
"""
),
{
"conv_id": conversation_id,
"user_id": user_id,
"messages": json.dumps(messages),
"pending": json.dumps(pending_tool_calls),
"tools_dict": json.dumps(tools_dict),
"schemas": json.dumps(tool_schemas),
"agent_config": json.dumps(agent_config),
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
"created_at": now,
"expires_at": expires,
},
)
return row_to_dict(result.fetchone())
def load_state(self, conversation_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM pending_tool_state "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"AND user_id = :user_id"
),
{"conv_id": conversation_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def delete_state(self, conversation_id: str, user_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM pending_tool_state "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"AND user_id = :user_id"
),
{"conv_id": conversation_id, "user_id": user_id},
)
return result.rowcount > 0
def cleanup_expired(self) -> int:
"""Delete rows where ``expires_at < now()``.
Replaces Mongo's ``expireAfterSeconds=0`` TTL index. Intended to
be called from a Celery beat task every 60 seconds.
"""
# clock_timestamp() — not now() — since the latter is frozen to the
# start of the transaction, which would let state that has just
# expired survive one more cleanup tick.
result = self._conn.execute(
text("DELETE FROM pending_tool_state WHERE expires_at < clock_timestamp()")
)
return result.rowcount

View File

@@ -27,16 +27,27 @@ class PromptsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, user_id: str, name: str, content: str) -> dict:
def create(
self,
user_id: str,
name: str,
content: str,
*,
legacy_mongo_id: str | None = None,
) -> dict:
sql = """
INSERT INTO prompts (user_id, name, content, legacy_mongo_id)
VALUES (:user_id, :name, :content, :legacy_mongo_id)
RETURNING *
"""
result = self._conn.execute(
text(
"""
INSERT INTO prompts (user_id, name, content)
VALUES (:user_id, :name, :content)
RETURNING *
"""
),
{"user_id": user_id, "name": name, "content": content},
text(sql),
{
"user_id": user_id,
"name": name,
"content": content,
"legacy_mongo_id": legacy_mongo_id,
},
)
return row_to_dict(result.fetchone())
@@ -48,6 +59,17 @@ class PromptsRepository:
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
"""Fetch a prompt by the original Mongo ObjectId string."""
sql = "SELECT * FROM prompts WHERE legacy_mongo_id = :legacy_id"
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
if user_id is not None:
sql += " AND user_id = :user_id"
params["user_id"] = user_id
result = self._conn.execute(text(sql), params)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
"""Fetch prompt content by ID without user scoping.
@@ -80,12 +102,48 @@ class PromptsRepository:
{"id": prompt_id, "user_id": user_id, "name": name, "content": content},
)
def update_by_legacy_id(
self,
legacy_mongo_id: str,
user_id: str,
name: str,
content: str,
) -> bool:
"""Update a prompt addressed by the Mongo ObjectId string."""
result = self._conn.execute(
text(
"""
UPDATE prompts
SET name = :name, content = :content, updated_at = now()
WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id
"""
),
{
"legacy_id": legacy_mongo_id,
"user_id": user_id,
"name": name,
"content": content,
},
)
return result.rowcount > 0
def delete(self, prompt_id: str, user_id: str) -> None:
self._conn.execute(
text("DELETE FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
{"id": prompt_id, "user_id": user_id},
)
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
"""Delete a prompt addressed by the Mongo ObjectId string."""
result = self._conn.execute(
text(
"DELETE FROM prompts "
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
),
{"legacy_id": legacy_mongo_id, "user_id": user_id},
)
return result.rowcount > 0
def find_or_create(self, user_id: str, name: str, content: str) -> dict:
"""Return existing prompt matching (user, name, content), or create one.

View File

@@ -0,0 +1,205 @@
"""Repository for the ``shared_conversations`` table.
Covers the sharing operations from ``shared_conversations_collections``
in Mongo:
- create a share record (with UUID, conversation_id, user, visibility flags)
- look up by uuid (public access)
- look up by conversation_id + user + flags (dedup check)
"""
from __future__ import annotations
import uuid as uuid_mod
from typing import Optional
from sqlalchemy import Connection, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import shared_conversations_table
class SharedConversationsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(
self,
conversation_id: str,
user_id: str,
*,
is_promptable: bool = False,
first_n_queries: int = 0,
api_key: str | None = None,
prompt_id: str | None = None,
chunks: int | None = None,
share_uuid: str | None = None,
) -> dict:
"""Create a share record.
``share_uuid`` allows the dual-write caller to supply the same
UUID that Mongo received, so public ``/shared/{uuid}`` links
keep resolving from both stores during the dual-write window.
Callers that need race-free dedup on the logical share key
should use :meth:`get_or_create` instead — it relies on the
composite partial unique index added in migration 0008 to
collapse concurrent requests to a single row.
"""
final_uuid = share_uuid or str(uuid_mod.uuid4())
values: dict = {
"uuid": final_uuid,
"conversation_id": conversation_id,
"user_id": user_id,
"is_promptable": is_promptable,
"first_n_queries": first_n_queries,
}
if api_key:
values["api_key"] = api_key
if prompt_id:
values["prompt_id"] = prompt_id
if chunks is not None:
values["chunks"] = chunks
stmt = (
pg_insert(shared_conversations_table)
.values(**values)
.returning(shared_conversations_table)
)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get_or_create(
self,
conversation_id: str,
user_id: str,
*,
is_promptable: bool = False,
first_n_queries: int = 0,
api_key: str | None = None,
prompt_id: str | None = None,
chunks: int | None = None,
share_uuid: str | None = None,
) -> dict:
"""Race-free share create/lookup keyed on the logical dedup tuple.
Leverages the partial unique index on
``(conversation_id, user_id, is_promptable, first_n_queries,
COALESCE(api_key, ''))`` added in migration 0008. Concurrent
requests for the same logical share converge on one row. The
returned dict's ``uuid`` is the canonical public identifier.
Dedup key rationale — ``prompt_id`` and ``chunks`` are
deliberately *not* part of the uniqueness key. A share row is
identified by "who shared what conversation under which
visibility rules"; ``prompt_id`` / ``chunks`` are mutable
properties of that share and are last-write-wins on re-share.
This preserves existing public ``/shared/{uuid}`` URLs when a
user updates the prompt or chunk count, matching the Mongo
``find_one`` + ``update`` semantics.
"""
final_uuid = share_uuid or str(uuid_mod.uuid4())
result = self._conn.execute(
text(
"""
INSERT INTO shared_conversations
(uuid, conversation_id, user_id, is_promptable,
first_n_queries, api_key, prompt_id, chunks)
VALUES
(CAST(:uuid AS uuid), CAST(:conversation_id AS uuid),
:user_id, :is_promptable, :first_n_queries,
:api_key, CAST(:prompt_id AS uuid), :chunks)
ON CONFLICT (conversation_id, user_id, is_promptable,
first_n_queries, COALESCE(api_key, ''))
DO UPDATE SET prompt_id = EXCLUDED.prompt_id,
chunks = EXCLUDED.chunks
RETURNING *
"""
),
{
"uuid": final_uuid,
"conversation_id": conversation_id,
"user_id": user_id,
"is_promptable": is_promptable,
"first_n_queries": first_n_queries,
"api_key": api_key,
"prompt_id": prompt_id,
"chunks": chunks,
},
)
return row_to_dict(result.fetchone())
def find_by_uuid(self, share_uuid: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM shared_conversations "
"WHERE uuid = CAST(:uuid AS uuid)"
),
{"uuid": share_uuid},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def find_existing(
self,
conversation_id: str,
user_id: str,
is_promptable: bool,
first_n_queries: int,
api_key: str | None = None,
) -> Optional[dict]:
"""Check for an existing share with matching parameters.
Mirrors the Mongo ``find_one`` dedup check before creating a share.
"""
if api_key:
result = self._conn.execute(
text(
"SELECT * FROM shared_conversations "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"AND user_id = :user_id "
"AND is_promptable = :is_promptable "
"AND first_n_queries = :fnq "
"AND api_key = :api_key "
"LIMIT 1"
),
{
"conv_id": conversation_id,
"user_id": user_id,
"is_promptable": is_promptable,
"fnq": first_n_queries,
"api_key": api_key,
},
)
else:
result = self._conn.execute(
text(
"SELECT * FROM shared_conversations "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"AND user_id = :user_id "
"AND is_promptable = :is_promptable "
"AND first_n_queries = :fnq "
"AND api_key IS NULL "
"LIMIT 1"
),
{
"conv_id": conversation_id,
"user_id": user_id,
"is_promptable": is_promptable,
"fnq": first_n_queries,
},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_conversation(self, conversation_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM shared_conversations "
"WHERE conversation_id = CAST(:conv_id AS uuid) "
"ORDER BY created_at DESC"
),
{"conv_id": conversation_id},
)
return [row_to_dict(r) for r in result.fetchall()]

View File

@@ -15,7 +15,7 @@ class SourcesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(self, name: str, *, user_id: Optional[str] = None,
def create(self, name: str, *, user_id: str,
type: Optional[str] = None, metadata: Optional[dict] = None) -> dict:
result = self._conn.execute(
text(
@@ -55,12 +55,12 @@ class SourcesRepository:
if not filtered:
return
values: dict = {}
for col, val in filtered.items():
if col == "metadata":
values[col] = json.dumps(val) if isinstance(val, dict) else val
else:
values[col] = val
# Pass Python objects directly for JSONB columns when using
# SQLAlchemy Core .update() — the JSONB type processor json.dumps
# them itself; pre-serialising here would double-encode and the
# value would round-trip as a JSON string instead of the original
# dict.
values: dict = dict(filtered)
values["updated_at"] = func.now()
t = sources_table

View File

@@ -46,7 +46,7 @@ class UserLogsRepository:
{
"user_id": user_id,
"endpoint": endpoint,
"data": json.dumps(data) if data is not None else None,
"data": json.dumps(data, default=str) if data is not None else None,
"timestamp": timestamp,
},
)

View File

@@ -0,0 +1,170 @@
"""Repository for the ``workflow_edges`` table.
Covers bulk insert, find by version, and delete operations that the
workflow routes perform on ``workflow_edges_collection`` in Mongo.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import workflow_edges_table
class WorkflowEdgesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(
self,
workflow_id: str,
graph_version: int,
edge_id: str,
from_node_id: str,
to_node_id: str,
*,
source_handle: str | None = None,
target_handle: str | None = None,
config: dict | None = None,
) -> dict:
"""Create a single edge.
``from_node_id`` and ``to_node_id`` are the Postgres **UUID PKs**
of the workflow_nodes rows (not user-provided node_id strings).
"""
values: dict = {
"workflow_id": workflow_id,
"graph_version": graph_version,
"edge_id": edge_id,
"from_node_id": from_node_id,
"to_node_id": to_node_id,
}
if source_handle is not None:
values["source_handle"] = source_handle
if target_handle is not None:
values["target_handle"] = target_handle
if config is not None:
values["config"] = config
stmt = pg_insert(workflow_edges_table).values(**values).returning(workflow_edges_table)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def bulk_create(
self,
workflow_id: str,
graph_version: int,
edges: list[dict],
) -> list[dict]:
"""Insert multiple edges in one statement.
Each element must have ``edge_id``, ``from_node_id`` (UUID PK),
``to_node_id`` (UUID PK). Optional: ``source_handle``,
``target_handle``, ``config``.
"""
if not edges:
return []
rows = []
for e in edges:
rows.append({
"workflow_id": workflow_id,
"graph_version": graph_version,
"edge_id": e["edge_id"],
"from_node_id": e["from_node_id"],
"to_node_id": e["to_node_id"],
"source_handle": e.get("source_handle"),
"target_handle": e.get("target_handle"),
"config": e.get("config", {}),
})
stmt = pg_insert(workflow_edges_table).values(rows).returning(workflow_edges_table)
result = self._conn.execute(stmt)
return [row_to_dict(r) for r in result.fetchall()]
def find_by_version(
self, workflow_id: str, graph_version: int,
) -> list[dict]:
"""List edges for a workflow/version, shaped to match the live API.
Joins ``workflow_nodes`` twice so callers receive the user-provided
node-id strings (``source_id``/``target_id``) that the Mongo code
and the frontend use, not the internal node UUIDs. The raw UUID
columns (``from_node_id``/``to_node_id``) are still included in
case a caller needs them.
"""
result = self._conn.execute(
text(
"""
SELECT e.*,
fn.node_id AS source_id,
tn.node_id AS target_id
FROM workflow_edges e
JOIN workflow_nodes fn ON fn.id = e.from_node_id
JOIN workflow_nodes tn ON tn.id = e.to_node_id
WHERE e.workflow_id = CAST(:wf_id AS uuid)
AND e.graph_version = :ver
ORDER BY e.edge_id
"""
),
{"wf_id": workflow_id, "ver": graph_version},
)
return [row_to_dict(r) for r in result.fetchall()]
def resolve_node_id(
self, workflow_id: str, graph_version: int, node_id: str,
) -> Optional[str]:
"""Look up the UUID PK of a node by its user-provided ``node_id``.
Callers that receive edges in the frontend shape (``source_id`` /
``target_id`` are user-provided strings) use this helper to
translate to the UUID PK before calling :meth:`create` /
:meth:`bulk_create`.
"""
result = self._conn.execute(
text(
"SELECT id FROM workflow_nodes "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"AND graph_version = :ver AND node_id = :node_id"
),
{"wf_id": workflow_id, "ver": graph_version, "node_id": node_id},
)
row = result.fetchone()
return str(row[0]) if row else None
def delete_by_workflow(self, workflow_id: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM workflow_edges "
"WHERE workflow_id = CAST(:wf_id AS uuid)"
),
{"wf_id": workflow_id},
)
return result.rowcount
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
result = self._conn.execute(
text(
"DELETE FROM workflow_edges "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"AND graph_version = :ver"
),
{"wf_id": workflow_id, "ver": graph_version},
)
return result.rowcount
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
"""Delete all edges for a workflow except the specified version."""
result = self._conn.execute(
text(
"DELETE FROM workflow_edges "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"AND graph_version != :ver"
),
{"wf_id": workflow_id, "ver": keep_version},
)
return result.rowcount

View File

@@ -0,0 +1,158 @@
"""Repository for the ``workflow_nodes`` table.
Covers bulk insert, find by version, and delete operations that the
workflow routes perform on ``workflow_nodes_collection`` in Mongo.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import workflow_nodes_table
class WorkflowNodesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(
self,
workflow_id: str,
graph_version: int,
node_id: str,
node_type: str,
*,
title: str | None = None,
description: str | None = None,
position: dict | None = None,
config: dict | None = None,
legacy_mongo_id: str | None = None,
) -> dict:
values: dict = {
"workflow_id": workflow_id,
"graph_version": graph_version,
"node_id": node_id,
"node_type": node_type,
}
if title is not None:
values["title"] = title
if description is not None:
values["description"] = description
if position is not None:
values["position"] = position
if config is not None:
values["config"] = config
if legacy_mongo_id is not None:
values["legacy_mongo_id"] = legacy_mongo_id
stmt = pg_insert(workflow_nodes_table).values(**values).returning(workflow_nodes_table)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def bulk_create(
self,
workflow_id: str,
graph_version: int,
nodes: list[dict],
) -> list[dict]:
"""Insert multiple nodes in one statement.
Each element of ``nodes`` should have at least ``node_id`` and
``node_type``; optional keys: ``title``, ``description``,
``position``, ``config``.
"""
if not nodes:
return []
rows = []
for n in nodes:
rows.append({
"workflow_id": workflow_id,
"graph_version": graph_version,
"node_id": n["node_id"],
"node_type": n["node_type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("config", {}),
"legacy_mongo_id": n.get("legacy_mongo_id"),
})
stmt = pg_insert(workflow_nodes_table).values(rows).returning(workflow_nodes_table)
result = self._conn.execute(stmt)
return [row_to_dict(r) for r in result.fetchall()]
def find_by_version(
self, workflow_id: str, graph_version: int,
) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM workflow_nodes "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"AND graph_version = :ver "
"ORDER BY node_id"
),
{"wf_id": workflow_id, "ver": graph_version},
)
return [row_to_dict(r) for r in result.fetchall()]
def find_node(
self, workflow_id: str, graph_version: int, node_id: str,
) -> Optional[dict]:
"""Find a single node by its user-provided ``node_id``."""
result = self._conn.execute(
text(
"SELECT * FROM workflow_nodes "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"AND graph_version = :ver AND node_id = :nid"
),
{"wf_id": workflow_id, "ver": graph_version, "nid": node_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
"""Find a node by the original Mongo ObjectId string."""
result = self._conn.execute(
text("SELECT * FROM workflow_nodes WHERE legacy_mongo_id = :legacy_id"),
{"legacy_id": legacy_mongo_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def delete_by_workflow(self, workflow_id: str) -> int:
result = self._conn.execute(
text(
"DELETE FROM workflow_nodes "
"WHERE workflow_id = CAST(:wf_id AS uuid)"
),
{"wf_id": workflow_id},
)
return result.rowcount
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
result = self._conn.execute(
text(
"DELETE FROM workflow_nodes "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"AND graph_version = :ver"
),
{"wf_id": workflow_id, "ver": graph_version},
)
return result.rowcount
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
"""Delete all nodes for a workflow except the specified version."""
result = self._conn.execute(
text(
"DELETE FROM workflow_nodes "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"AND graph_version != :ver"
),
{"wf_id": workflow_id, "ver": keep_version},
)
return result.rowcount

View File

@@ -0,0 +1,83 @@
"""Repository for the ``workflow_runs`` table.
In Mongo, workflow_runs_collection only has ``insert_one`` — runs are
written once after workflow execution completes and never updated.
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import workflow_runs_table
class WorkflowRunsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(
self,
workflow_id: str,
user_id: str,
status: str,
*,
inputs: dict | None = None,
result: dict | None = None,
steps: list | None = None,
started_at=None,
ended_at=None,
legacy_mongo_id: str | None = None,
) -> dict:
values: dict = {
"workflow_id": workflow_id,
"user_id": user_id,
"status": status,
}
if inputs is not None:
values["inputs"] = inputs
if result is not None:
values["result"] = result
if steps is not None:
values["steps"] = steps
if started_at is not None:
values["started_at"] = started_at
if ended_at is not None:
values["ended_at"] = ended_at
if legacy_mongo_id is not None:
values["legacy_mongo_id"] = legacy_mongo_id
stmt = pg_insert(workflow_runs_table).values(**values).returning(workflow_runs_table)
res = self._conn.execute(stmt)
return row_to_dict(res.fetchone())
def get(self, run_id: str) -> Optional[dict]:
res = self._conn.execute(
text("SELECT * FROM workflow_runs WHERE id = CAST(:id AS uuid)"),
{"id": run_id},
)
row = res.fetchone()
return row_to_dict(row) if row is not None else None
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
"""Fetch a workflow run by the original Mongo ObjectId string."""
res = self._conn.execute(
text("SELECT * FROM workflow_runs WHERE legacy_mongo_id = :legacy_id"),
{"legacy_id": legacy_mongo_id},
)
row = res.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_workflow(self, workflow_id: str) -> list[dict]:
res = self._conn.execute(
text(
"SELECT * FROM workflow_runs "
"WHERE workflow_id = CAST(:wf_id AS uuid) "
"ORDER BY started_at DESC"
),
{"wf_id": workflow_id},
)
return [row_to_dict(r) for r in res.fetchall()]

View File

@@ -0,0 +1,125 @@
"""Repository for the ``workflows`` table.
Covers CRUD on workflow metadata:
- create / get / list / update / delete
- graph version management
"""
from __future__ import annotations
from typing import Optional
from sqlalchemy import Connection, text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import workflows_table
class WorkflowsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
def create(
self,
user_id: str,
name: str,
description: str | None = None,
*,
legacy_mongo_id: str | None = None,
) -> dict:
values: dict = {"user_id": user_id, "name": name}
if description is not None:
values["description"] = description
if legacy_mongo_id is not None:
values["legacy_mongo_id"] = legacy_mongo_id
stmt = pg_insert(workflows_table).values(**values).returning(workflows_table)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get(self, workflow_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM workflows "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": workflow_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_by_id(self, workflow_id: str) -> Optional[dict]:
"""Fetch a workflow by ID without user check (for internal use)."""
result = self._conn.execute(
text("SELECT * FROM workflows WHERE id = CAST(:id AS uuid)"),
{"id": workflow_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def get_by_legacy_id(
self, legacy_mongo_id: str, user_id: str | None = None,
) -> Optional[dict]:
"""Fetch a workflow by its original Mongo ObjectId string."""
sql = "SELECT * FROM workflows WHERE legacy_mongo_id = :legacy_id"
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
if user_id is not None:
sql += " AND user_id = :user_id"
params["user_id"] = user_id
result = self._conn.execute(text(sql), params)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM workflows "
"WHERE user_id = :user_id ORDER BY created_at DESC"
),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, workflow_id: str, user_id: str, fields: dict) -> bool:
allowed = {"name", "description", "current_graph_version"}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return False
set_parts = [f"{col} = :{col}" for col in filtered]
set_parts.append("updated_at = now()")
params = {**filtered, "id": workflow_id, "user_id": user_id}
sql = (
f"UPDATE workflows SET {', '.join(set_parts)} "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
)
result = self._conn.execute(text(sql), params)
return result.rowcount > 0
def increment_graph_version(self, workflow_id: str, user_id: str) -> Optional[int]:
"""Atomically increment ``current_graph_version`` and return the new value."""
result = self._conn.execute(
text(
"UPDATE workflows "
"SET current_graph_version = current_graph_version + 1, "
" updated_at = now() "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
"RETURNING current_graph_version"
),
{"id": workflow_id, "user_id": user_id},
)
row = result.fetchone()
return row[0] if row else None
def delete(self, workflow_id: str, user_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM workflows "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": workflow_id, "user_id": user_id},
)
return result.rowcount > 0

View File

@@ -1178,8 +1178,8 @@ def attachment_worker(self, file_info, user):
dual_write(
AttachmentsRepository,
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type: repo.create(
u, fn, p, mime_type=mt,
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type, mid=attachment_id: repo.create(
u, fn, p, mime_type=mt, legacy_mongo_id=mid,
),
)

File diff suppressed because it is too large Load Diff

View File

@@ -29,8 +29,8 @@
"clsx": "^2.1.1",
"cmdk": "^1.1.1",
"copy-to-clipboard": "^3.3.3",
"i18next": "^25.8.18",
"i18next-browser-languagedetector": "^8.2.0",
"i18next": "^26.0.4",
"i18next-browser-languagedetector": "^8.2.1",
"lodash": "^4.17.21",
"lucide-react": "^0.562.0",
"mermaid": "^11.12.1",
@@ -41,7 +41,7 @@
"react-dom": "^19.1.1",
"react-dropzone": "^14.3.8",
"react-google-drive-picker": "^1.2.2",
"react-i18next": "^16.2.4",
"react-i18next": "^17.0.2",
"react-markdown": "^9.0.1",
"react-redux": "^9.2.0",
"react-router-dom": "^7.6.1",
@@ -58,7 +58,7 @@
"@types/react": "^19.1.8",
"@types/react-dom": "^19.1.7",
"@types/react-syntax-highlighter": "^15.5.13",
"@typescript-eslint/eslint-plugin": "^8.46.3",
"@typescript-eslint/eslint-plugin": "^8.58.2",
"@typescript-eslint/parser": "^8.46.3",
"@vitejs/plugin-react": "^6.0.1",
"eslint": "^9.39.1",
@@ -73,7 +73,7 @@
"lint-staged": "^16.4.0",
"postcss": "^8.4.49",
"prettier": "^3.5.3",
"prettier-plugin-tailwindcss": "^0.7.1",
"prettier-plugin-tailwindcss": "^0.7.2",
"tailwindcss": "^4.2.1",
"tw-animate-css": "^1.4.0",
"typescript": "^5.8.3",

View File

@@ -32,6 +32,7 @@ import argparse
import json
import logging
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable
@@ -52,6 +53,45 @@ logger = logging.getLogger("backfill")
# ---------------------------------------------------------------------------
def _extract_mongo_id_text(value: Any) -> str | None:
"""Return a Mongo ObjectId-like value as text across legacy shapes.
Handles raw ObjectId values, DBRef-like objects exposing ``.id``, and
dict encodings such as ``{"$id": {"$oid": "..."}}`` that show up in
imported / normalised BSON payloads.
"""
if value is None:
return None
if isinstance(value, dict):
if "$id" in value:
return _extract_mongo_id_text(value["$id"])
if "_id" in value:
return _extract_mongo_id_text(value["_id"])
if "$oid" in value:
return str(value["$oid"])
return None
ref_id = getattr(value, "id", None)
if ref_id is not None:
return _extract_mongo_id_text(ref_id)
return str(value)
def _coerce_document_timestamp(doc: dict[str, Any], *keys: str):
"""Return the first populated timestamp-like field from ``doc``.
Mongo user data is not fully uniform across older deployments. Some
records only carry ``created_at`` / ``updated_at`` and a few legacy
documents have no explicit timestamp at all. In that final case we
fall back to "now" so the backfill can preserve the row instead of
failing a NOT NULL constraint.
"""
for key in keys:
value = doc.get(key)
if value is not None:
return value
return datetime.now(timezone.utc)
def _backfill_users(
*,
conn: Connection,
@@ -123,9 +163,13 @@ def _backfill_prompts(
) -> dict:
upsert_sql = text(
"""
INSERT INTO prompts (user_id, name, content)
VALUES (:user_id, :name, :content)
ON CONFLICT DO NOTHING
INSERT INTO prompts (user_id, name, content, legacy_mongo_id)
VALUES (:user_id, :name, :content, :legacy_mongo_id)
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
DO UPDATE SET
name = EXCLUDED.name,
content = EXCLUDED.content,
updated_at = now()
"""
)
cursor = mongo_db["prompts"].find({}, no_cursor_timeout=True).batch_size(batch_size)
@@ -142,6 +186,7 @@ def _backfill_prompts(
"user_id": user_id,
"name": doc.get("name", ""),
"content": doc.get("content", ""),
"legacy_mongo_id": str(doc["_id"]),
})
if len(batch) >= batch_size:
if not dry_run:
@@ -198,48 +243,6 @@ def _backfill_user_tools(
return {"seen": seen, "written": written, "skipped_no_user": skipped}
def _backfill_feedback(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
insert_sql = text(
"""
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text, timestamp)
VALUES (CAST(:conversation_id AS uuid), :user_id, :question_index, :feedback_text, :timestamp)
ON CONFLICT DO NOTHING
"""
)
cursor = mongo_db["feedback"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user")
conv_id = doc.get("conversation_id")
if not user_id or not conv_id:
skipped += 1
continue
batch.append({
"conversation_id": str(conv_id),
"user_id": user_id,
"question_index": doc.get("question_index", 0),
"feedback_text": doc.get("feedback_text"),
"timestamp": doc.get("timestamp"),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_stack_logs(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
@@ -468,15 +471,32 @@ def _backfill_agents(
chunks, retriever, default_model_id,
tools, json_schema, models,
limited_token_mode, token_limit, limited_request_mode, request_limit,
shared, incoming_webhook_token
shared, incoming_webhook_token, legacy_mongo_id
) VALUES (
:user_id, :name, :status, :key, :description, :agent_type,
:chunks, :retriever, :default_model_id,
CAST(:tools AS jsonb), CAST(:json_schema AS jsonb), CAST(:models AS jsonb),
:limited_token_mode, :token_limit, :limited_request_mode, :request_limit,
:shared, :incoming_webhook_token
:shared, :incoming_webhook_token, :legacy_mongo_id
)
ON CONFLICT DO NOTHING
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
DO UPDATE SET
name = EXCLUDED.name,
status = EXCLUDED.status,
description = EXCLUDED.description,
agent_type = EXCLUDED.agent_type,
chunks = EXCLUDED.chunks,
retriever = EXCLUDED.retriever,
default_model_id = EXCLUDED.default_model_id,
tools = EXCLUDED.tools,
json_schema = EXCLUDED.json_schema,
models = EXCLUDED.models,
limited_token_mode = EXCLUDED.limited_token_mode,
token_limit = EXCLUDED.token_limit,
limited_request_mode = EXCLUDED.limited_request_mode,
request_limit = EXCLUDED.request_limit,
shared = EXCLUDED.shared,
updated_at = now()
"""
)
cursor = mongo_db["agents"].find({}, no_cursor_timeout=True).batch_size(batch_size)
@@ -493,7 +513,11 @@ def _backfill_agents(
"user_id": user_id,
"name": doc.get("name", ""),
"status": doc.get("status", "draft"),
"key": doc.get("key"),
# Mongo allows multiple agents with key="" but Postgres
# CITEXT UNIQUE treats them as a collision. Coerce empty
# strings to NULL so the unique constraint is only
# enforced for actual API keys.
"key": (doc.get("key") or None),
"description": doc.get("description"),
"agent_type": doc.get("agent_type"),
"chunks": doc.get("chunks"),
@@ -508,6 +532,7 @@ def _backfill_agents(
"request_limit": doc.get("request_limit"),
"shared": bool(doc.get("shared", False)),
"incoming_webhook_token": doc.get("incoming_webhook_token"),
"legacy_mongo_id": str(doc["_id"]),
})
if len(batch) >= batch_size:
if not dry_run:
@@ -528,8 +553,16 @@ def _backfill_attachments(
) -> dict:
insert_sql = text(
"""
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size)
VALUES (:user_id, :filename, :upload_path, :mime_type, :size)
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size,
legacy_mongo_id)
VALUES (:user_id, :filename, :upload_path, :mime_type, :size,
:legacy_mongo_id)
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
DO UPDATE SET
filename = EXCLUDED.filename,
upload_path = EXCLUDED.upload_path,
mime_type = EXCLUDED.mime_type,
size = EXCLUDED.size
"""
)
cursor = mongo_db["attachments"].find({}, no_cursor_timeout=True).batch_size(batch_size)
@@ -548,6 +581,7 @@ def _backfill_attachments(
"upload_path": doc.get("upload_path", ""),
"mime_type": doc.get("mime_type"),
"size": doc.get("size"),
"legacy_mongo_id": str(doc["_id"]),
})
if len(batch) >= batch_size:
if not dry_run:
@@ -780,6 +814,696 @@ def _backfill_connector_sessions(
return {"seen": seen, "written": written, "skipped": skipped}
# ---------------------------------------------------------------------------
# Phase 3 backfillers
# ---------------------------------------------------------------------------
def _backfill_conversations(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``conversations`` table from Mongo ``conversations`` collection.
Also flattens the nested ``queries`` array into
``conversation_messages`` rows (one per query, position = array index).
Idempotent via the ``legacy_mongo_id`` column: rerunning replaces any
previously migrated row's mutable fields and re-syncs its messages.
"""
agent_id_map = _build_legacy_id_map(conn, "agents")
attachment_id_map = _build_legacy_id_map(conn, "attachments")
conv_sql = text(
"""
INSERT INTO conversations
(user_id, name, agent_id, api_key, is_shared_usage, shared_token,
shared_with, compression_metadata, date, legacy_mongo_id)
VALUES
(:user_id, :name, CAST(:agent_id AS uuid), :api_key,
:is_shared_usage, :shared_token,
CAST(:shared_with AS text[]), CAST(:compression_metadata AS jsonb),
:date, :legacy_mongo_id)
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
DO UPDATE SET
name = EXCLUDED.name,
agent_id = EXCLUDED.agent_id,
api_key = EXCLUDED.api_key,
is_shared_usage = EXCLUDED.is_shared_usage,
shared_token = EXCLUDED.shared_token,
shared_with = EXCLUDED.shared_with,
compression_metadata = EXCLUDED.compression_metadata,
updated_at = now()
RETURNING id
"""
)
truncate_sql = text(
"""
DELETE FROM conversation_messages
WHERE conversation_id = CAST(:conv_id AS uuid)
AND position > :max_pos
"""
)
msg_sql = text(
"""
INSERT INTO conversation_messages
(conversation_id, position, prompt, response, thought,
sources, tool_calls, attachments, model_id, message_metadata, feedback,
timestamp)
VALUES
(CAST(:conv_id AS uuid), :position, :prompt, :response, :thought,
CAST(:sources AS jsonb), CAST(:tool_calls AS jsonb),
CAST(:attachments AS uuid[]),
:model_id, CAST(:metadata AS jsonb), CAST(:feedback AS jsonb),
:timestamp)
ON CONFLICT (conversation_id, position) DO UPDATE SET
prompt = EXCLUDED.prompt,
response = EXCLUDED.response,
thought = EXCLUDED.thought,
sources = EXCLUDED.sources,
tool_calls = EXCLUDED.tool_calls,
attachments = EXCLUDED.attachments,
model_id = EXCLUDED.model_id,
message_metadata = EXCLUDED.message_metadata,
feedback = EXCLUDED.feedback,
timestamp = EXCLUDED.timestamp
"""
)
cursor = mongo_db["conversations"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = msg_written = skipped = 0
malformed_messages = 0
unresolved_attachment_refs = 0
try:
for doc in cursor:
seen += 1
user_id = doc.get("user")
if not user_id:
skipped += 1
continue
shared_with = doc.get("shared_with") or []
comp_meta = doc.get("compression_metadata")
if dry_run:
# In dry-run we don't write, so we can't get a returning id.
# Skip message insertion too — they need the FK.
continue
mongo_agent_id = doc.get("agent_id")
pg_agent_id = agent_id_map.get(str(mongo_agent_id)) if mongo_agent_id else None
result = conn.execute(conv_sql, {
"user_id": user_id,
"name": doc.get("name"),
"agent_id": pg_agent_id,
"api_key": doc.get("api_key"),
"is_shared_usage": bool(doc.get("is_shared_usage", False)),
"shared_token": doc.get("shared_token"),
"shared_with": list(shared_with),
"compression_metadata": json.dumps(comp_meta) if comp_meta else None,
"date": _coerce_document_timestamp(doc, "date", "created_at", "updated_at"),
"legacy_mongo_id": str(doc["_id"]),
})
pg_conv_id = str(result.scalar())
written += 1
# Flatten queries array → conversation_messages rows
queries = doc.get("queries") or []
msg_batch: list[dict] = []
for pos, q in enumerate(queries):
if not isinstance(q, dict):
malformed_messages += 1
logger.warning(
"Skipping malformed conversation query during backfill: "
"conversation=%s position=%s type=%s",
doc.get("_id"),
pos,
type(q).__name__,
)
continue
fb = q.get("feedback")
fb_ts = q.get("feedback_timestamp")
feedback_json = None
if fb is not None:
feedback_json = json.dumps({"text": fb, "timestamp": str(fb_ts)} if fb_ts else {"text": fb})
# Resolve attachment ObjectIds → Postgres UUIDs; drop unresolved.
raw_attachments = q.get("attachments") or []
resolved_attachments: list[str] = []
for a in raw_attachments:
if not a:
continue
s = str(a)
if len(s) == 36 and "-" in s:
resolved_attachments.append(s)
else:
mapped = attachment_id_map.get(s)
if mapped:
resolved_attachments.append(mapped)
else:
unresolved_attachment_refs += 1
logger.warning(
"Conversation backfill dropped unresolved attachment ref: "
"conversation=%s position=%s attachment=%s",
doc.get("_id"),
pos,
s,
)
msg_batch.append({
"conv_id": pg_conv_id,
"position": pos,
"prompt": q.get("prompt"),
"response": q.get("response"),
"thought": q.get("thought"),
"sources": json.dumps(q.get("sources") or []),
"tool_calls": json.dumps(q.get("tool_calls") or []),
"attachments": resolved_attachments,
"model_id": q.get("model_id"),
"metadata": json.dumps(q.get("metadata") or {}),
"feedback": feedback_json,
"timestamp": (
q.get("timestamp")
or doc.get("date")
or doc.get("created_at")
or doc.get("updated_at")
or datetime.now(timezone.utc)
),
})
if msg_batch and not dry_run:
conn.execute(msg_sql, msg_batch)
msg_written += len(msg_batch)
# Converge: drop any messages past the Mongo queries length
# (handles the case where a conversation was truncated in Mongo
# after a previous backfill).
if not dry_run:
conn.execute(truncate_sql, {
"conv_id": pg_conv_id,
"max_pos": len(queries) - 1,
})
finally:
cursor.close()
return {
"seen": seen,
"written": written,
"messages_written": msg_written,
"skipped": skipped,
"malformed_messages": malformed_messages,
"unresolved_attachment_refs": unresolved_attachment_refs,
}
def _build_legacy_id_map(conn: Connection, table: str) -> dict[str, str]:
"""Return ``{legacy_mongo_id: pg_uuid}`` for the given table.
Used by Phase 3 backfills to resolve FK references that were Mongo
ObjectIds in the source data into the new Postgres UUIDs.
"""
rows = conn.execute(
text(
f"SELECT id, legacy_mongo_id FROM {table} "
"WHERE legacy_mongo_id IS NOT NULL"
)
).fetchall()
return {r._mapping["legacy_mongo_id"]: str(r._mapping["id"]) for r in rows}
def _backfill_shared_conversations(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``shared_conversations`` table.
Resolves Mongo ``conversation_id`` (ObjectId) → Postgres
``conversations.id`` (UUID) via the ``conversations.legacy_mongo_id``
column populated during the conversations backfill. Rows whose
parent conversation was not migrated are skipped.
"""
conv_id_map = _build_legacy_id_map(conn, "conversations")
prompt_id_map = _build_legacy_id_map(conn, "prompts")
agent_meta_by_key = {
doc.get("key"): {
"prompt_id": doc.get("prompt_id"),
"chunks": doc.get("chunks"),
}
for doc in mongo_db["agents"].find({}, {"key": 1, "prompt_id": 1, "chunks": 1})
if doc.get("key")
}
insert_sql = text(
"""
INSERT INTO shared_conversations
(uuid, conversation_id, user_id, is_promptable, first_n_queries,
api_key, prompt_id, chunks)
VALUES
(CAST(:uuid AS uuid), CAST(:conv_id AS uuid), :user_id,
:is_promptable, :first_n_queries, :api_key,
CAST(:prompt_id AS uuid), :chunks)
ON CONFLICT (uuid) DO UPDATE SET
conversation_id = EXCLUDED.conversation_id,
user_id = EXCLUDED.user_id,
is_promptable = EXCLUDED.is_promptable,
first_n_queries = EXCLUDED.first_n_queries,
api_key = EXCLUDED.api_key,
prompt_id = EXCLUDED.prompt_id,
chunks = EXCLUDED.chunks
"""
)
cursor = (
mongo_db["shared_conversations"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user")
mongo_conv_id = _extract_mongo_id_text(doc.get("conversation_id"))
mongo_uuid = doc.get("uuid")
if not user_id or not mongo_conv_id or not mongo_uuid:
skipped += 1
continue
pg_conv_id = conv_id_map.get(mongo_conv_id)
if not pg_conv_id:
skipped += 1
continue
# Mongo stores ``uuid`` as BSON Binary (standard UUID subtype).
# Unwrap to a plain uuid.UUID → string for Postgres CAST.
try:
share_uuid_str = str(mongo_uuid.as_uuid()) if hasattr(mongo_uuid, "as_uuid") else str(mongo_uuid)
except Exception:
share_uuid_str = str(mongo_uuid)
# prompt_id may be either a prompt ObjectId or the literal string
# "default" (see sharing/routes.py); only resolvable ObjectIds
# get a real FK value.
agent_meta = agent_meta_by_key.get(doc.get("api_key")) or {}
raw_prompt_id = doc.get("prompt_id")
if raw_prompt_id is None:
raw_prompt_id = agent_meta.get("prompt_id")
prompt_legacy_id = _extract_mongo_id_text(raw_prompt_id)
resolved_prompt_id = (
prompt_id_map.get(prompt_legacy_id) if prompt_legacy_id else None
)
chunks_raw = doc.get("chunks")
if chunks_raw is None:
chunks_raw = agent_meta.get("chunks")
chunks_val: int | None = None
if chunks_raw is not None:
try:
chunks_val = int(chunks_raw)
except (TypeError, ValueError):
chunks_val = None
batch.append({
"uuid": share_uuid_str,
"conv_id": pg_conv_id,
"user_id": user_id,
"is_promptable": bool(doc.get("isPromptable", False)),
"first_n_queries": doc.get("first_n_queries", 0),
"api_key": doc.get("api_key"),
"prompt_id": resolved_prompt_id,
"chunks": chunks_val,
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_pending_tool_state(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync ``pending_tool_state`` from Mongo.
Most rows will be expired by the time the backfill runs (30-min TTL).
We copy them anyway; the Celery cleanup task will purge stale rows on
its first tick. Resolves ``conversation_id`` via
``conversations.legacy_mongo_id``.
"""
conv_id_map = _build_legacy_id_map(conn, "conversations")
insert_sql = text(
"""
INSERT INTO pending_tool_state
(conversation_id, user_id, messages, pending_tool_calls,
tools_dict, tool_schemas, agent_config, client_tools,
created_at, expires_at)
VALUES
(CAST(:conv_id AS uuid), :user_id,
CAST(:messages AS jsonb), CAST(:pending AS jsonb),
CAST(:tools_dict AS jsonb), CAST(:schemas AS jsonb),
CAST(:agent_config AS jsonb), CAST(:client_tools AS jsonb),
:created_at, :expires_at)
ON CONFLICT (conversation_id, user_id) DO UPDATE SET
messages = EXCLUDED.messages,
pending_tool_calls = EXCLUDED.pending_tool_calls,
tools_dict = EXCLUDED.tools_dict,
tool_schemas = EXCLUDED.tool_schemas,
agent_config = EXCLUDED.agent_config,
client_tools = EXCLUDED.client_tools,
created_at = EXCLUDED.created_at,
expires_at = EXCLUDED.expires_at
"""
)
cursor = (
mongo_db["pending_tool_state"]
.find({}, no_cursor_timeout=True)
.batch_size(batch_size)
)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
mongo_conv_id = doc.get("conversation_id")
user_id = doc.get("user")
if not mongo_conv_id or not user_id:
skipped += 1
continue
pg_conv_id = conv_id_map.get(str(mongo_conv_id))
if not pg_conv_id:
skipped += 1
continue
batch.append({
"conv_id": pg_conv_id,
"user_id": user_id,
"messages": json.dumps(doc.get("messages") or [], default=str),
"pending": json.dumps(doc.get("pending_tool_calls") or [], default=str),
"tools_dict": json.dumps(doc.get("tools_dict") or {}, default=str),
"schemas": json.dumps(doc.get("tool_schemas") or [], default=str),
"agent_config": json.dumps(doc.get("agent_config") or {}, default=str),
"client_tools": json.dumps(doc.get("client_tools"), default=str) if doc.get("client_tools") else None,
"created_at": doc.get("created_at"),
"expires_at": doc.get("expires_at"),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_workflows(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``workflows`` table from Mongo ``workflows`` collection.
Idempotent via ``legacy_mongo_id``.
"""
insert_sql = text(
"""
INSERT INTO workflows (user_id, name, description, current_graph_version,
legacy_mongo_id)
VALUES (:user_id, :name, :description, :current_graph_version,
:legacy_mongo_id)
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description,
current_graph_version = EXCLUDED.current_graph_version,
updated_at = now()
"""
)
cursor = mongo_db["workflows"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
user_id = doc.get("user")
if not user_id:
skipped += 1
continue
batch.append({
"user_id": user_id,
"name": doc.get("name", ""),
"description": doc.get("description"),
"current_graph_version": doc.get("current_graph_version", 1),
"legacy_mongo_id": str(doc["_id"]),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_workflow_nodes(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync ``workflow_nodes``.
Resolves Mongo ``workflow_id`` (string ObjectId) →
``workflows.id`` (UUID) via ``workflows.legacy_mongo_id``.
Idempotent via ``legacy_mongo_id``.
"""
workflow_id_map = _build_legacy_id_map(conn, "workflows")
insert_sql = text(
"""
INSERT INTO workflow_nodes
(workflow_id, graph_version, node_id, node_type, title, description,
position, config, legacy_mongo_id)
VALUES
(CAST(:workflow_id AS uuid), :graph_version, :node_id, :node_type,
:title, :description, CAST(:position AS jsonb), CAST(:config AS jsonb),
:legacy_mongo_id)
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
DO UPDATE SET
graph_version = EXCLUDED.graph_version,
node_id = EXCLUDED.node_id,
node_type = EXCLUDED.node_type,
title = EXCLUDED.title,
description = EXCLUDED.description,
position = EXCLUDED.position,
config = EXCLUDED.config
"""
)
cursor = mongo_db["workflow_nodes"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
mongo_wf_id = doc.get("workflow_id")
if not mongo_wf_id:
skipped += 1
continue
pg_wf_id = workflow_id_map.get(str(mongo_wf_id))
if not pg_wf_id:
skipped += 1
continue
position = doc.get("position") or {"x": 0, "y": 0}
batch.append({
"workflow_id": pg_wf_id,
"graph_version": doc.get("graph_version", 1),
"node_id": doc.get("id", ""),
"node_type": doc.get("type", ""),
"title": doc.get("title"),
"description": doc.get("description"),
"position": json.dumps(position),
"config": json.dumps(doc.get("config") or {}),
"legacy_mongo_id": str(doc["_id"]),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_workflow_edges(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``workflow_edges`` table from Mongo ``workflow_edges`` collection.
Must run after ``workflow_nodes`` because ``from_node_id`` and
``to_node_id`` are FKs into ``workflow_nodes``.
The Mongo doc stores ``source_id`` and ``target_id`` as user-provided
node-id strings. We need to resolve them to Postgres UUIDs by looking
up the ``workflow_nodes`` row with matching ``(workflow_id,
graph_version, node_id)``.
"""
workflow_id_map = _build_legacy_id_map(conn, "workflows")
# Build a lookup: (pg_workflow_uuid, graph_version, node_id_str) → pg node UUID
pg_nodes = conn.execute(
text("SELECT id, workflow_id, graph_version, node_id FROM workflow_nodes")
).fetchall()
node_lookup: dict[tuple[str, int, str], str] = {}
for row in pg_nodes:
m = row._mapping
node_lookup[(str(m["workflow_id"]), m["graph_version"], m["node_id"])] = str(m["id"])
insert_sql = text(
"""
INSERT INTO workflow_edges
(workflow_id, graph_version, edge_id, from_node_id, to_node_id,
source_handle, target_handle, config)
VALUES
(CAST(:workflow_id AS uuid), :graph_version, :edge_id,
CAST(:from_node_id AS uuid), CAST(:to_node_id AS uuid),
:source_handle, :target_handle, CAST(:config AS jsonb))
ON CONFLICT (workflow_id, graph_version, edge_id) DO UPDATE SET
from_node_id = EXCLUDED.from_node_id,
to_node_id = EXCLUDED.to_node_id,
source_handle = EXCLUDED.source_handle,
target_handle = EXCLUDED.target_handle,
config = EXCLUDED.config
"""
)
cursor = mongo_db["workflow_edges"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
mongo_wf_id = doc.get("workflow_id")
if not mongo_wf_id:
skipped += 1
continue
pg_wf_id = workflow_id_map.get(str(mongo_wf_id))
if not pg_wf_id:
skipped += 1
continue
gv = doc.get("graph_version", 1)
source_nid = doc.get("source_id", "")
target_nid = doc.get("target_id", "")
from_uuid = node_lookup.get((pg_wf_id, gv, source_nid))
to_uuid = node_lookup.get((pg_wf_id, gv, target_nid))
if not from_uuid or not to_uuid:
skipped += 1
continue
batch.append({
"workflow_id": pg_wf_id,
"graph_version": gv,
"edge_id": doc.get("id", ""),
"from_node_id": from_uuid,
"to_node_id": to_uuid,
"source_handle": doc.get("source_handle"),
"target_handle": doc.get("target_handle"),
"config": json.dumps(doc.get("config") or {}),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
def _backfill_workflow_runs(
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
) -> dict:
"""Sync the ``workflow_runs`` table from Mongo ``workflow_runs`` collection.
Resolves Mongo ``workflow_id`` (string) → PG UUID via
``workflows.legacy_mongo_id``. Rows whose parent workflow was never
migrated (e.g. legacy ``workflow_id='unknown'``) are skipped.
"""
workflow_id_map = _build_legacy_id_map(conn, "workflows")
insert_sql = text(
"""
INSERT INTO workflow_runs
(workflow_id, user_id, status, inputs, result, steps,
started_at, ended_at, legacy_mongo_id)
VALUES
(CAST(:workflow_id AS uuid), :user_id, :status,
CAST(:inputs AS jsonb), CAST(:result AS jsonb),
CAST(:steps AS jsonb), :started_at, :ended_at, :legacy_mongo_id)
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
DO UPDATE SET
status = EXCLUDED.status,
inputs = EXCLUDED.inputs,
result = EXCLUDED.result,
steps = EXCLUDED.steps,
ended_at = EXCLUDED.ended_at
"""
)
cursor = mongo_db["workflow_runs"].find({}, no_cursor_timeout=True).batch_size(batch_size)
seen = written = skipped = 0
batch: list[dict] = []
try:
for doc in cursor:
seen += 1
mongo_wf_id = doc.get("workflow_id")
if not mongo_wf_id:
skipped += 1
continue
pg_wf_id = workflow_id_map.get(str(mongo_wf_id))
if not pg_wf_id:
skipped += 1
continue
batch.append({
"workflow_id": pg_wf_id,
"user_id": doc.get("user_id") or doc.get("user") or "",
"status": doc.get("status", "unknown"),
"inputs": json.dumps(doc.get("inputs") or {}, default=str),
"result": json.dumps(doc.get("outputs") or doc.get("result"), default=str),
"steps": json.dumps(doc.get("steps") or [], default=str),
"started_at": doc.get("started_at") or doc.get("created_at"),
"ended_at": doc.get("ended_at") or doc.get("completed_at"),
"legacy_mongo_id": str(doc["_id"]),
})
if len(batch) >= batch_size:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
batch.clear()
if batch:
if not dry_run:
conn.execute(insert_sql, batch)
written += len(batch)
finally:
cursor.close()
return {"seen": seen, "written": written, "skipped": skipped}
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
@@ -796,7 +1520,6 @@ BACKFILLERS: dict[str, BackfillFn] = {
"users": _backfill_users,
"prompts": _backfill_prompts,
"user_tools": _backfill_user_tools,
"feedback": _backfill_feedback,
"stack_logs": _backfill_stack_logs,
"user_logs": _backfill_user_logs,
"token_usage": _backfill_token_usage,
@@ -809,6 +1532,14 @@ BACKFILLERS: dict[str, BackfillFn] = {
"todos": _backfill_todos,
"notes": _backfill_notes,
"connector_sessions": _backfill_connector_sessions,
# Phase 3 (order: conversations first, then dependents)
"conversations": _backfill_conversations,
"shared_conversations": _backfill_shared_conversations,
"pending_tool_state": _backfill_pending_tool_state,
"workflows": _backfill_workflows,
"workflow_nodes": _backfill_workflow_nodes,
"workflow_edges": _backfill_workflow_edges,
"workflow_runs": _backfill_workflow_runs,
}

View File

@@ -336,6 +336,34 @@ class TestSaveWorkflowRun:
agent._save_workflow_run("query")
mock_collection.insert_one.assert_called_once()
saved_doc = mock_collection.insert_one.call_args.args[0]
assert saved_doc["user"] == "user1"
assert saved_doc["user_id"] == "user1"
@pytest.mark.unit
def test_dual_writes_when_mongo_insert_returns_id(self):
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011")
mock_engine = MagicMock()
mock_engine.state = {"query": "test"}
mock_engine.execution_log = []
mock_engine.get_execution_summary.return_value = []
agent._engine = mock_engine
insert_result = MagicMock()
insert_result.inserted_id = "507f1f77bcf86cd799439012"
mock_collection = MagicMock()
mock_collection.insert_one.return_value = insert_result
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings, \
patch("application.agents.workflow_agent.dual_write") as mock_dual_write:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
agent._save_workflow_run("query")
mock_dual_write.assert_called_once()
@pytest.mark.unit
def test_exception_does_not_propagate(self):

View File

@@ -973,14 +973,21 @@ class TestUpdateAgent:
agent_id = ObjectId()
existing = self._make_existing_agent(agent_id)
mock_col = Mock()
mock_repo = Mock()
mock_col.find_one.return_value = existing
mock_col.update_one.return_value = Mock(matched_count=1, modified_count=1)
mock_handle_img = Mock(return_value=("", None))
def _run_dual_write(_repo_cls, fn):
fn(mock_repo)
with patch(
"application.api.user.agents.routes.agents_collection", mock_col
), patch(
"application.api.user.agents.routes.handle_image_upload", mock_handle_img
), patch(
"application.api.user.agents.routes.dual_write",
side_effect=_run_dual_write,
):
with app.test_request_context(
f"/api/update_agent/{agent_id}",
@@ -993,6 +1000,11 @@ class TestUpdateAgent:
response = UpdateAgent().put(str(agent_id))
assert response.status_code == 200
assert response.json["success"] is True
mock_repo.update_by_legacy_id.assert_called_once_with(
str(agent_id),
"user1",
{"name": "Updated Name"},
)
def test_returns_400_invalid_status(self, app):
from application.api.user.agents.routes import UpdateAgent
@@ -1945,14 +1957,21 @@ class TestDeleteAgent:
agent_id = ObjectId()
mock_col = Mock()
mock_repo = Mock()
mock_col.find_one_and_delete.return_value = {
"_id": agent_id,
"user": "user1",
"agent_type": "classic",
}
def _run_dual_write(_repo_cls, fn):
fn(mock_repo)
with patch(
"application.api.user.agents.routes.agents_collection", mock_col
), patch(
"application.api.user.agents.routes.dual_write",
side_effect=_run_dual_write,
):
with app.test_request_context(f"/api/delete_agent?id={agent_id}"):
from flask import request
@@ -1961,6 +1980,7 @@ class TestDeleteAgent:
response = DeleteAgent().delete()
assert response.status_code == 200
assert response.json["id"] == str(agent_id)
mock_repo.delete_by_legacy_id.assert_called_once_with(str(agent_id), "user1")
def test_deletes_workflow_agent_cleans_up(self, app):
from application.api.user.agents.routes import DeleteAgent

View File

@@ -18,12 +18,19 @@ class TestCreatePrompt:
from application.api.user.prompts.routes import CreatePrompt
mock_collection = Mock()
mock_repo = Mock()
inserted_id = ObjectId()
mock_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
def _run_dual_write(_repo_cls, fn):
fn(mock_repo)
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
), patch(
"application.api.user.prompts.routes.dual_write",
side_effect=_run_dual_write,
):
with app.test_request_context(
"/api/create_prompt",
@@ -41,6 +48,12 @@ class TestCreatePrompt:
doc = mock_collection.insert_one.call_args[0][0]
assert doc["name"] == "My Prompt"
assert doc["user"] == "user1"
mock_repo.create.assert_called_once_with(
"user1",
"My Prompt",
"You are helpful.",
legacy_mongo_id=str(inserted_id),
)
def test_returns_401_unauthenticated(self, app):
from application.api.user.prompts.routes import CreatePrompt
@@ -204,10 +217,17 @@ class TestDeletePrompt:
prompt_id = ObjectId()
mock_collection = Mock()
mock_repo = Mock()
def _run_dual_write(_repo_cls, fn):
fn(mock_repo)
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
), patch(
"application.api.user.prompts.routes.dual_write",
side_effect=_run_dual_write,
):
with app.test_request_context(
"/api/delete_prompt",
@@ -224,6 +244,7 @@ class TestDeletePrompt:
mock_collection.delete_one.assert_called_once_with(
{"_id": prompt_id, "user": "user1"}
)
mock_repo.delete_by_legacy_id.assert_called_once_with(str(prompt_id), "user1")
def test_returns_400_missing_id(self, app):
from application.api.user.prompts.routes import DeletePrompt
@@ -249,10 +270,17 @@ class TestUpdatePrompt:
prompt_id = ObjectId()
mock_collection = Mock()
mock_repo = Mock()
def _run_dual_write(_repo_cls, fn):
fn(mock_repo)
with patch(
"application.api.user.prompts.routes.prompts_collection",
mock_collection,
), patch(
"application.api.user.prompts.routes.dual_write",
side_effect=_run_dual_write,
):
with app.test_request_context(
"/api/update_prompt",
@@ -271,6 +299,12 @@ class TestUpdatePrompt:
assert response.status_code == 200
assert response.json["success"] is True
mock_collection.update_one.assert_called_once()
mock_repo.update_by_legacy_id.assert_called_once_with(
str(prompt_id),
"user1",
"Updated",
"New content",
)
def test_returns_400_missing_fields(self, app):
from application.api.user.prompts.routes import UpdatePrompt

View File

@@ -200,7 +200,7 @@ class TestSetupPeriodicTasks:
setup_periodic_tasks(sender)
assert sender.add_periodic_task.call_count == 3
assert sender.add_periodic_task.call_count == 4
calls = sender.add_periodic_task.call_args_list
@@ -210,6 +210,8 @@ class TestSetupPeriodicTasks:
assert calls[1][0][0] == timedelta(weeks=1)
# monthly
assert calls[2][0][0] == timedelta(days=30)
# pending_tool_state TTL cleanup (60s)
assert calls[3][0][0] == timedelta(seconds=60)
class TestMcpOauthTask:

View File

@@ -886,7 +886,13 @@ class TestWorkflowListPost:
mock_wf_collection = Mock()
mock_wf_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
mock_nodes_collection = Mock()
mock_nodes_collection.insert_many.return_value = Mock(
inserted_ids=[ObjectId(), ObjectId()]
)
mock_edges_collection = Mock()
mock_edges_collection.insert_many.return_value = Mock(
inserted_ids=[ObjectId()]
)
with patch(
"application.api.user.workflows.routes.workflows_collection",
@@ -1152,7 +1158,13 @@ class TestWorkflowDetailPut:
}
mock_wf_collection.update_one.return_value = Mock()
mock_nodes_collection = Mock()
mock_nodes_collection.insert_many.return_value = Mock(
inserted_ids=[ObjectId(), ObjectId()]
)
mock_edges_collection = Mock()
mock_edges_collection.insert_many.return_value = Mock(
inserted_ids=[ObjectId()]
)
with patch(
"application.api.user.workflows.routes.workflows_collection",

View File

@@ -27,18 +27,15 @@ from application.core.settings import settings
def _run_alembic_upgrade(engine):
"""Run ``alembic upgrade head`` to ensure the full schema is present.
Falls back to inline DDL for CI environments where alembic is not
on PATH (shouldn't happen, but defence in depth).
Non-zero exit is re-raised so genuine schema-drift bugs surface as
test failures. If alembic reports the schema is already at head,
the subprocess still exits zero.
"""
alembic_ini = Path(__file__).resolve().parents[3] / "application" / "alembic.ini"
try:
subprocess.check_call(
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
timeout=30,
)
except Exception:
# Alembic failed — tables likely already exist from a prior run.
pass
subprocess.check_call(
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
timeout=60,
)
@pytest.fixture(scope="session")

View File

@@ -44,6 +44,21 @@ class TestCreate:
doc = repo.create("u", "a", "draft")
assert doc["_id"] == doc["id"]
def test_create_with_legacy_mongo_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create(
"u",
"a",
"draft",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
def test_create_normalizes_blank_key_to_null(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("u", "a", "draft", key="")
assert doc["key"] is None
class TestGet:
def test_get_existing(self, pg_conn):
@@ -61,6 +76,17 @@ class TestGet:
created = repo.create("user-1", "a", "draft")
assert repo.get(created["id"], "user-other") is None
def test_get_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create(
"user-1",
"a",
"draft",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
assert fetched["id"] == created["id"]
class TestFindByKey:
def test_finds_agent_by_key(self, pg_conn):
@@ -108,6 +134,31 @@ class TestUpdate:
updated = repo.update(created["id"], "user-1", {"id": "bad"})
assert updated is False
def test_update_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
repo.create(
"user-1",
"old",
"draft",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
updated = repo.update_by_legacy_id(
"507f1f77bcf86cd799439011",
"user-1",
{"name": "new", "last_used_at": None},
)
assert updated is True
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
assert fetched["name"] == "new"
def test_update_normalizes_blank_key_to_null(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old", "draft", key="my-unique-key")
updated = repo.update(created["id"], "user-1", {"key": ""})
assert updated is True
fetched = repo.get(created["id"], "user-1")
assert fetched["key"] is None
class TestDelete:
def test_deletes_agent(self, pg_conn):
@@ -124,6 +175,18 @@ class TestDelete:
assert deleted is False
assert repo.get(created["id"], "user-1") is not None
def test_delete_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create(
"user-1",
"a",
"draft",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
deleted = repo.delete_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
assert deleted is True
assert repo.get(created["id"], "user-1") is None
class TestSetFolder:
def test_assigns_folder(self, pg_conn):

View File

@@ -37,6 +37,16 @@ class TestCreate:
doc = repo.create("u", "f", "/p")
assert doc["_id"] == doc["id"]
def test_create_with_legacy_mongo_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create(
"u",
"f",
"/p",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
class TestGet:
def test_get_existing(self, pg_conn):
@@ -54,6 +64,17 @@ class TestGet:
created = repo.create("u", "f", "/p")
assert repo.get(created["id"], "other") is None
def test_get_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create(
"u",
"f",
"/p",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "u")
assert fetched["id"] == created["id"]
class TestListForUser:
def test_lists_only_own_attachments(self, pg_conn):

View File

@@ -0,0 +1,374 @@
"""Tests for ConversationsRepository against a real Postgres instance."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from application.storage.db.repositories.conversations import ConversationsRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> ConversationsRepository:
return ConversationsRepository(conn)
# ------------------------------------------------------------------
# Conversation CRUD
# ------------------------------------------------------------------
class TestCreate:
def test_creates_conversation(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "My Chat")
assert doc["user_id"] == "user-1"
assert doc["name"] == "My Chat"
assert doc["id"] is not None
assert doc["_id"] == doc["id"]
def test_create_with_agent(self, pg_conn):
from application.storage.db.repositories.agents import AgentsRepository
agent_repo = AgentsRepository(pg_conn)
agent = agent_repo.create("user-1", "a", "active")
repo = _repo(pg_conn)
doc = repo.create(
"user-1", "Chat",
agent_id=agent["id"],
api_key="ak-123",
is_shared_usage=True,
shared_token="tok-abc",
)
assert str(doc["agent_id"]) == agent["id"]
assert doc["api_key"] == "ak-123"
assert doc["is_shared_usage"] is True
assert doc["shared_token"] == "tok-abc"
class TestGet:
def test_get_owned(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "c")
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
def test_get_nonexistent(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
def test_get_wrong_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "c")
assert repo.get(created["id"], "user-other") is None
class TestListForUser:
def test_lists_own_conversations(self, pg_conn):
repo = _repo(pg_conn)
repo.create("alice", "c1")
repo.create("alice", "c2")
repo.create("bob", "c3")
results = repo.list_for_user("alice")
assert len(results) == 2
assert all(r["user_id"] == "alice" for r in results)
def test_excludes_api_key_without_agent(self, pg_conn):
repo = _repo(pg_conn)
repo.create("alice", "normal")
repo.create("alice", "api-only", api_key="key-1")
results = repo.list_for_user("alice")
assert len(results) == 1
assert results[0]["name"] == "normal"
class TestRename:
def test_renames(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old")
assert repo.rename(created["id"], "user-1", "new") is True
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "new"
def test_rename_wrong_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old")
assert repo.rename(created["id"], "user-other", "new") is False
class TestDelete:
def test_deletes(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "c")
assert repo.delete(created["id"], "user-1") is True
assert repo.get(created["id"], "user-1") is None
def test_delete_wrong_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "c")
assert repo.delete(created["id"], "user-other") is False
def test_delete_cascades_messages(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "hi", "response": "hello"})
repo.delete(conv["id"], "user-1")
assert repo.get_messages(conv["id"]) == []
def test_delete_all_for_user(self, pg_conn):
repo = _repo(pg_conn)
repo.create("user-1", "c1")
repo.create("user-1", "c2")
repo.create("user-2", "c3")
count = repo.delete_all_for_user("user-1")
assert count == 2
assert repo.list_for_user("user-1") == []
assert len(repo.list_for_user("user-2")) == 1
# ------------------------------------------------------------------
# Messages
# ------------------------------------------------------------------
class TestAppendMessage:
def test_append_first_message(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
msg = repo.append_message(conv["id"], {
"prompt": "hello",
"response": "hi there",
"model_id": "gpt-4",
})
assert msg["position"] == 0
assert msg["prompt"] == "hello"
assert msg["response"] == "hi there"
assert msg["model_id"] == "gpt-4"
def test_append_increments_position(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
m0 = repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
m1 = repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
assert m0["position"] == 0
assert m1["position"] == 1
def test_append_with_sources_and_tools(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
msg = repo.append_message(conv["id"], {
"prompt": "q",
"response": "a",
"sources": [{"title": "doc1"}],
"tool_calls": [{"name": "search", "args": {}}],
"metadata": {"search_query": "rewritten"},
})
assert msg["sources"] == [{"title": "doc1"}]
assert msg["tool_calls"] == [{"name": "search", "args": {}}]
assert msg["metadata"] == {"search_query": "rewritten"}
def test_append_preserves_explicit_timestamp(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
ts = datetime.now(timezone.utc)
msg = repo.append_message(conv["id"], {
"prompt": "q",
"response": "a",
"timestamp": ts,
})
assert msg["timestamp"] == ts
class TestGetMessages:
def test_returns_ordered_messages(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
msgs = repo.get_messages(conv["id"])
assert len(msgs) == 2
assert msgs[0]["position"] == 0
assert msgs[1]["position"] == 1
def test_get_message_at(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
msg = repo.get_message_at(conv["id"], 1)
assert msg["prompt"] == "q2"
def test_get_message_at_nonexistent(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
assert repo.get_message_at(conv["id"], 99) is None
class TestUpdateMessageAt:
def test_updates_response(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "q", "response": "old"})
assert repo.update_message_at(conv["id"], 0, {"response": "new"}) is True
msg = repo.get_message_at(conv["id"], 0)
assert msg["response"] == "new"
def test_update_disallowed_field(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
assert repo.update_message_at(conv["id"], 0, {"id": "bad"}) is False
def test_updates_explicit_timestamp(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "q", "response": "old"})
ts = datetime.now(timezone.utc)
assert repo.update_message_at(
conv["id"], 0, {"response": "new", "timestamp": ts},
) is True
msg = repo.get_message_at(conv["id"], 0)
assert msg["response"] == "new"
assert msg["timestamp"] == ts
class TestTruncateAfter:
def test_truncates_messages(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
for i in range(5):
repo.append_message(conv["id"], {"prompt": f"q{i}", "response": f"a{i}"})
deleted = repo.truncate_after(conv["id"], 2)
assert deleted == 2
msgs = repo.get_messages(conv["id"])
assert len(msgs) == 3
assert [m["position"] for m in msgs] == [0, 1, 2]
class TestSetFeedback:
def test_set_feedback(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
assert repo.set_feedback(conv["id"], 0, {"text": "thumbs_up"}) is True
msg = repo.get_message_at(conv["id"], 0)
assert msg["feedback"] == {"text": "thumbs_up"}
def test_unset_feedback(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
repo.set_feedback(conv["id"], 0, {"text": "thumbs_up"})
assert repo.set_feedback(conv["id"], 0, None) is True
msg = repo.get_message_at(conv["id"], 0)
assert msg["feedback"] is None
def test_set_feedback_nonexistent_position(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
assert repo.set_feedback(conv["id"], 99, {"text": "x"}) is False
class TestMessageCount:
def test_counts_messages(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
assert repo.message_count(conv["id"]) == 0
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
assert repo.message_count(conv["id"]) == 1
class TestCompressionMetadata:
def test_set_compression_metadata(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
meta = {"is_compressed": True, "last_compression_at": "2026-01-01T00:00:00Z"}
assert repo.update_compression_metadata(conv["id"], "user-1", meta) is True
fetched = repo.get(conv["id"], "user-1")
assert fetched["compression_metadata"]["is_compressed"] is True
def test_set_compression_flags_preserves_points(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
repo.update_compression_metadata(conv["id"], "user-1", {
"is_compressed": False,
"compression_points": [{"summary": "earlier"}],
})
assert repo.set_compression_flags(
conv["id"], is_compressed=True, last_compression_at="2026-01-02",
) is True
fetched = repo.get(conv["id"], "user-1")
assert fetched["compression_metadata"]["is_compressed"] is True
assert fetched["compression_metadata"]["last_compression_at"] == "2026-01-02"
assert fetched["compression_metadata"]["compression_points"] == [
{"summary": "earlier"}
]
def test_append_compression_point_slices_to_max(self, pg_conn):
repo = _repo(pg_conn)
conv = repo.create("user-1", "c")
for i in range(5):
assert repo.append_compression_point(
conv["id"], {"summary": f"p{i}"}, max_points=3,
) is True
fetched = repo.get(conv["id"], "user-1")
points = fetched["compression_metadata"]["compression_points"]
assert [p["summary"] for p in points] == ["p2", "p3", "p4"]
class TestConcurrentAppend:
"""Two threads appending to the same conversation must not race on
``position``. The plan (migration-postgres.md §Phase 3) explicitly
calls this out as the single trickiest invariant, so we exercise it
directly with two parallel connections."""
def test_concurrent_appends_get_distinct_positions(self, pg_engine, pg_conn):
import threading
# Arrange — one conversation, created inside the outer test txn so
# it disappears on teardown even if the workers somehow commit.
# We commit it explicitly so the workers' separate sessions see it.
repo_setup = _repo(pg_conn)
conv = repo_setup.create("user-concurrent", "c")
pg_conn.commit()
try:
errors: list[BaseException] = []
def worker() -> None:
try:
with pg_engine.begin() as worker_conn:
ConversationsRepository(worker_conn).append_message(
conv["id"], {"prompt": "q", "response": "a"},
)
except BaseException as e: # noqa: BLE001
errors.append(e)
threads = [threading.Thread(target=worker) for _ in range(2)]
for t in threads:
t.start()
for t in threads:
t.join()
assert errors == [], f"worker threads errored: {errors}"
# Assert — the parent row-lock in append_message must have
# serialised the two inserts so they land at positions {0, 1}.
with pg_engine.connect() as verify_conn:
msgs = ConversationsRepository(verify_conn).get_messages(conv["id"])
positions = sorted(m["position"] for m in msgs)
assert positions == [0, 1], (
f"concurrent appends raced; got positions {positions}"
)
finally:
# Clean up — the conversation was committed, so the transaction
# rollback won't drop it.
with pg_engine.begin() as cleanup_conn:
ConversationsRepository(cleanup_conn).delete(
conv["id"], "user-concurrent"
)

View File

@@ -1,79 +0,0 @@
"""Tests for FeedbackRepository against a real Postgres instance."""
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import text
from application.storage.db.repositories.feedback import FeedbackRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> FeedbackRepository:
return FeedbackRepository(conn)
def _make_conversation_id(pg_conn) -> str:
"""Insert a minimal conversations row and return its id as string.
feedback has a FK to conversations (added in Tier 3 migration), but
the FK constraint may not exist yet during early phases. We create a
row anyway to keep tests realistic.
"""
cid = str(uuid.uuid4())
# Only insert if the conversations table exists; otherwise use a random UUID.
row = pg_conn.execute(
text(
"SELECT 1 FROM information_schema.tables "
"WHERE table_schema='public' AND table_name='conversations'"
)
).scalar()
if row:
pg_conn.execute(
text("INSERT INTO conversations (id, user_id) VALUES (CAST(:id AS uuid), 'test')"),
{"id": cid},
)
return cid
class TestCreate:
def test_creates_feedback(self, pg_conn):
repo = _repo(pg_conn)
cid = _make_conversation_id(pg_conn)
doc = repo.create(cid, "user-1", 0, "great answer")
assert doc["conversation_id"] is not None
assert doc["user_id"] == "user-1"
assert doc["question_index"] == 0
assert doc["feedback_text"] == "great answer"
def test_allows_null_feedback_text(self, pg_conn):
repo = _repo(pg_conn)
cid = _make_conversation_id(pg_conn)
doc = repo.create(cid, "user-1", 1)
assert doc["feedback_text"] is None
class TestListForConversation:
def test_lists_feedback_for_conversation(self, pg_conn):
repo = _repo(pg_conn)
cid = _make_conversation_id(pg_conn)
repo.create(cid, "user-1", 0, "good")
repo.create(cid, "user-1", 1, "bad")
results = repo.list_for_conversation(cid)
assert len(results) == 2
assert results[0]["question_index"] == 0
assert results[1]["question_index"] == 1
def test_does_not_mix_conversations(self, pg_conn):
repo = _repo(pg_conn)
cid1 = _make_conversation_id(pg_conn)
cid2 = _make_conversation_id(pg_conn)
repo.create(cid1, "user-1", 0, "a")
repo.create(cid2, "user-1", 0, "b")
assert len(repo.list_for_conversation(cid1)) == 1

View File

@@ -0,0 +1,99 @@
"""Tests for PendingToolStateRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.pending_tool_state import PendingToolStateRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _conv(conn) -> dict:
return ConversationsRepository(conn).create("user-1", "test conv")
def _repo(conn) -> PendingToolStateRepository:
return PendingToolStateRepository(conn)
def _sample_state() -> dict:
return {
"messages": [{"role": "user", "content": "hello"}],
"pending_tool_calls": [{"id": "tc-1", "name": "search"}],
"tools_dict": {"search": {"type": "function"}},
"tool_schemas": [{"name": "search"}],
"agent_config": {"model_id": "gpt-4", "llm_name": "openai"},
}
class TestSaveState:
def test_creates_state(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
state = _sample_state()
doc = repo.save_state(conv["id"], "user-1", **state)
assert doc["user_id"] == "user-1"
assert doc["messages"] == state["messages"]
assert doc["pending_tool_calls"] == state["pending_tool_calls"]
assert doc["expires_at"] is not None
def test_upsert_replaces_existing(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
state = _sample_state()
repo.save_state(conv["id"], "user-1", **state)
state["messages"] = [{"role": "user", "content": "updated"}]
doc2 = repo.save_state(conv["id"], "user-1", **state)
# Same row, updated content
assert doc2["messages"] == [{"role": "user", "content": "updated"}]
def test_save_with_client_tools(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
state = _sample_state()
state["client_tools"] = [{"name": "browser"}]
doc = repo.save_state(conv["id"], "user-1", **state)
assert doc["client_tools"] == [{"name": "browser"}]
class TestLoadState:
def test_loads_existing(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
repo.save_state(conv["id"], "user-1", **_sample_state())
loaded = repo.load_state(conv["id"], "user-1")
assert loaded is not None
assert loaded["agent_config"]["model_id"] == "gpt-4"
def test_load_nonexistent(self, pg_conn):
repo = _repo(pg_conn)
assert repo.load_state("00000000-0000-0000-0000-000000000000", "u") is None
class TestDeleteState:
def test_deletes(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
repo.save_state(conv["id"], "user-1", **_sample_state())
assert repo.delete_state(conv["id"], "user-1") is True
assert repo.load_state(conv["id"], "user-1") is None
def test_delete_nonexistent(self, pg_conn):
repo = _repo(pg_conn)
assert repo.delete_state("00000000-0000-0000-0000-000000000000", "u") is False
class TestCleanupExpired:
def test_cleanup_removes_expired(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
# Create a state with TTL of 0 seconds (already expired)
repo.save_state(conv["id"], "user-1", **_sample_state(), ttl_seconds=0)
deleted = repo.cleanup_expired()
assert deleted >= 1
assert repo.load_state(conv["id"], "user-1") is None

View File

@@ -30,6 +30,11 @@ class TestCreate:
doc = repo.create("user-1", "p", "c")
assert doc["_id"] == doc["id"]
def test_create_with_legacy_mongo_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "p", "c", legacy_mongo_id="507f1f77bcf86cd799439011")
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
class TestGet:
def test_get_by_id_and_user(self, pg_conn):
@@ -47,6 +52,17 @@ class TestGet:
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
def test_get_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create(
"user-1",
"p",
"c",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
assert fetched["id"] == created["id"]
class TestGetForRendering:
def test_returns_prompt_without_user_scoping(self, pg_conn):
@@ -88,6 +104,24 @@ class TestUpdate:
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "old"
def test_update_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
repo.create(
"user-1",
"old",
"old-content",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
assert repo.update_by_legacy_id(
"507f1f77bcf86cd799439011",
"user-1",
"new",
"new-content",
) is True
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
assert fetched["name"] == "new"
assert fetched["content"] == "new-content"
class TestDelete:
def test_deletes_prompt(self, pg_conn):
@@ -102,6 +136,17 @@ class TestDelete:
repo.delete(created["id"], "user-other")
assert repo.get(created["id"], "user-1") is not None
def test_delete_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create(
"user-1",
"p",
"c",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
assert repo.delete_by_legacy_id("507f1f77bcf86cd799439011", "user-1") is True
assert repo.get(created["id"], "user-1") is None
class TestFindOrCreate:
def test_creates_when_missing(self, pg_conn):

View File

@@ -0,0 +1,91 @@
"""Tests for SharedConversationsRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.shared_conversations import SharedConversationsRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _conv(conn) -> dict:
return ConversationsRepository(conn).create("user-1", "test conv")
def _repo(conn) -> SharedConversationsRepository:
return SharedConversationsRepository(conn)
class TestCreate:
def test_creates_share(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
share = repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
assert share["conversation_id"] is not None
assert share["user_id"] == "user-1"
assert share["is_promptable"] is False
assert share["first_n_queries"] == 3
assert share["uuid"] is not None
def test_create_promptable_with_api_key(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
share = repo.create(
conv["id"], "user-1",
is_promptable=True,
first_n_queries=5,
api_key="ak-prompt",
)
assert share["is_promptable"] is True
assert share["api_key"] == "ak-prompt"
class TestFindByUuid:
def test_finds_by_uuid(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
share = repo.create(conv["id"], "user-1", first_n_queries=2)
found = repo.find_by_uuid(str(share["uuid"]))
assert found["id"] == share["id"]
def test_not_found(self, pg_conn):
repo = _repo(pg_conn)
assert repo.find_by_uuid("00000000-0000-0000-0000-000000000000") is None
class TestFindExisting:
def test_finds_matching_share(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
found = repo.find_existing(conv["id"], "user-1", False, 3)
assert found is not None
assert found["first_n_queries"] == 3
def test_no_match_different_params(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
assert repo.find_existing(conv["id"], "user-1", True, 3) is None
def test_finds_with_api_key(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
repo.create(conv["id"], "user-1", is_promptable=True, first_n_queries=5, api_key="ak-1")
found = repo.find_existing(conv["id"], "user-1", True, 5, api_key="ak-1")
assert found is not None
class TestListForConversation:
def test_lists_shares(self, pg_conn):
conv = _conv(pg_conn)
repo = _repo(pg_conn)
repo.create(conv["id"], "user-1", first_n_queries=1)
repo.create(conv["id"], "user-1", first_n_queries=2)
results = repo.list_for_conversation(conv["id"])
assert len(results) == 2

View File

@@ -25,12 +25,6 @@ class TestCreate:
assert doc["type"] == "url"
assert doc["id"] is not None
def test_creates_system_source_without_user(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("system-src")
assert doc["user_id"] is None
assert doc["name"] == "system-src"
def test_creates_source_with_metadata(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("src", user_id="u", metadata={"url": "https://example.com"})
@@ -38,7 +32,7 @@ class TestCreate:
def test_create_returns_id_and_underscore_id(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("s")
doc = repo.create("s", user_id="u")
assert doc["_id"] == doc["id"]

View File

@@ -0,0 +1,116 @@
"""Tests for WorkflowEdgesRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _setup(conn) -> tuple[dict, dict, dict]:
"""Create a workflow with two nodes and return (workflow, node1, node2)."""
wf = WorkflowsRepository(conn).create("user-1", "test wf")
node_repo = WorkflowNodesRepository(conn)
n1 = node_repo.create(wf["id"], 1, "start-node", "start")
n2 = node_repo.create(wf["id"], 1, "end-node", "end")
return wf, n1, n2
def _repo(conn) -> WorkflowEdgesRepository:
return WorkflowEdgesRepository(conn)
class TestCreate:
def test_creates_edge(self, pg_conn):
wf, n1, n2 = _setup(pg_conn)
repo = _repo(pg_conn)
edge = repo.create(
wf["id"], 1, "edge-1", n1["id"], n2["id"],
source_handle="out", target_handle="in",
)
assert edge["edge_id"] == "edge-1"
assert str(edge["from_node_id"]) == n1["id"]
assert str(edge["to_node_id"]) == n2["id"]
assert edge["source_handle"] == "out"
assert edge["target_handle"] == "in"
class TestBulkCreate:
def test_bulk_creates_edges(self, pg_conn):
wf, n1, n2 = _setup(pg_conn)
repo = _repo(pg_conn)
edges = repo.bulk_create(wf["id"], 1, [
{"edge_id": "e1", "from_node_id": n1["id"], "to_node_id": n2["id"]},
{"edge_id": "e2", "from_node_id": n2["id"], "to_node_id": n1["id"],
"source_handle": "loop"},
])
assert len(edges) == 2
def test_bulk_create_empty(self, pg_conn):
wf, _, _ = _setup(pg_conn)
repo = _repo(pg_conn)
assert repo.bulk_create(wf["id"], 1, []) == []
class TestFindByVersion:
def test_finds_edges(self, pg_conn):
wf, n1, n2 = _setup(pg_conn)
repo = _repo(pg_conn)
repo.create(wf["id"], 1, "e1", n1["id"], n2["id"])
edges = repo.find_by_version(wf["id"], 1)
assert len(edges) == 1
assert edges[0]["edge_id"] == "e1"
def test_no_edges_for_version(self, pg_conn):
wf, _, _ = _setup(pg_conn)
repo = _repo(pg_conn)
assert repo.find_by_version(wf["id"], 99) == []
class TestDelete:
def test_delete_by_workflow(self, pg_conn):
wf, n1, n2 = _setup(pg_conn)
repo = _repo(pg_conn)
repo.create(wf["id"], 1, "e1", n1["id"], n2["id"])
deleted = repo.delete_by_workflow(wf["id"])
assert deleted == 1
assert repo.find_by_version(wf["id"], 1) == []
def test_delete_by_version(self, pg_conn):
wf = WorkflowsRepository(pg_conn).create("user-1", "wf")
node_repo = WorkflowNodesRepository(pg_conn)
n1v1 = node_repo.create(wf["id"], 1, "n1", "start")
n2v1 = node_repo.create(wf["id"], 1, "n2", "end")
n1v2 = node_repo.create(wf["id"], 2, "n1", "start")
n2v2 = node_repo.create(wf["id"], 2, "n2", "end")
repo = _repo(pg_conn)
repo.create(wf["id"], 1, "e1", n1v1["id"], n2v1["id"])
repo.create(wf["id"], 2, "e1", n1v2["id"], n2v2["id"])
repo.delete_by_version(wf["id"], 1)
assert repo.find_by_version(wf["id"], 1) == []
assert len(repo.find_by_version(wf["id"], 2)) == 1
def test_delete_other_versions(self, pg_conn):
wf = WorkflowsRepository(pg_conn).create("user-1", "wf")
node_repo = WorkflowNodesRepository(pg_conn)
n1v1 = node_repo.create(wf["id"], 1, "n1", "start")
n2v1 = node_repo.create(wf["id"], 1, "n2", "end")
n1v2 = node_repo.create(wf["id"], 2, "n1", "start")
n2v2 = node_repo.create(wf["id"], 2, "n2", "end")
repo = _repo(pg_conn)
repo.create(wf["id"], 1, "e1", n1v1["id"], n2v1["id"])
repo.create(wf["id"], 2, "e1", n1v2["id"], n2v2["id"])
repo.delete_other_versions(wf["id"], 2)
assert repo.find_by_version(wf["id"], 1) == []
assert len(repo.find_by_version(wf["id"], 2)) == 1

View File

@@ -0,0 +1,172 @@
"""Tests for WorkflowNodesRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _wf(conn) -> dict:
return WorkflowsRepository(conn).create("user-1", "test wf")
def _repo(conn) -> WorkflowNodesRepository:
return WorkflowNodesRepository(conn)
class TestCreate:
def test_creates_node(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
node = repo.create(wf["id"], 1, "node-start", "start", title="Start")
assert node["node_id"] == "node-start"
assert node["node_type"] == "start"
assert node["title"] == "Start"
assert node["graph_version"] == 1
def test_create_with_config(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
node = repo.create(
wf["id"], 1, "node-agent", "agent",
config={"agent_type": "classic", "system_prompt": "You are helpful"},
position={"x": 100, "y": 200},
)
assert node["config"]["agent_type"] == "classic"
assert node["position"]["x"] == 100
def test_create_with_legacy_mongo_id(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
node = repo.create(
wf["id"],
1,
"node-agent",
"agent",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
assert node["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
class TestBulkCreate:
def test_bulk_creates_nodes(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
nodes = repo.bulk_create(wf["id"], 1, [
{"node_id": "n1", "node_type": "start", "title": "Start"},
{"node_id": "n2", "node_type": "agent", "config": {"agent_type": "react"}},
{"node_id": "n3", "node_type": "end"},
])
assert len(nodes) == 3
node_ids = {n["node_id"] for n in nodes}
assert node_ids == {"n1", "n2", "n3"}
def test_bulk_create_empty(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
assert repo.bulk_create(wf["id"], 1, []) == []
def test_bulk_create_with_legacy_mongo_ids(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
nodes = repo.bulk_create(wf["id"], 1, [
{
"node_id": "n1",
"node_type": "start",
"legacy_mongo_id": "507f1f77bcf86cd799439011",
},
{
"node_id": "n2",
"node_type": "end",
"legacy_mongo_id": "507f1f77bcf86cd799439012",
},
])
assert {n["legacy_mongo_id"] for n in nodes} == {
"507f1f77bcf86cd799439011",
"507f1f77bcf86cd799439012",
}
class TestFindByVersion:
def test_finds_nodes(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
repo.bulk_create(wf["id"], 1, [
{"node_id": "n1", "node_type": "start"},
{"node_id": "n2", "node_type": "end"},
])
repo.bulk_create(wf["id"], 2, [
{"node_id": "n1", "node_type": "start"},
])
v1_nodes = repo.find_by_version(wf["id"], 1)
v2_nodes = repo.find_by_version(wf["id"], 2)
assert len(v1_nodes) == 2
assert len(v2_nodes) == 1
class TestFindNode:
def test_finds_specific_node(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
repo.create(wf["id"], 1, "node-start", "start")
found = repo.find_node(wf["id"], 1, "node-start")
assert found is not None
assert found["node_type"] == "start"
def test_not_found(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
assert repo.find_node(wf["id"], 1, "nonexistent") is None
def test_get_by_legacy_id(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
created = repo.create(
wf["id"],
1,
"node-start",
"start",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
found = repo.get_by_legacy_id("507f1f77bcf86cd799439011")
assert found["id"] == created["id"]
class TestDelete:
def test_delete_by_workflow(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
repo.bulk_create(wf["id"], 1, [
{"node_id": "n1", "node_type": "start"},
{"node_id": "n2", "node_type": "end"},
])
deleted = repo.delete_by_workflow(wf["id"])
assert deleted == 2
assert repo.find_by_version(wf["id"], 1) == []
def test_delete_by_version(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
repo.bulk_create(wf["id"], 1, [{"node_id": "n1", "node_type": "start"}])
repo.bulk_create(wf["id"], 2, [{"node_id": "n1", "node_type": "start"}])
repo.delete_by_version(wf["id"], 1)
assert repo.find_by_version(wf["id"], 1) == []
assert len(repo.find_by_version(wf["id"], 2)) == 1
def test_delete_other_versions(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
repo.bulk_create(wf["id"], 1, [{"node_id": "n1", "node_type": "start"}])
repo.bulk_create(wf["id"], 2, [{"node_id": "n1", "node_type": "start"}])
repo.bulk_create(wf["id"], 3, [{"node_id": "n1", "node_type": "start"}])
repo.delete_other_versions(wf["id"], 2)
assert repo.find_by_version(wf["id"], 1) == []
assert len(repo.find_by_version(wf["id"], 2)) == 1
assert repo.find_by_version(wf["id"], 3) == []

View File

@@ -0,0 +1,104 @@
"""Tests for WorkflowRunsRepository against a real Postgres instance."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _wf(conn) -> dict:
return WorkflowsRepository(conn).create("user-1", "test wf")
def _repo(conn) -> WorkflowRunsRepository:
return WorkflowRunsRepository(conn)
class TestCreate:
def test_creates_run(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
run = repo.create(wf["id"], "user-1", "completed")
assert run["status"] == "completed"
assert run["user_id"] == "user-1"
assert run["id"] is not None
def test_create_with_details(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
now = datetime.now(timezone.utc)
run = repo.create(
wf["id"], "user-1", "completed",
inputs={"query": "hello"},
result={"output": "world"},
steps=[
{"node_id": "n1", "status": "completed"},
{"node_id": "n2", "status": "completed"},
],
ended_at=now,
)
assert run["inputs"] == {"query": "hello"}
assert run["result"] == {"output": "world"}
assert len(run["steps"]) == 2
assert run["ended_at"] is not None
def test_create_with_started_at_and_legacy_id(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
now = datetime.now(timezone.utc)
run = repo.create(
wf["id"],
"user-1",
"completed",
started_at=now,
legacy_mongo_id="507f1f77bcf86cd799439011",
)
assert run["started_at"] == now
assert run["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
class TestGet:
def test_get_existing(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
created = repo.create(wf["id"], "user-1", "completed")
fetched = repo.get(created["id"])
assert fetched["id"] == created["id"]
def test_get_nonexistent(self, pg_conn):
repo = _repo(pg_conn)
assert repo.get("00000000-0000-0000-0000-000000000000") is None
def test_get_by_legacy_id(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
created = repo.create(
wf["id"], "user-1", "completed",
legacy_mongo_id="507f1f77bcf86cd799439011",
)
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011")
assert fetched["id"] == created["id"]
class TestListForWorkflow:
def test_lists_runs(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
repo.create(wf["id"], "user-1", "completed")
repo.create(wf["id"], "user-1", "failed")
runs = repo.list_for_workflow(wf["id"])
assert len(runs) == 2
def test_empty_list(self, pg_conn):
wf = _wf(pg_conn)
repo = _repo(pg_conn)
assert repo.list_for_workflow(wf["id"]) == []

View File

@@ -0,0 +1,117 @@
"""Tests for WorkflowsRepository against a real Postgres instance."""
from __future__ import annotations
import pytest
from application.storage.db.repositories.workflows import WorkflowsRepository
pytestmark = pytest.mark.skipif(
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
def _repo(conn) -> WorkflowsRepository:
return WorkflowsRepository(conn)
class TestCreate:
def test_creates_workflow(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "My Workflow")
assert doc["user_id"] == "user-1"
assert doc["name"] == "My Workflow"
assert doc["current_graph_version"] == 1
assert doc["id"] is not None
assert doc["_id"] == doc["id"]
def test_create_with_description(self, pg_conn):
repo = _repo(pg_conn)
doc = repo.create("user-1", "wf", description="A test workflow")
assert doc["description"] == "A test workflow"
class TestGet:
def test_get_existing(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
fetched = repo.get(created["id"], "user-1")
assert fetched["id"] == created["id"]
def test_get_wrong_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
assert repo.get(created["id"], "user-other") is None
def test_get_by_id(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
fetched = repo.get_by_id(created["id"])
assert fetched["id"] == created["id"]
def test_get_by_legacy_id(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create(
"user-1", "wf", legacy_mongo_id="507f1f77bcf86cd799439011",
)
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
assert fetched["id"] == created["id"]
class TestListForUser:
def test_lists_own(self, pg_conn):
repo = _repo(pg_conn)
repo.create("alice", "wf1")
repo.create("alice", "wf2")
repo.create("bob", "wf3")
results = repo.list_for_user("alice")
assert len(results) == 2
class TestUpdate:
def test_updates_name(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old")
assert repo.update(created["id"], "user-1", {"name": "new"}) is True
fetched = repo.get(created["id"], "user-1")
assert fetched["name"] == "new"
def test_update_wrong_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "old")
assert repo.update(created["id"], "other", {"name": "new"}) is False
def test_update_disallowed_field(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
assert repo.update(created["id"], "user-1", {"id": "bad"}) is False
class TestIncrementGraphVersion:
def test_increments(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
assert created["current_graph_version"] == 1
new_ver = repo.increment_graph_version(created["id"], "user-1")
assert new_ver == 2
fetched = repo.get(created["id"], "user-1")
assert fetched["current_graph_version"] == 2
def test_increment_wrong_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
assert repo.increment_graph_version(created["id"], "other") is None
class TestDelete:
def test_deletes(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
assert repo.delete(created["id"], "user-1") is True
assert repo.get(created["id"], "user-1") is None
def test_delete_wrong_user(self, pg_conn):
repo = _repo(pg_conn)
created = repo.create("user-1", "wf")
assert repo.delete(created["id"], "other") is False

View File

@@ -0,0 +1,414 @@
"""Integration coverage for dual-write rows surviving backfill reruns."""
from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock
import pytest
from bson import ObjectId
from bson.dbref import DBRef
from flask import Flask, request
from application.core.settings import settings
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.dual_write import dual_write
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.prompts import PromptsRepository
from application.storage.db.repositories.shared_conversations import SharedConversationsRepository
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from scripts.db.backfill import (
_backfill_agents,
_backfill_attachments,
_backfill_conversations,
_backfill_prompts,
_backfill_shared_conversations,
_backfill_workflow_runs,
_backfill_workflow_nodes,
_backfill_workflows,
)
pytestmark = pytest.mark.skipif(
not settings.POSTGRES_URI,
reason="POSTGRES_URI not configured",
)
class _BoundEngine:
"""Expose one pre-opened SQLAlchemy connection as an Engine.begin()."""
def __init__(self, conn):
self._conn = conn
@contextmanager
def begin(self):
yield self._conn
@pytest.fixture
def app():
return Flask(__name__)
@pytest.fixture
def mongo_db(mock_mongo_db):
return mock_mongo_db[settings.MONGO_DB_NAME]
@pytest.fixture
def dual_write_pg(monkeypatch, pg_conn):
monkeypatch.setattr(settings, "USE_POSTGRES", True)
monkeypatch.setattr(
"application.storage.db.engine.get_engine",
lambda: _BoundEngine(pg_conn),
)
return pg_conn
def test_prompt_dual_write_row_survives_backfill_rerun(app, mongo_db, dual_write_pg, monkeypatch):
from application.api.user.prompts.routes import CreatePrompt
monkeypatch.setattr(
"application.api.user.prompts.routes.prompts_collection",
mongo_db["prompts"],
)
with app.test_request_context(
"/api/create_prompt",
method="POST",
json={"name": "Greeting", "content": "Hello"},
):
request.decoded_token = {"sub": "user-1"}
response = CreatePrompt().post()
assert response.status_code == 200
mongo_id = response.json["id"]
repo = PromptsRepository(dual_write_pg)
prompt = repo.get_by_legacy_id(mongo_id, "user-1")
assert prompt is not None
assert prompt["content"] == "Hello"
mongo_db["prompts"].update_one(
{"_id": ObjectId(mongo_id)},
{"$set": {"content": "Hello again"}},
)
_backfill_prompts(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
prompts = repo.list_for_user("user-1")
assert len(prompts) == 1
assert prompts[0]["legacy_mongo_id"] == mongo_id
assert prompts[0]["content"] == "Hello again"
def test_agent_dual_write_row_survives_backfill_rerun(app, mongo_db, dual_write_pg, monkeypatch):
from application.api.user.agents.routes import CreateAgent
monkeypatch.setattr(
"application.api.user.agents.routes.agents_collection",
mongo_db["agents"],
)
monkeypatch.setattr(
"application.api.user.agents.routes.handle_image_upload",
lambda *_args, **_kwargs: ("", None),
)
with app.test_request_context(
"/api/create_agent",
method="POST",
json={"name": "Mirror Agent", "status": "draft"},
):
request.decoded_token = {"sub": "user-1"}
response = CreateAgent().post()
assert response.status_code == 201
mongo_id = response.json["id"]
repo = AgentsRepository(dual_write_pg)
agent = repo.get_by_legacy_id(mongo_id, "user-1")
assert agent is not None
assert agent["name"] == "Mirror Agent"
mongo_db["agents"].update_one(
{"_id": ObjectId(mongo_id)},
{"$set": {"name": "Renamed Agent", "description": "Updated by backfill"}},
)
_backfill_agents(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
agents = repo.list_for_user("user-1")
assert len(agents) == 1
assert agents[0]["legacy_mongo_id"] == mongo_id
assert agents[0]["name"] == "Renamed Agent"
assert agents[0]["description"] == "Updated by backfill"
def test_attachment_dual_write_row_survives_backfill_rerun(mongo_db, dual_write_pg):
mongo_id = ObjectId()
mongo_db["attachments"].insert_one(
{
"_id": mongo_id,
"user": "user-1",
"filename": "notes.txt",
"upload_path": "/uploads/notes.txt",
"mime_type": "text/plain",
"size": 12,
}
)
dual_write(
AttachmentsRepository,
lambda repo: repo.create(
"user-1",
"notes.txt",
"/uploads/notes.txt",
mime_type="text/plain",
size=12,
legacy_mongo_id=str(mongo_id),
),
)
repo = AttachmentsRepository(dual_write_pg)
attachment = repo.get_by_legacy_id(str(mongo_id), "user-1")
assert attachment is not None
assert attachment["filename"] == "notes.txt"
mongo_db["attachments"].update_one(
{"_id": mongo_id},
{"$set": {"filename": "notes-v2.txt", "size": 24}},
)
_backfill_attachments(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
attachments = repo.list_for_user("user-1")
assert len(attachments) == 1
assert attachments[0]["legacy_mongo_id"] == str(mongo_id)
assert attachments[0]["filename"] == "notes-v2.txt"
assert attachments[0]["size"] == 24
def test_workflow_nodes_dual_write_rows_survive_backfill_rerun(mongo_db, dual_write_pg, monkeypatch):
from application.api.user.workflows.routes import (
_dual_write_workflow_create,
create_workflow_edges,
create_workflow_nodes,
)
monkeypatch.setattr(
"application.api.user.workflows.routes.workflows_collection",
mongo_db["workflows"],
)
monkeypatch.setattr(
"application.api.user.workflows.routes.workflow_nodes_collection",
mongo_db["workflow_nodes"],
)
monkeypatch.setattr(
"application.api.user.workflows.routes.workflow_edges_collection",
mongo_db["workflow_edges"],
)
workflow_doc = {
"name": "Workflow",
"description": "test",
"user": "user-1",
"current_graph_version": 1,
}
insert_result = mongo_db["workflows"].insert_one(workflow_doc)
workflow_id = str(insert_result.inserted_id)
nodes_data = [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
]
edges_data = [{"id": "edge-1", "source": "start", "target": "end"}]
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
_dual_write_workflow_create(
workflow_id,
"user-1",
"Workflow",
"test",
created_nodes,
edges_data,
)
workflow_repo = WorkflowsRepository(dual_write_pg)
workflow = workflow_repo.list_for_user("user-1")[0]
node_repo = WorkflowNodesRepository(dual_write_pg)
pg_nodes = node_repo.find_by_version(workflow["id"], 1)
assert len(pg_nodes) == 2
assert all(node["legacy_mongo_id"] for node in pg_nodes)
renamed_node_id = created_nodes[0]["legacy_mongo_id"]
mongo_db["workflow_nodes"].update_one(
{"_id": ObjectId(renamed_node_id)},
{"$set": {"title": "Renamed Start"}},
)
_backfill_workflows(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
_backfill_workflow_nodes(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
pg_nodes = node_repo.find_by_version(workflow["id"], 1)
assert len(pg_nodes) == 2
renamed_node = node_repo.get_by_legacy_id(renamed_node_id)
assert renamed_node is not None
assert renamed_node["title"] == "Renamed Start"
def test_compression_summary_dual_write_appends_pg_message(mongo_db, dual_write_pg):
from application.api.answer.services.conversation_service import ConversationService
mongo_conv_id = ObjectId()
mongo_db["conversations"].insert_one(
{"_id": mongo_conv_id, "user": "user-1", "queries": []},
)
conv = ConversationsRepository(dual_write_pg).create(
"user-1", "Mirror", legacy_mongo_id=str(mongo_conv_id),
)
service = ConversationService()
metadata = {
"compressed_summary": "Compressed context summary",
"timestamp": "2026-04-13T12:00:00+00:00",
"model_used": "gpt-4",
}
service.append_compression_message(str(mongo_conv_id), metadata)
pg_messages = ConversationsRepository(dual_write_pg).get_messages(conv["id"])
assert len(pg_messages) == 1
assert pg_messages[0]["prompt"] == "[Context Compression Summary]"
assert pg_messages[0]["response"] == "Compressed context summary"
assert pg_messages[0]["model_id"] == "gpt-4"
def test_workflow_run_dual_write_row_survives_backfill_rerun(mongo_db, dual_write_pg):
from application.agents.workflow_agent import WorkflowAgent
mongo_workflow_id = ObjectId()
mongo_db["workflows"].insert_one(
{
"_id": mongo_workflow_id,
"user": "user-1",
"name": "Workflow",
"description": "test",
}
)
workflow = WorkflowsRepository(dual_write_pg).create(
"user-1", "Workflow", description="test",
legacy_mongo_id=str(mongo_workflow_id),
)
agent = WorkflowAgent(
endpoint="https://api.example.com",
llm_name="openai",
model_id="gpt-4",
api_key="test_key",
user_api_key=None,
prompt="You are helpful.",
chat_history=[],
decoded_token={"sub": "user-1"},
attachments=[],
json_schema=None,
workflow_id=str(mongo_workflow_id),
workflow_owner="user-1",
)
agent._engine = MagicMock()
agent._engine.state = {"answer": "ok"}
agent._engine.execution_log = []
agent._engine.get_execution_summary.return_value = []
agent._save_workflow_run("hello")
run_repo = WorkflowRunsRepository(dual_write_pg)
runs = run_repo.list_for_workflow(workflow["id"])
assert len(runs) == 1
assert runs[0]["user_id"] == "user-1"
legacy_mongo_id = runs[0]["legacy_mongo_id"]
mongo_db["workflow_runs"].update_one(
{"_id": ObjectId(legacy_mongo_id)},
{"$set": {"status": "failed", "user": "user-1", "user_id": "user-1"}},
)
_backfill_workflow_runs(
conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False,
)
runs = run_repo.list_for_workflow(workflow["id"])
assert len(runs) == 1
assert runs[0]["status"] == "failed"
assert runs[0]["legacy_mongo_id"] == legacy_mongo_id
def test_shared_conversation_backfill_recovers_dbref_and_agent_prompt_metadata(
mongo_db, dual_write_pg,
):
conv = ConversationsRepository(dual_write_pg).create(
"user-1", "Conversation", legacy_mongo_id="507f1f77bcf86cd799439011",
)
PromptsRepository(dual_write_pg).create(
"user-1", "Prompt", "Body",
legacy_mongo_id="507f1f77bcf86cd799439012",
)
mongo_db["agents"].insert_one(
{
"_id": ObjectId(),
"key": "share-key",
"prompt_id": ObjectId("507f1f77bcf86cd799439012"),
"chunks": "7",
"user": "user-1",
}
)
mongo_db["shared_conversations"].insert_one(
{
"_id": ObjectId(),
"uuid": "00000000-0000-0000-0000-000000000001",
"conversation_id": DBRef(
"conversations", ObjectId("507f1f77bcf86cd799439011"),
),
"user": "user-1",
"isPromptable": True,
"first_n_queries": 2,
"api_key": "share-key",
}
)
_backfill_shared_conversations(
conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False,
)
shares = SharedConversationsRepository(dual_write_pg).list_for_conversation(conv["id"])
assert len(shares) == 1
assert shares[0]["api_key"] == "share-key"
assert shares[0]["chunks"] == 7
assert shares[0]["prompt_id"] is not None
def test_conversation_backfill_reports_unresolved_attachment_refs(mongo_db, dual_write_pg):
mongo_db["conversations"].insert_one(
{
"_id": ObjectId("507f1f77bcf86cd799439021"),
"user": "user-1",
"name": "Conversation",
"queries": [
{
"prompt": "q1",
"response": "a1",
"attachments": [str(ObjectId("507f1f77bcf86cd799439022"))],
}
],
}
)
stats = _backfill_conversations(
conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False,
)
assert stats["unresolved_attachment_refs"] == 1
conv = ConversationsRepository(dual_write_pg).get_by_legacy_id(
"507f1f77bcf86cd799439021",
)
messages = ConversationsRepository(dual_write_pg).get_messages(conv["id"])
assert messages[0]["attachments"] == []