Compare commits

...

10 Commits

Author SHA1 Message Date
Alex
107eb56b1d fix: mini cors hardening 2026-04-23 11:58:20 +01:00
Alex
5c07f5f340 fix: asgi issues 2026-04-23 10:41:27 +01:00
Alex
7ad78d4219 feat: asgi and mcp tool server 2026-04-22 22:19:20 +01:00
Alex
a5153d5212 feat: asgi and search service 2026-04-22 09:12:29 +01:00
Alex
d4b1c1fd81 chore: 0.17.0 version 2026-04-21 16:16:11 +01:00
Alex
2de84acf81 fix: mini callout 2026-04-21 16:14:08 +01:00
Alex
2702750861 docs: upgrading guide 2026-04-21 15:04:17 +01:00
Alex
2b5f20d0ec fix: safer version 2026-04-21 14:22:32 +01:00
Alex
619b41dc5b fix: better version fetch 2026-04-21 14:07:26 +01:00
Alex
76d8f49ccb feat: security version check 2026-04-21 09:16:52 +01:00
31 changed files with 2039 additions and 443 deletions

View File

@@ -37,6 +37,22 @@ Run the Flask API (if needed):
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
That's the fast inner-loop option — quick startup, the Werkzeug interactive
debugger still works, and it hot-reloads on source changes. It serves the
Flask routes only (`/api/*`, `/stream`, etc.).
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
or to match the production runtime exactly — run the ASGI composition under
uvicorn instead:
```bash
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
```
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
full flag set.
Run the Celery worker in a separate terminal (if needed):
```bash

View File

@@ -88,5 +88,15 @@ EXPOSE 7091
# Switch to non-root user
USER appuser
# Start Gunicorn
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
CMD ["gunicorn", \
"-w", "1", \
"-k", "uvicorn_worker.UvicornWorker", \
"--bind", "0.0.0.0:7091", \
"--timeout", "180", \
"--graceful-timeout", "120", \
"--keep-alive", "5", \
"--worker-tmp-dir", "/dev/shm", \
"--max-requests", "1000", \
"--max-requests-jitter", "100", \
"--config", "application/gunicorn_conf.py", \
"application.asgi:asgi_app"]

View File

@@ -0,0 +1,37 @@
"""0002 app_metadata — singleton key/value table for instance-wide state.
Used by the startup version-check client to persist the anonymous
instance UUID and a one-shot "notice shown" flag. Both values are tiny
plain-text strings; this is a deliberate generic-config table rather
than dedicated columns so future one-off settings (telemetry opt-in
timestamps, feature-flag overrides, etc.) don't each need their own
migration.
Revision ID: 0002_app_metadata
Revises: 0001_initial
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002_app_metadata"
down_revision: Union[str, None] = "0001_initial"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE app_metadata (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS app_metadata;")

View File

@@ -1,21 +1,21 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from application.api.answer.routes.base import answer_ns
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.vectorstore.vector_creator import VectorCreator
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
"""Fast search endpoint for retrieving relevant documents."""
search_model = answer_ns.model(
"SearchModel",
@@ -32,102 +32,10 @@ class SearchResource(Resource):
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent."""
with db_readonly() as conn:
agent_data = AgentsRepository(conn).find_by_key(api_key)
if not agent_data:
return []
source_ids: List[str] = []
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
extra = agent_data.get("extra_source_ids") or []
for src in extra:
if src:
source_ids.append(str(src))
if not source_ids:
single = agent_data.get("source_id")
if single:
source_ids.append(str(single))
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
data = request.get_json() or {}
question = data.get("question")
api_key = data.get("api_key")
@@ -135,32 +43,13 @@ class SearchResource(Resource):
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response(search(api_key, question, chunks), 200)
except InvalidAPIKey:
return make_response({"error": "Invalid API key"}, 401)
except SearchFailed:
logger.exception("/api/search failed")
return make_response({"error": "Search failed"}, 500)

View File

@@ -4,7 +4,7 @@ import platform
import uuid
import dotenv
from flask import Flask, jsonify, redirect, request
from flask import Flask, Response, jsonify, redirect, request
from jose import jwt
from application.auth import handle_auth
@@ -149,12 +149,11 @@ def authenticate_request():
@app.after_request
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
response.headers.add(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
return response

33
application/asgi.py Normal file
View File

@@ -0,0 +1,33 @@
"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process."""
from __future__ import annotations
from a2wsgi import WSGIMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Mount
from application.app import app as flask_app
from application.mcp_server import mcp
_WSGI_THREADPOOL = 32
mcp_app = mcp.http_app(path="/")
asgi_app = Starlette(
routes=[
Mount("/mcp", app=mcp_app),
Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)),
],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
expose_headers=["Mcp-Session-Id"],
),
],
lifespan=mcp_app.lifespan,
)

View File

@@ -1,6 +1,8 @@
import threading
from celery import Celery
from application.core.settings import settings
from celery.signals import setup_logging, worker_process_init
from celery.signals import setup_logging, worker_process_init, worker_ready
def make_celery(app_name=__name__):
@@ -39,5 +41,25 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
dispose_engine()
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.
Runs in a daemon thread so a slow endpoint or bad DNS never holds
up the worker becoming ready for tasks. The check itself is
fail-silent (see ``application.updates.version_check.run_check``);
this handler's only job is to launch it and get out of the way.
Import is lazy so the symbol resolution never fires at module
import time — consistent with the ``_dispose_db_engine_on_fork``
pattern above.
"""
try:
from application.updates.version_check import run_check
except Exception:
return
threading.Thread(target=run_check, name="version-check", daemon=True).start()
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -149,6 +149,9 @@ class Settings(BaseSettings):
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
# Anonymous startup version check for security issues.
VERSION_CHECK: bool = True
URL_STRATEGY: str = "backend" # backend or s3
JWT_SECRET_KEY: str = ""

View File

@@ -0,0 +1,72 @@
"""Gunicorn config — keeps uvicorn's access log in NCSA format."""
from __future__ import annotations
import logging
import logging.config
# NCSA common log format:
# %(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"
# Uvicorn's access formatter exposes a ``client_addr``/``request_line``/
# ``status_code`` trio but not the full NCSA field set, so we re-derive
# what we can.
_NCSA_FMT = (
'%(client_addr)s - - [%(asctime)s] "%(request_line)s" %(status_code)s'
)
logconfig_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"ncsa_access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": _NCSA_FMT,
"datefmt": "%d/%b/%Y:%H:%M:%S %z",
"use_colors": False,
},
"default": {
"format": "[%(asctime)s] [%(process)d] [%(levelname)s] %(name)s: %(message)s",
},
},
"handlers": {
"access": {
"class": "logging.StreamHandler",
"formatter": "ncsa_access",
"stream": "ext://sys.stdout",
},
"default": {
"class": "logging.StreamHandler",
"formatter": "default",
"stream": "ext://sys.stderr",
},
},
"loggers": {
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
"uvicorn.error": {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": "INFO",
"propagate": False,
},
"gunicorn.error": {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
},
"gunicorn.access": {
"handlers": ["access"],
"level": "INFO",
"propagate": False,
},
},
"root": {"handlers": ["default"], "level": "INFO"},
}
def on_starting(server): # pragma: no cover — gunicorn hook
"""Ensure gunicorn's own loggers use the configured handlers."""
logging.config.dictConfig(logconfig_dict)

59
application/mcp_server.py Normal file
View File

@@ -0,0 +1,59 @@
"""FastMCP server exposing DocsGPT retrieval over streamable HTTP.
Mounted at ``/mcp`` by ``application/asgi.py``. Bearer tokens are the
existing DocsGPT agent API keys — no new credential surface.
The tool reads the ``Authorization`` header directly via
``get_http_headers(include={"authorization"})``. The ``include`` kwarg
is required: by default ``get_http_headers`` strips ``authorization``
(and a handful of other hop-by-hop headers) so they aren't forwarded
to downstream services — since we deliberately want the caller's
token, we opt it back in.
"""
from __future__ import annotations
import asyncio
import logging
from fastmcp import FastMCP
from fastmcp.server.dependencies import get_http_headers
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
logger = logging.getLogger(__name__)
mcp = FastMCP("docsgpt")
def _extract_bearer_token() -> str | None:
auth = get_http_headers(include={"authorization"}).get("authorization", "")
parts = auth.split(None, 1)
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1]:
return None
return parts[1]
@mcp.tool
async def search_docs(query: str, chunks: int = 5) -> list[dict]:
"""Search the caller's DocsGPT knowledge base.
Authentication is via ``Authorization: Bearer <agent-api-key>`` on
the MCP request — the same opaque key that ``/api/search`` accepts
in its JSON body. Returns at most ``chunks`` hits, each a dict with
``text``, ``title``, ``source`` keys.
"""
api_key = _extract_bearer_token()
if not api_key:
raise PermissionError("Missing Bearer token")
try:
return await asyncio.to_thread(search, api_key, query, chunks)
except InvalidAPIKey as exc:
raise PermissionError("Invalid API key") from exc
except SearchFailed:
logger.exception("search_docs failed")
raise

View File

@@ -1,5 +1,7 @@
a2wsgi==1.10.10
alembic>=1.13,<2
anthropic==0.88.0
asgiref>=3.11.1
boto3==1.42.83
beautifulsoup4==4.14.3
cel-python==0.5.0
@@ -14,7 +16,7 @@ docx2txt==0.9
ddgs>=8.0.0
fast-ebook
elevenlabs==2.43.0
Flask==3.1.3
Flask==3.1.1
faiss-cpu==1.13.2
fastmcp==3.2.4
flask-restx==1.3.2
@@ -76,6 +78,7 @@ requests==2.33.1
retry==0.9.2
sentence-transformers==5.3.0
sqlalchemy>=2.0,<3
starlette>=1.0,<2
tiktoken==0.12.0
tokenizers==0.22.2
torch==2.11.0
@@ -85,6 +88,8 @@ typing-extensions==4.15.0
typing-inspect==0.9.0
tzdata==2026.1
urllib3==2.6.3
uvicorn[standard]>=0.30,<1
uvicorn-worker>=0.4,<1
vine==5.1.0
wcwidth==0.6.0
werkzeug>=3.1.0

View File

View File

@@ -0,0 +1,153 @@
"""Shared retrieval service used by the HTTP search route and the MCP tool.
Flask-free. Raises domain exceptions (``InvalidAPIKey``, ``SearchFailed``)
that callers translate into their own wire protocol (HTTP status codes,
MCP error responses, etc.).
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
class InvalidAPIKey(Exception):
"""The supplied ``api_key`` does not resolve to an agent."""
class SearchFailed(Exception):
"""Unexpected error during retrieval (e.g. DB outage). Caller maps to 5xx."""
def _collect_source_ids(agent: Dict[str, Any]) -> List[str]:
"""Extract the ordered list of source UUIDs to search.
Prefers ``extra_source_ids`` (PG ARRAY(UUID) of multi-source agents);
falls back to the legacy single ``source_id`` field.
"""
source_ids: List[str] = []
extra = agent.get("extra_source_ids") or []
for src in extra:
if src:
source_ids.append(str(src))
if not source_ids:
single = agent.get("source_id")
if single:
source_ids.append(str(single))
return source_ids
def _search_sources(
query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across each source's vectorstore and return up to ``chunks`` hits.
Per-source errors are logged and skipped so one broken index doesn't
take down the whole search. Results are de-duplicated by content hash.
"""
if chunks <= 0 or not source_ids:
return []
results: List[Dict[str, Any]] = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts: set[int] = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get("title", metadata.get("post_title", ""))
if not isinstance(title, str):
title = str(title) if title else ""
if title:
title = title.split("/")[-1]
else:
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append(
{
"text": page_content,
"title": title,
"source": source,
}
)
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
def search(api_key: str, query: str, chunks: int = 5) -> List[Dict[str, Any]]:
"""Resolve an agent by API key and search its sources.
Args:
api_key: Agent API key (the opaque string stored on
``agents.key`` in Postgres).
query: Free-text search query.
chunks: Max number of hits to return.
Returns:
List of hit dicts with ``text``, ``title``, ``source`` keys.
Empty list if the agent has no sources configured.
Raises:
InvalidAPIKey: if ``api_key`` does not resolve to an agent.
SearchFailed: on unexpected DB / infrastructure errors.
"""
if chunks <= 0:
return []
try:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
except Exception as e:
raise SearchFailed("agent lookup failed") from e
if not agent:
raise InvalidAPIKey()
source_ids = _collect_source_ids(agent)
if not source_ids:
return []
return _search_sources(query, source_ids, chunks)

View File

@@ -117,6 +117,16 @@ stack_logs_table = Table(
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
# Singleton key/value table for instance-wide state (e.g. anonymous
# instance UUID, one-shot notice flags). Added in migration
# ``0002_app_metadata``.
app_metadata_table = Table(
"app_metadata",
metadata,
Column("key", Text, primary_key=True),
Column("value", Text, nullable=False),
)
# --- Phase 2, Tier 2 --------------------------------------------------------

View File

@@ -0,0 +1,60 @@
"""Repository for the ``app_metadata`` singleton key/value table.
Owns the instance-wide state the version-check client needs:
``instance_id`` (anonymous UUID sent with each check) and
``version_check_notice_shown`` (one-shot flag for the first-run
telemetry notice). Kept deliberately generic so future one-off config
values can piggyback without a new migration each time.
"""
from __future__ import annotations
import uuid
from typing import Optional
from sqlalchemy import Connection, text
class AppMetadataRepository:
"""Postgres-backed ``app_metadata`` store. Tiny by design."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def get(self, key: str) -> Optional[str]:
row = self._conn.execute(
text("SELECT value FROM app_metadata WHERE key = :key"),
{"key": key},
).fetchone()
return row[0] if row is not None else None
def set(self, key: str, value: str) -> None:
self._conn.execute(
text(
"INSERT INTO app_metadata (key, value) VALUES (:key, :value) "
"ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value"
),
{"key": key, "value": value},
)
def get_or_create_instance_id(self) -> str:
"""Return the anonymous instance UUID, generating one if absent.
Uses ``INSERT ... ON CONFLICT DO NOTHING`` + re-read so two
workers racing on the very first startup converge on a single
UUID instead of each persisting their own.
"""
existing = self.get("instance_id")
if existing:
return existing
candidate = str(uuid.uuid4())
self._conn.execute(
text(
"INSERT INTO app_metadata (key, value) VALUES ('instance_id', :value) "
"ON CONFLICT (key) DO NOTHING"
),
{"value": candidate},
)
# Re-read: if another worker won the race, their UUID is now authoritative.
winner = self.get("instance_id")
return winner or candidate

View File

View File

@@ -0,0 +1,302 @@
"""Anonymous startup version-check client.
Called once per Celery worker boot (see ``application/celery_init.py``
``worker_ready`` handler). Posts the running version + anonymous
instance UUID to ``gptcloud.arc53.com/api/check``, caches the response
in Redis, and surfaces any advisories to stdout + logs.
Design invariants — all enforced by a broad ``try/except`` at the top
of :func:`run_check`:
* Never blocks worker startup (fired from a daemon thread).
* Never raises to the caller (every failure is swallowed + logged at
``DEBUG``).
* Opt-out via ``VERSION_CHECK=0`` short-circuits before any Postgres
write, Redis access, or outbound request.
* Redis coordinates multi-worker and multi-replica deployments — the
first worker to acquire ``docsgpt:version_check:lock`` fetches, the
rest read from the cached response on the next cycle.
"""
from __future__ import annotations
import json
import logging
import os
import platform
import socket
import sys
from typing import Any, Dict, Optional
import requests
from application.cache import get_redis_instance
from application.core.settings import settings
from application.storage.db.repositories.app_metadata import AppMetadataRepository
from application.storage.db.session import db_session
from application.version import get_version
logger = logging.getLogger(__name__)
ENDPOINT_URL = "https://gptcloud.arc53.com/api/check"
CLIENT_NAME = "docsgpt-backend"
REQUEST_TIMEOUT_SECONDS = 5
CACHE_KEY = "docsgpt:version_check:response"
LOCK_KEY = "docsgpt:version_check:lock"
CACHE_TTL_SECONDS = 6 * 3600 # 6h default; shortened by response `next_check_after`.
LOCK_TTL_SECONDS = 60
NOTICE_KEY = "version_check_notice_shown"
INSTANCE_ID_KEY = "instance_id"
_HIGH_SEVERITIES = {"high", "critical"}
_ANSI_RESET = "\033[0m"
_ANSI_RED = "\033[31m"
_ANSI_YELLOW = "\033[33m"
def run_check() -> None:
"""Entry point for the worker-startup daemon thread.
Safe to call unconditionally: the opt-out, Redis-outage, and
Postgres-outage paths all return silently. No exception propagates.
"""
try:
_run_check_inner()
except Exception as exc: # noqa: BLE001 — belt-and-braces; nothing escapes.
logger.debug("version check crashed: %s", exc, exc_info=True)
def _run_check_inner() -> None:
if not settings.VERSION_CHECK:
return
instance_id = _resolve_instance_id_and_notice()
if instance_id is None:
# Postgres unavailable — per spec we skip the check entirely
# rather than phone home with a synthetic/ephemeral UUID.
return
redis_client = get_redis_instance()
cached = _read_cache(redis_client)
if cached is not None:
_render_advisories(cached)
return
# Cache miss. Try to win the lock; if another worker has it, skip.
# ``redis_client is None`` here means Redis is unreachable — per the
# spec we still proceed uncached (acceptable duplicate calls in
# multi-worker Redis-less deploys).
if redis_client is not None and not _acquire_lock(redis_client):
return
response = _fetch(instance_id)
if response is None:
if redis_client is not None:
_release_lock(redis_client)
return
_write_cache(redis_client, response)
_render_advisories(response)
if redis_client is not None:
_release_lock(redis_client)
def _resolve_instance_id_and_notice() -> Optional[str]:
"""Load (or create) the instance UUID and emit the first-run notice.
The notice is printed at most once across the lifetime of the
installation — tracked via the ``version_check_notice_shown`` row
in ``app_metadata``. Both reads and the write happen inside one
short transaction so two racing workers can't each emit the notice.
"""
try:
with db_session() as conn:
repo = AppMetadataRepository(conn)
instance_id = repo.get_or_create_instance_id()
if repo.get(NOTICE_KEY) is None:
_print_first_run_notice()
repo.set(NOTICE_KEY, "1")
return instance_id
except Exception as exc: # noqa: BLE001 — Postgres down, bad URI, etc.
logger.debug("version check: Postgres unavailable (%s)", exc, exc_info=True)
return None
def _print_first_run_notice() -> None:
message = (
"Anonymous version check enabled — sends version to "
"gptcloud.arc53.com.\nDisable with VERSION_CHECK=0."
)
print(message, flush=True)
logger.info("version check: first-run notice shown")
def _read_cache(redis_client) -> Optional[Dict[str, Any]]:
if redis_client is None:
return None
try:
raw = redis_client.get(CACHE_KEY)
except Exception as exc: # noqa: BLE001 — Redis transient errors.
logger.debug("version check: cache GET failed (%s)", exc, exc_info=True)
return None
if raw is None:
return None
try:
return json.loads(raw.decode("utf-8") if isinstance(raw, bytes) else raw)
except (ValueError, AttributeError) as exc:
logger.debug("version check: cache decode failed (%s)", exc, exc_info=True)
return None
def _write_cache(redis_client, response: Dict[str, Any]) -> None:
if redis_client is None:
return
ttl = _compute_ttl(response)
try:
redis_client.setex(CACHE_KEY, ttl, json.dumps(response))
except Exception as exc: # noqa: BLE001
logger.debug("version check: cache SETEX failed (%s)", exc, exc_info=True)
def _compute_ttl(response: Dict[str, Any]) -> int:
"""Cap the cache at 6h but honor a shorter server-specified window."""
next_after = response.get("next_check_after")
if isinstance(next_after, (int, float)) and next_after > 0:
return max(1, min(CACHE_TTL_SECONDS, int(next_after)))
return CACHE_TTL_SECONDS
def _acquire_lock(redis_client) -> bool:
try:
owner = f"{socket.gethostname()}:{os.getpid()}"
return bool(
redis_client.set(LOCK_KEY, owner, nx=True, ex=LOCK_TTL_SECONDS)
)
except Exception as exc: # noqa: BLE001
# Treat a failing Redis the same as "no lock infra" — skip rather
# than fire without coordination, because Redis outage is
# usually transient and one missed cycle is harmless.
logger.debug("version check: lock acquire failed (%s)", exc, exc_info=True)
return False
def _release_lock(redis_client) -> None:
try:
redis_client.delete(LOCK_KEY)
except Exception as exc: # noqa: BLE001
logger.debug("version check: lock release failed (%s)", exc, exc_info=True)
def _fetch(instance_id: str) -> Optional[Dict[str, Any]]:
version = get_version()
if version in ("", "unknown"):
# The endpoint rejects payloads without a valid semver, and the
# rejection is otherwise logged at DEBUG — invisible under the
# usual ``-l INFO`` Celery worker start. Surface it loudly so a
# misconfigured release (missing or unset ``__version__``) is
# obvious instead of silently disabling the check.
logger.warning(
"version check: skipping — get_version() returned %r. "
"Set __version__ in application/version.py to a valid "
"version string.",
version,
)
return None
payload = {
"version": version,
"instance_id": instance_id,
"python_version": platform.python_version(),
"platform": sys.platform,
"client": CLIENT_NAME,
}
try:
resp = requests.post(
ENDPOINT_URL,
json=payload,
timeout=REQUEST_TIMEOUT_SECONDS,
)
except requests.RequestException as exc:
logger.debug("version check: request failed (%s)", exc, exc_info=True)
return None
if resp.status_code >= 400:
logger.debug("version check: non-2xx response %s", resp.status_code)
return None
try:
return resp.json()
except ValueError as exc:
logger.debug("version check: response decode failed (%s)", exc, exc_info=True)
return None
def _render_advisories(response: Dict[str, Any]) -> None:
advisories = response.get("advisories") or []
if not isinstance(advisories, list):
return
current_version = get_version()
for advisory in advisories:
if not isinstance(advisory, dict):
continue
severity = str(advisory.get("severity", "")).lower()
advisory_id = advisory.get("id", "UNKNOWN")
title = advisory.get("title", "")
url = advisory.get("url", "")
fixed_in = advisory.get("fixed_in")
summary = advisory.get(
"summary",
f"Your DocsGPT version {current_version} is vulnerable.",
)
logger.warning(
"security advisory %s (severity=%s) affects version %s: %s%s%s",
advisory_id,
severity or "unknown",
current_version,
title or summary,
f" — fixed in {fixed_in}" if fixed_in else "",
f"{url}" if url else "",
)
if severity in _HIGH_SEVERITIES:
_print_console_advisory(
advisory_id=advisory_id,
title=title,
severity=severity,
summary=summary,
fixed_in=fixed_in,
url=url,
)
def _print_console_advisory(
*,
advisory_id: str,
title: str,
severity: str,
summary: str,
fixed_in: Optional[str],
url: str,
) -> None:
color = _ANSI_RED if severity == "critical" else _ANSI_YELLOW
bar = "=" * 60
upgrade_line = ""
if fixed_in and url:
upgrade_line = f" Upgrade to {fixed_in}+ — {url}"
elif fixed_in:
upgrade_line = f" Upgrade to {fixed_in}+"
elif url:
upgrade_line = f" {url}"
lines = [
bar,
f"\u26a0 SECURITY ADVISORY: {advisory_id}",
f" {summary}",
f" {title} (severity: {severity})" if title else f" severity: {severity}",
]
if upgrade_line:
lines.append(upgrade_line)
lines.append(bar)
print(f"{color}{chr(10).join(lines)}{_ANSI_RESET}", flush=True)

10
application/version.py Normal file
View File

@@ -0,0 +1,10 @@
"""DocsGPT backend version string."""
from __future__ import annotations
__version__ = "0.17.0"
def get_version() -> str:
"""Return the DocsGPT backend version."""
return __version__

View File

@@ -104,7 +104,15 @@ To run the DocsGPT backend locally, you'll need to set up a Python environment a
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
This command will launch the backend server, making it accessible on `http://localhost:7091`.
This command will launch the backend server, making it accessible on `http://localhost:7091`. It's the fastest inner-loop option for day-to-day development — the Werkzeug interactive debugger still works and it hot-reloads on source changes. It serves the Flask routes only.
If you need to exercise the full ASGI stack — the `/mcp` endpoint (FastMCP server), or to match the production runtime — run the ASGI composition under uvicorn instead:
```bash
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
```
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same `application.asgi:asgi_app` target.
6. **Start the Celery Worker:**

View File

@@ -1,6 +1,7 @@
export default {
"index": "Home",
"quickstart": "Quickstart",
"upgrading": "Upgrading",
"Deploying": "Deploying",
"Models": "Models",
"Tools": "Tools",

View File

@@ -0,0 +1,66 @@
---
title: Upgrading DocsGPT
description: Upgrade your DocsGPT deployment across Docker Compose, source builds, and Kubernetes.
---
import { Callout } from 'nextra/components'
# Upgrading DocsGPT
<Callout type="warning">
**Upgrading from 0.16.x?** User data moved from MongoDB to Postgres in 0.17.0. Follow the [Postgres Migration guide](/Deploying/Postgres-Migration) before running `docker compose pull` or `git pull` — existing deployments will not start cleanly without it.
</Callout>
## Check your version
```bash
docker compose exec backend python -c "from application.version import get_version; print(get_version())"
```
Release notes: [changelog](/changelog). Tags: [GitHub releases](https://github.com/arc53/DocsGPT/releases).
## Docker Compose — hub images
```bash
cd DocsGPT/deployment
docker compose -f docker-compose-hub.yaml pull
docker compose -f docker-compose-hub.yaml up -d
```
`pull` fetches the latest image for whichever tag your compose file references. To move to a specific release, edit `image: arc53/docsgpt:<tag>` first.
## Docker Compose — from source
```bash
cd DocsGPT
git pull
docker compose -f deployment/docker-compose.yaml build
docker compose -f deployment/docker-compose.yaml up -d
```
Swap `git pull` for `git checkout <tag>` if you want to pin a specific release.
## Kubernetes
```bash
kubectl set image deployment/docsgpt-backend backend=arc53/docsgpt:<tag>
kubectl set image deployment/docsgpt-worker worker=arc53/docsgpt:<tag>
kubectl rollout status deployment/docsgpt-backend
kubectl rollout status deployment/docsgpt-worker
```
Full manifests: [Kubernetes deployment guide](/Deploying/Kubernetes-Deploying).
## Migrations
Alembic migrations run on worker startup. To apply manually:
```bash
docker compose exec backend alembic -c application/alembic.ini upgrade head
```
`upgrade head` is idempotent.
## Rollback
Set the image tag to the previous release and `up -d` again. Schema changes are not reversible without a backup — take one before upgrading any release that mentions migrations in the changelog.

View File

@@ -16,6 +16,7 @@ markers =
unit: Unit tests
integration: Integration tests
slow: Slow running tests
asyncio_mode = strict
filterwarnings =
ignore::DeprecationWarning
ignore::PendingDeprecationWarning

137
scripts/mock_llm.py Normal file
View File

@@ -0,0 +1,137 @@
"""Mock OpenAI-compatible LLM server for benchmarking.
Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
chunks in OpenAI's chat.completions streaming format, or a single response
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
"""
import asyncio
import json
import logging
import time
import uuid
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
TOKEN_COUNT = 100
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
logger = logging.getLogger("mock_llm")
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
FILLER_TOKENS = [
"Lorem", " ipsum", " dolor", " sit", " amet", ",", " consectetur",
" adipiscing", " elit", ".", " Sed", " do", " eiusmod", " tempor",
" incididunt", " ut", " labore", " et", " dolore", " magna", " aliqua",
".", " Ut", " enim", " ad", " minim", " veniam", ",", " quis", " nostrud",
" exercitation", " ullamco", " laboris", " nisi", " ut", " aliquip",
" ex", " ea", " commodo", " consequat", ".", " Duis", " aute", " irure",
" dolor", " in", " reprehenderit", " in", " voluptate", " velit",
" esse", " cillum", " dolore", " eu", " fugiat", " nulla", " pariatur",
".", " Excepteur", " sint", " occaecat", " cupidatat", " non", " proident",
",", " sunt", " in", " culpa", " qui", " officia", " deserunt",
" mollit", " anim", " id", " est", " laborum", ".", " Curabitur",
" pretium", " tincidunt", " lacus", ".", " Nulla", " gravida", " orci",
" a", " odio", ".", " Nullam", " varius", ",", " turpis", " et",
" commodo", " pharetra", ",", " est", " eros", " bibendum", " elit",
".",
]
app = FastAPI()
def _token_stream_id() -> str:
return f"chatcmpl-mock-{uuid.uuid4().hex[:12]}"
def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None) -> str:
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(payload)}\n\n"
async def _stream_response(model: str, req_id: str):
completion_id = _token_stream_id()
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
await asyncio.sleep(TOKEN_DELAY_S)
yield _sse_chunk(completion_id, model, {"content": tok})
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
yield "data: [DONE]\n\n"
logger.info("[%s] stream done", req_id)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
model = body.get("model", "mock")
stream = bool(body.get("stream", False))
req_id = uuid.uuid4().hex[:8]
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
if stream:
return StreamingResponse(
_stream_response(model, req_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
},
)
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
logger.info("[%s] non-stream done", req_id)
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
completion_id = _token_stream_id()
return JSONResponse(
{
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": TOKEN_COUNT,
"total_tokens": 10 + TOKEN_COUNT,
},
}
)
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
}
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")

View File

@@ -1,3 +1,17 @@
"""Tests for /api/search route (application/api/answer/routes/search.py).
Retrieval logic lives in ``application/services/search_service.py`` and
has its own unit tests in ``tests/services/test_search_service.py``. The
tests below focus on what the route specifically owns:
* Request validation (400 for missing fields).
* Translation of the service's ``InvalidAPIKey`` / ``SearchFailed``
exceptions to HTTP status codes (401 / 500).
* End-to-end happy path against a real ephemeral Postgres via
``pg_conn``, to catch regressions in the route's wiring to the
service and repositories.
"""
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
@@ -6,254 +20,97 @@ import pytest
@pytest.mark.unit
class TestSearchResourceValidation:
pass
def test_returns_error_when_question_missing(self, mock_mongo_db, flask_app):
def test_returns_400_when_question_missing(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
with flask_app.test_request_context(
json={"api_key": "test_key"}
):
resource = SearchResource()
result = resource.post()
with flask_app.test_request_context(json={"api_key": "test_key"}):
result = SearchResource().post()
assert result.status_code == 400
assert "question" in result.json["error"]
def test_returns_error_when_api_key_missing(self, mock_mongo_db, flask_app):
def test_returns_400_when_api_key_missing(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
with flask_app.test_request_context(
json={"question": "test query"}
):
resource = SearchResource()
result = resource.post()
with flask_app.test_request_context(json={"question": "test query"}):
result = SearchResource().post()
assert result.status_code == 400
assert "api_key" in result.json["error"]
@pytest.mark.unit
class TestGetSourcesFromApiKey:
pass
class TestSearchResourceExceptionMapping:
"""Verify the route maps service exceptions to HTTP status codes.
def test_returns_source_id_via_patched_method(self, mock_mongo_db, flask_app):
"""Test that _get_sources_from_api_key can return multiple sources via patch."""
The service function itself is patched; these tests do not care about
the search logic — only that 401/500/200 are produced correctly from
the three possible service outcomes.
"""
def test_invalid_api_key_returns_401(self, flask_app):
from application.api.answer.routes.search import SearchResource
from application.services.search_service import InvalidAPIKey
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "bad"}
), patch(
"application.api.answer.routes.search.search",
side_effect=InvalidAPIKey(),
):
result = SearchResource().post()
assert result.status_code == 401
assert result.json == {"error": "Invalid API key"}
def test_search_failed_returns_500(self, flask_app):
from application.api.answer.routes.search import SearchResource
from application.services.search_service import SearchFailed
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "k"}
), patch(
"application.api.answer.routes.search.search",
side_effect=SearchFailed("boom"),
):
result = SearchResource().post()
assert result.status_code == 500
assert result.json == {"error": "Search failed"}
def test_happy_path_passes_service_result_through(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
hits = [{"text": "t", "title": "T", "source": "s"}]
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "k", "chunks": 7}
), patch(
"application.api.answer.routes.search.search",
return_value=hits,
) as mock_search:
result = SearchResource().post()
assert result.status_code == 200
assert result.json == hits
mock_search.assert_called_once_with("k", "q", 7)
with patch.object(resource, "_get_sources_from_api_key", return_value=["src1", "src2"]):
result = resource._get_sources_from_api_key("any_key")
assert len(result) == 2
assert "src1" in result
assert "src2" in result
@pytest.mark.unit
class TestSearchVectorstores:
pass
def test_returns_empty_when_no_source_ids(self, mock_mongo_db, flask_app):
def test_default_chunks_is_5(self, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
result = resource._search_vectorstores("test query", [], 5)
assert result == []
def test_skips_empty_source_ids(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = []
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["", " "], 5)
mock_create.assert_not_called()
assert result == []
def test_returns_search_results(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = {
"text": "Test content",
"page_content": "Test content",
"metadata": {
"title": "Test Title",
"source": "/path/to/doc",
},
}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert len(result) == 1
assert result[0]["text"] == "Test content"
assert result[0]["title"] == "Test Title"
assert result[0]["source"] == "/path/to/doc"
def test_handles_langchain_document_format(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = MagicMock()
mock_doc.page_content = "Langchain content"
mock_doc.metadata = {"title": "LC Title", "source": "/lc/path"}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert len(result) == 1
assert result[0]["text"] == "Langchain content"
assert result[0]["title"] == "LC Title"
def test_respects_chunks_limit(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_docs = [
{"text": f"Content {i}", "metadata": {"title": f"Title {i}"}}
for i in range(10)
]
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = mock_docs
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 3)
assert len(result) == 3
def test_deduplicates_results(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
duplicate_text = "Duplicate content " * 20
mock_docs = [
{"text": duplicate_text, "metadata": {"title": "Title 1"}},
{"text": duplicate_text, "metadata": {"title": "Title 2"}},
{"text": "Unique content", "metadata": {"title": "Title 3"}},
]
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = mock_docs
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert len(result) == 2
def test_handles_vectorstore_error_gracefully(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_create.side_effect = Exception("Vectorstore error")
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert result == []
def test_uses_filename_as_title_fallback(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = {
"text": "Content without title",
"metadata": {"filename": "document.pdf"},
}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert result[0]["title"] == "document.pdf"
def test_uses_content_snippet_as_title_last_resort(self, mock_mongo_db, flask_app):
from application.api.answer.routes.search import SearchResource
with flask_app.app_context():
resource = SearchResource()
mock_doc = {
"text": "Content without any title metadata at all",
"metadata": {},
}
with patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore"
) as mock_create:
mock_vectorstore = MagicMock()
mock_vectorstore.search.return_value = [mock_doc]
mock_create.return_value = mock_vectorstore
result = resource._search_vectorstores("test query", ["source_id"], 5)
assert "Content without any title" in result[0]["title"]
assert result[0]["title"].endswith("...")
@pytest.mark.unit
class TestSearchEndpoint:
pass
with flask_app.app_context(), flask_app.test_request_context(
json={"question": "q", "api_key": "k"} # no chunks field
), patch(
"application.api.answer.routes.search.search",
return_value=[],
) as mock_search:
SearchResource().post()
mock_search.assert_called_once_with("k", "q", 5)
# ---------------------------------------------------------------------------
# Real-PG tests for SearchResource.
# End-to-end against a real ephemeral Postgres.
#
# These exercise the full route → service → repository → DB path, patching
# only ``VectorCreator.create_vectorstore`` (so we don't need real embeddings
# or a vector index). ``db_readonly`` is redirected at the *service* module
# since that's where the import now lives.
# ---------------------------------------------------------------------------
@@ -264,7 +121,7 @@ def _patch_search_db(conn):
yield conn
with patch(
"application.api.answer.routes.search.db_readonly", _yield
"application.services.search_service.db_readonly", _yield
):
yield
@@ -298,9 +155,7 @@ class TestSearchResourcePgConn:
def test_search_returns_results(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
src = SourcesRepository(pg_conn).create("src", user_id="u")
AgentsRepository(pg_conn).create(
@@ -315,7 +170,7 @@ class TestSearchResourcePgConn:
]
with _patch_search_db(pg_conn), patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore",
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
), flask_app.app_context():
with flask_app.test_request_context(
@@ -328,9 +183,7 @@ class TestSearchResourcePgConn:
def test_search_uses_extra_source_ids(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
src1 = SourcesRepository(pg_conn).create("s1", user_id="u")
src2 = SourcesRepository(pg_conn).create("s2", user_id="u")
@@ -345,7 +198,7 @@ class TestSearchResourcePgConn:
{"text": "one", "metadata": {"title": "A"}},
]
with _patch_search_db(pg_conn), patch(
"application.api.answer.routes.search.VectorCreator.create_vectorstore",
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
), flask_app.app_context():
with flask_app.test_request_context(
@@ -353,71 +206,3 @@ class TestSearchResourcePgConn:
):
result = SearchResource().post()
assert result.status_code == 200
def test_search_exception_returns_500(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
src = SourcesRepository(pg_conn).create("src", user_id="u")
AgentsRepository(pg_conn).create(
"u", "a", "published",
key="err-key",
source_id=str(src["id"]),
)
with _patch_search_db(pg_conn), patch(
"application.api.answer.routes.search.SearchResource._get_sources_from_api_key",
side_effect=RuntimeError("boom"),
), flask_app.app_context():
with flask_app.test_request_context(
json={"question": "q", "api_key": "err-key"},
):
result = SearchResource().post()
assert result.status_code == 500
class TestGetSourcesFromApiKeyPg:
def test_empty_for_unknown_key(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
with _patch_search_db(pg_conn), flask_app.app_context():
got = SearchResource()._get_sources_from_api_key("nope")
assert got == []
def test_returns_extra_source_ids(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
src = SourcesRepository(pg_conn).create("s", user_id="u")
AgentsRepository(pg_conn).create(
"u", "a", "published",
key="sources-key",
extra_source_ids=[str(src["id"])],
)
with _patch_search_db(pg_conn), flask_app.app_context():
got = SearchResource()._get_sources_from_api_key("sources-key")
assert got == [str(src["id"])]
def test_falls_back_to_single_source(self, pg_conn, flask_app):
from application.api.answer.routes.search import SearchResource
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.sources import (
SourcesRepository,
)
src = SourcesRepository(pg_conn).create("s", user_id="u")
AgentsRepository(pg_conn).create(
"u", "a", "published",
key="single-key",
source_id=str(src["id"]),
)
with _patch_search_db(pg_conn), flask_app.app_context():
got = SearchResource()._get_sources_from_api_key("single-key")
assert got == [str(src["id"])]

View File

@@ -1,4 +1,5 @@
pytest>=8.0.0
pytest-asyncio>=0.23
pytest-cov>=4.1.0
coverage>=7.4.0
pytest-postgresql>=6.0.0

View File

View File

@@ -0,0 +1,134 @@
"""Tests for application/mcp_server.py.
The server module exposes one FastMCP tool, ``search_docs``, that reads
the caller's ``Authorization: Bearer <key>`` header via
``get_http_headers()`` and delegates to
``application.services.search_service.search``. These tests exercise
the tool directly by patching ``get_http_headers`` and ``search``; the
full HTTP-layer plumbing (mount, lifespan, session handshake) is
covered by ``tests/test_asgi.py``.
"""
from unittest.mock import patch
import pytest
@pytest.mark.unit
class TestSearchDocsTool:
@pytest.mark.asyncio
async def test_missing_bearer_raises_permission_error(self):
from application.mcp_server import search_docs
with patch(
"application.mcp_server.get_http_headers", return_value={}
):
with pytest.raises(PermissionError):
await search_docs(query="hi")
@pytest.mark.asyncio
async def test_non_bearer_header_raises_permission_error(self):
from application.mcp_server import search_docs
with patch(
"application.mcp_server.get_http_headers",
return_value={"authorization": "Basic dXNlcjpwYXNz"},
):
with pytest.raises(PermissionError):
await search_docs(query="hi")
@pytest.mark.asyncio
async def test_blank_bearer_token_raises_permission_error(self):
from application.mcp_server import search_docs
with patch(
"application.mcp_server.get_http_headers",
return_value={"authorization": "Bearer "},
):
with pytest.raises(PermissionError):
await search_docs(query="hi")
@pytest.mark.asyncio
async def test_invalid_api_key_raises_permission_error(self):
from application.mcp_server import search_docs
from application.services.search_service import InvalidAPIKey
with (
patch(
"application.mcp_server.get_http_headers",
return_value={"authorization": "Bearer bogus"},
),
patch(
"application.mcp_server.search", side_effect=InvalidAPIKey()
),
):
with pytest.raises(PermissionError):
await search_docs(query="hi")
@pytest.mark.asyncio
async def test_search_failed_bubbles_up(self):
from application.mcp_server import search_docs
from application.services.search_service import SearchFailed
with (
patch(
"application.mcp_server.get_http_headers",
return_value={"authorization": "Bearer k"},
),
patch(
"application.mcp_server.search",
side_effect=SearchFailed("boom"),
),
):
with pytest.raises(SearchFailed):
await search_docs(query="hi")
@pytest.mark.asyncio
async def test_happy_path_passes_args_and_returns_hits(self):
from application.mcp_server import search_docs
hits = [{"text": "t", "title": "T", "source": "s"}]
with (
patch(
"application.mcp_server.get_http_headers",
return_value={"authorization": "Bearer the-key"},
),
patch(
"application.mcp_server.search", return_value=hits
) as mock_search,
):
out = await search_docs(query="q", chunks=7)
assert out == hits
mock_search.assert_called_once_with("the-key", "q", 7)
@pytest.mark.asyncio
async def test_default_chunks_is_5(self):
from application.mcp_server import search_docs
with (
patch(
"application.mcp_server.get_http_headers",
return_value={"authorization": "Bearer k"},
),
patch(
"application.mcp_server.search", return_value=[]
) as mock_search,
):
await search_docs(query="q")
mock_search.assert_called_once_with("k", "q", 5)
@pytest.mark.asyncio
async def test_bearer_scheme_case_insensitive(self):
from application.mcp_server import search_docs
with (
patch(
"application.mcp_server.get_http_headers",
return_value={"authorization": "bearer lowercase-scheme"},
),
patch(
"application.mcp_server.search", return_value=[]
) as mock_search,
):
await search_docs(query="q")
mock_search.assert_called_once_with("lowercase-scheme", "q", 5)

View File

@@ -0,0 +1,230 @@
"""Unit tests for application/services/search_service.py.
Tests exercise the service function in isolation — AgentsRepository is
stubbed via a patched ``db_readonly`` context manager, and
``VectorCreator.create_vectorstore`` is patched to return a fake
vectorstore. No Flask app context, no real DB, no real embeddings.
"""
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
_collect_source_ids,
search,
)
@contextmanager
def _fake_db_readonly(agent_data):
"""Patch ``db_readonly`` so ``AgentsRepository.find_by_key`` returns ``agent_data``."""
agents_repo = MagicMock()
agents_repo.find_by_key.return_value = agent_data
@contextmanager
def _yield_conn():
yield MagicMock()
with patch(
"application.services.search_service.db_readonly", _yield_conn
), patch(
"application.services.search_service.AgentsRepository",
return_value=agents_repo,
):
yield
@pytest.mark.unit
class TestCollectSourceIds:
def test_empty_when_no_sources(self):
assert _collect_source_ids({}) == []
def test_returns_extra_source_ids(self):
agent = {"extra_source_ids": ["s1", "s2"], "source_id": "legacy"}
assert _collect_source_ids(agent) == ["s1", "s2"]
def test_falls_back_to_single_source_id(self):
agent = {"extra_source_ids": [], "source_id": "s1"}
assert _collect_source_ids(agent) == ["s1"]
def test_skips_empty_entries_in_extra(self):
agent = {"extra_source_ids": ["", None, "s1"], "source_id": "fallback"}
assert _collect_source_ids(agent) == ["s1"]
@pytest.mark.unit
class TestSearchInvalidAPIKey:
def test_raises_when_key_unknown(self):
with _fake_db_readonly(None):
with pytest.raises(InvalidAPIKey):
search("does-not-exist", "hello", 5)
def test_raises_search_failed_on_db_error(self):
@contextmanager
def _yield_conn():
yield MagicMock()
agents_repo = MagicMock()
agents_repo.find_by_key.side_effect = RuntimeError("db down")
with patch(
"application.services.search_service.db_readonly", _yield_conn
), patch(
"application.services.search_service.AgentsRepository",
return_value=agents_repo,
):
with pytest.raises(SearchFailed):
search("any-key", "hello", 5)
@pytest.mark.unit
class TestSearchEmptyWhenNoSources:
def test_returns_empty_when_agent_has_no_sources(self):
with _fake_db_readonly({"extra_source_ids": [], "source_id": None}):
assert search("k", "q", 5) == []
def test_returns_empty_for_zero_chunks_without_db_lookup(self):
with patch("application.services.search_service.db_readonly") as mock_db:
assert search("k", "q", 0) == []
mock_db.assert_not_called()
def test_returns_empty_for_negative_chunks_without_db_lookup(self):
with patch("application.services.search_service.db_readonly") as mock_db:
assert search("k", "q", -1) == []
mock_db.assert_not_called()
@pytest.mark.unit
class TestSearchResults:
def test_returns_hit_shape(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
fake_vs = MagicMock()
fake_vs.search.return_value = [
{
"text": "Test content",
"metadata": {"title": "Test Title", "source": "/path/to/doc"},
}
]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert results == [
{"text": "Test content", "title": "Test Title", "source": "/path/to/doc"}
]
def test_handles_langchain_document_format(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
lc_doc = MagicMock()
lc_doc.page_content = "Langchain content"
lc_doc.metadata = {"title": "LC Title", "source": "/lc/path"}
fake_vs = MagicMock()
fake_vs.search.return_value = [lc_doc]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert len(results) == 1
assert results[0]["text"] == "Langchain content"
assert results[0]["title"] == "LC Title"
def test_respects_chunks_cap(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
docs = [
{"text": f"Content {i}", "metadata": {"title": f"T{i}"}}
for i in range(10)
]
fake_vs = MagicMock()
fake_vs.search.return_value = docs
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 3)
assert len(results) == 3
def test_deduplicates_results_by_content_prefix(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
dup_text = "Duplicate content " * 20
docs = [
{"text": dup_text, "metadata": {"title": "T1"}},
{"text": dup_text, "metadata": {"title": "T2"}},
{"text": "Unique content", "metadata": {"title": "T3"}},
]
fake_vs = MagicMock()
fake_vs.search.return_value = docs
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert len(results) == 2
def test_skips_broken_source_and_returns_from_healthy_ones(self):
# Two sources — the first raises, the second returns a doc. The
# caller should still get the healthy source's result.
agent = {"extra_source_ids": ["broken", "ok"], "source_id": None}
healthy_vs = MagicMock()
healthy_vs.search.return_value = [
{"text": "ok content", "metadata": {"title": "Ok"}}
]
def create_vs(store, source_id, key):
if source_id == "broken":
raise RuntimeError("vector index missing")
return healthy_vs
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
side_effect=create_vs,
):
results = search("k", "q", 5)
assert len(results) == 1
assert results[0]["text"] == "ok content"
def test_uses_filename_when_title_missing(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
fake_vs = MagicMock()
fake_vs.search.return_value = [
{"text": "body", "metadata": {"filename": "document.pdf"}}
]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert results[0]["title"] == "document.pdf"
def test_uses_content_snippet_as_title_last_resort(self):
agent = {"source_id": "src-1", "extra_source_ids": []}
fake_vs = MagicMock()
fake_vs.search.return_value = [
{"text": "Content without any title metadata at all", "metadata": {}}
]
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore",
return_value=fake_vs,
):
results = search("k", "q", 5)
assert results[0]["title"].endswith("...")
assert "Content without any title" in results[0]["title"]
def test_skips_empty_source_ids(self):
# ``source_id=" "`` only — after strip() this leaves no real source.
agent = {"extra_source_ids": [" ", ""], "source_id": None}
with _fake_db_readonly(agent), patch(
"application.services.search_service.VectorCreator.create_vectorstore"
) as mock_create:
results = search("k", "q", 5)
mock_create.assert_not_called()
assert results == []

View File

@@ -105,11 +105,26 @@ class TestAuthenticateRequest:
assert response.status_code == 200
class TestAfterRequest:
class TestFlaskCors:
@pytest.mark.unit
def test_cors_headers(self, client):
response = client.get("/api/health")
assert response.headers.get("Access-Control-Allow-Origin") == "*"
assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "")
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")
def test_cors_headers_on_flask_route(self, client):
response = client.get("/api/health", headers={"Origin": "http://localhost:5173"})
assert response.headers["Access-Control-Allow-Origin"] == "*"
assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization"
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS"
@pytest.mark.unit
def test_cors_headers_on_flask_preflight(self, client):
response = client.options(
"/api/health",
headers={
"Origin": "http://localhost:5173",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "Content-Type",
},
)
assert response.status_code == 200
assert response.headers["Access-Control-Allow-Origin"] == "*"
assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization"
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS"

136
tests/test_asgi.py Normal file
View File

@@ -0,0 +1,136 @@
"""Smoke tests for application/asgi.py.
The goal isn't to re-test Flask or FastMCP internals — it's to catch
regressions in the wiring: mounts resolve, CORS headers emit, lifespan
runs (without it, the /mcp session manager raises "Task group is not
initialized"), routing to ``/`` vs ``/mcp`` doesn't cross paths.
Uses ``starlette.testclient.TestClient`` because it boots the ASGI app
end-to-end and handles the lifespan protocol automatically — ``httpx``
alone does not run lifespan events, which would mask the exact kind of
misconfiguration this test suite exists to catch.
"""
import pytest
@pytest.mark.unit
def test_asgi_app_imports():
from application.asgi import asgi_app
assert asgi_app is not None
@pytest.mark.unit
def test_flask_route_served_through_starlette_mount():
"""GET /api/health should reach the Flask app via a2wsgi and return 200."""
from starlette.testclient import TestClient
from application.asgi import asgi_app
with TestClient(asgi_app) as client:
r = client.get("/api/health")
assert r.status_code == 200
assert r.json() == {"status": "ok"}
@pytest.mark.unit
def test_mcp_endpoint_mounted_and_lifespan_runs():
"""/mcp must be reachable AND the FastMCP session manager must start.
Without ``lifespan=mcp_app.lifespan`` on the outer Starlette app,
every /mcp request raises ``RuntimeError: Task group is not
initialized``. Hitting the endpoint under a real lifespan-aware
client catches that.
"""
from starlette.testclient import TestClient
from application.asgi import asgi_app
with TestClient(asgi_app) as client:
# Minimal MCP initialize request. Doesn't need to succeed — we
# just need a non-404, non-500-with-RuntimeError response to
# confirm the mount + lifespan are both wired.
r = client.post(
"/mcp/",
headers={
"Origin": "http://example.com",
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json={
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "pytest", "version": "0"},
},
},
)
assert r.status_code != 404, f"/mcp mount unreachable: {r.status_code}"
# A successful initialize returns 200 with a Mcp-Session-Id header.
assert r.status_code == 200
assert "mcp-session-id" in {k.lower() for k in r.headers.keys()}
assert r.headers.get("access-control-expose-headers") == "Mcp-Session-Id"
@pytest.mark.unit
def test_cors_headers_on_flask_route():
"""CORS middleware should emit allow-origin on actual (non-preflight) requests.
``allow_origins=["*"]`` → header value is literal ``*`` (not an echo).
"""
from starlette.testclient import TestClient
from application.asgi import asgi_app
with TestClient(asgi_app) as client:
r = client.get("/api/health", headers={"Origin": "http://example.com"})
assert r.status_code == 200
assert r.headers.get("access-control-allow-origin") == "*"
@pytest.mark.unit
def test_cors_preflight_on_flask_route():
"""OPTIONS preflight on a Flask route should be handled by Starlette CORSMiddleware."""
from starlette.testclient import TestClient
from application.asgi import asgi_app
with TestClient(asgi_app) as client:
r = client.options(
"/api/health",
headers={
"Origin": "http://example.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "Content-Type",
},
)
assert r.status_code in (200, 204)
assert r.headers.get("access-control-allow-origin") == "*"
assert "GET" in r.headers.get("access-control-allow-methods", "")
@pytest.mark.unit
def test_cors_preflight_on_mcp_route():
"""Browser clients hitting /mcp should be allowed to send session headers."""
from starlette.testclient import TestClient
from application.asgi import asgi_app
with TestClient(asgi_app) as client:
r = client.options(
"/mcp/",
headers={
"Origin": "http://example.com",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": (
"Authorization, Content-Type, Mcp-Session-Id"
),
},
)
assert r.status_code in (200, 204)
assert r.headers.get("access-control-allow-origin") == "*"
assert "Mcp-Session-Id" in r.headers.get("access-control-allow-headers", "")

402
tests/test_version_check.py Normal file
View File

@@ -0,0 +1,402 @@
"""Unit tests for the anonymous startup version-check client.
All external dependencies (Postgres, Redis, HTTP) are mocked so the
suite runs in pure-Python isolation. The focus is on the branching
behavior described in the spec: opt-out, cache-hit, cache-miss,
lock-denied, and the various failure paths that must never propagate.
"""
from __future__ import annotations
import json
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
import requests
from application.updates import version_check as vc_module
class _FakeRepo:
"""Stand-in for AppMetadataRepository backed by a plain dict."""
def __init__(self, store: dict | None = None, *, raise_on_get_instance: bool = False):
self._store: dict[str, str] = dict(store) if store else {}
self._raise = raise_on_get_instance
def get(self, key: str):
return self._store.get(key)
def set(self, key: str, value: str) -> None:
self._store[key] = value
def get_or_create_instance_id(self) -> str:
if self._raise:
raise RuntimeError("simulated Postgres outage")
existing = self._store.get("instance_id")
if existing:
return existing
self._store["instance_id"] = "11111111-2222-3333-4444-555555555555"
return self._store["instance_id"]
@contextmanager
def _fake_db_session():
"""Stand-in for ``db_session()`` — yields ``None`` because the fake
repository ignores its connection argument."""
yield None
def _install_repo(monkeypatch, repo: _FakeRepo):
"""Patch the repo constructor so ``AppMetadataRepository(conn)`` → ``repo``."""
monkeypatch.setattr(
vc_module, "AppMetadataRepository", lambda conn: repo
)
def _install_db_session(monkeypatch, *, raise_exc: Exception | None = None):
if raise_exc is not None:
@contextmanager
def boom():
raise raise_exc
yield # pragma: no cover - unreachable
monkeypatch.setattr(vc_module, "db_session", boom)
else:
monkeypatch.setattr(vc_module, "db_session", _fake_db_session)
def _make_redis_mock(*, get_return=None, set_return=True):
client = MagicMock()
client.get.return_value = get_return
client.set.return_value = set_return
client.setex.return_value = True
client.delete.return_value = 1
return client
@pytest.fixture
def enable_check(monkeypatch):
monkeypatch.setattr(vc_module.settings, "VERSION_CHECK", True)
@pytest.mark.unit
def test_opt_out_short_circuits(monkeypatch):
"""VERSION_CHECK=0 → no Postgres, no Redis, no network."""
monkeypatch.setattr(vc_module.settings, "VERSION_CHECK", False)
db_spy = MagicMock()
redis_spy = MagicMock()
post_spy = MagicMock()
monkeypatch.setattr(vc_module, "db_session", db_spy)
monkeypatch.setattr(vc_module, "get_redis_instance", redis_spy)
monkeypatch.setattr(vc_module.requests, "post", post_spy)
vc_module.run_check()
db_spy.assert_not_called()
redis_spy.assert_not_called()
post_spy.assert_not_called()
@pytest.mark.unit
def test_cache_hit_renders_without_lock_or_network(monkeypatch, enable_check, capsys):
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
cached = {
"advisories": [
{
"id": "DOCSGPT-TEST-1",
"title": "Example",
"severity": "high",
"fixed_in": "0.17.0",
"url": "https://example.test/a",
"summary": "Upgrade required.",
}
]
}
redis_client = _make_redis_mock(get_return=json.dumps(cached).encode("utf-8"))
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
post_spy = MagicMock()
monkeypatch.setattr(vc_module.requests, "post", post_spy)
vc_module.run_check()
redis_client.get.assert_called_once_with(vc_module.CACHE_KEY)
redis_client.set.assert_not_called()
redis_client.setex.assert_not_called()
post_spy.assert_not_called()
assert "SECURITY ADVISORY: DOCSGPT-TEST-1" in capsys.readouterr().out
@pytest.mark.unit
def test_cache_miss_lock_acquired_fetches_and_caches(monkeypatch, enable_check):
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
redis_client = _make_redis_mock(get_return=None, set_return=True)
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
response_body = {
"advisories": [
{
"id": "DOCSGPT-LOW-1",
"title": "Minor",
"severity": "low",
"fixed_in": "0.17.0",
"url": "https://example.test/low",
}
],
"next_check_after": 1800,
}
post_response = MagicMock()
post_response.status_code = 200
post_response.json.return_value = response_body
post_spy = MagicMock(return_value=post_response)
monkeypatch.setattr(vc_module.requests, "post", post_spy)
vc_module.run_check()
post_spy.assert_called_once()
call_kwargs = post_spy.call_args
assert call_kwargs.args[0] == vc_module.ENDPOINT_URL
payload = call_kwargs.kwargs["json"]
assert payload["client"] == "docsgpt-backend"
assert payload["instance_id"] == "11111111-2222-3333-4444-555555555555"
assert "version" in payload and "python_version" in payload
# Lock acquired with NX EX, cache written with server-specified TTL,
# lock released.
redis_client.set.assert_called_once()
set_kwargs = redis_client.set.call_args.kwargs
assert set_kwargs == {"nx": True, "ex": vc_module.LOCK_TTL_SECONDS}
redis_client.setex.assert_called_once()
setex_args = redis_client.setex.call_args.args
assert setex_args[0] == vc_module.CACHE_KEY
assert setex_args[1] == 1800 # server override under 6h
redis_client.delete.assert_called_once_with(vc_module.LOCK_KEY)
@pytest.mark.unit
def test_cache_miss_lock_denied_skips_silently(monkeypatch, enable_check):
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
redis_client = _make_redis_mock(get_return=None, set_return=False) # lock not acquired
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
post_spy = MagicMock()
monkeypatch.setattr(vc_module.requests, "post", post_spy)
vc_module.run_check()
post_spy.assert_not_called()
redis_client.setex.assert_not_called()
redis_client.delete.assert_not_called()
@pytest.mark.unit
def test_instance_id_persisted_across_runs(monkeypatch, enable_check):
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
redis_client = _make_redis_mock(get_return=None, set_return=True)
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
post_response = MagicMock()
post_response.status_code = 200
post_response.json.return_value = {}
monkeypatch.setattr(
vc_module.requests, "post", MagicMock(return_value=post_response)
)
vc_module.run_check()
first_id = repo.get("instance_id")
vc_module.run_check()
second_id = repo.get("instance_id")
assert first_id is not None
assert first_id == second_id
@pytest.mark.unit
def test_first_run_notice_emitted_once(monkeypatch, enable_check, capsys):
repo = _FakeRepo() # empty — notice not shown yet
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
# Cache hit so we don't need to mock HTTP. Notice logic runs before cache.
redis_client = _make_redis_mock(get_return=json.dumps({}).encode("utf-8"))
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
vc_module.run_check()
first_out = capsys.readouterr().out
assert "Anonymous version check enabled" in first_out
assert repo.get("version_check_notice_shown") == "1"
vc_module.run_check()
second_out = capsys.readouterr().out
assert "Anonymous version check enabled" not in second_out
@pytest.mark.unit
def test_postgres_unavailable_skips_silently(monkeypatch, enable_check):
_install_db_session(monkeypatch, raise_exc=RuntimeError("db down"))
redis_spy = MagicMock()
post_spy = MagicMock()
monkeypatch.setattr(vc_module, "get_redis_instance", redis_spy)
monkeypatch.setattr(vc_module.requests, "post", post_spy)
vc_module.run_check()
redis_spy.assert_not_called()
post_spy.assert_not_called()
@pytest.mark.unit
def test_postgres_repo_raises_skips_silently(monkeypatch, enable_check):
repo = _FakeRepo(raise_on_get_instance=True)
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
redis_spy = MagicMock()
post_spy = MagicMock()
monkeypatch.setattr(vc_module, "get_redis_instance", redis_spy)
monkeypatch.setattr(vc_module.requests, "post", post_spy)
vc_module.run_check()
redis_spy.assert_not_called()
post_spy.assert_not_called()
@pytest.mark.unit
def test_redis_unavailable_proceeds_uncached(monkeypatch, enable_check):
"""``get_redis_instance()`` → None should not abort the check."""
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: None)
post_response = MagicMock()
post_response.status_code = 200
post_response.json.return_value = {"advisories": []}
post_spy = MagicMock(return_value=post_response)
monkeypatch.setattr(vc_module.requests, "post", post_spy)
vc_module.run_check()
post_spy.assert_called_once()
@pytest.mark.unit
def test_unknown_version_warns_and_skips(monkeypatch, enable_check):
"""get_version() → "unknown" must not hit the endpoint silently."""
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
redis_client = _make_redis_mock(get_return=None, set_return=True)
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
monkeypatch.setattr(vc_module, "get_version", lambda: "unknown")
post_spy = MagicMock()
monkeypatch.setattr(vc_module.requests, "post", post_spy)
with patch.object(vc_module, "logger") as mock_logger:
vc_module.run_check()
post_spy.assert_not_called()
redis_client.setex.assert_not_called()
# Lock released so the next cycle can retry.
redis_client.delete.assert_called_once_with(vc_module.LOCK_KEY)
assert mock_logger.warning.called
assert "unknown" in mock_logger.warning.call_args.args[0].lower() \
or mock_logger.warning.call_args.args[1:] == ("unknown",)
@pytest.mark.unit
def test_http_5xx_swallowed(monkeypatch, enable_check):
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
redis_client = _make_redis_mock(get_return=None, set_return=True)
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
post_response = MagicMock()
post_response.status_code = 503
post_response.json.return_value = {}
monkeypatch.setattr(
vc_module.requests, "post", MagicMock(return_value=post_response)
)
vc_module.run_check()
redis_client.setex.assert_not_called()
# Lock still released so the next cycle can retry.
redis_client.delete.assert_called_once_with(vc_module.LOCK_KEY)
@pytest.mark.unit
def test_http_timeout_swallowed(monkeypatch, enable_check):
repo = _FakeRepo({"version_check_notice_shown": "1"})
_install_repo(monkeypatch, repo)
_install_db_session(monkeypatch)
redis_client = _make_redis_mock(get_return=None, set_return=True)
monkeypatch.setattr(vc_module, "get_redis_instance", lambda: redis_client)
monkeypatch.setattr(
vc_module.requests,
"post",
MagicMock(side_effect=requests.Timeout("boom")),
)
# Must not raise.
vc_module.run_check()
redis_client.setex.assert_not_called()
redis_client.delete.assert_called_once_with(vc_module.LOCK_KEY)
@pytest.mark.unit
def test_compute_ttl_honors_server_override():
assert vc_module._compute_ttl({"next_check_after": 300}) == 300
assert vc_module._compute_ttl({"next_check_after": 60000}) == vc_module.CACHE_TTL_SECONDS
assert vc_module._compute_ttl({}) == vc_module.CACHE_TTL_SECONDS
assert vc_module._compute_ttl({"next_check_after": "bad"}) == vc_module.CACHE_TTL_SECONDS
# Zero/negative overrides fall back to the 6h default.
assert vc_module._compute_ttl({"next_check_after": 0}) == vc_module.CACHE_TTL_SECONDS
@pytest.mark.unit
def test_render_advisories_logs_warning_and_prints_banner(monkeypatch, capsys):
with patch.object(vc_module, "logger") as mock_logger:
vc_module._render_advisories(
{
"advisories": [
{
"id": "DOCSGPT-2025-001",
"title": "SSRF",
"severity": "critical",
"fixed_in": "0.17.0",
"url": "https://example.test/a",
"summary": "Your DocsGPT is vulnerable.",
},
{
"id": "DOCSGPT-2025-002",
"title": "Low-sev",
"severity": "low",
},
]
}
)
# Both advisories logged as warnings.
assert mock_logger.warning.call_count == 2
out = capsys.readouterr().out
# Only the high/critical one gets the console banner.
assert "DOCSGPT-2025-001" in out
assert "DOCSGPT-2025-002" not in out