mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 14:34:32 +00:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
18755bdd9b | ||
|
|
0f20adcbf4 | ||
|
|
5fab798707 | ||
|
|
cb30a24e05 | ||
|
|
530761d08c | ||
|
|
73fbc28744 | ||
|
|
b5b6538762 | ||
|
|
a9761061fc | ||
|
|
9388996a15 | ||
|
|
875868b7e5 | ||
|
|
502819ae52 | ||
|
|
cada1a44fc | ||
|
|
6192767451 | ||
|
|
5c3e6eca54 | ||
|
|
3931ccccee | ||
|
|
420e9d3dd5 | ||
|
|
749eed3d0b |
@@ -1,51 +1,80 @@
|
||||
Ollama
|
||||
Qdrant
|
||||
Milvus
|
||||
Chatwoot
|
||||
Nextra
|
||||
VSCode
|
||||
npm
|
||||
LLMs
|
||||
Agentic
|
||||
Anthropic's
|
||||
api
|
||||
APIs
|
||||
Groq
|
||||
SGLang
|
||||
LMDeploy
|
||||
OAuth
|
||||
Vite
|
||||
LLM
|
||||
JSONPath
|
||||
UIs
|
||||
configs
|
||||
uncomment
|
||||
qdrant
|
||||
vectorstore
|
||||
docsgpt
|
||||
llm
|
||||
GPUs
|
||||
kubectl
|
||||
Lightsail
|
||||
enqueues
|
||||
chatbot
|
||||
VSCode's
|
||||
Shareability
|
||||
feedbacks
|
||||
Atlassian
|
||||
automations
|
||||
Premade
|
||||
Signup
|
||||
Repo
|
||||
repo
|
||||
env
|
||||
URl
|
||||
agentic
|
||||
llama_cpp
|
||||
parsable
|
||||
SDKs
|
||||
boolean
|
||||
bool
|
||||
hardcode
|
||||
EOL
|
||||
Postgres
|
||||
Supabase
|
||||
config
|
||||
autoescaping
|
||||
Autoescaping
|
||||
backfill
|
||||
backfills
|
||||
bool
|
||||
boolean
|
||||
brave_web_search
|
||||
chatbot
|
||||
Chatwoot
|
||||
config
|
||||
configs
|
||||
CSVs
|
||||
dev
|
||||
diarization
|
||||
Docling
|
||||
docsgpt
|
||||
docstrings
|
||||
Entra
|
||||
env
|
||||
enqueues
|
||||
EOL
|
||||
ESLint
|
||||
feedbacks
|
||||
Figma
|
||||
GPUs
|
||||
Groq
|
||||
hardcode
|
||||
hardcoding
|
||||
Idempotency
|
||||
JSONPath
|
||||
kubectl
|
||||
Lightsail
|
||||
llama_cpp
|
||||
llm
|
||||
LLM
|
||||
LLMs
|
||||
LMDeploy
|
||||
Milvus
|
||||
Mixtral
|
||||
namespace
|
||||
namespaces
|
||||
needs_auth
|
||||
Nextra
|
||||
Novita
|
||||
npm
|
||||
OAuth
|
||||
Ollama
|
||||
opencode
|
||||
parsable
|
||||
passthrough
|
||||
PDFs
|
||||
pgvector
|
||||
Postgres
|
||||
Premade
|
||||
Pydantic
|
||||
pytest
|
||||
Qdrant
|
||||
qdrant
|
||||
Repo
|
||||
repo
|
||||
Sanitization
|
||||
SDKs
|
||||
SGLang
|
||||
Shareability
|
||||
Signup
|
||||
Supabase
|
||||
UIs
|
||||
uncomment
|
||||
URl
|
||||
vectorstore
|
||||
Vite
|
||||
VSCode
|
||||
VSCode's
|
||||
widget's
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -183,5 +183,6 @@ application/vectors/
|
||||
|
||||
node_modules/
|
||||
.vscode/settings.json
|
||||
.vscode/sftp.json
|
||||
/models/
|
||||
model/
|
||||
|
||||
@@ -73,7 +73,7 @@ class BraveSearchTool(Tool):
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
@@ -118,7 +118,7 @@ class BraveSearchTool(Tool):
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
|
||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
response = requests.get(url, timeout=100)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if currency.upper() in data:
|
||||
|
||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -31,14 +31,14 @@ class TelegramTool(Tool):
|
||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -15,6 +15,9 @@ from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -181,6 +184,9 @@ class WorkflowAgent(BaseAgent):
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
@@ -188,6 +194,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
user=owner_id,
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
@@ -196,7 +203,34 @@ class WorkflowAgent(BaseAgent):
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
result = workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
legacy_mongo_id = (
|
||||
str(result.inserted_id)
|
||||
if getattr(result, "inserted_id", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def _pg_write(repo: WorkflowRunsRepository) -> None:
|
||||
if not self.workflow_id or not owner_id or not legacy_mongo_id:
|
||||
return
|
||||
workflow = WorkflowsRepository(repo._conn).get_by_legacy_id(
|
||||
self.workflow_id, owner_id,
|
||||
)
|
||||
if workflow is None:
|
||||
return
|
||||
repo.create(
|
||||
workflow["id"],
|
||||
owner_id,
|
||||
run.status.value,
|
||||
inputs=run.inputs,
|
||||
result=run.outputs,
|
||||
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||
started_at=run.created_at,
|
||||
ended_at=run.completed_at,
|
||||
legacy_mongo_id=legacy_mongo_id,
|
||||
)
|
||||
|
||||
dual_write(WorkflowRunsRepository, _pg_write)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
|
||||
@@ -211,6 +211,7 @@ class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
workflow_id: str
|
||||
user: Optional[str] = None
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -226,7 +227,7 @@ class WorkflowRun(BaseModel):
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
doc = {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
@@ -235,3 +236,7 @@ class WorkflowRun(BaseModel):
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
if self.user:
|
||||
doc["user"] = self.user
|
||||
doc["user_id"] = self.user
|
||||
return doc
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -469,6 +469,18 @@ class BaseAnswerResource:
|
||||
log_data[key] = value[:10000]
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
|
||||
dual_write(
|
||||
UserLogsRepository,
|
||||
lambda repo, d=log_data: repo.insert(
|
||||
user_id=d.get("user"),
|
||||
endpoint="stream_answer",
|
||||
data=d,
|
||||
),
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
|
||||
@@ -13,6 +13,11 @@ from bson import ObjectId
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,6 +112,26 @@ class ContinuationService:
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — upsert against the same Mongo conversation
|
||||
# by resolving its UUID via conversations.legacy_mongo_id.
|
||||
def _pg_save(_: PendingToolStateRepository) -> None:
|
||||
conn = _._conn # reuse the existing transaction
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
_.save_state(
|
||||
conv["id"],
|
||||
user,
|
||||
messages=_make_serializable(messages),
|
||||
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||
tools_dict=_make_serializable(tools_dict),
|
||||
tool_schemas=_make_serializable(tool_schemas),
|
||||
agent_config=_make_serializable(agent_config),
|
||||
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||
)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_save)
|
||||
return state_id
|
||||
|
||||
def load_state(
|
||||
@@ -138,4 +163,13 @@ class ContinuationService:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — delete the same row.
|
||||
def _pg_delete(repo: PendingToolStateRepository) -> None:
|
||||
conv = ConversationsRepository(repo._conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.delete_state(conv["id"], user)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_delete)
|
||||
return result.deleted_count > 0
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@@ -113,6 +115,26 @@ class ConversationService:
|
||||
},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
# Dual-write to Postgres: update the message at :index and
|
||||
# truncate anything after it, mirroring Mongo's $set+$slice.
|
||||
def _pg_update_at_index(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.update_message_at(conv["id"], index, {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
})
|
||||
repo.truncate_after(conv["id"], index)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_update_at_index)
|
||||
return conversation_id
|
||||
elif conversation_id:
|
||||
# Append new message to existing conversation
|
||||
@@ -138,6 +160,25 @@ class ConversationService:
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Dual-write to Postgres: append the same message.
|
||||
def _pg_append(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append)
|
||||
return conversation_id
|
||||
else:
|
||||
# Create new conversation
|
||||
@@ -193,7 +234,34 @@ class ConversationService:
|
||||
if agent:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
inserted_id = str(result.inserted_id)
|
||||
|
||||
# Dual-write to Postgres: create the conversation row with
|
||||
# legacy_mongo_id and append the first message.
|
||||
def _pg_create(repo: ConversationsRepository) -> None:
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
completion,
|
||||
agent_id=conversation_data.get("agent_id"),
|
||||
api_key=conversation_data.get("api_key"),
|
||||
is_shared_usage=conversation_data.get("is_shared_usage", False),
|
||||
shared_token=conversation_data.get("shared_token"),
|
||||
legacy_mongo_id=inserted_id,
|
||||
)
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_create)
|
||||
return inserted_id
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
@@ -230,6 +298,24 @@ class ConversationService:
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres: mirror $set + $push $slice.
|
||||
def _pg_compression(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.set_compression_flags(
|
||||
conv["id"],
|
||||
is_compressed=True,
|
||||
last_compression_at=compression_metadata.get("timestamp"),
|
||||
)
|
||||
repo.append_compression_point(
|
||||
conv["id"],
|
||||
compression_metadata,
|
||||
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_compression)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
@@ -266,6 +352,23 @@ class ConversationService:
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _pg_append_summary(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
"timestamp": timestamp,
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append_summary)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -13,6 +13,8 @@ from application.api.user.base import (
|
||||
agent_folders_collection,
|
||||
agents_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
|
||||
agents_folders_ns = Namespace(
|
||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||
@@ -83,6 +85,10 @@ class AgentFolders(Resource):
|
||||
"updated_at": now,
|
||||
}
|
||||
result = agent_folders_collection.insert_one(folder)
|
||||
dual_write(
|
||||
AgentFoldersRepository,
|
||||
lambda repo, u=user, n=data["name"]: repo.create(u, n),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
||||
201,
|
||||
@@ -167,6 +173,10 @@ class AgentFolder(Resource):
|
||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
||||
)
|
||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
||||
dual_write(
|
||||
AgentFoldersRepository,
|
||||
lambda repo, fid=folder_id, u=user: repo.delete(fid, u),
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@@ -24,6 +24,7 @@ from application.api.user.base import (
|
||||
workflows_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
@@ -113,6 +114,35 @@ AGENT_TYPE_SCHEMAS["research"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
|
||||
|
||||
def _build_pg_agent_fields(fields: dict) -> dict:
|
||||
"""Translate Mongo-shaped agent fields into the Postgres mirror subset."""
|
||||
allowed = {
|
||||
"name",
|
||||
"description",
|
||||
"agent_type",
|
||||
"status",
|
||||
"key",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"tools",
|
||||
"json_schema",
|
||||
"models",
|
||||
"default_model_id",
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"incoming_webhook_token",
|
||||
"lastUsedAt",
|
||||
}
|
||||
translated: dict = {}
|
||||
for key, value in fields.items():
|
||||
if key not in allowed:
|
||||
continue
|
||||
translated["last_used_at" if key == "lastUsedAt" else key] = value
|
||||
return translated
|
||||
|
||||
|
||||
def normalize_workflow_reference(workflow_value):
|
||||
"""Normalize workflow references from form/json payloads."""
|
||||
if workflow_value is None:
|
||||
@@ -623,6 +653,18 @@ class CreateAgent(Resource):
|
||||
new_agent["retriever"] = "classic"
|
||||
resp = agents_collection.insert_one(new_agent)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
AgentsRepository,
|
||||
lambda repo, u=user, a=new_agent, mid=new_id: repo.create(
|
||||
u, a.get("name", ""), a.get("status", "draft"),
|
||||
key=a.get("key"), description=a.get("description"),
|
||||
retriever=a.get("retriever"), chunks=a.get("chunks"),
|
||||
tools=a.get("tools"), models=a.get("models"),
|
||||
shared=a.get("shared", False),
|
||||
incoming_webhook_token=a.get("incoming_webhook_token"),
|
||||
legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -1158,6 +1200,14 @@ class UpdateAgent(Resource):
|
||||
jsonify({"success": False, "message": "Database error during update"}),
|
||||
500,
|
||||
)
|
||||
pg_update_fields = _build_pg_agent_fields(update_fields)
|
||||
if pg_update_fields:
|
||||
dual_write(
|
||||
AgentsRepository,
|
||||
lambda repo, aid=agent_id, u=user, fields=pg_update_fields: repo.update_by_legacy_id(
|
||||
aid, u, fields,
|
||||
),
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": agent_id,
|
||||
@@ -1185,6 +1235,10 @@ class DeleteAgent(Resource):
|
||||
deleted_agent = agents_collection.find_one_and_delete(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
dual_write(
|
||||
AgentsRepository,
|
||||
lambda repo, aid=agent_id, u=user: repo.delete_by_legacy_id(aid, u),
|
||||
)
|
||||
if not deleted_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
|
||||
@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
@@ -30,15 +32,23 @@ class DeleteConversation(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token["sub"]
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
{"_id": ObjectId(conversation_id), "user": user_id}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_delete(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
repo.delete(conv["id"], user_id)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_delete)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -59,6 +69,11 @@ class DeleteAllConversations(Resource):
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
dual_write(
|
||||
ConversationsRepository,
|
||||
lambda r, uid=user_id: r.delete_all_for_user(uid),
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -190,9 +205,10 @@ class UpdateConversationName(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"_id": ObjectId(data["id"]), "user": user_id},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -200,6 +216,13 @@ class UpdateConversationName(Resource):
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_rename(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["id"])
|
||||
if conv is not None:
|
||||
repo.rename(conv["id"], user_id, data["name"])
|
||||
|
||||
dual_write(ConversationsRepository, _pg_rename)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -277,4 +300,21 @@ class SubmitFeedback(Resource):
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
# Dual-write to Postgres: mirror the per-message feedback set/unset.
|
||||
feedback_value = data["feedback"]
|
||||
question_index = int(data["question_index"])
|
||||
feedback_payload = (
|
||||
None if feedback_value is None
|
||||
else {"text": feedback_value, "timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat()}
|
||||
)
|
||||
|
||||
def _pg_feedback(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["conversation_id"])
|
||||
if conv is not None:
|
||||
repo.set_feedback(conv["id"], question_index, feedback_payload)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_feedback)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
@@ -49,6 +51,12 @@ class CreatePrompt(Resource):
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, u=user, n=data["name"], c=data["content"], mid=new_id: repo.create(
|
||||
u, n, c, legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -149,6 +157,10 @@ class DeletePrompt(Resource):
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user: repo.delete_by_legacy_id(pid, u),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -185,6 +197,12 @@ class UpdatePrompt(Resource):
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update_by_legacy_id(
|
||||
pid, u, n, c,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
@@ -15,8 +15,71 @@ from application.api.user.base import (
|
||||
conversations_collection,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.shared_conversations import (
|
||||
SharedConversationsRepository,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
|
||||
def _dual_write_share(
|
||||
mongo_conv_id: str,
|
||||
share_uuid: str,
|
||||
user: str,
|
||||
*,
|
||||
is_promptable: bool,
|
||||
first_n_queries: int,
|
||||
api_key: str | None,
|
||||
prompt_id: str | None = None,
|
||||
chunks: int | None = None,
|
||||
) -> None:
|
||||
"""Mirror a Mongo share-record insert into Postgres.
|
||||
|
||||
Preserves the Mongo-generated UUID so public ``/shared/{uuid}`` URLs
|
||||
resolve from both stores during cutover.
|
||||
"""
|
||||
def _write(repo: SharedConversationsRepository) -> None:
|
||||
conv = ConversationsRepository(repo._conn).get_by_legacy_id(
|
||||
mongo_conv_id, user_id=user,
|
||||
)
|
||||
if conv is None:
|
||||
return
|
||||
# prompt_id / chunks are only meaningful for promptable shares;
|
||||
# prompt_id is often the string "default" or an ObjectId that
|
||||
# hasn't been migrated — pass as-is and let the repo drop
|
||||
# non-UUID values. Scope the prompt lookup by user_id so an
|
||||
# authenticated caller can't link another user's prompt into
|
||||
# their share record.
|
||||
resolved_prompt_id = None
|
||||
if prompt_id and len(str(prompt_id)) == 24:
|
||||
from sqlalchemy import text as _text
|
||||
row = repo._conn.execute(
|
||||
_text(
|
||||
"SELECT id FROM prompts "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": str(prompt_id), "user_id": user},
|
||||
).fetchone()
|
||||
if row:
|
||||
resolved_prompt_id = str(row[0])
|
||||
# get_or_create is race-free on the PG side thanks to the
|
||||
# composite partial unique index on the dedup tuple
|
||||
# (migration 0008). It converges concurrent share requests to
|
||||
# a single row.
|
||||
repo.get_or_create(
|
||||
conv["id"],
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=first_n_queries,
|
||||
api_key=api_key,
|
||||
prompt_id=resolved_prompt_id,
|
||||
chunks=chunks,
|
||||
share_uuid=share_uuid,
|
||||
)
|
||||
|
||||
dual_write(SharedConversationsRepository, _write)
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
@@ -124,6 +187,16 @@ class ShareConversation(Resource):
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=api_uuid,
|
||||
prompt_id=prompt_id,
|
||||
chunks=int(chunks) if chunks else None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -155,6 +228,16 @@ class ShareConversation(Resource):
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=api_uuid,
|
||||
prompt_id=prompt_id,
|
||||
chunks=int(chunks) if chunks else None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -192,6 +275,14 @@ class ShareConversation(Resource):
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
_dual_write_share(
|
||||
conversation_id,
|
||||
str(explicit_binary.as_uuid()),
|
||||
user,
|
||||
is_promptable=is_promptable,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=None,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
|
||||
@@ -134,6 +134,12 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=60),
|
||||
cleanup_pending_tool_state.s(),
|
||||
name="cleanup-pending-tool-state",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -146,3 +152,27 @@ def mcp_oauth_task(self, config, user):
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Delete pending_tool_state rows past their TTL.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||
no native TTL, so this task runs every 60 seconds to keep
|
||||
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||
configured (keeps the task runnable in Mongo-only environments).
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||
return {"deleted": deleted}
|
||||
|
||||
@@ -9,6 +9,8 @@ from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
@@ -294,6 +296,13 @@ class CreateTool(Resource):
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
UserToolsRepository,
|
||||
lambda repo, u=user, t=new_tool: repo.create(
|
||||
u, t["name"], config=t.get("config"),
|
||||
custom_name=t.get("customName"), display_name=t.get("displayName"),
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -581,6 +590,10 @@ class DeleteTool(Resource):
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
dual_write(
|
||||
UserToolsRepository,
|
||||
lambda repo, tid=data["id"], u=user: repo.delete(tid, u),
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
|
||||
@@ -11,6 +11,10 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
@@ -35,6 +39,174 @@ def _workflow_error_response(message: str, err: Exception):
|
||||
return error_response(message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Postgres dual-write helpers
|
||||
#
|
||||
# Workflows are unusual relative to other Phase 3 tables: a single user
|
||||
# action (create / update) writes to three collections in concert
|
||||
# (workflows + workflow_nodes + workflow_edges) and the edges reference
|
||||
# nodes by user-provided string ids. The Postgres mirror needs to:
|
||||
#
|
||||
# 1. Run all three writes inside one PG transaction (so the just-created
|
||||
# nodes are visible when we resolve their UUIDs for the edge insert).
|
||||
# 2. Translate edge source_id/target_id strings → workflow_nodes.id UUIDs
|
||||
# after the bulk_create returns them.
|
||||
#
|
||||
# Each helper opens exactly one ``dual_write`` call (one PG txn) and uses
|
||||
# the connection from whichever repo it was instantiated with to spin up
|
||||
# any sibling repos it needs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _dual_write_workflow_create(
|
||||
mongo_workflow_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
graph_version: int = 1,
|
||||
) -> None:
|
||||
"""Mirror a Mongo workflow create into Postgres."""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
conn = repo._conn
|
||||
wf = repo.create(
|
||||
user_id,
|
||||
name,
|
||||
description=description,
|
||||
legacy_mongo_id=mongo_workflow_id,
|
||||
)
|
||||
_write_graph(conn, wf["id"], graph_version, nodes_data, edges_data)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _dual_write_workflow_update(
|
||||
mongo_workflow_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
next_graph_version: int,
|
||||
) -> None:
|
||||
"""Mirror a Mongo workflow update into Postgres.
|
||||
|
||||
Mirrors the Mongo route: insert the new graph_version's nodes/edges,
|
||||
bump the workflow's name/description/current_graph_version, then drop
|
||||
every other graph_version's nodes/edges.
|
||||
"""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
conn = repo._conn
|
||||
wf = _resolve_pg_workflow(conn, mongo_workflow_id)
|
||||
if wf is None:
|
||||
return
|
||||
_write_graph(conn, wf["id"], next_graph_version, nodes_data, edges_data)
|
||||
repo.update(wf["id"], user_id, {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"current_graph_version": next_graph_version,
|
||||
})
|
||||
WorkflowNodesRepository(conn).delete_other_versions(
|
||||
wf["id"], next_graph_version,
|
||||
)
|
||||
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||
wf["id"], next_graph_version,
|
||||
)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _dual_write_workflow_delete(mongo_workflow_id: str, user_id: str) -> None:
|
||||
"""Mirror a Mongo workflow delete into Postgres.
|
||||
|
||||
The CASCADE on workflows.id → workflow_nodes/workflow_edges takes
|
||||
care of the children automatically.
|
||||
"""
|
||||
|
||||
def _do(repo: WorkflowsRepository) -> None:
|
||||
wf = _resolve_pg_workflow(repo._conn, mongo_workflow_id)
|
||||
if wf is not None:
|
||||
repo.delete(wf["id"], user_id)
|
||||
|
||||
dual_write(WorkflowsRepository, _do)
|
||||
|
||||
|
||||
def _resolve_pg_workflow(conn, mongo_workflow_id: str) -> Optional[Dict]:
|
||||
"""Look up a Postgres workflow by its Mongo ObjectId string."""
|
||||
from sqlalchemy import text as _text
|
||||
row = conn.execute(
|
||||
_text("SELECT id FROM workflows WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": mongo_workflow_id},
|
||||
).fetchone()
|
||||
return {"id": str(row[0])} if row else None
|
||||
|
||||
|
||||
def _write_graph(
|
||||
conn,
|
||||
pg_workflow_id: str,
|
||||
graph_version: int,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
) -> None:
|
||||
"""Bulk-create nodes + edges for one graph version inside one txn.
|
||||
|
||||
Edges arrive with source/target as user-provided node-id strings
|
||||
(the same shape the Mongo route stores). We bulk-insert nodes first,
|
||||
capture their ``node_id → UUID`` map from the returned rows, then
|
||||
translate edge source/target strings to those UUIDs before the edge
|
||||
bulk insert. Edges referencing missing nodes are dropped (logged).
|
||||
"""
|
||||
nodes_repo = WorkflowNodesRepository(conn)
|
||||
edges_repo = WorkflowEdgesRepository(conn)
|
||||
|
||||
if nodes_data:
|
||||
created_nodes = nodes_repo.bulk_create(
|
||||
pg_workflow_id, graph_version,
|
||||
[
|
||||
{
|
||||
"node_id": n["id"],
|
||||
"node_type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||
}
|
||||
for n in nodes_data
|
||||
],
|
||||
)
|
||||
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
|
||||
else:
|
||||
node_uuid_by_str = {}
|
||||
|
||||
if edges_data:
|
||||
translated_edges: List[Dict] = []
|
||||
for e in edges_data:
|
||||
src = e.get("source")
|
||||
tgt = e.get("target")
|
||||
from_uuid = node_uuid_by_str.get(src)
|
||||
to_uuid = node_uuid_by_str.get(tgt)
|
||||
if not from_uuid or not to_uuid:
|
||||
current_app.logger.warning(
|
||||
"PG dual-write: dropping edge %s; node refs unresolved "
|
||||
"(source=%s, target=%s)",
|
||||
e.get("id"), src, tgt,
|
||||
)
|
||||
continue
|
||||
translated_edges.append({
|
||||
"edge_id": e["id"],
|
||||
"from_node_id": from_uuid,
|
||||
"to_node_id": to_uuid,
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
})
|
||||
if translated_edges:
|
||||
edges_repo.bulk_create(pg_workflow_id, graph_version, translated_edges)
|
||||
|
||||
|
||||
def serialize_workflow(w: Dict) -> Dict:
|
||||
"""Serialize workflow document to API response format."""
|
||||
return {
|
||||
@@ -317,24 +489,28 @@ def _can_reach_end(
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow nodes into database."""
|
||||
) -> List[Dict]:
|
||||
"""Insert workflow nodes into Mongo and return rows with Mongo ids."""
|
||||
if nodes_data:
|
||||
workflow_nodes_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
)
|
||||
mongo_nodes = [
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
result = workflow_nodes_collection.insert_many(mongo_nodes)
|
||||
return [
|
||||
{**node, "legacy_mongo_id": str(inserted_id)}
|
||||
for node, inserted_id in zip(nodes_data, result.inserted_ids)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def create_workflow_edges(
|
||||
@@ -399,7 +575,7 @@ class WorkflowList(Resource):
|
||||
workflow_id = str(result.inserted_id)
|
||||
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
@@ -407,6 +583,15 @@ class WorkflowList(Resource):
|
||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||
return _workflow_error_response("Failed to create workflow structure", err)
|
||||
|
||||
_dual_write_workflow_create(
|
||||
workflow_id,
|
||||
user_id,
|
||||
name,
|
||||
data.get("description", ""),
|
||||
created_nodes,
|
||||
edges_data,
|
||||
)
|
||||
|
||||
return success_response({"id": workflow_id}, 201)
|
||||
|
||||
|
||||
@@ -473,7 +658,9 @@ class WorkflowDetail(Resource):
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
||||
created_nodes = create_workflow_nodes(
|
||||
workflow_id, nodes_data, next_graph_version,
|
||||
)
|
||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many(
|
||||
@@ -520,6 +707,16 @@ class WorkflowDetail(Resource):
|
||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||
)
|
||||
|
||||
_dual_write_workflow_update(
|
||||
workflow_id,
|
||||
user_id,
|
||||
name,
|
||||
data.get("description", ""),
|
||||
created_nodes,
|
||||
edges_data,
|
||||
next_graph_version,
|
||||
)
|
||||
|
||||
return success_response()
|
||||
|
||||
@require_auth
|
||||
@@ -543,4 +740,6 @@ class WorkflowDetail(Resource):
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to delete workflow", err)
|
||||
|
||||
_dual_write_workflow_delete(workflow_id, user_id)
|
||||
|
||||
return success_response()
|
||||
|
||||
@@ -28,13 +28,11 @@ class Settings(BaseSettings):
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
# User-data Postgres DB Optional during the MongoDB→Postgres migration; becomes required once the migration is
|
||||
# complete.
|
||||
# User-data Postgres DB.
|
||||
POSTGRES_URI: Optional[str] = None
|
||||
|
||||
# MongoDB→Postgres migration switches
|
||||
# MongoDB→Postgres migration: dual-write to Postgres (Mongo stays source of truth)
|
||||
USE_POSTGRES: bool = False
|
||||
READ_POSTGRES: bool = False
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
@@ -72,6 +70,10 @@ class Settings(BaseSettings):
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
|
||||
# Confluence Cloud integration
|
||||
CONFLUENCE_CLIENT_ID: Optional[str] = None
|
||||
CONFLUENCE_CLIENT_SECRET: Optional[str] = None
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
|
||||
@@ -157,5 +157,21 @@ def _log_to_mongodb(
|
||||
user_logs_collection.insert_one(log_entry)
|
||||
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
dual_write(
|
||||
StackLogsRepository,
|
||||
lambda repo, e=log_entry: repo.insert(
|
||||
activity_id=e["id"],
|
||||
endpoint=e.get("endpoint"),
|
||||
level=e.get("level"),
|
||||
user_id=e.get("user"),
|
||||
api_key=e.get("api_key"),
|
||||
query=e.get("query"),
|
||||
stacks=e.get("stacks"),
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)
|
||||
|
||||
4
application/parser/connectors/confluence/__init__.py
Normal file
4
application/parser/connectors/confluence/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .auth import ConfluenceAuth
|
||||
from .loader import ConfluenceLoader
|
||||
|
||||
__all__ = ["ConfluenceAuth", "ConfluenceLoader"]
|
||||
216
application/parser/connectors/confluence/auth.py
Normal file
216
application/parser/connectors/confluence/auth.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfluenceAuth(BaseConnectorAuth):
|
||||
|
||||
SCOPES = [
|
||||
"read:page:confluence",
|
||||
"read:space:confluence",
|
||||
"read:attachment:confluence",
|
||||
"read:me",
|
||||
"offline_access",
|
||||
]
|
||||
|
||||
AUTH_URL = "https://auth.atlassian.com/authorize"
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
ME_URL = "https://api.atlassian.com/me"
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.CONFLUENCE_CLIENT_ID
|
||||
self.client_secret = settings.CONFLUENCE_CLIENT_SECRET
|
||||
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError(
|
||||
"Confluence OAuth credentials not configured. "
|
||||
"Please set CONFLUENCE_CLIENT_ID and CONFLUENCE_CLIENT_SECRET in settings."
|
||||
)
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
params = {
|
||||
"audience": "api.atlassian.com",
|
||||
"client_id": self.client_id,
|
||||
"scope": " ".join(self.SCOPES),
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": state,
|
||||
"response_type": "code",
|
||||
"prompt": "consent",
|
||||
}
|
||||
return f"{self.AUTH_URL}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
if not authorization_code:
|
||||
raise ValueError("Authorization code is required")
|
||||
|
||||
response = requests.post(
|
||||
self.TOKEN_URL,
|
||||
json={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": authorization_code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError("OAuth flow did not return an access token")
|
||||
|
||||
refresh_token = token_data.get("refresh_token")
|
||||
if not refresh_token:
|
||||
raise ValueError("OAuth flow did not return a refresh token")
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
expiry = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expires_in)
|
||||
).isoformat()
|
||||
|
||||
cloud_id = self._fetch_cloud_id(access_token)
|
||||
user_info = self._fetch_user_info(access_token)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_uri": self.TOKEN_URL,
|
||||
"scopes": self.SCOPES,
|
||||
"expiry": expiry,
|
||||
"cloud_id": cloud_id,
|
||||
"user_info": {
|
||||
"name": user_info.get("display_name", ""),
|
||||
"email": user_info.get("email", ""),
|
||||
},
|
||||
}
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
if not refresh_token:
|
||||
raise ValueError("Refresh token is required")
|
||||
|
||||
response = requests.post(
|
||||
self.TOKEN_URL,
|
||||
json={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
access_token = token_data.get("access_token")
|
||||
new_refresh_token = token_data.get("refresh_token", refresh_token)
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
expiry = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(seconds=expires_in)
|
||||
).isoformat()
|
||||
|
||||
cloud_id = self._fetch_cloud_id(access_token)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
"token_uri": self.TOKEN_URL,
|
||||
"scopes": self.SCOPES,
|
||||
"expiry": expiry,
|
||||
"cloud_id": cloud_id,
|
||||
}
|
||||
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
if not token_info:
|
||||
return True
|
||||
|
||||
expiry = token_info.get("expiry")
|
||||
if not expiry:
|
||||
return bool(token_info.get("access_token"))
|
||||
|
||||
try:
|
||||
expiry_dt = datetime.datetime.fromisoformat(expiry)
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return now >= expiry_dt - datetime.timedelta(seconds=60)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings as app_settings
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[app_settings.MONGO_DB_NAME]
|
||||
|
||||
session = db["connector_sessions"].find_one({"session_token": session_token})
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
|
||||
token_info = session.get("token_info")
|
||||
if not token_info:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
required = ["access_token", "refresh_token", "cloud_id"]
|
||||
missing = [f for f in required if not token_info.get(f)]
|
||||
if missing:
|
||||
raise ValueError(f"Missing required token fields: {missing}")
|
||||
|
||||
return token_info
|
||||
|
||||
def sanitize_token_info(
|
||||
self, token_info: Dict[str, Any], **extra_fields
|
||||
) -> Dict[str, Any]:
|
||||
return super().sanitize_token_info(
|
||||
token_info,
|
||||
cloud_id=token_info.get("cloud_id"),
|
||||
**extra_fields,
|
||||
)
|
||||
|
||||
def _fetch_cloud_id(self, access_token: str) -> str:
|
||||
response = requests.get(
|
||||
self.RESOURCES_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
resources = response.json()
|
||||
|
||||
if not resources:
|
||||
raise ValueError("No accessible Confluence sites found for this account")
|
||||
|
||||
return resources[0]["id"]
|
||||
|
||||
def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
response = requests.get(
|
||||
self.ME_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch user info: %s", e)
|
||||
return {}
|
||||
416
application/parser/connectors/confluence/loader.py
Normal file
416
application/parser/connectors/confluence/loader.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_V2 = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki/api/v2"
|
||||
DOWNLOAD_BASE = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki"
|
||||
|
||||
SUPPORTED_ATTACHMENT_TYPES = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"text/html": ".html",
|
||||
"text/markdown": ".md",
|
||||
"application/json": ".json",
|
||||
"application/epub+zip": ".epub",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
}
|
||||
|
||||
|
||||
def _retry_on_auth_failure(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code in (401, 403):
|
||||
logger.info(
|
||||
"Auth failure in %s, refreshing token and retrying", func.__name__
|
||||
)
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info["access_token"]
|
||||
self.refresh_token = new_token_info.get(
|
||||
"refresh_token", self.refresh_token
|
||||
)
|
||||
self._persist_refreshed_tokens(new_token_info)
|
||||
except Exception as refresh_err:
|
||||
raise ValueError(
|
||||
f"Authentication failed and could not be refreshed: {refresh_err}"
|
||||
) from e
|
||||
return func(self, *args, **kwargs)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ConfluenceLoader(BaseConnectorLoader):
|
||||
|
||||
def __init__(self, session_token: str):
|
||||
self.auth = ConfluenceAuth()
|
||||
self.session_token = session_token
|
||||
|
||||
token_info = self.auth.get_token_info_from_session(session_token)
|
||||
self.access_token = token_info["access_token"]
|
||||
self.refresh_token = token_info["refresh_token"]
|
||||
self.cloud_id = token_info["cloud_id"]
|
||||
|
||||
self.base_url = API_V2.format(cloud_id=self.cloud_id)
|
||||
self.download_base = DOWNLOAD_BASE.format(cloud_id=self.cloud_id)
|
||||
self.next_page_token = None
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def _persist_refreshed_tokens(self, token_info: Dict[str, Any]) -> None:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings as app_settings
|
||||
|
||||
sanitized = self.auth.sanitize_token_info(token_info)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[app_settings.MONGO_DB_NAME]
|
||||
db["connector_sessions"].update_one(
|
||||
{"session_token": self.session_token},
|
||||
{"$set": {"token_info": sanitized}},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist refreshed tokens: %s", e)
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
folder_id = inputs.get("folder_id")
|
||||
file_ids = inputs.get("file_ids", [])
|
||||
limit = inputs.get("limit", 100)
|
||||
list_only = inputs.get("list_only", False)
|
||||
page_token = inputs.get("page_token")
|
||||
search_query = inputs.get("search_query")
|
||||
self.next_page_token = None
|
||||
|
||||
if file_ids:
|
||||
return self._load_pages_by_ids(file_ids, list_only, search_query)
|
||||
|
||||
if folder_id:
|
||||
return self._list_pages_in_space(
|
||||
folder_id, limit, list_only, page_token, search_query
|
||||
)
|
||||
|
||||
return self._list_spaces(limit, page_token, search_query)
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
|
||||
config = source_config or getattr(self, "config", {})
|
||||
file_ids = config.get("file_ids", [])
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
files_downloaded = 0
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [file_ids]
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [folder_ids]
|
||||
|
||||
for page_id in file_ids:
|
||||
if self._download_page(page_id, local_dir):
|
||||
files_downloaded += 1
|
||||
files_downloaded += self._download_page_attachments(page_id, local_dir)
|
||||
|
||||
for space_id in folder_ids:
|
||||
files_downloaded += self._download_space(space_id, local_dir)
|
||||
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": files_downloaded == 0,
|
||||
"source_type": "confluence",
|
||||
"config_used": config,
|
||||
}
|
||||
|
||||
def _list_spaces(
|
||||
self, limit: int, cursor: Optional[str], search_query: Optional[str]
|
||||
) -> List[Document]:
|
||||
documents: List[Document] = []
|
||||
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/spaces",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
for space in data.get("results", []):
|
||||
name = space.get("name", "")
|
||||
if search_query and search_query.lower() not in name.lower():
|
||||
continue
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
text="",
|
||||
doc_id=space["id"],
|
||||
extra_info={
|
||||
"file_name": name,
|
||||
"mime_type": "folder",
|
||||
"size": None,
|
||||
"created_time": space.get("createdAt"),
|
||||
"modified_time": None,
|
||||
"source": "confluence",
|
||||
"is_folder": True,
|
||||
"space_key": space.get("key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
self.next_page_token = self._extract_cursor(next_link)
|
||||
return documents
|
||||
|
||||
def _list_pages_in_space(
|
||||
self,
|
||||
space_id: str,
|
||||
limit: int,
|
||||
list_only: bool,
|
||||
cursor: Optional[str],
|
||||
search_query: Optional[str],
|
||||
) -> List[Document]:
|
||||
documents: List[Document] = []
|
||||
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/spaces/{space_id}/pages",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
for page in data.get("results", []):
|
||||
title = page.get("title", "")
|
||||
if search_query and search_query.lower() not in title.lower():
|
||||
continue
|
||||
|
||||
doc = self._page_to_document(
|
||||
page, load_content=not list_only, space_id=space_id
|
||||
)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
self.next_page_token = self._extract_cursor(next_link)
|
||||
return documents
|
||||
|
||||
def _load_pages_by_ids(
|
||||
self, page_ids: List[str], list_only: bool, search_query: Optional[str]
|
||||
) -> List[Document]:
|
||||
documents: List[Document] = []
|
||||
for page_id in page_ids:
|
||||
try:
|
||||
params: Dict[str, str] = {}
|
||||
if not list_only:
|
||||
params["body-format"] = "storage"
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/pages/{page_id}",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
page = response.json()
|
||||
|
||||
title = page.get("title", "")
|
||||
if search_query and search_query.lower() not in title.lower():
|
||||
continue
|
||||
|
||||
doc = self._page_to_document(page, load_content=not list_only)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logger.error("Error loading page %s: %s", page_id, e)
|
||||
return documents
|
||||
|
||||
def _page_to_document(
|
||||
self,
|
||||
page: Dict[str, Any],
|
||||
load_content: bool = False,
|
||||
space_id: Optional[str] = None,
|
||||
) -> Optional[Document]:
|
||||
page_id = page.get("id")
|
||||
title = page.get("title", "Unknown")
|
||||
version = page.get("version", {})
|
||||
modified_time = version.get("createdAt") if isinstance(version, dict) else None
|
||||
created_time = page.get("createdAt")
|
||||
resolved_space_id = space_id or page.get("spaceId")
|
||||
|
||||
text = ""
|
||||
if load_content:
|
||||
body = page.get("body", {})
|
||||
storage = body.get("storage", {}) if isinstance(body, dict) else {}
|
||||
text = storage.get("value", "") if isinstance(storage, dict) else ""
|
||||
|
||||
return Document(
|
||||
text=text,
|
||||
doc_id=str(page_id),
|
||||
extra_info={
|
||||
"file_name": title,
|
||||
"mime_type": "text/html",
|
||||
"size": len(text) if text else None,
|
||||
"created_time": created_time,
|
||||
"modified_time": modified_time,
|
||||
"source": "confluence",
|
||||
"is_folder": False,
|
||||
"page_id": str(page_id),
|
||||
"space_id": resolved_space_id,
|
||||
"cloud_id": self.cloud_id,
|
||||
},
|
||||
)
|
||||
|
||||
def _download_page(self, page_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/pages/{page_id}",
|
||||
headers=self._headers(),
|
||||
params={"body-format": "storage"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
page = response.json()
|
||||
|
||||
title = page.get("title", page_id)
|
||||
safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in title)
|
||||
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||
|
||||
file_path = os.path.join(local_dir, f"{safe_name}.html")
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(body)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error downloading page %s: %s", page_id, e)
|
||||
return False
|
||||
|
||||
def _download_page_attachments(self, page_id: str, local_dir: str) -> int:
|
||||
downloaded = 0
|
||||
try:
|
||||
cursor = None
|
||||
while True:
|
||||
params: Dict[str, Any] = {"limit": 100}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
response = requests.get(
|
||||
f"{self.base_url}/pages/{page_id}/attachments",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
for att in data.get("results", []):
|
||||
media_type = att.get("mediaType", "")
|
||||
if media_type not in SUPPORTED_ATTACHMENT_TYPES:
|
||||
continue
|
||||
|
||||
download_link = att.get("_links", {}).get("download")
|
||||
if not download_link:
|
||||
continue
|
||||
|
||||
raw_name = att.get("title", att.get("id", "attachment"))
|
||||
file_name = "".join(
|
||||
c if c.isalnum() or c in " -_." else "_"
|
||||
for c in os.path.basename(raw_name)
|
||||
) or "attachment"
|
||||
file_path = os.path.join(local_dir, file_name)
|
||||
|
||||
url = f"{self.download_base}{download_link}"
|
||||
file_resp = requests.get(
|
||||
url, headers=self._headers(), timeout=60, stream=True
|
||||
)
|
||||
file_resp.raise_for_status()
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in file_resp.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
downloaded += 1
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
cursor = self._extract_cursor(next_link)
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error downloading attachments for page %s: %s", page_id, e)
|
||||
return downloaded
|
||||
|
||||
def _download_space(self, space_id: str, local_dir: str) -> int:
|
||||
downloaded = 0
|
||||
cursor = None
|
||||
while True:
|
||||
params: Dict[str, Any] = {"limit": 250}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/spaces/{space_id}/pages",
|
||||
headers=self._headers(),
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except Exception as e:
|
||||
logger.error("Error listing pages in space %s: %s", space_id, e)
|
||||
break
|
||||
|
||||
for page in data.get("results", []):
|
||||
page_id = page.get("id")
|
||||
if self._download_page(str(page_id), local_dir):
|
||||
downloaded += 1
|
||||
downloaded += self._download_page_attachments(str(page_id), local_dir)
|
||||
|
||||
next_link = data.get("_links", {}).get("next")
|
||||
cursor = self._extract_cursor(next_link)
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
return downloaded
|
||||
|
||||
@staticmethod
|
||||
def _extract_cursor(next_link: Optional[str]) -> Optional[str]:
|
||||
if not next_link:
|
||||
return None
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
parsed = urlparse(next_link)
|
||||
cursors = parse_qs(parsed.query).get("cursor")
|
||||
return cursors[0] if cursors else None
|
||||
@@ -1,5 +1,7 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||
from application.parser.connectors.confluence.loader import ConfluenceLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||
|
||||
@@ -13,11 +15,13 @@ class ConnectorCreator:
|
||||
"""
|
||||
|
||||
connectors = {
|
||||
"confluence": ConfluenceLoader,
|
||||
"google_drive": GoogleDriveLoader,
|
||||
"share_point": SharePointLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"confluence": ConfluenceAuth,
|
||||
"google_drive": GoogleDriveAuth,
|
||||
"share_point": SharePointAuth,
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
file_metadata = response.json()
|
||||
@@ -236,9 +236,9 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
||||
else:
|
||||
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
||||
response = requests.get(search_url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(search_url, headers=self._get_headers(), params=params, timeout=100)
|
||||
else:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -307,7 +307,8 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
response = requests.get(
|
||||
f"{self.GRAPH_API_BASE}/me/drive",
|
||||
headers=self._get_headers(),
|
||||
params={'$select': 'webUrl'}
|
||||
params={'$select': 'webUrl'},
|
||||
timeout=100,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get('webUrl')
|
||||
@@ -352,7 +353,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
|
||||
headers = self._get_headers()
|
||||
headers["Content-Type"] = "application/json"
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response = requests.post(url, headers=headers, json=body, timeout=100)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
@@ -472,7 +473,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
|
||||
try:
|
||||
url = f"{self._get_item_url(file_id)}/content"
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response = requests.get(url, headers=self._get_headers(), timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
@@ -491,7 +492,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
metadata = response.json()
|
||||
@@ -507,7 +508,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
full_path = os.path.join(local_dir, file_name)
|
||||
|
||||
download_url = f"{self._get_item_url(file_id)}/content"
|
||||
download_response = requests.get(download_url, headers=self._get_headers())
|
||||
download_response = requests.get(download_url, headers=self._get_headers(), timeout=100)
|
||||
download_response.raise_for_status()
|
||||
|
||||
with open(full_path, 'wb') as f:
|
||||
@@ -527,7 +528,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
params = {'$top': 1000}
|
||||
|
||||
while url:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
@@ -609,7 +610,7 @@ class SharePointLoader(BaseConnectorLoader):
|
||||
try:
|
||||
url = self._get_item_url(folder_id)
|
||||
params = {'$select': 'id,name'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
|
||||
folder_metadata = response.json()
|
||||
|
||||
@@ -24,7 +24,7 @@ class PDFParser(BaseParser):
|
||||
# alternatively you can use local vision capable LLM
|
||||
with open(file, "rb") as file_loaded:
|
||||
files = {'file': file_loaded}
|
||||
response = requests.post(doc2md_service, files=files)
|
||||
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||
data = response.json()["markdown"]
|
||||
return data
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class ImageParser(BaseParser):
|
||||
# alternatively you can use local vision capable LLM
|
||||
with open(file, "rb") as file_loaded:
|
||||
files = {'file': file_loaded}
|
||||
response = requests.post(doc2md_service, files=files)
|
||||
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||
data = response.json()["markdown"]
|
||||
else:
|
||||
data = ""
|
||||
|
||||
@@ -77,7 +77,7 @@ class GitHubLoader(BaseRemote):
|
||||
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
for attempt in range(max_retries):
|
||||
response = requests.get(url, headers=self.headers)
|
||||
response = requests.get(url, headers=self.headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
|
||||
@@ -23,13 +23,25 @@ from application.core.settings import settings
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
|
||||
def _resolve_uri() -> str:
|
||||
"""Return the Postgres URI for user-data tables.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that
|
||||
reach this path without a configured URI have a setup bug — the
|
||||
error message points them at the right setting.
|
||||
"""
|
||||
if not settings.POSTGRES_URI:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
return settings.POSTGRES_URI
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
|
||||
|
||||
Returns:
|
||||
A SQLAlchemy ``Engine`` configured with a pooled connection to
|
||||
@@ -37,14 +49,8 @@ def get_engine() -> Engine:
|
||||
"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
if not settings.POSTGRES_URI:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
_engine = create_engine(
|
||||
settings.POSTGRES_URI,
|
||||
_resolve_uri(),
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
|
||||
|
||||
@@ -5,17 +5,31 @@ MongoDB→Postgres migration. The baseline schema in the Alembic migration
|
||||
(``application/alembic/versions/0001_initial.py``) is the source of truth
|
||||
for DDL; the ``Table`` definitions below must match it column-for-column.
|
||||
If the two drift, migrations win — update this file to match.
|
||||
|
||||
Cross-table invariant not expressed in the Core ``Table`` definitions
|
||||
below: every ``user_id`` column is FK-enforced against
|
||||
``users(user_id)`` with ``ON DELETE RESTRICT``, and a
|
||||
``BEFORE INSERT OR UPDATE OF user_id`` trigger on each child table
|
||||
auto-creates the ``users`` row if it does not yet exist. See migration
|
||||
``0015_user_id_fk``. The FKs are intentionally omitted from the Core
|
||||
declarations to keep this file readable; the DB is the authority.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
ForeignKeyConstraint,
|
||||
Integer,
|
||||
MetaData,
|
||||
UniqueConstraint,
|
||||
Table,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||
|
||||
metadata = MetaData()
|
||||
|
||||
@@ -36,3 +50,347 @@ users_table = Table(
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
prompts_table = Table(
|
||||
"prompts",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("content", Text, nullable=False),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
user_tools_table = Table(
|
||||
"user_tools",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("custom_name", Text),
|
||||
Column("display_name", Text),
|
||||
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
token_usage_table = Table(
|
||||
"token_usage",
|
||||
metadata,
|
||||
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||
Column("user_id", Text),
|
||||
Column("api_key", Text),
|
||||
Column("agent_id", UUID(as_uuid=True)),
|
||||
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
user_logs_table = Table(
|
||||
"user_logs",
|
||||
metadata,
|
||||
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||
Column("user_id", Text),
|
||||
Column("endpoint", Text),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("data", JSONB),
|
||||
)
|
||||
|
||||
stack_logs_table = Table(
|
||||
"stack_logs",
|
||||
metadata,
|
||||
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||
Column("activity_id", Text, nullable=False),
|
||||
Column("endpoint", Text),
|
||||
Column("level", Text),
|
||||
Column("user_id", Text),
|
||||
Column("api_key", Text),
|
||||
Column("query", Text),
|
||||
Column("stacks", JSONB, nullable=False, server_default="[]"),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||
|
||||
agent_folders_table = Table(
|
||||
"agent_folders",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("description", Text),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
sources_table = Table(
|
||||
"sources",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("type", Text),
|
||||
Column("metadata", JSONB, nullable=False, server_default="{}"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
agents_table = Table(
|
||||
"agents",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("description", Text),
|
||||
Column("agent_type", Text),
|
||||
Column("status", Text, nullable=False),
|
||||
Column("key", CITEXT, unique=True),
|
||||
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
|
||||
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||
Column("chunks", Integer),
|
||||
Column("retriever", Text),
|
||||
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
|
||||
Column("tools", JSONB, nullable=False, server_default="[]"),
|
||||
Column("json_schema", JSONB),
|
||||
Column("models", JSONB),
|
||||
Column("default_model_id", Text),
|
||||
Column("folder_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
|
||||
Column("limited_token_mode", Boolean, nullable=False, server_default="false"),
|
||||
Column("token_limit", Integer),
|
||||
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
|
||||
Column("request_limit", Integer),
|
||||
Column("shared", Boolean, nullable=False, server_default="false"),
|
||||
Column("incoming_webhook_token", CITEXT, unique=True),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("last_used_at", DateTime(timezone=True)),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
attachments_table = Table(
|
||||
"attachments",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("filename", Text, nullable=False),
|
||||
Column("upload_path", Text, nullable=False),
|
||||
Column("mime_type", Text),
|
||||
Column("size", BigInteger),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
memories_table = Table(
|
||||
"memories",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||
Column("path", Text, nullable=False),
|
||||
Column("content", Text, nullable=False),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("user_id", "tool_id", "path", name="memories_user_tool_path_uidx"),
|
||||
)
|
||||
|
||||
todos_table = Table(
|
||||
"todos",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||
Column("title", Text, nullable=False),
|
||||
Column("completed", Boolean, nullable=False, server_default="false"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
notes_table = Table(
|
||||
"notes",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||
Column("title", Text, nullable=False),
|
||||
Column("content", Text, nullable=False),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("user_id", "tool_id", name="notes_user_tool_uidx"),
|
||||
)
|
||||
|
||||
connector_sessions_table = Table(
|
||||
"connector_sessions",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("provider", Text, nullable=False),
|
||||
Column("session_data", JSONB, nullable=False),
|
||||
Column("expires_at", DateTime(timezone=True)),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("user_id", "provider", name="connector_sessions_user_provider_uidx"),
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||
|
||||
conversations_table = Table(
|
||||
"conversations",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("agent_id", UUID(as_uuid=True), ForeignKey("agents.id", ondelete="SET NULL")),
|
||||
Column("name", Text),
|
||||
Column("api_key", Text),
|
||||
Column("is_shared_usage", Boolean, nullable=False, server_default="false"),
|
||||
Column("shared_token", Text),
|
||||
Column("shared_with", ARRAY(Text), nullable=False, server_default="{}"),
|
||||
Column("compression_metadata", JSONB),
|
||||
Column("date", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
conversation_messages_table = Table(
|
||||
"conversation_messages",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||
# Denormalised from conversations.user_id. Auto-filled on insert by a
|
||||
# BEFORE INSERT trigger when the caller omits it. See migration 0020.
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("position", Integer, nullable=False),
|
||||
Column("prompt", Text),
|
||||
Column("response", Text),
|
||||
Column("thought", Text),
|
||||
Column("sources", JSONB, nullable=False, server_default="[]"),
|
||||
Column("tool_calls", JSONB, nullable=False, server_default="[]"),
|
||||
# Postgres cannot FK-enforce array elements, so the referential
|
||||
# invariant is kept by an AFTER DELETE trigger on ``attachments``
|
||||
# that array_removes the id from every row that references it.
|
||||
# See migration 0017_cleanup_dangling_refs.
|
||||
Column("attachments", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||
Column("model_id", Text),
|
||||
# Renamed from ``metadata`` in migration 0016 to avoid SQLAlchemy's
|
||||
# reserved attribute collision on declarative models. The repository
|
||||
# translates this ↔ API dict key ``metadata`` so external callers
|
||||
# still see ``metadata``.
|
||||
Column("message_metadata", JSONB, nullable=False, server_default="{}"),
|
||||
Column("feedback", JSONB),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||
)
|
||||
|
||||
shared_conversations_table = Table(
|
||||
"shared_conversations",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("uuid", UUID(as_uuid=True), nullable=False, unique=True),
|
||||
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
|
||||
Column("chunks", Integer),
|
||||
Column("is_promptable", Boolean, nullable=False, server_default="false"),
|
||||
Column("first_n_queries", Integer, nullable=False, server_default="0"),
|
||||
Column("api_key", Text),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
pending_tool_state_table = Table(
|
||||
"pending_tool_state",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("messages", JSONB, nullable=False),
|
||||
Column("pending_tool_calls", JSONB, nullable=False),
|
||||
Column("tools_dict", JSONB, nullable=False),
|
||||
Column("tool_schemas", JSONB, nullable=False),
|
||||
Column("agent_config", JSONB, nullable=False),
|
||||
Column("client_tools", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("expires_at", DateTime(timezone=True), nullable=False),
|
||||
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
|
||||
)
|
||||
|
||||
workflows_table = Table(
|
||||
"workflows",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("name", Text, nullable=False),
|
||||
Column("description", Text),
|
||||
Column("current_graph_version", Integer, nullable=False, server_default="1"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
workflow_nodes_table = Table(
|
||||
"workflow_nodes",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("graph_version", Integer, nullable=False),
|
||||
Column("node_id", Text, nullable=False),
|
||||
Column("node_type", Text, nullable=False),
|
||||
Column("title", Text),
|
||||
Column("description", Text),
|
||||
Column("position", JSONB, nullable=False, server_default='{"x": 0, "y": 0}'),
|
||||
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||
Column("legacy_mongo_id", Text),
|
||||
# Composite UNIQUE so workflow_edges can use a composite FK that
|
||||
# enforces endpoint nodes belong to the same (workflow, version) as
|
||||
# the edge itself. See migration 0008.
|
||||
UniqueConstraint(
|
||||
"id", "workflow_id", "graph_version",
|
||||
name="workflow_nodes_id_wf_ver_key",
|
||||
),
|
||||
)
|
||||
|
||||
workflow_edges_table = Table(
|
||||
"workflow_edges",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("graph_version", Integer, nullable=False),
|
||||
Column("edge_id", Text, nullable=False),
|
||||
Column("from_node_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("to_node_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("source_handle", Text),
|
||||
Column("target_handle", Text),
|
||||
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||
# Composite FKs: endpoints must belong to the same (workflow, version)
|
||||
# as the edge. Prevents cross-workflow / cross-version edges that the
|
||||
# single-column FKs couldn't catch. See migration 0008.
|
||||
ForeignKeyConstraint(
|
||||
["from_node_id", "workflow_id", "graph_version"],
|
||||
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||
ondelete="CASCADE",
|
||||
name="workflow_edges_from_node_fk",
|
||||
),
|
||||
ForeignKeyConstraint(
|
||||
["to_node_id", "workflow_id", "graph_version"],
|
||||
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||
ondelete="CASCADE",
|
||||
name="workflow_edges_to_node_fk",
|
||||
),
|
||||
)
|
||||
|
||||
workflow_runs_table = Table(
|
||||
"workflow_runs",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("status", Text, nullable=False),
|
||||
Column("inputs", JSONB),
|
||||
Column("result", JSONB),
|
||||
Column("steps", JSONB, nullable=False, server_default="[]"),
|
||||
Column("started_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("ended_at", DateTime(timezone=True)),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
88
application/storage/db/repositories/agent_folders.py
Normal file
88
application/storage/db/repositories/agent_folders.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Repository for the ``agent_folders`` table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class AgentFoldersRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, name: str, *, description: Optional[str] = None) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO agent_folders (user_id, name, description)
|
||||
VALUES (:user_id, :name, :description)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "name": name, "description": description},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, folder_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": folder_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agent_folders WHERE user_id = :user_id ORDER BY created_at"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, folder_id: str, user_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return False
|
||||
params: dict = {"id": folder_id, "user_id": user_id}
|
||||
if "name" in filtered and "description" in filtered:
|
||||
params["name"] = filtered["name"]
|
||||
params["description"] = filtered["description"]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE agent_folders "
|
||||
"SET name = :name, description = :description, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
elif "name" in filtered:
|
||||
params["name"] = filtered["name"]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE agent_folders "
|
||||
"SET name = :name, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
else:
|
||||
params["description"] = filtered["description"]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE agent_folders "
|
||||
"SET description = :description, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, folder_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": folder_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
195
application/storage/db/repositories/agents.py
Normal file
195
application/storage/db/repositories/agents.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Repository for the ``agents`` table.
|
||||
|
||||
This is the most complex Phase 2 repository. Covers every write operation
|
||||
the legacy Mongo code performs on ``agents_collection``:
|
||||
|
||||
- create, update, delete
|
||||
- find by key (API key lookup)
|
||||
- find by webhook token
|
||||
- list for user, list templates
|
||||
- folder assignment
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, func, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import agents_table
|
||||
|
||||
|
||||
class AgentsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
@staticmethod
|
||||
def _normalize_unique_text(col: str, val):
|
||||
"""Coerce blank strings for nullable unique text columns to NULL."""
|
||||
if col == "key" and val == "":
|
||||
return None
|
||||
return val
|
||||
|
||||
def create(self, user_id: str, name: str, status: str, **kwargs) -> dict:
|
||||
values: dict = {"user_id": user_id, "name": name, "status": status}
|
||||
|
||||
_ALLOWED = {
|
||||
"description", "agent_type", "key", "retriever",
|
||||
"default_model_id", "incoming_webhook_token",
|
||||
"source_id", "prompt_id", "folder_id",
|
||||
"chunks", "token_limit", "request_limit",
|
||||
"limited_token_mode", "limited_request_mode", "shared",
|
||||
"tools", "json_schema", "models", "legacy_mongo_id",
|
||||
}
|
||||
|
||||
for col, val in kwargs.items():
|
||||
if col not in _ALLOWED or val is None:
|
||||
continue
|
||||
if col in ("tools", "json_schema", "models"):
|
||||
# JSONB columns: pass the Python object directly. SQLAlchemy
|
||||
# Core's JSONB type processor json.dumps it once during
|
||||
# bind; pre-serialising would double-encode and the value
|
||||
# would round-trip as a JSON string instead of the dict.
|
||||
values[col] = val
|
||||
elif col in ("chunks", "token_limit", "request_limit"):
|
||||
values[col] = int(val)
|
||||
elif col in ("limited_token_mode", "limited_request_mode", "shared"):
|
||||
values[col] = bool(val)
|
||||
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||
values[col] = str(val)
|
||||
else:
|
||||
values[col] = self._normalize_unique_text(col, val)
|
||||
|
||||
stmt = pg_insert(agents_table).values(**values).returning(agents_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, agent_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": agent_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch an agent by the original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM agents WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def find_by_key(self, key: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agents WHERE key = :key"),
|
||||
{"key": key},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def find_by_webhook_token(self, token: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agents WHERE incoming_webhook_token = :token"),
|
||||
{"token": token},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agents WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def list_templates(self) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM agents WHERE user_id = 'system' ORDER BY name"),
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, agent_id: str, user_id: str, fields: dict) -> bool:
|
||||
allowed = {
|
||||
"name", "description", "agent_type", "status", "key", "source_id",
|
||||
"chunks", "retriever", "prompt_id", "tools", "json_schema", "models",
|
||||
"default_model_id", "folder_id", "limited_token_mode", "token_limit",
|
||||
"limited_request_mode", "request_limit", "shared",
|
||||
"incoming_webhook_token", "last_used_at",
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return False
|
||||
|
||||
values: dict = {}
|
||||
for col, val in filtered.items():
|
||||
if col in ("tools", "json_schema", "models"):
|
||||
# See note in create(): JSONB columns receive Python
|
||||
# objects, the type processor handles serialisation.
|
||||
values[col] = val
|
||||
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||
values[col] = str(val) if val else None
|
||||
else:
|
||||
values[col] = self._normalize_unique_text(col, val)
|
||||
values["updated_at"] = func.now()
|
||||
|
||||
t = agents_table
|
||||
stmt = (
|
||||
t.update()
|
||||
.where(t.c.id == agent_id)
|
||||
.where(t.c.user_id == user_id)
|
||||
.values(**values)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
def update_by_legacy_id(self, legacy_mongo_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Update an agent addressed by the Mongo ObjectId string."""
|
||||
agent = self.get_by_legacy_id(legacy_mongo_id, user_id)
|
||||
if agent is None:
|
||||
return False
|
||||
return self.update(agent["id"], user_id, fields)
|
||||
|
||||
def delete(self, agent_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": agent_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete an agent addressed by the Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM agents "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_folder(self, agent_id: str, user_id: str, folder_id: Optional[str]) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE agents SET folder_id = CAST(:folder_id AS uuid), updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"id": agent_id, "user_id": user_id, "folder_id": folder_id},
|
||||
)
|
||||
|
||||
def clear_folder_for_all(self, folder_id: str, user_id: str) -> None:
|
||||
"""Remove folder assignment from all agents in a folder (used on folder delete)."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"UPDATE agents SET folder_id = NULL, updated_at = now() "
|
||||
"WHERE folder_id = CAST(:folder_id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"folder_id": folder_id, "user_id": user_id},
|
||||
)
|
||||
66
application/storage/db/repositories/attachments.py
Normal file
66
application/storage/db/repositories/attachments.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Repository for the ``attachments`` table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class AttachmentsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, filename: str, upload_path: str, *,
|
||||
mime_type: Optional[str] = None, size: Optional[int] = None,
|
||||
legacy_mongo_id: Optional[str] = None) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO attachments
|
||||
(user_id, filename, upload_path, mime_type, size, legacy_mongo_id)
|
||||
VALUES
|
||||
(:user_id, :filename, :upload_path, :mime_type, :size, :legacy_mongo_id)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"filename": filename,
|
||||
"upload_path": upload_path,
|
||||
"mime_type": mime_type,
|
||||
"size": size,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM attachments WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": attachment_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch an attachment by the original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM attachments WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
65
application/storage/db/repositories/connector_sessions.py
Normal file
65
application/storage/db/repositories/connector_sessions.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Repository for the ``connector_sessions`` table.
|
||||
|
||||
Covers operations across connector routes and tools:
|
||||
- upsert session data
|
||||
- find session by user + provider
|
||||
- find session by token
|
||||
- delete session
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class ConnectorSessionsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def upsert(self, user_id: str, provider: str, session_data: dict) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO connector_sessions (user_id, provider, session_data)
|
||||
VALUES (:user_id, :provider, CAST(:session_data AS jsonb))
|
||||
ON CONFLICT (user_id, provider)
|
||||
DO UPDATE SET session_data = EXCLUDED.session_data
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"session_data": json.dumps(session_data),
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_by_user_provider(self, user_id: str, provider: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"
|
||||
),
|
||||
{"user_id": user_id, "provider": provider},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM connector_sessions WHERE user_id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def delete(self, user_id: str, provider: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"),
|
||||
{"user_id": user_id, "provider": provider},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
476
application/storage/db/repositories/conversations.py
Normal file
476
application/storage/db/repositories/conversations.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""Repository for the ``conversations`` and ``conversation_messages`` tables.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``conversations_collection``:
|
||||
|
||||
- create / get / list / delete conversations
|
||||
- append message (transactional position allocation)
|
||||
- update message at index (overwrite + optional truncation)
|
||||
- set / unset feedback on a message
|
||||
- rename conversation
|
||||
- update compression metadata
|
||||
- shared_with access checks
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import conversations_table, conversation_messages_table
|
||||
|
||||
|
||||
def _message_row_to_dict(row) -> dict:
|
||||
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
|
||||
back to the public API key ``metadata`` so callers keep the Mongo-era
|
||||
shape. See migration 0016 for the column rename rationale."""
|
||||
out = row_to_dict(row)
|
||||
if "message_metadata" in out:
|
||||
out["metadata"] = out.pop("message_metadata")
|
||||
return out
|
||||
|
||||
|
||||
class ConversationsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conversation CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str | None = None,
|
||||
*,
|
||||
agent_id: str | None = None,
|
||||
api_key: str | None = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: str | None = None,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a new conversation.
|
||||
|
||||
``legacy_mongo_id`` is used by the dual-write shim so that a
|
||||
Postgres row inserted *after* a successful Mongo insert carries
|
||||
the Mongo ``_id`` as a lookup key. Subsequent appends/updates
|
||||
can then resolve the PG row by that id via
|
||||
:meth:`get_by_legacy_id`.
|
||||
"""
|
||||
values: dict = {
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
}
|
||||
if agent_id:
|
||||
values["agent_id"] = agent_id
|
||||
if api_key:
|
||||
values["api_key"] = api_key
|
||||
if is_shared_usage:
|
||||
values["is_shared_usage"] = True
|
||||
if shared_token:
|
||||
values["shared_token"] = shared_token
|
||||
if legacy_mongo_id:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(conversations_table).values(**values).returning(conversations_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||
) -> Optional[dict]:
|
||||
"""Look up a conversation by the original Mongo ObjectId string.
|
||||
|
||||
Used by the dual-write helpers to translate a Mongo ``_id`` into
|
||||
the Postgres UUID for follow-up writes. When ``user_id`` is
|
||||
provided, the lookup is scoped to rows owned by that user so
|
||||
callers can't accidentally resolve another user's conversation.
|
||||
"""
|
||||
if user_id is not None:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
else:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations WHERE legacy_mongo_id = :legacy_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch a conversation the user owns or has shared access to."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"AND (user_id = :user_id OR :user_id = ANY(shared_with))"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_owned(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch a conversation owned by the user (no shared access)."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str, limit: int = 30) -> list[dict]:
|
||||
"""List conversations for a user, most recent first.
|
||||
|
||||
Mirrors the Mongo query: either no api_key or agent_id exists.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversations "
|
||||
"WHERE user_id = :user_id "
|
||||
"AND (api_key IS NULL OR agent_id IS NOT NULL) "
|
||||
"ORDER BY date DESC LIMIT :limit"
|
||||
),
|
||||
{"user_id": user_id, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id, "name": name},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_shared_token(self, conversation_id: str, user_id: str, token: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET shared_token = :token, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id, "token": token},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, user_id: str, metadata: dict,
|
||||
) -> bool:
|
||||
"""Replace the entire ``compression_metadata`` JSONB blob.
|
||||
|
||||
Prefer :meth:`append_compression_point` + :meth:`set_compression_flags`
|
||||
to match the Mongo service semantics exactly (those two mirror
|
||||
``$set`` + ``$push $slice``). This method is retained for callers
|
||||
that already compute the full merged blob client-side.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations "
|
||||
"SET compression_metadata = CAST(:meta AS jsonb), updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id, "meta": json.dumps(metadata)},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_compression_flags(
|
||||
self,
|
||||
conversation_id: str,
|
||||
*,
|
||||
is_compressed: bool,
|
||||
last_compression_at,
|
||||
) -> bool:
|
||||
"""Update ``compression_metadata.is_compressed`` and
|
||||
``compression_metadata.last_compression_at`` without touching
|
||||
``compression_points``.
|
||||
|
||||
Mirrors the Mongo ``$set`` on those two subfields in
|
||||
``ConversationService.update_compression_metadata``. Initialises
|
||||
the surrounding object when the row has no ``compression_metadata``
|
||||
yet.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversations SET
|
||||
compression_metadata = jsonb_set(
|
||||
jsonb_set(
|
||||
COALESCE(compression_metadata, '{}'::jsonb),
|
||||
'{is_compressed}',
|
||||
to_jsonb(CAST(:is_compressed AS boolean)), true
|
||||
),
|
||||
'{last_compression_at}',
|
||||
to_jsonb(CAST(:last_compression_at AS text)), true
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": conversation_id,
|
||||
"is_compressed": bool(is_compressed),
|
||||
"last_compression_at": (
|
||||
str(last_compression_at) if last_compression_at is not None else None
|
||||
),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def append_compression_point(
|
||||
self,
|
||||
conversation_id: str,
|
||||
point: dict,
|
||||
*,
|
||||
max_points: int,
|
||||
) -> bool:
|
||||
"""Append one compression point, keeping at most ``max_points``.
|
||||
|
||||
Mirrors Mongo's ``$push {"$each": [point], "$slice": -max_points}``
|
||||
on ``compression_metadata.compression_points``. Preserves the
|
||||
other top-level keys in ``compression_metadata``.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversations SET
|
||||
compression_metadata = jsonb_set(
|
||||
COALESCE(compression_metadata, '{}'::jsonb),
|
||||
'{compression_points}',
|
||||
COALESCE(
|
||||
(
|
||||
SELECT jsonb_agg(elem ORDER BY rn)
|
||||
FROM (
|
||||
SELECT
|
||||
elem,
|
||||
row_number() OVER () AS rn,
|
||||
count(*) OVER () AS cnt
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(
|
||||
compression_metadata -> 'compression_points',
|
||||
'[]'::jsonb
|
||||
) || jsonb_build_array(CAST(:point AS jsonb))
|
||||
) AS elem
|
||||
) ranked
|
||||
WHERE rn > cnt - :max_points
|
||||
),
|
||||
'[]'::jsonb
|
||||
),
|
||||
true
|
||||
),
|
||||
updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": conversation_id,
|
||||
"point": json.dumps(point, default=str),
|
||||
"max_points": int(max_points),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, conversation_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete_all_for_user(self, user_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM conversations WHERE user_id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Messages
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_messages(self, conversation_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"ORDER BY position ASC"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
return [_message_row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def get_message_at(self, conversation_id: str, position: int) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND position = :pos"
|
||||
),
|
||||
{"conv_id": conversation_id, "pos": position},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return _message_row_to_dict(row) if row is not None else None
|
||||
|
||||
def append_message(self, conversation_id: str, message: dict) -> dict:
|
||||
"""Append a message to a conversation.
|
||||
|
||||
Uses ``SELECT ... FOR UPDATE`` to allocate the next position
|
||||
atomically. The caller must be inside a transaction.
|
||||
|
||||
Mirrors Mongo's ``$push`` on the ``queries`` array.
|
||||
"""
|
||||
# Lock the parent conversation row to serialize concurrent appends.
|
||||
self._conn.execute(
|
||||
text(
|
||||
"SELECT id FROM conversations "
|
||||
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
next_pos_result = self._conn.execute(
|
||||
text(
|
||||
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
|
||||
"FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
next_pos = next_pos_result.scalar()
|
||||
|
||||
values = {
|
||||
"conversation_id": conversation_id,
|
||||
"position": next_pos,
|
||||
"prompt": message.get("prompt"),
|
||||
"response": message.get("response"),
|
||||
"thought": message.get("thought"),
|
||||
"sources": message.get("sources") or [],
|
||||
"tool_calls": message.get("tool_calls") or [],
|
||||
"model_id": message.get("model_id"),
|
||||
"message_metadata": message.get("metadata") or {},
|
||||
}
|
||||
if message.get("timestamp") is not None:
|
||||
values["timestamp"] = message["timestamp"]
|
||||
|
||||
attachments = message.get("attachments")
|
||||
if attachments:
|
||||
values["attachments"] = [str(a) for a in attachments]
|
||||
|
||||
stmt = (
|
||||
pg_insert(conversation_messages_table)
|
||||
.values(**values)
|
||||
.returning(conversation_messages_table)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
# Touch the parent conversation's updated_at.
|
||||
self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conversation_id},
|
||||
)
|
||||
return _message_row_to_dict(result.fetchone())
|
||||
|
||||
def update_message_at(
|
||||
self, conversation_id: str, position: int, fields: dict,
|
||||
) -> bool:
|
||||
"""Update specific fields on a message at a given position.
|
||||
|
||||
Mirrors Mongo's ``$set`` on ``queries.{index}.*``.
|
||||
"""
|
||||
allowed = {
|
||||
"prompt", "response", "thought", "sources", "tool_calls",
|
||||
"attachments", "model_id", "metadata", "timestamp",
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return False
|
||||
|
||||
# Map public API key ``metadata`` → DB column ``message_metadata``.
|
||||
api_to_col = {"metadata": "message_metadata"}
|
||||
|
||||
set_parts = []
|
||||
params: dict = {"conv_id": conversation_id, "pos": position}
|
||||
for key, val in filtered.items():
|
||||
col = api_to_col.get(key, key)
|
||||
if key in ("sources", "tool_calls", "metadata"):
|
||||
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
params[col] = json.dumps(val) if not isinstance(val, str) else val
|
||||
elif key == "attachments":
|
||||
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
|
||||
params[col] = [str(a) for a in val] if val else []
|
||||
else:
|
||||
set_parts.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
|
||||
if "timestamp" not in filtered:
|
||||
set_parts.append("timestamp = now()")
|
||||
sql = (
|
||||
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||
)
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
|
||||
"""Delete messages with position > keep_up_to.
|
||||
|
||||
Mirrors Mongo's ``$push`` + ``$slice`` that trims queries after an
|
||||
index-based update.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND position > :pos"
|
||||
),
|
||||
{"conv_id": conversation_id, "pos": keep_up_to},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def set_feedback(
|
||||
self, conversation_id: str, position: int, feedback: dict | None,
|
||||
) -> bool:
|
||||
"""Set or unset feedback on a message.
|
||||
|
||||
``feedback`` is a JSONB value, e.g. ``{"text": "thumbs_up",
|
||||
"timestamp": "..."}`` or ``None`` to unset.
|
||||
"""
|
||||
fb_json = json.dumps(feedback) if feedback is not None else None
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversation_messages "
|
||||
"SET feedback = CAST(:fb AS jsonb) "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||
),
|
||||
{"conv_id": conversation_id, "pos": position, "fb": fb_json},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def message_count(self, conversation_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT COUNT(*) FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
return result.scalar() or 0
|
||||
97
application/storage/db/repositories/memories.py
Normal file
97
application/storage/db/repositories/memories.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Repository for the ``memories`` table.
|
||||
|
||||
Covers the operations in ``application/agents/tools/memory.py``:
|
||||
- upsert (create/overwrite file)
|
||||
- find by path (view file)
|
||||
- find by path prefix (view directory, regex scan)
|
||||
- delete by path / path prefix
|
||||
- rename (update path)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class MemoriesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def upsert(self, user_id: str, tool_id: str, path: str, content: str) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO memories (user_id, tool_id, path, content)
|
||||
VALUES (:user_id, CAST(:tool_id AS uuid), :path, :content)
|
||||
ON CONFLICT (user_id, tool_id, path)
|
||||
DO UPDATE SET content = EXCLUDED.content, updated_at = now()
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "path": path, "content": content},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_by_path(self, user_id: str, tool_id: str, path: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM memories WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "path": path},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM memories WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def delete_by_path(self, user_id: str, tool_id: str, path: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM memories WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "path": path},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM memories WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_all(self, user_id: str, tool_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM memories WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def update_path(self, user_id: str, tool_id: str, old_path: str, new_path: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE memories SET path = :new_path, updated_at = now() "
|
||||
"WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid) AND path = :old_path"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "old_path": old_path, "new_path": new_path},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
62
application/storage/db/repositories/notes.py
Normal file
62
application/storage/db/repositories/notes.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Repository for the ``notes`` table.
|
||||
|
||||
Covers the operations in ``application/agents/tools/notes.py``.
|
||||
Note: the Mongo schema stores a single ``note`` text field per (user_id, tool_id),
|
||||
while the Postgres schema has ``title`` + ``content``. During dual-write,
|
||||
title is set to a default and content holds the note text.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class NotesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def upsert(self, user_id: str, tool_id: str, title: str, content: str) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO notes (user_id, tool_id, title, content)
|
||||
VALUES (:user_id, CAST(:tool_id AS uuid), :title, :content)
|
||||
ON CONFLICT (user_id, tool_id)
|
||||
DO UPDATE SET content = EXCLUDED.content, title = EXCLUDED.title, updated_at = now()
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "title": title, "content": content},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_for_user_tool(self, user_id: str, tool_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get(self, note_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM notes WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": note_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete(self, user_id: str, tool_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
128
application/storage/db/repositories/pending_tool_state.py
Normal file
128
application/storage/db/repositories/pending_tool_state.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Repository for the ``pending_tool_state`` table.
|
||||
|
||||
Mirrors the continuation service's three operations on
|
||||
``pending_tool_state`` in Mongo:
|
||||
|
||||
- save_state → upsert (INSERT ... ON CONFLICT DO UPDATE)
|
||||
- load_state → find_one by (conversation_id, user_id)
|
||||
- delete_state → delete_one by (conversation_id, user_id)
|
||||
|
||||
Plus a cleanup method for the Celery beat task that replaces Mongo's
|
||||
TTL index.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
|
||||
|
||||
|
||||
class PendingToolStateRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
messages: list,
|
||||
pending_tool_calls: list,
|
||||
tools_dict: dict,
|
||||
tool_schemas: list,
|
||||
agent_config: dict,
|
||||
client_tools: list | None = None,
|
||||
ttl_seconds: int = PENDING_STATE_TTL_SECONDS,
|
||||
) -> dict:
|
||||
"""Upsert pending tool state.
|
||||
|
||||
Mirrors Mongo's ``replace_one(..., upsert=True)``.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = datetime.fromtimestamp(
|
||||
now.timestamp() + ttl_seconds, tz=timezone.utc,
|
||||
)
|
||||
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO pending_tool_state
|
||||
(conversation_id, user_id, messages, pending_tool_calls,
|
||||
tools_dict, tool_schemas, agent_config, client_tools,
|
||||
created_at, expires_at)
|
||||
VALUES
|
||||
(CAST(:conv_id AS uuid), :user_id,
|
||||
CAST(:messages AS jsonb), CAST(:pending AS jsonb),
|
||||
CAST(:tools_dict AS jsonb), CAST(:schemas AS jsonb),
|
||||
CAST(:agent_config AS jsonb), CAST(:client_tools AS jsonb),
|
||||
:created_at, :expires_at)
|
||||
ON CONFLICT (conversation_id, user_id) DO UPDATE SET
|
||||
messages = EXCLUDED.messages,
|
||||
pending_tool_calls = EXCLUDED.pending_tool_calls,
|
||||
tools_dict = EXCLUDED.tools_dict,
|
||||
tool_schemas = EXCLUDED.tool_schemas,
|
||||
agent_config = EXCLUDED.agent_config,
|
||||
client_tools = EXCLUDED.client_tools,
|
||||
created_at = EXCLUDED.created_at,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"messages": json.dumps(messages),
|
||||
"pending": json.dumps(pending_tool_calls),
|
||||
"tools_dict": json.dumps(tools_dict),
|
||||
"schemas": json.dumps(tool_schemas),
|
||||
"agent_config": json.dumps(agent_config),
|
||||
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
|
||||
"created_at": now,
|
||||
"expires_at": expires,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def load_state(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM pending_tool_state "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id"
|
||||
),
|
||||
{"conv_id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete_state(self, conversation_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM pending_tool_state "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id"
|
||||
),
|
||||
{"conv_id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Delete rows where ``expires_at < now()``.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index. Intended to
|
||||
be called from a Celery beat task every 60 seconds.
|
||||
"""
|
||||
# clock_timestamp() — not now() — since the latter is frozen to the
|
||||
# start of the transaction, which would let state that has just
|
||||
# expired survive one more cleanup tick.
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM pending_tool_state WHERE expires_at < clock_timestamp()")
|
||||
)
|
||||
return result.rowcount
|
||||
161
application/storage/db/repositories/prompts.py
Normal file
161
application/storage/db/repositories/prompts.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Repository for the ``prompts`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``prompts_collection``:
|
||||
|
||||
1. ``insert_one`` in prompts/routes.py (create)
|
||||
2. ``find`` by user in prompts/routes.py (list)
|
||||
3. ``find_one`` by id+user in prompts/routes.py (get single)
|
||||
4. ``find_one`` by id only in stream_processor.py (get content for rendering)
|
||||
5. ``update_one`` in prompts/routes.py (update name+content)
|
||||
6. ``delete_one`` in prompts/routes.py (delete)
|
||||
7. ``find_one`` + ``insert_one`` in seeder.py (upsert by user+name+content)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class PromptsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``prompts_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
content: str,
|
||||
*,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
sql = """
|
||||
INSERT INTO prompts (user_id, name, content, legacy_mongo_id)
|
||||
VALUES (:user_id, :name, :content, :legacy_mongo_id)
|
||||
RETURNING *
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(sql),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"content": content,
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, prompt_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": prompt_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||
"""Fetch a prompt by the original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM prompts WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
|
||||
"""Fetch prompt content by ID without user scoping.
|
||||
|
||||
Used only by stream_processor to render a prompt whose owner is
|
||||
not known at call time. Do NOT use in user-facing routes.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": prompt_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM prompts WHERE user_id = :user_id ORDER BY created_at"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, prompt_id: str, user_id: str, name: str, content: str) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE prompts
|
||||
SET name = :name, content = :content, updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"id": prompt_id, "user_id": user_id, "name": name, "content": content},
|
||||
)
|
||||
|
||||
def update_by_legacy_id(
|
||||
self,
|
||||
legacy_mongo_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Update a prompt addressed by the Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE prompts
|
||||
SET name = :name, content = :content, updated_at = now()
|
||||
WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"legacy_id": legacy_mongo_id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, prompt_id: str, user_id: str) -> None:
|
||||
self._conn.execute(
|
||||
text("DELETE FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": prompt_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||
"""Delete a prompt addressed by the Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM prompts "
|
||||
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||
),
|
||||
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def find_or_create(self, user_id: str, name: str, content: str) -> dict:
|
||||
"""Return existing prompt matching (user, name, content), or create one.
|
||||
|
||||
Used by the seeder to avoid duplicating template prompts.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM prompts WHERE user_id = :user_id AND name = :name AND content = :content"
|
||||
),
|
||||
{"user_id": user_id, "name": name, "content": content},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is not None:
|
||||
return row_to_dict(row)
|
||||
return self.create(user_id, name, content)
|
||||
205
application/storage/db/repositories/shared_conversations.py
Normal file
205
application/storage/db/repositories/shared_conversations.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Repository for the ``shared_conversations`` table.
|
||||
|
||||
Covers the sharing operations from ``shared_conversations_collections``
|
||||
in Mongo:
|
||||
|
||||
- create a share record (with UUID, conversation_id, user, visibility flags)
|
||||
- look up by uuid (public access)
|
||||
- look up by conversation_id + user + flags (dedup check)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as uuid_mod
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import shared_conversations_table
|
||||
|
||||
|
||||
class SharedConversationsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
is_promptable: bool = False,
|
||||
first_n_queries: int = 0,
|
||||
api_key: str | None = None,
|
||||
prompt_id: str | None = None,
|
||||
chunks: int | None = None,
|
||||
share_uuid: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a share record.
|
||||
|
||||
``share_uuid`` allows the dual-write caller to supply the same
|
||||
UUID that Mongo received, so public ``/shared/{uuid}`` links
|
||||
keep resolving from both stores during the dual-write window.
|
||||
|
||||
Callers that need race-free dedup on the logical share key
|
||||
should use :meth:`get_or_create` instead — it relies on the
|
||||
composite partial unique index added in migration 0008 to
|
||||
collapse concurrent requests to a single row.
|
||||
"""
|
||||
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||
values: dict = {
|
||||
"uuid": final_uuid,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"first_n_queries": first_n_queries,
|
||||
}
|
||||
if api_key:
|
||||
values["api_key"] = api_key
|
||||
if prompt_id:
|
||||
values["prompt_id"] = prompt_id
|
||||
if chunks is not None:
|
||||
values["chunks"] = chunks
|
||||
|
||||
stmt = (
|
||||
pg_insert(shared_conversations_table)
|
||||
.values(**values)
|
||||
.returning(shared_conversations_table)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
is_promptable: bool = False,
|
||||
first_n_queries: int = 0,
|
||||
api_key: str | None = None,
|
||||
prompt_id: str | None = None,
|
||||
chunks: int | None = None,
|
||||
share_uuid: str | None = None,
|
||||
) -> dict:
|
||||
"""Race-free share create/lookup keyed on the logical dedup tuple.
|
||||
|
||||
Leverages the partial unique index on
|
||||
``(conversation_id, user_id, is_promptable, first_n_queries,
|
||||
COALESCE(api_key, ''))`` added in migration 0008. Concurrent
|
||||
requests for the same logical share converge on one row. The
|
||||
returned dict's ``uuid`` is the canonical public identifier.
|
||||
|
||||
Dedup key rationale — ``prompt_id`` and ``chunks`` are
|
||||
deliberately *not* part of the uniqueness key. A share row is
|
||||
identified by "who shared what conversation under which
|
||||
visibility rules"; ``prompt_id`` / ``chunks`` are mutable
|
||||
properties of that share and are last-write-wins on re-share.
|
||||
This preserves existing public ``/shared/{uuid}`` URLs when a
|
||||
user updates the prompt or chunk count, matching the Mongo
|
||||
``find_one`` + ``update`` semantics.
|
||||
"""
|
||||
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO shared_conversations
|
||||
(uuid, conversation_id, user_id, is_promptable,
|
||||
first_n_queries, api_key, prompt_id, chunks)
|
||||
VALUES
|
||||
(CAST(:uuid AS uuid), CAST(:conversation_id AS uuid),
|
||||
:user_id, :is_promptable, :first_n_queries,
|
||||
:api_key, CAST(:prompt_id AS uuid), :chunks)
|
||||
ON CONFLICT (conversation_id, user_id, is_promptable,
|
||||
first_n_queries, COALESCE(api_key, ''))
|
||||
DO UPDATE SET prompt_id = EXCLUDED.prompt_id,
|
||||
chunks = EXCLUDED.chunks
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"uuid": final_uuid,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"first_n_queries": first_n_queries,
|
||||
"api_key": api_key,
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def find_by_uuid(self, share_uuid: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE uuid = CAST(:uuid AS uuid)"
|
||||
),
|
||||
{"uuid": share_uuid},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def find_existing(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
is_promptable: bool,
|
||||
first_n_queries: int,
|
||||
api_key: str | None = None,
|
||||
) -> Optional[dict]:
|
||||
"""Check for an existing share with matching parameters.
|
||||
|
||||
Mirrors the Mongo ``find_one`` dedup check before creating a share.
|
||||
"""
|
||||
if api_key:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id "
|
||||
"AND is_promptable = :is_promptable "
|
||||
"AND first_n_queries = :fnq "
|
||||
"AND api_key = :api_key "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"fnq": first_n_queries,
|
||||
"api_key": api_key,
|
||||
},
|
||||
)
|
||||
else:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"AND user_id = :user_id "
|
||||
"AND is_promptable = :is_promptable "
|
||||
"AND first_n_queries = :fnq "
|
||||
"AND api_key IS NULL "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"is_promptable": is_promptable,
|
||||
"fnq": first_n_queries,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_conversation(self, conversation_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM shared_conversations "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||
"ORDER BY created_at DESC"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
80
application/storage/db/repositories/sources.py
Normal file
80
application/storage/db/repositories/sources.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Repository for the ``sources`` table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, func, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import sources_table
|
||||
|
||||
|
||||
class SourcesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, name: str, *, user_id: str,
|
||||
type: Optional[str] = None, metadata: Optional[dict] = None) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO sources (user_id, name, type, metadata)
|
||||
VALUES (:user_id, :name, :type, CAST(:metadata AS jsonb))
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"type": type,
|
||||
"metadata": json.dumps(metadata or {}),
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, source_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": source_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM sources WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, source_id: str, user_id: str, fields: dict) -> None:
|
||||
allowed = {"name", "type", "metadata"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return
|
||||
|
||||
# Pass Python objects directly for JSONB columns when using
|
||||
# SQLAlchemy Core .update() — the JSONB type processor json.dumps
|
||||
# them itself; pre-serialising here would double-encode and the
|
||||
# value would round-trip as a JSON string instead of the original
|
||||
# dict.
|
||||
values: dict = dict(filtered)
|
||||
values["updated_at"] = func.now()
|
||||
|
||||
t = sources_table
|
||||
stmt = (
|
||||
t.update()
|
||||
.where(t.c.id == source_id)
|
||||
.where(t.c.user_id == user_id)
|
||||
.values(**values)
|
||||
)
|
||||
self._conn.execute(stmt)
|
||||
|
||||
def delete(self, source_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": source_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
58
application/storage/db/repositories/stack_logs.py
Normal file
58
application/storage/db/repositories/stack_logs.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Repository for the ``stack_logs`` table.
|
||||
|
||||
Covers the single operation the legacy Mongo code performs:
|
||||
|
||||
1. ``insert_one`` in logging.py ``_log_to_mongodb`` — append-only debug/error
|
||||
activity log. The Mongo collection is ``stack_logs``; the Mongo variable
|
||||
inside ``_log_to_mongodb`` is misleadingly named ``user_logs_collection``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
|
||||
class StackLogsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``stack_logs`` collection."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def insert(
|
||||
self,
|
||||
*,
|
||||
activity_id: str,
|
||||
endpoint: Optional[str] = None,
|
||||
level: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
stacks: Optional[list] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO stack_logs (activity_id, endpoint, level, user_id, api_key, query, stacks, timestamp)
|
||||
VALUES (
|
||||
:activity_id, :endpoint, :level, :user_id, :api_key, :query,
|
||||
CAST(:stacks AS jsonb),
|
||||
COALESCE(:timestamp, now())
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"activity_id": activity_id,
|
||||
"endpoint": endpoint,
|
||||
"level": level,
|
||||
"user_id": user_id,
|
||||
"api_key": api_key,
|
||||
"query": query,
|
||||
"stacks": json.dumps(stacks or []),
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
78
application/storage/db/repositories/todos.py
Normal file
78
application/storage/db/repositories/todos.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Repository for the ``todos`` table.
|
||||
|
||||
Covers the operations in ``application/agents/tools/todo_list.py``.
|
||||
Note: the Mongo schema uses ``todo_id`` (sequential int) and ``status`` (text),
|
||||
while the Postgres schema uses ``completed`` (boolean) and the UUID ``id`` as PK.
|
||||
The repository bridges both shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class TodosRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, tool_id: str, title: str) -> dict:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO todos (user_id, tool_id, title)
|
||||
VALUES (:user_id, CAST(:tool_id AS uuid), :title)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id, "title": title},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, todo_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": todo_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user_tool(self, user_id: str, tool_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM todos WHERE user_id = :user_id "
|
||||
"AND tool_id = CAST(:tool_id AS uuid) ORDER BY created_at"
|
||||
),
|
||||
{"user_id": user_id, "tool_id": tool_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update_title(self, todo_id: str, user_id: str, title: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE todos SET title = :title, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": todo_id, "user_id": user_id, "title": title},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def set_completed(self, todo_id: str, user_id: str, completed: bool = True) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE todos SET completed = :completed, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": todo_id, "user_id": user_id, "completed": completed},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, todo_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": todo_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
104
application/storage/db/repositories/token_usage.py
Normal file
104
application/storage/db/repositories/token_usage.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Repository for the ``token_usage`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``token_usage_collection`` / ``usage_collection``:
|
||||
|
||||
1. ``insert_one`` in usage.py (record per-call token counts)
|
||||
2. ``aggregate`` in analytics/routes.py (time-bucketed totals)
|
||||
3. ``aggregate`` in answer/routes/base.py (24h sum for rate limiting)
|
||||
4. ``count_documents`` in answer/routes/base.py (24h request count)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
|
||||
class TokenUsageRepository:
|
||||
"""Postgres-backed replacement for Mongo ``token_usage_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def insert(
|
||||
self,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
prompt_tokens: int = 0,
|
||||
generated_tokens: int = 0,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
|
||||
VALUES (
|
||||
:user_id, :api_key,
|
||||
CAST(:agent_id AS uuid),
|
||||
:prompt_tokens, :generated_tokens,
|
||||
COALESCE(:timestamp, now())
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"api_key": api_key,
|
||||
"agent_id": agent_id,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
def sum_tokens_in_range(
|
||||
self,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Total (prompt + generated) tokens in the given time range."""
|
||||
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||
params: dict = {"start": start, "end": end}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
params["user_id"] = user_id
|
||||
if api_key is not None:
|
||||
clauses.append("api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
result = self._conn.execute(
|
||||
text(f"SELECT COALESCE(SUM(prompt_tokens + generated_tokens), 0) FROM token_usage WHERE {where}"),
|
||||
params,
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
def count_in_range(
|
||||
self,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Count of token_usage rows in the given time range (for request limiting)."""
|
||||
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||
params: dict = {"start": start, "end": end}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
params["user_id"] = user_id
|
||||
if api_key is not None:
|
||||
clauses.append("api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
result = self._conn.execute(
|
||||
text(f"SELECT COUNT(*) FROM token_usage WHERE {where}"),
|
||||
params,
|
||||
)
|
||||
return result.scalar()
|
||||
84
application/storage/db/repositories/user_logs.py
Normal file
84
application/storage/db/repositories/user_logs.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Repository for the ``user_logs`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``user_logs_collection``:
|
||||
|
||||
1. ``insert_one`` in logging.py (per-request activity log via
|
||||
``_log_to_mongodb`` — note: the *Mongo* variable is confusingly named
|
||||
``user_logs_collection`` but points at the ``user_logs`` Mongo
|
||||
collection, not ``stack_logs``)
|
||||
2. ``insert_one`` in answer/routes/base.py (per-stream log entry)
|
||||
3. ``find`` with sort/skip/limit in analytics/routes.py (paginated log list)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class UserLogsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``user_logs_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def insert(
|
||||
self,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
data: Optional[dict] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO user_logs (user_id, endpoint, data, timestamp)
|
||||
VALUES (:user_id, :endpoint, CAST(:data AS jsonb), COALESCE(:timestamp, now()))
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"endpoint": endpoint,
|
||||
"data": json.dumps(data, default=str) if data is not None else None,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
def list_paginated(
|
||||
self,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
) -> tuple[list[dict], bool]:
|
||||
"""Return ``(rows, has_more)`` for the requested page.
|
||||
|
||||
Mirrors the Mongo ``find(query).sort().skip().limit(page_size+1)``
|
||||
pattern used in analytics/routes.py.
|
||||
"""
|
||||
clauses: list[str] = []
|
||||
params: dict = {"limit": page_size + 1, "offset": (page - 1) * page_size}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
params["user_id"] = user_id
|
||||
if api_key is not None:
|
||||
clauses.append("data->>'api_key' = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
f"SELECT * FROM user_logs {where} ORDER BY timestamp DESC LIMIT :limit OFFSET :offset"
|
||||
),
|
||||
params,
|
||||
)
|
||||
rows = [row_to_dict(r) for r in result.fetchall()]
|
||||
has_more = len(rows) > page_size
|
||||
return rows[:page_size], has_more
|
||||
114
application/storage/db/repositories/user_tools.py
Normal file
114
application/storage/db/repositories/user_tools.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Repository for the ``user_tools`` table.
|
||||
|
||||
Covers every operation the legacy Mongo code performs on
|
||||
``user_tools_collection``:
|
||||
|
||||
1. ``find`` by user in tools/routes.py and base.py (list all / active)
|
||||
2. ``find_one`` by id in tools/routes.py and sharing.py (get single)
|
||||
3. ``insert_one`` in tools/routes.py and mcp.py (create)
|
||||
4. ``update_one`` in tools/routes.py and mcp.py (update fields)
|
||||
5. ``delete_one`` in tools/routes.py (delete)
|
||||
6. ``find`` by user+status in stream_processor.py and tool_executor.py (active tools)
|
||||
7. ``find_one`` by user+name in mcp.py (upsert check)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class UserToolsRepository:
|
||||
"""Postgres-backed replacement for Mongo ``user_tools_collection``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(self, user_id: str, name: str, *, config: Optional[dict] = None,
|
||||
custom_name: Optional[str] = None, display_name: Optional[str] = None,
|
||||
extra: Optional[dict] = None) -> dict:
|
||||
"""Insert a new tool row. ``extra`` is merged into the config JSONB."""
|
||||
cfg = config or {}
|
||||
if extra:
|
||||
cfg.update(extra)
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO user_tools (user_id, name, custom_name, display_name, config)
|
||||
VALUES (:user_id, :name, :custom_name, :display_name, CAST(:config AS jsonb))
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"custom_name": custom_name,
|
||||
"display_name": display_name,
|
||||
"config": json.dumps(cfg),
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, tool_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": tool_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM user_tools WHERE user_id = :user_id ORDER BY created_at"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, tool_id: str, user_id: str, fields: dict) -> None:
|
||||
"""Update arbitrary fields on a tool row.
|
||||
|
||||
``fields`` maps column names to new values. Only ``name``,
|
||||
``custom_name``, ``display_name``, and ``config`` are allowed.
|
||||
"""
|
||||
allowed = {"name", "custom_name", "display_name", "config"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return
|
||||
params: dict = {
|
||||
"id": tool_id,
|
||||
"user_id": user_id,
|
||||
"name": filtered.get("name"),
|
||||
"custom_name": filtered.get("custom_name"),
|
||||
"display_name": filtered.get("display_name"),
|
||||
"config": (
|
||||
json.dumps(filtered["config"])
|
||||
if "config" in filtered and isinstance(filtered["config"], dict)
|
||||
else filtered.get("config")
|
||||
),
|
||||
}
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE user_tools
|
||||
SET
|
||||
name = COALESCE(:name, name),
|
||||
custom_name = COALESCE(:custom_name, custom_name),
|
||||
display_name = COALESCE(:display_name, display_name),
|
||||
config = COALESCE(CAST(:config AS jsonb), config),
|
||||
updated_at = now()
|
||||
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
def delete(self, tool_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text("DELETE FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||
{"id": tool_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
170
application/storage/db/repositories/workflow_edges.py
Normal file
170
application/storage/db/repositories/workflow_edges.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Repository for the ``workflow_edges`` table.
|
||||
|
||||
Covers bulk insert, find by version, and delete operations that the
|
||||
workflow routes perform on ``workflow_edges_collection`` in Mongo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflow_edges_table
|
||||
|
||||
|
||||
class WorkflowEdgesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
edge_id: str,
|
||||
from_node_id: str,
|
||||
to_node_id: str,
|
||||
*,
|
||||
source_handle: str | None = None,
|
||||
target_handle: str | None = None,
|
||||
config: dict | None = None,
|
||||
) -> dict:
|
||||
"""Create a single edge.
|
||||
|
||||
``from_node_id`` and ``to_node_id`` are the Postgres **UUID PKs**
|
||||
of the workflow_nodes rows (not user-provided node_id strings).
|
||||
"""
|
||||
values: dict = {
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"edge_id": edge_id,
|
||||
"from_node_id": from_node_id,
|
||||
"to_node_id": to_node_id,
|
||||
}
|
||||
if source_handle is not None:
|
||||
values["source_handle"] = source_handle
|
||||
if target_handle is not None:
|
||||
values["target_handle"] = target_handle
|
||||
if config is not None:
|
||||
values["config"] = config
|
||||
|
||||
stmt = pg_insert(workflow_edges_table).values(**values).returning(workflow_edges_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def bulk_create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
edges: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Insert multiple edges in one statement.
|
||||
|
||||
Each element must have ``edge_id``, ``from_node_id`` (UUID PK),
|
||||
``to_node_id`` (UUID PK). Optional: ``source_handle``,
|
||||
``target_handle``, ``config``.
|
||||
"""
|
||||
if not edges:
|
||||
return []
|
||||
|
||||
rows = []
|
||||
for e in edges:
|
||||
rows.append({
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"edge_id": e["edge_id"],
|
||||
"from_node_id": e["from_node_id"],
|
||||
"to_node_id": e["to_node_id"],
|
||||
"source_handle": e.get("source_handle"),
|
||||
"target_handle": e.get("target_handle"),
|
||||
"config": e.get("config", {}),
|
||||
})
|
||||
|
||||
stmt = pg_insert(workflow_edges_table).values(rows).returning(workflow_edges_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_by_version(
|
||||
self, workflow_id: str, graph_version: int,
|
||||
) -> list[dict]:
|
||||
"""List edges for a workflow/version, shaped to match the live API.
|
||||
|
||||
Joins ``workflow_nodes`` twice so callers receive the user-provided
|
||||
node-id strings (``source_id``/``target_id``) that the Mongo code
|
||||
and the frontend use, not the internal node UUIDs. The raw UUID
|
||||
columns (``from_node_id``/``to_node_id``) are still included in
|
||||
case a caller needs them.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT e.*,
|
||||
fn.node_id AS source_id,
|
||||
tn.node_id AS target_id
|
||||
FROM workflow_edges e
|
||||
JOIN workflow_nodes fn ON fn.id = e.from_node_id
|
||||
JOIN workflow_nodes tn ON tn.id = e.to_node_id
|
||||
WHERE e.workflow_id = CAST(:wf_id AS uuid)
|
||||
AND e.graph_version = :ver
|
||||
ORDER BY e.edge_id
|
||||
"""
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def resolve_node_id(
|
||||
self, workflow_id: str, graph_version: int, node_id: str,
|
||||
) -> Optional[str]:
|
||||
"""Look up the UUID PK of a node by its user-provided ``node_id``.
|
||||
|
||||
Callers that receive edges in the frontend shape (``source_id`` /
|
||||
``target_id`` are user-provided strings) use this helper to
|
||||
translate to the UUID PK before calling :meth:`create` /
|
||||
:meth:`bulk_create`.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT id FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver AND node_id = :node_id"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version, "node_id": node_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
def delete_by_workflow(self, workflow_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_edges "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid)"
|
||||
),
|
||||
{"wf_id": workflow_id},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_edges "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
|
||||
"""Delete all edges for a workflow except the specified version."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_edges "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version != :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": keep_version},
|
||||
)
|
||||
return result.rowcount
|
||||
158
application/storage/db/repositories/workflow_nodes.py
Normal file
158
application/storage/db/repositories/workflow_nodes.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Repository for the ``workflow_nodes`` table.
|
||||
|
||||
Covers bulk insert, find by version, and delete operations that the
|
||||
workflow routes perform on ``workflow_nodes_collection`` in Mongo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflow_nodes_table
|
||||
|
||||
|
||||
class WorkflowNodesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
node_id: str,
|
||||
node_type: str,
|
||||
*,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
position: dict | None = None,
|
||||
config: dict | None = None,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
values: dict = {
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"node_id": node_id,
|
||||
"node_type": node_type,
|
||||
}
|
||||
if title is not None:
|
||||
values["title"] = title
|
||||
if description is not None:
|
||||
values["description"] = description
|
||||
if position is not None:
|
||||
values["position"] = position
|
||||
if config is not None:
|
||||
values["config"] = config
|
||||
if legacy_mongo_id is not None:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(workflow_nodes_table).values(**values).returning(workflow_nodes_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def bulk_create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph_version: int,
|
||||
nodes: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Insert multiple nodes in one statement.
|
||||
|
||||
Each element of ``nodes`` should have at least ``node_id`` and
|
||||
``node_type``; optional keys: ``title``, ``description``,
|
||||
``position``, ``config``.
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
rows = []
|
||||
for n in nodes:
|
||||
rows.append({
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"node_id": n["node_id"],
|
||||
"node_type": n["node_type"],
|
||||
"title": n.get("title"),
|
||||
"description": n.get("description"),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("config", {}),
|
||||
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||
})
|
||||
|
||||
stmt = pg_insert(workflow_nodes_table).values(rows).returning(workflow_nodes_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_by_version(
|
||||
self, workflow_id: str, graph_version: int,
|
||||
) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver "
|
||||
"ORDER BY node_id"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_node(
|
||||
self, workflow_id: str, graph_version: int, node_id: str,
|
||||
) -> Optional[dict]:
|
||||
"""Find a single node by its user-provided ``node_id``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver AND node_id = :nid"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version, "nid": node_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
"""Find a node by the original Mongo ObjectId string."""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM workflow_nodes WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete_by_workflow(self, workflow_id: str) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid)"
|
||||
),
|
||||
{"wf_id": workflow_id},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version = :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": graph_version},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
|
||||
"""Delete all nodes for a workflow except the specified version."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflow_nodes "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"AND graph_version != :ver"
|
||||
),
|
||||
{"wf_id": workflow_id, "ver": keep_version},
|
||||
)
|
||||
return result.rowcount
|
||||
83
application/storage/db/repositories/workflow_runs.py
Normal file
83
application/storage/db/repositories/workflow_runs.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Repository for the ``workflow_runs`` table.
|
||||
|
||||
In Mongo, workflow_runs_collection only has ``insert_one`` — runs are
|
||||
written once after workflow execution completes and never updated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflow_runs_table
|
||||
|
||||
|
||||
class WorkflowRunsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
status: str,
|
||||
*,
|
||||
inputs: dict | None = None,
|
||||
result: dict | None = None,
|
||||
steps: list | None = None,
|
||||
started_at=None,
|
||||
ended_at=None,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
values: dict = {
|
||||
"workflow_id": workflow_id,
|
||||
"user_id": user_id,
|
||||
"status": status,
|
||||
}
|
||||
if inputs is not None:
|
||||
values["inputs"] = inputs
|
||||
if result is not None:
|
||||
values["result"] = result
|
||||
if steps is not None:
|
||||
values["steps"] = steps
|
||||
if started_at is not None:
|
||||
values["started_at"] = started_at
|
||||
if ended_at is not None:
|
||||
values["ended_at"] = ended_at
|
||||
if legacy_mongo_id is not None:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(workflow_runs_table).values(**values).returning(workflow_runs_table)
|
||||
res = self._conn.execute(stmt)
|
||||
return row_to_dict(res.fetchone())
|
||||
|
||||
def get(self, run_id: str) -> Optional[dict]:
|
||||
res = self._conn.execute(
|
||||
text("SELECT * FROM workflow_runs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": run_id},
|
||||
)
|
||||
row = res.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||
"""Fetch a workflow run by the original Mongo ObjectId string."""
|
||||
res = self._conn.execute(
|
||||
text("SELECT * FROM workflow_runs WHERE legacy_mongo_id = :legacy_id"),
|
||||
{"legacy_id": legacy_mongo_id},
|
||||
)
|
||||
row = res.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_workflow(self, workflow_id: str) -> list[dict]:
|
||||
res = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflow_runs "
|
||||
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||
"ORDER BY started_at DESC"
|
||||
),
|
||||
{"wf_id": workflow_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in res.fetchall()]
|
||||
125
application/storage/db/repositories/workflows.py
Normal file
125
application/storage/db/repositories/workflows.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Repository for the ``workflows`` table.
|
||||
|
||||
Covers CRUD on workflow metadata:
|
||||
|
||||
- create / get / list / update / delete
|
||||
- graph version management
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import workflows_table
|
||||
|
||||
|
||||
class WorkflowsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
*,
|
||||
legacy_mongo_id: str | None = None,
|
||||
) -> dict:
|
||||
values: dict = {"user_id": user_id, "name": name}
|
||||
if description is not None:
|
||||
values["description"] = description
|
||||
if legacy_mongo_id is not None:
|
||||
values["legacy_mongo_id"] = legacy_mongo_id
|
||||
|
||||
stmt = pg_insert(workflows_table).values(**values).returning(workflows_table)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, workflow_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflows "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": workflow_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_id(self, workflow_id: str) -> Optional[dict]:
|
||||
"""Fetch a workflow by ID without user check (for internal use)."""
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM workflows WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": workflow_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_by_legacy_id(
|
||||
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||
) -> Optional[dict]:
|
||||
"""Fetch a workflow by its original Mongo ObjectId string."""
|
||||
sql = "SELECT * FROM workflows WHERE legacy_mongo_id = :legacy_id"
|
||||
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||
if user_id is not None:
|
||||
sql += " AND user_id = :user_id"
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM workflows "
|
||||
"WHERE user_id = :user_id ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, workflow_id: str, user_id: str, fields: dict) -> bool:
|
||||
allowed = {"name", "description", "current_graph_version"}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return False
|
||||
|
||||
set_parts = [f"{col} = :{col}" for col in filtered]
|
||||
set_parts.append("updated_at = now()")
|
||||
params = {**filtered, "id": workflow_id, "user_id": user_id}
|
||||
|
||||
sql = (
|
||||
f"UPDATE workflows SET {', '.join(set_parts)} "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
)
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def increment_graph_version(self, workflow_id: str, user_id: str) -> Optional[int]:
|
||||
"""Atomically increment ``current_graph_version`` and return the new value."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE workflows "
|
||||
"SET current_graph_version = current_graph_version + 1, "
|
||||
" updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"RETURNING current_graph_version"
|
||||
),
|
||||
{"id": workflow_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def delete(self, workflow_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM workflows "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": workflow_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
@@ -110,6 +110,20 @@ def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
usage_data["agent_id"] = normalized_agent_id
|
||||
usage_collection.insert_one(usage_data)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
|
||||
dual_write(
|
||||
TokenUsageRepository,
|
||||
lambda repo, d=usage_data: repo.insert(
|
||||
user_id=d.get("user_id"),
|
||||
api_key=d.get("api_key"),
|
||||
agent_id=d.get("agent_id"),
|
||||
prompt_tokens=d["prompt_tokens"],
|
||||
generated_tokens=d["generated_tokens"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
|
||||
@@ -27,13 +27,20 @@ class PGVectorStore(BaseVectorStore):
|
||||
self._metadata_column = metadata_column
|
||||
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
|
||||
# Use provided connection string or fall back to settings
|
||||
# Use provided connection string or fall back to settings.
|
||||
# If PGVECTOR_CONNECTION_STRING is not set but POSTGRES_URI is,
|
||||
# reuse the same cluster — normalize from SQLAlchemy dialect to libpq form.
|
||||
self._connection_string = connection_string or getattr(settings, 'PGVECTOR_CONNECTION_STRING', None)
|
||||
|
||||
|
||||
if not self._connection_string and getattr(settings, 'POSTGRES_URI', None):
|
||||
from application.core.db_uri import normalize_pgvector_connection_string
|
||||
self._connection_string = normalize_pgvector_connection_string(settings.POSTGRES_URI)
|
||||
|
||||
if not self._connection_string:
|
||||
raise ValueError(
|
||||
"PostgreSQL connection string is required. "
|
||||
"Set PGVECTOR_CONNECTION_STRING in settings or pass connection_string parameter."
|
||||
"Set PGVECTOR_CONNECTION_STRING or POSTGRES_URI in settings, "
|
||||
"or pass connection_string parameter."
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -247,7 +247,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
|
||||
|
||||
def download_file(url, params, dest_path):
|
||||
try:
|
||||
response = requests.get(url, params=params)
|
||||
response = requests.get(url, params=params, timeout=100)
|
||||
response.raise_for_status()
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
@@ -284,12 +284,14 @@ def upload_index(full_path, file_data):
|
||||
files=files,
|
||||
data=file_data,
|
||||
headers=headers,
|
||||
timeout=100,
|
||||
)
|
||||
else:
|
||||
response = requests.post(
|
||||
urljoin(settings.API_URL, "/api/upload_index"),
|
||||
data=file_data,
|
||||
headers=headers,
|
||||
timeout=100,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except (requests.RequestException, FileNotFoundError) as e:
|
||||
@@ -1171,6 +1173,16 @@ def attachment_worker(self, file_info, user):
|
||||
}
|
||||
)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
|
||||
dual_write(
|
||||
AttachmentsRepository,
|
||||
lambda repo, u=user, fn=filename, p=relative_path, mt=mime_type, mid=attachment_id: repo.create(
|
||||
u, fn, p, mime_type=mt, legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Stored attachment with ID: {attachment_id}", extra={"user": user}
|
||||
)
|
||||
|
||||
@@ -54,8 +54,8 @@ flowchart LR
|
||||
* **Technology:** Supports multiple LLM APIs and local engines.
|
||||
* **Responsibility:** This layer provides an abstraction for interacting with Large Language Models (LLMs).
|
||||
* **Key Features:**
|
||||
* Supports LLMs from OpenAI, Google, Anthropic, Groq, HuggingFace Inference API, Azure OpenAI, also compatable with local models like Ollama, LLaMa.cpp, Text Generation Inference (TGI), SGLang, vLLM, Aphrodite, FriendliAI, and LMDeploy.
|
||||
* Manages API key handling and request formatting and Tool fromatting.
|
||||
* Supports LLMs from OpenAI, Google, Anthropic, Groq, HuggingFace Inference API, Azure OpenAI, also compatible with local models like Ollama, LLaMa.cpp, Text Generation Inference (TGI), SGLang, vLLM, Aphrodite, FriendliAI, and LMDeploy.
|
||||
* Manages API key handling and request formatting and Tool formatting.
|
||||
* Offers caching mechanisms to improve response times and reduce API usage.
|
||||
* Handles streaming responses for a more interactive user experience.
|
||||
|
||||
@@ -120,7 +120,7 @@ sequenceDiagram
|
||||
|
||||
## Deployment Architecture
|
||||
|
||||
DocsGPT is designed to be deployed using Docker and Kubernetes, here is a qucik overview of a simple k8s deployment.
|
||||
DocsGPT is designed to be deployed using Docker and Kubernetes, here is a quick overview of a simple k8s deployment.
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
|
||||
@@ -7,6 +7,10 @@ export default {
|
||||
"title": "🔗 SharePoint / OneDrive",
|
||||
"href": "/Guides/Integrations/sharepoint-connector"
|
||||
},
|
||||
"confluence-connector": {
|
||||
"title": "🔗 Confluence",
|
||||
"href": "/Guides/Integrations/confluence-connector"
|
||||
},
|
||||
"mcp-tool-integration": {
|
||||
"title": "🔗 MCP Tools",
|
||||
"href": "/Guides/Integrations/mcp-tool-integration"
|
||||
|
||||
67
docs/content/Guides/Integrations/confluence-connector.mdx
Normal file
67
docs/content/Guides/Integrations/confluence-connector.mdx
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
title: Confluence Connector
|
||||
description: Connect your Confluence Cloud workspace as an external knowledge base to upload and process pages directly.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import { Steps } from 'nextra/components'
|
||||
|
||||
# Confluence Connector
|
||||
|
||||
Connect your Confluence Cloud workspace to upload and process pages directly as an external knowledge base. Supports page content and attachments (PDFs, Office files, text files, images, and more). Authentication is handled via Atlassian OAuth 2.0 with automatic token refresh.
|
||||
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Create an OAuth 2.0 App in Atlassian
|
||||
|
||||
1. Go to [developer.atlassian.com/console/myapps](https://developer.atlassian.com/console/myapps/) and click **Create** > **OAuth 2.0 integration**
|
||||
2. Under **Authorization**, add a callback URL:
|
||||
- Local: `http://localhost:7091/api/connectors/callback?provider=confluence`
|
||||
- Production: `https://yourdomain.com/api/connectors/callback?provider=confluence`
|
||||
|
||||
### Step 2: Configure Permissions
|
||||
|
||||
In your app settings, go to **Permissions** and add the **Confluence API**. Enable these scopes:
|
||||
- `read:page:confluence`
|
||||
- `read:space:confluence`
|
||||
- `read:attachment:confluence`
|
||||
|
||||
### Step 3: Get Your Credentials
|
||||
|
||||
Go to **Settings** in your app to find the **Client ID** and **Secret**. Copy both.
|
||||
|
||||
### Step 4: Configure Environment Variables
|
||||
|
||||
Add to your backend `.env` file:
|
||||
|
||||
```env
|
||||
CONFLUENCE_CLIENT_ID=your-atlassian-client-id
|
||||
CONFLUENCE_CLIENT_SECRET=your-atlassian-client-secret
|
||||
```
|
||||
|
||||
Add to your frontend `.env` file:
|
||||
|
||||
```env
|
||||
VITE_CONFLUENCE_CLIENT_ID=your-atlassian-client-id
|
||||
```
|
||||
|
||||
| Variable | Description | Required |
|
||||
|----------|-------------|----------|
|
||||
| `CONFLUENCE_CLIENT_ID` | Client ID from your Atlassian OAuth app | Yes |
|
||||
| `CONFLUENCE_CLIENT_SECRET` | Client secret from your Atlassian OAuth app | Yes |
|
||||
| `VITE_CONFLUENCE_CLIENT_ID` | Same Client ID, used by the frontend to show the Confluence option | Yes |
|
||||
|
||||
### Step 5: Restart and Use
|
||||
|
||||
Restart your application, then go to the upload section in DocsGPT and select **Confluence** as the source. You'll be redirected to Atlassian to sign in, then can browse spaces and select pages to process.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Option not appearing** — Verify `VITE_CONFLUENCE_CLIENT_ID` is set in the frontend `.env`, then restart.
|
||||
- **Authentication failed** — Check that the callback URL matches exactly, including `?provider=confluence`.
|
||||
- **No accessible sites** — Ensure the authenticating user has access to at least one Confluence Cloud site.
|
||||
- **Permission denied** — Verify that the Confluence API scopes are enabled in your Atlassian app settings.
|
||||
@@ -8,205 +8,66 @@ import { Steps } from 'nextra/components'
|
||||
|
||||
# Google Drive Connector
|
||||
|
||||
The Google Drive Connector allows you to seamlessly connect your Google Drive account as an external knowledge base. This integration enables you to upload and process files directly from your Google Drive without manually downloading and uploading them to DocsGPT.
|
||||
Connect your Google Drive account to upload and process files directly as an external knowledge base. Supports Google Workspace files (Docs, Sheets, Slides), Office files, PDFs, text files, CSVs, images, and more. Authentication is handled via Google OAuth 2.0 with automatic token refresh.
|
||||
|
||||
## Features
|
||||
|
||||
- **Direct File Access**: Browse and select files directly from your Google Drive
|
||||
- **Comprehensive File Support**: Supports all major document formats including:
|
||||
- Google Workspace files (Docs, Sheets, Slides)
|
||||
- Microsoft Office files (.docx, .xlsx, .pptx, .doc, .ppt, .xls)
|
||||
- PDF documents
|
||||
- Text files (.txt, .md, .rst, .html, .rtf)
|
||||
- Data files (.csv, .json)
|
||||
- Image files (.png, .jpg, .jpeg)
|
||||
- E-books (.epub)
|
||||
- **Secure Authentication**: Uses OAuth 2.0 for secure access to your Google Drive
|
||||
- **Real-time Sync**: Process files directly from Google Drive without local downloads
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
The Google Drive Connector requires proper configuration of Google API credentials. Follow the setup instructions below to enable this feature.
|
||||
</Callout>
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before setting up the Google Drive Connector, you'll need:
|
||||
|
||||
1. A Google Cloud Platform (GCP) project
|
||||
2. Google Drive API enabled
|
||||
3. OAuth 2.0 credentials configured
|
||||
4. DocsGPT instance with proper environment variables
|
||||
|
||||
## Setup Instructions
|
||||
## Setup
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Create a Google Cloud Project
|
||||
|
||||
1. Go to the [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create a new project or select an existing one
|
||||
3. Note down your Project ID for later use
|
||||
1. Go to the [Google Cloud Console](https://console.cloud.google.com/) and create a new project (or select an existing one)
|
||||
2. Navigate to **APIs & Services** > **Library**, search for "Google Drive API", and click **Enable**
|
||||
|
||||
### Step 2: Enable Google Drive API
|
||||
### Step 2: Create OAuth 2.0 Credentials
|
||||
|
||||
1. In the Google Cloud Console, navigate to **APIs & Services** > **Library**
|
||||
2. Search for "Google Drive API"
|
||||
3. Click on "Google Drive API" and click **Enable**
|
||||
1. Go to **APIs & Services** > **Credentials** > **Create Credentials** > **OAuth client ID**
|
||||
2. If prompted, configure the OAuth consent screen (choose **External**, fill in required fields)
|
||||
3. Select **Web application** as the application type
|
||||
4. Add your DocsGPT URL to **Authorized JavaScript origins** (e.g. `http://localhost:3000`)
|
||||
5. Add your callback URL to **Authorized redirect URIs**:
|
||||
- Local: `http://localhost:7091/api/connectors/callback?provider=google_drive`
|
||||
- Production: `https://yourdomain.com/api/connectors/callback?provider=google_drive`
|
||||
6. Click **Create** and copy the **Client ID** and **Client Secret**
|
||||
|
||||
### Step 3: Create OAuth 2.0 Credentials
|
||||
### Step 3: Configure Environment Variables
|
||||
|
||||
1. Go to **APIs & Services** > **Credentials**
|
||||
2. Click **Create Credentials** > **OAuth client ID**
|
||||
3. If prompted, configure the OAuth consent screen:
|
||||
- Choose **External** user type (unless you're using Google Workspace)
|
||||
- Fill in the required fields (App name, User support email, Developer contact)
|
||||
- Add your domain to **Authorized domains** if deploying publicly
|
||||
4. For Application type, select **Web application**
|
||||
5. Add your DocsGPT frontend URL to **Authorized JavaScript origins**:
|
||||
- For local development: `http://localhost:3000`
|
||||
- For production: `https://yourdomain.com`
|
||||
6. Add your DocsGPT callback URL to **Authorized redirect URIs**:
|
||||
- For local development: `http://localhost:7091/api/connectors/callback?provider=google_drive`
|
||||
- For production: `https://yourdomain.com/api/connectors/callback?provider=google_drive`
|
||||
7. Click **Create** and note down the **Client ID** and **Client Secret**
|
||||
|
||||
|
||||
|
||||
### Step 4: Configure Backend Environment Variables
|
||||
|
||||
Add the following environment variables to your backend configuration:
|
||||
|
||||
**For Docker deployment**, add to your `.env` file in the root directory:
|
||||
Add to your backend `.env` file:
|
||||
|
||||
```env
|
||||
# Google Drive Connector Configuration
|
||||
GOOGLE_CLIENT_ID=your_google_client_id_here
|
||||
GOOGLE_CLIENT_SECRET=your_google_client_secret_here
|
||||
GOOGLE_CLIENT_ID=your-google-client-id
|
||||
GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
```
|
||||
|
||||
**For manual deployment**, set these environment variables in your system or application configuration.
|
||||
|
||||
### Step 5: Configure Frontend Environment Variables
|
||||
|
||||
Add the following environment variables to your frontend `.env` file:
|
||||
Add to your frontend `.env` file:
|
||||
|
||||
```env
|
||||
# Google Drive Frontend Configuration
|
||||
VITE_GOOGLE_CLIENT_ID=your_google_client_id_here
|
||||
VITE_GOOGLE_CLIENT_ID=your-google-client-id
|
||||
```
|
||||
|
||||
| Variable | Description | Required |
|
||||
|----------|-------------|----------|
|
||||
| `GOOGLE_CLIENT_ID` | OAuth Client ID from GCP Credentials | Yes |
|
||||
| `GOOGLE_CLIENT_SECRET` | OAuth Client Secret from GCP Credentials | Yes |
|
||||
| `VITE_GOOGLE_CLIENT_ID` | Same Client ID, used by the frontend to show the Google Drive option | Yes |
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Make sure to use the same Google Client ID in both backend and frontend configurations.
|
||||
</Callout>
|
||||
|
||||
### Step 6: Restart Your Application
|
||||
### Step 4: Restart and Use
|
||||
|
||||
After configuring the environment variables:
|
||||
|
||||
1. **For Docker**: Restart your Docker containers
|
||||
```bash
|
||||
docker-compose down
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
2. **For manual deployment**: Restart both backend and frontend services
|
||||
Restart your application, then go to the upload section in DocsGPT and select **Google Drive** as the source. You'll be redirected to Google to sign in, then can browse and select files to process.
|
||||
|
||||
</Steps>
|
||||
|
||||
## Using the Google Drive Connector
|
||||
|
||||
Once configured, you can use the Google Drive Connector to upload files:
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Access the Upload Interface
|
||||
|
||||
1. Navigate to the DocsGPT interface
|
||||
2. Go to the upload/training section
|
||||
3. You should now see "Google Drive" as an available upload option
|
||||
|
||||
### Step 2: Connect Your Google Account
|
||||
|
||||
1. Select "Google Drive" as your upload method
|
||||
2. Click "Connect to Google Drive"
|
||||
3. You'll be redirected to Google's OAuth consent screen
|
||||
4. Grant the necessary permissions to DocsGPT
|
||||
5. You'll be redirected back to DocsGPT with a successful connection
|
||||
|
||||
### Step 3: Select Files
|
||||
|
||||
1. Once connected, click "Select Files"
|
||||
2. The Google Drive picker will open
|
||||
3. Browse your Google Drive and select the files you want to process
|
||||
4. Click "Select" to confirm your choices
|
||||
|
||||
### Step 4: Process Files
|
||||
|
||||
1. Review your selected files
|
||||
2. Click "Train" or "Upload" to process the files
|
||||
3. DocsGPT will download and process the files from your Google Drive
|
||||
4. Once processing is complete, the files will be available in your knowledge base
|
||||
|
||||
</Steps>
|
||||
|
||||
## Supported File Types
|
||||
|
||||
The Google Drive Connector supports the following file types:
|
||||
|
||||
| File Type | Extensions | Description |
|
||||
|-----------|------------|-------------|
|
||||
| **Google Workspace** | - | Google Docs, Sheets, Slides (automatically converted) |
|
||||
| **Microsoft Office** | .docx, .xlsx, .pptx | Modern Office formats |
|
||||
| **Legacy Office** | .doc, .ppt, .xls | Older Office formats |
|
||||
| **PDF Documents** | .pdf | Portable Document Format |
|
||||
| **Text Files** | .txt, .md, .rst, .html, .rtf | Various text formats |
|
||||
| **Data Files** | .csv, .json | Structured data formats |
|
||||
| **Images** | .png, .jpg, .jpeg | Image files (with OCR if enabled) |
|
||||
| **E-books** | .epub | Electronic publication format |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**"Google Drive option not appearing"**
|
||||
- Verify that `VITE_GOOGLE_CLIENT_ID` is set in frontend environment
|
||||
- Check that `VITE_GOOGLE_CLIENT_ID` environment variable is present in your frontend configuration
|
||||
- Check browser console for any JavaScript errors
|
||||
- Ensure the frontend has been restarted after adding environment variables
|
||||
|
||||
**"Authentication failed"**
|
||||
- Verify that your OAuth 2.0 credentials are correctly configured
|
||||
- Check that the redirect URI `http://<your-domain>/api/connectors/callback?provider=google_drive` is correctly added in GCP console
|
||||
- Ensure the Google Drive API is enabled in your GCP project
|
||||
|
||||
**"Permission denied" errors**
|
||||
- Verify that the OAuth consent screen is properly configured
|
||||
- Check that your Google account has access to the files you're trying to select
|
||||
- Ensure the required scopes are granted during authentication
|
||||
|
||||
**"Files not processing"**
|
||||
- Check that the backend environment variables are correctly set
|
||||
- Verify that the OAuth credentials have the necessary permissions
|
||||
- Check the backend logs for any error messages
|
||||
|
||||
### Environment Variable Checklist
|
||||
|
||||
**Backend (.env in root directory):**
|
||||
- ✅ `GOOGLE_CLIENT_ID`
|
||||
- ✅ `GOOGLE_CLIENT_SECRET`
|
||||
|
||||
**Frontend (.env in frontend directory):**
|
||||
- ✅ `VITE_GOOGLE_CLIENT_ID`
|
||||
|
||||
### Security Considerations
|
||||
|
||||
- Keep your Google Client Secret secure and never expose it in frontend code
|
||||
- Regularly rotate your OAuth credentials
|
||||
- Use HTTPS in production to protect authentication tokens
|
||||
- Ensure proper OAuth consent screen configuration for production use
|
||||
- **Option not appearing** — Verify `VITE_GOOGLE_CLIENT_ID` is set in the frontend `.env`, then restart.
|
||||
- **Authentication failed** — Check that the redirect URI matches exactly, including `?provider=google_drive`. Ensure the Google Drive API is enabled.
|
||||
- **Permission denied** — Verify the OAuth consent screen is configured and the user has access to the target files.
|
||||
- **Files not processing** — Check backend logs and verify that backend environment variables are correctly set.
|
||||
|
||||
<Callout type="tip" emoji="💡">
|
||||
For production deployments, make sure to add your actual domain to the OAuth consent screen and authorized origins/redirect URIs.
|
||||
For production deployments, add your actual domain to the OAuth consent screen and authorized origins/redirect URIs.
|
||||
</Callout>
|
||||
|
||||
|
||||
|
||||
4
frontend/src/assets/confluence.svg
Normal file
4
frontend/src/assets/confluence.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M15.7903 2.01315C16.0583 1.53239 16.6644 1.35859 17.1464 1.62428L21.9827 4.28993C22.2157 4.41835 22.3879 4.63433 22.4613 4.89007C22.5346 5.14582 22.503 5.42024 22.3735 5.65262L20.7649 8.53807C19.6743 10.4944 17.9383 11.868 15.9685 12.5426L21.8863 15.8043C22.1193 15.9328 22.2915 16.1488 22.3649 16.4045C22.4382 16.6602 22.4066 16.9347 22.2771 17.167L19.5962 21.9761C19.3282 22.4569 18.7221 22.6307 18.24 22.365L11.4692 18.6331C10.8804 18.3085 10.1413 18.5224 9.81847 19.1015L8.20996 21.987C7.94196 22.4677 7.33584 22.6415 6.8538 22.3758L2.01729 19.7101C1.78429 19.5816 1.61207 19.3657 1.53874 19.1099C1.46541 18.8542 1.49701 18.5798 1.62655 18.3474L3.23506 15.4619C4.32566 13.5056 6.06166 12.132 8.0315 11.4574L2.11368 8.19564C1.88068 8.06721 1.70846 7.85124 1.63513 7.59549C1.56179 7.33975 1.59339 7.06533 1.72294 6.83295L4.40379 2.02389C4.67179 1.54313 5.27791 1.36933 5.75995 1.63502L12.531 5.36708C13.1199 5.69165 13.8589 5.47779 14.1818 4.89861L15.7903 2.01315ZM17.0526 3.85624L15.9287 5.87243C15.067 7.41803 13.1136 7.97187 11.5656 7.11864L5.66611 3.86698L3.9591 6.92911L9.85005 10.1761C13.11 11.9729 17.2146 10.7994 19.018 7.56424L20.1373 5.55645L17.0526 3.85624ZM14.15 13.8239C10.89 12.0271 6.78543 13.2006 4.98197 16.4357L3.86271 18.4435L6.94764 20.1439L8.07157 18.1277C8.93317 16.5821 10.8866 16.0283 12.4346 16.8815L18.3339 20.133L20.0409 17.0709L14.15 13.8239Z" fill="#000000"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
@@ -1,6 +1,7 @@
|
||||
import React, { useRef } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import { useDarkTheme } from '../hooks';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
|
||||
@@ -149,7 +150,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
|
||||
{isConnected ? (
|
||||
<div className="mb-4">
|
||||
<div className="flex w-full items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-sm font-medium text-[#212121]">
|
||||
<div className="text-eerie-black flex w-full items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-sm font-medium">
|
||||
<div className="flex max-w-[500px] items-center gap-2">
|
||||
<svg className="h-4 w-4" viewBox="0 0 24 24">
|
||||
<path
|
||||
@@ -166,7 +167,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
{onDisconnect && (
|
||||
<button
|
||||
onClick={onDisconnect}
|
||||
className="text-xs font-medium text-[#212121] underline hover:text-gray-700"
|
||||
className="text-eerie-black text-xs font-medium underline hover:text-gray-700"
|
||||
>
|
||||
{t('modals.uploadDoc.connectors.auth.disconnect')}
|
||||
</button>
|
||||
|
||||
@@ -60,6 +60,10 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
displayName: 'SharePoint',
|
||||
rootName: 'My Files',
|
||||
},
|
||||
confluence: {
|
||||
displayName: 'Confluence',
|
||||
rootName: 'Spaces',
|
||||
},
|
||||
} as const;
|
||||
|
||||
const getProviderConfig = (provider: string) => {
|
||||
@@ -202,7 +206,9 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
if (!validateResponse.ok) {
|
||||
removeSessionToken(provider);
|
||||
setIsConnected(false);
|
||||
setAuthError('Session expired. Please reconnect to Google Drive.');
|
||||
setAuthError(
|
||||
`Session expired. Please reconnect to ${getProviderConfig(provider).displayName}.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -398,6 +404,7 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
|
||||
<ConnectorAuth
|
||||
provider={provider}
|
||||
label={`Connect to ${getProviderConfig(provider).displayName}`}
|
||||
onSuccess={(data) => {
|
||||
setUserEmail(data.user_email || 'Connected User');
|
||||
setIsConnected(true);
|
||||
|
||||
@@ -263,8 +263,8 @@ const MermaidRenderer: React.FC<MermaidRendererProps> = ({
|
||||
const errorRender = !isCurrentlyLoading && error;
|
||||
|
||||
return (
|
||||
<div className="w-inherit group border-border bg-card relative rounded-lg border">
|
||||
<div className="bg-platinum flex items-center justify-between px-2 py-1">
|
||||
<div className="w-inherit group border-border bg-card relative overflow-hidden rounded-[14px] border">
|
||||
<div className="bg-platinum dark:bg-muted flex items-center justify-between px-2 py-1">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
mermaid
|
||||
</span>
|
||||
@@ -401,7 +401,7 @@ const MermaidRenderer: React.FC<MermaidRendererProps> = ({
|
||||
|
||||
{showCode && (
|
||||
<div className="border-border border-t">
|
||||
<div className="bg-platinum p-2">
|
||||
<div className="bg-platinum dark:bg-muted p-2">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
Mermaid Code
|
||||
</span>
|
||||
|
||||
@@ -1296,9 +1296,8 @@ export default function MessageInput({
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (autoFocus) inputRef.current?.focus();
|
||||
handleInput();
|
||||
}, [autoFocus, handleInput]);
|
||||
}, [handleInput]);
|
||||
|
||||
const handleChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setValue(e.target.value);
|
||||
@@ -1364,8 +1363,9 @@ export default function MessageInput({
|
||||
) {
|
||||
onSubmit(value);
|
||||
setValue('');
|
||||
// Refocus input after submission if autoFocus is enabled
|
||||
if (autoFocus) {
|
||||
if (isTouch) {
|
||||
inputRef.current?.blur();
|
||||
} else if (autoFocus) {
|
||||
setTimeout(() => {
|
||||
if (isMountedRef.current) {
|
||||
inputRef.current?.focus();
|
||||
@@ -1544,6 +1544,7 @@ export default function MessageInput({
|
||||
id="message-input"
|
||||
ref={inputRef}
|
||||
value={value}
|
||||
autoFocus={autoFocus && !isTouch}
|
||||
onChange={handleChange}
|
||||
readOnly={
|
||||
recordingState === 'recording' ||
|
||||
|
||||
@@ -236,7 +236,7 @@ export default function Conversation() {
|
||||
isSplitArtifactOpen ? 'w-[60%] px-6' : 'w-full'
|
||||
}`}
|
||||
>
|
||||
<div className="min-h-0 flex-1">
|
||||
<div className="relative min-h-0 flex-1 ">
|
||||
<ConversationMessages
|
||||
handleQuestion={handleQuestion}
|
||||
handleQuestionSubmission={handleQuestionSubmission}
|
||||
@@ -255,6 +255,7 @@ export default function Conversation() {
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
<div className="from-background pointer-events-none absolute right-1.5 bottom-0 left-0 h-6 rounded-t-2xl bg-linear-to-t to-transparent" />
|
||||
</div>
|
||||
|
||||
<div
|
||||
|
||||
@@ -559,7 +559,7 @@ const ConversationBubble = forwardRef<
|
||||
|
||||
return match ? (
|
||||
<div className="group border-border relative overflow-hidden rounded-[14px] border">
|
||||
<div className="bg-platinum flex items-center justify-between px-2 py-1">
|
||||
<div className="bg-platinum dark:bg-muted flex items-center justify-between px-2 py-1">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
{language}
|
||||
</span>
|
||||
@@ -1204,7 +1204,7 @@ function Thought({
|
||||
|
||||
return match ? (
|
||||
<div className="group border-border relative overflow-hidden rounded-[14px] border">
|
||||
<div className="bg-platinum flex items-center justify-between px-2 py-1">
|
||||
<div className="bg-platinum dark:bg-muted flex items-center justify-between px-2 py-1">
|
||||
<span className="text-foreground dark:text-foreground text-xs font-medium">
|
||||
{language}
|
||||
</span>
|
||||
|
||||
@@ -62,40 +62,130 @@ export default function ConversationMessages({
|
||||
const { t } = useTranslation();
|
||||
|
||||
const conversationRef = useRef<HTMLDivElement>(null);
|
||||
const [hasScrolledToLast, setHasScrolledToLast] = useState(true);
|
||||
const [userInterruptedScroll, setUserInterruptedScroll] = useState(false);
|
||||
const [scrollButtonVisible, setScrollButtonVisible] = useState(false);
|
||||
const userInterruptedRef = useRef(false);
|
||||
const [interrupted, setInterrupted] = useState(false);
|
||||
const lastTouchYRef = useRef<number | null>(null);
|
||||
const isInitialLoad = useRef(true);
|
||||
const prevQueriesRef = useRef(queries);
|
||||
const isAutoScrollingRef = useRef(false);
|
||||
const smoothScrollTimeoutRef =
|
||||
useRef<ReturnType<typeof setTimeout>>(undefined);
|
||||
const showButtonTimerRef = useRef<ReturnType<typeof setTimeout>>(undefined);
|
||||
|
||||
const handleUserScrollInterruption = useCallback(() => {
|
||||
if (!userInterruptedScroll && status === 'loading') {
|
||||
setUserInterruptedScroll(true);
|
||||
}
|
||||
}, [userInterruptedScroll, status]);
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
clearTimeout(smoothScrollTimeoutRef.current);
|
||||
clearTimeout(showButtonTimerRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const scrollConversationToBottom = useCallback(() => {
|
||||
if (!conversationRef.current || userInterruptedScroll) return;
|
||||
const isAtBottom = useCallback(() => {
|
||||
const el = conversationRef.current;
|
||||
if (!el) return true;
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}, []);
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
if (!conversationRef?.current) return;
|
||||
// Arm on upward scroll intent; requiring !isAtBottom() missed small nudges still inside SCROLL_THRESHOLD.
|
||||
const markInterruptedIfLoading = useCallback(() => {
|
||||
if (userInterruptedRef.current || status !== 'loading') return;
|
||||
userInterruptedRef.current = true;
|
||||
setInterrupted(true);
|
||||
}, [status]);
|
||||
|
||||
if (status === 'idle' || !queries[queries.length - 1]?.response) {
|
||||
conversationRef.current.scrollTo({
|
||||
behavior: 'smooth',
|
||||
top: conversationRef.current.scrollHeight,
|
||||
});
|
||||
} else {
|
||||
conversationRef.current.scrollTop =
|
||||
conversationRef.current.scrollHeight;
|
||||
const handleWheel = useCallback(
|
||||
(e: React.WheelEvent) => {
|
||||
if (e.deltaY < 0) markInterruptedIfLoading();
|
||||
},
|
||||
[markInterruptedIfLoading],
|
||||
);
|
||||
|
||||
const handleTouchStart = useCallback((e: React.TouchEvent) => {
|
||||
lastTouchYRef.current = e.touches[0].clientY;
|
||||
}, []);
|
||||
|
||||
const handleTouchMove = useCallback(
|
||||
(e: React.TouchEvent) => {
|
||||
const y = e.touches[0].clientY;
|
||||
if (lastTouchYRef.current !== null && y > lastTouchYRef.current) {
|
||||
markInterruptedIfLoading();
|
||||
}
|
||||
});
|
||||
}, [userInterruptedScroll, status, queries]);
|
||||
lastTouchYRef.current = y;
|
||||
},
|
||||
[markInterruptedIfLoading],
|
||||
);
|
||||
|
||||
const checkScrollPosition = useCallback(() => {
|
||||
const setButtonHidden = useCallback(() => {
|
||||
clearTimeout(showButtonTimerRef.current);
|
||||
showButtonTimerRef.current = undefined;
|
||||
setScrollButtonVisible(false);
|
||||
}, []);
|
||||
|
||||
const setButtonVisibleDebounced = useCallback(() => {
|
||||
if (showButtonTimerRef.current) return;
|
||||
showButtonTimerRef.current = setTimeout(() => {
|
||||
setScrollButtonVisible(true);
|
||||
showButtonTimerRef.current = undefined;
|
||||
}, 300);
|
||||
}, []);
|
||||
|
||||
const scrollConversationToBottom = useCallback(
|
||||
(instant?: boolean) => {
|
||||
if (!conversationRef.current) return;
|
||||
|
||||
isAutoScrollingRef.current = true;
|
||||
clearTimeout(smoothScrollTimeoutRef.current);
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
if (!conversationRef?.current) return;
|
||||
|
||||
if (instant) {
|
||||
conversationRef.current.scrollTop =
|
||||
conversationRef.current.scrollHeight;
|
||||
if (isAtBottom()) {
|
||||
setButtonHidden();
|
||||
}
|
||||
isAutoScrollingRef.current = false;
|
||||
} else {
|
||||
conversationRef.current.scrollTo({
|
||||
behavior: 'smooth',
|
||||
top: conversationRef.current.scrollHeight,
|
||||
});
|
||||
smoothScrollTimeoutRef.current = setTimeout(() => {
|
||||
if (isAtBottom()) {
|
||||
setButtonHidden();
|
||||
}
|
||||
isAutoScrollingRef.current = false;
|
||||
}, 500);
|
||||
}
|
||||
});
|
||||
},
|
||||
[isAtBottom, setButtonHidden],
|
||||
);
|
||||
|
||||
const handleScroll = useCallback(() => {
|
||||
const el = conversationRef.current;
|
||||
if (!el) return;
|
||||
const isAtBottom =
|
||||
el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
setHasScrolledToLast(isAtBottom);
|
||||
}, [setHasScrolledToLast]);
|
||||
|
||||
const atBottom = isAtBottom();
|
||||
|
||||
if (atBottom && userInterruptedRef.current) {
|
||||
userInterruptedRef.current = false;
|
||||
setInterrupted(false);
|
||||
}
|
||||
|
||||
if (atBottom) {
|
||||
setButtonHidden();
|
||||
isAutoScrollingRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (isAutoScrollingRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
setButtonVisibleDebounced();
|
||||
}, [isAtBottom, setButtonHidden, setButtonVisibleDebounced]);
|
||||
|
||||
const lastQuery = queries[queries.length - 1];
|
||||
const lastQueryResponse = lastQuery?.response;
|
||||
@@ -103,34 +193,46 @@ export default function ConversationMessages({
|
||||
const lastQueryThought = lastQuery?.thought;
|
||||
|
||||
useEffect(() => {
|
||||
if (!userInterruptedScroll) {
|
||||
scrollConversationToBottom();
|
||||
if (interrupted) return;
|
||||
|
||||
const prevQueries = prevQueriesRef.current;
|
||||
const isConversationSwitch =
|
||||
prevQueries !== queries && prevQueries[0] !== queries[0];
|
||||
|
||||
if (isInitialLoad.current || isConversationSwitch) {
|
||||
isInitialLoad.current = false;
|
||||
scrollConversationToBottom(true);
|
||||
prevQueriesRef.current = queries;
|
||||
return;
|
||||
}
|
||||
|
||||
const isNewMessage = queries.length > prevQueries.length;
|
||||
prevQueriesRef.current = queries;
|
||||
|
||||
scrollConversationToBottom(isNewMessage ? false : true);
|
||||
}, [
|
||||
queries.length,
|
||||
lastQueryResponse,
|
||||
lastQueryError,
|
||||
lastQueryThought,
|
||||
userInterruptedScroll,
|
||||
interrupted,
|
||||
scrollConversationToBottom,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (status === 'idle') {
|
||||
setUserInterruptedScroll(false);
|
||||
userInterruptedRef.current = false;
|
||||
setInterrupted(false);
|
||||
}
|
||||
}, [status]);
|
||||
|
||||
useEffect(() => {
|
||||
const currentConversationRef = conversationRef.current;
|
||||
currentConversationRef?.addEventListener('scroll', checkScrollPosition);
|
||||
currentConversationRef?.addEventListener('scroll', handleScroll);
|
||||
return () => {
|
||||
currentConversationRef?.removeEventListener(
|
||||
'scroll',
|
||||
checkScrollPosition,
|
||||
);
|
||||
currentConversationRef?.removeEventListener('scroll', handleScroll);
|
||||
};
|
||||
}, [checkScrollPosition]);
|
||||
}, [handleScroll]);
|
||||
|
||||
const retryIconProps = {
|
||||
width: 12,
|
||||
@@ -208,7 +310,7 @@ export default function ConversationMessages({
|
||||
>
|
||||
<div className="flex max-w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<div className="flex h-[34px] w-[34px] items-center justify-center overflow-hidden rounded-full">
|
||||
<div className="flex h-8.5 w-8.5 items-center justify-center overflow-hidden rounded-full">
|
||||
<img
|
||||
src={DocsGPT3}
|
||||
alt={t('conversation.answer')}
|
||||
@@ -237,18 +339,24 @@ export default function ConversationMessages({
|
||||
return (
|
||||
<div
|
||||
ref={conversationRef}
|
||||
onWheel={handleUserScrollInterruption}
|
||||
onTouchMove={handleUserScrollInterruption}
|
||||
onWheel={handleWheel}
|
||||
onTouchStart={handleTouchStart}
|
||||
onTouchMove={handleTouchMove}
|
||||
className="flex h-full w-full justify-center overflow-y-auto will-change-scroll sm:pt-6 lg:pt-12"
|
||||
>
|
||||
{queries.length > 0 && !hasScrolledToLast && (
|
||||
{queries.length > 0 && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setUserInterruptedScroll(false);
|
||||
userInterruptedRef.current = false;
|
||||
setInterrupted(false);
|
||||
scrollConversationToBottom();
|
||||
}}
|
||||
aria-label={t('Scroll to bottom') || 'Scroll to bottom'}
|
||||
className="border-border bg-card fixed right-14 bottom-40 z-10 flex h-7 w-7 items-center justify-center rounded-full border md:h-9 md:w-9"
|
||||
className={`border-border bg-card fixed bottom-40 left-1/2 z-10 flex h-7 w-7 -translate-x-1/2 items-center justify-center rounded-full border transition-all duration-300 ease-in-out md:right-14 md:left-auto md:h-9 md:w-9 md:translate-x-0 ${
|
||||
scrollButtonVisible
|
||||
? 'pointer-events-auto scale-100 opacity-100'
|
||||
: 'pointer-events-none scale-75 opacity-0'
|
||||
}`}
|
||||
>
|
||||
<img
|
||||
src={ArrowDown}
|
||||
@@ -261,8 +369,8 @@ export default function ConversationMessages({
|
||||
<div
|
||||
className={
|
||||
isSplitView
|
||||
? 'w-full max-w-[1300px] px-2'
|
||||
: 'w-full max-w-[1300px] px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12'
|
||||
? 'w-full max-w-325 px-2'
|
||||
: 'w-full max-w-325 px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12'
|
||||
}
|
||||
>
|
||||
{headerContent}
|
||||
|
||||
@@ -325,6 +325,14 @@
|
||||
"s3": {
|
||||
"label": "Amazon S3",
|
||||
"heading": "Inhalt von Amazon S3 hinzufügen"
|
||||
},
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Von SharePoint hochladen"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Von Confluence hochladen"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -341,6 +341,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Upload from SharePoint"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Upload from Confluence"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Subir desde SharePoint"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Subir desde Confluence"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "SharePointからアップロード"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Confluenceからアップロード"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "Загрузить из SharePoint"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "Загрузить из Confluence"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "從SharePoint上傳"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "從Confluence上傳"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -329,6 +329,10 @@
|
||||
"share_point": {
|
||||
"label": "SharePoint",
|
||||
"heading": "从SharePoint上传"
|
||||
},
|
||||
"confluence": {
|
||||
"label": "Confluence",
|
||||
"heading": "从Confluence上传"
|
||||
}
|
||||
},
|
||||
"connectors": {
|
||||
|
||||
@@ -266,6 +266,23 @@ function Upload({
|
||||
initialSelectedFolders={selectedFolders}
|
||||
/>
|
||||
);
|
||||
case 'confluence_picker':
|
||||
return (
|
||||
<FilePicker
|
||||
key={field.name}
|
||||
onSelectionChange={(
|
||||
selectedFileIds: string[],
|
||||
selectedFolderIds: string[] = [],
|
||||
) => {
|
||||
setSelectedFiles(selectedFileIds);
|
||||
setSelectedFolders(selectedFolderIds);
|
||||
}}
|
||||
provider="confluence"
|
||||
token={token}
|
||||
initialSelectedFiles={selectedFiles}
|
||||
initialSelectedFolders={selectedFolders}
|
||||
/>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
@@ -551,6 +568,9 @@ function Upload({
|
||||
const hasSharePointPicker = schema.some(
|
||||
(field: FormField) => field.type === 'share_point_picker',
|
||||
);
|
||||
const hasConfluencePicker = schema.some(
|
||||
(field: FormField) => field.type === 'confluence_picker',
|
||||
);
|
||||
|
||||
let configData: Record<string, unknown> = { ...ingestor.config };
|
||||
|
||||
@@ -561,7 +581,8 @@ function Upload({
|
||||
} else if (
|
||||
hasRemoteFilePicker ||
|
||||
hasGoogleDrivePicker ||
|
||||
hasSharePointPicker
|
||||
hasSharePointPicker ||
|
||||
hasConfluencePicker
|
||||
) {
|
||||
const sessionToken = getSessionToken(ingestor.type as string);
|
||||
configData = {
|
||||
@@ -721,6 +742,9 @@ function Upload({
|
||||
const hasSharePointPicker = schema.some(
|
||||
(field: FormField) => field.type === 'share_point_picker',
|
||||
);
|
||||
const hasConfluencePicker = schema.some(
|
||||
(field: FormField) => field.type === 'confluence_picker',
|
||||
);
|
||||
|
||||
if (hasLocalFilePicker) {
|
||||
if (files.length === 0) {
|
||||
@@ -729,7 +753,8 @@ function Upload({
|
||||
} else if (
|
||||
hasRemoteFilePicker ||
|
||||
hasGoogleDrivePicker ||
|
||||
hasSharePointPicker
|
||||
hasSharePointPicker ||
|
||||
hasConfluencePicker
|
||||
) {
|
||||
if (selectedFiles.length === 0 && selectedFolders.length === 0) {
|
||||
return true;
|
||||
|
||||
@@ -6,8 +6,10 @@ import RedditIcon from '../../assets/reddit.svg';
|
||||
import DriveIcon from '../../assets/drive.svg';
|
||||
import S3Icon from '../../assets/s3.svg';
|
||||
import SharePoint from '../../assets/sharepoint.svg';
|
||||
import ConfluenceIcon from '../../assets/confluence.svg';
|
||||
|
||||
export type IngestorType =
|
||||
| 'confluence'
|
||||
| 'crawler'
|
||||
| 'github'
|
||||
| 'reddit'
|
||||
@@ -38,7 +40,8 @@ export type FieldType =
|
||||
| 'local_file_picker'
|
||||
| 'remote_file_picker'
|
||||
| 'google_drive_picker'
|
||||
| 'share_point_picker';
|
||||
| 'share_point_picker'
|
||||
| 'confluence_picker';
|
||||
|
||||
export interface FormField {
|
||||
name: string;
|
||||
@@ -214,6 +217,24 @@ export const IngestorFormSchemas: IngestorSchema[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
key: 'confluence',
|
||||
label: 'Confluence',
|
||||
icon: ConfluenceIcon,
|
||||
heading: 'Upload from Confluence',
|
||||
validate: () => {
|
||||
const confluenceClientId = import.meta.env.VITE_CONFLUENCE_CLIENT_ID;
|
||||
return !!confluenceClientId;
|
||||
},
|
||||
fields: [
|
||||
{
|
||||
name: 'files',
|
||||
label: 'Select Pages from Confluence',
|
||||
type: 'confluence_picker',
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export const IngestorDefaultConfigs: Record<
|
||||
@@ -261,6 +282,13 @@ export const IngestorDefaultConfigs: Record<
|
||||
recursive: true,
|
||||
},
|
||||
},
|
||||
confluence: {
|
||||
name: '',
|
||||
config: {
|
||||
file_ids: '',
|
||||
folder_ids: '',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export interface IngestorOption {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -336,6 +336,34 @@ class TestSaveWorkflowRun:
|
||||
agent._save_workflow_run("query")
|
||||
|
||||
mock_collection.insert_one.assert_called_once()
|
||||
saved_doc = mock_collection.insert_one.call_args.args[0]
|
||||
assert saved_doc["user"] == "user1"
|
||||
assert saved_doc["user_id"] == "user1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dual_writes_when_mongo_insert_returns_id(self):
|
||||
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {"query": "test"}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
insert_result = MagicMock()
|
||||
insert_result.inserted_id = "507f1f77bcf86cd799439012"
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.insert_one.return_value = insert_result
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings, \
|
||||
patch("application.agents.workflow_agent.dual_write") as mock_dual_write:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
agent._save_workflow_run("query")
|
||||
|
||||
mock_dual_write.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_does_not_propagate(self):
|
||||
|
||||
@@ -973,14 +973,21 @@ class TestUpdateAgent:
|
||||
agent_id = ObjectId()
|
||||
existing = self._make_existing_agent(agent_id)
|
||||
mock_col = Mock()
|
||||
mock_repo = Mock()
|
||||
mock_col.find_one.return_value = existing
|
||||
mock_col.update_one.return_value = Mock(matched_count=1, modified_count=1)
|
||||
mock_handle_img = Mock(return_value=("", None))
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.routes.agents_collection", mock_col
|
||||
), patch(
|
||||
"application.api.user.agents.routes.handle_image_upload", mock_handle_img
|
||||
), patch(
|
||||
"application.api.user.agents.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
f"/api/update_agent/{agent_id}",
|
||||
@@ -993,6 +1000,11 @@ class TestUpdateAgent:
|
||||
response = UpdateAgent().put(str(agent_id))
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_repo.update_by_legacy_id.assert_called_once_with(
|
||||
str(agent_id),
|
||||
"user1",
|
||||
{"name": "Updated Name"},
|
||||
)
|
||||
|
||||
def test_returns_400_invalid_status(self, app):
|
||||
from application.api.user.agents.routes import UpdateAgent
|
||||
@@ -1945,14 +1957,21 @@ class TestDeleteAgent:
|
||||
|
||||
agent_id = ObjectId()
|
||||
mock_col = Mock()
|
||||
mock_repo = Mock()
|
||||
mock_col.find_one_and_delete.return_value = {
|
||||
"_id": agent_id,
|
||||
"user": "user1",
|
||||
"agent_type": "classic",
|
||||
}
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.agents.routes.agents_collection", mock_col
|
||||
), patch(
|
||||
"application.api.user.agents.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(f"/api/delete_agent?id={agent_id}"):
|
||||
from flask import request
|
||||
@@ -1961,6 +1980,7 @@ class TestDeleteAgent:
|
||||
response = DeleteAgent().delete()
|
||||
assert response.status_code == 200
|
||||
assert response.json["id"] == str(agent_id)
|
||||
mock_repo.delete_by_legacy_id.assert_called_once_with(str(agent_id), "user1")
|
||||
|
||||
def test_deletes_workflow_agent_cleans_up(self, app):
|
||||
from application.api.user.agents.routes import DeleteAgent
|
||||
|
||||
@@ -18,12 +18,19 @@ class TestCreatePrompt:
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
|
||||
mock_collection = Mock()
|
||||
mock_repo = Mock()
|
||||
inserted_id = ObjectId()
|
||||
mock_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.prompts.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/create_prompt",
|
||||
@@ -41,6 +48,12 @@ class TestCreatePrompt:
|
||||
doc = mock_collection.insert_one.call_args[0][0]
|
||||
assert doc["name"] == "My Prompt"
|
||||
assert doc["user"] == "user1"
|
||||
mock_repo.create.assert_called_once_with(
|
||||
"user1",
|
||||
"My Prompt",
|
||||
"You are helpful.",
|
||||
legacy_mongo_id=str(inserted_id),
|
||||
)
|
||||
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
from application.api.user.prompts.routes import CreatePrompt
|
||||
@@ -204,10 +217,17 @@ class TestDeletePrompt:
|
||||
|
||||
prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_repo = Mock()
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.prompts.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/delete_prompt",
|
||||
@@ -224,6 +244,7 @@ class TestDeletePrompt:
|
||||
mock_collection.delete_one.assert_called_once_with(
|
||||
{"_id": prompt_id, "user": "user1"}
|
||||
)
|
||||
mock_repo.delete_by_legacy_id.assert_called_once_with(str(prompt_id), "user1")
|
||||
|
||||
def test_returns_400_missing_id(self, app):
|
||||
from application.api.user.prompts.routes import DeletePrompt
|
||||
@@ -249,10 +270,17 @@ class TestUpdatePrompt:
|
||||
|
||||
prompt_id = ObjectId()
|
||||
mock_collection = Mock()
|
||||
mock_repo = Mock()
|
||||
|
||||
def _run_dual_write(_repo_cls, fn):
|
||||
fn(mock_repo)
|
||||
|
||||
with patch(
|
||||
"application.api.user.prompts.routes.prompts_collection",
|
||||
mock_collection,
|
||||
), patch(
|
||||
"application.api.user.prompts.routes.dual_write",
|
||||
side_effect=_run_dual_write,
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/api/update_prompt",
|
||||
@@ -271,6 +299,12 @@ class TestUpdatePrompt:
|
||||
assert response.status_code == 200
|
||||
assert response.json["success"] is True
|
||||
mock_collection.update_one.assert_called_once()
|
||||
mock_repo.update_by_legacy_id.assert_called_once_with(
|
||||
str(prompt_id),
|
||||
"user1",
|
||||
"Updated",
|
||||
"New content",
|
||||
)
|
||||
|
||||
def test_returns_400_missing_fields(self, app):
|
||||
from application.api.user.prompts.routes import UpdatePrompt
|
||||
|
||||
@@ -200,7 +200,7 @@ class TestSetupPeriodicTasks:
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 3
|
||||
assert sender.add_periodic_task.call_count == 4
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
@@ -210,6 +210,8 @@ class TestSetupPeriodicTasks:
|
||||
assert calls[1][0][0] == timedelta(weeks=1)
|
||||
# monthly
|
||||
assert calls[2][0][0] == timedelta(days=30)
|
||||
# pending_tool_state TTL cleanup (60s)
|
||||
assert calls[3][0][0] == timedelta(seconds=60)
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
|
||||
@@ -886,7 +886,13 @@ class TestWorkflowListPost:
|
||||
mock_wf_collection = Mock()
|
||||
mock_wf_collection.insert_one.return_value = Mock(inserted_id=inserted_id)
|
||||
mock_nodes_collection = Mock()
|
||||
mock_nodes_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId(), ObjectId()]
|
||||
)
|
||||
mock_edges_collection = Mock()
|
||||
mock_edges_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId()]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.user.workflows.routes.workflows_collection",
|
||||
@@ -1152,7 +1158,13 @@ class TestWorkflowDetailPut:
|
||||
}
|
||||
mock_wf_collection.update_one.return_value = Mock()
|
||||
mock_nodes_collection = Mock()
|
||||
mock_nodes_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId(), ObjectId()]
|
||||
)
|
||||
mock_edges_collection = Mock()
|
||||
mock_edges_collection.insert_many.return_value = Mock(
|
||||
inserted_ids=[ObjectId()]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"application.api.user.workflows.routes.workflows_collection",
|
||||
|
||||
@@ -31,6 +31,7 @@ class TestGitHubLoaderFetchFileContent:
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.github.com/repos/owner/repo/contents/README.md",
|
||||
headers=loader.headers,
|
||||
timeout=100,
|
||||
)
|
||||
|
||||
@patch("application.parser.remote.github_loader.requests.get")
|
||||
@@ -66,7 +67,7 @@ class TestGitHubLoaderFetchRepoFiles:
|
||||
def test_recurses_directories(self, mock_get):
|
||||
loader = GitHubLoader()
|
||||
|
||||
def side_effect(url, headers=None):
|
||||
def side_effect(url, headers=None, timeout=None):
|
||||
if url.endswith("/contents/"):
|
||||
return make_response([
|
||||
{"type": "file", "path": "README.md"},
|
||||
|
||||
0
tests/storage/db/__init__.py
Normal file
0
tests/storage/db/__init__.py
Normal file
63
tests/storage/db/conftest.py
Normal file
63
tests/storage/db/conftest.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Fixtures for repository tests against a real Postgres instance.
|
||||
|
||||
These tests hit the local dev Postgres (the DBngin instance on this machine,
|
||||
or CI's service container). Each test runs inside a transaction that is
|
||||
rolled back at the end, so tests never leak state into each other and the
|
||||
database stays clean without needing per-test CREATE/DROP overhead.
|
||||
|
||||
Required env:
|
||||
POSTGRES_URI — e.g. postgresql+psycopg://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
Tests are skipped automatically when POSTGRES_URI is unset so that
|
||||
contributors without a local Postgres can still run the rest of the suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def _run_alembic_upgrade(engine):
|
||||
"""Run ``alembic upgrade head`` to ensure the full schema is present.
|
||||
|
||||
Non-zero exit is re-raised so genuine schema-drift bugs surface as
|
||||
test failures. If alembic reports the schema is already at head,
|
||||
the subprocess still exits zero.
|
||||
"""
|
||||
alembic_ini = Path(__file__).resolve().parents[3] / "application" / "alembic.ini"
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "alembic", "-c", str(alembic_ini), "upgrade", "head"],
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pg_engine():
|
||||
"""Session-scoped engine pointing at the test Postgres."""
|
||||
if not settings.POSTGRES_URI:
|
||||
pytest.skip("POSTGRES_URI not set")
|
||||
engine = create_engine(settings.POSTGRES_URI)
|
||||
_run_alembic_upgrade(engine)
|
||||
yield engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pg_conn(pg_engine):
|
||||
"""Per-test connection wrapped in a transaction that always rolls back.
|
||||
|
||||
Repositories receive this connection and operate normally. At teardown
|
||||
the outer transaction is rolled back so no data persists between tests.
|
||||
"""
|
||||
conn = pg_engine.connect()
|
||||
txn = conn.begin()
|
||||
yield conn
|
||||
txn.rollback()
|
||||
conn.close()
|
||||
0
tests/storage/db/repositories/__init__.py
Normal file
0
tests/storage/db/repositories/__init__.py
Normal file
116
tests/storage/db/repositories/test_agent_folders.py
Normal file
116
tests/storage/db/repositories/test_agent_folders.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Tests for AgentFoldersRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> AgentFoldersRepository:
|
||||
return AgentFoldersRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_folder(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "My Folder")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "My Folder"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "f")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "f")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "f")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_folders(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "f1")
|
||||
repo.create("alice", "f2")
|
||||
repo.create("bob", "f3")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
def test_updates_name(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
updated = repo.update(created["id"], "user-1", {"name": "new"})
|
||||
assert updated is True
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_update_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
updated = repo.update(created["id"], "user-other", {"name": "new"})
|
||||
assert updated is False
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "old"
|
||||
|
||||
def test_update_disallowed_field_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "f")
|
||||
updated = repo.update(created["id"], "user-1", {"id": "00000000-0000-0000-0000-000000000000"})
|
||||
assert updated is False
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_folder(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "f")
|
||||
deleted = repo.delete(created["id"], "user-1")
|
||||
assert deleted is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
def test_delete_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "f")
|
||||
deleted = repo.delete(created["id"], "user-other")
|
||||
assert deleted is False
|
||||
assert repo.get(created["id"], "user-1") is not None
|
||||
|
||||
|
||||
class TestTenantIsolation:
|
||||
def test_user_a_cannot_see_user_b_folders(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
folder_a = repo.create("alice", "private")
|
||||
assert repo.get(folder_a["id"], "bob") is None
|
||||
|
||||
def test_list_returns_only_own_folders(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "a1")
|
||||
repo.create("bob", "b1")
|
||||
alice_folders = repo.list_for_user("alice")
|
||||
bob_folders = repo.list_for_user("bob")
|
||||
assert len(alice_folders) == 1
|
||||
assert len(bob_folders) == 1
|
||||
assert alice_folders[0]["name"] == "a1"
|
||||
assert bob_folders[0]["name"] == "b1"
|
||||
226
tests/storage/db/repositories/test_agents.py
Normal file
226
tests/storage/db/repositories/test_agents.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Tests for AgentsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> AgentsRepository:
|
||||
return AgentsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_agent_minimal(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "My Agent", "draft")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "My Agent"
|
||||
assert doc["status"] == "draft"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_create_with_kwargs(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
"user-1", "Agent2", "active",
|
||||
description="A test agent",
|
||||
chunks=5,
|
||||
tools=[{"name": "search"}],
|
||||
shared=True,
|
||||
)
|
||||
assert doc["description"] == "A test agent"
|
||||
assert doc["chunks"] == 5
|
||||
assert doc["tools"] == [{"name": "search"}]
|
||||
assert doc["shared"] is True
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("u", "a", "draft")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
"u",
|
||||
"a",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
def test_create_normalizes_blank_key_to_null(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("u", "a", "draft", key="")
|
||||
assert doc["key"] is None
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "a", "draft")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "a", "draft")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"a",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestFindByKey:
|
||||
def test_finds_agent_by_key(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("u", "a", "draft", key="my-unique-key")
|
||||
fetched = repo.find_by_key("my-unique-key")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_find_by_key_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.find_by_key("nonexistent-key") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_agents(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "a1", "draft")
|
||||
repo.create("alice", "a2", "active")
|
||||
repo.create("bob", "b1", "draft")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
def test_updates_name(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "draft")
|
||||
updated = repo.update(created["id"], "user-1", {"name": "new"})
|
||||
assert updated is True
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_update_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "draft")
|
||||
updated = repo.update(created["id"], "user-other", {"name": "new"})
|
||||
assert updated is False
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "old"
|
||||
|
||||
def test_update_disallowed_field_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "a", "draft")
|
||||
updated = repo.update(created["id"], "user-1", {"id": "bad"})
|
||||
assert updated is False
|
||||
|
||||
def test_update_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(
|
||||
"user-1",
|
||||
"old",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
updated = repo.update_by_legacy_id(
|
||||
"507f1f77bcf86cd799439011",
|
||||
"user-1",
|
||||
{"name": "new", "last_used_at": None},
|
||||
)
|
||||
assert updated is True
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_update_normalizes_blank_key_to_null(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "draft", key="my-unique-key")
|
||||
updated = repo.update(created["id"], "user-1", {"key": ""})
|
||||
assert updated is True
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["key"] is None
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_agent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "a", "draft")
|
||||
deleted = repo.delete(created["id"], "user-1")
|
||||
assert deleted is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
def test_delete_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "a", "draft")
|
||||
deleted = repo.delete(created["id"], "user-other")
|
||||
assert deleted is False
|
||||
assert repo.get(created["id"], "user-1") is not None
|
||||
|
||||
def test_delete_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"a",
|
||||
"draft",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
deleted = repo.delete_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert deleted is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
|
||||
class TestSetFolder:
|
||||
def test_assigns_folder(self, pg_conn):
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
|
||||
folder_repo = AgentFoldersRepository(pg_conn)
|
||||
folder = folder_repo.create("user-1", "f")
|
||||
repo = _repo(pg_conn)
|
||||
agent = repo.create("user-1", "a", "draft")
|
||||
repo.set_folder(agent["id"], "user-1", folder["id"])
|
||||
fetched = repo.get(agent["id"], "user-1")
|
||||
assert str(fetched["folder_id"]) == str(folder["id"])
|
||||
|
||||
def test_clear_folder(self, pg_conn):
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
|
||||
folder_repo = AgentFoldersRepository(pg_conn)
|
||||
folder = folder_repo.create("user-1", "f")
|
||||
repo = _repo(pg_conn)
|
||||
agent = repo.create("user-1", "a", "draft", folder_id=folder["id"])
|
||||
repo.set_folder(agent["id"], "user-1", None)
|
||||
fetched = repo.get(agent["id"], "user-1")
|
||||
assert fetched["folder_id"] is None
|
||||
|
||||
|
||||
class TestClearFolderForAll:
|
||||
def test_clears_folder_from_all_agents(self, pg_conn):
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
|
||||
folder_repo = AgentFoldersRepository(pg_conn)
|
||||
folder = folder_repo.create("user-1", "f")
|
||||
repo = _repo(pg_conn)
|
||||
a1 = repo.create("user-1", "a1", "draft", folder_id=folder["id"])
|
||||
a2 = repo.create("user-1", "a2", "draft", folder_id=folder["id"])
|
||||
repo.clear_folder_for_all(folder["id"], "user-1")
|
||||
assert repo.get(a1["id"], "user-1")["folder_id"] is None
|
||||
assert repo.get(a2["id"], "user-1")["folder_id"] is None
|
||||
92
tests/storage/db/repositories/test_attachments.py
Normal file
92
tests/storage/db/repositories/test_attachments.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for AttachmentsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> AttachmentsRepository:
|
||||
return AttachmentsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_attachment(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "file.pdf", "/uploads/file.pdf")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["filename"] == "file.pdf"
|
||||
assert doc["upload_path"] == "/uploads/file.pdf"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_creates_with_optional_fields(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "img.png", "/uploads/img.png",
|
||||
mime_type="image/png", size=1024)
|
||||
assert doc["mime_type"] == "image/png"
|
||||
assert doc["size"] == 1024
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("u", "f", "/p")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
"u",
|
||||
"f",
|
||||
"/p",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("u", "f", "/p")
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("u", "f", "/p")
|
||||
assert repo.get(created["id"], "other") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"u",
|
||||
"f",
|
||||
"/p",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "u")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_attachments(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "a1.pdf", "/a1")
|
||||
repo.create("alice", "a2.pdf", "/a2")
|
||||
repo.create("bob", "b1.pdf", "/b1")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
def test_list_empty_for_unknown_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
results = repo.list_for_user("nonexistent")
|
||||
assert results == []
|
||||
94
tests/storage/db/repositories/test_connector_sessions.py
Normal file
94
tests/storage/db/repositories/test_connector_sessions.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Tests for ConnectorSessionsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.connector_sessions import ConnectorSessionsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> ConnectorSessionsRepository:
|
||||
return ConnectorSessionsRepository(conn)
|
||||
|
||||
|
||||
class TestUpsert:
|
||||
def test_creates_session(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.upsert("user-1", "google", {"token": "abc123"})
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["provider"] == "google"
|
||||
assert doc["session_data"] == {"token": "abc123"}
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_upsert_creates_second_session(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
first = repo.upsert("user-1", "google", {"token": "v1"})
|
||||
assert first["session_data"] == {"token": "v1"}
|
||||
# Without a UNIQUE(user_id, provider) constraint, a second upsert
|
||||
# creates another row (ON CONFLICT DO NOTHING never fires).
|
||||
second = repo.upsert("user-1", "google", {"token": "v2"})
|
||||
assert second["session_data"] == {"token": "v2"}
|
||||
|
||||
|
||||
class TestGetByUserProvider:
|
||||
def test_finds_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("u", "slack", {"key": "val"})
|
||||
fetched = repo.get_by_user_provider("u", "slack")
|
||||
assert fetched is not None
|
||||
assert fetched["session_data"] == {"key": "val"}
|
||||
|
||||
def test_returns_none_for_missing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get_by_user_provider("u", "nonexistent") is None
|
||||
|
||||
def test_different_providers_are_separate(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("u", "google", {"g": 1})
|
||||
repo.upsert("u", "slack", {"s": 2})
|
||||
g = repo.get_by_user_provider("u", "google")
|
||||
s = repo.get_by_user_provider("u", "slack")
|
||||
assert g["session_data"] == {"g": 1}
|
||||
assert s["session_data"] == {"s": 2}
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_all_providers(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("alice", "google", {"g": 1})
|
||||
repo.upsert("alice", "slack", {"s": 1})
|
||||
repo.upsert("bob", "google", {"g": 2})
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
def test_list_empty_for_unknown_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.list_for_user("nonexistent") == []
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_session(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("u", "google", {"t": 1})
|
||||
deleted = repo.delete("u", "google")
|
||||
assert deleted is True
|
||||
assert repo.get_by_user_provider("u", "google") is None
|
||||
|
||||
def test_delete_nonexistent_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
deleted = repo.delete("u", "nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
def test_delete_one_provider_leaves_others(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.upsert("u", "google", {"g": 1})
|
||||
repo.upsert("u", "slack", {"s": 1})
|
||||
repo.delete("u", "google")
|
||||
assert repo.get_by_user_provider("u", "google") is None
|
||||
assert repo.get_by_user_provider("u", "slack") is not None
|
||||
374
tests/storage/db/repositories/test_conversations.py
Normal file
374
tests/storage/db/repositories/test_conversations.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Tests for ConversationsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> ConversationsRepository:
|
||||
return ConversationsRepository(conn)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conversation CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_conversation(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "My Chat")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "My Chat"
|
||||
assert doc["id"] is not None
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_agent(self, pg_conn):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
agent_repo = AgentsRepository(pg_conn)
|
||||
agent = agent_repo.create("user-1", "a", "active")
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create(
|
||||
"user-1", "Chat",
|
||||
agent_id=agent["id"],
|
||||
api_key="ak-123",
|
||||
is_shared_usage=True,
|
||||
shared_token="tok-abc",
|
||||
)
|
||||
assert str(doc["agent_id"]) == agent["id"]
|
||||
assert doc["api_key"] == "ak-123"
|
||||
assert doc["is_shared_usage"] is True
|
||||
assert doc["shared_token"] == "tok-abc"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_owned(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
|
||||
|
||||
def test_get_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_own_conversations(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "c1")
|
||||
repo.create("alice", "c2")
|
||||
repo.create("bob", "c3")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
def test_excludes_api_key_without_agent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "normal")
|
||||
repo.create("alice", "api-only", api_key="key-1")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "normal"
|
||||
|
||||
|
||||
class TestRename:
|
||||
def test_renames(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
assert repo.rename(created["id"], "user-1", "new") is True
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_rename_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old")
|
||||
assert repo.rename(created["id"], "user-other", "new") is False
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
assert repo.delete(created["id"], "user-1") is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
def test_delete_wrong_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "c")
|
||||
assert repo.delete(created["id"], "user-other") is False
|
||||
|
||||
def test_delete_cascades_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "hi", "response": "hello"})
|
||||
repo.delete(conv["id"], "user-1")
|
||||
assert repo.get_messages(conv["id"]) == []
|
||||
|
||||
def test_delete_all_for_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("user-1", "c1")
|
||||
repo.create("user-1", "c2")
|
||||
repo.create("user-2", "c3")
|
||||
count = repo.delete_all_for_user("user-1")
|
||||
assert count == 2
|
||||
assert repo.list_for_user("user-1") == []
|
||||
assert len(repo.list_for_user("user-2")) == 1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Messages
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppendMessage:
|
||||
def test_append_first_message(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
msg = repo.append_message(conv["id"], {
|
||||
"prompt": "hello",
|
||||
"response": "hi there",
|
||||
"model_id": "gpt-4",
|
||||
})
|
||||
assert msg["position"] == 0
|
||||
assert msg["prompt"] == "hello"
|
||||
assert msg["response"] == "hi there"
|
||||
assert msg["model_id"] == "gpt-4"
|
||||
|
||||
def test_append_increments_position(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
m0 = repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
|
||||
m1 = repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
|
||||
assert m0["position"] == 0
|
||||
assert m1["position"] == 1
|
||||
|
||||
def test_append_with_sources_and_tools(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
msg = repo.append_message(conv["id"], {
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"sources": [{"title": "doc1"}],
|
||||
"tool_calls": [{"name": "search", "args": {}}],
|
||||
"metadata": {"search_query": "rewritten"},
|
||||
})
|
||||
assert msg["sources"] == [{"title": "doc1"}]
|
||||
assert msg["tool_calls"] == [{"name": "search", "args": {}}]
|
||||
assert msg["metadata"] == {"search_query": "rewritten"}
|
||||
|
||||
def test_append_preserves_explicit_timestamp(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
ts = datetime.now(timezone.utc)
|
||||
msg = repo.append_message(conv["id"], {
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"timestamp": ts,
|
||||
})
|
||||
assert msg["timestamp"] == ts
|
||||
|
||||
|
||||
class TestGetMessages:
|
||||
def test_returns_ordered_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
|
||||
repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
|
||||
msgs = repo.get_messages(conv["id"])
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["position"] == 0
|
||||
assert msgs[1]["position"] == 1
|
||||
|
||||
def test_get_message_at(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q1", "response": "a1"})
|
||||
repo.append_message(conv["id"], {"prompt": "q2", "response": "a2"})
|
||||
msg = repo.get_message_at(conv["id"], 1)
|
||||
assert msg["prompt"] == "q2"
|
||||
|
||||
def test_get_message_at_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
assert repo.get_message_at(conv["id"], 99) is None
|
||||
|
||||
|
||||
class TestUpdateMessageAt:
|
||||
def test_updates_response(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "old"})
|
||||
assert repo.update_message_at(conv["id"], 0, {"response": "new"}) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["response"] == "new"
|
||||
|
||||
def test_update_disallowed_field(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
assert repo.update_message_at(conv["id"], 0, {"id": "bad"}) is False
|
||||
|
||||
def test_updates_explicit_timestamp(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "old"})
|
||||
ts = datetime.now(timezone.utc)
|
||||
assert repo.update_message_at(
|
||||
conv["id"], 0, {"response": "new", "timestamp": ts},
|
||||
) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["response"] == "new"
|
||||
assert msg["timestamp"] == ts
|
||||
|
||||
|
||||
class TestTruncateAfter:
|
||||
def test_truncates_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
for i in range(5):
|
||||
repo.append_message(conv["id"], {"prompt": f"q{i}", "response": f"a{i}"})
|
||||
deleted = repo.truncate_after(conv["id"], 2)
|
||||
assert deleted == 2
|
||||
msgs = repo.get_messages(conv["id"])
|
||||
assert len(msgs) == 3
|
||||
assert [m["position"] for m in msgs] == [0, 1, 2]
|
||||
|
||||
|
||||
class TestSetFeedback:
|
||||
def test_set_feedback(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
assert repo.set_feedback(conv["id"], 0, {"text": "thumbs_up"}) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["feedback"] == {"text": "thumbs_up"}
|
||||
|
||||
def test_unset_feedback(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
repo.set_feedback(conv["id"], 0, {"text": "thumbs_up"})
|
||||
assert repo.set_feedback(conv["id"], 0, None) is True
|
||||
msg = repo.get_message_at(conv["id"], 0)
|
||||
assert msg["feedback"] is None
|
||||
|
||||
def test_set_feedback_nonexistent_position(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
assert repo.set_feedback(conv["id"], 99, {"text": "x"}) is False
|
||||
|
||||
|
||||
class TestMessageCount:
|
||||
def test_counts_messages(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
assert repo.message_count(conv["id"]) == 0
|
||||
repo.append_message(conv["id"], {"prompt": "q", "response": "a"})
|
||||
assert repo.message_count(conv["id"]) == 1
|
||||
|
||||
|
||||
class TestCompressionMetadata:
|
||||
def test_set_compression_metadata(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
meta = {"is_compressed": True, "last_compression_at": "2026-01-01T00:00:00Z"}
|
||||
assert repo.update_compression_metadata(conv["id"], "user-1", meta) is True
|
||||
fetched = repo.get(conv["id"], "user-1")
|
||||
assert fetched["compression_metadata"]["is_compressed"] is True
|
||||
|
||||
def test_set_compression_flags_preserves_points(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
repo.update_compression_metadata(conv["id"], "user-1", {
|
||||
"is_compressed": False,
|
||||
"compression_points": [{"summary": "earlier"}],
|
||||
})
|
||||
assert repo.set_compression_flags(
|
||||
conv["id"], is_compressed=True, last_compression_at="2026-01-02",
|
||||
) is True
|
||||
fetched = repo.get(conv["id"], "user-1")
|
||||
assert fetched["compression_metadata"]["is_compressed"] is True
|
||||
assert fetched["compression_metadata"]["last_compression_at"] == "2026-01-02"
|
||||
assert fetched["compression_metadata"]["compression_points"] == [
|
||||
{"summary": "earlier"}
|
||||
]
|
||||
|
||||
def test_append_compression_point_slices_to_max(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
conv = repo.create("user-1", "c")
|
||||
for i in range(5):
|
||||
assert repo.append_compression_point(
|
||||
conv["id"], {"summary": f"p{i}"}, max_points=3,
|
||||
) is True
|
||||
fetched = repo.get(conv["id"], "user-1")
|
||||
points = fetched["compression_metadata"]["compression_points"]
|
||||
assert [p["summary"] for p in points] == ["p2", "p3", "p4"]
|
||||
|
||||
|
||||
class TestConcurrentAppend:
|
||||
"""Two threads appending to the same conversation must not race on
|
||||
``position``. The plan (migration-postgres.md §Phase 3) explicitly
|
||||
calls this out as the single trickiest invariant, so we exercise it
|
||||
directly with two parallel connections."""
|
||||
|
||||
def test_concurrent_appends_get_distinct_positions(self, pg_engine, pg_conn):
|
||||
import threading
|
||||
|
||||
# Arrange — one conversation, created inside the outer test txn so
|
||||
# it disappears on teardown even if the workers somehow commit.
|
||||
# We commit it explicitly so the workers' separate sessions see it.
|
||||
repo_setup = _repo(pg_conn)
|
||||
conv = repo_setup.create("user-concurrent", "c")
|
||||
pg_conn.commit()
|
||||
|
||||
try:
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
with pg_engine.begin() as worker_conn:
|
||||
ConversationsRepository(worker_conn).append_message(
|
||||
conv["id"], {"prompt": "q", "response": "a"},
|
||||
)
|
||||
except BaseException as e: # noqa: BLE001
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(2)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"worker threads errored: {errors}"
|
||||
|
||||
# Assert — the parent row-lock in append_message must have
|
||||
# serialised the two inserts so they land at positions {0, 1}.
|
||||
with pg_engine.connect() as verify_conn:
|
||||
msgs = ConversationsRepository(verify_conn).get_messages(conv["id"])
|
||||
positions = sorted(m["position"] for m in msgs)
|
||||
assert positions == [0, 1], (
|
||||
f"concurrent appends raced; got positions {positions}"
|
||||
)
|
||||
finally:
|
||||
# Clean up — the conversation was committed, so the transaction
|
||||
# rollback won't drop it.
|
||||
with pg_engine.begin() as cleanup_conn:
|
||||
ConversationsRepository(cleanup_conn).delete(
|
||||
conv["id"], "user-concurrent"
|
||||
)
|
||||
135
tests/storage/db/repositories/test_memories.py
Normal file
135
tests/storage/db/repositories/test_memories.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Tests for MemoriesRepository against a real Postgres instance.
|
||||
|
||||
Memories have a FK to user_tools, so each test creates a tool row first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.memories import MemoriesRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> MemoriesRepository:
|
||||
return MemoriesRepository(conn)
|
||||
|
||||
|
||||
def _make_tool(conn, user_id: str = "test-user", name: str = "mem-tool") -> str:
|
||||
"""Insert a user_tools row and return its UUID as a string."""
|
||||
return str(
|
||||
conn.execute(
|
||||
text("INSERT INTO user_tools (user_id, name) VALUES (:uid, :name) RETURNING id"),
|
||||
{"uid": user_id, "name": name},
|
||||
).scalar()
|
||||
)
|
||||
|
||||
|
||||
class TestUpsert:
|
||||
def test_creates_memory(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
doc = repo.upsert("test-user", tool_id, "/docs/readme.md", "Hello world")
|
||||
assert doc["path"] == "/docs/readme.md"
|
||||
assert doc["content"] == "Hello world"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_upsert_overwrites_content(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("test-user", tool_id, "/a.txt", "v1")
|
||||
doc = repo.upsert("test-user", tool_id, "/a.txt", "v2")
|
||||
assert doc["content"] == "v2"
|
||||
|
||||
def test_upsert_is_idempotent_on_same_content(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
first = repo.upsert("test-user", tool_id, "/a.txt", "same")
|
||||
second = repo.upsert("test-user", tool_id, "/a.txt", "same")
|
||||
assert first["id"] == second["id"]
|
||||
|
||||
|
||||
class TestGetByPath:
|
||||
def test_finds_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "/x", "content")
|
||||
fetched = repo.get_by_path("u", tool_id, "/x")
|
||||
assert fetched is not None
|
||||
assert fetched["content"] == "content"
|
||||
|
||||
def test_returns_none_for_missing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
assert repo.get_by_path("u", tool_id, "/nonexistent") is None
|
||||
|
||||
|
||||
class TestListByPrefix:
|
||||
def test_lists_matching_prefix(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "/docs/a.md", "a")
|
||||
repo.upsert("u", tool_id, "/docs/b.md", "b")
|
||||
repo.upsert("u", tool_id, "/other/c.md", "c")
|
||||
results = repo.list_by_prefix("u", tool_id, "/docs/")
|
||||
assert len(results) == 2
|
||||
assert {r["path"] for r in results} == {"/docs/a.md", "/docs/b.md"}
|
||||
|
||||
|
||||
class TestDeleteByPath:
|
||||
def test_deletes_single(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "/x", "c")
|
||||
count = repo.delete_by_path("u", tool_id, "/x")
|
||||
assert count == 1
|
||||
assert repo.get_by_path("u", tool_id, "/x") is None
|
||||
|
||||
def test_delete_nonexistent_returns_zero(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
assert repo.delete_by_path("u", tool_id, "/nope") == 0
|
||||
|
||||
|
||||
class TestDeleteByPrefix:
|
||||
def test_deletes_matching_prefix(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "/dir/a", "a")
|
||||
repo.upsert("u", tool_id, "/dir/b", "b")
|
||||
repo.upsert("u", tool_id, "/other/c", "c")
|
||||
count = repo.delete_by_prefix("u", tool_id, "/dir/")
|
||||
assert count == 2
|
||||
assert repo.get_by_path("u", tool_id, "/other/c") is not None
|
||||
|
||||
|
||||
class TestDeleteAll:
|
||||
def test_deletes_all_for_user_tool(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "/a", "a")
|
||||
repo.upsert("u", tool_id, "/b", "b")
|
||||
count = repo.delete_all("u", tool_id)
|
||||
assert count == 2
|
||||
assert repo.list_by_prefix("u", tool_id, "/") == []
|
||||
|
||||
|
||||
class TestUpdatePath:
|
||||
def test_renames_path(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "/old.txt", "content")
|
||||
renamed = repo.update_path("u", tool_id, "/old.txt", "/new.txt")
|
||||
assert renamed is True
|
||||
assert repo.get_by_path("u", tool_id, "/old.txt") is None
|
||||
assert repo.get_by_path("u", tool_id, "/new.txt")["content"] == "content"
|
||||
|
||||
def test_rename_nonexistent_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
assert repo.update_path("u", tool_id, "/nope", "/new") is False
|
||||
100
tests/storage/db/repositories/test_notes.py
Normal file
100
tests/storage/db/repositories/test_notes.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Tests for NotesRepository against a real Postgres instance.
|
||||
|
||||
Notes have a FK to user_tools, so each test creates a tool row first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> NotesRepository:
|
||||
return NotesRepository(conn)
|
||||
|
||||
|
||||
def _make_tool(conn, user_id: str = "test-user", name: str = "notes-tool") -> str:
|
||||
"""Insert a user_tools row and return its UUID as a string."""
|
||||
return str(
|
||||
conn.execute(
|
||||
text("INSERT INTO user_tools (user_id, name) VALUES (:uid, :name) RETURNING id"),
|
||||
{"uid": user_id, "name": name},
|
||||
).scalar()
|
||||
)
|
||||
|
||||
|
||||
class TestUpsert:
|
||||
def test_creates_note(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
doc = repo.upsert("test-user", tool_id, "My Note", "Some content")
|
||||
assert doc["title"] == "My Note"
|
||||
assert doc["content"] == "Some content"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_second_upsert_also_returns_content(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
first = repo.upsert("test-user", tool_id, "title", "v1")
|
||||
assert first["content"] == "v1"
|
||||
# A second upsert for the same (user, tool) creates a new note
|
||||
# (no unique constraint on (user_id, tool_id) exists).
|
||||
second = repo.upsert("test-user", tool_id, "title2", "v2")
|
||||
assert second["content"] == "v2"
|
||||
|
||||
|
||||
class TestGetForUserTool:
|
||||
def test_returns_note(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "t", "c")
|
||||
fetched = repo.get_for_user_tool("u", tool_id)
|
||||
assert fetched is not None
|
||||
assert fetched["content"] == "c"
|
||||
|
||||
def test_returns_none_when_missing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
assert repo.get_for_user_tool("u", tool_id) is None
|
||||
|
||||
|
||||
class TestGetById:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.upsert("u", tool_id, "t", "c")
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.upsert("u", tool_id, "t", "c")
|
||||
assert repo.get(created["id"], "other") is None
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_note(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.upsert("u", tool_id, "t", "c")
|
||||
deleted = repo.delete("u", tool_id)
|
||||
assert deleted is True
|
||||
assert repo.get_for_user_tool("u", tool_id) is None
|
||||
|
||||
def test_delete_nonexistent_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
deleted = repo.delete("u", tool_id)
|
||||
assert deleted is False
|
||||
99
tests/storage/db/repositories/test_pending_tool_state.py
Normal file
99
tests/storage/db/repositories/test_pending_tool_state.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for PendingToolStateRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import PendingToolStateRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _conv(conn) -> dict:
|
||||
return ConversationsRepository(conn).create("user-1", "test conv")
|
||||
|
||||
|
||||
def _repo(conn) -> PendingToolStateRepository:
|
||||
return PendingToolStateRepository(conn)
|
||||
|
||||
|
||||
def _sample_state() -> dict:
|
||||
return {
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"pending_tool_calls": [{"id": "tc-1", "name": "search"}],
|
||||
"tools_dict": {"search": {"type": "function"}},
|
||||
"tool_schemas": [{"name": "search"}],
|
||||
"agent_config": {"model_id": "gpt-4", "llm_name": "openai"},
|
||||
}
|
||||
|
||||
|
||||
class TestSaveState:
|
||||
def test_creates_state(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
state = _sample_state()
|
||||
doc = repo.save_state(conv["id"], "user-1", **state)
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["messages"] == state["messages"]
|
||||
assert doc["pending_tool_calls"] == state["pending_tool_calls"]
|
||||
assert doc["expires_at"] is not None
|
||||
|
||||
def test_upsert_replaces_existing(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
state = _sample_state()
|
||||
repo.save_state(conv["id"], "user-1", **state)
|
||||
state["messages"] = [{"role": "user", "content": "updated"}]
|
||||
doc2 = repo.save_state(conv["id"], "user-1", **state)
|
||||
# Same row, updated content
|
||||
assert doc2["messages"] == [{"role": "user", "content": "updated"}]
|
||||
|
||||
def test_save_with_client_tools(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
state = _sample_state()
|
||||
state["client_tools"] = [{"name": "browser"}]
|
||||
doc = repo.save_state(conv["id"], "user-1", **state)
|
||||
assert doc["client_tools"] == [{"name": "browser"}]
|
||||
|
||||
|
||||
class TestLoadState:
|
||||
def test_loads_existing(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.save_state(conv["id"], "user-1", **_sample_state())
|
||||
loaded = repo.load_state(conv["id"], "user-1")
|
||||
assert loaded is not None
|
||||
assert loaded["agent_config"]["model_id"] == "gpt-4"
|
||||
|
||||
def test_load_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.load_state("00000000-0000-0000-0000-000000000000", "u") is None
|
||||
|
||||
|
||||
class TestDeleteState:
|
||||
def test_deletes(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.save_state(conv["id"], "user-1", **_sample_state())
|
||||
assert repo.delete_state(conv["id"], "user-1") is True
|
||||
assert repo.load_state(conv["id"], "user-1") is None
|
||||
|
||||
def test_delete_nonexistent(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.delete_state("00000000-0000-0000-0000-000000000000", "u") is False
|
||||
|
||||
|
||||
class TestCleanupExpired:
|
||||
def test_cleanup_removes_expired(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
# Create a state with TTL of 0 seconds (already expired)
|
||||
repo.save_state(conv["id"], "user-1", **_sample_state(), ttl_seconds=0)
|
||||
deleted = repo.cleanup_expired()
|
||||
assert deleted >= 1
|
||||
assert repo.load_state(conv["id"], "user-1") is None
|
||||
167
tests/storage/db/repositories/test_prompts.py
Normal file
167
tests/storage/db/repositories/test_prompts.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for PromptsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> PromptsRepository:
|
||||
return PromptsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_prompt(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "greeting", "Hello {{name}}")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "greeting"
|
||||
assert doc["content"] == "Hello {{name}}"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "p", "c")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
def test_create_with_legacy_mongo_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("user-1", "p", "c", legacy_mongo_id="507f1f77bcf86cd799439011")
|
||||
assert doc["legacy_mongo_id"] == "507f1f77bcf86cd799439011"
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_by_id_and_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
|
||||
|
||||
def test_get_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"p",
|
||||
"c",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
|
||||
class TestGetForRendering:
|
||||
def test_returns_prompt_without_user_scoping(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
fetched = repo.get_for_rendering(created["id"])
|
||||
assert fetched is not None
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get_for_rendering("00000000-0000-0000-0000-000000000000") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_prompts(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("alice", "a1", "c1")
|
||||
repo.create("alice", "a2", "c2")
|
||||
repo.create("bob", "b1", "c3")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
def test_updates_name_and_content(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "old-content")
|
||||
repo.update(created["id"], "user-1", "new", "new-content")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
assert fetched["content"] == "new-content"
|
||||
|
||||
def test_update_wrong_user_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "old", "old-content")
|
||||
repo.update(created["id"], "user-other", "new", "new-content")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["name"] == "old"
|
||||
|
||||
def test_update_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(
|
||||
"user-1",
|
||||
"old",
|
||||
"old-content",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert repo.update_by_legacy_id(
|
||||
"507f1f77bcf86cd799439011",
|
||||
"user-1",
|
||||
"new",
|
||||
"new-content",
|
||||
) is True
|
||||
fetched = repo.get_by_legacy_id("507f1f77bcf86cd799439011", "user-1")
|
||||
assert fetched["name"] == "new"
|
||||
assert fetched["content"] == "new-content"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_prompt(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
repo.delete(created["id"], "user-1")
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
def test_delete_wrong_user_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("user-1", "p", "c")
|
||||
repo.delete(created["id"], "user-other")
|
||||
assert repo.get(created["id"], "user-1") is not None
|
||||
|
||||
def test_delete_by_legacy_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create(
|
||||
"user-1",
|
||||
"p",
|
||||
"c",
|
||||
legacy_mongo_id="507f1f77bcf86cd799439011",
|
||||
)
|
||||
assert repo.delete_by_legacy_id("507f1f77bcf86cd799439011", "user-1") is True
|
||||
assert repo.get(created["id"], "user-1") is None
|
||||
|
||||
|
||||
class TestFindOrCreate:
|
||||
def test_creates_when_missing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.find_or_create("sys", "template", "content")
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_returns_existing_on_match(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
first = repo.find_or_create("sys", "template", "content")
|
||||
second = repo.find_or_create("sys", "template", "content")
|
||||
assert first["id"] == second["id"]
|
||||
|
||||
def test_different_content_creates_new(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
first = repo.find_or_create("sys", "template", "v1")
|
||||
second = repo.find_or_create("sys", "template", "v2")
|
||||
assert first["id"] != second["id"]
|
||||
91
tests/storage/db/repositories/test_shared_conversations.py
Normal file
91
tests/storage/db/repositories/test_shared_conversations.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for SharedConversationsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.shared_conversations import SharedConversationsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _conv(conn) -> dict:
|
||||
return ConversationsRepository(conn).create("user-1", "test conv")
|
||||
|
||||
|
||||
def _repo(conn) -> SharedConversationsRepository:
|
||||
return SharedConversationsRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_share(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
share = repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
|
||||
assert share["conversation_id"] is not None
|
||||
assert share["user_id"] == "user-1"
|
||||
assert share["is_promptable"] is False
|
||||
assert share["first_n_queries"] == 3
|
||||
assert share["uuid"] is not None
|
||||
|
||||
def test_create_promptable_with_api_key(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
share = repo.create(
|
||||
conv["id"], "user-1",
|
||||
is_promptable=True,
|
||||
first_n_queries=5,
|
||||
api_key="ak-prompt",
|
||||
)
|
||||
assert share["is_promptable"] is True
|
||||
assert share["api_key"] == "ak-prompt"
|
||||
|
||||
|
||||
class TestFindByUuid:
|
||||
def test_finds_by_uuid(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
share = repo.create(conv["id"], "user-1", first_n_queries=2)
|
||||
found = repo.find_by_uuid(str(share["uuid"]))
|
||||
assert found["id"] == share["id"]
|
||||
|
||||
def test_not_found(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.find_by_uuid("00000000-0000-0000-0000-000000000000") is None
|
||||
|
||||
|
||||
class TestFindExisting:
|
||||
def test_finds_matching_share(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
|
||||
found = repo.find_existing(conv["id"], "user-1", False, 3)
|
||||
assert found is not None
|
||||
assert found["first_n_queries"] == 3
|
||||
|
||||
def test_no_match_different_params(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", is_promptable=False, first_n_queries=3)
|
||||
assert repo.find_existing(conv["id"], "user-1", True, 3) is None
|
||||
|
||||
def test_finds_with_api_key(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", is_promptable=True, first_n_queries=5, api_key="ak-1")
|
||||
found = repo.find_existing(conv["id"], "user-1", True, 5, api_key="ak-1")
|
||||
assert found is not None
|
||||
|
||||
|
||||
class TestListForConversation:
|
||||
def test_lists_shares(self, pg_conn):
|
||||
conv = _conv(pg_conn)
|
||||
repo = _repo(pg_conn)
|
||||
repo.create(conv["id"], "user-1", first_n_queries=1)
|
||||
repo.create(conv["id"], "user-1", first_n_queries=2)
|
||||
results = repo.list_for_conversation(conv["id"])
|
||||
assert len(results) == 2
|
||||
115
tests/storage/db/repositories/test_sources.py
Normal file
115
tests/storage/db/repositories/test_sources.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Tests for SourcesRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> SourcesRepository:
|
||||
return SourcesRepository(conn)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_source_with_user(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("my-source", user_id="user-1", type="url")
|
||||
assert doc["user_id"] == "user-1"
|
||||
assert doc["name"] == "my-source"
|
||||
assert doc["type"] == "url"
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_creates_source_with_metadata(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("src", user_id="u", metadata={"url": "https://example.com"})
|
||||
assert doc["metadata"] == {"url": "https://example.com"}
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
doc = repo.create("s", user_id="u")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("s", user_id="user-1")
|
||||
fetched = repo.get(created["id"], "user-1")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("s", user_id="user-1")
|
||||
assert repo.get(created["id"], "user-other") is None
|
||||
|
||||
|
||||
class TestListForUser:
|
||||
def test_lists_only_own_sources(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.create("s1", user_id="alice")
|
||||
repo.create("s2", user_id="alice")
|
||||
repo.create("s3", user_id="bob")
|
||||
results = repo.list_for_user("alice")
|
||||
assert len(results) == 2
|
||||
assert all(r["user_id"] == "alice" for r in results)
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
def test_updates_name(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("old", user_id="u")
|
||||
repo.update(created["id"], "u", {"name": "new"})
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["name"] == "new"
|
||||
|
||||
def test_updates_metadata(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("s", user_id="u", metadata={"a": 1})
|
||||
repo.update(created["id"], "u", {"metadata": {"a": 2, "b": 3}})
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["metadata"] == {"a": 2, "b": 3}
|
||||
|
||||
def test_update_disallowed_field_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("s", user_id="u")
|
||||
repo.update(created["id"], "u", {"id": "00000000-0000-0000-0000-000000000000"})
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_update_wrong_user_is_noop(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("old", user_id="u")
|
||||
repo.update(created["id"], "other-user", {"name": "new"})
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["name"] == "old"
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_source(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("s", user_id="u")
|
||||
deleted = repo.delete(created["id"], "u")
|
||||
assert deleted is True
|
||||
assert repo.get(created["id"], "u") is None
|
||||
|
||||
def test_delete_nonexistent_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
deleted = repo.delete("00000000-0000-0000-0000-000000000000", "u")
|
||||
assert deleted is False
|
||||
|
||||
def test_delete_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
created = repo.create("s", user_id="u")
|
||||
deleted = repo.delete(created["id"], "other-user")
|
||||
assert deleted is False
|
||||
assert repo.get(created["id"], "u") is not None
|
||||
58
tests/storage/db/repositories/test_stack_logs.py
Normal file
58
tests/storage/db/repositories/test_stack_logs.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for StackLogsRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> StackLogsRepository:
|
||||
return StackLogsRepository(conn)
|
||||
|
||||
|
||||
class TestInsert:
|
||||
def test_inserts_log(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(
|
||||
activity_id="act-1",
|
||||
endpoint="/api/answer",
|
||||
level="info",
|
||||
user_id="u1",
|
||||
api_key="k1",
|
||||
query="what is python?",
|
||||
stacks=[{"component": "retriever", "data": {"docs": 3}}],
|
||||
)
|
||||
row = pg_conn.execute(
|
||||
text("SELECT * FROM stack_logs WHERE activity_id = 'act-1'")
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
mapping = dict(row._mapping)
|
||||
assert mapping["endpoint"] == "/api/answer"
|
||||
assert mapping["level"] == "info"
|
||||
assert mapping["user_id"] == "u1"
|
||||
assert mapping["stacks"] == [{"component": "retriever", "data": {"docs": 3}}]
|
||||
|
||||
def test_inserts_with_empty_stacks(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(activity_id="act-2", level="error")
|
||||
row = pg_conn.execute(
|
||||
text("SELECT stacks FROM stack_logs WHERE activity_id = 'act-2'")
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
assert dict(row._mapping)["stacks"] == []
|
||||
|
||||
def test_truncated_query_stored(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
long_query = "x" * 20000
|
||||
repo.insert(activity_id="act-3", query=long_query)
|
||||
row = pg_conn.execute(
|
||||
text("SELECT query FROM stack_logs WHERE activity_id = 'act-3'")
|
||||
).fetchone()
|
||||
assert len(dict(row._mapping)["query"]) == 20000
|
||||
158
tests/storage/db/repositories/test_todos.py
Normal file
158
tests/storage/db/repositories/test_todos.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for TodosRepository against a real Postgres instance.
|
||||
|
||||
Todos have a FK to user_tools, so each test creates a tool row first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> TodosRepository:
|
||||
return TodosRepository(conn)
|
||||
|
||||
|
||||
def _make_tool(conn, user_id: str = "test-user", name: str = "todo-tool") -> str:
|
||||
"""Insert a user_tools row and return its UUID as a string."""
|
||||
return str(
|
||||
conn.execute(
|
||||
text("INSERT INTO user_tools (user_id, name) VALUES (:uid, :name) RETURNING id"),
|
||||
{"uid": user_id, "name": name},
|
||||
).scalar()
|
||||
)
|
||||
|
||||
|
||||
class TestCreate:
|
||||
def test_creates_todo(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
doc = repo.create("test-user", tool_id, "Buy milk")
|
||||
assert doc["title"] == "Buy milk"
|
||||
assert doc["completed"] is False
|
||||
assert doc["id"] is not None
|
||||
|
||||
def test_create_returns_id_and_underscore_id(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
doc = repo.create("test-user", tool_id, "t")
|
||||
assert doc["_id"] == doc["id"]
|
||||
|
||||
|
||||
class TestGet:
|
||||
def test_get_existing(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "t")
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["id"] == created["id"]
|
||||
|
||||
def test_get_nonexistent_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.get("00000000-0000-0000-0000-000000000000", "u") is None
|
||||
|
||||
def test_get_wrong_user_returns_none(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "t")
|
||||
assert repo.get(created["id"], "other") is None
|
||||
|
||||
|
||||
class TestListForUserTool:
|
||||
def test_lists_todos_for_user_tool(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
repo.create("u", tool_id, "t1")
|
||||
repo.create("u", tool_id, "t2")
|
||||
results = repo.list_for_user_tool("u", tool_id)
|
||||
assert len(results) == 2
|
||||
|
||||
def test_different_tools_are_isolated(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_a = _make_tool(pg_conn, name="tool-a")
|
||||
tool_b = _make_tool(pg_conn, name="tool-b")
|
||||
repo.create("u", tool_a, "a-todo")
|
||||
repo.create("u", tool_b, "b-todo")
|
||||
assert len(repo.list_for_user_tool("u", tool_a)) == 1
|
||||
assert len(repo.list_for_user_tool("u", tool_b)) == 1
|
||||
|
||||
|
||||
class TestUpdateTitle:
|
||||
def test_updates_title(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "old")
|
||||
updated = repo.update_title(created["id"], "u", "new")
|
||||
assert updated is True
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["title"] == "new"
|
||||
|
||||
def test_update_nonexistent_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.update_title("00000000-0000-0000-0000-000000000000", "u", "x") is False
|
||||
|
||||
def test_update_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "old")
|
||||
updated = repo.update_title(created["id"], "other", "new")
|
||||
assert updated is False
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["title"] == "old"
|
||||
|
||||
|
||||
class TestSetCompleted:
|
||||
def test_marks_completed(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "t")
|
||||
repo.set_completed(created["id"], "u", True)
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["completed"] is True
|
||||
|
||||
def test_unmarks_completed(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "t")
|
||||
repo.set_completed(created["id"], "u", True)
|
||||
repo.set_completed(created["id"], "u", False)
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["completed"] is False
|
||||
|
||||
def test_set_completed_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "t")
|
||||
result = repo.set_completed(created["id"], "other", True)
|
||||
assert result is False
|
||||
fetched = repo.get(created["id"], "u")
|
||||
assert fetched["completed"] is False
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_deletes_todo(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "t")
|
||||
deleted = repo.delete(created["id"], "u")
|
||||
assert deleted is True
|
||||
assert repo.get(created["id"], "u") is None
|
||||
|
||||
def test_delete_nonexistent_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
assert repo.delete("00000000-0000-0000-0000-000000000000", "u") is False
|
||||
|
||||
def test_delete_wrong_user_returns_false(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
tool_id = _make_tool(pg_conn)
|
||||
created = repo.create("u", tool_id, "t")
|
||||
deleted = repo.delete(created["id"], "other")
|
||||
assert deleted is False
|
||||
assert repo.get(created["id"], "u") is not None
|
||||
90
tests/storage/db/repositories/test_token_usage.py
Normal file
90
tests/storage/db/repositories/test_token_usage.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Tests for TokenUsageRepository against a real Postgres instance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not __import__("application.core.settings", fromlist=["settings"]).settings.POSTGRES_URI,
|
||||
reason="POSTGRES_URI not configured",
|
||||
)
|
||||
|
||||
|
||||
def _repo(conn) -> TokenUsageRepository:
|
||||
return TokenUsageRepository(conn)
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class TestInsert:
|
||||
def test_inserts_row(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1", prompt_tokens=10, generated_tokens=5)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert total == 15
|
||||
|
||||
def test_insert_with_api_key(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(api_key="key-1", prompt_tokens=20, generated_tokens=10)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), api_key="key-1"
|
||||
)
|
||||
assert total == 30
|
||||
|
||||
|
||||
class TestSumTokensInRange:
|
||||
def test_sums_correctly(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1", prompt_tokens=10, generated_tokens=5)
|
||||
repo.insert(user_id="u1", prompt_tokens=20, generated_tokens=10)
|
||||
repo.insert(user_id="u2", prompt_tokens=100, generated_tokens=50)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert total == 45
|
||||
|
||||
def test_returns_zero_when_no_rows(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="nobody"
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
def test_respects_time_range(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
old = _now() - timedelta(hours=48)
|
||||
repo.insert(user_id="u1", prompt_tokens=100, generated_tokens=0, timestamp=old)
|
||||
repo.insert(user_id="u1", prompt_tokens=10, generated_tokens=0)
|
||||
total = repo.sum_tokens_in_range(
|
||||
start=_now() - timedelta(hours=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert total == 10
|
||||
|
||||
|
||||
class TestCountInRange:
|
||||
def test_counts_rows(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(user_id="u1", prompt_tokens=1, generated_tokens=1)
|
||||
repo.insert(user_id="u1", prompt_tokens=1, generated_tokens=1)
|
||||
repo.insert(user_id="u2", prompt_tokens=1, generated_tokens=1)
|
||||
count = repo.count_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), user_id="u1"
|
||||
)
|
||||
assert count == 2
|
||||
|
||||
def test_filters_by_api_key(self, pg_conn):
|
||||
repo = _repo(pg_conn)
|
||||
repo.insert(api_key="k1", prompt_tokens=1, generated_tokens=1)
|
||||
repo.insert(api_key="k2", prompt_tokens=1, generated_tokens=1)
|
||||
count = repo.count_in_range(
|
||||
start=_now() - timedelta(minutes=1), end=_now() + timedelta(minutes=1), api_key="k1"
|
||||
)
|
||||
assert count == 1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user