mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
20 Commits
pg-2
...
codex/add-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa938d76d7 | ||
|
|
2940628aa6 | ||
|
|
7f23928134 | ||
|
|
20e17c84c7 | ||
|
|
389ddf6068 | ||
|
|
1e2443fb90 | ||
|
|
6387bd1892 | ||
|
|
7d22724d1c | ||
|
|
f6f12f6895 | ||
|
|
934127f323 | ||
|
|
1780e3cc91 | ||
|
|
5e7fab2f34 | ||
|
|
92ae76f95e | ||
|
|
18755bdd9b | ||
|
|
0f20adcbf4 | ||
|
|
18e2a829c9 | ||
|
|
cd44501a71 | ||
|
|
f8ebdf3fd4 | ||
|
|
7c6fca18ad | ||
|
|
5fab798707 |
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
# DocsGPT Incident Response Plan (IRP)
|
||||
|
||||
This playbook describes how maintainers respond to confirmed or suspected security incidents.
|
||||
|
||||
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
|
||||
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
|
||||
|
||||
## Severity
|
||||
|
||||
| Severity | Definition | Typical examples |
|
||||
|---|---|---|
|
||||
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
|
||||
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
|
||||
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
|
||||
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
|
||||
|
||||
|
||||
## Response workflow
|
||||
|
||||
### 1) Triage (target: initial response within 48 hours)
|
||||
|
||||
1. Acknowledge report.
|
||||
2. Validate on latest release and `main`.
|
||||
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
|
||||
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
|
||||
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
|
||||
|
||||
### 2) Investigation
|
||||
|
||||
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
|
||||
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
|
||||
3. Request a CVE through GHSA for **Medium+** issues.
|
||||
|
||||
### 3) Containment, fix, and disclosure
|
||||
|
||||
1. Implement and test fix in private security workflow (GHSA private fork/branch).
|
||||
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
|
||||
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
|
||||
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
|
||||
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
|
||||
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
|
||||
|
||||
### 4) Post-incident
|
||||
|
||||
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
|
||||
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
|
||||
3. Track follow-up hardening actions with owners/dates.
|
||||
4. Update this IRP and related runbooks as needed.
|
||||
|
||||
## Scenario playbooks
|
||||
|
||||
### Supply-chain compromise
|
||||
|
||||
1. Freeze releases and investigate blast radius.
|
||||
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
|
||||
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
|
||||
4. Publish advisory with exact affected versions and required user actions.
|
||||
|
||||
### Data exposure
|
||||
|
||||
1. Determine scope (users, documents, keys, logs, time window).
|
||||
2. Disable affected path or hotfix immediately for managed cloud.
|
||||
3. Notify affected users with concrete remediation steps (for example, rotate keys).
|
||||
4. Continue through standard fix/disclosure workflow.
|
||||
|
||||
### Critical regression with security impact
|
||||
|
||||
1. Identify introducing change (`git bisect` if needed).
|
||||
2. Publish workaround within 24 hours (for example, pin to known-good version).
|
||||
3. Ship patch release with regression test and close incident with public summary.
|
||||
|
||||
## AI-specific guidance
|
||||
|
||||
Treat confirmed AI-specific abuse as security incidents:
|
||||
|
||||
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
|
||||
- Cross-tenant retrieval/isolation failure -> **High**
|
||||
- API key disclosure in output -> **High**
|
||||
|
||||
## Secret rotation quick reference
|
||||
|
||||
| Secret | Standard rotation action |
|
||||
|---|---|
|
||||
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
|
||||
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
|
||||
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
|
||||
| Database credentials | Rotate in DB platform; redeploy with new secrets |
|
||||
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
|
||||
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
|
||||
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
|
||||
|
||||
## Maintenance
|
||||
|
||||
Review this document:
|
||||
|
||||
- after every **Critical/High** incident, and
|
||||
- at least annually.
|
||||
|
||||
Changes should be proposed via pull request to `main`.
|
||||
25
.github/workflows/zizmor.yml
vendored
Normal file
25
.github/workflows/zizmor.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: GitHub Actions Security Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["master"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
zizmor:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run zizmor 🌈
|
||||
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
|
||||
@@ -18,5 +18,5 @@ We aim to acknowledge reports within 48 hours.
|
||||
|
||||
## Incident Handling
|
||||
|
||||
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
|
||||
|
||||
@@ -15,6 +15,9 @@ from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -181,6 +184,9 @@ class WorkflowAgent(BaseAgent):
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
@@ -188,6 +194,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
user=owner_id,
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
@@ -196,7 +203,34 @@ class WorkflowAgent(BaseAgent):
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
result = workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
legacy_mongo_id = (
|
||||
str(result.inserted_id)
|
||||
if getattr(result, "inserted_id", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def _pg_write(repo: WorkflowRunsRepository) -> None:
|
||||
if not self.workflow_id or not owner_id or not legacy_mongo_id:
|
||||
return
|
||||
workflow = WorkflowsRepository(repo._conn).get_by_legacy_id(
|
||||
self.workflow_id, owner_id,
|
||||
)
|
||||
if workflow is None:
|
||||
return
|
||||
repo.create(
|
||||
workflow["id"],
|
||||
owner_id,
|
||||
run.status.value,
|
||||
inputs=run.inputs,
|
||||
result=run.outputs,
|
||||
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||
started_at=run.created_at,
|
||||
ended_at=run.completed_at,
|
||||
legacy_mongo_id=legacy_mongo_id,
|
||||
)
|
||||
|
||||
dual_write(WorkflowRunsRepository, _pg_write)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
|
||||
@@ -211,6 +211,7 @@ class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
workflow_id: str
|
||||
user: Optional[str] = None
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -226,7 +227,7 @@ class WorkflowRun(BaseModel):
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
doc = {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
@@ -235,3 +236,7 @@ class WorkflowRun(BaseModel):
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
if self.user:
|
||||
doc["user"] = self.user
|
||||
doc["user_id"] = self.user
|
||||
return doc
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,57 +0,0 @@
|
||||
"""0002 add unique constraints for notes and connector_sessions.
|
||||
|
||||
The memories table already has ``memories_user_tool_path_uidx`` from the
|
||||
0001 baseline. Notes and connector_sessions were missing unique constraints
|
||||
that their repository upsert logic depends on.
|
||||
|
||||
Before creating the indexes, duplicate rows are cleaned up — keeping only
|
||||
the row with the latest ``id`` (UUID, lexicographic max) per group.
|
||||
|
||||
Revision ID: 0002_add_unique_constraints
|
||||
Revises: 0001_initial
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "0002_add_unique_constraints"
|
||||
down_revision: Union[str, None] = "0001_initial"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Deduplicate notes: keep one row per (user_id, tool_id)
|
||||
op.execute("""
|
||||
DELETE FROM notes
|
||||
WHERE id NOT IN (
|
||||
SELECT DISTINCT ON (user_id, tool_id) id
|
||||
FROM notes
|
||||
ORDER BY user_id, tool_id, created_at DESC
|
||||
);
|
||||
""")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS notes_user_tool_uidx "
|
||||
"ON notes (user_id, tool_id);"
|
||||
)
|
||||
|
||||
# Deduplicate connector_sessions: keep one row per (user_id, provider)
|
||||
op.execute("""
|
||||
DELETE FROM connector_sessions
|
||||
WHERE id NOT IN (
|
||||
SELECT DISTINCT ON (user_id, provider) id
|
||||
FROM connector_sessions
|
||||
ORDER BY user_id, provider, created_at DESC
|
||||
);
|
||||
""")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS connector_sessions_user_provider_uidx "
|
||||
"ON connector_sessions (user_id, provider);"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS connector_sessions_user_provider_uidx;")
|
||||
op.execute("DROP INDEX IF EXISTS notes_user_tool_uidx;")
|
||||
@@ -13,6 +13,11 @@ from bson import ObjectId
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,6 +112,26 @@ class ContinuationService:
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — upsert against the same Mongo conversation
|
||||
# by resolving its UUID via conversations.legacy_mongo_id.
|
||||
def _pg_save(_: PendingToolStateRepository) -> None:
|
||||
conn = _._conn # reuse the existing transaction
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
_.save_state(
|
||||
conv["id"],
|
||||
user,
|
||||
messages=_make_serializable(messages),
|
||||
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||
tools_dict=_make_serializable(tools_dict),
|
||||
tool_schemas=_make_serializable(tool_schemas),
|
||||
agent_config=_make_serializable(agent_config),
|
||||
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||
)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_save)
|
||||
return state_id
|
||||
|
||||
def load_state(
|
||||
@@ -138,4 +163,13 @@ class ContinuationService:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — delete the same row.
|
||||
def _pg_delete(repo: PendingToolStateRepository) -> None:
|
||||
conv = ConversationsRepository(repo._conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.delete_state(conv["id"], user)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_delete)
|
||||
return result.deleted_count > 0
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@@ -113,6 +115,26 @@ class ConversationService:
|
||||
},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
# Dual-write to Postgres: update the message at :index and
|
||||
# truncate anything after it, mirroring Mongo's $set+$slice.
|
||||
def _pg_update_at_index(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.update_message_at(conv["id"], index, {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
})
|
||||
repo.truncate_after(conv["id"], index)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_update_at_index)
|
||||
return conversation_id
|
||||
elif conversation_id:
|
||||
# Append new message to existing conversation
|
||||
@@ -138,6 +160,25 @@ class ConversationService:
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Dual-write to Postgres: append the same message.
|
||||
def _pg_append(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append)
|
||||
return conversation_id
|
||||
else:
|
||||
# Create new conversation
|
||||
@@ -193,7 +234,34 @@ class ConversationService:
|
||||
if agent:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
inserted_id = str(result.inserted_id)
|
||||
|
||||
# Dual-write to Postgres: create the conversation row with
|
||||
# legacy_mongo_id and append the first message.
|
||||
def _pg_create(repo: ConversationsRepository) -> None:
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
completion,
|
||||
agent_id=conversation_data.get("agent_id"),
|
||||
api_key=conversation_data.get("api_key"),
|
||||
is_shared_usage=conversation_data.get("is_shared_usage", False),
|
||||
shared_token=conversation_data.get("shared_token"),
|
||||
legacy_mongo_id=inserted_id,
|
||||
)
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_create)
|
||||
return inserted_id
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
@@ -230,6 +298,24 @@ class ConversationService:
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres: mirror $set + $push $slice.
|
||||
def _pg_compression(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.set_compression_flags(
|
||||
conv["id"],
|
||||
is_compressed=True,
|
||||
last_compression_at=compression_metadata.get("timestamp"),
|
||||
)
|
||||
repo.append_compression_point(
|
||||
conv["id"],
|
||||
compression_metadata,
|
||||
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_compression)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
@@ -266,6 +352,23 @@ class ConversationService:
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _pg_append_summary(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
"timestamp": timestamp,
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append_summary)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -114,6 +114,35 @@ AGENT_TYPE_SCHEMAS["research"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
|
||||
|
||||
def _build_pg_agent_fields(fields: dict) -> dict:
|
||||
"""Translate Mongo-shaped agent fields into the Postgres mirror subset."""
|
||||
allowed = {
|
||||
"name",
|
||||
"description",
|
||||
"agent_type",
|
||||
"status",
|
||||
"key",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"tools",
|
||||
"json_schema",
|
||||
"models",
|
||||
"default_model_id",
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"incoming_webhook_token",
|
||||
"lastUsedAt",
|
||||
}
|
||||
translated: dict = {}
|
||||
for key, value in fields.items():
|
||||
if key not in allowed:
|
||||
continue
|
||||
translated["last_used_at" if key == "lastUsedAt" else key] = value
|
||||
return translated
|
||||
|
||||
|
||||
def normalize_workflow_reference(workflow_value):
|
||||
"""Normalize workflow references from form/json payloads."""
|
||||
if workflow_value is None:
|
||||
@@ -626,13 +655,14 @@ class CreateAgent(Resource):
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
AgentsRepository,
|
||||
lambda repo, u=user, a=new_agent: repo.create(
|
||||
lambda repo, u=user, a=new_agent, mid=new_id: repo.create(
|
||||
u, a.get("name", ""), a.get("status", "draft"),
|
||||
key=a.get("key"), description=a.get("description"),
|
||||
retriever=a.get("retriever"), chunks=a.get("chunks"),
|
||||
tools=a.get("tools"), models=a.get("models"),
|
||||
shared=a.get("shared", False),
|
||||
incoming_webhook_token=a.get("incoming_webhook_token"),
|
||||
legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -1170,6 +1200,14 @@ class UpdateAgent(Resource):
|
||||
jsonify({"success": False, "message": "Database error during update"}),
|
||||
500,
|
||||
)
|
||||
pg_update_fields = _build_pg_agent_fields(update_fields)
|
||||
if pg_update_fields:
|
||||
dual_write(
|
||||
AgentsRepository,
|
||||
lambda repo, aid=agent_id, u=user, fields=pg_update_fields: repo.update_by_legacy_id(
|
||||
aid, u, fields,
|
||||
),
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": agent_id,
|
||||
@@ -1199,7 +1237,7 @@ class DeleteAgent(Resource):
|
||||
)
|
||||
dual_write(
|
||||
AgentsRepository,
|
||||
lambda repo, aid=agent_id, u=user: repo.delete(aid, u),
|
||||
lambda repo, aid=agent_id, u=user: repo.delete_by_legacy_id(aid, u),
|
||||
)
|
||||
if not deleted_agent:
|
||||
return make_response(
|
||||
|
||||
@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
@@ -30,15 +32,23 @@ class DeleteConversation(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token["sub"]
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
{"_id": ObjectId(conversation_id), "user": user_id}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_delete(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
repo.delete(conv["id"], user_id)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_delete)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -59,6 +69,11 @@ class DeleteAllConversations(Resource):
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
dual_write(
|
||||
ConversationsRepository,
|
||||
lambda r, uid=user_id: r.delete_all_for_user(uid),
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -190,9 +205,10 @@ class UpdateConversationName(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"_id": ObjectId(data["id"]), "user": user_id},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -200,6 +216,13 @@ class UpdateConversationName(Resource):
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_rename(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["id"])
|
||||
if conv is not None:
|
||||
repo.rename(conv["id"], user_id, data["name"])
|
||||
|
||||
dual_write(ConversationsRepository, _pg_rename)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -277,4 +300,21 @@ class SubmitFeedback(Resource):
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
# Dual-write to Postgres: mirror the per-message feedback set/unset.
|
||||
feedback_value = data["feedback"]
|
||||
question_index = int(data["question_index"])
|
||||
feedback_payload = (
|
||||
None if feedback_value is None
|
||||
else {"text": feedback_value, "timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat()}
|
||||
)
|
||||
|
||||
def _pg_feedback(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["conversation_id"])
|
||||
if conv is not None:
|
||||
repo.set_feedback(conv["id"], question_index, feedback_payload)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_feedback)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@@ -53,7 +53,9 @@ class CreatePrompt(Resource):
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, u=user, n=data["name"], c=data["content"]: repo.create(u, n, c),
|
||||
lambda repo, u=user, n=data["name"], c=data["content"], mid=new_id: repo.create(
|
||||
u, n, c, legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
@@ -157,7 +159,7 @@ class DeletePrompt(Resource):
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user: repo.delete(pid, u),
|
||||
lambda repo, pid=data["id"], u=user: repo.delete_by_legacy_id(pid, u),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
@@ -197,7 +199,9 @@ class UpdatePrompt(Resource):
|
||||
)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update(pid, u, n, c),
|
||||
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update_by_legacy_id(
|
||||
pid, u, n, c,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
|
||||
@@ -15,8 +15,71 @@ from application.api.user.base import (
|
||||
conversations_collection,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.shared_conversations import (
|
||||
SharedConversationsRepository,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
|
||||
def _dual_write_share(
|
||||
mongo_conv_id: str,
|
||||
share_uuid: str,
|
||||
user: str,
|
||||
*,
|
||||
is_promptable: bool,
|
||||
first_n_queries: int,
|
||||
api_key: str | None,
|
||||
prompt_id: str | None = None,
|
||||
chunks: int | None = None,
|
||||
) -> None:
|
||||
"""Mirror a Mongo share-record insert into Postgres.
|
||||
|
||||
Preserves the Mongo-generated UUID so public ``/shared/{uuid}`` URLs
|
||||
resolve from both stores during cutover.
|
||||
"""
|
||||
def _write(repo: SharedConversationsRepository) -> None:
|
||||
conv = ConversationsRepository(repo._conn).get_by_legacy_id(
|
||||
mongo_conv_id, user_id=user,
|
||||
)
|
||||
if conv is None:
|
||||
return
|
||||
# prompt_id / chunks are only meaningful for promptable shares;
|
||||
# prompt_id is often the string "default" or an ObjectId that
|
||||
# hasn't been migrated — pass as-is and let the repo drop
|
||||
# non-UUID values. Scope the prompt lookup by user_id so an
|
||||
# authenticated caller can't link another user's prompt into
|
||||
# their share record.
|
||||
resolved_prompt_id = None
|
||||
if prompt_id and len(str(prompt_id)) == 24:
|
||||
from sqlalchemy import text as _text
|
||||
row = repo._conn.execute(
|
||||
_text(
|
||||
"SELECT id FROM prompts "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": str(prompt_id), "user_id": user},
|
||||
).fetchone()
|
||||
if row:
|
||||
resolved_prompt_id = str(row[0])
|
||||
# get_or_create is race-free on the PG side thanks to the
|
||||
# composite partial unique index on the dedup tuple
|
||||
# (migration 0008). It converges concurrent share requests to
|
||||
# a single row.
|
||||
repo.get_or_create(
|
||||
conv["id"],
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=first_n_queries,
|
||||
api_key=api_key,
|
||||
prompt_id=resolved_prompt_id,
|
||||
chunks=chunks,
|
||||
share_uuid=share_uuid,
|
||||
)
|
||||
|
||||
dual_write(SharedConversationsRepository, _write)
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
@@ -124,6 +187,16 @@ class ShareConversation(Resource):
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=api_uuid,
|
||||
prompt_id=prompt_id,
|
||||
chunks=int(chunks) if chunks else None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -155,6 +228,16 @@ class ShareConversation(Resource):
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=api_uuid,
|
||||
prompt_id=prompt_id,
|
||||
chunks=int(chunks) if chunks else None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -192,6 +275,14 @@ class ShareConversation(Resource):
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
|
||||
@@ -134,6 +134,12 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=60),
|
||||
cleanup_pending_tool_state.s(),
|
||||
name="cleanup-pending-tool-state",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -146,3 +152,27 @@ def mcp_oauth_task(self, config, user):
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Delete pending_tool_state rows past their TTL.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||
no native TTL, so this task runs every 60 seconds to keep
|
||||
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||
configured (keeps the task runnable in Mongo-only environments).
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||
return {"deleted": deleted}
|
||||
|
||||
@@ -11,6 +11,10 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
@@ -35,6 +39,174 @@ def _workflow_error_response(message: str, err: Exception):
|
||||
return error_response(message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Postgres dual-write helpers
|
||||
#
|
||||
# Workflows are unusual relative to other Phase 3 tables: a single user
|
||||
# action (create / update) writes to three collections in concert
|
||||
# (workflows + workflow_nodes + workflow_edges) and the edges reference
|
||||
# nodes by user-provided string ids. The Postgres mirror needs to:
|
||||
#
|
||||
# 1. Run all three writes inside one PG transaction (so the just-created
|
||||
# nodes are visible when we resolve their UUIDs for the edge insert).
|
||||
# 2. Translate edge source_id/target_id strings → workflow_nodes.id UUIDs
|
||||
# after the bulk_create returns them.
|
||||
#
|
||||
# Each helper opens exactly one ``dual_write`` call (one PG txn) and uses
|
||||
# the connection from whichever repo it was instantiated with to spin up
|
||||
# any sibling repos it needs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _dual_write_workflow_create(
|
||||
mongo_workflow_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
graph_version: int = 1,
|
||||
) -> None:
|
||||
"""Mirror a Mongo workflow create into Postgres."""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
conn = repo._conn
|
||||
wf = repo.create(
|
||||
user_id,
|
||||
name,
|
||||
description=description,
|
||||
legacy_mongo_id=mongo_workflow_id,
|
||||
)
|
||||
_write_graph(conn, wf["id"], graph_version, nodes_data, edges_data)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _dual_write_workflow_update(
|
||||
mongo_workflow_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
next_graph_version: int,
|
||||
) -> None:
|
||||
"""Mirror a Mongo workflow update into Postgres.
|
||||
|
||||
Mirrors the Mongo route: insert the new graph_version's nodes/edges,
|
||||
bump the workflow's name/description/current_graph_version, then drop
|
||||
every other graph_version's nodes/edges.
|
||||
"""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
conn = repo._conn
|
||||
wf = _resolve_pg_workflow(conn, mongo_workflow_id)
|
||||
if wf is None:
|
||||
return
|
||||
_write_graph(conn, wf["id"], next_graph_version, nodes_data, edges_data)
|
||||
repo.update(wf["id"], user_id, {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"current_graph_version": next_graph_version,
|
||||
})
|
||||
WorkflowNodesRepository(conn).delete_other_versions(
|
||||
wf["id"], next_graph_version,
|
||||
)
|
||||
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||
wf["id"], next_graph_version,
|
||||
)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _dual_write_workflow_delete(mongo_workflow_id: str, user_id: str) -> None:
|
||||
"""Mirror a Mongo workflow delete into Postgres.
|
||||
|
||||
The CASCADE on workflows.id → workflow_nodes/workflow_edges takes
|
||||
care of the children automatically.
|
||||
"""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
wf = _resolve_pg_workflow(repo._conn, mongo_workflow_id)
|
||||
if wf is not None:
|
||||
repo.delete(wf["id"], user_id)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _resolve_pg_workflow(conn, mongo_workflow_id: str) -> Optional[Dict]:
|
||||
"""Look up a Postgres workflow by its Mongo ObjectId string."""
|
||||
from sqlalchemy import text as _text
|
||||
row = conn.execute(
|
||||
_text("SELECT id FROM workflows WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": mongo_workflow_id},
|
||||
).fetchone()
|
||||
return {"id": str(row[0])} if row else None
|
||||
|
||||
|
||||
def _write_graph(
|
||||
conn,
|
||||
pg_workflow_id: str,
|
||||
graph_version: int,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
) -> None:
|
||||
"""Bulk-create nodes + edges for one graph version inside one txn.
|
||||
|
||||
Edges arrive with source/target as user-provided node-id strings
|
||||
(the same shape the Mongo route stores). We bulk-insert nodes first,
|
||||
capture their ``node_id → UUID`` map from the returned rows, then
|
||||
translate edge source/target strings to those UUIDs before the edge
|
||||
bulk insert. Edges referencing missing nodes are dropped (logged).
|
||||
"""
|
||||
nodes_repo = WorkflowNodesRepository(conn)
|
||||
edges_repo = WorkflowEdgesRepository(conn)
|
||||
|
||||
if nodes_data:
|
||||
created_nodes = nodes_repo.bulk_create(
|
||||
pg_workflow_id, graph_version,
|
||||
[
|
||||
{
|
||||
"node_id": n["id"],
|
||||
"node_type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||
}
|
||||
for n in nodes_data
|
||||
],
|
||||
)
|
||||
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
|
||||
else:
|
||||
node_uuid_by_str = {}
|
||||
|
||||
if edges_data:
|
||||
translated_edges: List[Dict] = []
|
||||
for e in edges_data:
|
||||
src = e.get("source")
|
||||
tgt = e.get("target")
|
||||
from_uuid = node_uuid_by_str.get(src)
|
||||
to_uuid = node_uuid_by_str.get(tgt)
|
||||
if not from_uuid or not to_uuid:
|
||||
current_app.logger.warning(
|
||||
"PG dual-write: dropping edge %s; node refs unresolved "
|
||||
"(source=%s, target=%s)",
|
||||
e.get("id"), src, tgt,
|
||||
)
|
||||
continue
|
||||
translated_edges.append({
|
||||
"edge_id": e["id"],
|
||||
"from_node_id": from_uuid,
|
||||
"to_node_id": to_uuid,
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
})
|
||||
if translated_edges:
|
||||
edges_repo.bulk_create(pg_workflow_id, graph_version, translated_edges)
|
||||
|
||||
|
||||
def serialize_workflow(w: Dict) -> Dict:
|
||||
"""Serialize workflow document to API response format."""
|
||||
return {
|
||||
@@ -317,24 +489,28 @@ def _can_reach_end(
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow nodes into database."""
|
||||
) -> List[Dict]:
|
||||
"""Insert workflow nodes into Mongo and return rows with Mongo ids."""
|
||||
if nodes_data:
|
||||
workflow_nodes_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
)
|
||||
mongo_nodes = [
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
result = workflow_nodes_collection.insert_many(mongo_nodes)
|
||||
return [
|
||||
{**node, "legacy_mongo_id": str(inserted_id)}
|
||||
for node, inserted_id in zip(nodes_data, result.inserted_ids)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def create_workflow_edges(
|
||||
@@ -399,7 +575,7 @@ class WorkflowList(Resource):
|
||||
workflow_id = str(result.inserted_id)
|
||||
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
@@ -407,6 +583,15 @@ class WorkflowList(Resource):
|
||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||
return _workflow_error_response("Failed to create workflow structure", err)
|
||||
|
||||
_dual_write_workflow_create(
|
||||
workflow_id,
|
||||
user_id,
|
||||
name,
|
||||
data.get("description", ""),
|
||||
created_nodes,
|
||||
edges_data,
|
||||
)
|
||||
|
||||
return success_response({"id": workflow_id}, 201)
|
||||
|
||||
|
||||
@@ -473,7 +658,9 @@ class WorkflowDetail(Resource):
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
||||
created_nodes = create_workflow_nodes(
|
||||
workflow_id, nodes_data, next_graph_version,
|
||||
)
|
||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many(
|
||||
@@ -520,6 +707,16 @@ class WorkflowDetail(Resource):
|
||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||
)
|
||||
|
||||
_dual_write_workflow_update(
|
||||
workflow_id,
|
||||
user_id,
|
||||
name,
|
||||
data.get("description", ""),
|
||||
created_nodes,
|
||||
edges_data,
|
||||
next_graph_version,
|
||||
)
|
||||
|
||||
return success_response()
|
||||
|
||||
@require_auth
|
||||
@@ -543,4 +740,6 @@ class WorkflowDetail(Resource):
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to delete workflow", err)
|
||||
|
||||
_dual_write_workflow_delete(workflow_id, user_id)
|
||||
|
||||
return success_response()
|
||||
|
||||
@@ -5,6 +5,14 @@ MongoDB→Postgres migration. The baseline schema in the Alembic migration
|
||||
(``application/alembic/versions/0001_initial.py``) is the source of truth
|
||||
for DDL; the ``Table`` definitions below must match it column-for-column.
|
||||
If the two drift, migrations win — update this file to match.
|
||||
|
||||
Cross-table invariant not expressed in the Core ``Table`` definitions
|
||||
below: every ``user_id`` column is FK-enforced against
|
||||
``users(user_id)`` with ``ON DELETE RESTRICT``, and a
|
||||
``BEFORE INSERT OR UPDATE OF user_id`` trigger on each child table
|
||||
auto-creates the ``users`` row if it does not yet exist. See migration
|
||||
``0015_user_id_fk``. The FKs are intentionally omitted from the Core
|
||||
declarations to keep this file readable; the DB is the authority.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
@@ -13,6 +21,7 @@ from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
ForeignKeyConstraint,
|
||||
Integer,
|
||||
MetaData,
|
||||
UniqueConstraint,
|
||||
@@ -20,7 +29,7 @@ from sqlalchemy import (
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||
|
||||
metadata = MetaData()
|
||||
|
||||
@@ -51,6 +60,7 @@ prompts_table = Table(
|
||||
Column("content", Text, nullable=False),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
user_tools_table = Table(
|
||||
@@ -88,17 +98,6 @@ user_logs_table = Table(
|
||||
Column("data", JSONB),
|
||||
)
|
||||
|
||||
feedback_table = Table(
|
||||
"feedback",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("conversation_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("question_index", Integer, nullable=False),
|
||||
Column("feedback_text", Text),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
stack_logs_table = Table(
|
||||
"stack_logs",
|
||||
metadata,
|
||||
@@ -131,7 +130,7 @@ sources_table = Table(
|
||||
"sources",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("type", Text),
|
||||
Column("metadata", JSONB, nullable=False, server_default="{}"),
|
||||
@@ -148,7 +147,7 @@ agents_table = Table(
|
||||
Column("description", Text),
|
||||
Column("agent_type", Text),
|
||||
Column("status", Text, nullable=False),
|
||||
Column("key", Text, unique=True),
|
||||
Column("key", CITEXT, unique=True),
|
||||
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
|
||||
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||
Column("chunks", Integer),
|
||||
@@ -164,10 +163,11 @@ agents_table = Table(
|
||||
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
|
||||
Column("request_limit", Integer),
|
||||
Column("shared", Boolean, nullable=False, server_default="false"),
|
||||
Column("incoming_webhook_token", Text, unique=True),
|
||||
Column("incoming_webhook_token", CITEXT, unique=True),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("last_used_at", DateTime(timezone=True)),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
attachments_table = Table(
|
||||
@@ -180,6 +180,7 @@ attachments_table = Table(
|
||||
Column("mime_type", Text),
|
||||
Column("size", BigInteger),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
memories_table = Table(
|
||||
@@ -230,3 +231,166 @@ connector_sessions_table = Table(
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("user_id", "provider", name="connector_sessions_user_provider_uidx"),
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||
|
||||
conversations_table = Table(
|
||||
"conversations",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("agent_id", UUID(as_uuid=True), ForeignKey("agents.id", ondelete="SET NULL")),
|
||||
Column("name", Text),
|
||||
Column("api_key", Text),
|
||||
Column("is_shared_usage", Boolean, nullable=False, server_default="false"),
|
||||
Column("shared_token", Text),
|
||||
Column("shared_with", ARRAY(Text), nullable=False, server_default="{}"),
|
||||
Column("compression_metadata", JSONB),
|
||||
Column("date", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
conversation_messages_table = Table(
|
||||
"conversation_messages",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||
# Denormalised from conversations.user_id. Auto-filled on insert by a
|
||||
# BEFORE INSERT trigger when the caller omits it. See migration 0020.
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("position", Integer, nullable=False),
|
||||
Column("prompt", Text),
|
||||
Column("response", Text),
|
||||
Column("thought", Text),
|
||||
Column("sources", JSONB, nullable=False, server_default="[]"),
|
||||
Column("tool_calls", JSONB, nullable=False, server_default="[]"),
|
||||
# Postgres cannot FK-enforce array elements, so the referential
|
||||
# invariant is kept by an AFTER DELETE trigger on ``attachments``
|
||||
# that array_removes the id from every row that references it.
|
||||
# See migration 0017_cleanup_dangling_refs.
|
||||
Column("attachments", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||
Column("model_id", Text),
|
||||
# Renamed from ``metadata`` in migration 0016 to avoid SQLAlchemy's
|
||||
# reserved attribute collision on declarative models. The repository
|
||||
# translates this ↔ API dict key ``metadata`` so external callers
|
||||
# still see ``metadata``.
|
||||
Column("message_metadata", JSONB, nullable=False, server_default="{}"),
|
||||
Column("feedback", JSONB),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||
)
|
||||
|
||||
shared_conversations_table = Table(
|
||||
"shared_conversations",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("uuid", UUID(as_uuid=True), nullable=False, unique=True),
|
||||
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
|
||||
Column("chunks", Integer),
|
||||
Column("is_promptable", Boolean, nullable=False, server_default="false"),
|
||||
Column("first_n_queries", Integer, nullable=False, server_default="0"),
|
||||
Column("api_key", Text),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
pending_tool_state_table = Table(
|
||||
"pending_tool_state",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("messages", JSONB, nullable=False),
|
||||
Column("pending_tool_calls", JSONB, nullable=False),
|
||||
Column("tools_dict", JSONB, nullable=False),
|
||||
Column("tool_schemas", JSONB, nullable=False),
|
||||
Column("agent_config", JSONB, nullable=False),
|
||||
Column("client_tools", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("expires_at", DateTime(timezone=True), nullable=False),
|
||||
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
|
||||
)
|
||||
|
||||
workflows_table = Table(
|
||||
"workflows",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("description", Text),
|
||||
Column("current_graph_version", Integer, nullable=False, server_default="1"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
workflow_nodes_table = Table(
|
||||
"workflow_nodes",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("graph_version", Integer, nullable=False),
|
||||
Column("node_id", Text, nullable=False),
|
||||
Column("node_type", Text, nullable=False),
|
||||
Column("title", Text),
|
||||
Column("description", Text),
|
||||
Column("position", JSONB, nullable=False, server_default='{"x": 0, "y": 0}'),
|
||||
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||
Column("legacy_mongo_id", Text),
|
||||
# Composite UNIQUE so workflow_edges can use a composite FK that
|
||||
# enforces endpoint nodes belong to the same (workflow, version) as
|
||||
# the edge itself. See migration 0008.
|
||||
UniqueConstraint(
|
||||
"id", "workflow_id", "graph_version",
|
||||
name="workflow_nodes_id_wf_ver_key",
|
||||
),
|
||||
)
|
||||
|
||||
workflow_edges_table = Table(
|
||||
"workflow_edges",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("graph_version", Integer, nullable=False),
|
||||
Column("edge_id", Text, nullable=False),
|
||||
Column("from_node_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("to_node_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("source_handle", Text),
|
||||
Column("target_handle", Text),
|
||||
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||
# Composite FKs: endpoints must belong to the same (workflow, version)
|
||||
# as the edge. Prevents cross-workflow / cross-version edges that the
|
||||
# single-column FKs couldn't catch. See migration 0008.
|
||||
ForeignKeyConstraint(
|
||||
["from_node_id", "workflow_id", "graph_version"],
|
||||
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||
ondelete="CASCADE",
|
||||
name="workflow_edges_from_node_fk",
|
||||
),
|
||||
ForeignKeyConstraint(
|
||||
["to_node_id", "workflow_id", "graph_version"],
|
||||
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||
ondelete="CASCADE",
|
||||
name="workflow_edges_to_node_fk",
|
||||
),
|
||||
)
|
||||
|
||||
workflow_runs_table = Table(
|
||||
"workflow_runs",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("status", Text, nullable=False),
|
||||
Column("inputs", JSONB),
|
||||
Column("result", JSONB),
|
||||
Column("steps", JSONB, nullable=False, server_default="[]"),
|
||||
Column("started_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("ended_at", DateTime(timezone=True)),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
@@ -12,7 +12,6 @@ the legacy Mongo code performs on ``agents_collection``:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, func, text
|
||||
@@ -26,6 +25,13 @@ class AgentsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
@staticmethod
|
||||
def _normalize_unique_text(col: str, val):
|
||||
"""Coerce blank strings for nullable unique text columns to NULL."""
|
||||
if col == "key" and val == "":
|
||||
return None
|
||||
return val
|
||||
|
||||
def create(self, user_id: str, name: str, status: str, **kwargs) -> dict:
|
||||
values: dict = {"user_id": user_id, "name": name, "status": status}
|
||||
|
||||
@@ -35,14 +41,18 @@ class AgentsRepository:
|
||||
"source_id", "prompt_id", "folder_id",
|
||||
"chunks", "token_limit", "request_limit",
|
||||
"limited_token_mode", "limited_request_mode", "shared",
|
||||
"tools", "json_schema", "models",
|
||||
"tools", "json_schema", "models", "legacy_mongo_id",
|
||||
}
|
||||
|
||||
for col, val in kwargs.items():
|
||||
if col not in _ALLOWED or val is None:
|
||||
continue
|
||||
if col in ("tools", "json_schema", "models"):
|
||||
values[col] = json.dumps(val)
|
||||
# JSONB columns: pass the Python object directly. SQLAlchemy
|
||||
# Core's JSONB type processor json.dumps it once during
|
||||
# bind; pre-serialising would double-encode and the value
|
||||
# would round-trip as a JSON string instead of the dict.
|
||||
values[col] = val
|
||||
elif col in ("chunks", "token_limit", "request_limit"):
|
||||
values[col] = int(val)
|
||||
elif col in ("limited_token_mode", "limited_request_mode", "shared"):
|
||||
@@ -50,7 +60,7 @@ class AgentsRepository:
|
||||
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||
values[col] = str(val)
|
||||
else:
|
||||
values[col] = val
|
||||
values[col] = self._normalize_unique_text(col, val)
|
||||
|
||||
stmt = pg_insert(agents_table).values(**values).returning(agents_table)
|
||||
result = self._conn.execute(stmt)
|
||||
@@ -64,6 +74,17 @@ class AgentsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch an agent by the original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM agents WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def find_by_key(self, key: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agents WHERE key = :key"),
|
||||
@@ -108,11 +129,13 @@ class AgentsRepository:
|
||||
values: dict = {}
|
||||
for col, val in filtered.items():
|
||||
if col in ("tools", "json_schema", "models"):
|
||||
values[col] = json.dumps(val) if not isinstance(val, str) else val
|
||||
# See note in create(): JSONB columns receive Python
|
||||
# objects, the type processor handles serialisation.
|
||||
values[col] = val
|
||||
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||
values[col] = str(val) if val else None
|
||||
else:
|
||||
values[col] = val
|
||||
values[col] = self._normalize_unique_text(col, val)
|
||||
values["updated_at"] = func.now()
|
||||
|
||||
t = agents_table
|
||||
@@ -125,6 +148,13 @@ class AgentsRepository:
|
||||
result = self._conn.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
def update_by_legacy_id(self, legacy_mongo_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Update an agent addressed by the Mongo ObjectId string."""
|
||||
agent = self.get_by_legacy_id(legacy_mongo_id, user_id)
|
||||
if agent is None:
|
||||
return False
|
||||
return self.update(agent["id"], user_id, fields)
|
||||
|
||||
def delete(self, agent_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
@@ -132,6 +162,17 @@ class AgentsRepository:
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete an agent addressed by the Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM agents "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_folder(self, agent_id: str, user_id: str, folder_id: Optional[str]) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
|
||||
@@ -14,12 +14,15 @@ class AttachmentsRepository:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, filename: str, upload_path: str, *,
|
||||
mime_type: Optional[str] = None, size: Optional[int] = None) -> dict:
|
||||
mime_type: Optional[str] = None, size: Optional[int] = None,
|
||||
legacy_mongo_id: Optional[str] = None) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size)
|
||||
VALUES (:user_id, :filename, :upload_path, :mime_type, :size)
|
||||
INSERT INTO attachments
|
||||
(user_id, filename, upload_path, mime_type, size, legacy_mongo_id)
|
||||
VALUES
|
||||
(:user_id, :filename, :upload_path, :mime_type, :size, :legacy_mongo_id)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
@@ -29,6 +32,7 @@ class AttachmentsRepository:
|
||||
"upload_path": upload_path,
|
||||
"mime_type": mime_type,
|
||||
"size": size,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
@@ -43,6 +47,17 @@ class AttachmentsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch an attachment by the original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM attachments WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||
|
||||
476
application/storage/db/repositories/conversations.py
Normal file
476
application/storage/db/repositories/conversations.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""Repository for the ``conversations`` and ``conversation_messages`` tables.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``conversations_collection``:
|
||||
|
||||
- create / get / list / delete conversations
|
||||
- append message (transactional position allocation)
|
||||
- update message at index (overwrite + optional truncation)
|
||||
- set / unset feedback on a message
|
||||
- rename conversation
|
||||
- update compression metadata
|
||||
- shared_with access checks
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import conversations_table, conversation_messages_table
|
||||
|
||||
|
||||
def _message_row_to_dict(row) -> dict:
|
||||
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
|
||||
back to the public API key ``metadata`` so callers keep the Mongo-era
|
||||
shape. See migration 0016 for the column rename rationale."""
|
||||
out = row_to_dict(row)
|
||||
if "message_metadata" in out:
|
||||
out["metadata"] = out.pop("message_metadata")
|
||||
return out
|
||||
|
||||
|
||||
class ConversationsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conversation CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str | None = None,
|
||||
*,
|
||||
agent_id: str | None = None,
|
||||
api_key: str | None = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: str | None = None,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a new conversation.
|
||||
|
||||
``legacy_mongo_id`` is used by the dual-write shim so that a
|
||||
Postgres row inserted *after* a successful Mongo insert carries
|
||||
the Mongo ``_id`` as a lookup key. Subsequent appends/updates
|
||||
can then resolve the PG row by that id via
|
||||
:meth:`get_by_legacy_id`.
|
||||
"""
|
||||
values: dict = {
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
}
|
||||
if agent_id:
|
||||
values["agent_id"] = agent_id
|
||||
if api_key:
|
||||
values["api_key"] = api_key
|
||||
if is_shared_usage:
|
||||
values["is_shared_usage"] = True
|
||||
if shared_token:
|
||||
values["shared_token"] = shared_token
|
||||
if legacy_mongo_id:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(conversations_table).values(**values).returning(conversations_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||
) -> Optional[dict]:
|
||||
"""Look up a conversation by the original Mongo ObjectId string.
|
||||
|
||||
Used by the dual-write helpers to translate a Mongo ``_id`` into
|
||||
the Postgres UUID for follow-up writes. When ``user_id`` is
|
||||
provided, the lookup is scoped to rows owned by that user so
|
||||
callers can't accidentally resolve another user's conversation.
|
||||
"""
|
||||
if user_id is not None:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
else:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations WHERE legacy_mongo_id = :legacy_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch a conversation the user owns or has shared access to."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"AND (user_id = :user_id OR :user_id = ANY(shared_with))"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_owned(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch a conversation owned by the user (no shared access)."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str, limit: int = 30) -> list[dict]:
|
||||
"""List conversations for a user, most recent first.
|
||||
|
||||
Mirrors the Mongo query: either no api_key or agent_id exists.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE user_id = :user_id "
|
||||
"AND (api_key IS NULL OR agent_id IS NOT NULL) "
|
||||
"ORDER BY date DESC LIMIT :limit"
|
||||
),
|
||||
{"user_id": user_id, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id, "name": name},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_shared_token(self, conversation_id: str, user_id: str, token: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET shared_token = :token, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id, "token": token},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, user_id: str, metadata: dict,
|
||||
) -> bool:
|
||||
"""Replace the entire ``compression_metadata`` JSONB blob.
|
||||
|
||||
Prefer :meth:`append_compression_point` + :meth:`set_compression_flags`
|
||||
to match the Mongo service semantics exactly (those two mirror
|
||||
``$set`` + ``$push $slice``). This method is retained for callers
|
||||
that already compute the full merged blob client-side.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations "
|
||||
"SET compression_metadata = CAST(:meta AS jsonb), updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id, "meta": json.dumps(metadata)},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_compression_flags(
|
||||
self,
|
||||
conversation_id: str,
|
||||
*,
|
||||
is_compressed: bool,
|
||||
last_compression_at,
|
||||
) -> bool:
|
||||
"""Update ``compression_metadata.is_compressed`` and
|
||||
``compression_metadata.last_compression_at`` without touching
|
||||
``compression_points``.
|
||||
|
||||
Mirrors the Mongo ``$set`` on those two subfields in
|
||||
``ConversationService.update_compression_metadata``. Initialises
|
||||
the surrounding object when the row has no ``compression_metadata``
|
||||
yet.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversations SET
|
||||
compression_metadata = jsonb_set(
|
||||
jsonb_set(
|
||||
COALESCE(compression_metadata, '{}'::jsonb),
|
||||
'{is_compressed}',
|
||||
to_jsonb(CAST(:is_compressed AS boolean)), true
|
||||
),
|
||||
'{last_compression_at}',
|
||||
to_jsonb(CAST(:last_compression_at AS text)), true
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": conversation_id,
|
||||
"is_compressed": bool(is_compressed),
|
||||
"last_compression_at": (
|
||||
str(last_compression_at) if last_compression_at is not None else None
|
||||
),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def append_compression_point(
|
||||
self,
|
||||
conversation_id: str,
|
||||
point: dict,
|
||||
*,
|
||||
max_points: int,
|
||||
) -> bool:
|
||||
"""Append one compression point, keeping at most ``max_points``.
|
||||
|
||||
Mirrors Mongo's ``$push {"$each": [point], "$slice": -max_points}``
|
||||
on ``compression_metadata.compression_points``. Preserves the
|
||||
other top-level keys in ``compression_metadata``.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversations SET
|
||||
compression_metadata = jsonb_set(
|
||||
COALESCE(compression_metadata, '{}'::jsonb),
|
||||
'{compression_points}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem ORDER BY rn)
|
||||
FROM (
|
||||
SELECT
|
||||
elem,
|
||||
row_number() OVER () AS rn,
|
||||
count(*) OVER () AS cnt
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(
|
||||
compression_metadata -> 'compression_points',
|
||||
'[]'::jsonb
|
||||
) || jsonb_build_array(CAST(:point AS jsonb))
|
||||
) AS elem
|
||||
) ranked
|
||||
WHERE rn > cnt - :max_points
|
||||
),
|
||||
'[]'::jsonb
|
||||
),
|
||||
true
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": conversation_id,
|
||||
"point": json.dumps(point, default=str),
|
||||
"max_points": int(max_points),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, conversation_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete_all_for_user(self, user_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM conversations WHERE user_id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Messages
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_messages(self, conversation_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"ORDER BY position ASC"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
return [_message_row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def get_message_at(self, conversation_id: str, position: int) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND position = :pos"
|
||||
),
|
||||
{"conv_id": conversation_id, "pos": position},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return _message_row_to_dict(row) if row is not None else None
|
||||
|
||||
def append_message(self, conversation_id: str, message: dict) -> dict:
|
||||
"""Append a message to a conversation.
|
||||
|
||||
Uses ``SELECT ... FOR UPDATE`` to allocate the next position
|
||||
atomically. The caller must be inside a transaction.
|
||||
|
||||
Mirrors Mongo's ``$push`` on the ``queries`` array.
|
||||
"""
|
||||
# Lock the parent conversation row to serialize concurrent appends.
|
||||
self._conn.execute(
|
||||
text(
|
||||
"SELECT id FROM conversations "
|
||||
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
next_pos_result = self._conn.execute(
|
||||
text(
|
||||
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
|
||||
"FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
next_pos = next_pos_result.scalar()
|
||||
|
||||
values = {
|
||||
"conversation_id": conversation_id,
|
||||
"position": next_pos,
|
||||
"prompt": message.get("prompt"),
|
||||
"response": message.get("response"),
|
||||
"thought": message.get("thought"),
|
||||
"sources": message.get("sources") or [],
|
||||
"tool_calls": message.get("tool_calls") or [],
|
||||
"model_id": message.get("model_id"),
|
||||
"message_metadata": message.get("metadata") or {},
|
||||
}
|
||||
if message.get("timestamp") is not None:
|
||||
values["timestamp"] = message["timestamp"]
|
||||
|
||||
attachments = message.get("attachments")
|
||||
if attachments:
|
||||
values["attachments"] = [str(a) for a in attachments]
|
||||
|
||||
stmt = (
|
||||
pg_insert(conversation_messages_table)
|
||||
.values(**values)
|
||||
.returning(conversation_messages_table)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
# Touch the parent conversation's updated_at.
|
||||
self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conversation_id},
|
||||
)
|
||||
return _message_row_to_dict(result.fetchone())
|
||||
|
||||
def update_message_at(
|
||||
self, conversation_id: str, position: int, fields: dict,
|
||||
) -> bool:
|
||||
"""Update specific fields on a message at a given position.
|
||||
|
||||
Mirrors Mongo's ``$set`` on ``queries.{index}.*``.
|
||||
"""
|
||||
allowed = {
|
||||
"prompt", "response", "thought", "sources", "tool_calls",
|
||||
"attachments", "model_id", "metadata", "timestamp",
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return False
|
||||
|
||||
# Map public API key ``metadata`` → DB column ``message_metadata``.
|
||||
api_to_col = {"metadata": "message_metadata"}
|
||||
|
||||
set_parts = []
|
||||
params: dict = {"conv_id": conversation_id, "pos": position}
|
||||
for key, val in filtered.items():
|
||||
col = api_to_col.get(key, key)
|
||||
if key in ("sources", "tool_calls", "metadata"):
|
||||
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = json.dumps(val) if not isinstance(val, str) else val
|
||||
elif key == "attachments":
|
||||
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
|
||||
params[col] = [str(a) for a in val] if val else []
|
||||
else:
|
||||
set_parts.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
|
||||
if "timestamp" not in filtered:
|
||||
set_parts.append("timestamp = now()")
|
||||
sql = (
|
||||
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||
)
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
|
||||
"""Delete messages with position > keep_up_to.
|
||||
|
||||
Mirrors Mongo's ``$push`` + ``$slice`` that trims queries after an
|
||||
index-based update.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND position > :pos"
|
||||
),
|
||||
{"conv_id": conversation_id, "pos": keep_up_to},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def set_feedback(
|
||||
self, conversation_id: str, position: int, feedback: dict | None,
|
||||
) -> bool:
|
||||
"""Set or unset feedback on a message.
|
||||
|
||||
``feedback`` is a JSONB value, e.g. ``{"text": "thumbs_up",
|
||||
"timestamp": "..."}`` or ``None`` to unset.
|
||||
"""
|
||||
fb_json = json.dumps(feedback) if feedback is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversation_messages "
|
||||
"SET feedback = CAST(:fb AS jsonb) "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||
),
|
||||
{"conv_id": conversation_id, "pos": position, "fb": fb_json},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def message_count(self, conversation_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT COUNT(*) FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
return result.scalar() or 0
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Repository for the ``feedback`` table.
|
||||
|
||||
The ``feedback_collection`` global is declared in ``base.py`` but currently
|
||||
has zero direct call sites in the application code (all feedback writes go
|
||||
through ``conversation_messages.feedback`` JSONB field on the conversations
|
||||
collection). The table exists for when feedback is denormalized into its own
|
||||
rows. This repository provides the append-only insert and basic reads
|
||||
needed for that future.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
"""Postgres-backed replacement for Mongo ``feedback_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
question_index: int,
|
||||
feedback_text: Optional[str] = None,
|
||||
) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text)
|
||||
VALUES (CAST(:conversation_id AS uuid), :user_id, :question_index, :feedback_text)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"question_index": question_index,
|
||||
"feedback_text": feedback_text,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def list_for_conversation(self, conversation_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM feedback WHERE conversation_id = CAST(:cid AS uuid) ORDER BY question_index"
|
||||
),
|
||||
{"cid": conversation_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
128
application/storage/db/repositories/pending_tool_state.py
Normal file
128
application/storage/db/repositories/pending_tool_state.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Repository for the ``pending_tool_state`` table.
|
||||
|
||||
Mirrors the continuation service's three operations on
|
||||
``pending_tool_state`` in Mongo:
|
||||
|
||||
- save_state → upsert (INSERT ... ON CONFLICT DO UPDATE)
|
||||
- load_state → find_one by (conversation_id, user_id)
|
||||
- delete_state → delete_one by (conversation_id, user_id)
|
||||
|
||||
Plus a cleanup method for the Celery beat task that replaces Mongo's
|
||||
TTL index.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
|
||||
|
||||
|
||||
class PendingToolStateRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
messages: list,
|
||||
pending_tool_calls: list,
|
||||
tools_dict: dict,
|
||||
tool_schemas: list,
|
||||
agent_config: dict,
|
||||
client_tools: list | None = None,
|
||||
ttl_seconds: int = PENDING_STATE_TTL_SECONDS,
|
||||
) -> dict:
|
||||
"""Upsert pending tool state.
|
||||
|
||||
Mirrors Mongo's ``replace_one(..., upsert=True)``.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = datetime.fromtimestamp(
|
||||
now.timestamp() + ttl_seconds, tz=timezone.utc,
|
||||
)
|
||||
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO pending_tool_state
|
||||
(conversation_id, user_id, messages, pending_tool_calls,
|
||||
tools_dict, tool_schemas, agent_config, client_tools,
|
||||
created_at, expires_at)
|
||||
VALUES
|
||||
(CAST(:conv_id AS uuid), :user_id,
|
||||
CAST(:messages AS jsonb), CAST(:pending AS jsonb),
|
||||
CAST(:tools_dict AS jsonb), CAST(:schemas AS jsonb),
|
||||
CAST(:agent_config AS jsonb), CAST(:client_tools AS jsonb),
|
||||
:created_at, :expires_at)
|
||||
ON CONFLICT (conversation_id, user_id) DO UPDATE SET
|
||||
messages = EXCLUDED.messages,
|
||||
pending_tool_calls = EXCLUDED.pending_tool_calls,
|
||||
tools_dict = EXCLUDED.tools_dict,
|
||||
tool_schemas = EXCLUDED.tool_schemas,
|
||||
agent_config = EXCLUDED.agent_config,
|
||||
client_tools = EXCLUDED.client_tools,
|
||||
created_at = EXCLUDED.created_at,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"messages": json.dumps(messages),
|
||||
"pending": json.dumps(pending_tool_calls),
|
||||
"tools_dict": json.dumps(tools_dict),
|
||||
"schemas": json.dumps(tool_schemas),
|
||||
"agent_config": json.dumps(agent_config),
|
||||
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
|
||||
"created_at": now,
|
||||
"expires_at": expires,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def load_state(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM pending_tool_state "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id"
|
||||
),
|
||||
{"conv_id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete_state(self, conversation_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM pending_tool_state "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id"
|
||||
),
|
||||
{"conv_id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Delete rows where ``expires_at < now()``.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index. Intended to
|
||||
be called from a Celery beat task every 60 seconds.
|
||||
"""
|
||||
# clock_timestamp() — not now() — since the latter is frozen to the
|
||||
# start of the transaction, which would let state that has just
|
||||
# expired survive one more cleanup tick.
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM pending_tool_state WHERE expires_at < clock_timestamp()")
|
||||
)
|
||||
return result.rowcount
|
||||
@@ -27,16 +27,27 @@ class PromptsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, name: str, content: str) -> dict:
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
content: str,
|
||||
*,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
sql = """
|
||||
INSERT INTO prompts (user_id, name, content, legacy_mongo_id)
|
||||
VALUES (:user_id, :name, :content, :legacy_mongo_id)
|
||||
RETURNING *
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO prompts (user_id, name, content)
|
||||
VALUES (:user_id, :name, :content)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "name": name, "content": content},
|
||||
text(sql),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"content": content,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
@@ -48,6 +59,17 @@ class PromptsRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch a prompt by the original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM prompts WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
|
||||
"""Fetch prompt content by ID without user scoping.
|
||||
|
||||
@@ -80,12 +102,48 @@ class PromptsRepository:
|
||||
{"id": prompt_id, "user_id": user_id, "name": name, "content": content},
|
||||
)
|
||||
|
||||
def update_by_legacy_id(
|
||||
self,
|
||||
legacy_mongo_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Update a prompt addressed by the Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE prompts
|
||||
SET name = :name, content = :content, updated_at = now()
|
||||
WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"legacy_id": legacy_mongo_id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, prompt_id: str, user_id: str) -> None:
|
||||
self._conn.execute(
|
||||
text("DELETE FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": prompt_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete a prompt addressed by the Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM prompts "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def find_or_create(self, user_id: str, name: str, content: str) -> dict:
|
||||
"""Return existing prompt matching (user, name, content), or create one.
|
||||
|
||||
|
||||
205
application/storage/db/repositories/shared_conversations.py
Normal file
205
application/storage/db/repositories/shared_conversations.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Repository for the ``shared_conversations`` table.
|
||||
|
||||
Covers the sharing operations from ``shared_conversations_collections``
|
||||
in Mongo:
|
||||
|
||||
- create a share record (with UUID, conversation_id, user, visibility flags)
|
||||
- look up by uuid (public access)
|
||||
- look up by conversation_id + user + flags (dedup check)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as uuid_mod
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import shared_conversations_table
|
||||
|
||||
|
||||
class SharedConversationsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
is_promptable: bool = False,
|
||||
first_n_queries: int = 0,
|
||||
api_key: str | None = None,
|
||||
prompt_id: str | None = None,
|
||||
chunks: int | None = None,
|
||||
share_uuid: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a share record.
|
||||
|
||||
``share_uuid`` allows the dual-write caller to supply the same
|
||||
UUID that Mongo received, so public ``/shared/{uuid}`` links
|
||||
keep resolving from both stores during the dual-write window.
|
||||
|
||||
Callers that need race-free dedup on the logical share key
|
||||
should use :meth:`get_or_create` instead — it relies on the
|
||||
composite partial unique index added in migration 0008 to
|
||||
collapse concurrent requests to a single row.
|
||||
"""
|
||||
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||
values: dict = {
|
||||
"uuid": final_uuid,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"first_n_queries": first_n_queries,
|
||||
}
|
||||
if api_key:
|
||||
values["api_key"] = api_key
|
||||
if prompt_id:
|
||||
values["prompt_id"] = prompt_id
|
||||
if chunks is not None:
|
||||
values["chunks"] = chunks
|
||||
|
||||
stmt = (
|
||||
pg_insert(shared_conversations_table)
|
||||
.values(**values)
|
||||
.returning(shared_conversations_table)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
is_promptable: bool = False,
|
||||
first_n_queries: int = 0,
|
||||
api_key: str | None = None,
|
||||
prompt_id: str | None = None,
|
||||
chunks: int | None = None,
|
||||
share_uuid: str | None = None,
|
||||
) -> dict:
|
||||
"""Race-free share create/lookup keyed on the logical dedup tuple.
|
||||
|
||||
Leverages the partial unique index on
|
||||
``(conversation_id, user_id, is_promptable, first_n_queries,
|
||||
COALESCE(api_key, ''))`` added in migration 0008. Concurrent
|
||||
requests for the same logical share converge on one row. The
|
||||
returned dict's ``uuid`` is the canonical public identifier.
|
||||
|
||||
Dedup key rationale — ``prompt_id`` and ``chunks`` are
|
||||
deliberately *not* part of the uniqueness key. A share row is
|
||||
identified by "who shared what conversation under which
|
||||
visibility rules"; ``prompt_id`` / ``chunks`` are mutable
|
||||
properties of that share and are last-write-wins on re-share.
|
||||
This preserves existing public ``/shared/{uuid}`` URLs when a
|
||||
user updates the prompt or chunk count, matching the Mongo
|
||||
``find_one`` + ``update`` semantics.
|
||||
"""
|
||||
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO shared_conversations
|
||||
(uuid, conversation_id, user_id, is_promptable,
|
||||
first_n_queries, api_key, prompt_id, chunks)
|
||||
VALUES
|
||||
(CAST(:uuid AS uuid), CAST(:conversation_id AS uuid),
|
||||
:user_id, :is_promptable, :first_n_queries,
|
||||
:api_key, CAST(:prompt_id AS uuid), :chunks)
|
||||
ON CONFLICT (conversation_id, user_id, is_promptable,
|
||||
first_n_queries, COALESCE(api_key, ''))
|
||||
DO UPDATE SET prompt_id = EXCLUDED.prompt_id,
|
||||
chunks = EXCLUDED.chunks
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"uuid": final_uuid,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"first_n_queries": first_n_queries,
|
||||
"api_key": api_key,
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def find_by_uuid(self, share_uuid: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE uuid = CAST(:uuid AS uuid)"
|
||||
),
|
||||
{"uuid": share_uuid},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def find_existing(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
is_promptable: bool,
|
||||
first_n_queries: int,
|
||||
api_key: str | None = None,
|
||||
) -> Optional[dict]:
|
||||
"""Check for an existing share with matching parameters.
|
||||
|
||||
Mirrors the Mongo ``find_one`` dedup check before creating a share.
|
||||
"""
|
||||
if api_key:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id "
|
||||
"AND is_promptable = :is_promptable "
|
||||
"AND first_n_queries = :fnq "
|
||||
"AND api_key = :api_key "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"fnq": first_n_queries,
|
||||
"api_key": api_key,
|
||||
},
|
||||
)
|
||||
else:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id "
|
||||
"AND is_promptable = :is_promptable "
|
||||
"AND first_n_queries = :fnq "
|
||||
"AND api_key IS NULL "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"fnq": first_n_queries,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_conversation(self, conversation_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"ORDER BY created_at DESC"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
@@ -15,7 +15,7 @@ class SourcesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, name: str, *, user_id: Optional[str] = None,
|
||||
def create(self, name: str, *, user_id: str,
|
||||
type: Optional[str] = None, metadata: Optional[dict] = None) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
@@ -55,12 +55,12 @@ class SourcesRepository:
|
||||
if not filtered:
|
||||
return
|
||||
|
||||
values: dict = {}
|
||||
for col, val in filtered.items():
|
||||
if col == "metadata":
|
||||
values[col] = json.dumps(val) if isinstance(val, dict) else val
|
||||
else:
|
||||
values[col] = val
|
||||
# Pass Python objects directly for JSONB columns when using
|
||||
# SQLAlchemy Core .update() — the JSONB type processor json.dumps
|
||||
# them itself; pre-serialising here would double-encode and the
|
||||
# value would round-trip as a JSON string instead of the original
|
||||
# dict.
|
||||
values: dict = dict(filtered)
|
||||
values["updated_at"] = func.now()
|
||||
|
||||
t = sources_table
|
||||
|
||||
@@ -46,7 +46,7 @@ class UserLogsRepository:
|
||||
{
|
||||
"user_id": user_id,
|
||||
"endpoint": endpoint,
|
||||
"data": json.dumps(data) if data is not None else None,
|
||||
"data": json.dumps(data, default=str) if data is not None else None,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
170
application/storage/db/repositories/workflow_edges.py
Normal file
170
application/storage/db/repositories/workflow_edges.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Repository for the ``workflow_edges`` table.
|
||||
|
||||
Covers bulk insert, find by version, and delete operations that the
|
||||
workflow routes perform on ``workflow_edges_collection`` in Mongo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflow_edges_table
|
||||
|
||||
|
||||
class WorkflowEdgesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
edge_id: str,
|
||||
from_node_id: str,
|
||||
to_node_id: str,
|
||||
*,
|
||||
source_handle: str | None = None,
|
||||
target_handle: str | None = None,
|
||||
config: dict | None = None,
|
||||
) -> dict:
|
||||
"""Create a single edge.
|
||||
|
||||
``from_node_id`` and ``to_node_id`` are the Postgres **UUID PKs**
|
||||
of the workflow_nodes rows (not user-provided node_id strings).
|
||||
"""
|
||||
values: dict = {
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"edge_id": edge_id,
|
||||
"from_node_id": from_node_id,
|
||||
"to_node_id": to_node_id,
|
||||
}
|
||||
if source_handle is not None:
|
||||
values["source_handle"] = source_handle
|
||||
if target_handle is not None:
|
||||
values["target_handle"] = target_handle
|
||||
if config is not None:
|
||||
values["config"] = config
|
||||
|
||||
stmt = pg_insert(workflow_edges_table).values(**values).returning(workflow_edges_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def bulk_create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
edges: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Insert multiple edges in one statement.
|
||||
|
||||
Each element must have ``edge_id``, ``from_node_id`` (UUID PK),
|
||||
``to_node_id`` (UUID PK). Optional: ``source_handle``,
|
||||
``target_handle``, ``config``.
|
||||
"""
|
||||
if not edges:
|
||||
return []
|
||||
|
||||
rows = []
|
||||
for e in edges:
|
||||
rows.append({
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"edge_id": e["edge_id"],
|
||||
"from_node_id": e["from_node_id"],
|
||||
"to_node_id": e["to_node_id"],
|
||||
"source_handle": e.get("source_handle"),
|
||||
"target_handle": e.get("target_handle"),
|
||||
"config": e.get("config", {}),
|
||||
})
|
||||
|
||||
stmt = pg_insert(workflow_edges_table).values(rows).returning(workflow_edges_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_by_version(
|
||||
self, workflow_id: str, graph_version: int,
|
||||
) -> list[dict]:
|
||||
"""List edges for a workflow/version, shaped to match the live API.
|
||||
|
||||
Joins ``workflow_nodes`` twice so callers receive the user-provided
|
||||
node-id strings (``source_id``/``target_id``) that the Mongo code
|
||||
and the frontend use, not the internal node UUIDs. The raw UUID
|
||||
columns (``from_node_id``/``to_node_id``) are still included in
|
||||
case a caller needs them.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT e.*,
|
||||
fn.node_id AS source_id,
|
||||
tn.node_id AS target_id
|
||||
FROM workflow_edges e
|
||||
JOIN workflow_nodes fn ON fn.id = e.from_node_id
|
||||
JOIN workflow_nodes tn ON tn.id = e.to_node_id
|
||||
WHERE e.workflow_id = CAST(:wf_id AS uuid)
|
||||
AND e.graph_version = :ver
|
||||
ORDER BY e.edge_id
|
||||
"""
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def resolve_node_id(
|
||||
self, workflow_id: str, graph_version: int, node_id: str,
|
||||
) -> Optional[str]:
|
||||
"""Look up the UUID PK of a node by its user-provided ``node_id``.
|
||||
|
||||
Callers that receive edges in the frontend shape (``source_id`` /
|
||||
``target_id`` are user-provided strings) use this helper to
|
||||
translate to the UUID PK before calling :meth:`create` /
|
||||
:meth:`bulk_create`.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT id FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver AND node_id = :node_id"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version, "node_id": node_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
def delete_by_workflow(self, workflow_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_edges "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid)"
|
||||
),
|
||||
{"wf_id": workflow_id},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_edges "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
|
||||
"""Delete all edges for a workflow except the specified version."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_edges "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version != :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": keep_version},
|
||||
)
|
||||
return result.rowcount
|
||||
158
application/storage/db/repositories/workflow_nodes.py
Normal file
158
application/storage/db/repositories/workflow_nodes.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Repository for the ``workflow_nodes`` table.
|
||||
|
||||
Covers bulk insert, find by version, and delete operations that the
|
||||
workflow routes perform on ``workflow_nodes_collection`` in Mongo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflow_nodes_table
|
||||
|
||||
|
||||
class WorkflowNodesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
node_id: str,
|
||||
node_type: str,
|
||||
*,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
position: dict | None = None,
|
||||
config: dict | None = None,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
values: dict = {
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"node_id": node_id,
|
||||
"node_type": node_type,
|
||||
}
|
||||
if title is not None:
|
||||
values["title"] = title
|
||||
if description is not None:
|
||||
values["description"] = description
|
||||
if position is not None:
|
||||
values["position"] = position
|
||||
if config is not None:
|
||||
values["config"] = config
|
||||
if legacy_mongo_id is not None:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(workflow_nodes_table).values(**values).returning(workflow_nodes_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def bulk_create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
nodes: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Insert multiple nodes in one statement.
|
||||
|
||||
Each element of ``nodes`` should have at least ``node_id`` and
|
||||
``node_type``; optional keys: ``title``, ``description``,
|
||||
``position``, ``config``.
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
rows = []
|
||||
for n in nodes:
|
||||
rows.append({
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"node_id": n["node_id"],
|
||||
"node_type": n["node_type"],
|
||||
"title": n.get("title"),
|
||||
"description": n.get("description"),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("config", {}),
|
||||
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||
})
|
||||
|
||||
stmt = pg_insert(workflow_nodes_table).values(rows).returning(workflow_nodes_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_by_version(
|
||||
self, workflow_id: str, graph_version: int,
|
||||
) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver "
|
||||
"ORDER BY node_id"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_node(
|
||||
self, workflow_id: str, graph_version: int, node_id: str,
|
||||
) -> Optional[dict]:
|
||||
"""Find a single node by its user-provided ``node_id``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver AND node_id = :nid"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version, "nid": node_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
"""Find a node by the original Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM workflow_nodes WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete_by_workflow(self, workflow_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid)"
|
||||
),
|
||||
{"wf_id": workflow_id},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
|
||||
"""Delete all nodes for a workflow except the specified version."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version != :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": keep_version},
|
||||
)
|
||||
return result.rowcount
|
||||
83
application/storage/db/repositories/workflow_runs.py
Normal file
83
application/storage/db/repositories/workflow_runs.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Repository for the ``workflow_runs`` table.
|
||||
|
||||
In Mongo, workflow_runs_collection only has ``insert_one`` — runs are
|
||||
written once after workflow execution completes and never updated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflow_runs_table
|
||||
|
||||
|
||||
class WorkflowRunsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
status: str,
|
||||
*,
|
||||
inputs: dict | None = None,
|
||||
result: dict | None = None,
|
||||
steps: list | None = None,
|
||||
started_at=None,
|
||||
ended_at=None,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
values: dict = {
|
||||
"workflow_id": workflow_id,
|
||||
"user_id": user_id,
|
||||
"status": status,
|
||||
}
|
||||
if inputs is not None:
|
||||
values["inputs"] = inputs
|
||||
if result is not None:
|
||||
values["result"] = result
|
||||
if steps is not None:
|
||||
values["steps"] = steps
|
||||
if started_at is not None:
|
||||
values["started_at"] = started_at
|
||||
if ended_at is not None:
|
||||
values["ended_at"] = ended_at
|
||||
if legacy_mongo_id is not None:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(workflow_runs_table).values(**values).returning(workflow_runs_table)
|
||||
res = self._conn.execute(stmt)
|
||||
return row_to_dict(res.fetchone())
|
||||
|
||||
def get(self, run_id: str) -> Optional[dict]:
|
||||
res = self._conn.execute(
|
||||
text("SELECT * FROM workflow_runs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": run_id},
|
||||
)
|
||||
row = res.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
"""Fetch a workflow run by the original Mongo ObjectId string."""
|
||||
res = self._conn.execute(
|
||||
text("SELECT * FROM workflow_runs WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
)
|
||||
row = res.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_workflow(self, workflow_id: str) -> list[dict]:
|
||||
res = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflow_runs "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"ORDER BY started_at DESC"
|
||||
),
|
||||
{"wf_id": workflow_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in res.fetchall()]
|
||||
125
application/storage/db/repositories/workflows.py
Normal file
125
application/storage/db/repositories/workflows.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Repository for the ``workflows`` table.
|
||||
|
||||
Covers CRUD on workflow metadata:
|
||||
|
||||
- create / get / list / update / delete
|
||||
- graph version management
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflows_table
|
||||
|
||||
|
||||
class WorkflowsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
*,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
values: dict = {"user_id": user_id, "name": name}
|
||||
if description is not None:
|
||||
values["description"] = description
|
||||
if legacy_mongo_id is not None:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(workflows_table).values(**values).returning(workflows_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, workflow_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflows "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": workflow_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_id(self, workflow_id: str) -> Optional[dict]:
|
||||
"""Fetch a workflow by ID without user check (for internal use)."""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM workflows WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": workflow_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||
) -> Optional[dict]:
|
||||
"""Fetch a workflow by its original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM workflows WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflows "
|
||||
"WHERE user_id = :user_id ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, workflow_id: str, user_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description", "current_graph_version"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return False
|
||||
|
||||
set_parts = [f"{col} = :{col}" for col in filtered]
|
||||
set_parts.append("updated_at = now()")
|
||||
params = {**filtered, "id": workflow_id, "user_id": user_id}
|
||||
|
||||
sql = (
|
||||
f"UPDATE workflows SET {', '.join(set_parts)} "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
)
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def increment_graph_version(self, workflow_id: str, user_id: str) -> Optional[int]:
|
||||
"""Atomically increment ``current_graph_version`` and return the new value."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE workflows "
|
||||
"SET current_graph_version = current_graph_version + 1, "
|
||||
" updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"RETURNING current_graph_version"
|
||||
),
|
||||
{"id": workflow_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def delete(self, workflow_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflows "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": workflow_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
@@ -1178,8 +1178,8 @@ def attachment_worker(self, file_info, user):
|
||||
|
||||
dual_write(
|
||||
AttachmentsRepository,
|
||||
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type: repo.create(
|
||||
u, fn, p, mime_type=mt,
|
||||
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type, mid=attachment_id: repo.create(
|
||||
u, fn, p, mime_type=mt, legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
852
frontend/package-lock.json
generated
852
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -29,8 +29,8 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"i18next": "^25.8.18",
|
||||
"i18next-browser-languagedetector": "^8.2.0",
|
||||
"i18next": "^26.0.4",
|
||||
"i18next-browser-languagedetector": "^8.2.1",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.562.0",
|
||||
"mermaid": "^11.12.1",
|
||||
@@ -41,7 +41,7 @@
|
||||
"react-dom": "^19.1.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-google-drive-picker": "^1.2.2",
|
||||
"react-i18next": "^16.2.4",
|
||||
"react-i18next": "^17.0.2",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-redux": "^9.2.0",
|
||||
"react-router-dom": "^7.6.1",
|
||||
@@ -58,7 +58,7 @@
|
||||
"@types/react": "^19.1.8",
|
||||
"@types/react-dom": "^19.1.7",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@typescript-eslint/eslint-plugin": "^8.46.3",
|
||||
"@typescript-eslint/eslint-plugin": "^8.58.2",
|
||||
"@typescript-eslint/parser": "^8.46.3",
|
||||
"@vitejs/plugin-react": "^6.0.1",
|
||||
"eslint": "^9.39.1",
|
||||
@@ -73,7 +73,7 @@
|
||||
"lint-staged": "^16.4.0",
|
||||
"postcss": "^8.4.49",
|
||||
"prettier": "^3.5.3",
|
||||
"prettier-plugin-tailwindcss": "^0.7.1",
|
||||
"prettier-plugin-tailwindcss": "^0.7.2",
|
||||
"tailwindcss": "^4.2.1",
|
||||
"tw-animate-css": "^1.4.0",
|
||||
"typescript": "^5.8.3",
|
||||
|
||||
@@ -32,6 +32,7 @@ import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
@@ -52,6 +53,45 @@ logger = logging.getLogger("backfill")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_mongo_id_text(value: Any) -> str | None:
|
||||
"""Return a Mongo ObjectId-like value as text across legacy shapes.
|
||||
|
||||
Handles raw ObjectId values, DBRef-like objects exposing ``.id``, and
|
||||
dict encodings such as ``{"$id": {"$oid": "..."}}`` that show up in
|
||||
imported / normalised BSON payloads.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
if "$id" in value:
|
||||
return _extract_mongo_id_text(value["$id"])
|
||||
if "_id" in value:
|
||||
return _extract_mongo_id_text(value["_id"])
|
||||
if "$oid" in value:
|
||||
return str(value["$oid"])
|
||||
return None
|
||||
ref_id = getattr(value, "id", None)
|
||||
if ref_id is not None:
|
||||
return _extract_mongo_id_text(ref_id)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _coerce_document_timestamp(doc: dict[str, Any], *keys: str):
|
||||
"""Return the first populated timestamp-like field from ``doc``.
|
||||
|
||||
Mongo user data is not fully uniform across older deployments. Some
|
||||
records only carry ``created_at`` / ``updated_at`` and a few legacy
|
||||
documents have no explicit timestamp at all. In that final case we
|
||||
fall back to "now" so the backfill can preserve the row instead of
|
||||
failing a NOT NULL constraint.
|
||||
"""
|
||||
for key in keys:
|
||||
value = doc.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _backfill_users(
|
||||
*,
|
||||
conn: Connection,
|
||||
@@ -123,9 +163,13 @@ def _backfill_prompts(
|
||||
) -> dict:
|
||||
upsert_sql = text(
|
||||
"""
|
||||
INSERT INTO prompts (user_id, name, content)
|
||||
VALUES (:user_id, :name, :content)
|
||||
ON CONFLICT DO NOTHING
|
||||
INSERT INTO prompts (user_id, name, content, legacy_mongo_id)
|
||||
VALUES (:user_id, :name, :content, :legacy_mongo_id)
|
||||
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
content = EXCLUDED.content,
|
||||
updated_at = now()
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["prompts"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
@@ -142,6 +186,7 @@ def _backfill_prompts(
|
||||
"user_id": user_id,
|
||||
"name": doc.get("name", ""),
|
||||
"content": doc.get("content", ""),
|
||||
"legacy_mongo_id": str(doc["_id"]),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
@@ -198,48 +243,6 @@ def _backfill_user_tools(
|
||||
return {"seen": seen, "written": written, "skipped_no_user": skipped}
|
||||
|
||||
|
||||
def _backfill_feedback(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO feedback (conversation_id, user_id, question_index, feedback_text, timestamp)
|
||||
VALUES (CAST(:conversation_id AS uuid), :user_id, :question_index, :feedback_text, :timestamp)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["feedback"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
seen = written = skipped = 0
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user")
|
||||
conv_id = doc.get("conversation_id")
|
||||
if not user_id or not conv_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"conversation_id": str(conv_id),
|
||||
"user_id": user_id,
|
||||
"question_index": doc.get("question_index", 0),
|
||||
"feedback_text": doc.get("feedback_text"),
|
||||
"timestamp": doc.get("timestamp"),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
def _backfill_stack_logs(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
@@ -468,15 +471,32 @@ def _backfill_agents(
|
||||
chunks, retriever, default_model_id,
|
||||
tools, json_schema, models,
|
||||
limited_token_mode, token_limit, limited_request_mode, request_limit,
|
||||
shared, incoming_webhook_token
|
||||
shared, incoming_webhook_token, legacy_mongo_id
|
||||
) VALUES (
|
||||
:user_id, :name, :status, :key, :description, :agent_type,
|
||||
:chunks, :retriever, :default_model_id,
|
||||
CAST(:tools AS jsonb), CAST(:json_schema AS jsonb), CAST(:models AS jsonb),
|
||||
:limited_token_mode, :token_limit, :limited_request_mode, :request_limit,
|
||||
:shared, :incoming_webhook_token
|
||||
:shared, :incoming_webhook_token, :legacy_mongo_id
|
||||
)
|
||||
ON CONFLICT DO NOTHING
|
||||
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
status = EXCLUDED.status,
|
||||
description = EXCLUDED.description,
|
||||
agent_type = EXCLUDED.agent_type,
|
||||
chunks = EXCLUDED.chunks,
|
||||
retriever = EXCLUDED.retriever,
|
||||
default_model_id = EXCLUDED.default_model_id,
|
||||
tools = EXCLUDED.tools,
|
||||
json_schema = EXCLUDED.json_schema,
|
||||
models = EXCLUDED.models,
|
||||
limited_token_mode = EXCLUDED.limited_token_mode,
|
||||
token_limit = EXCLUDED.token_limit,
|
||||
limited_request_mode = EXCLUDED.limited_request_mode,
|
||||
request_limit = EXCLUDED.request_limit,
|
||||
shared = EXCLUDED.shared,
|
||||
updated_at = now()
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["agents"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
@@ -493,7 +513,11 @@ def _backfill_agents(
|
||||
"user_id": user_id,
|
||||
"name": doc.get("name", ""),
|
||||
"status": doc.get("status", "draft"),
|
||||
"key": doc.get("key"),
|
||||
# Mongo allows multiple agents with key="" but Postgres
|
||||
# CITEXT UNIQUE treats them as a collision. Coerce empty
|
||||
# strings to NULL so the unique constraint is only
|
||||
# enforced for actual API keys.
|
||||
"key": (doc.get("key") or None),
|
||||
"description": doc.get("description"),
|
||||
"agent_type": doc.get("agent_type"),
|
||||
"chunks": doc.get("chunks"),
|
||||
@@ -508,6 +532,7 @@ def _backfill_agents(
|
||||
"request_limit": doc.get("request_limit"),
|
||||
"shared": bool(doc.get("shared", False)),
|
||||
"incoming_webhook_token": doc.get("incoming_webhook_token"),
|
||||
"legacy_mongo_id": str(doc["_id"]),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
@@ -528,8 +553,16 @@ def _backfill_attachments(
|
||||
) -> dict:
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size)
|
||||
VALUES (:user_id, :filename, :upload_path, :mime_type, :size)
|
||||
INSERT INTO attachments (user_id, filename, upload_path, mime_type, size,
|
||||
legacy_mongo_id)
|
||||
VALUES (:user_id, :filename, :upload_path, :mime_type, :size,
|
||||
:legacy_mongo_id)
|
||||
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
|
||||
DO UPDATE SET
|
||||
filename = EXCLUDED.filename,
|
||||
upload_path = EXCLUDED.upload_path,
|
||||
mime_type = EXCLUDED.mime_type,
|
||||
size = EXCLUDED.size
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["attachments"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
@@ -548,6 +581,7 @@ def _backfill_attachments(
|
||||
"upload_path": doc.get("upload_path", ""),
|
||||
"mime_type": doc.get("mime_type"),
|
||||
"size": doc.get("size"),
|
||||
"legacy_mongo_id": str(doc["_id"]),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
@@ -780,6 +814,696 @@ def _backfill_connector_sessions(
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 3 backfillers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _backfill_conversations(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``conversations`` table from Mongo ``conversations`` collection.
|
||||
|
||||
Also flattens the nested ``queries`` array into
|
||||
``conversation_messages`` rows (one per query, position = array index).
|
||||
|
||||
Idempotent via the ``legacy_mongo_id`` column: rerunning replaces any
|
||||
previously migrated row's mutable fields and re-syncs its messages.
|
||||
"""
|
||||
agent_id_map = _build_legacy_id_map(conn, "agents")
|
||||
attachment_id_map = _build_legacy_id_map(conn, "attachments")
|
||||
|
||||
conv_sql = text(
|
||||
"""
|
||||
INSERT INTO conversations
|
||||
(user_id, name, agent_id, api_key, is_shared_usage, shared_token,
|
||||
shared_with, compression_metadata, date, legacy_mongo_id)
|
||||
VALUES
|
||||
(:user_id, :name, CAST(:agent_id AS uuid), :api_key,
|
||||
:is_shared_usage, :shared_token,
|
||||
CAST(:shared_with AS text[]), CAST(:compression_metadata AS jsonb),
|
||||
:date, :legacy_mongo_id)
|
||||
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
agent_id = EXCLUDED.agent_id,
|
||||
api_key = EXCLUDED.api_key,
|
||||
is_shared_usage = EXCLUDED.is_shared_usage,
|
||||
shared_token = EXCLUDED.shared_token,
|
||||
shared_with = EXCLUDED.shared_with,
|
||||
compression_metadata = EXCLUDED.compression_metadata,
|
||||
updated_at = now()
|
||||
RETURNING id
|
||||
"""
|
||||
)
|
||||
truncate_sql = text(
|
||||
"""
|
||||
DELETE FROM conversation_messages
|
||||
WHERE conversation_id = CAST(:conv_id AS uuid)
|
||||
AND position > :max_pos
|
||||
"""
|
||||
)
|
||||
msg_sql = text(
|
||||
"""
|
||||
INSERT INTO conversation_messages
|
||||
(conversation_id, position, prompt, response, thought,
|
||||
sources, tool_calls, attachments, model_id, message_metadata, feedback,
|
||||
timestamp)
|
||||
VALUES
|
||||
(CAST(:conv_id AS uuid), :position, :prompt, :response, :thought,
|
||||
CAST(:sources AS jsonb), CAST(:tool_calls AS jsonb),
|
||||
CAST(:attachments AS uuid[]),
|
||||
:model_id, CAST(:metadata AS jsonb), CAST(:feedback AS jsonb),
|
||||
:timestamp)
|
||||
ON CONFLICT (conversation_id, position) DO UPDATE SET
|
||||
prompt = EXCLUDED.prompt,
|
||||
response = EXCLUDED.response,
|
||||
thought = EXCLUDED.thought,
|
||||
sources = EXCLUDED.sources,
|
||||
tool_calls = EXCLUDED.tool_calls,
|
||||
attachments = EXCLUDED.attachments,
|
||||
model_id = EXCLUDED.model_id,
|
||||
message_metadata = EXCLUDED.message_metadata,
|
||||
feedback = EXCLUDED.feedback,
|
||||
timestamp = EXCLUDED.timestamp
|
||||
"""
|
||||
)
|
||||
|
||||
cursor = mongo_db["conversations"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
seen = written = msg_written = skipped = 0
|
||||
malformed_messages = 0
|
||||
unresolved_attachment_refs = 0
|
||||
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user")
|
||||
if not user_id:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
shared_with = doc.get("shared_with") or []
|
||||
comp_meta = doc.get("compression_metadata")
|
||||
|
||||
if dry_run:
|
||||
# In dry-run we don't write, so we can't get a returning id.
|
||||
# Skip message insertion too — they need the FK.
|
||||
continue
|
||||
|
||||
mongo_agent_id = doc.get("agent_id")
|
||||
pg_agent_id = agent_id_map.get(str(mongo_agent_id)) if mongo_agent_id else None
|
||||
|
||||
result = conn.execute(conv_sql, {
|
||||
"user_id": user_id,
|
||||
"name": doc.get("name"),
|
||||
"agent_id": pg_agent_id,
|
||||
"api_key": doc.get("api_key"),
|
||||
"is_shared_usage": bool(doc.get("is_shared_usage", False)),
|
||||
"shared_token": doc.get("shared_token"),
|
||||
"shared_with": list(shared_with),
|
||||
"compression_metadata": json.dumps(comp_meta) if comp_meta else None,
|
||||
"date": _coerce_document_timestamp(doc, "date", "created_at", "updated_at"),
|
||||
"legacy_mongo_id": str(doc["_id"]),
|
||||
})
|
||||
pg_conv_id = str(result.scalar())
|
||||
written += 1
|
||||
|
||||
# Flatten queries array → conversation_messages rows
|
||||
queries = doc.get("queries") or []
|
||||
msg_batch: list[dict] = []
|
||||
for pos, q in enumerate(queries):
|
||||
if not isinstance(q, dict):
|
||||
malformed_messages += 1
|
||||
logger.warning(
|
||||
"Skipping malformed conversation query during backfill: "
|
||||
"conversation=%s position=%s type=%s",
|
||||
doc.get("_id"),
|
||||
pos,
|
||||
type(q).__name__,
|
||||
)
|
||||
continue
|
||||
fb = q.get("feedback")
|
||||
fb_ts = q.get("feedback_timestamp")
|
||||
feedback_json = None
|
||||
if fb is not None:
|
||||
feedback_json = json.dumps({"text": fb, "timestamp": str(fb_ts)} if fb_ts else {"text": fb})
|
||||
|
||||
# Resolve attachment ObjectIds → Postgres UUIDs; drop unresolved.
|
||||
raw_attachments = q.get("attachments") or []
|
||||
resolved_attachments: list[str] = []
|
||||
for a in raw_attachments:
|
||||
if not a:
|
||||
continue
|
||||
s = str(a)
|
||||
if len(s) == 36 and "-" in s:
|
||||
resolved_attachments.append(s)
|
||||
else:
|
||||
mapped = attachment_id_map.get(s)
|
||||
if mapped:
|
||||
resolved_attachments.append(mapped)
|
||||
else:
|
||||
unresolved_attachment_refs += 1
|
||||
logger.warning(
|
||||
"Conversation backfill dropped unresolved attachment ref: "
|
||||
"conversation=%s position=%s attachment=%s",
|
||||
doc.get("_id"),
|
||||
pos,
|
||||
s,
|
||||
)
|
||||
|
||||
msg_batch.append({
|
||||
"conv_id": pg_conv_id,
|
||||
"position": pos,
|
||||
"prompt": q.get("prompt"),
|
||||
"response": q.get("response"),
|
||||
"thought": q.get("thought"),
|
||||
"sources": json.dumps(q.get("sources") or []),
|
||||
"tool_calls": json.dumps(q.get("tool_calls") or []),
|
||||
"attachments": resolved_attachments,
|
||||
"model_id": q.get("model_id"),
|
||||
"metadata": json.dumps(q.get("metadata") or {}),
|
||||
"feedback": feedback_json,
|
||||
"timestamp": (
|
||||
q.get("timestamp")
|
||||
or doc.get("date")
|
||||
or doc.get("created_at")
|
||||
or doc.get("updated_at")
|
||||
or datetime.now(timezone.utc)
|
||||
),
|
||||
})
|
||||
|
||||
if msg_batch and not dry_run:
|
||||
conn.execute(msg_sql, msg_batch)
|
||||
msg_written += len(msg_batch)
|
||||
|
||||
# Converge: drop any messages past the Mongo queries length
|
||||
# (handles the case where a conversation was truncated in Mongo
|
||||
# after a previous backfill).
|
||||
if not dry_run:
|
||||
conn.execute(truncate_sql, {
|
||||
"conv_id": pg_conv_id,
|
||||
"max_pos": len(queries) - 1,
|
||||
})
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return {
|
||||
"seen": seen,
|
||||
"written": written,
|
||||
"messages_written": msg_written,
|
||||
"skipped": skipped,
|
||||
"malformed_messages": malformed_messages,
|
||||
"unresolved_attachment_refs": unresolved_attachment_refs,
|
||||
}
|
||||
|
||||
|
||||
def _build_legacy_id_map(conn: Connection, table: str) -> dict[str, str]:
|
||||
"""Return ``{legacy_mongo_id: pg_uuid}`` for the given table.
|
||||
|
||||
Used by Phase 3 backfills to resolve FK references that were Mongo
|
||||
ObjectIds in the source data into the new Postgres UUIDs.
|
||||
"""
|
||||
rows = conn.execute(
|
||||
text(
|
||||
f"SELECT id, legacy_mongo_id FROM {table} "
|
||||
"WHERE legacy_mongo_id IS NOT NULL"
|
||||
)
|
||||
).fetchall()
|
||||
return {r._mapping["legacy_mongo_id"]: str(r._mapping["id"]) for r in rows}
|
||||
|
||||
|
||||
def _backfill_shared_conversations(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``shared_conversations`` table.
|
||||
|
||||
Resolves Mongo ``conversation_id`` (ObjectId) → Postgres
|
||||
``conversations.id`` (UUID) via the ``conversations.legacy_mongo_id``
|
||||
column populated during the conversations backfill. Rows whose
|
||||
parent conversation was not migrated are skipped.
|
||||
"""
|
||||
conv_id_map = _build_legacy_id_map(conn, "conversations")
|
||||
prompt_id_map = _build_legacy_id_map(conn, "prompts")
|
||||
agent_meta_by_key = {
|
||||
doc.get("key"): {
|
||||
"prompt_id": doc.get("prompt_id"),
|
||||
"chunks": doc.get("chunks"),
|
||||
}
|
||||
for doc in mongo_db["agents"].find({}, {"key": 1, "prompt_id": 1, "chunks": 1})
|
||||
if doc.get("key")
|
||||
}
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO shared_conversations
|
||||
(uuid, conversation_id, user_id, is_promptable, first_n_queries,
|
||||
api_key, prompt_id, chunks)
|
||||
VALUES
|
||||
(CAST(:uuid AS uuid), CAST(:conv_id AS uuid), :user_id,
|
||||
:is_promptable, :first_n_queries, :api_key,
|
||||
CAST(:prompt_id AS uuid), :chunks)
|
||||
ON CONFLICT (uuid) DO UPDATE SET
|
||||
conversation_id = EXCLUDED.conversation_id,
|
||||
user_id = EXCLUDED.user_id,
|
||||
is_promptable = EXCLUDED.is_promptable,
|
||||
first_n_queries = EXCLUDED.first_n_queries,
|
||||
api_key = EXCLUDED.api_key,
|
||||
prompt_id = EXCLUDED.prompt_id,
|
||||
chunks = EXCLUDED.chunks
|
||||
"""
|
||||
)
|
||||
cursor = (
|
||||
mongo_db["shared_conversations"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
seen = written = skipped = 0
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user")
|
||||
mongo_conv_id = _extract_mongo_id_text(doc.get("conversation_id"))
|
||||
mongo_uuid = doc.get("uuid")
|
||||
if not user_id or not mongo_conv_id or not mongo_uuid:
|
||||
skipped += 1
|
||||
continue
|
||||
pg_conv_id = conv_id_map.get(mongo_conv_id)
|
||||
if not pg_conv_id:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Mongo stores ``uuid`` as BSON Binary (standard UUID subtype).
|
||||
# Unwrap to a plain uuid.UUID → string for Postgres CAST.
|
||||
try:
|
||||
share_uuid_str = str(mongo_uuid.as_uuid()) if hasattr(mongo_uuid, "as_uuid") else str(mongo_uuid)
|
||||
except Exception:
|
||||
share_uuid_str = str(mongo_uuid)
|
||||
|
||||
# prompt_id may be either a prompt ObjectId or the literal string
|
||||
# "default" (see sharing/routes.py); only resolvable ObjectIds
|
||||
# get a real FK value.
|
||||
agent_meta = agent_meta_by_key.get(doc.get("api_key")) or {}
|
||||
raw_prompt_id = doc.get("prompt_id")
|
||||
if raw_prompt_id is None:
|
||||
raw_prompt_id = agent_meta.get("prompt_id")
|
||||
prompt_legacy_id = _extract_mongo_id_text(raw_prompt_id)
|
||||
resolved_prompt_id = (
|
||||
prompt_id_map.get(prompt_legacy_id) if prompt_legacy_id else None
|
||||
)
|
||||
|
||||
chunks_raw = doc.get("chunks")
|
||||
if chunks_raw is None:
|
||||
chunks_raw = agent_meta.get("chunks")
|
||||
chunks_val: int | None = None
|
||||
if chunks_raw is not None:
|
||||
try:
|
||||
chunks_val = int(chunks_raw)
|
||||
except (TypeError, ValueError):
|
||||
chunks_val = None
|
||||
|
||||
batch.append({
|
||||
"uuid": share_uuid_str,
|
||||
"conv_id": pg_conv_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": bool(doc.get("isPromptable", False)),
|
||||
"first_n_queries": doc.get("first_n_queries", 0),
|
||||
"api_key": doc.get("api_key"),
|
||||
"prompt_id": resolved_prompt_id,
|
||||
"chunks": chunks_val,
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
def _backfill_pending_tool_state(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync ``pending_tool_state`` from Mongo.
|
||||
|
||||
Most rows will be expired by the time the backfill runs (30-min TTL).
|
||||
We copy them anyway; the Celery cleanup task will purge stale rows on
|
||||
its first tick. Resolves ``conversation_id`` via
|
||||
``conversations.legacy_mongo_id``.
|
||||
"""
|
||||
conv_id_map = _build_legacy_id_map(conn, "conversations")
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO pending_tool_state
|
||||
(conversation_id, user_id, messages, pending_tool_calls,
|
||||
tools_dict, tool_schemas, agent_config, client_tools,
|
||||
created_at, expires_at)
|
||||
VALUES
|
||||
(CAST(:conv_id AS uuid), :user_id,
|
||||
CAST(:messages AS jsonb), CAST(:pending AS jsonb),
|
||||
CAST(:tools_dict AS jsonb), CAST(:schemas AS jsonb),
|
||||
CAST(:agent_config AS jsonb), CAST(:client_tools AS jsonb),
|
||||
:created_at, :expires_at)
|
||||
ON CONFLICT (conversation_id, user_id) DO UPDATE SET
|
||||
messages = EXCLUDED.messages,
|
||||
pending_tool_calls = EXCLUDED.pending_tool_calls,
|
||||
tools_dict = EXCLUDED.tools_dict,
|
||||
tool_schemas = EXCLUDED.tool_schemas,
|
||||
agent_config = EXCLUDED.agent_config,
|
||||
client_tools = EXCLUDED.client_tools,
|
||||
created_at = EXCLUDED.created_at,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
"""
|
||||
)
|
||||
cursor = (
|
||||
mongo_db["pending_tool_state"]
|
||||
.find({}, no_cursor_timeout=True)
|
||||
.batch_size(batch_size)
|
||||
)
|
||||
seen = written = skipped = 0
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
mongo_conv_id = doc.get("conversation_id")
|
||||
user_id = doc.get("user")
|
||||
if not mongo_conv_id or not user_id:
|
||||
skipped += 1
|
||||
continue
|
||||
pg_conv_id = conv_id_map.get(str(mongo_conv_id))
|
||||
if not pg_conv_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"conv_id": pg_conv_id,
|
||||
"user_id": user_id,
|
||||
"messages": json.dumps(doc.get("messages") or [], default=str),
|
||||
"pending": json.dumps(doc.get("pending_tool_calls") or [], default=str),
|
||||
"tools_dict": json.dumps(doc.get("tools_dict") or {}, default=str),
|
||||
"schemas": json.dumps(doc.get("tool_schemas") or [], default=str),
|
||||
"agent_config": json.dumps(doc.get("agent_config") or {}, default=str),
|
||||
"client_tools": json.dumps(doc.get("client_tools"), default=str) if doc.get("client_tools") else None,
|
||||
"created_at": doc.get("created_at"),
|
||||
"expires_at": doc.get("expires_at"),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
def _backfill_workflows(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``workflows`` table from Mongo ``workflows`` collection.
|
||||
|
||||
Idempotent via ``legacy_mongo_id``.
|
||||
"""
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO workflows (user_id, name, description, current_graph_version,
|
||||
legacy_mongo_id)
|
||||
VALUES (:user_id, :name, :description, :current_graph_version,
|
||||
:legacy_mongo_id)
|
||||
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
|
||||
DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
description = EXCLUDED.description,
|
||||
current_graph_version = EXCLUDED.current_graph_version,
|
||||
updated_at = now()
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["workflows"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
seen = written = skipped = 0
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
user_id = doc.get("user")
|
||||
if not user_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"user_id": user_id,
|
||||
"name": doc.get("name", ""),
|
||||
"description": doc.get("description"),
|
||||
"current_graph_version": doc.get("current_graph_version", 1),
|
||||
"legacy_mongo_id": str(doc["_id"]),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
def _backfill_workflow_nodes(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync ``workflow_nodes``.
|
||||
|
||||
Resolves Mongo ``workflow_id`` (string ObjectId) →
|
||||
``workflows.id`` (UUID) via ``workflows.legacy_mongo_id``.
|
||||
Idempotent via ``legacy_mongo_id``.
|
||||
"""
|
||||
workflow_id_map = _build_legacy_id_map(conn, "workflows")
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO workflow_nodes
|
||||
(workflow_id, graph_version, node_id, node_type, title, description,
|
||||
position, config, legacy_mongo_id)
|
||||
VALUES
|
||||
(CAST(:workflow_id AS uuid), :graph_version, :node_id, :node_type,
|
||||
:title, :description, CAST(:position AS jsonb), CAST(:config AS jsonb),
|
||||
:legacy_mongo_id)
|
||||
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
|
||||
DO UPDATE SET
|
||||
graph_version = EXCLUDED.graph_version,
|
||||
node_id = EXCLUDED.node_id,
|
||||
node_type = EXCLUDED.node_type,
|
||||
title = EXCLUDED.title,
|
||||
description = EXCLUDED.description,
|
||||
position = EXCLUDED.position,
|
||||
config = EXCLUDED.config
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["workflow_nodes"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
seen = written = skipped = 0
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
mongo_wf_id = doc.get("workflow_id")
|
||||
if not mongo_wf_id:
|
||||
skipped += 1
|
||||
continue
|
||||
pg_wf_id = workflow_id_map.get(str(mongo_wf_id))
|
||||
if not pg_wf_id:
|
||||
skipped += 1
|
||||
continue
|
||||
position = doc.get("position") or {"x": 0, "y": 0}
|
||||
batch.append({
|
||||
"workflow_id": pg_wf_id,
|
||||
"graph_version": doc.get("graph_version", 1),
|
||||
"node_id": doc.get("id", ""),
|
||||
"node_type": doc.get("type", ""),
|
||||
"title": doc.get("title"),
|
||||
"description": doc.get("description"),
|
||||
"position": json.dumps(position),
|
||||
"config": json.dumps(doc.get("config") or {}),
|
||||
"legacy_mongo_id": str(doc["_id"]),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
def _backfill_workflow_edges(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``workflow_edges`` table from Mongo ``workflow_edges`` collection.
|
||||
|
||||
Must run after ``workflow_nodes`` because ``from_node_id`` and
|
||||
``to_node_id`` are FKs into ``workflow_nodes``.
|
||||
|
||||
The Mongo doc stores ``source_id`` and ``target_id`` as user-provided
|
||||
node-id strings. We need to resolve them to Postgres UUIDs by looking
|
||||
up the ``workflow_nodes`` row with matching ``(workflow_id,
|
||||
graph_version, node_id)``.
|
||||
"""
|
||||
workflow_id_map = _build_legacy_id_map(conn, "workflows")
|
||||
# Build a lookup: (pg_workflow_uuid, graph_version, node_id_str) → pg node UUID
|
||||
pg_nodes = conn.execute(
|
||||
text("SELECT id, workflow_id, graph_version, node_id FROM workflow_nodes")
|
||||
).fetchall()
|
||||
node_lookup: dict[tuple[str, int, str], str] = {}
|
||||
for row in pg_nodes:
|
||||
m = row._mapping
|
||||
node_lookup[(str(m["workflow_id"]), m["graph_version"], m["node_id"])] = str(m["id"])
|
||||
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO workflow_edges
|
||||
(workflow_id, graph_version, edge_id, from_node_id, to_node_id,
|
||||
source_handle, target_handle, config)
|
||||
VALUES
|
||||
(CAST(:workflow_id AS uuid), :graph_version, :edge_id,
|
||||
CAST(:from_node_id AS uuid), CAST(:to_node_id AS uuid),
|
||||
:source_handle, :target_handle, CAST(:config AS jsonb))
|
||||
ON CONFLICT (workflow_id, graph_version, edge_id) DO UPDATE SET
|
||||
from_node_id = EXCLUDED.from_node_id,
|
||||
to_node_id = EXCLUDED.to_node_id,
|
||||
source_handle = EXCLUDED.source_handle,
|
||||
target_handle = EXCLUDED.target_handle,
|
||||
config = EXCLUDED.config
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["workflow_edges"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
seen = written = skipped = 0
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
mongo_wf_id = doc.get("workflow_id")
|
||||
if not mongo_wf_id:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
pg_wf_id = workflow_id_map.get(str(mongo_wf_id))
|
||||
if not pg_wf_id:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
gv = doc.get("graph_version", 1)
|
||||
source_nid = doc.get("source_id", "")
|
||||
target_nid = doc.get("target_id", "")
|
||||
|
||||
from_uuid = node_lookup.get((pg_wf_id, gv, source_nid))
|
||||
to_uuid = node_lookup.get((pg_wf_id, gv, target_nid))
|
||||
if not from_uuid or not to_uuid:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
batch.append({
|
||||
"workflow_id": pg_wf_id,
|
||||
"graph_version": gv,
|
||||
"edge_id": doc.get("id", ""),
|
||||
"from_node_id": from_uuid,
|
||||
"to_node_id": to_uuid,
|
||||
"source_handle": doc.get("source_handle"),
|
||||
"target_handle": doc.get("target_handle"),
|
||||
"config": json.dumps(doc.get("config") or {}),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
def _backfill_workflow_runs(
|
||||
*, conn: Connection, mongo_db: Any, batch_size: int, dry_run: bool,
|
||||
) -> dict:
|
||||
"""Sync the ``workflow_runs`` table from Mongo ``workflow_runs`` collection.
|
||||
|
||||
Resolves Mongo ``workflow_id`` (string) → PG UUID via
|
||||
``workflows.legacy_mongo_id``. Rows whose parent workflow was never
|
||||
migrated (e.g. legacy ``workflow_id='unknown'``) are skipped.
|
||||
"""
|
||||
workflow_id_map = _build_legacy_id_map(conn, "workflows")
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO workflow_runs
|
||||
(workflow_id, user_id, status, inputs, result, steps,
|
||||
started_at, ended_at, legacy_mongo_id)
|
||||
VALUES
|
||||
(CAST(:workflow_id AS uuid), :user_id, :status,
|
||||
CAST(:inputs AS jsonb), CAST(:result AS jsonb),
|
||||
CAST(:steps AS jsonb), :started_at, :ended_at, :legacy_mongo_id)
|
||||
ON CONFLICT (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL
|
||||
DO UPDATE SET
|
||||
status = EXCLUDED.status,
|
||||
inputs = EXCLUDED.inputs,
|
||||
result = EXCLUDED.result,
|
||||
steps = EXCLUDED.steps,
|
||||
ended_at = EXCLUDED.ended_at
|
||||
"""
|
||||
)
|
||||
cursor = mongo_db["workflow_runs"].find({}, no_cursor_timeout=True).batch_size(batch_size)
|
||||
seen = written = skipped = 0
|
||||
batch: list[dict] = []
|
||||
try:
|
||||
for doc in cursor:
|
||||
seen += 1
|
||||
mongo_wf_id = doc.get("workflow_id")
|
||||
if not mongo_wf_id:
|
||||
skipped += 1
|
||||
continue
|
||||
pg_wf_id = workflow_id_map.get(str(mongo_wf_id))
|
||||
if not pg_wf_id:
|
||||
skipped += 1
|
||||
continue
|
||||
batch.append({
|
||||
"workflow_id": pg_wf_id,
|
||||
"user_id": doc.get("user_id") or doc.get("user") or "",
|
||||
"status": doc.get("status", "unknown"),
|
||||
"inputs": json.dumps(doc.get("inputs") or {}, default=str),
|
||||
"result": json.dumps(doc.get("outputs") or doc.get("result"), default=str),
|
||||
"steps": json.dumps(doc.get("steps") or [], default=str),
|
||||
"started_at": doc.get("started_at") or doc.get("created_at"),
|
||||
"ended_at": doc.get("ended_at") or doc.get("completed_at"),
|
||||
"legacy_mongo_id": str(doc["_id"]),
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
if not dry_run:
|
||||
conn.execute(insert_sql, batch)
|
||||
written += len(batch)
|
||||
finally:
|
||||
cursor.close()
|
||||
return {"seen": seen, "written": written, "skipped": skipped}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -796,7 +1520,6 @@ BACKFILLERS: dict[str, BackfillFn] = {
|
||||
"users": _backfill_users,
|
||||
"prompts": _backfill_prompts,
|
||||
"user_tools": _backfill_user_tools,
|
||||
"feedback": _backfill_feedback,
|
||||
"stack_logs": _backfill_stack_logs,
|
||||
"user_logs": _backfill_user_logs,
|
||||
"token_usage": _backfill_token_usage,
|
||||
@@ -809,6 +1532,14 @@ BACKFILLERS: dict[str, BackfillFn] = {
|
||||
"todos": _backfill_todos,
|
||||
"notes": _backfill_notes,
|
||||
"connector_sessions": _backfill_connector_sessions,
|
||||
# Phase 3 (order: conversations first, then dependents)
|
||||
"conversations": _backfill_conversations,
|
||||
"shared_conversations": _backfill_shared_conversations,
|
||||
"pending_tool_state": _backfill_pending_tool_state,
|
||||
"workflows": _backfill_workflows,
|
||||
"workflow_nodes": _backfill_workflow_nodes,
|
||||
"workflow_edges": _backfill_workflow_edges,
|
||||
"workflow_runs": _backfill_workflow_runs,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -336,6 +336,34 @@ class TestSaveWorkflowRun:
|
||||
agent._save_workflow_run("query")
|
||||
|
||||
mock_collection.insert_one.assert_called_once()
|
||||
saved_doc = mock_collection.insert_one.call_args.args[0]
|
||||
assert saved_doc["user"] == "user1"
|
||||
assert saved_doc["user_id"] == "user1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dual_writes_when_mongo_insert_returns_id(self):
|
||||
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {"query": "test"}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
insert_result = MagicMock()
|
||||
insert_result.inserted_id = "507f1f77bcf86cd799439012"
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.insert_one.return_value = insert_result
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings, \
|
||||
patch("application.agents.workflow_agent.dual_write") as mock_dual_write:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
agent._save_workflow_run("query")
|
||||
|
||||
mock_dual_write.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_does_not_propagate(self):
|
||||
|
||||
@@ -973,14 +973,21 @@ class TestUpdateAgent:
|
||||
agent_id = ObjectId()
|
||||
existing = self._make_existing_agent(agent_id)
|
||||
mock_col = Mock()
|
||||
mock_repo = Mock()
|
||||
mock_col.find_one.return_value = existing
|
||||
mock_col.update_one.return_value = Mock(matched_count=1, modified_count=1)
|
||||
mock_handle_img = Mock(return_value=("", None))
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.routes.agents_collection", mock_col
|
||||
), patch(
|
||||
"application.api.user.agents.routes.handle_image_upload", mock_handle_img
|
||||
), patch(
|
||||
"application.api.user.agents.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/update_agent/{agent_id}",
|
||||
@@ -993,6 +1000,11 @@ class TestUpdateAgent:
|
||||
response = UpdateAgent().put(str(agent_id))
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_repo.update_by_legacy_id.assert_called_once_with(
|
||||
str(agent_id),
|
||||
"user1",
|
||||
{"name": "Updated Name"},
|
||||
)
|
||||
|
||||
def test_returns_400_invalid_status(self, app):
|
||||
from application.api.user.agents.routes import UpdateAgent
|
||||
@@ -1945,14 +1957,21 @@ class TestDeleteAgent:
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_col = Mock()
|
||||
mock_repo = Mock()
|
||||
mock_col.find_one_and_delete.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
"agent_type": "classic",
|
||||
}
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.routes.agents_collection", mock_col
|
||||
), patch(
|
||||
"application.api.user.agents.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(f"/api/delete_agent?id={agent_id}"):
|
||||
from flask import request
|
||||
@@ -1961,6 +1980,7 @@ class TestDeleteAgent:
|
||||
response = DeleteAgent().delete()
|
||||
assert response.status_code == 200
|
||||
assert response.json["id"] == str(agent_id)
|
||||
mock_repo.delete_by_legacy_id.assert_called_once_with(str(agent_id), "user1")
|
||||
|
||||
def test_deletes_workflow_agent_cleans_up(self, app):
|
||||
from application.api.user.agents.routes import DeleteAgent
|
||||
|
||||
@@ -18,12 +18,19 @@ class TestCreatePrompt:
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_repo = Mock()
|
||||
inserted_id = ObjectId()
|
||||
mock_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.prompts.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/create_prompt",
|
||||
@@ -41,6 +48,12 @@ class TestCreatePrompt:
|
||||
doc = mock_collection.insert_one.call_args[0][0]
|
||||
assert doc["name"] == "My Prompt"
|
||||
assert doc["user"] == "user1"
|
||||
mock_repo.create.assert_called_once_with(
|
||||
"user1",
|
||||
"My Prompt",
|
||||
"You are helpful.",
|
||||
legacy_mongo_id=str(inserted_id),
|
||||
)
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
@@ -204,10 +217,17 @@ class TestDeletePrompt:
|
||||
|
||||
prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_repo = Mock()
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.prompts.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/delete_prompt",
|
||||
@@ -224,6 +244,7 @@ class TestDeletePrompt:
|
||||
mock_collection.delete_one.assert_called_once_with(
|
||||
{"_id": prompt_id, "user": "user1"}
|
||||
)
|
||||
mock_repo.delete_by_legacy_id.assert_called_once_with(str(prompt_id), "user1")
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.prompts.routes import DeletePrompt
|
||||
@@ -249,10 +270,17 @@ class TestUpdatePrompt:
|
||||
|
||||
prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_repo = Mock()
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.prompts.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_prompt",
|
||||
@@ -271,6 +299,12 @@ class TestUpdatePrompt:
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_collection.update_one.assert_called_once()
|
||||
mock_repo.update_by_legacy_id.assert_called_once_with(
|
||||
str(prompt_id),
|
||||
"user1",
|
||||
"Updated",
|
||||
"New content",
|
||||
)
|
||||
|
||||
def test_returns_400_missing_fields(self, app):
|
||||
from application.api.user.prompts.routes import UpdatePrompt
|
||||
|
||||
@@ -200,7 +200,7 @@ class TestSetupPeriodicTasks:
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 3
|
||||
assert sender.add_periodic_task.call_count == 4
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
@@ -210,6 +210,8 @@ class TestSetupPeriodicTasks:
|
||||
assert calls[1][0][0] == timedelta(weeks=1)
|
||||
# monthly
|
||||
assert calls[2][0][0] == timedelta(days=30)
|
||||
# pending_tool_state TTL cleanup (60s)
|
||||
assert calls[3][0][0] == timedelta(seconds=60)
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
|
||||
@@ -886,7 +886,13 @@ class TestWorkflowListPost:
|
||||
mock_wf_collection = Mock()
|
||||
mock_wf_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
|
||||
mock_nodes_collection = Mock()
|
||||
mock_nodes_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId(), ObjectId()]
|
||||
)
|
||||
mock_edges_collection = Mock()
|
||||
mock_edges_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId()]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.user.workflows.routes.workflows_collection",
|
||||
@@ -1152,7 +1158,13 @@ class TestWorkflowDetailPut:
|
||||
}
|
||||
mock_wf_collection.update_one.return_value = Mock()
|
||||
mock_nodes_collection = Mock()
|
||||
mock_nodes_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId(), ObjectId()]
|
||||
)
|
||||
mock_edges_collection = Mock()
|
||||
mock_edges_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId()]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.user.workflows.routes.workflows_collection",
|
||||
|
||||
@@ -27,18 +27,15 @@ from application.core.settings import settings
|
||||
def _run_alembic_upgrade(engine):
|
||||
"""Run ``alembic upgrade head`` to ensure the full schema is present.
|
||||
|
||||
Falls back to inline DDL for CI environments where alembic is not
|
||||
on PATH (shouldn't happen, but defence in depth).
|
||||
Non-zero exit is re-raised so genuine schema-drift bugs surface as
|
||||
test failures. If alembic reports the schema is already at head,
|
||||
the subprocess still exits zero.
|
||||
"""
|
||||
alembic_ini = Path(__file__).resolve().parents[3] / "application" / "alembic.ini"
|
||||
try:
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
timeout=30,
|
||||
)
|
||||
except Exception:
|
||||
# Alembic failed — tables likely already exist from a prior run.
|
||||
pass
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@@ -44,6 +44,21 @@ class TestCreate:
|
||||
doc = repo.create("u", "a", "draft")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
"u",
|
||||
"a",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
def test_create_normalizes_blank_key_to_null(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("u", "a", "draft", key="")
|
||||
assert doc["key"] is None
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
@@ -61,6 +76,17 @@ class TestGet:
|
||||
created = repo.create("user-1", "a", "draft")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"a",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestFindByKey:
|
||||
def test_finds_agent_by_key(self, pg_conn):
|
||||
@@ -108,6 +134,31 @@ class TestUpdate:
|
||||
updated = repo.update(created["id"], "user-1", {"id": "bad"})
|
||||
assert updated is False
|
||||
|
||||
def test_update_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(
|
||||
"user-1",
|
||||
"old",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
updated = repo.update_by_legacy_id(
|
||||
"507f1f77bcf86cd799439011",
|
||||
"user-1",
|
||||
{"name": "new", "last_used_at": None},
|
||||
)
|
||||
assert updated is True
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_update_normalizes_blank_key_to_null(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "draft", key="my-unique-key")
|
||||
updated = repo.update(created["id"], "user-1", {"key": ""})
|
||||
assert updated is True
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["key"] is None
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_agent(self, pg_conn):
|
||||
@@ -124,6 +175,18 @@ class TestDelete:
|
||||
assert deleted is False
|
||||
assert repo.get(created["id"], "user-1") is not None
|
||||
|
||||
def test_delete_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"a",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
deleted = repo.delete_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert deleted is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
|
||||
class TestSetFolder:
|
||||
def test_assigns_folder(self, pg_conn):
|
||||
|
||||
@@ -37,6 +37,16 @@ class TestCreate:
|
||||
doc = repo.create("u", "f", "/p")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
"u",
|
||||
"f",
|
||||
"/p",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
@@ -54,6 +64,17 @@ class TestGet:
|
||||
created = repo.create("u", "f", "/p")
|
||||
assert repo.get(created["id"], "other") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"u",
|
||||
"f",
|
||||
"/p",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "u")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_attachments(self, pg_conn):
|
||||
|
||||
374
tests/storage/db/repositories/test_conversations.py
Normal file
374
tests/storage/db/repositories/test_conversations.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Tests for ConversationsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> ConversationsRepository:
|
||||
return ConversationsRepository(conn)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conversation CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_conversation(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "My Chat")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "My Chat"
|
||||
assert doc["id"] is not None
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_agent(self, pg_conn):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
agent_repo = AgentsRepository(pg_conn)
|
||||
agent = agent_repo.create("user-1", "a", "active")
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
"user-1", "Chat",
|
||||
agent_id=agent["id"],
|
||||
api_key="ak-123",
|
||||
is_shared_usage=True,
|
||||
shared_token="tok-abc",
|
||||
)
|
||||
assert str(doc["agent_id"]) == agent["id"]
|
||||
assert doc["api_key"] == "ak-123"
|
||||
assert doc["is_shared_usage"] is True
|
||||
assert doc["shared_token"] == "tok-abc"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_owned(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
|
||||
|
||||
def test_get_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_own_conversations(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "c1")
|
||||
repo.create("alice", "c2")
|
||||
repo.create("bob", "c3")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
def test_excludes_api_key_without_agent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "normal")
|
||||
repo.create("alice", "api-only", api_key="key-1")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "normal"
|
||||
|
||||
|
||||
class TestRename:
|
||||
def test_renames(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
assert repo.rename(created["id"], "user-1", "new") is True
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_rename_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
assert repo.rename(created["id"], "user-other", "new") is False
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
assert repo.delete(created["id"], "user-1") is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
def test_delete_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
assert repo.delete(created["id"], "user-other") is False
|
||||
|
||||
def test_delete_cascades_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "hi", "response": "hello"})
|
||||
repo.delete(conv["id"], "user-1")
|
||||
assert repo.get_messages(conv["id"]) == []
|
||||
|
||||
def test_delete_all_for_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("user-1", "c1")
|
||||
repo.create("user-1", "c2")
|
||||
repo.create("user-2", "c3")
|
||||
count = repo.delete_all_for_user("user-1")
|
||||
assert count == 2
|
||||
assert repo.list_for_user("user-1") == []
|
||||
assert len(repo.list_for_user("user-2")) == 1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Messages
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppendMessage:
|
||||
def test_append_first_message(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
msg = repo.append_message(conv["id"], {
|
||||
"prompt": "hello",
|
||||
"response": "hi there",
|
||||
"model_id": "gpt-4",
|
||||
})
|
||||
assert msg["position"] == 0
|
||||
assert msg["prompt"] == "hello"
|
||||
assert msg["response"] == "hi there"
|
||||
assert msg["model_id"] == "gpt-4"
|
||||
|
||||
def test_append_increments_position(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
m0 = repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
|
||||
m1 = repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
|
||||
assert m0["position"] == 0
|
||||
assert m1["position"] == 1
|
||||
|
||||
def test_append_with_sources_and_tools(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
msg = repo.append_message(conv["id"], {
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"sources": [{"title": "doc1"}],
|
||||
"tool_calls": [{"name": "search", "args": {}}],
|
||||
"metadata": {"search_query": "rewritten"},
|
||||
})
|
||||
assert msg["sources"] == [{"title": "doc1"}]
|
||||
assert msg["tool_calls"] == [{"name": "search", "args": {}}]
|
||||
assert msg["metadata"] == {"search_query": "rewritten"}
|
||||
|
||||
def test_append_preserves_explicit_timestamp(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
ts = datetime.now(timezone.utc)
|
||||
msg = repo.append_message(conv["id"], {
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"timestamp": ts,
|
||||
})
|
||||
assert msg["timestamp"] == ts
|
||||
|
||||
|
||||
class TestGetMessages:
|
||||
def test_returns_ordered_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
|
||||
repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
|
||||
msgs = repo.get_messages(conv["id"])
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["position"] == 0
|
||||
assert msgs[1]["position"] == 1
|
||||
|
||||
def test_get_message_at(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
|
||||
repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
|
||||
msg = repo.get_message_at(conv["id"], 1)
|
||||
assert msg["prompt"] == "q2"
|
||||
|
||||
def test_get_message_at_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
assert repo.get_message_at(conv["id"], 99) is None
|
||||
|
||||
|
||||
class TestUpdateMessageAt:
|
||||
def test_updates_response(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "old"})
|
||||
assert repo.update_message_at(conv["id"], 0, {"response": "new"}) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["response"] == "new"
|
||||
|
||||
def test_update_disallowed_field(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
assert repo.update_message_at(conv["id"], 0, {"id": "bad"}) is False
|
||||
|
||||
def test_updates_explicit_timestamp(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "old"})
|
||||
ts = datetime.now(timezone.utc)
|
||||
assert repo.update_message_at(
|
||||
conv["id"], 0, {"response": "new", "timestamp": ts},
|
||||
) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["response"] == "new"
|
||||
assert msg["timestamp"] == ts
|
||||
|
||||
|
||||
class TestTruncateAfter:
|
||||
def test_truncates_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
for i in range(5):
|
||||
repo.append_message(conv["id"], {"prompt": f"q{i}", "response": f"a{i}"})
|
||||
deleted = repo.truncate_after(conv["id"], 2)
|
||||
assert deleted == 2
|
||||
msgs = repo.get_messages(conv["id"])
|
||||
assert len(msgs) == 3
|
||||
assert [m["position"] for m in msgs] == [0, 1, 2]
|
||||
|
||||
|
||||
class TestSetFeedback:
|
||||
def test_set_feedback(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
assert repo.set_feedback(conv["id"], 0, {"text": "thumbs_up"}) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["feedback"] == {"text": "thumbs_up"}
|
||||
|
||||
def test_unset_feedback(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
repo.set_feedback(conv["id"], 0, {"text": "thumbs_up"})
|
||||
assert repo.set_feedback(conv["id"], 0, None) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["feedback"] is None
|
||||
|
||||
def test_set_feedback_nonexistent_position(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
assert repo.set_feedback(conv["id"], 99, {"text": "x"}) is False
|
||||
|
||||
|
||||
class TestMessageCount:
|
||||
def test_counts_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
assert repo.message_count(conv["id"]) == 0
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
assert repo.message_count(conv["id"]) == 1
|
||||
|
||||
|
||||
class TestCompressionMetadata:
|
||||
def test_set_compression_metadata(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
meta = {"is_compressed": True, "last_compression_at": "2026-01-01T00:00:00Z"}
|
||||
assert repo.update_compression_metadata(conv["id"], "user-1", meta) is True
|
||||
fetched = repo.get(conv["id"], "user-1")
|
||||
assert fetched["compression_metadata"]["is_compressed"] is True
|
||||
|
||||
def test_set_compression_flags_preserves_points(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.update_compression_metadata(conv["id"], "user-1", {
|
||||
"is_compressed": False,
|
||||
"compression_points": [{"summary": "earlier"}],
|
||||
})
|
||||
assert repo.set_compression_flags(
|
||||
conv["id"], is_compressed=True, last_compression_at="2026-01-02",
|
||||
) is True
|
||||
fetched = repo.get(conv["id"], "user-1")
|
||||
assert fetched["compression_metadata"]["is_compressed"] is True
|
||||
assert fetched["compression_metadata"]["last_compression_at"] == "2026-01-02"
|
||||
assert fetched["compression_metadata"]["compression_points"] == [
|
||||
{"summary": "earlier"}
|
||||
]
|
||||
|
||||
def test_append_compression_point_slices_to_max(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
for i in range(5):
|
||||
assert repo.append_compression_point(
|
||||
conv["id"], {"summary": f"p{i}"}, max_points=3,
|
||||
) is True
|
||||
fetched = repo.get(conv["id"], "user-1")
|
||||
points = fetched["compression_metadata"]["compression_points"]
|
||||
assert [p["summary"] for p in points] == ["p2", "p3", "p4"]
|
||||
|
||||
|
||||
class TestConcurrentAppend:
|
||||
"""Two threads appending to the same conversation must not race on
|
||||
``position``. The plan (migration-postgres.md §Phase 3) explicitly
|
||||
calls this out as the single trickiest invariant, so we exercise it
|
||||
directly with two parallel connections."""
|
||||
|
||||
def test_concurrent_appends_get_distinct_positions(self, pg_engine, pg_conn):
|
||||
import threading
|
||||
|
||||
# Arrange — one conversation, created inside the outer test txn so
|
||||
# it disappears on teardown even if the workers somehow commit.
|
||||
# We commit it explicitly so the workers' separate sessions see it.
|
||||
repo_setup = _repo(pg_conn)
|
||||
conv = repo_setup.create("user-concurrent", "c")
|
||||
pg_conn.commit()
|
||||
|
||||
try:
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
with pg_engine.begin() as worker_conn:
|
||||
ConversationsRepository(worker_conn).append_message(
|
||||
conv["id"], {"prompt": "q", "response": "a"},
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(2)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"worker threads errored: {errors}"
|
||||
|
||||
# Assert — the parent row-lock in append_message must have
|
||||
# serialised the two inserts so they land at positions {0, 1}.
|
||||
with pg_engine.connect() as verify_conn:
|
||||
msgs = ConversationsRepository(verify_conn).get_messages(conv["id"])
|
||||
positions = sorted(m["position"] for m in msgs)
|
||||
assert positions == [0, 1], (
|
||||
f"concurrent appends raced; got positions {positions}"
|
||||
)
|
||||
finally:
|
||||
# Clean up — the conversation was committed, so the transaction
|
||||
# rollback won't drop it.
|
||||
with pg_engine.begin() as cleanup_conn:
|
||||
ConversationsRepository(cleanup_conn).delete(
|
||||
conv["id"], "user-concurrent"
|
||||
)
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Tests for FeedbackRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.feedback import FeedbackRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> FeedbackRepository:
|
||||
return FeedbackRepository(conn)
|
||||
|
||||
|
||||
def _make_conversation_id(pg_conn) -> str:
|
||||
"""Insert a minimal conversations row and return its id as string.
|
||||
|
||||
feedback has a FK to conversations (added in Tier 3 migration), but
|
||||
the FK constraint may not exist yet during early phases. We create a
|
||||
row anyway to keep tests realistic.
|
||||
"""
|
||||
cid = str(uuid.uuid4())
|
||||
# Only insert if the conversations table exists; otherwise use a random UUID.
|
||||
row = pg_conn.execute(
|
||||
text(
|
||||
"SELECT 1 FROM information_schema.tables "
|
||||
"WHERE table_schema='public' AND table_name='conversations'"
|
||||
)
|
||||
).scalar()
|
||||
if row:
|
||||
pg_conn.execute(
|
||||
text("INSERT INTO conversations (id, user_id) VALUES (CAST(:id AS uuid), 'test')"),
|
||||
{"id": cid},
|
||||
)
|
||||
return cid
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_feedback(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid = _make_conversation_id(pg_conn)
|
||||
doc = repo.create(cid, "user-1", 0, "great answer")
|
||||
assert doc["conversation_id"] is not None
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["question_index"] == 0
|
||||
assert doc["feedback_text"] == "great answer"
|
||||
|
||||
def test_allows_null_feedback_text(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid = _make_conversation_id(pg_conn)
|
||||
doc = repo.create(cid, "user-1", 1)
|
||||
assert doc["feedback_text"] is None
|
||||
|
||||
|
||||
class TestListForConversation:
|
||||
def test_lists_feedback_for_conversation(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid = _make_conversation_id(pg_conn)
|
||||
repo.create(cid, "user-1", 0, "good")
|
||||
repo.create(cid, "user-1", 1, "bad")
|
||||
results = repo.list_for_conversation(cid)
|
||||
assert len(results) == 2
|
||||
assert results[0]["question_index"] == 0
|
||||
assert results[1]["question_index"] == 1
|
||||
|
||||
def test_does_not_mix_conversations(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
cid1 = _make_conversation_id(pg_conn)
|
||||
cid2 = _make_conversation_id(pg_conn)
|
||||
repo.create(cid1, "user-1", 0, "a")
|
||||
repo.create(cid2, "user-1", 0, "b")
|
||||
assert len(repo.list_for_conversation(cid1)) == 1
|
||||
99
tests/storage/db/repositories/test_pending_tool_state.py
Normal file
99
tests/storage/db/repositories/test_pending_tool_state.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for PendingToolStateRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import PendingToolStateRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _conv(conn) -> dict:
|
||||
return ConversationsRepository(conn).create("user-1", "test conv")
|
||||
|
||||
|
||||
def _repo(conn) -> PendingToolStateRepository:
|
||||
return PendingToolStateRepository(conn)
|
||||
|
||||
|
||||
def _sample_state() -> dict:
|
||||
return {
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"pending_tool_calls": [{"id": "tc-1", "name": "search"}],
|
||||
"tools_dict": {"search": {"type": "function"}},
|
||||
"tool_schemas": [{"name": "search"}],
|
||||
"agent_config": {"model_id": "gpt-4", "llm_name": "openai"},
|
||||
}
|
||||
|
||||
|
||||
class TestSaveState:
|
||||
def test_creates_state(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
state = _sample_state()
|
||||
doc = repo.save_state(conv["id"], "user-1", **state)
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["messages"] == state["messages"]
|
||||
assert doc["pending_tool_calls"] == state["pending_tool_calls"]
|
||||
assert doc["expires_at"] is not None
|
||||
|
||||
def test_upsert_replaces_existing(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
state = _sample_state()
|
||||
repo.save_state(conv["id"], "user-1", **state)
|
||||
state["messages"] = [{"role": "user", "content": "updated"}]
|
||||
doc2 = repo.save_state(conv["id"], "user-1", **state)
|
||||
# Same row, updated content
|
||||
assert doc2["messages"] == [{"role": "user", "content": "updated"}]
|
||||
|
||||
def test_save_with_client_tools(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
state = _sample_state()
|
||||
state["client_tools"] = [{"name": "browser"}]
|
||||
doc = repo.save_state(conv["id"], "user-1", **state)
|
||||
assert doc["client_tools"] == [{"name": "browser"}]
|
||||
|
||||
|
||||
class TestLoadState:
|
||||
def test_loads_existing(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.save_state(conv["id"], "user-1", **_sample_state())
|
||||
loaded = repo.load_state(conv["id"], "user-1")
|
||||
assert loaded is not None
|
||||
assert loaded["agent_config"]["model_id"] == "gpt-4"
|
||||
|
||||
def test_load_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.load_state("00000000-0000-0000-0000-000000000000", "u") is None
|
||||
|
||||
|
||||
class TestDeleteState:
|
||||
def test_deletes(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.save_state(conv["id"], "user-1", **_sample_state())
|
||||
assert repo.delete_state(conv["id"], "user-1") is True
|
||||
assert repo.load_state(conv["id"], "user-1") is None
|
||||
|
||||
def test_delete_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.delete_state("00000000-0000-0000-0000-000000000000", "u") is False
|
||||
|
||||
|
||||
class TestCleanupExpired:
|
||||
def test_cleanup_removes_expired(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
# Create a state with TTL of 0 seconds (already expired)
|
||||
repo.save_state(conv["id"], "user-1", **_sample_state(), ttl_seconds=0)
|
||||
deleted = repo.cleanup_expired()
|
||||
assert deleted >= 1
|
||||
assert repo.load_state(conv["id"], "user-1") is None
|
||||
@@ -30,6 +30,11 @@ class TestCreate:
|
||||
doc = repo.create("user-1", "p", "c")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "p", "c", legacy_mongo_id="507f1f77bcf86cd799439011")
|
||||
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_by_id_and_user(self, pg_conn):
|
||||
@@ -47,6 +52,17 @@ class TestGet:
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"p",
|
||||
"c",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestGetForRendering:
|
||||
def test_returns_prompt_without_user_scoping(self, pg_conn):
|
||||
@@ -88,6 +104,24 @@ class TestUpdate:
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "old"
|
||||
|
||||
def test_update_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(
|
||||
"user-1",
|
||||
"old",
|
||||
"old-content",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert repo.update_by_legacy_id(
|
||||
"507f1f77bcf86cd799439011",
|
||||
"user-1",
|
||||
"new",
|
||||
"new-content",
|
||||
) is True
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
assert fetched["content"] == "new-content"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_prompt(self, pg_conn):
|
||||
@@ -102,6 +136,17 @@ class TestDelete:
|
||||
repo.delete(created["id"], "user-other")
|
||||
assert repo.get(created["id"], "user-1") is not None
|
||||
|
||||
def test_delete_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"p",
|
||||
"c",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert repo.delete_by_legacy_id("507f1f77bcf86cd799439011", "user-1") is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
|
||||
class TestFindOrCreate:
|
||||
def test_creates_when_missing(self, pg_conn):
|
||||
|
||||
91
tests/storage/db/repositories/test_shared_conversations.py
Normal file
91
tests/storage/db/repositories/test_shared_conversations.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for SharedConversationsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.shared_conversations import SharedConversationsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _conv(conn) -> dict:
|
||||
return ConversationsRepository(conn).create("user-1", "test conv")
|
||||
|
||||
|
||||
def _repo(conn) -> SharedConversationsRepository:
|
||||
return SharedConversationsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_share(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
share = repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
|
||||
assert share["conversation_id"] is not None
|
||||
assert share["user_id"] == "user-1"
|
||||
assert share["is_promptable"] is False
|
||||
assert share["first_n_queries"] == 3
|
||||
assert share["uuid"] is not None
|
||||
|
||||
def test_create_promptable_with_api_key(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
share = repo.create(
|
||||
conv["id"], "user-1",
|
||||
is_promptable=True,
|
||||
first_n_queries=5,
|
||||
api_key="ak-prompt",
|
||||
)
|
||||
assert share["is_promptable"] is True
|
||||
assert share["api_key"] == "ak-prompt"
|
||||
|
||||
|
||||
class TestFindByUuid:
|
||||
def test_finds_by_uuid(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
share = repo.create(conv["id"], "user-1", first_n_queries=2)
|
||||
found = repo.find_by_uuid(str(share["uuid"]))
|
||||
assert found["id"] == share["id"]
|
||||
|
||||
def test_not_found(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.find_by_uuid("00000000-0000-0000-0000-000000000000") is None
|
||||
|
||||
|
||||
class TestFindExisting:
|
||||
def test_finds_matching_share(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
|
||||
found = repo.find_existing(conv["id"], "user-1", False, 3)
|
||||
assert found is not None
|
||||
assert found["first_n_queries"] == 3
|
||||
|
||||
def test_no_match_different_params(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
|
||||
assert repo.find_existing(conv["id"], "user-1", True, 3) is None
|
||||
|
||||
def test_finds_with_api_key(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", is_promptable=True, first_n_queries=5, api_key="ak-1")
|
||||
found = repo.find_existing(conv["id"], "user-1", True, 5, api_key="ak-1")
|
||||
assert found is not None
|
||||
|
||||
|
||||
class TestListForConversation:
|
||||
def test_lists_shares(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", first_n_queries=1)
|
||||
repo.create(conv["id"], "user-1", first_n_queries=2)
|
||||
results = repo.list_for_conversation(conv["id"])
|
||||
assert len(results) == 2
|
||||
@@ -25,12 +25,6 @@ class TestCreate:
|
||||
assert doc["type"] == "url"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_creates_system_source_without_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("system-src")
|
||||
assert doc["user_id"] is None
|
||||
assert doc["name"] == "system-src"
|
||||
|
||||
def test_creates_source_with_metadata(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("src", user_id="u", metadata={"url": "https://example.com"})
|
||||
@@ -38,7 +32,7 @@ class TestCreate:
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("s")
|
||||
doc = repo.create("s", user_id="u")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
|
||||
|
||||
116
tests/storage/db/repositories/test_workflow_edges.py
Normal file
116
tests/storage/db/repositories/test_workflow_edges.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Tests for WorkflowEdgesRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _setup(conn) -> tuple[dict, dict, dict]:
|
||||
"""Create a workflow with two nodes and return (workflow, node1, node2)."""
|
||||
wf = WorkflowsRepository(conn).create("user-1", "test wf")
|
||||
node_repo = WorkflowNodesRepository(conn)
|
||||
n1 = node_repo.create(wf["id"], 1, "start-node", "start")
|
||||
n2 = node_repo.create(wf["id"], 1, "end-node", "end")
|
||||
return wf, n1, n2
|
||||
|
||||
|
||||
def _repo(conn) -> WorkflowEdgesRepository:
|
||||
return WorkflowEdgesRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_edge(self, pg_conn):
|
||||
wf, n1, n2 = _setup(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
edge = repo.create(
|
||||
wf["id"], 1, "edge-1", n1["id"], n2["id"],
|
||||
source_handle="out", target_handle="in",
|
||||
)
|
||||
assert edge["edge_id"] == "edge-1"
|
||||
assert str(edge["from_node_id"]) == n1["id"]
|
||||
assert str(edge["to_node_id"]) == n2["id"]
|
||||
assert edge["source_handle"] == "out"
|
||||
assert edge["target_handle"] == "in"
|
||||
|
||||
|
||||
class TestBulkCreate:
|
||||
def test_bulk_creates_edges(self, pg_conn):
|
||||
wf, n1, n2 = _setup(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
edges = repo.bulk_create(wf["id"], 1, [
|
||||
{"edge_id": "e1", "from_node_id": n1["id"], "to_node_id": n2["id"]},
|
||||
{"edge_id": "e2", "from_node_id": n2["id"], "to_node_id": n1["id"],
|
||||
"source_handle": "loop"},
|
||||
])
|
||||
assert len(edges) == 2
|
||||
|
||||
def test_bulk_create_empty(self, pg_conn):
|
||||
wf, _, _ = _setup(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.bulk_create(wf["id"], 1, []) == []
|
||||
|
||||
|
||||
class TestFindByVersion:
|
||||
def test_finds_edges(self, pg_conn):
|
||||
wf, n1, n2 = _setup(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(wf["id"], 1, "e1", n1["id"], n2["id"])
|
||||
edges = repo.find_by_version(wf["id"], 1)
|
||||
assert len(edges) == 1
|
||||
assert edges[0]["edge_id"] == "e1"
|
||||
|
||||
def test_no_edges_for_version(self, pg_conn):
|
||||
wf, _, _ = _setup(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.find_by_version(wf["id"], 99) == []
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_by_workflow(self, pg_conn):
|
||||
wf, n1, n2 = _setup(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(wf["id"], 1, "e1", n1["id"], n2["id"])
|
||||
deleted = repo.delete_by_workflow(wf["id"])
|
||||
assert deleted == 1
|
||||
assert repo.find_by_version(wf["id"], 1) == []
|
||||
|
||||
def test_delete_by_version(self, pg_conn):
|
||||
wf = WorkflowsRepository(pg_conn).create("user-1", "wf")
|
||||
node_repo = WorkflowNodesRepository(pg_conn)
|
||||
n1v1 = node_repo.create(wf["id"], 1, "n1", "start")
|
||||
n2v1 = node_repo.create(wf["id"], 1, "n2", "end")
|
||||
n1v2 = node_repo.create(wf["id"], 2, "n1", "start")
|
||||
n2v2 = node_repo.create(wf["id"], 2, "n2", "end")
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(wf["id"], 1, "e1", n1v1["id"], n2v1["id"])
|
||||
repo.create(wf["id"], 2, "e1", n1v2["id"], n2v2["id"])
|
||||
|
||||
repo.delete_by_version(wf["id"], 1)
|
||||
assert repo.find_by_version(wf["id"], 1) == []
|
||||
assert len(repo.find_by_version(wf["id"], 2)) == 1
|
||||
|
||||
def test_delete_other_versions(self, pg_conn):
|
||||
wf = WorkflowsRepository(pg_conn).create("user-1", "wf")
|
||||
node_repo = WorkflowNodesRepository(pg_conn)
|
||||
n1v1 = node_repo.create(wf["id"], 1, "n1", "start")
|
||||
n2v1 = node_repo.create(wf["id"], 1, "n2", "end")
|
||||
n1v2 = node_repo.create(wf["id"], 2, "n1", "start")
|
||||
n2v2 = node_repo.create(wf["id"], 2, "n2", "end")
|
||||
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(wf["id"], 1, "e1", n1v1["id"], n2v1["id"])
|
||||
repo.create(wf["id"], 2, "e1", n1v2["id"], n2v2["id"])
|
||||
|
||||
repo.delete_other_versions(wf["id"], 2)
|
||||
assert repo.find_by_version(wf["id"], 1) == []
|
||||
assert len(repo.find_by_version(wf["id"], 2)) == 1
|
||||
172
tests/storage/db/repositories/test_workflow_nodes.py
Normal file
172
tests/storage/db/repositories/test_workflow_nodes.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Tests for WorkflowNodesRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _wf(conn) -> dict:
|
||||
return WorkflowsRepository(conn).create("user-1", "test wf")
|
||||
|
||||
|
||||
def _repo(conn) -> WorkflowNodesRepository:
|
||||
return WorkflowNodesRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_node(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
node = repo.create(wf["id"], 1, "node-start", "start", title="Start")
|
||||
assert node["node_id"] == "node-start"
|
||||
assert node["node_type"] == "start"
|
||||
assert node["title"] == "Start"
|
||||
assert node["graph_version"] == 1
|
||||
|
||||
def test_create_with_config(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
node = repo.create(
|
||||
wf["id"], 1, "node-agent", "agent",
|
||||
config={"agent_type": "classic", "system_prompt": "You are helpful"},
|
||||
position={"x": 100, "y": 200},
|
||||
)
|
||||
assert node["config"]["agent_type"] == "classic"
|
||||
assert node["position"]["x"] == 100
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
node = repo.create(
|
||||
wf["id"],
|
||||
1,
|
||||
"node-agent",
|
||||
"agent",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert node["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
|
||||
class TestBulkCreate:
|
||||
def test_bulk_creates_nodes(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
nodes = repo.bulk_create(wf["id"], 1, [
|
||||
{"node_id": "n1", "node_type": "start", "title": "Start"},
|
||||
{"node_id": "n2", "node_type": "agent", "config": {"agent_type": "react"}},
|
||||
{"node_id": "n3", "node_type": "end"},
|
||||
])
|
||||
assert len(nodes) == 3
|
||||
node_ids = {n["node_id"] for n in nodes}
|
||||
assert node_ids == {"n1", "n2", "n3"}
|
||||
|
||||
def test_bulk_create_empty(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.bulk_create(wf["id"], 1, []) == []
|
||||
|
||||
def test_bulk_create_with_legacy_mongo_ids(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
nodes = repo.bulk_create(wf["id"], 1, [
|
||||
{
|
||||
"node_id": "n1",
|
||||
"node_type": "start",
|
||||
"legacy_mongo_id": "507f1f77bcf86cd799439011",
|
||||
},
|
||||
{
|
||||
"node_id": "n2",
|
||||
"node_type": "end",
|
||||
"legacy_mongo_id": "507f1f77bcf86cd799439012",
|
||||
},
|
||||
])
|
||||
assert {n["legacy_mongo_id"] for n in nodes} == {
|
||||
"507f1f77bcf86cd799439011",
|
||||
"507f1f77bcf86cd799439012",
|
||||
}
|
||||
|
||||
|
||||
class TestFindByVersion:
|
||||
def test_finds_nodes(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.bulk_create(wf["id"], 1, [
|
||||
{"node_id": "n1", "node_type": "start"},
|
||||
{"node_id": "n2", "node_type": "end"},
|
||||
])
|
||||
repo.bulk_create(wf["id"], 2, [
|
||||
{"node_id": "n1", "node_type": "start"},
|
||||
])
|
||||
v1_nodes = repo.find_by_version(wf["id"], 1)
|
||||
v2_nodes = repo.find_by_version(wf["id"], 2)
|
||||
assert len(v1_nodes) == 2
|
||||
assert len(v2_nodes) == 1
|
||||
|
||||
|
||||
class TestFindNode:
|
||||
def test_finds_specific_node(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(wf["id"], 1, "node-start", "start")
|
||||
found = repo.find_node(wf["id"], 1, "node-start")
|
||||
assert found is not None
|
||||
assert found["node_type"] == "start"
|
||||
|
||||
def test_not_found(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.find_node(wf["id"], 1, "nonexistent") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
wf["id"],
|
||||
1,
|
||||
"node-start",
|
||||
"start",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
found = repo.get_by_legacy_id("507f1f77bcf86cd799439011")
|
||||
assert found["id"] == created["id"]
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete_by_workflow(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.bulk_create(wf["id"], 1, [
|
||||
{"node_id": "n1", "node_type": "start"},
|
||||
{"node_id": "n2", "node_type": "end"},
|
||||
])
|
||||
deleted = repo.delete_by_workflow(wf["id"])
|
||||
assert deleted == 2
|
||||
assert repo.find_by_version(wf["id"], 1) == []
|
||||
|
||||
def test_delete_by_version(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.bulk_create(wf["id"], 1, [{"node_id": "n1", "node_type": "start"}])
|
||||
repo.bulk_create(wf["id"], 2, [{"node_id": "n1", "node_type": "start"}])
|
||||
repo.delete_by_version(wf["id"], 1)
|
||||
assert repo.find_by_version(wf["id"], 1) == []
|
||||
assert len(repo.find_by_version(wf["id"], 2)) == 1
|
||||
|
||||
def test_delete_other_versions(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.bulk_create(wf["id"], 1, [{"node_id": "n1", "node_type": "start"}])
|
||||
repo.bulk_create(wf["id"], 2, [{"node_id": "n1", "node_type": "start"}])
|
||||
repo.bulk_create(wf["id"], 3, [{"node_id": "n1", "node_type": "start"}])
|
||||
repo.delete_other_versions(wf["id"], 2)
|
||||
assert repo.find_by_version(wf["id"], 1) == []
|
||||
assert len(repo.find_by_version(wf["id"], 2)) == 1
|
||||
assert repo.find_by_version(wf["id"], 3) == []
|
||||
104
tests/storage/db/repositories/test_workflow_runs.py
Normal file
104
tests/storage/db/repositories/test_workflow_runs.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for WorkflowRunsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _wf(conn) -> dict:
|
||||
return WorkflowsRepository(conn).create("user-1", "test wf")
|
||||
|
||||
|
||||
def _repo(conn) -> WorkflowRunsRepository:
|
||||
return WorkflowRunsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_run(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
run = repo.create(wf["id"], "user-1", "completed")
|
||||
assert run["status"] == "completed"
|
||||
assert run["user_id"] == "user-1"
|
||||
assert run["id"] is not None
|
||||
|
||||
def test_create_with_details(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
now = datetime.now(timezone.utc)
|
||||
run = repo.create(
|
||||
wf["id"], "user-1", "completed",
|
||||
inputs={"query": "hello"},
|
||||
result={"output": "world"},
|
||||
steps=[
|
||||
{"node_id": "n1", "status": "completed"},
|
||||
{"node_id": "n2", "status": "completed"},
|
||||
],
|
||||
ended_at=now,
|
||||
)
|
||||
assert run["inputs"] == {"query": "hello"}
|
||||
assert run["result"] == {"output": "world"}
|
||||
assert len(run["steps"]) == 2
|
||||
assert run["ended_at"] is not None
|
||||
|
||||
def test_create_with_started_at_and_legacy_id(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
now = datetime.now(timezone.utc)
|
||||
run = repo.create(
|
||||
wf["id"],
|
||||
"user-1",
|
||||
"completed",
|
||||
started_at=now,
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert run["started_at"] == now
|
||||
assert run["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(wf["id"], "user-1", "completed")
|
||||
fetched = repo.get(created["id"])
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
wf["id"], "user-1", "completed",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestListForWorkflow:
|
||||
def test_lists_runs(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(wf["id"], "user-1", "completed")
|
||||
repo.create(wf["id"], "user-1", "failed")
|
||||
runs = repo.list_for_workflow(wf["id"])
|
||||
assert len(runs) == 2
|
||||
|
||||
def test_empty_list(self, pg_conn):
|
||||
wf = _wf(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.list_for_workflow(wf["id"]) == []
|
||||
117
tests/storage/db/repositories/test_workflows.py
Normal file
117
tests/storage/db/repositories/test_workflows.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Tests for WorkflowsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> WorkflowsRepository:
|
||||
return WorkflowsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_workflow(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "My Workflow")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "My Workflow"
|
||||
assert doc["current_graph_version"] == 1
|
||||
assert doc["id"] is not None
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_description(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "wf", description="A test workflow")
|
||||
assert doc["description"] == "A test workflow"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
def test_get_by_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
fetched = repo.get_by_id(created["id"])
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1", "wf", legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_own(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "wf1")
|
||||
repo.create("alice", "wf2")
|
||||
repo.create("bob", "wf3")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
def test_updates_name(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
assert repo.update(created["id"], "user-1", {"name": "new"}) is True
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_update_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
assert repo.update(created["id"], "other", {"name": "new"}) is False
|
||||
|
||||
def test_update_disallowed_field(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
assert repo.update(created["id"], "user-1", {"id": "bad"}) is False
|
||||
|
||||
|
||||
class TestIncrementGraphVersion:
|
||||
def test_increments(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
assert created["current_graph_version"] == 1
|
||||
new_ver = repo.increment_graph_version(created["id"], "user-1")
|
||||
assert new_ver == 2
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["current_graph_version"] == 2
|
||||
|
||||
def test_increment_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
assert repo.increment_graph_version(created["id"], "other") is None
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
assert repo.delete(created["id"], "user-1") is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
def test_delete_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "wf")
|
||||
assert repo.delete(created["id"], "other") is False
|
||||
414
tests/storage/db/test_dual_write_backfill.py
Normal file
414
tests/storage/db/test_dual_write_backfill.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""Integration coverage for dual-write rows surviving backfill reruns."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from bson.dbref import DBRef
|
||||
from flask import Flask, request
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.repositories.shared_conversations import SharedConversationsRepository
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from scripts.db.backfill import (
|
||||
_backfill_agents,
|
||||
_backfill_attachments,
|
||||
_backfill_conversations,
|
||||
_backfill_prompts,
|
||||
_backfill_shared_conversations,
|
||||
_backfill_workflow_runs,
|
||||
_backfill_workflow_nodes,
|
||||
_backfill_workflows,
|
||||
)
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
class _BoundEngine:
|
||||
"""Expose one pre-opened SQLAlchemy connection as an Engine.begin()."""
|
||||
|
||||
def __init__(self, conn):
|
||||
self._conn = conn
|
||||
|
||||
@contextmanager
|
||||
def begin(self):
|
||||
yield self._conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mongo_db(mock_mongo_db):
|
||||
return mock_mongo_db[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dual_write_pg(monkeypatch, pg_conn):
|
||||
monkeypatch.setattr(settings, "USE_POSTGRES", True)
|
||||
monkeypatch.setattr(
|
||||
"application.storage.db.engine.get_engine",
|
||||
lambda: _BoundEngine(pg_conn),
|
||||
)
|
||||
return pg_conn
|
||||
|
||||
|
||||
def test_prompt_dual_write_row_survives_backfill_rerun(app, mongo_db, dual_write_pg, monkeypatch):
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mongo_db["prompts"],
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/create_prompt",
|
||||
method="POST",
|
||||
json={"name": "Greeting", "content": "Hello"},
|
||||
):
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
response = CreatePrompt().post()
|
||||
|
||||
assert response.status_code == 200
|
||||
mongo_id = response.json["id"]
|
||||
|
||||
repo = PromptsRepository(dual_write_pg)
|
||||
prompt = repo.get_by_legacy_id(mongo_id, "user-1")
|
||||
assert prompt is not None
|
||||
assert prompt["content"] == "Hello"
|
||||
|
||||
mongo_db["prompts"].update_one(
|
||||
{"_id": ObjectId(mongo_id)},
|
||||
{"$set": {"content": "Hello again"}},
|
||||
)
|
||||
_backfill_prompts(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
|
||||
|
||||
prompts = repo.list_for_user("user-1")
|
||||
assert len(prompts) == 1
|
||||
assert prompts[0]["legacy_mongo_id"] == mongo_id
|
||||
assert prompts[0]["content"] == "Hello again"
|
||||
|
||||
|
||||
def test_agent_dual_write_row_survives_backfill_rerun(app, mongo_db, dual_write_pg, monkeypatch):
|
||||
from application.api.user.agents.routes import CreateAgent
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.agents.routes.agents_collection",
|
||||
mongo_db["agents"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.agents.routes.handle_image_upload",
|
||||
lambda *_args, **_kwargs: ("", None),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/create_agent",
|
||||
method="POST",
|
||||
json={"name": "Mirror Agent", "status": "draft"},
|
||||
):
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
response = CreateAgent().post()
|
||||
|
||||
assert response.status_code == 201
|
||||
mongo_id = response.json["id"]
|
||||
|
||||
repo = AgentsRepository(dual_write_pg)
|
||||
agent = repo.get_by_legacy_id(mongo_id, "user-1")
|
||||
assert agent is not None
|
||||
assert agent["name"] == "Mirror Agent"
|
||||
|
||||
mongo_db["agents"].update_one(
|
||||
{"_id": ObjectId(mongo_id)},
|
||||
{"$set": {"name": "Renamed Agent", "description": "Updated by backfill"}},
|
||||
)
|
||||
_backfill_agents(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
|
||||
|
||||
agents = repo.list_for_user("user-1")
|
||||
assert len(agents) == 1
|
||||
assert agents[0]["legacy_mongo_id"] == mongo_id
|
||||
assert agents[0]["name"] == "Renamed Agent"
|
||||
assert agents[0]["description"] == "Updated by backfill"
|
||||
|
||||
|
||||
def test_attachment_dual_write_row_survives_backfill_rerun(mongo_db, dual_write_pg):
|
||||
mongo_id = ObjectId()
|
||||
mongo_db["attachments"].insert_one(
|
||||
{
|
||||
"_id": mongo_id,
|
||||
"user": "user-1",
|
||||
"filename": "notes.txt",
|
||||
"upload_path": "/uploads/notes.txt",
|
||||
"mime_type": "text/plain",
|
||||
"size": 12,
|
||||
}
|
||||
)
|
||||
|
||||
dual_write(
|
||||
AttachmentsRepository,
|
||||
lambda repo: repo.create(
|
||||
"user-1",
|
||||
"notes.txt",
|
||||
"/uploads/notes.txt",
|
||||
mime_type="text/plain",
|
||||
size=12,
|
||||
legacy_mongo_id=str(mongo_id),
|
||||
),
|
||||
)
|
||||
|
||||
repo = AttachmentsRepository(dual_write_pg)
|
||||
attachment = repo.get_by_legacy_id(str(mongo_id), "user-1")
|
||||
assert attachment is not None
|
||||
assert attachment["filename"] == "notes.txt"
|
||||
|
||||
mongo_db["attachments"].update_one(
|
||||
{"_id": mongo_id},
|
||||
{"$set": {"filename": "notes-v2.txt", "size": 24}},
|
||||
)
|
||||
_backfill_attachments(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
|
||||
|
||||
attachments = repo.list_for_user("user-1")
|
||||
assert len(attachments) == 1
|
||||
assert attachments[0]["legacy_mongo_id"] == str(mongo_id)
|
||||
assert attachments[0]["filename"] == "notes-v2.txt"
|
||||
assert attachments[0]["size"] == 24
|
||||
|
||||
|
||||
def test_workflow_nodes_dual_write_rows_survive_backfill_rerun(mongo_db, dual_write_pg, monkeypatch):
|
||||
from application.api.user.workflows.routes import (
|
||||
_dual_write_workflow_create,
|
||||
create_workflow_edges,
|
||||
create_workflow_nodes,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.workflows.routes.workflows_collection",
|
||||
mongo_db["workflows"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.workflows.routes.workflow_nodes_collection",
|
||||
mongo_db["workflow_nodes"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.api.user.workflows.routes.workflow_edges_collection",
|
||||
mongo_db["workflow_edges"],
|
||||
)
|
||||
|
||||
workflow_doc = {
|
||||
"name": "Workflow",
|
||||
"description": "test",
|
||||
"user": "user-1",
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
insert_result = mongo_db["workflows"].insert_one(workflow_doc)
|
||||
workflow_id = str(insert_result.inserted_id)
|
||||
nodes_data = [
|
||||
{"id": "start", "type": "start", "title": "Start"},
|
||||
{"id": "end", "type": "end", "title": "End"},
|
||||
]
|
||||
edges_data = [{"id": "edge-1", "source": "start", "target": "end"}]
|
||||
|
||||
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
_dual_write_workflow_create(
|
||||
workflow_id,
|
||||
"user-1",
|
||||
"Workflow",
|
||||
"test",
|
||||
created_nodes,
|
||||
edges_data,
|
||||
)
|
||||
|
||||
workflow_repo = WorkflowsRepository(dual_write_pg)
|
||||
workflow = workflow_repo.list_for_user("user-1")[0]
|
||||
node_repo = WorkflowNodesRepository(dual_write_pg)
|
||||
pg_nodes = node_repo.find_by_version(workflow["id"], 1)
|
||||
assert len(pg_nodes) == 2
|
||||
assert all(node["legacy_mongo_id"] for node in pg_nodes)
|
||||
|
||||
renamed_node_id = created_nodes[0]["legacy_mongo_id"]
|
||||
mongo_db["workflow_nodes"].update_one(
|
||||
{"_id": ObjectId(renamed_node_id)},
|
||||
{"$set": {"title": "Renamed Start"}},
|
||||
)
|
||||
|
||||
_backfill_workflows(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
|
||||
_backfill_workflow_nodes(conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False)
|
||||
|
||||
pg_nodes = node_repo.find_by_version(workflow["id"], 1)
|
||||
assert len(pg_nodes) == 2
|
||||
renamed_node = node_repo.get_by_legacy_id(renamed_node_id)
|
||||
assert renamed_node is not None
|
||||
assert renamed_node["title"] == "Renamed Start"
|
||||
|
||||
|
||||
def test_compression_summary_dual_write_appends_pg_message(mongo_db, dual_write_pg):
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
|
||||
mongo_conv_id = ObjectId()
|
||||
mongo_db["conversations"].insert_one(
|
||||
{"_id": mongo_conv_id, "user": "user-1", "queries": []},
|
||||
)
|
||||
conv = ConversationsRepository(dual_write_pg).create(
|
||||
"user-1", "Mirror", legacy_mongo_id=str(mongo_conv_id),
|
||||
)
|
||||
|
||||
service = ConversationService()
|
||||
metadata = {
|
||||
"compressed_summary": "Compressed context summary",
|
||||
"timestamp": "2026-04-13T12:00:00+00:00",
|
||||
"model_used": "gpt-4",
|
||||
}
|
||||
service.append_compression_message(str(mongo_conv_id), metadata)
|
||||
|
||||
pg_messages = ConversationsRepository(dual_write_pg).get_messages(conv["id"])
|
||||
assert len(pg_messages) == 1
|
||||
assert pg_messages[0]["prompt"] == "[Context Compression Summary]"
|
||||
assert pg_messages[0]["response"] == "Compressed context summary"
|
||||
assert pg_messages[0]["model_id"] == "gpt-4"
|
||||
|
||||
|
||||
def test_workflow_run_dual_write_row_survives_backfill_rerun(mongo_db, dual_write_pg):
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
|
||||
mongo_workflow_id = ObjectId()
|
||||
mongo_db["workflows"].insert_one(
|
||||
{
|
||||
"_id": mongo_workflow_id,
|
||||
"user": "user-1",
|
||||
"name": "Workflow",
|
||||
"description": "test",
|
||||
}
|
||||
)
|
||||
workflow = WorkflowsRepository(dual_write_pg).create(
|
||||
"user-1", "Workflow", description="test",
|
||||
legacy_mongo_id=str(mongo_workflow_id),
|
||||
)
|
||||
|
||||
agent = WorkflowAgent(
|
||||
endpoint="https://api.example.com",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4",
|
||||
api_key="test_key",
|
||||
user_api_key=None,
|
||||
prompt="You are helpful.",
|
||||
chat_history=[],
|
||||
decoded_token={"sub": "user-1"},
|
||||
attachments=[],
|
||||
json_schema=None,
|
||||
workflow_id=str(mongo_workflow_id),
|
||||
workflow_owner="user-1",
|
||||
)
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.state = {"answer": "ok"}
|
||||
agent._engine.execution_log = []
|
||||
agent._engine.get_execution_summary.return_value = []
|
||||
|
||||
agent._save_workflow_run("hello")
|
||||
|
||||
run_repo = WorkflowRunsRepository(dual_write_pg)
|
||||
runs = run_repo.list_for_workflow(workflow["id"])
|
||||
assert len(runs) == 1
|
||||
assert runs[0]["user_id"] == "user-1"
|
||||
legacy_mongo_id = runs[0]["legacy_mongo_id"]
|
||||
|
||||
mongo_db["workflow_runs"].update_one(
|
||||
{"_id": ObjectId(legacy_mongo_id)},
|
||||
{"$set": {"status": "failed", "user": "user-1", "user_id": "user-1"}},
|
||||
)
|
||||
_backfill_workflow_runs(
|
||||
conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False,
|
||||
)
|
||||
|
||||
runs = run_repo.list_for_workflow(workflow["id"])
|
||||
assert len(runs) == 1
|
||||
assert runs[0]["status"] == "failed"
|
||||
assert runs[0]["legacy_mongo_id"] == legacy_mongo_id
|
||||
|
||||
|
||||
def test_shared_conversation_backfill_recovers_dbref_and_agent_prompt_metadata(
|
||||
mongo_db, dual_write_pg,
|
||||
):
|
||||
conv = ConversationsRepository(dual_write_pg).create(
|
||||
"user-1", "Conversation", legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
PromptsRepository(dual_write_pg).create(
|
||||
"user-1", "Prompt", "Body",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439012",
|
||||
)
|
||||
|
||||
mongo_db["agents"].insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"key": "share-key",
|
||||
"prompt_id": ObjectId("507f1f77bcf86cd799439012"),
|
||||
"chunks": "7",
|
||||
"user": "user-1",
|
||||
}
|
||||
)
|
||||
mongo_db["shared_conversations"].insert_one(
|
||||
{
|
||||
"_id": ObjectId(),
|
||||
"uuid": "00000000-0000-0000-0000-000000000001",
|
||||
"conversation_id": DBRef(
|
||||
"conversations", ObjectId("507f1f77bcf86cd799439011"),
|
||||
),
|
||||
"user": "user-1",
|
||||
"isPromptable": True,
|
||||
"first_n_queries": 2,
|
||||
"api_key": "share-key",
|
||||
}
|
||||
)
|
||||
|
||||
_backfill_shared_conversations(
|
||||
conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False,
|
||||
)
|
||||
|
||||
shares = SharedConversationsRepository(dual_write_pg).list_for_conversation(conv["id"])
|
||||
assert len(shares) == 1
|
||||
assert shares[0]["api_key"] == "share-key"
|
||||
assert shares[0]["chunks"] == 7
|
||||
assert shares[0]["prompt_id"] is not None
|
||||
|
||||
|
||||
def test_conversation_backfill_reports_unresolved_attachment_refs(mongo_db, dual_write_pg):
|
||||
mongo_db["conversations"].insert_one(
|
||||
{
|
||||
"_id": ObjectId("507f1f77bcf86cd799439021"),
|
||||
"user": "user-1",
|
||||
"name": "Conversation",
|
||||
"queries": [
|
||||
{
|
||||
"prompt": "q1",
|
||||
"response": "a1",
|
||||
"attachments": [str(ObjectId("507f1f77bcf86cd799439022"))],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
stats = _backfill_conversations(
|
||||
conn=dual_write_pg, mongo_db=mongo_db, batch_size=50, dry_run=False,
|
||||
)
|
||||
|
||||
assert stats["unresolved_attachment_refs"] == 1
|
||||
conv = ConversationsRepository(dual_write_pg).get_by_legacy_id(
|
||||
"507f1f77bcf86cd799439021",
|
||||
)
|
||||
messages = ConversationsRepository(dual_write_pg).get_messages(conv["id"])
|
||||
assert messages[0]["attachments"] == []
|
||||
Reference in New Issue
Block a user